import json import logging import re from typing import Optional, Dict, Any, List from sqlalchemy.orm import Session from backend.database import Company, Industry, RoboticsCategory, EnrichmentData, get_db from backend.config import settings from backend.lib.core_utils import call_gemini_flash, safe_eval_math, run_serp_search from backend.services.scraping import scrape_website_content # Corrected import logger = logging.getLogger(__name__) class ClassificationService: def __init__(self, db: Session): self.db = db self.allowed_industries_notion: List[Industry] = self._load_industry_definitions() self.robotics_categories: List[RoboticsCategory] = self._load_robotics_categories() # Pre-process allowed industries for LLM prompt self.llm_industry_definitions = [ {"name": ind.name, "description": ind.description} for ind in self.allowed_industries_notion ] # Store for quick lookup self.industry_lookup = {ind.name: ind for ind in self.allowed_industries_notion} self.category_lookup = {cat.id: cat for cat in self.robotics_categories} def _load_industry_definitions(self) -> List[Industry]: """Loads all industry definitions from the database.""" industries = self.db.query(Industry).all() if not industries: logger.warning("No industry definitions found in DB. Classification might be limited.") return industries def _load_robotics_categories(self) -> List[RoboticsCategory]: """Loads all robotics categories from the database.""" categories = self.db.query(RoboticsCategory).all() if not categories: logger.warning("No robotics categories found in DB. Potential scoring might be limited.") return categories def _get_wikipedia_content(self, company_id: int) -> Optional[str]: """Fetches Wikipedia content from enrichment_data for a given company.""" enrichment = self.db.query(EnrichmentData).filter( EnrichmentData.company_id == company_id, EnrichmentData.source_type == "wikipedia" ).order_by(EnrichmentData.created_at.desc()).first() if enrichment and enrichment.content: # Wikipedia content is stored as JSON with a 'text' key wiki_data = enrichment.content return wiki_data.get('text') return None def _run_llm_classification_prompt(self, website_text: str, company_name: str) -> Optional[str]: """ Uses LLM to classify the company into one of the predefined industries. Returns the industry name (string) or "Others". """ prompt = r""" Du bist ein präziser Branchen-Klassifizierer für Unternehmen. Deine Aufgabe ist es, das vorliegende Unternehmen basierend auf seinem Website-Inhalt einer der untenstehenden Branchen zuzuordnen. --- UNTERNEHMEN --- Name: {company_name} Website-Inhalt (Auszug): {website_text_excerpt} --- ZU VERWENDENDE BRANCHEN-DEFINITIONEN (STRIKT) --- Wähle EINE der folgenden Branchen. Jede Branche hat eine Definition. {industry_definitions_json} --- AUFGABE --- Analysiere den Website-Inhalt. Wähle die Branchen-Definition, die am besten zum Unternehmen passt. Wenn keine der Definitionen zutrifft oder du unsicher bist, wähle "Others". Gib NUR den Namen der zugeordneten Branche zurück, als reinen String, nichts anderes. Beispiel Output: Hotellerie Beispiel Output: Automotive - Dealer Beispiel Output: Others """.format( company_name=company_name, website_text_excerpt=website_text[:10000], # Limit text to avoid token limits industry_definitions_json=json.dumps(self.llm_industry_definitions, ensure_ascii=False) ) try: response = call_gemini_flash(prompt, temperature=0.1, json_mode=False) # Low temp for strict classification classified_industry = response.strip() if classified_industry in [ind.name for ind in self.allowed_industries_notion] + ["Others"]: return classified_industry logger.warning(f"LLM classified industry '{classified_industry}' not in allowed list. Defaulting to Others.") return "Others" except Exception as e: logger.error(f"LLM classification failed for {company_name}: {e}", exc_info=True) return None def _run_llm_metric_extraction_prompt(self, text_content: str, search_term: str, industry_name: str) -> Optional[Dict[str, Any]]: """ Uses LLM to extract the specific metric value from text. Returns a dict with 'raw_value', 'raw_unit', 'standardized_value' (if found), 'metric_name'. """ # Attempt to extract both the raw unit count and a potential area if explicitly mentioned prompt = r""" Du bist ein Datenextraktions-Spezialist. Analysiere den folgenden Text, um spezifische Metrik-Informationen zu extrahieren. --- KONTEXT --- Unternehmen ist in der Branche: {industry_name} Gesuchter Wert (Rohdaten): '{search_term}' --- TEXT --- {text_content_excerpt} --- AUFGABE --- 1. Finde den numerischen Wert für '{search_term}'. 2. Versuche auch, eine explizit genannte Gesamtfläche in Quadratmetern (m²) zu finden, falls relevant und vorhanden. Gib NUR ein JSON-Objekt zurück mit den Schlüsseln: 'raw_value': Der gefundene numerische Wert für '{search_term}' (als Zahl). null, falls nicht gefunden. 'raw_unit': Die Einheit des raw_value (z.B. "Betten", "Stellplätze"). null, falls nicht gefunden. 'area_value': Ein gefundener numerischer Wert für eine Gesamtfläche in m² (als Zahl). null, falls nicht gefunden. 'metric_name': Der Name der Metrik, nach der gesucht wurde (also '{search_term}'). Beispiel Output (wenn 180 Betten und 4500m² Fläche gefunden): {{"raw_value": 180, "raw_unit": "Betten", "area_value": 4500, "metric_name": "{search_term}"}} Beispiel Output (wenn nur 180 Betten gefunden): {{"raw_value": 180, "raw_unit": "Betten", "area_value": null, "metric_name": "{search_term}"}} Beispiel Output (wenn nichts gefunden): {{"raw_value": null, "raw_unit": null, "area_value": null, "metric_name": "{search_term}"}} """.format( industry_name=industry_name, search_term=search_term, text_content_excerpt=text_content[:15000] # Adjust as needed for token limits ) try: response = call_gemini_flash(prompt, temperature=0.05, json_mode=True) # Very low temp for extraction result = json.loads(response) return result except Exception as e: logger.error(f"LLM metric extraction failed for '{search_term}' in '{industry_name}': {e}", exc_info=True) return None def _parse_standardization_logic(self, formula: str, raw_value: float) -> Optional[float]: """ Safely parses and executes a simple mathematical formula for standardization. Supports basic arithmetic (+, -, *, /) and integer/float values. """ if not formula or not raw_value: return None # Replace 'wert' or 'value' with the actual raw_value formula_cleaned = formula.replace("wert", str(raw_value)).replace("Value", str(raw_value)).replace("VALUE", str(raw_value)) try: # Use safe_eval_math from core_utils to prevent arbitrary code execution return safe_eval_math(formula_cleaned) except Exception as e: logger.error(f"Error evaluating standardization logic '{formula}' with value {raw_value}: {e}", exc_info=True) return None def _extract_and_calculate_metric_cascade( self, company: Company, industry_name: str, search_term: str, standardization_logic: Optional[str], standardized_unit: Optional[str] ) -> Dict[str, Any]: """ Orchestrates the 3-stage (Website -> Wikipedia -> SerpAPI) metric extraction. """ results = { "calculated_metric_name": search_term, "calculated_metric_value": None, "calculated_metric_unit": None, "standardized_metric_value": None, "standardized_metric_unit": standardized_unit, "metric_source": None } # --- STAGE 1: Website Analysis --- logger.info(f"Stage 1: Analyzing website for '{search_term}' for {company.name}") website_content = scrape_website_content(company.website) if website_content: llm_result = self._run_llm_metric_extraction_prompt(website_content, search_term, industry_name) if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None): results["calculated_metric_value"] = llm_result.get("raw_value") results["calculated_metric_unit"] = llm_result.get("raw_unit") results["metric_source"] = "website" if llm_result.get("area_value") is not None: # Prioritize directly found standardized area results["standardized_metric_value"] = llm_result.get("area_value") logger.info(f"Direct area value found on website for {company.name}: {llm_result.get('area_value')} m²") elif llm_result.get("raw_value") is not None and standardization_logic: # Calculate if only raw value found results["standardized_metric_value"] = self._parse_standardization_logic( standardization_logic, llm_result["raw_value"] ) return results # --- STAGE 2: Wikipedia Analysis --- logger.info(f"Stage 2: Analyzing Wikipedia for '{search_term}' for {company.name}") wikipedia_content = self._get_wikipedia_content(company.id) if wikipedia_content: llm_result = self._run_llm_metric_extraction_prompt(wikipedia_content, search_term, industry_name) if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None): results["calculated_metric_value"] = llm_result.get("raw_value") results["calculated_metric_unit"] = llm_result.get("raw_unit") results["metric_source"] = "wikipedia" if llm_result.get("area_value") is not None: results["standardized_metric_value"] = llm_result.get("area_value") logger.info(f"Direct area value found on Wikipedia for {company.name}: {llm_result.get('area_value')} m²") elif llm_result.get("raw_value") is not None and standardization_logic: results["standardized_metric_value"] = self._parse_standardization_logic( standardization_logic, llm_result["raw_value"] ) return results # --- STAGE 3: SerpAPI (Google Search) --- logger.info(f"Stage 3: Running SerpAPI search for '{search_term}' for {company.name}") search_query = f"{company.name} {search_term} {industry_name}" # Example: "Hotel Moxy Würzburg Anzahl Betten Hotellerie" serp_results = run_serp_search(search_query) # This returns a dictionary of search results if serp_results and serp_results.get("organic_results"): # Concatenate snippets from organic results snippets = " ".join([res.get("snippet", "") for res in serp_results["organic_results"]]) if snippets: llm_result = self._run_llm_metric_extraction_prompt(snippets, search_term, industry_name) if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None): results["calculated_metric_value"] = llm_result.get("raw_value") results["calculated_metric_unit"] = llm_result.get("raw_unit") results["metric_source"] = "serpapi" if llm_result.get("area_value") is not None: results["standardized_metric_value"] = llm_result.get("area_value") logger.info(f"Direct area value found via SerpAPI for {company.name}: {llm_result.get('area_value')} m²") elif llm_result.get("raw_value") is not None and standardization_logic: results["standardized_metric_value"] = self._parse_standardization_logic( standardization_logic, llm_result["raw_value"] ) return results logger.info(f"Could not extract metric for '{search_term}' from any source for {company.name}.") return results # Return results with None values def classify_company_potential(self, company: Company) -> Company: """ Main method to classify industry and calculate potential metric for a company. """ logger.info(f"Starting classification for Company ID: {company.id}, Name: {company.name}") # --- STEP 1: Strict Industry Classification --- website_content_for_classification = scrape_website_content(company.website) if not website_content_for_classification: logger.warning(f"No website content found for {company.name}. Skipping industry classification.") company.industry_ai = "Others" # Default if no content else: classified_industry_name = self._run_llm_classification_prompt(website_content_for_classification, company.name) if classified_industry_name: company.industry_ai = classified_industry_name logger.info(f"Classified {company.name} into industry: {classified_industry_name}") else: company.industry_ai = "Others" logger.warning(f"Failed to classify industry for {company.name}. Setting to 'Others'.") self.db.add(company) # Update industry_ai self.db.commit() self.db.refresh(company) # --- STEP 2: Metric Extraction & Standardization (if not 'Others') --- if company.industry_ai == "Others" or company.industry_ai is None: logger.info(f"Company {company.name} classified as 'Others'. Skipping metric extraction.") return company industry_definition = self.industry_lookup.get(company.industry_ai) if not industry_definition: logger.error(f"Industry definition for '{company.industry_ai}' not found in lookup. Skipping metric extraction.") return company if not industry_definition.scraper_search_term: logger.info(f"Industry '{company.industry_ai}' has no 'Scraper Search Term'. Skipping metric extraction.") return company # Determine standardized unit from standardization_logic if possible standardized_unit = "Einheiten" # Default if industry_definition.standardization_logic: # Example: "wert * 25m² (Fläche pro Zimmer)" -> extract "m²" match = re.search(r'(\w+)$', industry_definition.standardization_logic.replace(' ', '')) if match: standardized_unit = match.group(1).replace('(', '').replace(')', '') # Extract unit like "m²" metric_results = self._extract_and_calculate_metric_cascade( company, company.industry_ai, industry_definition.scraper_search_term, industry_definition.standardization_logic, standardized_unit # Pass the derived unit ) # Update company object with results company.calculated_metric_name = metric_results["calculated_metric_name"] company.calculated_metric_value = metric_results["calculated_metric_value"] company.calculated_metric_unit = metric_results["calculated_metric_unit"] company.standardized_metric_value = metric_results["standardized_metric_value"] company.standardized_metric_unit = metric_results["standardized_metric_unit"] company.metric_source = metric_results["metric_source"] company.last_classification_at = datetime.utcnow() # Update timestamp self.db.add(company) self.db.commit() self.db.refresh(company) # Refresh to get updated values logger.info(f"Classification and metric extraction completed for {company.name}.") return company # --- HELPER FOR SAFE MATH EVALUATION (Moved from core_utils.py or assumed to be there) --- # Assuming safe_eval_math is available via backend.lib.core_utils.safe_eval_math # Example implementation if not: # def safe_eval_math(expression: str) -> float: # # Implement a safe parser/evaluator for simple math expressions # # For now, a very basic eval might be used, but in production, this needs to be locked down # allowed_chars = "0123456789.+-*/ " # if not all(c in allowed_chars for c in expression): # raise ValueError("Expression contains disallowed characters.") # return eval(expression)