from typing import Tuple import json import logging import re from datetime import datetime from typing import Optional, Dict, Any, List from sqlalchemy.orm import Session from backend.database import Company, Industry, RoboticsCategory, EnrichmentData from backend.lib.core_utils import call_gemini_flash, safe_eval_math, run_serp_search from backend.services.scraping import scrape_website_content from backend.lib.metric_parser import MetricParser logger = logging.getLogger(__name__) class ClassificationService: def __init__(self): pass def _load_industry_definitions(self, db: Session) -> List[Industry]: industries = db.query(Industry).all() if not industries: logger.warning("No industry definitions found in DB. Classification might be limited.") return industries def _get_wikipedia_content(self, db: Session, company_id: int) -> Optional[Dict[str, Any]]: enrichment = db.query(EnrichmentData).filter( EnrichmentData.company_id == company_id, EnrichmentData.source_type == "wikipedia" ).order_by(EnrichmentData.created_at.desc()).first() return enrichment.content if enrichment and enrichment.content else None def _run_llm_classification_prompt(self, website_text: str, company_name: str, industry_definitions: List[Dict[str, str]]) -> Optional[str]: # ... [omitted for brevity, no changes here] ... pass def _run_llm_metric_extraction_prompt(self, text_content: str, search_term: str, industry_name: str) -> Optional[Dict[str, Any]]: # ... [omitted for brevity, no changes here] ... pass def _is_metric_plausible(self, metric_name: str, value: Optional[float]) -> bool: # ... [omitted for brevity, no changes here] ... pass def _parse_standardization_logic(self, formula: str, raw_value: float) -> Optional[float]: if not formula or raw_value is None: return None formula_cleaned = formula.replace("wert", str(raw_value)).replace("Value", str(raw_value)).replace("Wert", str(raw_value)) formula_cleaned = re.sub(r'(?i)m[²2]', '', formula_cleaned) formula_cleaned = re.sub(r'(?i)qm', '', formula_cleaned) formula_cleaned = re.sub(r'\s*\(.*\)\s*$', '', formula_cleaned).strip() try: return safe_eval_math(formula_cleaned) except Exception as e: logger.error(f"Failed to parse standardization logic '{formula}' with value {raw_value}: {e}") return None def _get_best_metric_result(self, results_list: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]: if not results_list: return None source_priority = {"wikipedia": 0, "website": 1, "serpapi": 2} valid_results = [r for r in results_list if r.get("calculated_metric_value") is not None] if not valid_results: return None valid_results.sort(key=lambda r: (source_priority.get(r.get("metric_source"), 99), -r.get("metric_confidence", 0.0))) logger.info(f"Best result chosen: {valid_results[0]}") return valid_results[0] def _get_website_content_and_url(self, company: Company) -> Tuple[Optional[str], Optional[str]]: return scrape_website_content(company.website), company.website def _get_wikipedia_content_and_url(self, db: Session, company_id: int) -> Tuple[Optional[str], Optional[str]]: wiki_data = self._get_wikipedia_content(db, company_id) return (wiki_data.get('full_text'), wiki_data.get('url')) if wiki_data else (None, None) def _get_serpapi_content_and_url(self, company: Company, search_term: str) -> Tuple[Optional[str], Optional[str]]: serp_results = run_serp_search(f"{company.name} {company.city or ''} {search_term}") if not serp_results: return None, None content = " ".join([res.get("snippet", "") for res in serp_results.get("organic_results", [])]) url = serp_results.get("organic_results", [{}])[0].get("link") if serp_results.get("organic_results") else None return content, url def _extract_and_calculate_metric_cascade(self, db: Session, company: Company, industry_name: str, search_term: str, standardization_logic: Optional[str], standardized_unit: Optional[str]) -> Dict[str, Any]: final_result = {"calculated_metric_name": search_term, "calculated_metric_value": None, "calculated_metric_unit": None, "standardized_metric_value": None, "standardized_metric_unit": standardized_unit, "metric_source": None, "metric_proof_text": None, "metric_source_url": None, "metric_confidence": 0.0, "metric_confidence_reason": "No value found in any source."} sources = [ ("website", self._get_website_content_and_url), ("wikipedia", self._get_wikipedia_content_and_url), ("serpapi", self._get_serpapi_content_and_url) ] all_source_results = [] for source_name, content_loader in sources: logger.info(f"Checking {source_name} for '{search_term}' for {company.name}") try: args = (company,) if source_name == 'website' else (db, company.id) if source_name == 'wikipedia' else (company, search_term) content_text, current_source_url = content_loader(*args) if not content_text: logger.info(f"No content for {source_name}.") continue llm_result = self._run_llm_metric_extraction_prompt(content_text, search_term, industry_name) if llm_result: llm_result['source_url'] = current_source_url all_source_results.append((source_name, llm_result)) except Exception as e: logger.error(f"Error in {source_name} stage: {e}") processed_results = [] # ... [processing logic as before, no changes] ... best_result = self._get_best_metric_result(processed_results) return best_result if best_result else final_result # ... [rest of the class, no changes] ... def extract_metrics_for_industry(self, company: Company, db: Session, industry: Industry) -> Company: if not industry or not industry.scraper_search_term: logger.warning(f"No metric configuration for industry '{industry.name if industry else 'None'}'") return company # Improved unit derivation if "m²" in (industry.standardization_logic or "") or "m²" in (industry.scraper_search_term or ""): std_unit = "m²" else: std_unit = "Einheiten" metrics = self._extract_and_calculate_metric_cascade( db, company, industry.name, industry.scraper_search_term, industry.standardization_logic, std_unit ) company.calculated_metric_name = metrics["calculated_metric_name"] company.calculated_metric_value = metrics["calculated_metric_value"] company.calculated_metric_unit = metrics["calculated_metric_unit"] company.standardized_metric_value = metrics["standardized_metric_value"] company.standardized_metric_unit = metrics["standardized_metric_unit"] company.metric_source = metrics["metric_source"] company.metric_proof_text = metrics["metric_proof_text"] company.metric_source_url = metrics.get("metric_source_url") company.metric_confidence = metrics["metric_confidence"] company.metric_confidence_reason = metrics["metric_confidence_reason"] company.last_classification_at = datetime.utcnow() db.commit() return company def reevaluate_wikipedia_metric(self, company: Company, db: Session, industry: Industry) -> Company: # ... [omitted for brevity, no changes here] ... pass def classify_company_potential(self, company: Company, db: Session) -> Company: # ... [omitted for brevity, no changes here] ... pass