Files
Brancheneinstufung2/company-explorer/backend/services/classification.py

334 lines
17 KiB
Python

import json
import logging
import re
from typing import Optional, Dict, Any, List
from sqlalchemy.orm import Session
from backend.database import Company, Industry, RoboticsCategory, EnrichmentData, get_db
from backend.config import settings
from backend.lib.core_utils import call_gemini_flash, safe_eval_math, run_serp_search
from backend.services.scraping import scrape_website_content # Corrected import
logger = logging.getLogger(__name__)
class ClassificationService:
def __init__(self, db: Session):
self.db = db
self.allowed_industries_notion: List[Industry] = self._load_industry_definitions()
self.robotics_categories: List[RoboticsCategory] = self._load_robotics_categories()
# Pre-process allowed industries for LLM prompt
self.llm_industry_definitions = [
{"name": ind.name, "description": ind.description} for ind in self.allowed_industries_notion
]
# Store for quick lookup
self.industry_lookup = {ind.name: ind for ind in self.allowed_industries_notion}
self.category_lookup = {cat.id: cat for cat in self.robotics_categories}
def _load_industry_definitions(self) -> List[Industry]:
"""Loads all industry definitions from the database."""
industries = self.db.query(Industry).all()
if not industries:
logger.warning("No industry definitions found in DB. Classification might be limited.")
return industries
def _load_robotics_categories(self) -> List[RoboticsCategory]:
"""Loads all robotics categories from the database."""
categories = self.db.query(RoboticsCategory).all()
if not categories:
logger.warning("No robotics categories found in DB. Potential scoring might be limited.")
return categories
def _get_wikipedia_content(self, company_id: int) -> Optional[str]:
"""Fetches Wikipedia content from enrichment_data for a given company."""
enrichment = self.db.query(EnrichmentData).filter(
EnrichmentData.company_id == company_id,
EnrichmentData.source_type == "wikipedia"
).order_by(EnrichmentData.created_at.desc()).first()
if enrichment and enrichment.content:
# Wikipedia content is stored as JSON with a 'text' key
wiki_data = enrichment.content
return wiki_data.get('text')
return None
def _run_llm_classification_prompt(self, website_text: str, company_name: str) -> Optional[str]:
"""
Uses LLM to classify the company into one of the predefined industries.
Returns the industry name (string) or "Others".
"""
prompt = r"""
Du bist ein präziser Branchen-Klassifizierer für Unternehmen.
Deine Aufgabe ist es, das vorliegende Unternehmen basierend auf seinem Website-Inhalt
einer der untenstehenden Branchen zuzuordnen.
--- UNTERNEHMEN ---
Name: {company_name}
Website-Inhalt (Auszug):
{website_text_excerpt}
--- ZU VERWENDENDE BRANCHEN-DEFINITIONEN (STRIKT) ---
Wähle EINE der folgenden Branchen. Jede Branche hat eine Definition.
{industry_definitions_json}
--- AUFGABE ---
Analysiere den Website-Inhalt. Wähle die Branchen-Definition, die am besten zum Unternehmen passt.
Wenn keine der Definitionen zutrifft oder du unsicher bist, wähle "Others".
Gib NUR den Namen der zugeordneten Branche zurück, als reinen String, nichts anderes.
Beispiel Output: Hotellerie
Beispiel Output: Automotive - Dealer
Beispiel Output: Others
""".format(
company_name=company_name,
website_text_excerpt=website_text[:10000], # Limit text to avoid token limits
industry_definitions_json=json.dumps(self.llm_industry_definitions, ensure_ascii=False)
)
try:
response = call_gemini_flash(prompt, temperature=0.1, json_mode=False) # Low temp for strict classification
classified_industry = response.strip()
if classified_industry in [ind.name for ind in self.allowed_industries_notion] + ["Others"]:
return classified_industry
logger.warning(f"LLM classified industry '{classified_industry}' not in allowed list. Defaulting to Others.")
return "Others"
except Exception as e:
logger.error(f"LLM classification failed for {company_name}: {e}", exc_info=True)
return None
def _run_llm_metric_extraction_prompt(self, text_content: str, search_term: str, industry_name: str) -> Optional[Dict[str, Any]]:
"""
Uses LLM to extract the specific metric value from text.
Returns a dict with 'raw_value', 'raw_unit', 'standardized_value' (if found), 'metric_name'.
"""
# Attempt to extract both the raw unit count and a potential area if explicitly mentioned
prompt = r"""
Du bist ein Datenextraktions-Spezialist.
Analysiere den folgenden Text, um spezifische Metrik-Informationen zu extrahieren.
--- KONTEXT ---
Unternehmen ist in der Branche: {industry_name}
Gesuchter Wert (Rohdaten): '{search_term}'
--- TEXT ---
{text_content_excerpt}
--- AUFGABE ---
1. Finde den numerischen Wert für '{search_term}'.
2. Versuche auch, eine explizit genannte Gesamtfläche in Quadratmetern (m²) zu finden, falls relevant und vorhanden.
Gib NUR ein JSON-Objekt zurück mit den Schlüsseln:
'raw_value': Der gefundene numerische Wert für '{search_term}' (als Zahl). null, falls nicht gefunden.
'raw_unit': Die Einheit des raw_value (z.B. "Betten", "Stellplätze"). null, falls nicht gefunden.
'area_value': Ein gefundener numerischer Wert für eine Gesamtfläche in m² (als Zahl). null, falls nicht gefunden.
'metric_name': Der Name der Metrik, nach der gesucht wurde (also '{search_term}').
Beispiel Output (wenn 180 Betten und 4500m² Fläche gefunden):
{{"raw_value": 180, "raw_unit": "Betten", "area_value": 4500, "metric_name": "{search_term}"}}
Beispiel Output (wenn nur 180 Betten gefunden):
{{"raw_value": 180, "raw_unit": "Betten", "area_value": null, "metric_name": "{search_term}"}}
Beispiel Output (wenn nichts gefunden):
{{"raw_value": null, "raw_unit": null, "area_value": null, "metric_name": "{search_term}"}}
""".format(
industry_name=industry_name,
search_term=search_term,
text_content_excerpt=text_content[:15000] # Adjust as needed for token limits
)
try:
response = call_gemini_flash(prompt, temperature=0.05, json_mode=True) # Very low temp for extraction
result = json.loads(response)
return result
except Exception as e:
logger.error(f"LLM metric extraction failed for '{search_term}' in '{industry_name}': {e}", exc_info=True)
return None
def _parse_standardization_logic(self, formula: str, raw_value: float) -> Optional[float]:
"""
Safely parses and executes a simple mathematical formula for standardization.
Supports basic arithmetic (+, -, *, /) and integer/float values.
"""
if not formula or not raw_value:
return None
# Replace 'wert' or 'value' with the actual raw_value
formula_cleaned = formula.replace("wert", str(raw_value)).replace("Value", str(raw_value)).replace("VALUE", str(raw_value))
try:
# Use safe_eval_math from core_utils to prevent arbitrary code execution
return safe_eval_math(formula_cleaned)
except Exception as e:
logger.error(f"Error evaluating standardization logic '{formula}' with value {raw_value}: {e}", exc_info=True)
return None
def _extract_and_calculate_metric_cascade(
self,
company: Company,
industry_name: str,
search_term: str,
standardization_logic: Optional[str],
standardized_unit: Optional[str]
) -> Dict[str, Any]:
"""
Orchestrates the 3-stage (Website -> Wikipedia -> SerpAPI) metric extraction.
"""
results = {
"calculated_metric_name": search_term,
"calculated_metric_value": None,
"calculated_metric_unit": None,
"standardized_metric_value": None,
"standardized_metric_unit": standardized_unit,
"metric_source": None
}
# --- STAGE 1: Website Analysis ---
logger.info(f"Stage 1: Analyzing website for '{search_term}' for {company.name}")
website_content = scrape_website_content(company.website)
if website_content:
llm_result = self._run_llm_metric_extraction_prompt(website_content, search_term, industry_name)
if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None):
results["calculated_metric_value"] = llm_result.get("raw_value")
results["calculated_metric_unit"] = llm_result.get("raw_unit")
results["metric_source"] = "website"
if llm_result.get("area_value") is not None:
# Prioritize directly found standardized area
results["standardized_metric_value"] = llm_result.get("area_value")
logger.info(f"Direct area value found on website for {company.name}: {llm_result.get('area_value')}")
elif llm_result.get("raw_value") is not None and standardization_logic:
# Calculate if only raw value found
results["standardized_metric_value"] = self._parse_standardization_logic(
standardization_logic, llm_result["raw_value"]
)
return results
# --- STAGE 2: Wikipedia Analysis ---
logger.info(f"Stage 2: Analyzing Wikipedia for '{search_term}' for {company.name}")
wikipedia_content = self._get_wikipedia_content(company.id)
if wikipedia_content:
llm_result = self._run_llm_metric_extraction_prompt(wikipedia_content, search_term, industry_name)
if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None):
results["calculated_metric_value"] = llm_result.get("raw_value")
results["calculated_metric_unit"] = llm_result.get("raw_unit")
results["metric_source"] = "wikipedia"
if llm_result.get("area_value") is not None:
results["standardized_metric_value"] = llm_result.get("area_value")
logger.info(f"Direct area value found on Wikipedia for {company.name}: {llm_result.get('area_value')}")
elif llm_result.get("raw_value") is not None and standardization_logic:
results["standardized_metric_value"] = self._parse_standardization_logic(
standardization_logic, llm_result["raw_value"]
)
return results
# --- STAGE 3: SerpAPI (Google Search) ---
logger.info(f"Stage 3: Running SerpAPI search for '{search_term}' for {company.name}")
search_query = f"{company.name} {search_term} {industry_name}" # Example: "Hotel Moxy Würzburg Anzahl Betten Hotellerie"
serp_results = run_serp_search(search_query) # This returns a dictionary of search results
if serp_results and serp_results.get("organic_results"):
# Concatenate snippets from organic results
snippets = " ".join([res.get("snippet", "") for res in serp_results["organic_results"]])
if snippets:
llm_result = self._run_llm_metric_extraction_prompt(snippets, search_term, industry_name)
if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None):
results["calculated_metric_value"] = llm_result.get("raw_value")
results["calculated_metric_unit"] = llm_result.get("raw_unit")
results["metric_source"] = "serpapi"
if llm_result.get("area_value") is not None:
results["standardized_metric_value"] = llm_result.get("area_value")
logger.info(f"Direct area value found via SerpAPI for {company.name}: {llm_result.get('area_value')}")
elif llm_result.get("raw_value") is not None and standardization_logic:
results["standardized_metric_value"] = self._parse_standardization_logic(
standardization_logic, llm_result["raw_value"]
)
return results
logger.info(f"Could not extract metric for '{search_term}' from any source for {company.name}.")
return results # Return results with None values
def classify_company_potential(self, company: Company) -> Company:
"""
Main method to classify industry and calculate potential metric for a company.
"""
logger.info(f"Starting classification for Company ID: {company.id}, Name: {company.name}")
# --- STEP 1: Strict Industry Classification ---
website_content_for_classification = scrape_website_content(company.website)
if not website_content_for_classification:
logger.warning(f"No website content found for {company.name}. Skipping industry classification.")
company.industry_ai = "Others" # Default if no content
else:
classified_industry_name = self._run_llm_classification_prompt(website_content_for_classification, company.name)
if classified_industry_name:
company.industry_ai = classified_industry_name
logger.info(f"Classified {company.name} into industry: {classified_industry_name}")
else:
company.industry_ai = "Others"
logger.warning(f"Failed to classify industry for {company.name}. Setting to 'Others'.")
self.db.add(company) # Update industry_ai
self.db.commit()
self.db.refresh(company)
# --- STEP 2: Metric Extraction & Standardization (if not 'Others') ---
if company.industry_ai == "Others" or company.industry_ai is None:
logger.info(f"Company {company.name} classified as 'Others'. Skipping metric extraction.")
return company
industry_definition = self.industry_lookup.get(company.industry_ai)
if not industry_definition:
logger.error(f"Industry definition for '{company.industry_ai}' not found in lookup. Skipping metric extraction.")
return company
if not industry_definition.scraper_search_term:
logger.info(f"Industry '{company.industry_ai}' has no 'Scraper Search Term'. Skipping metric extraction.")
return company
# Determine standardized unit from standardization_logic if possible
standardized_unit = "Einheiten" # Default
if industry_definition.standardization_logic:
# Example: "wert * 25m² (Fläche pro Zimmer)" -> extract "m²"
match = re.search(r'(\w+)$', industry_definition.standardization_logic.replace(' ', ''))
if match:
standardized_unit = match.group(1).replace('(', '').replace(')', '') # Extract unit like "m²"
metric_results = self._extract_and_calculate_metric_cascade(
company,
company.industry_ai,
industry_definition.scraper_search_term,
industry_definition.standardization_logic,
standardized_unit # Pass the derived unit
)
# Update company object with results
company.calculated_metric_name = metric_results["calculated_metric_name"]
company.calculated_metric_value = metric_results["calculated_metric_value"]
company.calculated_metric_unit = metric_results["calculated_metric_unit"]
company.standardized_metric_value = metric_results["standardized_metric_value"]
company.standardized_metric_unit = metric_results["standardized_metric_unit"]
company.metric_source = metric_results["metric_source"]
company.last_classification_at = datetime.utcnow() # Update timestamp
self.db.add(company)
self.db.commit()
self.db.refresh(company) # Refresh to get updated values
logger.info(f"Classification and metric extraction completed for {company.name}.")
return company
# --- HELPER FOR SAFE MATH EVALUATION (Moved from core_utils.py or assumed to be there) ---
# Assuming safe_eval_math is available via backend.lib.core_utils.safe_eval_math
# Example implementation if not:
# def safe_eval_math(expression: str) -> float:
# # Implement a safe parser/evaluator for simple math expressions
# # For now, a very basic eval might be used, but in production, this needs to be locked down
# allowed_chars = "0123456789.+-*/ "
# if not all(c in allowed_chars for c in expression):
# raise ValueError("Expression contains disallowed characters.")
# return eval(expression)