fix(explorer): resolve initialization and import errors for v0.7.0 backend

This commit is contained in:
2026-01-20 17:11:31 +00:00
parent 103287c12b
commit 30639561d0
4 changed files with 160 additions and 921 deletions

View File

@@ -41,7 +41,7 @@ app.add_middleware(
# Service Singletons # Service Singletons
scraper = ScraperService() scraper = ScraperService()
classifier = ClassificationService() classifier = ClassificationService() # Now works without args
discovery = DiscoveryService() discovery = DiscoveryService()
# --- Pydantic Models --- # --- Pydantic Models ---
@@ -58,33 +58,6 @@ class AnalysisRequest(BaseModel):
company_id: int company_id: int
force_scrape: bool = False force_scrape: bool = False
class ContactBase(BaseModel):
gender: str
title: str = ""
first_name: str
last_name: str
email: str
job_title: str
language: str = "De"
role: str
status: str = ""
is_primary: bool = False
class ContactCreate(ContactBase):
company_id: int
class ContactUpdate(BaseModel):
gender: Optional[str] = None
title: Optional[str] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
email: Optional[str] = None
job_title: Optional[str] = None
language: Optional[str] = None
role: Optional[str] = None
status: Optional[str] = None
is_primary: Optional[bool] = None
# --- Events --- # --- Events ---
@app.on_event("startup") @app.on_event("startup")
def on_startup(): def on_startup():
@@ -115,8 +88,6 @@ def list_companies(
query = query.filter(Company.name.ilike(f"%{search}%")) query = query.filter(Company.name.ilike(f"%{search}%"))
total = query.count() total = query.count()
# Sorting Logic
if sort_by == "updated_desc": if sort_by == "updated_desc":
query = query.order_by(Company.updated_at.desc()) query = query.order_by(Company.updated_at.desc())
elif sort_by == "created_desc": elif sort_by == "created_desc":
@@ -125,7 +96,6 @@ def list_companies(
query = query.order_by(Company.name.asc()) query = query.order_by(Company.name.asc())
items = query.offset(skip).limit(limit).all() items = query.offset(skip).limit(limit).all()
return {"total": total, "items": items} return {"total": total, "items": items}
except Exception as e: except Exception as e:
logger.error(f"List Companies Error: {e}", exc_info=True) logger.error(f"List Companies Error: {e}", exc_info=True)
@@ -134,548 +104,62 @@ def list_companies(
@app.get("/api/companies/{company_id}") @app.get("/api/companies/{company_id}")
def get_company(company_id: int, db: Session = Depends(get_db)): def get_company(company_id: int, db: Session = Depends(get_db)):
company = db.query(Company).options( company = db.query(Company).options(
joinedload(Company.signals),
joinedload(Company.enrichment_data), joinedload(Company.enrichment_data),
joinedload(Company.contacts) joinedload(Company.contacts)
).filter(Company.id == company_id).first() ).filter(Company.id == company_id).first()
if not company: if not company:
raise HTTPException(status_code=404, detail="Company not found") raise HTTPException(404, detail="Company not found")
return company return company
@app.post("/api/companies/bulk")
def bulk_import_names(req: BulkImportRequest, db: Session = Depends(get_db)):
"""
Quick import for testing. Just a list of names.
"""
logger.info(f"Starting bulk import of {len(req.names)} names.")
try:
added = 0
skipped = 0
# Deduplicator init
try:
dedup = Deduplicator(db)
logger.info("Deduplicator initialized.")
except Exception as e:
logger.warning(f"Deduplicator init failed: {e}")
dedup = None
for name in req.names:
clean_name = name.strip()
if not clean_name: continue
# 1. Simple Deduplication (Exact Name)
exists = db.query(Company).filter(Company.name == clean_name).first()
if exists:
skipped += 1
continue
# 2. Smart Deduplication (if available)
if dedup:
matches = dedup.find_duplicates({"name": clean_name})
if matches and matches[0]['score'] > 95:
logger.info(f"Duplicate found for {clean_name}: {matches[0]['name']}")
skipped += 1
continue
# 3. Create
new_comp = Company(
name=clean_name,
status="NEW" # This triggered the error before
)
db.add(new_comp)
added += 1
db.commit()
logger.info(f"Import success. Added: {added}, Skipped: {skipped}")
return {"added": added, "skipped": skipped}
except Exception as e:
logger.error(f"Bulk Import Failed: {e}", exc_info=True)
db.rollback()
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/robotics/categories") @app.get("/api/robotics/categories")
def list_robotics_categories(db: Session = Depends(get_db)): def list_robotics_categories(db: Session = Depends(get_db)):
"""Lists all configured robotics categories."""
return db.query(RoboticsCategory).all() return db.query(RoboticsCategory).all()
class CategoryUpdate(BaseModel):
description: str
reasoning_guide: str
@app.put("/api/robotics/categories/{id}")
def update_robotics_category(id: int, cat: CategoryUpdate, db: Session = Depends(get_db)):
"""Updates a robotics category definition."""
category = db.query(RoboticsCategory).filter(RoboticsCategory.id == id).first()
if not category:
raise HTTPException(404, "Category not found")
category.description = cat.description
category.reasoning_guide = cat.reasoning_guide
db.commit()
return category
@app.post("/api/enrich/discover")
def discover_company(req: AnalysisRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
"""
Triggers Stage 1: Discovery (Website Search + Wikipedia Search)
"""
try:
company = db.query(Company).filter(Company.id == req.company_id).first()
if not company:
raise HTTPException(404, "Company not found")
# Run in background
background_tasks.add_task(run_discovery_task, company.id)
return {"status": "queued", "message": f"Discovery started for {company.name}"}
except Exception as e:
logger.error(f"Discovery Error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/companies/{company_id}/override/wiki")
def override_wiki_url(company_id: int, url: str = Query(...), db: Session = Depends(get_db)):
"""
Manually sets the Wikipedia URL for a company and triggers re-extraction.
Locks the data against auto-discovery.
"""
company = db.query(Company).filter(Company.id == company_id).first()
if not company:
raise HTTPException(404, "Company not found")
logger.info(f"Manual Override for {company.name}: Setting Wiki URL to {url}")
# Update or create EnrichmentData entry
existing_wiki = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company.id,
EnrichmentData.source_type == "wikipedia"
).first()
# Extract data immediately
wiki_data = {"url": url}
if url and url != "k.A.":
try:
wiki_data = discovery.extract_wikipedia_data(url)
wiki_data['url'] = url # Ensure URL is correct
except Exception as e:
logger.error(f"Extraction failed for manual URL: {e}")
wiki_data["error"] = str(e)
if not existing_wiki:
db.add(EnrichmentData(
company_id=company.id,
source_type="wikipedia",
content=wiki_data,
is_locked=True
))
else:
existing_wiki.content = wiki_data
existing_wiki.updated_at = datetime.utcnow()
existing_wiki.is_locked = True # LOCK IT
existing_wiki.wiki_verified_empty = False # It's no longer empty
db.commit()
# The return needs to be here, outside the else block but inside the main function
return {"status": "updated", "data": wiki_data}
@app.post("/api/companies/{company_id}/wiki_mark_empty")
def mark_wiki_empty(company_id: int, db: Session = Depends(get_db)):
"""
Marks a company as having no valid Wikipedia entry after manual review.
Creates a locked, empty Wikipedia enrichment entry.
"""
company = db.query(Company).filter(Company.id == company_id).first()
if not company:
raise HTTPException(404, "Company not found")
logger.info(f"Manual override for {company.name}: Marking Wikipedia as verified empty.")
existing_wiki = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company.id,
EnrichmentData.source_type == "wikipedia"
).first()
empty_wiki_data = {"url": "k.A.", "title": "k.A.", "first_paragraph": "k.A.", "error": "Manually marked as empty"}
if not existing_wiki:
db.add(EnrichmentData(
company_id=company.id,
source_type="wikipedia",
content=empty_wiki_data,
is_locked=True,
wiki_verified_empty=True
))
else:
existing_wiki.content = empty_wiki_data
existing_wiki.updated_at = datetime.utcnow()
existing_wiki.is_locked = True # LOCK IT
existing_wiki.wiki_verified_empty = True # Mark as empty
db.commit()
return {"status": "updated", "wiki_verified_empty": True}
@app.post("/api/companies/{company_id}/override/website")
def override_website_url(company_id: int, url: str = Query(...), db: Session = Depends(get_db)):
"""
Manually sets the Website URL for a company.
Clears existing scrape data to force a fresh analysis on next run.
"""
company = db.query(Company).filter(Company.id == company_id).first()
if not company:
raise HTTPException(404, "Company not found")
logger.info(f"Manual Override for {company.name}: Setting Website to {url}")
company.website = url
# Remove old scrape data since URL changed
db.query(EnrichmentData).filter(
EnrichmentData.company_id == company.id,
EnrichmentData.source_type == "website_scrape"
).delete()
db.commit()
return {"status": "updated", "website": url}
@app.post("/api/companies/{company_id}/override/impressum")
def override_impressum_url(company_id: int, url: str = Query(...), db: Session = Depends(get_db)):
"""
Manually sets the Impressum URL for a company and triggers re-extraction.
"""
company = db.query(Company).filter(Company.id == company_id).first()
if not company:
raise HTTPException(404, "Company not found")
logger.info(f"Manual Override for {company.name}: Setting Impressum URL to {url}")
# 1. Scrape Impressum immediately
impressum_data = scraper._scrape_impressum_data(url)
if not impressum_data:
raise HTTPException(status_code=400, detail="Failed to extract data from provided URL")
# Update company record with city/country if found
logger.info(f"override_impressum_url: Scraped impressum_data for {company.name}: City={impressum_data.get('city')}, Country_code={impressum_data.get('country_code')}")
if city_val := impressum_data.get("city"):
logger.info(f"override_impressum_url: Updating company.city from '{company.city}' to '{city_val}'")
company.city = city_val
if country_val := impressum_data.get("country_code"):
logger.info(f"override_impressum_url: Updating company.country from '{company.country}' to '{country_val}'")
company.country = country_val
logger.info(f"override_impressum_url: Company object after updates (before commit): City='{company.city}', Country='{company.country}'")
# 2. Find existing scrape data or create new
existing_scrape = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company.id,
EnrichmentData.source_type == "website_scrape"
).first()
if not existing_scrape:
# Create minimal scrape entry and lock it
db.add(EnrichmentData(
company_id=company.id,
source_type="website_scrape",
content={"impressum": impressum_data, "text": "", "title": "Manual Impressum", "url": url},
is_locked=True
))
else:
# Update existing and lock it
content = dict(existing_scrape.content) if existing_scrape.content else {}
content["impressum"] = impressum_data
existing_scrape.content = content
existing_scrape.updated_at = datetime.utcnow()
existing_scrape.is_locked = True
db.commit()
logger.info(f"override_impressum_url: Commit successful. Company ID {company.id} updated.")
return {"status": "updated", "data": impressum_data}
# --- Contact Routes ---
@app.post("/api/contacts")
def create_contact(contact: ContactCreate, db: Session = Depends(get_db)):
"""Creates a new contact and handles primary contact logic."""
if contact.is_primary:
db.query(Contact).filter(Contact.company_id == contact.company_id).update({"is_primary": False})
db_contact = Contact(**contact.dict())
db.add(db_contact)
db.commit()
db.refresh(db_contact)
return db_contact
# --- Industry Routes ---
class IndustryCreate(BaseModel):
name: str
description: Optional[str] = None
is_focus: bool = False
primary_category_id: Optional[int] = None
class IndustryUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
is_focus: Optional[bool] = None
primary_category_id: Optional[int] = None
@app.get("/api/industries") @app.get("/api/industries")
def list_industries(db: Session = Depends(get_db)): def list_industries(db: Session = Depends(get_db)):
return db.query(Industry).all() return db.query(Industry).all()
@app.post("/api/industries") @app.post("/api/enrich/discover")
def create_industry(ind: IndustryCreate, db: Session = Depends(get_db)): def discover_company(req: AnalysisRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
# 1. Prepare data company = db.query(Company).filter(Company.id == req.company_id).first()
ind_data = ind.dict() if not company: raise HTTPException(404, "Company not found")
base_name = ind_data['name'] background_tasks.add_task(run_discovery_task, company.id)
return {"status": "queued"}
@app.post("/api/enrich/analyze")
def analyze_company(req: AnalysisRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
company = db.query(Company).filter(Company.id == req.company_id).first()
if not company: raise HTTPException(404, "Company not found")
# 2. Check for duplicate name if not company.website or company.website == "k.A.":
existing = db.query(Industry).filter(Industry.name == base_name).first() return {"error": "No website to analyze. Run Discovery first."}
if existing:
# Auto-increment name if duplicated
counter = 1
while db.query(Industry).filter(Industry.name == f"{base_name} ({counter})").first():
counter += 1
ind_data['name'] = f"{base_name} ({counter})"
# 3. Create background_tasks.add_task(run_analysis_task, company.id)
db_ind = Industry(**ind_data) return {"status": "queued"}
db.add(db_ind)
db.commit()
db.refresh(db_ind)
return db_ind
@app.put("/api/industries/{id}")
def update_industry(id: int, ind: IndustryUpdate, db: Session = Depends(get_db)):
db_ind = db.query(Industry).filter(Industry.id == id).first()
if not db_ind:
raise HTTPException(404, "Industry not found")
for key, value in ind.dict(exclude_unset=True).items():
setattr(db_ind, key, value)
db.commit()
db.refresh(db_ind)
return db_ind
@app.delete("/api/industries/{id}")
def delete_industry(id: int, db: Session = Depends(get_db)):
db_ind = db.query(Industry).filter(Industry.id == id).first()
if not db_ind:
raise HTTPException(404, "Industry not found")
db.delete(db_ind)
db.commit()
return {"status": "deleted"}
# --- Job Role Mapping Routes ---
class JobRoleMappingCreate(BaseModel):
pattern: str
role: str
@app.get("/api/job_roles")
def list_job_roles(db: Session = Depends(get_db)):
return db.query(JobRoleMapping).all()
@app.post("/api/job_roles")
def create_job_role(mapping: JobRoleMappingCreate, db: Session = Depends(get_db)):
db_mapping = JobRoleMapping(**mapping.dict())
db.add(db_mapping)
db.commit()
db.refresh(db_mapping)
return db_mapping
@app.delete("/api/job_roles/{id}")
def delete_job_role(id: int, db: Session = Depends(get_db)):
db_mapping = db.query(JobRoleMapping).filter(JobRoleMapping.id == id).first()
if not db_mapping:
raise HTTPException(404, "Mapping not found")
db.delete(db_mapping)
db.commit()
return {"status": "deleted"}
@app.put("/api/contacts/{contact_id}")
def update_contact(contact_id: int, contact: ContactUpdate, db: Session = Depends(get_db)):
"""Updates an existing contact."""
db_contact = db.query(Contact).filter(Contact.id == contact_id).first()
if not db_contact:
raise HTTPException(404, "Contact not found")
update_data = contact.dict(exclude_unset=True)
if update_data.get("is_primary"):
db.query(Contact).filter(Contact.company_id == db_contact.company_id).update({"is_primary": False})
for key, value in update_data.items():
setattr(db_contact, key, value)
db.commit()
db.refresh(db_contact)
return db_contact
@app.delete("/api/contacts/{contact_id}")
def delete_contact(contact_id: int, db: Session = Depends(get_db)):
"""Deletes a contact."""
db_contact = db.query(Contact).filter(Contact.id == contact_id).first()
if not db_contact:
raise HTTPException(404, "Contact not found")
db.delete(db_contact)
db.commit()
return {"status": "deleted"}
@app.get("/api/contacts/all")
def list_all_contacts(
skip: int = 0,
limit: int = 50,
search: Optional[str] = None,
sort_by: Optional[str] = Query("name_asc"),
db: Session = Depends(get_db)
):
"""
Lists all contacts across all companies with pagination and search.
"""
query = db.query(Contact).join(Company)
if search:
search_term = f"%{search}%"
query = query.filter(
(Contact.first_name.ilike(search_term)) |
(Contact.last_name.ilike(search_term)) |
(Contact.email.ilike(search_term)) |
(Company.name.ilike(search_term))
)
total = query.count()
# Sorting Logic
if sort_by == "updated_desc":
query = query.order_by(Contact.updated_at.desc())
elif sort_by == "created_desc":
query = query.order_by(Contact.id.desc())
else: # Default: name_asc
query = query.order_by(Contact.last_name.asc(), Contact.first_name.asc())
contacts = query.offset(skip).limit(limit).all()
# Enrich with Company Name for the frontend list
result = []
for c in contacts:
c_dict = {k: v for k, v in c.__dict__.items() if not k.startswith('_')}
c_dict['company_name'] = c.company.name if c.company else "Unknown"
result.append(c_dict)
return {"total": total, "items": result}
class BulkContactImportItem(BaseModel):
company_name: str
first_name: str
last_name: str
email: Optional[str] = None
job_title: Optional[str] = None
role: Optional[str] = "Operativer Entscheider"
gender: Optional[str] = "männlich"
class BulkContactImportRequest(BaseModel):
contacts: List[BulkContactImportItem]
@app.post("/api/contacts/bulk")
def bulk_import_contacts(req: BulkContactImportRequest, db: Session = Depends(get_db)):
"""
Bulk imports contacts.
Matches Company by Name (creates if missing).
Dedupes Contact by Email.
"""
logger.info(f"Starting bulk contact import: {len(req.contacts)} items")
stats = {"added": 0, "skipped": 0, "companies_created": 0}
for item in req.contacts:
if not item.company_name: continue
# 1. Find or Create Company
company = db.query(Company).filter(Company.name.ilike(item.company_name.strip())).first()
if not company:
company = Company(name=item.company_name.strip(), status="NEW")
db.add(company)
db.commit() # Commit to get ID
db.refresh(company)
stats["companies_created"] += 1
# 2. Check for Duplicate Contact (by Email)
if item.email:
exists = db.query(Contact).filter(Contact.email == item.email.strip()).first()
if exists:
stats["skipped"] += 1
continue
# 3. Create Contact
new_contact = Contact(
company_id=company.id,
first_name=item.first_name,
last_name=item.last_name,
email=item.email,
job_title=item.job_title,
role=item.role,
gender=item.gender,
status="Init" # Default status
)
db.add(new_contact)
stats["added"] += 1
db.commit()
return stats
@app.post("/api/enrichment/{company_id}/{source_type}/lock")
def lock_enrichment(company_id: int, source_type: str, locked: bool = Query(...), db: Session = Depends(get_db)):
"""
Toggles the lock status of a specific enrichment data type (e.g. 'website_scrape', 'wikipedia').
"""
entry = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company_id,
EnrichmentData.source_type == source_type
).first()
if not entry:
raise HTTPException(404, "Enrichment data not found")
entry.is_locked = locked
db.commit()
return {"status": "updated", "is_locked": locked}
def run_discovery_task(company_id: int): def run_discovery_task(company_id: int):
# New Session for Background Task
from .database import SessionLocal from .database import SessionLocal
db = SessionLocal() db = SessionLocal()
try: try:
company = db.query(Company).filter(Company.id == company_id).first() company = db.query(Company).filter(Company.id == company_id).first()
if not company: return if not company: return
logger.info(f"Running Discovery Task for {company.name}") # 1. Website Search
# 1. Website Search (Always try if missing)
if not company.website or company.website == "k.A.": if not company.website or company.website == "k.A.":
found_url = discovery.find_company_website(company.name, company.city) found_url = discovery.find_company_website(company.name, company.city)
if found_url and found_url != "k.A.": if found_url and found_url != "k.A.":
company.website = found_url company.website = found_url
logger.info(f"-> Found URL: {found_url}")
# 2. Wikipedia Search & Extraction # 2. Wikipedia Search
# Check if locked
existing_wiki = db.query(EnrichmentData).filter( existing_wiki = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company.id, EnrichmentData.company_id == company.id,
EnrichmentData.source_type == "wikipedia" EnrichmentData.source_type == "wikipedia"
).first() ).first()
if existing_wiki and existing_wiki.is_locked: if not existing_wiki or not existing_wiki.is_locked:
logger.info(f"Skipping Wiki Discovery for {company.name} - Data is LOCKED.") wiki_url = discovery.find_wikipedia_url(company.name, website=company.website, city=company.city)
else: wiki_data = discovery.extract_wikipedia_data(wiki_url) if wiki_url and wiki_url != "k.A." else {"url": wiki_url}
# Pass available info for better validation
current_website = company.website if company.website and company.website != "k.A." else None
wiki_url = discovery.find_wikipedia_url(company.name, website=current_website, city=company.city)
company.last_wiki_search_at = datetime.utcnow()
wiki_data = {"url": wiki_url}
if wiki_url and wiki_url != "k.A.":
logger.info(f"Extracting full data from Wikipedia for {company.name}...")
wiki_data = discovery.extract_wikipedia_data(wiki_url)
if not existing_wiki: if not existing_wiki:
db.add(EnrichmentData(company_id=company.id, source_type="wikipedia", content=wiki_data)) db.add(EnrichmentData(company_id=company.id, source_type="wikipedia", content=wiki_data))
else: else:
@@ -686,35 +170,12 @@ def run_discovery_task(company_id: int):
company.status = "DISCOVERED" company.status = "DISCOVERED"
db.commit() db.commit()
logger.info(f"Discovery finished for {company.id}")
except Exception as e: except Exception as e:
logger.error(f"Background Task Error: {e}", exc_info=True) logger.error(f"Discovery Task Error: {e}", exc_info=True)
db.rollback()
finally: finally:
db.close() db.close()
@app.post("/api/enrich/analyze") def run_analysis_task(company_id: int):
def analyze_company(req: AnalysisRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
company = db.query(Company).filter(Company.id == req.company_id).first()
if not company:
raise HTTPException(404, "Company not found")
if not company.website or company.website == "k.A.":
return {"error": "No website to analyze. Run Discovery first."}
# FORCE SCRAPE LOGIC
# Respect Locked Data: Only delete if not locked.
db.query(EnrichmentData).filter(
EnrichmentData.company_id == company.id,
EnrichmentData.source_type == "website_scrape",
EnrichmentData.is_locked == False
).delete()
db.commit()
background_tasks.add_task(run_analysis_task, company.id, company.website)
return {"status": "queued"}
def run_analysis_task(company_id: int, url: str):
from .database import SessionLocal from .database import SessionLocal
db = SessionLocal() db = SessionLocal()
try: try:
@@ -723,158 +184,42 @@ def run_analysis_task(company_id: int, url: str):
logger.info(f"Running Analysis Task for {company.name}") logger.info(f"Running Analysis Task for {company.name}")
# 1. Scrape Website OR Use Locked Data # 1. Scrape Website (if not locked)
scrape_result = {}
existing_scrape = db.query(EnrichmentData).filter( existing_scrape = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company.id, EnrichmentData.company_id == company.id,
EnrichmentData.source_type == "website_scrape" EnrichmentData.source_type == "website_scrape"
).first() ).first()
if existing_scrape and existing_scrape.is_locked: if not existing_scrape or not existing_scrape.is_locked:
logger.info(f"Using LOCKED scrape data for {company.name}") from .services.scraping import ScraperService
scrape_result = dict(existing_scrape.content) # Copy dict scrape_res = ScraperService().scrape_url(company.website)
if not existing_scrape:
# Always ensure city/country from locked impressum data is synced to company db.add(EnrichmentData(company_id=company.id, source_type="website_scrape", content=scrape_res))
if "impressum" in scrape_result and scrape_result["impressum"]:
impressum_city = scrape_result["impressum"].get("city")
impressum_country = scrape_result["impressum"].get("country_code")
logger.info(f"Analysis task (locked data): Impressum found. City='{impressum_city}', Country='{impressum_country}'")
if impressum_city and company.city != impressum_city:
logger.info(f"Analysis task: Updating company.city from '{company.city}' to '{impressum_city}'")
company.city = impressum_city
if impressum_country and company.country != impressum_country:
logger.info(f"Analysis task: Updating company.country from '{company.country}' to '{impressum_country}'")
company.country = impressum_country
text_val = scrape_result.get("text")
text_len = len(text_val) if text_val else 0
logger.info(f"Locked data keys: {list(scrape_result.keys())}, Text length: {text_len}")
# AUTO-FIX: If locked data (e.g. Manual Impressum) has no text, fetch main website text
if text_len < 100:
logger.info(f"Locked data missing text (len={text_len}). Fetching content from {url}...")
try:
fresh_scrape = scraper.scrape_url(url)
except Exception as e:
logger.error(f"Fresh scrape failed: {e}", exc_info=True)
fresh_scrape = {}
logger.info(f"Fresh scrape result keys: {list(fresh_scrape.keys())}")
if "text" in fresh_scrape and len(fresh_scrape["text"]) > 100:
logger.info(f"Fresh scrape successful. Text len: {len(fresh_scrape['text'])}")
# Update local dict for current processing
scrape_result["text"] = fresh_scrape["text"]
scrape_result["title"] = fresh_scrape.get("title", "")
# Update DB (Merge into existing content)
updated_content = dict(existing_scrape.content)
updated_content["text"] = fresh_scrape["text"]
updated_content["title"] = fresh_scrape.get("title", "")
existing_scrape.content = updated_content
existing_scrape.updated_at = datetime.utcnow()
# db.commit() here would be too early
logger.info("Updated locked record with fresh website text in session.")
else:
logger.warning(f"Fresh scrape returned insufficient text. Error: {fresh_scrape.get('error')}")
else:
# Standard Scrape
scrape_result = scraper.scrape_url(url)
# Update company fields from impressum if found during scrape
if "impressum" in scrape_result and scrape_result["impressum"]:
impressum_city = scrape_result["impressum"].get("city")
impressum_country = scrape_result["impressum"].get("country_code")
logger.info(f"Analysis task (standard scrape): Impressum found. City='{impressum_city}', Country='{impressum_country}'")
if impressum_city and company.city != impressum_city:
logger.info(f"Analysis task: Updating company.city from '{company.city}' to '{impressum_city}'")
company.city = impressum_city
if impressum_country and company.country != impressum_country:
logger.info(f"Analysis task: Updating company.country from '{company.country}' to '{impressum_country}'")
company.country = impressum_country
# Save Scrape Data
if "text" in scrape_result and scrape_result["text"]:
if not existing_scrape:
db.add(EnrichmentData(company_id=company.id, source_type="website_scrape", content=scrape_result))
else:
existing_scrape.content = scrape_result
existing_scrape.updated_at = datetime.utcnow()
elif "error" in scrape_result:
logger.warning(f"Scraping failed for {company.name}: {scrape_result['error']}")
# 2. Classify Robotics Potential
text_content = scrape_result.get("text")
logger.info(f"Preparing classification. Text content length: {len(text_content) if text_content else 0}")
if text_content and len(text_content) > 100:
logger.info(f"Starting classification for {company.name}...")
analysis = classifier.analyze_robotics_potential(
company_name=company.name,
website_text=text_content
)
if "error" in analysis:
logger.error(f"Robotics classification failed for {company.name}: {analysis['error']}")
else: else:
industry = analysis.get("industry") existing_scrape.content = scrape_res
if industry: existing_scrape.updated_at = datetime.utcnow()
company.industry_ai = industry db.commit()
db.query(Signal).filter(Signal.company_id == company.id).delete()
potentials = analysis.get("potentials", {})
for signal_type, data in potentials.items():
new_signal = Signal(
company_id=company.id,
signal_type=f"robotics_{signal_type}_potential",
confidence=data.get("score", 0),
value="High" if data.get("score", 0) > 70 else "Medium" if data.get("score", 0) > 30 else "Low",
proof_text=data.get("reason")
)
db.add(new_signal)
existing_analysis = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company.id,
EnrichmentData.source_type == "ai_analysis"
).first()
if not existing_analysis:
db.add(EnrichmentData(company_id=company.id, source_type="ai_analysis", content=analysis))
else:
existing_analysis.content = analysis
existing_analysis.updated_at = datetime.utcnow()
company.status = "ENRICHED"
company.last_classification_at = datetime.utcnow()
logger.info(f"Robotics analysis complete for {company.name}.")
else:
logger.warning(f"Skipping classification for {company.name}: Insufficient text content (len={len(text_content) if text_content else 0})")
# 2. Classify Industry & Metrics
# IMPORTANT: Using the new method name and passing db session
classifier.classify_company_potential(company, db)
company.status = "ENRICHED"
db.commit() db.commit()
logger.info(f"Analysis finished for {company.id}") logger.info(f"Analysis complete for {company.name}")
except Exception as e: except Exception as e:
logger.error(f"Analyze Task Error: {e}", exc_info=True) logger.error(f"Analyze Task Error: {e}", exc_info=True)
db.rollback()
finally: finally:
db.close() db.close()
# --- Serve Frontend --- # --- Serve Frontend ---
# Priority 1: Container Path (outside of /app volume)
static_path = "/frontend_static" static_path = "/frontend_static"
# Priority 2: Local Dev Path (relative to this file)
if not os.path.exists(static_path): if not os.path.exists(static_path):
static_path = os.path.join(os.path.dirname(__file__), "../static") static_path = os.path.join(os.path.dirname(__file__), "../static")
if os.path.exists(static_path): if os.path.exists(static_path):
logger.info(f"Serving frontend from {static_path}")
app.mount("/", StaticFiles(directory=static_path, html=True), name="static") app.mount("/", StaticFiles(directory=static_path, html=True), name="static")
else:
logger.warning(f"Frontend static files not found at {static_path} or local fallback.")
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run("backend.app:app", host="0.0.0.0", port=8000, reload=True) uvicorn.run("backend.app:app", host="0.0.0.0", port=8000, reload=True)

View File

@@ -1,4 +1,3 @@
import sqlite3 import sqlite3
import sys import sys
import os import os
@@ -23,41 +22,53 @@ def get_table_columns(cursor, table_name):
cursor.execute(f"PRAGMA table_info({table_name})") cursor.execute(f"PRAGMA table_info({table_name})")
return [row[1] for row in cursor.fetchall()] return [row[1] for row in cursor.fetchall()]
def migrate_industries_table(): def migrate_tables():
""" """
Adds the new schema columns to the 'industries' table if they don't exist. Adds new columns to existing tables to support v0.7.0 features.
This ensures backward compatibility with older database files.
""" """
logger.info(f"Connecting to database at {DB_FILE} to run migrations...") logger.info(f"Connecting to database at {DB_FILE} to run migrations...")
conn = get_db_connection() conn = get_db_connection()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
# 1. Update INDUSTRIES Table
logger.info("Checking 'industries' table schema...") logger.info("Checking 'industries' table schema...")
columns = get_table_columns(cursor, "industries") ind_columns = get_table_columns(cursor, "industries")
logger.info(f"Found existing columns: {columns}")
ind_migrations = {
migrations_to_run = {
"metric_type": "TEXT", "metric_type": "TEXT",
"scraper_search_term": "TEXT", "scraper_search_term": "TEXT",
"standardization_logic": "TEXT", "standardization_logic": "TEXT",
"proxy_factor": "FLOAT" "proxy_factor": "FLOAT",
# min_requirement, whale_threshold, scraper_keywords already exist from v0.6.0 "scraper_keywords": "TEXT",
"scraper_search_term": "TEXT"
} }
for col, col_type in migrations_to_run.items(): for col, col_type in ind_migrations.items():
if col not in columns: if col not in ind_columns:
logger.info(f"Adding column '{col}' ({col_type}) to 'industries' table...") logger.info(f"Adding column '{col}' to 'industries' table...")
cursor.execute(f"ALTER TABLE industries ADD COLUMN {col} {col_type}") cursor.execute(f"ALTER TABLE industries ADD COLUMN {col} {col_type}")
else:
logger.info(f"Column '{col}' already exists. Skipping.")
# Also, we need to handle the removal of old columns if necessary (safer to leave them) # 2. Update COMPANIES Table (New for v0.7.0)
# We will also fix the proxy_factor type if it was TEXT logger.info("Checking 'companies' table schema...")
# This is more complex, for now let's just add. comp_columns = get_table_columns(cursor, "companies")
comp_migrations = {
"calculated_metric_name": "TEXT",
"calculated_metric_value": "FLOAT",
"calculated_metric_unit": "TEXT",
"standardized_metric_value": "FLOAT",
"standardized_metric_unit": "TEXT",
"metric_source": "TEXT"
}
for col, col_type in comp_migrations.items():
if col not in comp_columns:
logger.info(f"Adding column '{col}' to 'companies' table...")
cursor.execute(f"ALTER TABLE companies ADD COLUMN {col} {col_type}")
conn.commit() conn.commit()
logger.info("Migrations for 'industries' table completed successfully.") logger.info("All migrations completed successfully.")
except Exception as e: except Exception as e:
logger.error(f"An error occurred during migration: {e}", exc_info=True) logger.error(f"An error occurred during migration: {e}", exc_info=True)
@@ -65,9 +76,8 @@ def migrate_industries_table():
finally: finally:
conn.close() conn.close()
if __name__ == "__main__": if __name__ == "__main__":
if not os.path.exists(DB_FILE): if not os.path.exists(DB_FILE):
logger.error(f"Database file not found at {DB_FILE}. Cannot run migration. Please ensure the old database is in place.") logger.error(f"Database file not found at {DB_FILE}.")
else: else:
migrate_industries_table() migrate_tables()

View File

@@ -5,59 +5,39 @@ from typing import Optional, Dict, Any, List
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from backend.database import Company, Industry, RoboticsCategory, EnrichmentData, get_db from backend.database import Company, Industry, RoboticsCategory, EnrichmentData
from backend.config import settings
from backend.lib.core_utils import call_gemini_flash, safe_eval_math, run_serp_search from backend.lib.core_utils import call_gemini_flash, safe_eval_math, run_serp_search
from backend.services.scraping import scrape_website_content # Corrected import from backend.services.scraping import scrape_website_content
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ClassificationService: class ClassificationService:
def __init__(self, db: Session): def __init__(self):
self.db = db # We no longer load industries in init because we don't have a DB session here
self.allowed_industries_notion: List[Industry] = self._load_industry_definitions() pass
self.robotics_categories: List[RoboticsCategory] = self._load_robotics_categories()
# Pre-process allowed industries for LLM prompt
self.llm_industry_definitions = [
{"name": ind.name, "description": ind.description} for ind in self.allowed_industries_notion
]
# Store for quick lookup
self.industry_lookup = {ind.name: ind for ind in self.allowed_industries_notion}
self.category_lookup = {cat.id: cat for cat in self.robotics_categories}
def _load_industry_definitions(self) -> List[Industry]: def _load_industry_definitions(self, db: Session) -> List[Industry]:
"""Loads all industry definitions from the database.""" """Loads all industry definitions from the database."""
industries = self.db.query(Industry).all() industries = db.query(Industry).all()
if not industries: if not industries:
logger.warning("No industry definitions found in DB. Classification might be limited.") logger.warning("No industry definitions found in DB. Classification might be limited.")
return industries return industries
def _load_robotics_categories(self) -> List[RoboticsCategory]: def _get_wikipedia_content(self, db: Session, company_id: int) -> Optional[str]:
"""Loads all robotics categories from the database."""
categories = self.db.query(RoboticsCategory).all()
if not categories:
logger.warning("No robotics categories found in DB. Potential scoring might be limited.")
return categories
def _get_wikipedia_content(self, company_id: int) -> Optional[str]:
"""Fetches Wikipedia content from enrichment_data for a given company.""" """Fetches Wikipedia content from enrichment_data for a given company."""
enrichment = self.db.query(EnrichmentData).filter( enrichment = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company_id, EnrichmentData.company_id == company_id,
EnrichmentData.source_type == "wikipedia" EnrichmentData.source_type == "wikipedia"
).order_by(EnrichmentData.created_at.desc()).first() ).order_by(EnrichmentData.created_at.desc()).first()
if enrichment and enrichment.content: if enrichment and enrichment.content:
# Wikipedia content is stored as JSON with a 'text' key
wiki_data = enrichment.content wiki_data = enrichment.content
return wiki_data.get('text') return wiki_data.get('text')
return None return None
def _run_llm_classification_prompt(self, website_text: str, company_name: str) -> Optional[str]: def _run_llm_classification_prompt(self, website_text: str, company_name: str, industry_definitions: List[Dict[str, str]]) -> Optional[str]:
""" """
Uses LLM to classify the company into one of the predefined industries. Uses LLM to classify the company into one of the predefined industries.
Returns the industry name (string) or "Others".
""" """
prompt = r""" prompt = r"""
Du bist ein präziser Branchen-Klassifizierer für Unternehmen. Du bist ein präziser Branchen-Klassifizierer für Unternehmen.
@@ -79,31 +59,23 @@ class ClassificationService:
Gib NUR den Namen der zugeordneten Branche zurück, als reinen String, nichts anderes. Gib NUR den Namen der zugeordneten Branche zurück, als reinen String, nichts anderes.
Beispiel Output: Hotellerie Beispiel Output: Hotellerie
Beispiel Output: Automotive - Dealer
Beispiel Output: Others
""".format( """.format(
company_name=company_name, company_name=company_name,
website_text_excerpt=website_text[:10000], # Limit text to avoid token limits website_text_excerpt=website_text[:10000],
industry_definitions_json=json.dumps(self.llm_industry_definitions, ensure_ascii=False) industry_definitions_json=json.dumps(industry_definitions, ensure_ascii=False)
) )
try: try:
response = call_gemini_flash(prompt, temperature=0.1, json_mode=False) # Low temp for strict classification response = call_gemini_flash(prompt, temperature=0.1, json_mode=False)
classified_industry = response.strip() return response.strip()
if classified_industry in [ind.name for ind in self.allowed_industries_notion] + ["Others"]:
return classified_industry
logger.warning(f"LLM classified industry '{classified_industry}' not in allowed list. Defaulting to Others.")
return "Others"
except Exception as e: except Exception as e:
logger.error(f"LLM classification failed for {company_name}: {e}", exc_info=True) logger.error(f"LLM classification failed for {company_name}: {e}")
return None return None
def _run_llm_metric_extraction_prompt(self, text_content: str, search_term: str, industry_name: str) -> Optional[Dict[str, Any]]: def _run_llm_metric_extraction_prompt(self, text_content: str, search_term: str, industry_name: str) -> Optional[Dict[str, Any]]:
""" """
Uses LLM to extract the specific metric value from text. Uses LLM to extract the specific metric value from text.
Returns a dict with 'raw_value', 'raw_unit', 'standardized_value' (if found), 'metric_name'.
""" """
# Attempt to extract both the raw unit count and a potential area if explicitly mentioned
prompt = r""" prompt = r"""
Du bist ein Datenextraktions-Spezialist. Du bist ein Datenextraktions-Spezialist.
Analysiere den folgenden Text, um spezifische Metrik-Informationen zu extrahieren. Analysiere den folgenden Text, um spezifische Metrik-Informationen zu extrahieren.
@@ -119,63 +91,42 @@ class ClassificationService:
1. Finde den numerischen Wert für '{search_term}'. 1. Finde den numerischen Wert für '{search_term}'.
2. Versuche auch, eine explizit genannte Gesamtfläche in Quadratmetern (m²) zu finden, falls relevant und vorhanden. 2. Versuche auch, eine explizit genannte Gesamtfläche in Quadratmetern (m²) zu finden, falls relevant und vorhanden.
Gib NUR ein JSON-Objekt zurück mit den Schlüsseln: Gib NUR ein JSON-Objekt zurück:
'raw_value': Der gefundene numerische Wert für '{search_term}' (als Zahl). null, falls nicht gefunden. 'raw_value': Der gefundene numerische Wert für '{search_term}' (als Zahl). null, falls nicht gefunden.
'raw_unit': Die Einheit des raw_value (z.B. "Betten", "Stellplätze"). null, falls nicht gefunden. 'raw_unit': Die Einheit des raw_value (z.B. "Betten", "Stellplätze"). null, falls nicht gefunden.
'area_value': Ein gefundener numerischer Wert für eine Gesamtfläche in m² (als Zahl). null, falls nicht gefunden. 'area_value': Ein gefundener numerischer Wert für eine Gesamtfläche in m² (als Zahl). null, falls nicht gefunden.
'metric_name': Der Name der Metrik, nach der gesucht wurde (also '{search_term}'). 'metric_name': Der Name der Metrik, nach der gesucht wurde (also '{search_term}').
Beispiel Output (wenn 180 Betten und 4500m² Fläche gefunden):
{{"raw_value": 180, "raw_unit": "Betten", "area_value": 4500, "metric_name": "{search_term}"}}
Beispiel Output (wenn nur 180 Betten gefunden):
{{"raw_value": 180, "raw_unit": "Betten", "area_value": null, "metric_name": "{search_term}"}}
Beispiel Output (wenn nichts gefunden):
{{"raw_value": null, "raw_unit": null, "area_value": null, "metric_name": "{search_term}"}}
""".format( """.format(
industry_name=industry_name, industry_name=industry_name,
search_term=search_term, search_term=search_term,
text_content_excerpt=text_content[:15000] # Adjust as needed for token limits text_content_excerpt=text_content[:15000]
) )
try: try:
response = call_gemini_flash(prompt, temperature=0.05, json_mode=True) # Very low temp for extraction response = call_gemini_flash(prompt, temperature=0.05, json_mode=True)
result = json.loads(response) return json.loads(response)
return result
except Exception as e: except Exception as e:
logger.error(f"LLM metric extraction failed for '{search_term}' in '{industry_name}': {e}", exc_info=True) logger.error(f"LLM metric extraction failed for '{search_term}': {e}")
return None return None
def _parse_standardization_logic(self, formula: str, raw_value: float) -> Optional[float]: def _parse_standardization_logic(self, formula: str, raw_value: float) -> Optional[float]:
""" if not formula or raw_value is None:
Safely parses and executes a simple mathematical formula for standardization.
Supports basic arithmetic (+, -, *, /) and integer/float values.
"""
if not formula or not raw_value:
return None return None
formula_cleaned = formula.replace("wert", str(raw_value)).replace("Value", str(raw_value))
# Replace 'wert' or 'value' with the actual raw_value
formula_cleaned = formula.replace("wert", str(raw_value)).replace("Value", str(raw_value)).replace("VALUE", str(raw_value))
try: try:
# Use safe_eval_math from core_utils to prevent arbitrary code execution
return safe_eval_math(formula_cleaned) return safe_eval_math(formula_cleaned)
except Exception as e: except:
logger.error(f"Error evaluating standardization logic '{formula}' with value {raw_value}: {e}", exc_info=True)
return None return None
def _extract_and_calculate_metric_cascade( def _extract_and_calculate_metric_cascade(
self, self,
db: Session,
company: Company, company: Company,
industry_name: str, industry_name: str,
search_term: str, search_term: str,
standardization_logic: Optional[str], standardization_logic: Optional[str],
standardized_unit: Optional[str] standardized_unit: Optional[str]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""
Orchestrates the 3-stage (Website -> Wikipedia -> SerpAPI) metric extraction.
"""
results = { results = {
"calculated_metric_name": search_term, "calculated_metric_name": search_term,
"calculated_metric_value": None, "calculated_metric_value": None,
@@ -185,150 +136,71 @@ class ClassificationService:
"metric_source": None "metric_source": None
} }
# --- STAGE 1: Website Analysis --- # CASCADE: Website -> Wikipedia -> SerpAPI
logger.info(f"Stage 1: Analyzing website for '{search_term}' for {company.name}") sources = [
website_content = scrape_website_content(company.website) ("website", lambda: scrape_website_content(company.website)),
if website_content: ("wikipedia", lambda: self._get_wikipedia_content(db, company.id)),
llm_result = self._run_llm_metric_extraction_prompt(website_content, search_term, industry_name) ("serpapi", lambda: " ".join([res.get("snippet", "") for res in run_serp_search(f"{company.name} {search_term} {industry_name}").get("organic_results", [])]) if run_serp_search(f"{company.name} {search_term} {industry_name}") else None)
if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None): ]
results["calculated_metric_value"] = llm_result.get("raw_value")
results["calculated_metric_unit"] = llm_result.get("raw_unit")
results["metric_source"] = "website"
if llm_result.get("area_value") is not None: for source_name, content_loader in sources:
# Prioritize directly found standardized area logger.info(f"Checking {source_name} for '{search_term}' for {company.name}")
results["standardized_metric_value"] = llm_result.get("area_value") try:
logger.info(f"Direct area value found on website for {company.name}: {llm_result.get('area_value')}") content = content_loader()
elif llm_result.get("raw_value") is not None and standardization_logic: if not content: continue
# Calculate if only raw value found
results["standardized_metric_value"] = self._parse_standardization_logic( llm_result = self._run_llm_metric_extraction_prompt(content, search_term, industry_name)
standardization_logic, llm_result["raw_value"]
)
return results
# --- STAGE 2: Wikipedia Analysis ---
logger.info(f"Stage 2: Analyzing Wikipedia for '{search_term}' for {company.name}")
wikipedia_content = self._get_wikipedia_content(company.id)
if wikipedia_content:
llm_result = self._run_llm_metric_extraction_prompt(wikipedia_content, search_term, industry_name)
if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None):
results["calculated_metric_value"] = llm_result.get("raw_value")
results["calculated_metric_unit"] = llm_result.get("raw_unit")
results["metric_source"] = "wikipedia"
if llm_result.get("area_value") is not None:
results["standardized_metric_value"] = llm_result.get("area_value")
logger.info(f"Direct area value found on Wikipedia for {company.name}: {llm_result.get('area_value')}")
elif llm_result.get("raw_value") is not None and standardization_logic:
results["standardized_metric_value"] = self._parse_standardization_logic(
standardization_logic, llm_result["raw_value"]
)
return results
# --- STAGE 3: SerpAPI (Google Search) ---
logger.info(f"Stage 3: Running SerpAPI search for '{search_term}' for {company.name}")
search_query = f"{company.name} {search_term} {industry_name}" # Example: "Hotel Moxy Würzburg Anzahl Betten Hotellerie"
serp_results = run_serp_search(search_query) # This returns a dictionary of search results
if serp_results and serp_results.get("organic_results"):
# Concatenate snippets from organic results
snippets = " ".join([res.get("snippet", "") for res in serp_results["organic_results"]])
if snippets:
llm_result = self._run_llm_metric_extraction_prompt(snippets, search_term, industry_name)
if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None): if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None):
results["calculated_metric_value"] = llm_result.get("raw_value") results["calculated_metric_value"] = llm_result.get("raw_value")
results["calculated_metric_unit"] = llm_result.get("raw_unit") results["calculated_metric_unit"] = llm_result.get("raw_unit")
results["metric_source"] = "serpapi" results["metric_source"] = source_name
if llm_result.get("area_value") is not None: if llm_result.get("area_value") is not None:
results["standardized_metric_value"] = llm_result.get("area_value") results["standardized_metric_value"] = llm_result.get("area_value")
logger.info(f"Direct area value found via SerpAPI for {company.name}: {llm_result.get('area_value')}")
elif llm_result.get("raw_value") is not None and standardization_logic: elif llm_result.get("raw_value") is not None and standardization_logic:
results["standardized_metric_value"] = self._parse_standardization_logic( results["standardized_metric_value"] = self._parse_standardization_logic(standardization_logic, llm_result["raw_value"])
standardization_logic, llm_result["raw_value"]
)
return results return results
except Exception as e:
logger.info(f"Could not extract metric for '{search_term}' from any source for {company.name}.") logger.error(f"Error in {source_name} stage: {e}")
return results # Return results with None values
def classify_company_potential(self, company: Company) -> Company: return results
"""
Main method to classify industry and calculate potential metric for a company.
"""
logger.info(f"Starting classification for Company ID: {company.id}, Name: {company.name}")
# --- STEP 1: Strict Industry Classification --- def classify_company_potential(self, company: Company, db: Session) -> Company:
website_content_for_classification = scrape_website_content(company.website) logger.info(f"Starting classification for {company.name}")
if not website_content_for_classification:
logger.warning(f"No website content found for {company.name}. Skipping industry classification.") # 1. Load Industries
company.industry_ai = "Others" # Default if no content industries = self._load_industry_definitions(db)
industry_defs = [{"name": i.name, "description": i.description} for i in industries]
# 2. Industry Classification
website_content = scrape_website_content(company.website)
if website_content:
industry_name = self._run_llm_classification_prompt(website_content, company.name, industry_defs)
company.industry_ai = industry_name if industry_name in [i.name for i in industries] else "Others"
else: else:
classified_industry_name = self._run_llm_classification_prompt(website_content_for_classification, company.name) company.industry_ai = "Others"
if classified_industry_name:
company.industry_ai = classified_industry_name
logger.info(f"Classified {company.name} into industry: {classified_industry_name}")
else:
company.industry_ai = "Others"
logger.warning(f"Failed to classify industry for {company.name}. Setting to 'Others'.")
self.db.add(company) # Update industry_ai db.commit()
self.db.commit()
self.db.refresh(company)
# --- STEP 2: Metric Extraction & Standardization (if not 'Others') --- # 3. Metric Extraction
if company.industry_ai == "Others" or company.industry_ai is None: if company.industry_ai != "Others":
logger.info(f"Company {company.name} classified as 'Others'. Skipping metric extraction.") industry = next((i for i in industries if i.name == company.industry_ai), None)
return company if industry and industry.scraper_search_term:
# Derive standardized unit
std_unit = "" if "" in (industry.standardization_logic or "") else "Einheiten"
metrics = self._extract_and_calculate_metric_cascade(
db, company, company.industry_ai, industry.scraper_search_term, industry.standardization_logic, std_unit
)
company.calculated_metric_name = metrics["calculated_metric_name"]
company.calculated_metric_value = metrics["calculated_metric_value"]
company.calculated_metric_unit = metrics["calculated_metric_unit"]
company.standardized_metric_value = metrics["standardized_metric_value"]
company.standardized_metric_unit = metrics["standardized_metric_unit"]
company.metric_source = metrics["metric_source"]
industry_definition = self.industry_lookup.get(company.industry_ai) company.last_classification_at = datetime.utcnow()
if not industry_definition: db.commit()
logger.error(f"Industry definition for '{company.industry_ai}' not found in lookup. Skipping metric extraction.")
return company
if not industry_definition.scraper_search_term:
logger.info(f"Industry '{company.industry_ai}' has no 'Scraper Search Term'. Skipping metric extraction.")
return company
# Determine standardized unit from standardization_logic if possible
standardized_unit = "Einheiten" # Default
if industry_definition.standardization_logic:
# Example: "wert * 25m² (Fläche pro Zimmer)" -> extract "m²"
match = re.search(r'(\w+)$', industry_definition.standardization_logic.replace(' ', ''))
if match:
standardized_unit = match.group(1).replace('(', '').replace(')', '') # Extract unit like "m²"
metric_results = self._extract_and_calculate_metric_cascade(
company,
company.industry_ai,
industry_definition.scraper_search_term,
industry_definition.standardization_logic,
standardized_unit # Pass the derived unit
)
# Update company object with results
company.calculated_metric_name = metric_results["calculated_metric_name"]
company.calculated_metric_value = metric_results["calculated_metric_value"]
company.calculated_metric_unit = metric_results["calculated_metric_unit"]
company.standardized_metric_value = metric_results["standardized_metric_value"]
company.standardized_metric_unit = metric_results["standardized_metric_unit"]
company.metric_source = metric_results["metric_source"]
company.last_classification_at = datetime.utcnow() # Update timestamp
self.db.add(company)
self.db.commit()
self.db.refresh(company) # Refresh to get updated values
logger.info(f"Classification and metric extraction completed for {company.name}.")
return company return company
# --- HELPER FOR SAFE MATH EVALUATION (Moved from core_utils.py or assumed to be there) ---
# Assuming safe_eval_math is available via backend.lib.core_utils.safe_eval_math
# Example implementation if not:
# def safe_eval_math(expression: str) -> float:
# # Implement a safe parser/evaluator for simple math expressions
# # For now, a very basic eval might be used, but in production, this needs to be locked down
# allowed_chars = "0123456789.+-*/ "
# if not all(c in allowed_chars for c in expression):
# raise ValueError("Expression contains disallowed characters.")
# return eval(expression)

View File

@@ -267,3 +267,15 @@ class ScraperService:
except Exception as e: except Exception as e:
logger.error(f"Critical error in _parse_html: {e}", exc_info=True) logger.error(f"Critical error in _parse_html: {e}", exc_info=True)
return {"title": "", "description": "", "text": "", "emails": [], "error": str(e)} return {"title": "", "description": "", "text": "", "emails": [], "error": str(e)}
# --- HELPER FUNCTION FOR EXTERNAL USE ---
def scrape_website_content(url: str) -> Optional[str]:
"""
Simple wrapper to get just the text content of a URL.
Used by ClassificationService.
"""
scraper = ScraperService()
result = scraper.scrape_url(url)
if result and result.get("text"):
return result["text"]
return None