from typing import Tuple import json import logging import re from datetime import datetime from typing import Optional, Dict, Any, List from sqlalchemy.orm import Session from backend.database import Company, Industry, RoboticsCategory, EnrichmentData from backend.lib.core_utils import call_gemini_flash, safe_eval_math, run_serp_search from backend.services.scraping import scrape_website_content from backend.lib.metric_parser import MetricParser logger = logging.getLogger(__name__) class ClassificationService: def __init__(self): pass def _load_industry_definitions(self, db: Session) -> List[Industry]: industries = db.query(Industry).all() if not industries: logger.warning("No industry definitions found in DB. Classification might be limited.") return industries def _get_wikipedia_content(self, db: Session, company_id: int) -> Optional[Dict[str, Any]]: enrichment = db.query(EnrichmentData).filter( EnrichmentData.company_id == company_id, EnrichmentData.source_type == "wikipedia" ).order_by(EnrichmentData.created_at.desc()).first() return enrichment.content if enrichment and enrichment.content else None def _run_llm_classification_prompt(self, website_text: str, company_name: str, industry_definitions: List[Dict[str, str]]) -> Optional[str]: prompt = f""" Act as a strict B2B Industry Classifier. Company: {company_name} Context: {website_text[:3000]} Available Industries: {json.dumps(industry_definitions, indent=2)} Task: Select the ONE industry that best matches the company. If the company is a Hospital/Klinik, select 'Healthcare - Hospital'. If none match well, select 'Others'. Return ONLY the exact name of the industry. """ try: response = call_gemini_flash(prompt) if not response: return "Others" cleaned = response.strip().replace('"', '').replace("'", "") # Simple fuzzy match check valid_names = [i['name'] for i in industry_definitions] + ["Others"] if cleaned in valid_names: return cleaned # Fallback: Try to find name in response for name in valid_names: if name in cleaned: return name return "Others" except Exception as e: logger.error(f"Classification Prompt Error: {e}") return "Others" def _run_llm_metric_extraction_prompt(self, text_content: str, search_term: str, industry_name: str) -> Optional[Dict[str, Any]]: prompt = f""" Extract the following metric for the company in industry '{industry_name}': Target Metric: "{search_term}" Source Text: {text_content[:6000]} Return a JSON object with: - "raw_value": The number found (e.g. 352 or 352.0). If text says "352 Betten", extract 352. If not found, null. - "raw_unit": The unit found (e.g. "Betten", "m²"). - "proof_text": A short quote from the text proving this value. JSON ONLY. """ try: response = call_gemini_flash(prompt, json_mode=True) if not response: return None if isinstance(response, str): response = response.replace("```json", "").replace("```", "").strip() data = json.loads(response) else: data = response # Basic cleanup if data.get("raw_value") == "null": data["raw_value"] = None return data except Exception as e: logger.error(f"LLM Extraction Parse Error: {e}") return None def _is_metric_plausible(self, metric_name: str, value: Optional[float]) -> bool: if value is None: return False try: val_float = float(value) return val_float > 0 except: return False def _parse_standardization_logic(self, formula: str, raw_value: float) -> Optional[float]: if not formula or raw_value is None: return None formula_cleaned = formula.replace("wert", str(raw_value)).replace("Value", str(raw_value)).replace("Wert", str(raw_value)) formula_cleaned = re.sub(r'(?i)m[²2]', '', formula_cleaned) formula_cleaned = re.sub(r'(?i)qm', '', formula_cleaned) formula_cleaned = re.sub(r'\s*\(.*\)\s*$', '', formula_cleaned).strip() try: return safe_eval_math(formula_cleaned) except Exception as e: logger.error(f"Failed to parse standardization logic '{formula}' with value {raw_value}: {e}") return None def _get_best_metric_result(self, results_list: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]: if not results_list: return None source_priority = {"wikipedia": 0, "website": 1, "serpapi": 2} valid_results = [r for r in results_list if r.get("calculated_metric_value") is not None] if not valid_results: return None valid_results.sort(key=lambda r: (source_priority.get(r.get("metric_source"), 99), -r.get("metric_confidence", 0.0))) logger.info(f"Best result chosen: {valid_results[0]}") return valid_results[0] def _get_website_content_and_url(self, company: Company) -> Tuple[Optional[str], Optional[str]]: return scrape_website_content(company.website), company.website def _get_wikipedia_content_and_url(self, db: Session, company_id: int) -> Tuple[Optional[str], Optional[str]]: wiki_data = self._get_wikipedia_content(db, company_id) return (wiki_data.get('full_text'), wiki_data.get('url')) if wiki_data else (None, None) def _get_serpapi_content_and_url(self, company: Company, search_term: str) -> Tuple[Optional[str], Optional[str]]: serp_results = run_serp_search(f"{company.name} {company.city or ''} {search_term}") if not serp_results: return None, None content = " ".join([res.get("snippet", "") for res in serp_results.get("organic_results", [])]) url = serp_results.get("organic_results", [{}])[0].get("link") if serp_results.get("organic_results") else None return content, url def _extract_and_calculate_metric_cascade(self, db: Session, company: Company, industry_name: str, search_term: str, standardization_logic: Optional[str], standardized_unit: Optional[str]) -> Dict[str, Any]: final_result = {"calculated_metric_name": search_term, "calculated_metric_value": None, "calculated_metric_unit": None, "standardized_metric_value": None, "standardized_metric_unit": standardized_unit, "metric_source": None, "metric_proof_text": None, "metric_source_url": None, "metric_confidence": 0.0, "metric_confidence_reason": "No value found in any source."} sources = [ ("website", self._get_website_content_and_url), ("wikipedia", self._get_wikipedia_content_and_url), ("serpapi", self._get_serpapi_content_and_url) ] all_source_results = [] for source_name, content_loader in sources: logger.info(f"Checking {source_name} for '{search_term}' for {company.name}") try: args = (company,) if source_name == 'website' else (db, company.id) if source_name == 'wikipedia' else (company, search_term) content_text, current_source_url = content_loader(*args) if not content_text: logger.info(f"No content for {source_name}.") continue llm_result = self._run_llm_metric_extraction_prompt(content_text, search_term, industry_name) if llm_result: llm_result['source_url'] = current_source_url all_source_results.append((source_name, llm_result)) except Exception as e: logger.error(f"Error in {source_name} stage: {e}") processed_results = [] for source_name, llm_result in all_source_results: metric_value = llm_result.get("raw_value") metric_unit = llm_result.get("raw_unit") if metric_value is not None and self._is_metric_plausible(search_term, metric_value): standardized_value = None if standardization_logic and metric_value is not None: standardized_value = self._parse_standardization_logic(standardization_logic, metric_value) processed_results.append({ "calculated_metric_name": search_term, "calculated_metric_value": metric_value, "calculated_metric_unit": metric_unit, "standardized_metric_value": standardized_value, "standardized_metric_unit": standardized_unit, "metric_source": source_name, "metric_proof_text": llm_result.get("proof_text"), "metric_source_url": llm_result.get("source_url"), "metric_confidence": 0.95, "metric_confidence_reason": "Value found and extracted by LLM." }) else: logger.info(f"LLM found no plausible metric for {search_term} in {source_name}.") best_result = self._get_best_metric_result(processed_results) return best_result if best_result else final_result def extract_metrics_for_industry(self, company: Company, db: Session, industry: Industry) -> Company: if not industry or not industry.scraper_search_term: logger.warning(f"No metric configuration for industry '{industry.name if industry else 'None'}'") return company # Improved unit derivation if "m²" in (industry.standardization_logic or "") or "m²" in (industry.scraper_search_term or ""): std_unit = "m²" else: std_unit = "Einheiten" metrics = self._extract_and_calculate_metric_cascade( db, company, industry.name, industry.scraper_search_term, industry.standardization_logic, std_unit ) company.calculated_metric_name = metrics["calculated_metric_name"] company.calculated_metric_value = metrics["calculated_metric_value"] company.calculated_metric_unit = metrics["calculated_metric_unit"] company.standardized_metric_value = metrics["standardized_metric_value"] company.standardized_metric_unit = metrics["standardized_metric_unit"] company.metric_source = metrics["metric_source"] company.metric_proof_text = metrics["metric_proof_text"] company.metric_source_url = metrics.get("metric_source_url") company.metric_confidence = metrics["metric_confidence"] company.metric_confidence_reason = metrics["metric_confidence_reason"] company.last_classification_at = datetime.utcnow() db.commit() return company def reevaluate_wikipedia_metric(self, company: Company, db: Session, industry: Industry) -> Company: logger.info(f"Re-evaluating metric for {company.name}...") return self.extract_metrics_for_industry(company, db, industry) def classify_company_potential(self, company: Company, db: Session) -> Company: logger.info(f"Starting classification for {company.name}...") # 1. Load Definitions industries = self._load_industry_definitions(db) industry_defs = [{"name": i.name, "description": i.description} for i in industries] # 2. Get Content (Website) website_content, _ = self._get_website_content_and_url(company) if not website_content: logger.warning(f"No website content for {company.name}. Skipping classification.") return company # 3. Classify Industry suggested_industry_name = self._run_llm_classification_prompt(website_content, company.name, industry_defs) logger.info(f"AI suggests industry: {suggested_industry_name}") # 4. Update Company # Match back to DB object matched_industry = next((i for i in industries if i.name == suggested_industry_name), None) if matched_industry: company.industry_ai = matched_industry.name else: company.industry_ai = "Others" # 5. Extract Metrics (Cascade) if matched_industry: self.extract_metrics_for_industry(company, db, matched_industry) company.last_classification_at = datetime.utcnow() db.commit() return company