Files
Brancheneinstufung2/company-explorer/backend/services/classification.py

207 lines
9.2 KiB
Python

import json
import logging
import re
from typing import Optional, Dict, Any, List
from sqlalchemy.orm import Session
from backend.database import Company, Industry, RoboticsCategory, EnrichmentData
from backend.lib.core_utils import call_gemini_flash, safe_eval_math, run_serp_search
from backend.services.scraping import scrape_website_content
logger = logging.getLogger(__name__)
class ClassificationService:
def __init__(self):
# We no longer load industries in init because we don't have a DB session here
pass
def _load_industry_definitions(self, db: Session) -> List[Industry]:
"""Loads all industry definitions from the database."""
industries = db.query(Industry).all()
if not industries:
logger.warning("No industry definitions found in DB. Classification might be limited.")
return industries
def _get_wikipedia_content(self, db: Session, company_id: int) -> Optional[str]:
"""Fetches Wikipedia content from enrichment_data for a given company."""
enrichment = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company_id,
EnrichmentData.source_type == "wikipedia"
).order_by(EnrichmentData.created_at.desc()).first()
if enrichment and enrichment.content:
wiki_data = enrichment.content
return wiki_data.get('text')
return None
def _run_llm_classification_prompt(self, website_text: str, company_name: str, industry_definitions: List[Dict[str, str]]) -> Optional[str]:
"""
Uses LLM to classify the company into one of the predefined industries.
"""
prompt = r"""
Du bist ein präziser Branchen-Klassifizierer für Unternehmen.
Deine Aufgabe ist es, das vorliegende Unternehmen basierend auf seinem Website-Inhalt
einer der untenstehenden Branchen zuzuordnen.
--- UNTERNEHMEN ---
Name: {company_name}
Website-Inhalt (Auszug):
{website_text_excerpt}
--- ZU VERWENDENDE BRANCHEN-DEFINITIONEN (STRIKT) ---
Wähle EINE der folgenden Branchen. Jede Branche hat eine Definition.
{industry_definitions_json}
--- AUFGABE ---
Analysiere den Website-Inhalt. Wähle die Branchen-Definition, die am besten zum Unternehmen passt.
Wenn keine der Definitionen zutrifft oder du unsicher bist, wähle "Others".
Gib NUR den Namen der zugeordneten Branche zurück, als reinen String, nichts anderes.
Beispiel Output: Hotellerie
""".format(
company_name=company_name,
website_text_excerpt=website_text[:10000],
industry_definitions_json=json.dumps(industry_definitions, ensure_ascii=False)
)
try:
response = call_gemini_flash(prompt, temperature=0.1, json_mode=False)
return response.strip()
except Exception as e:
logger.error(f"LLM classification failed for {company_name}: {e}")
return None
def _run_llm_metric_extraction_prompt(self, text_content: str, search_term: str, industry_name: str) -> Optional[Dict[str, Any]]:
"""
Uses LLM to extract the specific metric value from text.
"""
prompt = r"""
Du bist ein Datenextraktions-Spezialist.
Analysiere den folgenden Text, um spezifische Metrik-Informationen zu extrahieren.
--- KONTEXT ---
Unternehmen ist in der Branche: {industry_name}
Gesuchter Wert (Rohdaten): '{search_term}'
--- TEXT ---
{text_content_excerpt}
--- AUFGABE ---
1. Finde den numerischen Wert für '{search_term}'.
2. Versuche auch, eine explizit genannte Gesamtfläche in Quadratmetern (m²) zu finden, falls relevant und vorhanden.
Gib NUR ein JSON-Objekt zurück:
'raw_value': Der gefundene numerische Wert für '{search_term}' (als Zahl). null, falls nicht gefunden.
'raw_unit': Die Einheit des raw_value (z.B. "Betten", "Stellplätze"). null, falls nicht gefunden.
'area_value': Ein gefundener numerischer Wert für eine Gesamtfläche in m² (als Zahl). null, falls nicht gefunden.
'metric_name': Der Name der Metrik, nach der gesucht wurde (also '{search_term}').
""".format(
industry_name=industry_name,
search_term=search_term,
text_content_excerpt=text_content[:15000]
)
try:
response = call_gemini_flash(prompt, temperature=0.05, json_mode=True)
return json.loads(response)
except Exception as e:
logger.error(f"LLM metric extraction failed for '{search_term}': {e}")
return None
def _parse_standardization_logic(self, formula: str, raw_value: float) -> Optional[float]:
if not formula or raw_value is None:
return None
formula_cleaned = formula.replace("wert", str(raw_value)).replace("Value", str(raw_value))
try:
return safe_eval_math(formula_cleaned)
except:
return None
def _extract_and_calculate_metric_cascade(
self,
db: Session,
company: Company,
industry_name: str,
search_term: str,
standardization_logic: Optional[str],
standardized_unit: Optional[str]
) -> Dict[str, Any]:
results = {
"calculated_metric_name": search_term,
"calculated_metric_value": None,
"calculated_metric_unit": None,
"standardized_metric_value": None,
"standardized_metric_unit": standardized_unit,
"metric_source": None
}
# CASCADE: Website -> Wikipedia -> SerpAPI
sources = [
("website", lambda: scrape_website_content(company.website)),
("wikipedia", lambda: self._get_wikipedia_content(db, company.id)),
("serpapi", lambda: " ".join([res.get("snippet", "") for res in run_serp_search(f"{company.name} {search_term} {industry_name}").get("organic_results", [])]) if run_serp_search(f"{company.name} {search_term} {industry_name}") else None)
]
for source_name, content_loader in sources:
logger.info(f"Checking {source_name} for '{search_term}' for {company.name}")
try:
content = content_loader()
if not content: continue
llm_result = self._run_llm_metric_extraction_prompt(content, search_term, industry_name)
if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None):
results["calculated_metric_value"] = llm_result.get("raw_value")
results["calculated_metric_unit"] = llm_result.get("raw_unit")
results["metric_source"] = source_name
if llm_result.get("area_value") is not None:
results["standardized_metric_value"] = llm_result.get("area_value")
elif llm_result.get("raw_value") is not None and standardization_logic:
results["standardized_metric_value"] = self._parse_standardization_logic(standardization_logic, llm_result["raw_value"])
return results
except Exception as e:
logger.error(f"Error in {source_name} stage: {e}")
return results
def classify_company_potential(self, company: Company, db: Session) -> Company:
logger.info(f"Starting classification for {company.name}")
# 1. Load Industries
industries = self._load_industry_definitions(db)
industry_defs = [{"name": i.name, "description": i.description} for i in industries]
# 2. Industry Classification
website_content = scrape_website_content(company.website)
if website_content:
industry_name = self._run_llm_classification_prompt(website_content, company.name, industry_defs)
company.industry_ai = industry_name if industry_name in [i.name for i in industries] else "Others"
else:
company.industry_ai = "Others"
db.commit()
# 3. Metric Extraction
if company.industry_ai != "Others":
industry = next((i for i in industries if i.name == company.industry_ai), None)
if industry and industry.scraper_search_term:
# Derive standardized unit
std_unit = "" if "" in (industry.standardization_logic or "") else "Einheiten"
metrics = self._extract_and_calculate_metric_cascade(
db, company, company.industry_ai, industry.scraper_search_term, industry.standardization_logic, std_unit
)
company.calculated_metric_name = metrics["calculated_metric_name"]
company.calculated_metric_value = metrics["calculated_metric_value"]
company.calculated_metric_unit = metrics["calculated_metric_unit"]
company.standardized_metric_value = metrics["standardized_metric_value"]
company.standardized_metric_unit = metrics["standardized_metric_unit"]
company.metric_source = metrics["metric_source"]
company.last_classification_at = datetime.utcnow()
db.commit()
return company