207 lines
9.2 KiB
Python
207 lines
9.2 KiB
Python
import json
|
|
import logging
|
|
import re
|
|
from typing import Optional, Dict, Any, List
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from backend.database import Company, Industry, RoboticsCategory, EnrichmentData
|
|
from backend.lib.core_utils import call_gemini_flash, safe_eval_math, run_serp_search
|
|
from backend.services.scraping import scrape_website_content
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ClassificationService:
|
|
def __init__(self):
|
|
# We no longer load industries in init because we don't have a DB session here
|
|
pass
|
|
|
|
def _load_industry_definitions(self, db: Session) -> List[Industry]:
|
|
"""Loads all industry definitions from the database."""
|
|
industries = db.query(Industry).all()
|
|
if not industries:
|
|
logger.warning("No industry definitions found in DB. Classification might be limited.")
|
|
return industries
|
|
|
|
def _get_wikipedia_content(self, db: Session, company_id: int) -> Optional[str]:
|
|
"""Fetches Wikipedia content from enrichment_data for a given company."""
|
|
enrichment = db.query(EnrichmentData).filter(
|
|
EnrichmentData.company_id == company_id,
|
|
EnrichmentData.source_type == "wikipedia"
|
|
).order_by(EnrichmentData.created_at.desc()).first()
|
|
|
|
if enrichment and enrichment.content:
|
|
wiki_data = enrichment.content
|
|
return wiki_data.get('text')
|
|
return None
|
|
|
|
def _run_llm_classification_prompt(self, website_text: str, company_name: str, industry_definitions: List[Dict[str, str]]) -> Optional[str]:
|
|
"""
|
|
Uses LLM to classify the company into one of the predefined industries.
|
|
"""
|
|
prompt = r"""
|
|
Du bist ein präziser Branchen-Klassifizierer für Unternehmen.
|
|
Deine Aufgabe ist es, das vorliegende Unternehmen basierend auf seinem Website-Inhalt
|
|
einer der untenstehenden Branchen zuzuordnen.
|
|
|
|
--- UNTERNEHMEN ---
|
|
Name: {company_name}
|
|
Website-Inhalt (Auszug):
|
|
{website_text_excerpt}
|
|
|
|
--- ZU VERWENDENDE BRANCHEN-DEFINITIONEN (STRIKT) ---
|
|
Wähle EINE der folgenden Branchen. Jede Branche hat eine Definition.
|
|
{industry_definitions_json}
|
|
|
|
--- AUFGABE ---
|
|
Analysiere den Website-Inhalt. Wähle die Branchen-Definition, die am besten zum Unternehmen passt.
|
|
Wenn keine der Definitionen zutrifft oder du unsicher bist, wähle "Others".
|
|
Gib NUR den Namen der zugeordneten Branche zurück, als reinen String, nichts anderes.
|
|
|
|
Beispiel Output: Hotellerie
|
|
""".format(
|
|
company_name=company_name,
|
|
website_text_excerpt=website_text[:10000],
|
|
industry_definitions_json=json.dumps(industry_definitions, ensure_ascii=False)
|
|
)
|
|
|
|
try:
|
|
response = call_gemini_flash(prompt, temperature=0.1, json_mode=False)
|
|
return response.strip()
|
|
except Exception as e:
|
|
logger.error(f"LLM classification failed for {company_name}: {e}")
|
|
return None
|
|
|
|
def _run_llm_metric_extraction_prompt(self, text_content: str, search_term: str, industry_name: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Uses LLM to extract the specific metric value from text.
|
|
"""
|
|
prompt = r"""
|
|
Du bist ein Datenextraktions-Spezialist.
|
|
Analysiere den folgenden Text, um spezifische Metrik-Informationen zu extrahieren.
|
|
|
|
--- KONTEXT ---
|
|
Unternehmen ist in der Branche: {industry_name}
|
|
Gesuchter Wert (Rohdaten): '{search_term}'
|
|
|
|
--- TEXT ---
|
|
{text_content_excerpt}
|
|
|
|
--- AUFGABE ---
|
|
1. Finde den numerischen Wert für '{search_term}'.
|
|
2. Versuche auch, eine explizit genannte Gesamtfläche in Quadratmetern (m²) zu finden, falls relevant und vorhanden.
|
|
|
|
Gib NUR ein JSON-Objekt zurück:
|
|
'raw_value': Der gefundene numerische Wert für '{search_term}' (als Zahl). null, falls nicht gefunden.
|
|
'raw_unit': Die Einheit des raw_value (z.B. "Betten", "Stellplätze"). null, falls nicht gefunden.
|
|
'area_value': Ein gefundener numerischer Wert für eine Gesamtfläche in m² (als Zahl). null, falls nicht gefunden.
|
|
'metric_name': Der Name der Metrik, nach der gesucht wurde (also '{search_term}').
|
|
""".format(
|
|
industry_name=industry_name,
|
|
search_term=search_term,
|
|
text_content_excerpt=text_content[:15000]
|
|
)
|
|
|
|
try:
|
|
response = call_gemini_flash(prompt, temperature=0.05, json_mode=True)
|
|
return json.loads(response)
|
|
except Exception as e:
|
|
logger.error(f"LLM metric extraction failed for '{search_term}': {e}")
|
|
return None
|
|
|
|
def _parse_standardization_logic(self, formula: str, raw_value: float) -> Optional[float]:
|
|
if not formula or raw_value is None:
|
|
return None
|
|
formula_cleaned = formula.replace("wert", str(raw_value)).replace("Value", str(raw_value))
|
|
try:
|
|
return safe_eval_math(formula_cleaned)
|
|
except:
|
|
return None
|
|
|
|
def _extract_and_calculate_metric_cascade(
|
|
self,
|
|
db: Session,
|
|
company: Company,
|
|
industry_name: str,
|
|
search_term: str,
|
|
standardization_logic: Optional[str],
|
|
standardized_unit: Optional[str]
|
|
) -> Dict[str, Any]:
|
|
results = {
|
|
"calculated_metric_name": search_term,
|
|
"calculated_metric_value": None,
|
|
"calculated_metric_unit": None,
|
|
"standardized_metric_value": None,
|
|
"standardized_metric_unit": standardized_unit,
|
|
"metric_source": None
|
|
}
|
|
|
|
# CASCADE: Website -> Wikipedia -> SerpAPI
|
|
sources = [
|
|
("website", lambda: scrape_website_content(company.website)),
|
|
("wikipedia", lambda: self._get_wikipedia_content(db, company.id)),
|
|
("serpapi", lambda: " ".join([res.get("snippet", "") for res in run_serp_search(f"{company.name} {search_term} {industry_name}").get("organic_results", [])]) if run_serp_search(f"{company.name} {search_term} {industry_name}") else None)
|
|
]
|
|
|
|
for source_name, content_loader in sources:
|
|
logger.info(f"Checking {source_name} for '{search_term}' for {company.name}")
|
|
try:
|
|
content = content_loader()
|
|
if not content: continue
|
|
|
|
llm_result = self._run_llm_metric_extraction_prompt(content, search_term, industry_name)
|
|
if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None):
|
|
results["calculated_metric_value"] = llm_result.get("raw_value")
|
|
results["calculated_metric_unit"] = llm_result.get("raw_unit")
|
|
results["metric_source"] = source_name
|
|
|
|
if llm_result.get("area_value") is not None:
|
|
results["standardized_metric_value"] = llm_result.get("area_value")
|
|
elif llm_result.get("raw_value") is not None and standardization_logic:
|
|
results["standardized_metric_value"] = self._parse_standardization_logic(standardization_logic, llm_result["raw_value"])
|
|
|
|
return results
|
|
except Exception as e:
|
|
logger.error(f"Error in {source_name} stage: {e}")
|
|
|
|
return results
|
|
|
|
def classify_company_potential(self, company: Company, db: Session) -> Company:
|
|
logger.info(f"Starting classification for {company.name}")
|
|
|
|
# 1. Load Industries
|
|
industries = self._load_industry_definitions(db)
|
|
industry_defs = [{"name": i.name, "description": i.description} for i in industries]
|
|
|
|
# 2. Industry Classification
|
|
website_content = scrape_website_content(company.website)
|
|
if website_content:
|
|
industry_name = self._run_llm_classification_prompt(website_content, company.name, industry_defs)
|
|
company.industry_ai = industry_name if industry_name in [i.name for i in industries] else "Others"
|
|
else:
|
|
company.industry_ai = "Others"
|
|
|
|
db.commit()
|
|
|
|
# 3. Metric Extraction
|
|
if company.industry_ai != "Others":
|
|
industry = next((i for i in industries if i.name == company.industry_ai), None)
|
|
if industry and industry.scraper_search_term:
|
|
# Derive standardized unit
|
|
std_unit = "m²" if "m²" in (industry.standardization_logic or "") else "Einheiten"
|
|
|
|
metrics = self._extract_and_calculate_metric_cascade(
|
|
db, company, company.industry_ai, industry.scraper_search_term, industry.standardization_logic, std_unit
|
|
)
|
|
|
|
company.calculated_metric_name = metrics["calculated_metric_name"]
|
|
company.calculated_metric_value = metrics["calculated_metric_value"]
|
|
company.calculated_metric_unit = metrics["calculated_metric_unit"]
|
|
company.standardized_metric_value = metrics["standardized_metric_value"]
|
|
company.standardized_metric_unit = metrics["standardized_metric_unit"]
|
|
company.metric_source = metrics["metric_source"]
|
|
|
|
company.last_classification_at = datetime.utcnow()
|
|
db.commit()
|
|
return company
|