fix(explorer): resolve initialization and import errors for v0.7.0 backend
This commit is contained in:
@@ -41,7 +41,7 @@ app.add_middleware(
|
||||
|
||||
# Service Singletons
|
||||
scraper = ScraperService()
|
||||
classifier = ClassificationService()
|
||||
classifier = ClassificationService() # Now works without args
|
||||
discovery = DiscoveryService()
|
||||
|
||||
# --- Pydantic Models ---
|
||||
@@ -58,33 +58,6 @@ class AnalysisRequest(BaseModel):
|
||||
company_id: int
|
||||
force_scrape: bool = False
|
||||
|
||||
class ContactBase(BaseModel):
|
||||
gender: str
|
||||
title: str = ""
|
||||
first_name: str
|
||||
last_name: str
|
||||
email: str
|
||||
job_title: str
|
||||
language: str = "De"
|
||||
role: str
|
||||
status: str = ""
|
||||
is_primary: bool = False
|
||||
|
||||
class ContactCreate(ContactBase):
|
||||
company_id: int
|
||||
|
||||
class ContactUpdate(BaseModel):
|
||||
gender: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
job_title: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
is_primary: Optional[bool] = None
|
||||
|
||||
# --- Events ---
|
||||
@app.on_event("startup")
|
||||
def on_startup():
|
||||
@@ -115,8 +88,6 @@ def list_companies(
|
||||
query = query.filter(Company.name.ilike(f"%{search}%"))
|
||||
|
||||
total = query.count()
|
||||
|
||||
# Sorting Logic
|
||||
if sort_by == "updated_desc":
|
||||
query = query.order_by(Company.updated_at.desc())
|
||||
elif sort_by == "created_desc":
|
||||
@@ -125,7 +96,6 @@ def list_companies(
|
||||
query = query.order_by(Company.name.asc())
|
||||
|
||||
items = query.offset(skip).limit(limit).all()
|
||||
|
||||
return {"total": total, "items": items}
|
||||
except Exception as e:
|
||||
logger.error(f"List Companies Error: {e}", exc_info=True)
|
||||
@@ -134,547 +104,61 @@ def list_companies(
|
||||
@app.get("/api/companies/{company_id}")
|
||||
def get_company(company_id: int, db: Session = Depends(get_db)):
|
||||
company = db.query(Company).options(
|
||||
joinedload(Company.signals),
|
||||
joinedload(Company.enrichment_data),
|
||||
joinedload(Company.contacts)
|
||||
).filter(Company.id == company_id).first()
|
||||
if not company:
|
||||
raise HTTPException(status_code=404, detail="Company not found")
|
||||
raise HTTPException(404, detail="Company not found")
|
||||
return company
|
||||
|
||||
@app.post("/api/companies/bulk")
|
||||
def bulk_import_names(req: BulkImportRequest, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Quick import for testing. Just a list of names.
|
||||
"""
|
||||
logger.info(f"Starting bulk import of {len(req.names)} names.")
|
||||
try:
|
||||
added = 0
|
||||
skipped = 0
|
||||
|
||||
# Deduplicator init
|
||||
try:
|
||||
dedup = Deduplicator(db)
|
||||
logger.info("Deduplicator initialized.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Deduplicator init failed: {e}")
|
||||
dedup = None
|
||||
|
||||
for name in req.names:
|
||||
clean_name = name.strip()
|
||||
if not clean_name: continue
|
||||
|
||||
# 1. Simple Deduplication (Exact Name)
|
||||
exists = db.query(Company).filter(Company.name == clean_name).first()
|
||||
if exists:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# 2. Smart Deduplication (if available)
|
||||
if dedup:
|
||||
matches = dedup.find_duplicates({"name": clean_name})
|
||||
if matches and matches[0]['score'] > 95:
|
||||
logger.info(f"Duplicate found for {clean_name}: {matches[0]['name']}")
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# 3. Create
|
||||
new_comp = Company(
|
||||
name=clean_name,
|
||||
status="NEW" # This triggered the error before
|
||||
)
|
||||
db.add(new_comp)
|
||||
added += 1
|
||||
|
||||
db.commit()
|
||||
logger.info(f"Import success. Added: {added}, Skipped: {skipped}")
|
||||
return {"added": added, "skipped": skipped}
|
||||
except Exception as e:
|
||||
logger.error(f"Bulk Import Failed: {e}", exc_info=True)
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/api/robotics/categories")
|
||||
def list_robotics_categories(db: Session = Depends(get_db)):
|
||||
"""Lists all configured robotics categories."""
|
||||
return db.query(RoboticsCategory).all()
|
||||
|
||||
class CategoryUpdate(BaseModel):
|
||||
description: str
|
||||
reasoning_guide: str
|
||||
|
||||
@app.put("/api/robotics/categories/{id}")
|
||||
def update_robotics_category(id: int, cat: CategoryUpdate, db: Session = Depends(get_db)):
|
||||
"""Updates a robotics category definition."""
|
||||
category = db.query(RoboticsCategory).filter(RoboticsCategory.id == id).first()
|
||||
if not category:
|
||||
raise HTTPException(404, "Category not found")
|
||||
|
||||
category.description = cat.description
|
||||
category.reasoning_guide = cat.reasoning_guide
|
||||
db.commit()
|
||||
return category
|
||||
|
||||
@app.post("/api/enrich/discover")
|
||||
def discover_company(req: AnalysisRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Triggers Stage 1: Discovery (Website Search + Wikipedia Search)
|
||||
"""
|
||||
try:
|
||||
company = db.query(Company).filter(Company.id == req.company_id).first()
|
||||
if not company:
|
||||
raise HTTPException(404, "Company not found")
|
||||
|
||||
# Run in background
|
||||
background_tasks.add_task(run_discovery_task, company.id)
|
||||
|
||||
return {"status": "queued", "message": f"Discovery started for {company.name}"}
|
||||
except Exception as e:
|
||||
logger.error(f"Discovery Error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/companies/{company_id}/override/wiki")
|
||||
def override_wiki_url(company_id: int, url: str = Query(...), db: Session = Depends(get_db)):
|
||||
"""
|
||||
Manually sets the Wikipedia URL for a company and triggers re-extraction.
|
||||
Locks the data against auto-discovery.
|
||||
"""
|
||||
company = db.query(Company).filter(Company.id == company_id).first()
|
||||
if not company:
|
||||
raise HTTPException(404, "Company not found")
|
||||
|
||||
logger.info(f"Manual Override for {company.name}: Setting Wiki URL to {url}")
|
||||
|
||||
# Update or create EnrichmentData entry
|
||||
existing_wiki = db.query(EnrichmentData).filter(
|
||||
EnrichmentData.company_id == company.id,
|
||||
EnrichmentData.source_type == "wikipedia"
|
||||
).first()
|
||||
|
||||
# Extract data immediately
|
||||
wiki_data = {"url": url}
|
||||
if url and url != "k.A.":
|
||||
try:
|
||||
wiki_data = discovery.extract_wikipedia_data(url)
|
||||
wiki_data['url'] = url # Ensure URL is correct
|
||||
except Exception as e:
|
||||
logger.error(f"Extraction failed for manual URL: {e}")
|
||||
wiki_data["error"] = str(e)
|
||||
|
||||
if not existing_wiki:
|
||||
db.add(EnrichmentData(
|
||||
company_id=company.id,
|
||||
source_type="wikipedia",
|
||||
content=wiki_data,
|
||||
is_locked=True
|
||||
))
|
||||
else:
|
||||
existing_wiki.content = wiki_data
|
||||
existing_wiki.updated_at = datetime.utcnow()
|
||||
existing_wiki.is_locked = True # LOCK IT
|
||||
existing_wiki.wiki_verified_empty = False # It's no longer empty
|
||||
|
||||
db.commit()
|
||||
# The return needs to be here, outside the else block but inside the main function
|
||||
return {"status": "updated", "data": wiki_data}
|
||||
|
||||
@app.post("/api/companies/{company_id}/wiki_mark_empty")
|
||||
def mark_wiki_empty(company_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Marks a company as having no valid Wikipedia entry after manual review.
|
||||
Creates a locked, empty Wikipedia enrichment entry.
|
||||
"""
|
||||
company = db.query(Company).filter(Company.id == company_id).first()
|
||||
if not company:
|
||||
raise HTTPException(404, "Company not found")
|
||||
|
||||
logger.info(f"Manual override for {company.name}: Marking Wikipedia as verified empty.")
|
||||
|
||||
existing_wiki = db.query(EnrichmentData).filter(
|
||||
EnrichmentData.company_id == company.id,
|
||||
EnrichmentData.source_type == "wikipedia"
|
||||
).first()
|
||||
|
||||
empty_wiki_data = {"url": "k.A.", "title": "k.A.", "first_paragraph": "k.A.", "error": "Manually marked as empty"}
|
||||
|
||||
if not existing_wiki:
|
||||
db.add(EnrichmentData(
|
||||
company_id=company.id,
|
||||
source_type="wikipedia",
|
||||
content=empty_wiki_data,
|
||||
is_locked=True,
|
||||
wiki_verified_empty=True
|
||||
))
|
||||
else:
|
||||
existing_wiki.content = empty_wiki_data
|
||||
existing_wiki.updated_at = datetime.utcnow()
|
||||
existing_wiki.is_locked = True # LOCK IT
|
||||
existing_wiki.wiki_verified_empty = True # Mark as empty
|
||||
|
||||
db.commit()
|
||||
return {"status": "updated", "wiki_verified_empty": True}
|
||||
|
||||
@app.post("/api/companies/{company_id}/override/website")
|
||||
def override_website_url(company_id: int, url: str = Query(...), db: Session = Depends(get_db)):
|
||||
"""
|
||||
Manually sets the Website URL for a company.
|
||||
Clears existing scrape data to force a fresh analysis on next run.
|
||||
"""
|
||||
company = db.query(Company).filter(Company.id == company_id).first()
|
||||
if not company:
|
||||
raise HTTPException(404, "Company not found")
|
||||
|
||||
logger.info(f"Manual Override for {company.name}: Setting Website to {url}")
|
||||
company.website = url
|
||||
|
||||
# Remove old scrape data since URL changed
|
||||
db.query(EnrichmentData).filter(
|
||||
EnrichmentData.company_id == company.id,
|
||||
EnrichmentData.source_type == "website_scrape"
|
||||
).delete()
|
||||
|
||||
db.commit()
|
||||
return {"status": "updated", "website": url}
|
||||
|
||||
@app.post("/api/companies/{company_id}/override/impressum")
|
||||
def override_impressum_url(company_id: int, url: str = Query(...), db: Session = Depends(get_db)):
|
||||
"""
|
||||
Manually sets the Impressum URL for a company and triggers re-extraction.
|
||||
"""
|
||||
company = db.query(Company).filter(Company.id == company_id).first()
|
||||
if not company:
|
||||
raise HTTPException(404, "Company not found")
|
||||
|
||||
logger.info(f"Manual Override for {company.name}: Setting Impressum URL to {url}")
|
||||
|
||||
# 1. Scrape Impressum immediately
|
||||
impressum_data = scraper._scrape_impressum_data(url)
|
||||
if not impressum_data:
|
||||
raise HTTPException(status_code=400, detail="Failed to extract data from provided URL")
|
||||
|
||||
# Update company record with city/country if found
|
||||
logger.info(f"override_impressum_url: Scraped impressum_data for {company.name}: City={impressum_data.get('city')}, Country_code={impressum_data.get('country_code')}")
|
||||
if city_val := impressum_data.get("city"):
|
||||
logger.info(f"override_impressum_url: Updating company.city from '{company.city}' to '{city_val}'")
|
||||
company.city = city_val
|
||||
if country_val := impressum_data.get("country_code"):
|
||||
logger.info(f"override_impressum_url: Updating company.country from '{company.country}' to '{country_val}'")
|
||||
company.country = country_val
|
||||
logger.info(f"override_impressum_url: Company object after updates (before commit): City='{company.city}', Country='{company.country}'")
|
||||
|
||||
|
||||
# 2. Find existing scrape data or create new
|
||||
existing_scrape = db.query(EnrichmentData).filter(
|
||||
EnrichmentData.company_id == company.id,
|
||||
EnrichmentData.source_type == "website_scrape"
|
||||
).first()
|
||||
|
||||
if not existing_scrape:
|
||||
# Create minimal scrape entry and lock it
|
||||
db.add(EnrichmentData(
|
||||
company_id=company.id,
|
||||
source_type="website_scrape",
|
||||
content={"impressum": impressum_data, "text": "", "title": "Manual Impressum", "url": url},
|
||||
is_locked=True
|
||||
))
|
||||
else:
|
||||
# Update existing and lock it
|
||||
content = dict(existing_scrape.content) if existing_scrape.content else {}
|
||||
content["impressum"] = impressum_data
|
||||
existing_scrape.content = content
|
||||
existing_scrape.updated_at = datetime.utcnow()
|
||||
existing_scrape.is_locked = True
|
||||
|
||||
db.commit()
|
||||
logger.info(f"override_impressum_url: Commit successful. Company ID {company.id} updated.")
|
||||
return {"status": "updated", "data": impressum_data}
|
||||
|
||||
# --- Contact Routes ---
|
||||
|
||||
@app.post("/api/contacts")
|
||||
def create_contact(contact: ContactCreate, db: Session = Depends(get_db)):
|
||||
"""Creates a new contact and handles primary contact logic."""
|
||||
if contact.is_primary:
|
||||
db.query(Contact).filter(Contact.company_id == contact.company_id).update({"is_primary": False})
|
||||
|
||||
db_contact = Contact(**contact.dict())
|
||||
db.add(db_contact)
|
||||
db.commit()
|
||||
db.refresh(db_contact)
|
||||
return db_contact
|
||||
|
||||
# --- Industry Routes ---
|
||||
|
||||
class IndustryCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
is_focus: bool = False
|
||||
primary_category_id: Optional[int] = None
|
||||
|
||||
class IndustryUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
is_focus: Optional[bool] = None
|
||||
primary_category_id: Optional[int] = None
|
||||
|
||||
@app.get("/api/industries")
|
||||
def list_industries(db: Session = Depends(get_db)):
|
||||
return db.query(Industry).all()
|
||||
|
||||
@app.post("/api/industries")
|
||||
def create_industry(ind: IndustryCreate, db: Session = Depends(get_db)):
|
||||
# 1. Prepare data
|
||||
ind_data = ind.dict()
|
||||
base_name = ind_data['name']
|
||||
@app.post("/api/enrich/discover")
|
||||
def discover_company(req: AnalysisRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||||
company = db.query(Company).filter(Company.id == req.company_id).first()
|
||||
if not company: raise HTTPException(404, "Company not found")
|
||||
background_tasks.add_task(run_discovery_task, company.id)
|
||||
return {"status": "queued"}
|
||||
|
||||
# 2. Check for duplicate name
|
||||
existing = db.query(Industry).filter(Industry.name == base_name).first()
|
||||
if existing:
|
||||
# Auto-increment name if duplicated
|
||||
counter = 1
|
||||
while db.query(Industry).filter(Industry.name == f"{base_name} ({counter})").first():
|
||||
counter += 1
|
||||
ind_data['name'] = f"{base_name} ({counter})"
|
||||
@app.post("/api/enrich/analyze")
|
||||
def analyze_company(req: AnalysisRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||||
company = db.query(Company).filter(Company.id == req.company_id).first()
|
||||
if not company: raise HTTPException(404, "Company not found")
|
||||
|
||||
# 3. Create
|
||||
db_ind = Industry(**ind_data)
|
||||
db.add(db_ind)
|
||||
db.commit()
|
||||
db.refresh(db_ind)
|
||||
return db_ind
|
||||
if not company.website or company.website == "k.A.":
|
||||
return {"error": "No website to analyze. Run Discovery first."}
|
||||
|
||||
@app.put("/api/industries/{id}")
|
||||
def update_industry(id: int, ind: IndustryUpdate, db: Session = Depends(get_db)):
|
||||
db_ind = db.query(Industry).filter(Industry.id == id).first()
|
||||
if not db_ind:
|
||||
raise HTTPException(404, "Industry not found")
|
||||
|
||||
for key, value in ind.dict(exclude_unset=True).items():
|
||||
setattr(db_ind, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_ind)
|
||||
return db_ind
|
||||
|
||||
@app.delete("/api/industries/{id}")
|
||||
def delete_industry(id: int, db: Session = Depends(get_db)):
|
||||
db_ind = db.query(Industry).filter(Industry.id == id).first()
|
||||
if not db_ind:
|
||||
raise HTTPException(404, "Industry not found")
|
||||
db.delete(db_ind)
|
||||
db.commit()
|
||||
return {"status": "deleted"}
|
||||
|
||||
# --- Job Role Mapping Routes ---
|
||||
|
||||
class JobRoleMappingCreate(BaseModel):
|
||||
pattern: str
|
||||
role: str
|
||||
|
||||
@app.get("/api/job_roles")
|
||||
def list_job_roles(db: Session = Depends(get_db)):
|
||||
return db.query(JobRoleMapping).all()
|
||||
|
||||
@app.post("/api/job_roles")
|
||||
def create_job_role(mapping: JobRoleMappingCreate, db: Session = Depends(get_db)):
|
||||
db_mapping = JobRoleMapping(**mapping.dict())
|
||||
db.add(db_mapping)
|
||||
db.commit()
|
||||
db.refresh(db_mapping)
|
||||
return db_mapping
|
||||
|
||||
@app.delete("/api/job_roles/{id}")
|
||||
def delete_job_role(id: int, db: Session = Depends(get_db)):
|
||||
db_mapping = db.query(JobRoleMapping).filter(JobRoleMapping.id == id).first()
|
||||
if not db_mapping:
|
||||
raise HTTPException(404, "Mapping not found")
|
||||
db.delete(db_mapping)
|
||||
db.commit()
|
||||
return {"status": "deleted"}
|
||||
|
||||
@app.put("/api/contacts/{contact_id}")
|
||||
def update_contact(contact_id: int, contact: ContactUpdate, db: Session = Depends(get_db)):
|
||||
"""Updates an existing contact."""
|
||||
db_contact = db.query(Contact).filter(Contact.id == contact_id).first()
|
||||
if not db_contact:
|
||||
raise HTTPException(404, "Contact not found")
|
||||
|
||||
update_data = contact.dict(exclude_unset=True)
|
||||
|
||||
if update_data.get("is_primary"):
|
||||
db.query(Contact).filter(Contact.company_id == db_contact.company_id).update({"is_primary": False})
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(db_contact, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_contact)
|
||||
return db_contact
|
||||
|
||||
@app.delete("/api/contacts/{contact_id}")
|
||||
def delete_contact(contact_id: int, db: Session = Depends(get_db)):
|
||||
"""Deletes a contact."""
|
||||
db_contact = db.query(Contact).filter(Contact.id == contact_id).first()
|
||||
if not db_contact:
|
||||
raise HTTPException(404, "Contact not found")
|
||||
db.delete(db_contact)
|
||||
db.commit()
|
||||
return {"status": "deleted"}
|
||||
|
||||
@app.get("/api/contacts/all")
|
||||
def list_all_contacts(
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
search: Optional[str] = None,
|
||||
sort_by: Optional[str] = Query("name_asc"),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Lists all contacts across all companies with pagination and search.
|
||||
"""
|
||||
query = db.query(Contact).join(Company)
|
||||
|
||||
if search:
|
||||
search_term = f"%{search}%"
|
||||
query = query.filter(
|
||||
(Contact.first_name.ilike(search_term)) |
|
||||
(Contact.last_name.ilike(search_term)) |
|
||||
(Contact.email.ilike(search_term)) |
|
||||
(Company.name.ilike(search_term))
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
|
||||
# Sorting Logic
|
||||
if sort_by == "updated_desc":
|
||||
query = query.order_by(Contact.updated_at.desc())
|
||||
elif sort_by == "created_desc":
|
||||
query = query.order_by(Contact.id.desc())
|
||||
else: # Default: name_asc
|
||||
query = query.order_by(Contact.last_name.asc(), Contact.first_name.asc())
|
||||
|
||||
contacts = query.offset(skip).limit(limit).all()
|
||||
|
||||
# Enrich with Company Name for the frontend list
|
||||
result = []
|
||||
for c in contacts:
|
||||
c_dict = {k: v for k, v in c.__dict__.items() if not k.startswith('_')}
|
||||
c_dict['company_name'] = c.company.name if c.company else "Unknown"
|
||||
result.append(c_dict)
|
||||
|
||||
return {"total": total, "items": result}
|
||||
|
||||
class BulkContactImportItem(BaseModel):
|
||||
company_name: str
|
||||
first_name: str
|
||||
last_name: str
|
||||
email: Optional[str] = None
|
||||
job_title: Optional[str] = None
|
||||
role: Optional[str] = "Operativer Entscheider"
|
||||
gender: Optional[str] = "männlich"
|
||||
|
||||
class BulkContactImportRequest(BaseModel):
|
||||
contacts: List[BulkContactImportItem]
|
||||
|
||||
@app.post("/api/contacts/bulk")
|
||||
def bulk_import_contacts(req: BulkContactImportRequest, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Bulk imports contacts.
|
||||
Matches Company by Name (creates if missing).
|
||||
Dedupes Contact by Email.
|
||||
"""
|
||||
logger.info(f"Starting bulk contact import: {len(req.contacts)} items")
|
||||
stats = {"added": 0, "skipped": 0, "companies_created": 0}
|
||||
|
||||
for item in req.contacts:
|
||||
if not item.company_name: continue
|
||||
|
||||
# 1. Find or Create Company
|
||||
company = db.query(Company).filter(Company.name.ilike(item.company_name.strip())).first()
|
||||
if not company:
|
||||
company = Company(name=item.company_name.strip(), status="NEW")
|
||||
db.add(company)
|
||||
db.commit() # Commit to get ID
|
||||
db.refresh(company)
|
||||
stats["companies_created"] += 1
|
||||
|
||||
# 2. Check for Duplicate Contact (by Email)
|
||||
if item.email:
|
||||
exists = db.query(Contact).filter(Contact.email == item.email.strip()).first()
|
||||
if exists:
|
||||
stats["skipped"] += 1
|
||||
continue
|
||||
|
||||
# 3. Create Contact
|
||||
new_contact = Contact(
|
||||
company_id=company.id,
|
||||
first_name=item.first_name,
|
||||
last_name=item.last_name,
|
||||
email=item.email,
|
||||
job_title=item.job_title,
|
||||
role=item.role,
|
||||
gender=item.gender,
|
||||
status="Init" # Default status
|
||||
)
|
||||
db.add(new_contact)
|
||||
stats["added"] += 1
|
||||
|
||||
db.commit()
|
||||
return stats
|
||||
|
||||
@app.post("/api/enrichment/{company_id}/{source_type}/lock")
|
||||
def lock_enrichment(company_id: int, source_type: str, locked: bool = Query(...), db: Session = Depends(get_db)):
|
||||
"""
|
||||
Toggles the lock status of a specific enrichment data type (e.g. 'website_scrape', 'wikipedia').
|
||||
"""
|
||||
entry = db.query(EnrichmentData).filter(
|
||||
EnrichmentData.company_id == company_id,
|
||||
EnrichmentData.source_type == source_type
|
||||
).first()
|
||||
|
||||
if not entry:
|
||||
raise HTTPException(404, "Enrichment data not found")
|
||||
|
||||
entry.is_locked = locked
|
||||
db.commit()
|
||||
return {"status": "updated", "is_locked": locked}
|
||||
background_tasks.add_task(run_analysis_task, company.id)
|
||||
return {"status": "queued"}
|
||||
|
||||
def run_discovery_task(company_id: int):
|
||||
# New Session for Background Task
|
||||
from .database import SessionLocal
|
||||
db = SessionLocal()
|
||||
try:
|
||||
company = db.query(Company).filter(Company.id == company_id).first()
|
||||
if not company: return
|
||||
|
||||
logger.info(f"Running Discovery Task for {company.name}")
|
||||
|
||||
# 1. Website Search (Always try if missing)
|
||||
# 1. Website Search
|
||||
if not company.website or company.website == "k.A.":
|
||||
found_url = discovery.find_company_website(company.name, company.city)
|
||||
if found_url and found_url != "k.A.":
|
||||
company.website = found_url
|
||||
logger.info(f"-> Found URL: {found_url}")
|
||||
|
||||
# 2. Wikipedia Search & Extraction
|
||||
# Check if locked
|
||||
# 2. Wikipedia Search
|
||||
existing_wiki = db.query(EnrichmentData).filter(
|
||||
EnrichmentData.company_id == company.id,
|
||||
EnrichmentData.source_type == "wikipedia"
|
||||
).first()
|
||||
|
||||
if existing_wiki and existing_wiki.is_locked:
|
||||
logger.info(f"Skipping Wiki Discovery for {company.name} - Data is LOCKED.")
|
||||
else:
|
||||
# Pass available info for better validation
|
||||
current_website = company.website if company.website and company.website != "k.A." else None
|
||||
wiki_url = discovery.find_wikipedia_url(company.name, website=current_website, city=company.city)
|
||||
company.last_wiki_search_at = datetime.utcnow()
|
||||
|
||||
wiki_data = {"url": wiki_url}
|
||||
if wiki_url and wiki_url != "k.A.":
|
||||
logger.info(f"Extracting full data from Wikipedia for {company.name}...")
|
||||
wiki_data = discovery.extract_wikipedia_data(wiki_url)
|
||||
if not existing_wiki or not existing_wiki.is_locked:
|
||||
wiki_url = discovery.find_wikipedia_url(company.name, website=company.website, city=company.city)
|
||||
wiki_data = discovery.extract_wikipedia_data(wiki_url) if wiki_url and wiki_url != "k.A." else {"url": wiki_url}
|
||||
|
||||
if not existing_wiki:
|
||||
db.add(EnrichmentData(company_id=company.id, source_type="wikipedia", content=wiki_data))
|
||||
@@ -686,35 +170,12 @@ def run_discovery_task(company_id: int):
|
||||
company.status = "DISCOVERED"
|
||||
|
||||
db.commit()
|
||||
logger.info(f"Discovery finished for {company.id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Background Task Error: {e}", exc_info=True)
|
||||
db.rollback()
|
||||
logger.error(f"Discovery Task Error: {e}", exc_info=True)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@app.post("/api/enrich/analyze")
|
||||
def analyze_company(req: AnalysisRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
||||
company = db.query(Company).filter(Company.id == req.company_id).first()
|
||||
if not company:
|
||||
raise HTTPException(404, "Company not found")
|
||||
|
||||
if not company.website or company.website == "k.A.":
|
||||
return {"error": "No website to analyze. Run Discovery first."}
|
||||
|
||||
# FORCE SCRAPE LOGIC
|
||||
# Respect Locked Data: Only delete if not locked.
|
||||
db.query(EnrichmentData).filter(
|
||||
EnrichmentData.company_id == company.id,
|
||||
EnrichmentData.source_type == "website_scrape",
|
||||
EnrichmentData.is_locked == False
|
||||
).delete()
|
||||
db.commit()
|
||||
|
||||
background_tasks.add_task(run_analysis_task, company.id, company.website)
|
||||
return {"status": "queued"}
|
||||
|
||||
def run_analysis_task(company_id: int, url: str):
|
||||
def run_analysis_task(company_id: int):
|
||||
from .database import SessionLocal
|
||||
db = SessionLocal()
|
||||
try:
|
||||
@@ -723,157 +184,41 @@ def run_analysis_task(company_id: int, url: str):
|
||||
|
||||
logger.info(f"Running Analysis Task for {company.name}")
|
||||
|
||||
# 1. Scrape Website OR Use Locked Data
|
||||
scrape_result = {}
|
||||
# 1. Scrape Website (if not locked)
|
||||
existing_scrape = db.query(EnrichmentData).filter(
|
||||
EnrichmentData.company_id == company.id,
|
||||
EnrichmentData.source_type == "website_scrape"
|
||||
).first()
|
||||
|
||||
if existing_scrape and existing_scrape.is_locked:
|
||||
logger.info(f"Using LOCKED scrape data for {company.name}")
|
||||
scrape_result = dict(existing_scrape.content) # Copy dict
|
||||
|
||||
# Always ensure city/country from locked impressum data is synced to company
|
||||
if "impressum" in scrape_result and scrape_result["impressum"]:
|
||||
impressum_city = scrape_result["impressum"].get("city")
|
||||
impressum_country = scrape_result["impressum"].get("country_code")
|
||||
logger.info(f"Analysis task (locked data): Impressum found. City='{impressum_city}', Country='{impressum_country}'")
|
||||
if impressum_city and company.city != impressum_city:
|
||||
logger.info(f"Analysis task: Updating company.city from '{company.city}' to '{impressum_city}'")
|
||||
company.city = impressum_city
|
||||
if impressum_country and company.country != impressum_country:
|
||||
logger.info(f"Analysis task: Updating company.country from '{company.country}' to '{impressum_country}'")
|
||||
company.country = impressum_country
|
||||
|
||||
text_val = scrape_result.get("text")
|
||||
text_len = len(text_val) if text_val else 0
|
||||
logger.info(f"Locked data keys: {list(scrape_result.keys())}, Text length: {text_len}")
|
||||
|
||||
# AUTO-FIX: If locked data (e.g. Manual Impressum) has no text, fetch main website text
|
||||
if text_len < 100:
|
||||
logger.info(f"Locked data missing text (len={text_len}). Fetching content from {url}...")
|
||||
try:
|
||||
fresh_scrape = scraper.scrape_url(url)
|
||||
except Exception as e:
|
||||
logger.error(f"Fresh scrape failed: {e}", exc_info=True)
|
||||
fresh_scrape = {}
|
||||
|
||||
logger.info(f"Fresh scrape result keys: {list(fresh_scrape.keys())}")
|
||||
|
||||
if "text" in fresh_scrape and len(fresh_scrape["text"]) > 100:
|
||||
logger.info(f"Fresh scrape successful. Text len: {len(fresh_scrape['text'])}")
|
||||
# Update local dict for current processing
|
||||
scrape_result["text"] = fresh_scrape["text"]
|
||||
scrape_result["title"] = fresh_scrape.get("title", "")
|
||||
|
||||
# Update DB (Merge into existing content)
|
||||
updated_content = dict(existing_scrape.content)
|
||||
updated_content["text"] = fresh_scrape["text"]
|
||||
updated_content["title"] = fresh_scrape.get("title", "")
|
||||
|
||||
existing_scrape.content = updated_content
|
||||
existing_scrape.updated_at = datetime.utcnow()
|
||||
# db.commit() here would be too early
|
||||
logger.info("Updated locked record with fresh website text in session.")
|
||||
else:
|
||||
logger.warning(f"Fresh scrape returned insufficient text. Error: {fresh_scrape.get('error')}")
|
||||
else:
|
||||
# Standard Scrape
|
||||
scrape_result = scraper.scrape_url(url)
|
||||
|
||||
# Update company fields from impressum if found during scrape
|
||||
if "impressum" in scrape_result and scrape_result["impressum"]:
|
||||
impressum_city = scrape_result["impressum"].get("city")
|
||||
impressum_country = scrape_result["impressum"].get("country_code")
|
||||
logger.info(f"Analysis task (standard scrape): Impressum found. City='{impressum_city}', Country='{impressum_country}'")
|
||||
if impressum_city and company.city != impressum_city:
|
||||
logger.info(f"Analysis task: Updating company.city from '{company.city}' to '{impressum_city}'")
|
||||
company.city = impressum_city
|
||||
if impressum_country and company.country != impressum_country:
|
||||
logger.info(f"Analysis task: Updating company.country from '{company.country}' to '{impressum_country}'")
|
||||
company.country = impressum_country
|
||||
|
||||
# Save Scrape Data
|
||||
if "text" in scrape_result and scrape_result["text"]:
|
||||
if not existing_scrape or not existing_scrape.is_locked:
|
||||
from .services.scraping import ScraperService
|
||||
scrape_res = ScraperService().scrape_url(company.website)
|
||||
if not existing_scrape:
|
||||
db.add(EnrichmentData(company_id=company.id, source_type="website_scrape", content=scrape_result))
|
||||
db.add(EnrichmentData(company_id=company.id, source_type="website_scrape", content=scrape_res))
|
||||
else:
|
||||
existing_scrape.content = scrape_result
|
||||
existing_scrape.content = scrape_res
|
||||
existing_scrape.updated_at = datetime.utcnow()
|
||||
elif "error" in scrape_result:
|
||||
logger.warning(f"Scraping failed for {company.name}: {scrape_result['error']}")
|
||||
db.commit()
|
||||
|
||||
# 2. Classify Robotics Potential
|
||||
text_content = scrape_result.get("text")
|
||||
|
||||
logger.info(f"Preparing classification. Text content length: {len(text_content) if text_content else 0}")
|
||||
|
||||
if text_content and len(text_content) > 100:
|
||||
logger.info(f"Starting classification for {company.name}...")
|
||||
analysis = classifier.analyze_robotics_potential(
|
||||
company_name=company.name,
|
||||
website_text=text_content
|
||||
)
|
||||
|
||||
if "error" in analysis:
|
||||
logger.error(f"Robotics classification failed for {company.name}: {analysis['error']}")
|
||||
else:
|
||||
industry = analysis.get("industry")
|
||||
if industry:
|
||||
company.industry_ai = industry
|
||||
|
||||
db.query(Signal).filter(Signal.company_id == company.id).delete()
|
||||
|
||||
potentials = analysis.get("potentials", {})
|
||||
for signal_type, data in potentials.items():
|
||||
new_signal = Signal(
|
||||
company_id=company.id,
|
||||
signal_type=f"robotics_{signal_type}_potential",
|
||||
confidence=data.get("score", 0),
|
||||
value="High" if data.get("score", 0) > 70 else "Medium" if data.get("score", 0) > 30 else "Low",
|
||||
proof_text=data.get("reason")
|
||||
)
|
||||
db.add(new_signal)
|
||||
|
||||
existing_analysis = db.query(EnrichmentData).filter(
|
||||
EnrichmentData.company_id == company.id,
|
||||
EnrichmentData.source_type == "ai_analysis"
|
||||
).first()
|
||||
|
||||
if not existing_analysis:
|
||||
db.add(EnrichmentData(company_id=company.id, source_type="ai_analysis", content=analysis))
|
||||
else:
|
||||
existing_analysis.content = analysis
|
||||
existing_analysis.updated_at = datetime.utcnow()
|
||||
# 2. Classify Industry & Metrics
|
||||
# IMPORTANT: Using the new method name and passing db session
|
||||
classifier.classify_company_potential(company, db)
|
||||
|
||||
company.status = "ENRICHED"
|
||||
company.last_classification_at = datetime.utcnow()
|
||||
logger.info(f"Robotics analysis complete for {company.name}.")
|
||||
else:
|
||||
logger.warning(f"Skipping classification for {company.name}: Insufficient text content (len={len(text_content) if text_content else 0})")
|
||||
|
||||
db.commit()
|
||||
logger.info(f"Analysis finished for {company.id}")
|
||||
logger.info(f"Analysis complete for {company.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Analyze Task Error: {e}", exc_info=True)
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# --- Serve Frontend ---
|
||||
# Priority 1: Container Path (outside of /app volume)
|
||||
static_path = "/frontend_static"
|
||||
|
||||
# Priority 2: Local Dev Path (relative to this file)
|
||||
if not os.path.exists(static_path):
|
||||
static_path = os.path.join(os.path.dirname(__file__), "../static")
|
||||
|
||||
if os.path.exists(static_path):
|
||||
logger.info(f"Serving frontend from {static_path}")
|
||||
app.mount("/", StaticFiles(directory=static_path, html=True), name="static")
|
||||
else:
|
||||
logger.warning(f"Frontend static files not found at {static_path} or local fallback.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import sqlite3
|
||||
import sys
|
||||
import os
|
||||
@@ -23,41 +22,53 @@ def get_table_columns(cursor, table_name):
|
||||
cursor.execute(f"PRAGMA table_info({table_name})")
|
||||
return [row[1] for row in cursor.fetchall()]
|
||||
|
||||
def migrate_industries_table():
|
||||
def migrate_tables():
|
||||
"""
|
||||
Adds the new schema columns to the 'industries' table if they don't exist.
|
||||
This ensures backward compatibility with older database files.
|
||||
Adds new columns to existing tables to support v0.7.0 features.
|
||||
"""
|
||||
logger.info(f"Connecting to database at {DB_FILE} to run migrations...")
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
# 1. Update INDUSTRIES Table
|
||||
logger.info("Checking 'industries' table schema...")
|
||||
columns = get_table_columns(cursor, "industries")
|
||||
logger.info(f"Found existing columns: {columns}")
|
||||
ind_columns = get_table_columns(cursor, "industries")
|
||||
|
||||
migrations_to_run = {
|
||||
ind_migrations = {
|
||||
"metric_type": "TEXT",
|
||||
"scraper_search_term": "TEXT",
|
||||
"standardization_logic": "TEXT",
|
||||
"proxy_factor": "FLOAT"
|
||||
# min_requirement, whale_threshold, scraper_keywords already exist from v0.6.0
|
||||
"proxy_factor": "FLOAT",
|
||||
"scraper_keywords": "TEXT",
|
||||
"scraper_search_term": "TEXT"
|
||||
}
|
||||
|
||||
for col, col_type in migrations_to_run.items():
|
||||
if col not in columns:
|
||||
logger.info(f"Adding column '{col}' ({col_type}) to 'industries' table...")
|
||||
for col, col_type in ind_migrations.items():
|
||||
if col not in ind_columns:
|
||||
logger.info(f"Adding column '{col}' to 'industries' table...")
|
||||
cursor.execute(f"ALTER TABLE industries ADD COLUMN {col} {col_type}")
|
||||
else:
|
||||
logger.info(f"Column '{col}' already exists. Skipping.")
|
||||
|
||||
# Also, we need to handle the removal of old columns if necessary (safer to leave them)
|
||||
# We will also fix the proxy_factor type if it was TEXT
|
||||
# This is more complex, for now let's just add.
|
||||
# 2. Update COMPANIES Table (New for v0.7.0)
|
||||
logger.info("Checking 'companies' table schema...")
|
||||
comp_columns = get_table_columns(cursor, "companies")
|
||||
|
||||
comp_migrations = {
|
||||
"calculated_metric_name": "TEXT",
|
||||
"calculated_metric_value": "FLOAT",
|
||||
"calculated_metric_unit": "TEXT",
|
||||
"standardized_metric_value": "FLOAT",
|
||||
"standardized_metric_unit": "TEXT",
|
||||
"metric_source": "TEXT"
|
||||
}
|
||||
|
||||
for col, col_type in comp_migrations.items():
|
||||
if col not in comp_columns:
|
||||
logger.info(f"Adding column '{col}' to 'companies' table...")
|
||||
cursor.execute(f"ALTER TABLE companies ADD COLUMN {col} {col_type}")
|
||||
|
||||
conn.commit()
|
||||
logger.info("Migrations for 'industries' table completed successfully.")
|
||||
logger.info("All migrations completed successfully.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred during migration: {e}", exc_info=True)
|
||||
@@ -65,9 +76,8 @@ def migrate_industries_table():
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not os.path.exists(DB_FILE):
|
||||
logger.error(f"Database file not found at {DB_FILE}. Cannot run migration. Please ensure the old database is in place.")
|
||||
logger.error(f"Database file not found at {DB_FILE}.")
|
||||
else:
|
||||
migrate_industries_table()
|
||||
migrate_tables()
|
||||
@@ -5,59 +5,39 @@ from typing import Optional, Dict, Any, List
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.database import Company, Industry, RoboticsCategory, EnrichmentData, get_db
|
||||
from backend.config import settings
|
||||
from backend.database import Company, Industry, RoboticsCategory, EnrichmentData
|
||||
from backend.lib.core_utils import call_gemini_flash, safe_eval_math, run_serp_search
|
||||
from backend.services.scraping import scrape_website_content # Corrected import
|
||||
from backend.services.scraping import scrape_website_content
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ClassificationService:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.allowed_industries_notion: List[Industry] = self._load_industry_definitions()
|
||||
self.robotics_categories: List[RoboticsCategory] = self._load_robotics_categories()
|
||||
def __init__(self):
|
||||
# We no longer load industries in init because we don't have a DB session here
|
||||
pass
|
||||
|
||||
# Pre-process allowed industries for LLM prompt
|
||||
self.llm_industry_definitions = [
|
||||
{"name": ind.name, "description": ind.description} for ind in self.allowed_industries_notion
|
||||
]
|
||||
|
||||
# Store for quick lookup
|
||||
self.industry_lookup = {ind.name: ind for ind in self.allowed_industries_notion}
|
||||
self.category_lookup = {cat.id: cat for cat in self.robotics_categories}
|
||||
|
||||
def _load_industry_definitions(self) -> List[Industry]:
|
||||
def _load_industry_definitions(self, db: Session) -> List[Industry]:
|
||||
"""Loads all industry definitions from the database."""
|
||||
industries = self.db.query(Industry).all()
|
||||
industries = db.query(Industry).all()
|
||||
if not industries:
|
||||
logger.warning("No industry definitions found in DB. Classification might be limited.")
|
||||
return industries
|
||||
|
||||
def _load_robotics_categories(self) -> List[RoboticsCategory]:
|
||||
"""Loads all robotics categories from the database."""
|
||||
categories = self.db.query(RoboticsCategory).all()
|
||||
if not categories:
|
||||
logger.warning("No robotics categories found in DB. Potential scoring might be limited.")
|
||||
return categories
|
||||
|
||||
def _get_wikipedia_content(self, company_id: int) -> Optional[str]:
|
||||
def _get_wikipedia_content(self, db: Session, company_id: int) -> Optional[str]:
|
||||
"""Fetches Wikipedia content from enrichment_data for a given company."""
|
||||
enrichment = self.db.query(EnrichmentData).filter(
|
||||
enrichment = db.query(EnrichmentData).filter(
|
||||
EnrichmentData.company_id == company_id,
|
||||
EnrichmentData.source_type == "wikipedia"
|
||||
).order_by(EnrichmentData.created_at.desc()).first()
|
||||
|
||||
if enrichment and enrichment.content:
|
||||
# Wikipedia content is stored as JSON with a 'text' key
|
||||
wiki_data = enrichment.content
|
||||
return wiki_data.get('text')
|
||||
return None
|
||||
|
||||
def _run_llm_classification_prompt(self, website_text: str, company_name: str) -> Optional[str]:
|
||||
def _run_llm_classification_prompt(self, website_text: str, company_name: str, industry_definitions: List[Dict[str, str]]) -> Optional[str]:
|
||||
"""
|
||||
Uses LLM to classify the company into one of the predefined industries.
|
||||
Returns the industry name (string) or "Others".
|
||||
"""
|
||||
prompt = r"""
|
||||
Du bist ein präziser Branchen-Klassifizierer für Unternehmen.
|
||||
@@ -79,31 +59,23 @@ class ClassificationService:
|
||||
Gib NUR den Namen der zugeordneten Branche zurück, als reinen String, nichts anderes.
|
||||
|
||||
Beispiel Output: Hotellerie
|
||||
Beispiel Output: Automotive - Dealer
|
||||
Beispiel Output: Others
|
||||
""".format(
|
||||
company_name=company_name,
|
||||
website_text_excerpt=website_text[:10000], # Limit text to avoid token limits
|
||||
industry_definitions_json=json.dumps(self.llm_industry_definitions, ensure_ascii=False)
|
||||
website_text_excerpt=website_text[:10000],
|
||||
industry_definitions_json=json.dumps(industry_definitions, ensure_ascii=False)
|
||||
)
|
||||
|
||||
try:
|
||||
response = call_gemini_flash(prompt, temperature=0.1, json_mode=False) # Low temp for strict classification
|
||||
classified_industry = response.strip()
|
||||
if classified_industry in [ind.name for ind in self.allowed_industries_notion] + ["Others"]:
|
||||
return classified_industry
|
||||
logger.warning(f"LLM classified industry '{classified_industry}' not in allowed list. Defaulting to Others.")
|
||||
return "Others"
|
||||
response = call_gemini_flash(prompt, temperature=0.1, json_mode=False)
|
||||
return response.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"LLM classification failed for {company_name}: {e}", exc_info=True)
|
||||
logger.error(f"LLM classification failed for {company_name}: {e}")
|
||||
return None
|
||||
|
||||
def _run_llm_metric_extraction_prompt(self, text_content: str, search_term: str, industry_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Uses LLM to extract the specific metric value from text.
|
||||
Returns a dict with 'raw_value', 'raw_unit', 'standardized_value' (if found), 'metric_name'.
|
||||
"""
|
||||
# Attempt to extract both the raw unit count and a potential area if explicitly mentioned
|
||||
prompt = r"""
|
||||
Du bist ein Datenextraktions-Spezialist.
|
||||
Analysiere den folgenden Text, um spezifische Metrik-Informationen zu extrahieren.
|
||||
@@ -119,63 +91,42 @@ class ClassificationService:
|
||||
1. Finde den numerischen Wert für '{search_term}'.
|
||||
2. Versuche auch, eine explizit genannte Gesamtfläche in Quadratmetern (m²) zu finden, falls relevant und vorhanden.
|
||||
|
||||
Gib NUR ein JSON-Objekt zurück mit den Schlüsseln:
|
||||
Gib NUR ein JSON-Objekt zurück:
|
||||
'raw_value': Der gefundene numerische Wert für '{search_term}' (als Zahl). null, falls nicht gefunden.
|
||||
'raw_unit': Die Einheit des raw_value (z.B. "Betten", "Stellplätze"). null, falls nicht gefunden.
|
||||
'area_value': Ein gefundener numerischer Wert für eine Gesamtfläche in m² (als Zahl). null, falls nicht gefunden.
|
||||
'metric_name': Der Name der Metrik, nach der gesucht wurde (also '{search_term}').
|
||||
|
||||
Beispiel Output (wenn 180 Betten und 4500m² Fläche gefunden):
|
||||
{{"raw_value": 180, "raw_unit": "Betten", "area_value": 4500, "metric_name": "{search_term}"}}
|
||||
|
||||
Beispiel Output (wenn nur 180 Betten gefunden):
|
||||
{{"raw_value": 180, "raw_unit": "Betten", "area_value": null, "metric_name": "{search_term}"}}
|
||||
|
||||
Beispiel Output (wenn nichts gefunden):
|
||||
{{"raw_value": null, "raw_unit": null, "area_value": null, "metric_name": "{search_term}"}}
|
||||
""".format(
|
||||
industry_name=industry_name,
|
||||
search_term=search_term,
|
||||
text_content_excerpt=text_content[:15000] # Adjust as needed for token limits
|
||||
text_content_excerpt=text_content[:15000]
|
||||
)
|
||||
|
||||
try:
|
||||
response = call_gemini_flash(prompt, temperature=0.05, json_mode=True) # Very low temp for extraction
|
||||
result = json.loads(response)
|
||||
return result
|
||||
response = call_gemini_flash(prompt, temperature=0.05, json_mode=True)
|
||||
return json.loads(response)
|
||||
except Exception as e:
|
||||
logger.error(f"LLM metric extraction failed for '{search_term}' in '{industry_name}': {e}", exc_info=True)
|
||||
logger.error(f"LLM metric extraction failed for '{search_term}': {e}")
|
||||
return None
|
||||
|
||||
def _parse_standardization_logic(self, formula: str, raw_value: float) -> Optional[float]:
|
||||
"""
|
||||
Safely parses and executes a simple mathematical formula for standardization.
|
||||
Supports basic arithmetic (+, -, *, /) and integer/float values.
|
||||
"""
|
||||
if not formula or not raw_value:
|
||||
if not formula or raw_value is None:
|
||||
return None
|
||||
|
||||
# Replace 'wert' or 'value' with the actual raw_value
|
||||
formula_cleaned = formula.replace("wert", str(raw_value)).replace("Value", str(raw_value)).replace("VALUE", str(raw_value))
|
||||
|
||||
formula_cleaned = formula.replace("wert", str(raw_value)).replace("Value", str(raw_value))
|
||||
try:
|
||||
# Use safe_eval_math from core_utils to prevent arbitrary code execution
|
||||
return safe_eval_math(formula_cleaned)
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating standardization logic '{formula}' with value {raw_value}: {e}", exc_info=True)
|
||||
except:
|
||||
return None
|
||||
|
||||
def _extract_and_calculate_metric_cascade(
|
||||
self,
|
||||
db: Session,
|
||||
company: Company,
|
||||
industry_name: str,
|
||||
search_term: str,
|
||||
standardization_logic: Optional[str],
|
||||
standardized_unit: Optional[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Orchestrates the 3-stage (Website -> Wikipedia -> SerpAPI) metric extraction.
|
||||
"""
|
||||
results = {
|
||||
"calculated_metric_name": search_term,
|
||||
"calculated_metric_value": None,
|
||||
@@ -185,150 +136,71 @@ class ClassificationService:
|
||||
"metric_source": None
|
||||
}
|
||||
|
||||
# --- STAGE 1: Website Analysis ---
|
||||
logger.info(f"Stage 1: Analyzing website for '{search_term}' for {company.name}")
|
||||
# CASCADE: Website -> Wikipedia -> SerpAPI
|
||||
sources = [
|
||||
("website", lambda: scrape_website_content(company.website)),
|
||||
("wikipedia", lambda: self._get_wikipedia_content(db, company.id)),
|
||||
("serpapi", lambda: " ".join([res.get("snippet", "") for res in run_serp_search(f"{company.name} {search_term} {industry_name}").get("organic_results", [])]) if run_serp_search(f"{company.name} {search_term} {industry_name}") else None)
|
||||
]
|
||||
|
||||
for source_name, content_loader in sources:
|
||||
logger.info(f"Checking {source_name} for '{search_term}' for {company.name}")
|
||||
try:
|
||||
content = content_loader()
|
||||
if not content: continue
|
||||
|
||||
llm_result = self._run_llm_metric_extraction_prompt(content, search_term, industry_name)
|
||||
if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None):
|
||||
results["calculated_metric_value"] = llm_result.get("raw_value")
|
||||
results["calculated_metric_unit"] = llm_result.get("raw_unit")
|
||||
results["metric_source"] = source_name
|
||||
|
||||
if llm_result.get("area_value") is not None:
|
||||
results["standardized_metric_value"] = llm_result.get("area_value")
|
||||
elif llm_result.get("raw_value") is not None and standardization_logic:
|
||||
results["standardized_metric_value"] = self._parse_standardization_logic(standardization_logic, llm_result["raw_value"])
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Error in {source_name} stage: {e}")
|
||||
|
||||
return results
|
||||
|
||||
def classify_company_potential(self, company: Company, db: Session) -> Company:
|
||||
logger.info(f"Starting classification for {company.name}")
|
||||
|
||||
# 1. Load Industries
|
||||
industries = self._load_industry_definitions(db)
|
||||
industry_defs = [{"name": i.name, "description": i.description} for i in industries]
|
||||
|
||||
# 2. Industry Classification
|
||||
website_content = scrape_website_content(company.website)
|
||||
if website_content:
|
||||
llm_result = self._run_llm_metric_extraction_prompt(website_content, search_term, industry_name)
|
||||
if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None):
|
||||
results["calculated_metric_value"] = llm_result.get("raw_value")
|
||||
results["calculated_metric_unit"] = llm_result.get("raw_unit")
|
||||
results["metric_source"] = "website"
|
||||
|
||||
if llm_result.get("area_value") is not None:
|
||||
# Prioritize directly found standardized area
|
||||
results["standardized_metric_value"] = llm_result.get("area_value")
|
||||
logger.info(f"Direct area value found on website for {company.name}: {llm_result.get('area_value')} m²")
|
||||
elif llm_result.get("raw_value") is not None and standardization_logic:
|
||||
# Calculate if only raw value found
|
||||
results["standardized_metric_value"] = self._parse_standardization_logic(
|
||||
standardization_logic, llm_result["raw_value"]
|
||||
)
|
||||
return results
|
||||
|
||||
# --- STAGE 2: Wikipedia Analysis ---
|
||||
logger.info(f"Stage 2: Analyzing Wikipedia for '{search_term}' for {company.name}")
|
||||
wikipedia_content = self._get_wikipedia_content(company.id)
|
||||
if wikipedia_content:
|
||||
llm_result = self._run_llm_metric_extraction_prompt(wikipedia_content, search_term, industry_name)
|
||||
if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None):
|
||||
results["calculated_metric_value"] = llm_result.get("raw_value")
|
||||
results["calculated_metric_unit"] = llm_result.get("raw_unit")
|
||||
results["metric_source"] = "wikipedia"
|
||||
|
||||
if llm_result.get("area_value") is not None:
|
||||
results["standardized_metric_value"] = llm_result.get("area_value")
|
||||
logger.info(f"Direct area value found on Wikipedia for {company.name}: {llm_result.get('area_value')} m²")
|
||||
elif llm_result.get("raw_value") is not None and standardization_logic:
|
||||
results["standardized_metric_value"] = self._parse_standardization_logic(
|
||||
standardization_logic, llm_result["raw_value"]
|
||||
)
|
||||
return results
|
||||
|
||||
# --- STAGE 3: SerpAPI (Google Search) ---
|
||||
logger.info(f"Stage 3: Running SerpAPI search for '{search_term}' for {company.name}")
|
||||
search_query = f"{company.name} {search_term} {industry_name}" # Example: "Hotel Moxy Würzburg Anzahl Betten Hotellerie"
|
||||
serp_results = run_serp_search(search_query) # This returns a dictionary of search results
|
||||
|
||||
if serp_results and serp_results.get("organic_results"):
|
||||
# Concatenate snippets from organic results
|
||||
snippets = " ".join([res.get("snippet", "") for res in serp_results["organic_results"]])
|
||||
if snippets:
|
||||
llm_result = self._run_llm_metric_extraction_prompt(snippets, search_term, industry_name)
|
||||
if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None):
|
||||
results["calculated_metric_value"] = llm_result.get("raw_value")
|
||||
results["calculated_metric_unit"] = llm_result.get("raw_unit")
|
||||
results["metric_source"] = "serpapi"
|
||||
|
||||
if llm_result.get("area_value") is not None:
|
||||
results["standardized_metric_value"] = llm_result.get("area_value")
|
||||
logger.info(f"Direct area value found via SerpAPI for {company.name}: {llm_result.get('area_value')} m²")
|
||||
elif llm_result.get("raw_value") is not None and standardization_logic:
|
||||
results["standardized_metric_value"] = self._parse_standardization_logic(
|
||||
standardization_logic, llm_result["raw_value"]
|
||||
)
|
||||
return results
|
||||
|
||||
logger.info(f"Could not extract metric for '{search_term}' from any source for {company.name}.")
|
||||
return results # Return results with None values
|
||||
|
||||
def classify_company_potential(self, company: Company) -> Company:
|
||||
"""
|
||||
Main method to classify industry and calculate potential metric for a company.
|
||||
"""
|
||||
logger.info(f"Starting classification for Company ID: {company.id}, Name: {company.name}")
|
||||
|
||||
# --- STEP 1: Strict Industry Classification ---
|
||||
website_content_for_classification = scrape_website_content(company.website)
|
||||
if not website_content_for_classification:
|
||||
logger.warning(f"No website content found for {company.name}. Skipping industry classification.")
|
||||
company.industry_ai = "Others" # Default if no content
|
||||
else:
|
||||
classified_industry_name = self._run_llm_classification_prompt(website_content_for_classification, company.name)
|
||||
if classified_industry_name:
|
||||
company.industry_ai = classified_industry_name
|
||||
logger.info(f"Classified {company.name} into industry: {classified_industry_name}")
|
||||
industry_name = self._run_llm_classification_prompt(website_content, company.name, industry_defs)
|
||||
company.industry_ai = industry_name if industry_name in [i.name for i in industries] else "Others"
|
||||
else:
|
||||
company.industry_ai = "Others"
|
||||
logger.warning(f"Failed to classify industry for {company.name}. Setting to 'Others'.")
|
||||
|
||||
self.db.add(company) # Update industry_ai
|
||||
self.db.commit()
|
||||
self.db.refresh(company)
|
||||
db.commit()
|
||||
|
||||
# --- STEP 2: Metric Extraction & Standardization (if not 'Others') ---
|
||||
if company.industry_ai == "Others" or company.industry_ai is None:
|
||||
logger.info(f"Company {company.name} classified as 'Others'. Skipping metric extraction.")
|
||||
return company
|
||||
# 3. Metric Extraction
|
||||
if company.industry_ai != "Others":
|
||||
industry = next((i for i in industries if i.name == company.industry_ai), None)
|
||||
if industry and industry.scraper_search_term:
|
||||
# Derive standardized unit
|
||||
std_unit = "m²" if "m²" in (industry.standardization_logic or "") else "Einheiten"
|
||||
|
||||
industry_definition = self.industry_lookup.get(company.industry_ai)
|
||||
if not industry_definition:
|
||||
logger.error(f"Industry definition for '{company.industry_ai}' not found in lookup. Skipping metric extraction.")
|
||||
return company
|
||||
|
||||
if not industry_definition.scraper_search_term:
|
||||
logger.info(f"Industry '{company.industry_ai}' has no 'Scraper Search Term'. Skipping metric extraction.")
|
||||
return company
|
||||
|
||||
# Determine standardized unit from standardization_logic if possible
|
||||
standardized_unit = "Einheiten" # Default
|
||||
if industry_definition.standardization_logic:
|
||||
# Example: "wert * 25m² (Fläche pro Zimmer)" -> extract "m²"
|
||||
match = re.search(r'(\w+)$', industry_definition.standardization_logic.replace(' ', ''))
|
||||
if match:
|
||||
standardized_unit = match.group(1).replace('(', '').replace(')', '') # Extract unit like "m²"
|
||||
|
||||
metric_results = self._extract_and_calculate_metric_cascade(
|
||||
company,
|
||||
company.industry_ai,
|
||||
industry_definition.scraper_search_term,
|
||||
industry_definition.standardization_logic,
|
||||
standardized_unit # Pass the derived unit
|
||||
metrics = self._extract_and_calculate_metric_cascade(
|
||||
db, company, company.industry_ai, industry.scraper_search_term, industry.standardization_logic, std_unit
|
||||
)
|
||||
|
||||
# Update company object with results
|
||||
company.calculated_metric_name = metric_results["calculated_metric_name"]
|
||||
company.calculated_metric_value = metric_results["calculated_metric_value"]
|
||||
company.calculated_metric_unit = metric_results["calculated_metric_unit"]
|
||||
company.standardized_metric_value = metric_results["standardized_metric_value"]
|
||||
company.standardized_metric_unit = metric_results["standardized_metric_unit"]
|
||||
company.metric_source = metric_results["metric_source"]
|
||||
company.last_classification_at = datetime.utcnow() # Update timestamp
|
||||
company.calculated_metric_name = metrics["calculated_metric_name"]
|
||||
company.calculated_metric_value = metrics["calculated_metric_value"]
|
||||
company.calculated_metric_unit = metrics["calculated_metric_unit"]
|
||||
company.standardized_metric_value = metrics["standardized_metric_value"]
|
||||
company.standardized_metric_unit = metrics["standardized_metric_unit"]
|
||||
company.metric_source = metrics["metric_source"]
|
||||
|
||||
self.db.add(company)
|
||||
self.db.commit()
|
||||
self.db.refresh(company) # Refresh to get updated values
|
||||
|
||||
logger.info(f"Classification and metric extraction completed for {company.name}.")
|
||||
company.last_classification_at = datetime.utcnow()
|
||||
db.commit()
|
||||
return company
|
||||
|
||||
# --- HELPER FOR SAFE MATH EVALUATION (Moved from core_utils.py or assumed to be there) ---
|
||||
# Assuming safe_eval_math is available via backend.lib.core_utils.safe_eval_math
|
||||
# Example implementation if not:
|
||||
# def safe_eval_math(expression: str) -> float:
|
||||
# # Implement a safe parser/evaluator for simple math expressions
|
||||
# # For now, a very basic eval might be used, but in production, this needs to be locked down
|
||||
# allowed_chars = "0123456789.+-*/ "
|
||||
# if not all(c in allowed_chars for c in expression):
|
||||
# raise ValueError("Expression contains disallowed characters.")
|
||||
# return eval(expression)
|
||||
@@ -267,3 +267,15 @@ class ScraperService:
|
||||
except Exception as e:
|
||||
logger.error(f"Critical error in _parse_html: {e}", exc_info=True)
|
||||
return {"title": "", "description": "", "text": "", "emails": [], "error": str(e)}
|
||||
|
||||
# --- HELPER FUNCTION FOR EXTERNAL USE ---
|
||||
def scrape_website_content(url: str) -> Optional[str]:
|
||||
"""
|
||||
Simple wrapper to get just the text content of a URL.
|
||||
Used by ClassificationService.
|
||||
"""
|
||||
scraper = ScraperService()
|
||||
result = scraper.scrape_url(url)
|
||||
if result and result.get("text"):
|
||||
return result["text"]
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user