fix(explorer): resolve initialization and import errors for v0.7.0 backend

This commit is contained in:
2026-01-20 17:11:31 +00:00
parent d9cb096663
commit 58b30dc0ed
4 changed files with 160 additions and 921 deletions

View File

@@ -5,59 +5,39 @@ from typing import Optional, Dict, Any, List
from sqlalchemy.orm import Session
from backend.database import Company, Industry, RoboticsCategory, EnrichmentData, get_db
from backend.config import settings
from backend.database import Company, Industry, RoboticsCategory, EnrichmentData
from backend.lib.core_utils import call_gemini_flash, safe_eval_math, run_serp_search
from backend.services.scraping import scrape_website_content # Corrected import
from backend.services.scraping import scrape_website_content
logger = logging.getLogger(__name__)
class ClassificationService:
def __init__(self, db: Session):
self.db = db
self.allowed_industries_notion: List[Industry] = self._load_industry_definitions()
self.robotics_categories: List[RoboticsCategory] = self._load_robotics_categories()
# Pre-process allowed industries for LLM prompt
self.llm_industry_definitions = [
{"name": ind.name, "description": ind.description} for ind in self.allowed_industries_notion
]
# Store for quick lookup
self.industry_lookup = {ind.name: ind for ind in self.allowed_industries_notion}
self.category_lookup = {cat.id: cat for cat in self.robotics_categories}
def __init__(self):
# We no longer load industries in init because we don't have a DB session here
pass
def _load_industry_definitions(self) -> List[Industry]:
def _load_industry_definitions(self, db: Session) -> List[Industry]:
"""Loads all industry definitions from the database."""
industries = self.db.query(Industry).all()
industries = db.query(Industry).all()
if not industries:
logger.warning("No industry definitions found in DB. Classification might be limited.")
return industries
def _load_robotics_categories(self) -> List[RoboticsCategory]:
"""Loads all robotics categories from the database."""
categories = self.db.query(RoboticsCategory).all()
if not categories:
logger.warning("No robotics categories found in DB. Potential scoring might be limited.")
return categories
def _get_wikipedia_content(self, company_id: int) -> Optional[str]:
def _get_wikipedia_content(self, db: Session, company_id: int) -> Optional[str]:
"""Fetches Wikipedia content from enrichment_data for a given company."""
enrichment = self.db.query(EnrichmentData).filter(
enrichment = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company_id,
EnrichmentData.source_type == "wikipedia"
).order_by(EnrichmentData.created_at.desc()).first()
if enrichment and enrichment.content:
# Wikipedia content is stored as JSON with a 'text' key
wiki_data = enrichment.content
return wiki_data.get('text')
return None
def _run_llm_classification_prompt(self, website_text: str, company_name: str) -> Optional[str]:
def _run_llm_classification_prompt(self, website_text: str, company_name: str, industry_definitions: List[Dict[str, str]]) -> Optional[str]:
"""
Uses LLM to classify the company into one of the predefined industries.
Returns the industry name (string) or "Others".
"""
prompt = r"""
Du bist ein präziser Branchen-Klassifizierer für Unternehmen.
@@ -79,31 +59,23 @@ class ClassificationService:
Gib NUR den Namen der zugeordneten Branche zurück, als reinen String, nichts anderes.
Beispiel Output: Hotellerie
Beispiel Output: Automotive - Dealer
Beispiel Output: Others
""".format(
company_name=company_name,
website_text_excerpt=website_text[:10000], # Limit text to avoid token limits
industry_definitions_json=json.dumps(self.llm_industry_definitions, ensure_ascii=False)
website_text_excerpt=website_text[:10000],
industry_definitions_json=json.dumps(industry_definitions, ensure_ascii=False)
)
try:
response = call_gemini_flash(prompt, temperature=0.1, json_mode=False) # Low temp for strict classification
classified_industry = response.strip()
if classified_industry in [ind.name for ind in self.allowed_industries_notion] + ["Others"]:
return classified_industry
logger.warning(f"LLM classified industry '{classified_industry}' not in allowed list. Defaulting to Others.")
return "Others"
response = call_gemini_flash(prompt, temperature=0.1, json_mode=False)
return response.strip()
except Exception as e:
logger.error(f"LLM classification failed for {company_name}: {e}", exc_info=True)
logger.error(f"LLM classification failed for {company_name}: {e}")
return None
def _run_llm_metric_extraction_prompt(self, text_content: str, search_term: str, industry_name: str) -> Optional[Dict[str, Any]]:
"""
Uses LLM to extract the specific metric value from text.
Returns a dict with 'raw_value', 'raw_unit', 'standardized_value' (if found), 'metric_name'.
"""
# Attempt to extract both the raw unit count and a potential area if explicitly mentioned
prompt = r"""
Du bist ein Datenextraktions-Spezialist.
Analysiere den folgenden Text, um spezifische Metrik-Informationen zu extrahieren.
@@ -119,63 +91,42 @@ class ClassificationService:
1. Finde den numerischen Wert für '{search_term}'.
2. Versuche auch, eine explizit genannte Gesamtfläche in Quadratmetern (m²) zu finden, falls relevant und vorhanden.
Gib NUR ein JSON-Objekt zurück mit den Schlüsseln:
Gib NUR ein JSON-Objekt zurück:
'raw_value': Der gefundene numerische Wert für '{search_term}' (als Zahl). null, falls nicht gefunden.
'raw_unit': Die Einheit des raw_value (z.B. "Betten", "Stellplätze"). null, falls nicht gefunden.
'area_value': Ein gefundener numerischer Wert für eine Gesamtfläche in m² (als Zahl). null, falls nicht gefunden.
'metric_name': Der Name der Metrik, nach der gesucht wurde (also '{search_term}').
Beispiel Output (wenn 180 Betten und 4500m² Fläche gefunden):
{{"raw_value": 180, "raw_unit": "Betten", "area_value": 4500, "metric_name": "{search_term}"}}
Beispiel Output (wenn nur 180 Betten gefunden):
{{"raw_value": 180, "raw_unit": "Betten", "area_value": null, "metric_name": "{search_term}"}}
Beispiel Output (wenn nichts gefunden):
{{"raw_value": null, "raw_unit": null, "area_value": null, "metric_name": "{search_term}"}}
""".format(
industry_name=industry_name,
search_term=search_term,
text_content_excerpt=text_content[:15000] # Adjust as needed for token limits
text_content_excerpt=text_content[:15000]
)
try:
response = call_gemini_flash(prompt, temperature=0.05, json_mode=True) # Very low temp for extraction
result = json.loads(response)
return result
response = call_gemini_flash(prompt, temperature=0.05, json_mode=True)
return json.loads(response)
except Exception as e:
logger.error(f"LLM metric extraction failed for '{search_term}' in '{industry_name}': {e}", exc_info=True)
logger.error(f"LLM metric extraction failed for '{search_term}': {e}")
return None
def _parse_standardization_logic(self, formula: str, raw_value: float) -> Optional[float]:
"""
Safely parses and executes a simple mathematical formula for standardization.
Supports basic arithmetic (+, -, *, /) and integer/float values.
"""
if not formula or not raw_value:
if not formula or raw_value is None:
return None
# Replace 'wert' or 'value' with the actual raw_value
formula_cleaned = formula.replace("wert", str(raw_value)).replace("Value", str(raw_value)).replace("VALUE", str(raw_value))
formula_cleaned = formula.replace("wert", str(raw_value)).replace("Value", str(raw_value))
try:
# Use safe_eval_math from core_utils to prevent arbitrary code execution
return safe_eval_math(formula_cleaned)
except Exception as e:
logger.error(f"Error evaluating standardization logic '{formula}' with value {raw_value}: {e}", exc_info=True)
except:
return None
def _extract_and_calculate_metric_cascade(
self,
db: Session,
company: Company,
industry_name: str,
search_term: str,
standardization_logic: Optional[str],
standardized_unit: Optional[str]
) -> Dict[str, Any]:
"""
Orchestrates the 3-stage (Website -> Wikipedia -> SerpAPI) metric extraction.
"""
results = {
"calculated_metric_name": search_term,
"calculated_metric_value": None,
@@ -185,150 +136,71 @@ class ClassificationService:
"metric_source": None
}
# --- STAGE 1: Website Analysis ---
logger.info(f"Stage 1: Analyzing website for '{search_term}' for {company.name}")
website_content = scrape_website_content(company.website)
if website_content:
llm_result = self._run_llm_metric_extraction_prompt(website_content, search_term, industry_name)
if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None):
results["calculated_metric_value"] = llm_result.get("raw_value")
results["calculated_metric_unit"] = llm_result.get("raw_unit")
results["metric_source"] = "website"
# CASCADE: Website -> Wikipedia -> SerpAPI
sources = [
("website", lambda: scrape_website_content(company.website)),
("wikipedia", lambda: self._get_wikipedia_content(db, company.id)),
("serpapi", lambda: " ".join([res.get("snippet", "") for res in run_serp_search(f"{company.name} {search_term} {industry_name}").get("organic_results", [])]) if run_serp_search(f"{company.name} {search_term} {industry_name}") else None)
]
if llm_result.get("area_value") is not None:
# Prioritize directly found standardized area
results["standardized_metric_value"] = llm_result.get("area_value")
logger.info(f"Direct area value found on website for {company.name}: {llm_result.get('area_value')}")
elif llm_result.get("raw_value") is not None and standardization_logic:
# Calculate if only raw value found
results["standardized_metric_value"] = self._parse_standardization_logic(
standardization_logic, llm_result["raw_value"]
)
return results
# --- STAGE 2: Wikipedia Analysis ---
logger.info(f"Stage 2: Analyzing Wikipedia for '{search_term}' for {company.name}")
wikipedia_content = self._get_wikipedia_content(company.id)
if wikipedia_content:
llm_result = self._run_llm_metric_extraction_prompt(wikipedia_content, search_term, industry_name)
if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None):
results["calculated_metric_value"] = llm_result.get("raw_value")
results["calculated_metric_unit"] = llm_result.get("raw_unit")
results["metric_source"] = "wikipedia"
if llm_result.get("area_value") is not None:
results["standardized_metric_value"] = llm_result.get("area_value")
logger.info(f"Direct area value found on Wikipedia for {company.name}: {llm_result.get('area_value')}")
elif llm_result.get("raw_value") is not None and standardization_logic:
results["standardized_metric_value"] = self._parse_standardization_logic(
standardization_logic, llm_result["raw_value"]
)
return results
# --- STAGE 3: SerpAPI (Google Search) ---
logger.info(f"Stage 3: Running SerpAPI search for '{search_term}' for {company.name}")
search_query = f"{company.name} {search_term} {industry_name}" # Example: "Hotel Moxy Würzburg Anzahl Betten Hotellerie"
serp_results = run_serp_search(search_query) # This returns a dictionary of search results
if serp_results and serp_results.get("organic_results"):
# Concatenate snippets from organic results
snippets = " ".join([res.get("snippet", "") for res in serp_results["organic_results"]])
if snippets:
llm_result = self._run_llm_metric_extraction_prompt(snippets, search_term, industry_name)
for source_name, content_loader in sources:
logger.info(f"Checking {source_name} for '{search_term}' for {company.name}")
try:
content = content_loader()
if not content: continue
llm_result = self._run_llm_metric_extraction_prompt(content, search_term, industry_name)
if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None):
results["calculated_metric_value"] = llm_result.get("raw_value")
results["calculated_metric_unit"] = llm_result.get("raw_unit")
results["metric_source"] = "serpapi"
results["metric_source"] = source_name
if llm_result.get("area_value") is not None:
results["standardized_metric_value"] = llm_result.get("area_value")
logger.info(f"Direct area value found via SerpAPI for {company.name}: {llm_result.get('area_value')}")
elif llm_result.get("raw_value") is not None and standardization_logic:
results["standardized_metric_value"] = self._parse_standardization_logic(
standardization_logic, llm_result["raw_value"]
)
results["standardized_metric_value"] = self._parse_standardization_logic(standardization_logic, llm_result["raw_value"])
return results
logger.info(f"Could not extract metric for '{search_term}' from any source for {company.name}.")
return results # Return results with None values
except Exception as e:
logger.error(f"Error in {source_name} stage: {e}")
def classify_company_potential(self, company: Company) -> Company:
"""
Main method to classify industry and calculate potential metric for a company.
"""
logger.info(f"Starting classification for Company ID: {company.id}, Name: {company.name}")
return results
# --- STEP 1: Strict Industry Classification ---
website_content_for_classification = scrape_website_content(company.website)
if not website_content_for_classification:
logger.warning(f"No website content found for {company.name}. Skipping industry classification.")
company.industry_ai = "Others" # Default if no content
def classify_company_potential(self, company: Company, db: Session) -> Company:
logger.info(f"Starting classification for {company.name}")
# 1. Load Industries
industries = self._load_industry_definitions(db)
industry_defs = [{"name": i.name, "description": i.description} for i in industries]
# 2. Industry Classification
website_content = scrape_website_content(company.website)
if website_content:
industry_name = self._run_llm_classification_prompt(website_content, company.name, industry_defs)
company.industry_ai = industry_name if industry_name in [i.name for i in industries] else "Others"
else:
classified_industry_name = self._run_llm_classification_prompt(website_content_for_classification, company.name)
if classified_industry_name:
company.industry_ai = classified_industry_name
logger.info(f"Classified {company.name} into industry: {classified_industry_name}")
else:
company.industry_ai = "Others"
logger.warning(f"Failed to classify industry for {company.name}. Setting to 'Others'.")
company.industry_ai = "Others"
self.db.add(company) # Update industry_ai
self.db.commit()
self.db.refresh(company)
db.commit()
# --- STEP 2: Metric Extraction & Standardization (if not 'Others') ---
if company.industry_ai == "Others" or company.industry_ai is None:
logger.info(f"Company {company.name} classified as 'Others'. Skipping metric extraction.")
return company
# 3. Metric Extraction
if company.industry_ai != "Others":
industry = next((i for i in industries if i.name == company.industry_ai), None)
if industry and industry.scraper_search_term:
# Derive standardized unit
std_unit = "" if "" in (industry.standardization_logic or "") else "Einheiten"
metrics = self._extract_and_calculate_metric_cascade(
db, company, company.industry_ai, industry.scraper_search_term, industry.standardization_logic, std_unit
)
company.calculated_metric_name = metrics["calculated_metric_name"]
company.calculated_metric_value = metrics["calculated_metric_value"]
company.calculated_metric_unit = metrics["calculated_metric_unit"]
company.standardized_metric_value = metrics["standardized_metric_value"]
company.standardized_metric_unit = metrics["standardized_metric_unit"]
company.metric_source = metrics["metric_source"]
industry_definition = self.industry_lookup.get(company.industry_ai)
if not industry_definition:
logger.error(f"Industry definition for '{company.industry_ai}' not found in lookup. Skipping metric extraction.")
return company
if not industry_definition.scraper_search_term:
logger.info(f"Industry '{company.industry_ai}' has no 'Scraper Search Term'. Skipping metric extraction.")
return company
# Determine standardized unit from standardization_logic if possible
standardized_unit = "Einheiten" # Default
if industry_definition.standardization_logic:
# Example: "wert * 25m² (Fläche pro Zimmer)" -> extract "m²"
match = re.search(r'(\w+)$', industry_definition.standardization_logic.replace(' ', ''))
if match:
standardized_unit = match.group(1).replace('(', '').replace(')', '') # Extract unit like "m²"
metric_results = self._extract_and_calculate_metric_cascade(
company,
company.industry_ai,
industry_definition.scraper_search_term,
industry_definition.standardization_logic,
standardized_unit # Pass the derived unit
)
# Update company object with results
company.calculated_metric_name = metric_results["calculated_metric_name"]
company.calculated_metric_value = metric_results["calculated_metric_value"]
company.calculated_metric_unit = metric_results["calculated_metric_unit"]
company.standardized_metric_value = metric_results["standardized_metric_value"]
company.standardized_metric_unit = metric_results["standardized_metric_unit"]
company.metric_source = metric_results["metric_source"]
company.last_classification_at = datetime.utcnow() # Update timestamp
self.db.add(company)
self.db.commit()
self.db.refresh(company) # Refresh to get updated values
logger.info(f"Classification and metric extraction completed for {company.name}.")
company.last_classification_at = datetime.utcnow()
db.commit()
return company
# --- HELPER FOR SAFE MATH EVALUATION (Moved from core_utils.py or assumed to be there) ---
# Assuming safe_eval_math is available via backend.lib.core_utils.safe_eval_math
# Example implementation if not:
# def safe_eval_math(expression: str) -> float:
# # Implement a safe parser/evaluator for simple math expressions
# # For now, a very basic eval might be used, but in production, this needs to be locked down
# allowed_chars = "0123456789.+-*/ "
# if not all(c in allowed_chars for c in expression):
# raise ValueError("Expression contains disallowed characters.")
# return eval(expression)