334 lines
17 KiB
Python
334 lines
17 KiB
Python
import json
|
|
import logging
|
|
import re
|
|
from typing import Optional, Dict, Any, List
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from backend.database import Company, Industry, RoboticsCategory, EnrichmentData, get_db
|
|
from backend.config import settings
|
|
from backend.lib.core_utils import call_gemini_flash, safe_eval_math, run_serp_search
|
|
from backend.services.scraping import scrape_website_content # Corrected import
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ClassificationService:
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
self.allowed_industries_notion: List[Industry] = self._load_industry_definitions()
|
|
self.robotics_categories: List[RoboticsCategory] = self._load_robotics_categories()
|
|
|
|
# Pre-process allowed industries for LLM prompt
|
|
self.llm_industry_definitions = [
|
|
{"name": ind.name, "description": ind.description} for ind in self.allowed_industries_notion
|
|
]
|
|
|
|
# Store for quick lookup
|
|
self.industry_lookup = {ind.name: ind for ind in self.allowed_industries_notion}
|
|
self.category_lookup = {cat.id: cat for cat in self.robotics_categories}
|
|
|
|
def _load_industry_definitions(self) -> List[Industry]:
|
|
"""Loads all industry definitions from the database."""
|
|
industries = self.db.query(Industry).all()
|
|
if not industries:
|
|
logger.warning("No industry definitions found in DB. Classification might be limited.")
|
|
return industries
|
|
|
|
def _load_robotics_categories(self) -> List[RoboticsCategory]:
|
|
"""Loads all robotics categories from the database."""
|
|
categories = self.db.query(RoboticsCategory).all()
|
|
if not categories:
|
|
logger.warning("No robotics categories found in DB. Potential scoring might be limited.")
|
|
return categories
|
|
|
|
def _get_wikipedia_content(self, company_id: int) -> Optional[str]:
|
|
"""Fetches Wikipedia content from enrichment_data for a given company."""
|
|
enrichment = self.db.query(EnrichmentData).filter(
|
|
EnrichmentData.company_id == company_id,
|
|
EnrichmentData.source_type == "wikipedia"
|
|
).order_by(EnrichmentData.created_at.desc()).first()
|
|
|
|
if enrichment and enrichment.content:
|
|
# Wikipedia content is stored as JSON with a 'text' key
|
|
wiki_data = enrichment.content
|
|
return wiki_data.get('text')
|
|
return None
|
|
|
|
def _run_llm_classification_prompt(self, website_text: str, company_name: str) -> Optional[str]:
|
|
"""
|
|
Uses LLM to classify the company into one of the predefined industries.
|
|
Returns the industry name (string) or "Others".
|
|
"""
|
|
prompt = r"""
|
|
Du bist ein präziser Branchen-Klassifizierer für Unternehmen.
|
|
Deine Aufgabe ist es, das vorliegende Unternehmen basierend auf seinem Website-Inhalt
|
|
einer der untenstehenden Branchen zuzuordnen.
|
|
|
|
--- UNTERNEHMEN ---
|
|
Name: {company_name}
|
|
Website-Inhalt (Auszug):
|
|
{website_text_excerpt}
|
|
|
|
--- ZU VERWENDENDE BRANCHEN-DEFINITIONEN (STRIKT) ---
|
|
Wähle EINE der folgenden Branchen. Jede Branche hat eine Definition.
|
|
{industry_definitions_json}
|
|
|
|
--- AUFGABE ---
|
|
Analysiere den Website-Inhalt. Wähle die Branchen-Definition, die am besten zum Unternehmen passt.
|
|
Wenn keine der Definitionen zutrifft oder du unsicher bist, wähle "Others".
|
|
Gib NUR den Namen der zugeordneten Branche zurück, als reinen String, nichts anderes.
|
|
|
|
Beispiel Output: Hotellerie
|
|
Beispiel Output: Automotive - Dealer
|
|
Beispiel Output: Others
|
|
""".format(
|
|
company_name=company_name,
|
|
website_text_excerpt=website_text[:10000], # Limit text to avoid token limits
|
|
industry_definitions_json=json.dumps(self.llm_industry_definitions, ensure_ascii=False)
|
|
)
|
|
|
|
try:
|
|
response = call_gemini_flash(prompt, temperature=0.1, json_mode=False) # Low temp for strict classification
|
|
classified_industry = response.strip()
|
|
if classified_industry in [ind.name for ind in self.allowed_industries_notion] + ["Others"]:
|
|
return classified_industry
|
|
logger.warning(f"LLM classified industry '{classified_industry}' not in allowed list. Defaulting to Others.")
|
|
return "Others"
|
|
except Exception as e:
|
|
logger.error(f"LLM classification failed for {company_name}: {e}", exc_info=True)
|
|
return None
|
|
|
|
def _run_llm_metric_extraction_prompt(self, text_content: str, search_term: str, industry_name: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Uses LLM to extract the specific metric value from text.
|
|
Returns a dict with 'raw_value', 'raw_unit', 'standardized_value' (if found), 'metric_name'.
|
|
"""
|
|
# Attempt to extract both the raw unit count and a potential area if explicitly mentioned
|
|
prompt = r"""
|
|
Du bist ein Datenextraktions-Spezialist.
|
|
Analysiere den folgenden Text, um spezifische Metrik-Informationen zu extrahieren.
|
|
|
|
--- KONTEXT ---
|
|
Unternehmen ist in der Branche: {industry_name}
|
|
Gesuchter Wert (Rohdaten): '{search_term}'
|
|
|
|
--- TEXT ---
|
|
{text_content_excerpt}
|
|
|
|
--- AUFGABE ---
|
|
1. Finde den numerischen Wert für '{search_term}'.
|
|
2. Versuche auch, eine explizit genannte Gesamtfläche in Quadratmetern (m²) zu finden, falls relevant und vorhanden.
|
|
|
|
Gib NUR ein JSON-Objekt zurück mit den Schlüsseln:
|
|
'raw_value': Der gefundene numerische Wert für '{search_term}' (als Zahl). null, falls nicht gefunden.
|
|
'raw_unit': Die Einheit des raw_value (z.B. "Betten", "Stellplätze"). null, falls nicht gefunden.
|
|
'area_value': Ein gefundener numerischer Wert für eine Gesamtfläche in m² (als Zahl). null, falls nicht gefunden.
|
|
'metric_name': Der Name der Metrik, nach der gesucht wurde (also '{search_term}').
|
|
|
|
Beispiel Output (wenn 180 Betten und 4500m² Fläche gefunden):
|
|
{{"raw_value": 180, "raw_unit": "Betten", "area_value": 4500, "metric_name": "{search_term}"}}
|
|
|
|
Beispiel Output (wenn nur 180 Betten gefunden):
|
|
{{"raw_value": 180, "raw_unit": "Betten", "area_value": null, "metric_name": "{search_term}"}}
|
|
|
|
Beispiel Output (wenn nichts gefunden):
|
|
{{"raw_value": null, "raw_unit": null, "area_value": null, "metric_name": "{search_term}"}}
|
|
""".format(
|
|
industry_name=industry_name,
|
|
search_term=search_term,
|
|
text_content_excerpt=text_content[:15000] # Adjust as needed for token limits
|
|
)
|
|
|
|
try:
|
|
response = call_gemini_flash(prompt, temperature=0.05, json_mode=True) # Very low temp for extraction
|
|
result = json.loads(response)
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"LLM metric extraction failed for '{search_term}' in '{industry_name}': {e}", exc_info=True)
|
|
return None
|
|
|
|
def _parse_standardization_logic(self, formula: str, raw_value: float) -> Optional[float]:
|
|
"""
|
|
Safely parses and executes a simple mathematical formula for standardization.
|
|
Supports basic arithmetic (+, -, *, /) and integer/float values.
|
|
"""
|
|
if not formula or not raw_value:
|
|
return None
|
|
|
|
# Replace 'wert' or 'value' with the actual raw_value
|
|
formula_cleaned = formula.replace("wert", str(raw_value)).replace("Value", str(raw_value)).replace("VALUE", str(raw_value))
|
|
|
|
try:
|
|
# Use safe_eval_math from core_utils to prevent arbitrary code execution
|
|
return safe_eval_math(formula_cleaned)
|
|
except Exception as e:
|
|
logger.error(f"Error evaluating standardization logic '{formula}' with value {raw_value}: {e}", exc_info=True)
|
|
return None
|
|
|
|
def _extract_and_calculate_metric_cascade(
|
|
self,
|
|
company: Company,
|
|
industry_name: str,
|
|
search_term: str,
|
|
standardization_logic: Optional[str],
|
|
standardized_unit: Optional[str]
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Orchestrates the 3-stage (Website -> Wikipedia -> SerpAPI) metric extraction.
|
|
"""
|
|
results = {
|
|
"calculated_metric_name": search_term,
|
|
"calculated_metric_value": None,
|
|
"calculated_metric_unit": None,
|
|
"standardized_metric_value": None,
|
|
"standardized_metric_unit": standardized_unit,
|
|
"metric_source": None
|
|
}
|
|
|
|
# --- STAGE 1: Website Analysis ---
|
|
logger.info(f"Stage 1: Analyzing website for '{search_term}' for {company.name}")
|
|
website_content = scrape_website_content(company.website)
|
|
if website_content:
|
|
llm_result = self._run_llm_metric_extraction_prompt(website_content, search_term, industry_name)
|
|
if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None):
|
|
results["calculated_metric_value"] = llm_result.get("raw_value")
|
|
results["calculated_metric_unit"] = llm_result.get("raw_unit")
|
|
results["metric_source"] = "website"
|
|
|
|
if llm_result.get("area_value") is not None:
|
|
# Prioritize directly found standardized area
|
|
results["standardized_metric_value"] = llm_result.get("area_value")
|
|
logger.info(f"Direct area value found on website for {company.name}: {llm_result.get('area_value')} m²")
|
|
elif llm_result.get("raw_value") is not None and standardization_logic:
|
|
# Calculate if only raw value found
|
|
results["standardized_metric_value"] = self._parse_standardization_logic(
|
|
standardization_logic, llm_result["raw_value"]
|
|
)
|
|
return results
|
|
|
|
# --- STAGE 2: Wikipedia Analysis ---
|
|
logger.info(f"Stage 2: Analyzing Wikipedia for '{search_term}' for {company.name}")
|
|
wikipedia_content = self._get_wikipedia_content(company.id)
|
|
if wikipedia_content:
|
|
llm_result = self._run_llm_metric_extraction_prompt(wikipedia_content, search_term, industry_name)
|
|
if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None):
|
|
results["calculated_metric_value"] = llm_result.get("raw_value")
|
|
results["calculated_metric_unit"] = llm_result.get("raw_unit")
|
|
results["metric_source"] = "wikipedia"
|
|
|
|
if llm_result.get("area_value") is not None:
|
|
results["standardized_metric_value"] = llm_result.get("area_value")
|
|
logger.info(f"Direct area value found on Wikipedia for {company.name}: {llm_result.get('area_value')} m²")
|
|
elif llm_result.get("raw_value") is not None and standardization_logic:
|
|
results["standardized_metric_value"] = self._parse_standardization_logic(
|
|
standardization_logic, llm_result["raw_value"]
|
|
)
|
|
return results
|
|
|
|
# --- STAGE 3: SerpAPI (Google Search) ---
|
|
logger.info(f"Stage 3: Running SerpAPI search for '{search_term}' for {company.name}")
|
|
search_query = f"{company.name} {search_term} {industry_name}" # Example: "Hotel Moxy Würzburg Anzahl Betten Hotellerie"
|
|
serp_results = run_serp_search(search_query) # This returns a dictionary of search results
|
|
|
|
if serp_results and serp_results.get("organic_results"):
|
|
# Concatenate snippets from organic results
|
|
snippets = " ".join([res.get("snippet", "") for res in serp_results["organic_results"]])
|
|
if snippets:
|
|
llm_result = self._run_llm_metric_extraction_prompt(snippets, search_term, industry_name)
|
|
if llm_result and (llm_result.get("raw_value") is not None or llm_result.get("area_value") is not None):
|
|
results["calculated_metric_value"] = llm_result.get("raw_value")
|
|
results["calculated_metric_unit"] = llm_result.get("raw_unit")
|
|
results["metric_source"] = "serpapi"
|
|
|
|
if llm_result.get("area_value") is not None:
|
|
results["standardized_metric_value"] = llm_result.get("area_value")
|
|
logger.info(f"Direct area value found via SerpAPI for {company.name}: {llm_result.get('area_value')} m²")
|
|
elif llm_result.get("raw_value") is not None and standardization_logic:
|
|
results["standardized_metric_value"] = self._parse_standardization_logic(
|
|
standardization_logic, llm_result["raw_value"]
|
|
)
|
|
return results
|
|
|
|
logger.info(f"Could not extract metric for '{search_term}' from any source for {company.name}.")
|
|
return results # Return results with None values
|
|
|
|
def classify_company_potential(self, company: Company) -> Company:
|
|
"""
|
|
Main method to classify industry and calculate potential metric for a company.
|
|
"""
|
|
logger.info(f"Starting classification for Company ID: {company.id}, Name: {company.name}")
|
|
|
|
# --- STEP 1: Strict Industry Classification ---
|
|
website_content_for_classification = scrape_website_content(company.website)
|
|
if not website_content_for_classification:
|
|
logger.warning(f"No website content found for {company.name}. Skipping industry classification.")
|
|
company.industry_ai = "Others" # Default if no content
|
|
else:
|
|
classified_industry_name = self._run_llm_classification_prompt(website_content_for_classification, company.name)
|
|
if classified_industry_name:
|
|
company.industry_ai = classified_industry_name
|
|
logger.info(f"Classified {company.name} into industry: {classified_industry_name}")
|
|
else:
|
|
company.industry_ai = "Others"
|
|
logger.warning(f"Failed to classify industry for {company.name}. Setting to 'Others'.")
|
|
|
|
self.db.add(company) # Update industry_ai
|
|
self.db.commit()
|
|
self.db.refresh(company)
|
|
|
|
# --- STEP 2: Metric Extraction & Standardization (if not 'Others') ---
|
|
if company.industry_ai == "Others" or company.industry_ai is None:
|
|
logger.info(f"Company {company.name} classified as 'Others'. Skipping metric extraction.")
|
|
return company
|
|
|
|
industry_definition = self.industry_lookup.get(company.industry_ai)
|
|
if not industry_definition:
|
|
logger.error(f"Industry definition for '{company.industry_ai}' not found in lookup. Skipping metric extraction.")
|
|
return company
|
|
|
|
if not industry_definition.scraper_search_term:
|
|
logger.info(f"Industry '{company.industry_ai}' has no 'Scraper Search Term'. Skipping metric extraction.")
|
|
return company
|
|
|
|
# Determine standardized unit from standardization_logic if possible
|
|
standardized_unit = "Einheiten" # Default
|
|
if industry_definition.standardization_logic:
|
|
# Example: "wert * 25m² (Fläche pro Zimmer)" -> extract "m²"
|
|
match = re.search(r'(\w+)$', industry_definition.standardization_logic.replace(' ', ''))
|
|
if match:
|
|
standardized_unit = match.group(1).replace('(', '').replace(')', '') # Extract unit like "m²"
|
|
|
|
metric_results = self._extract_and_calculate_metric_cascade(
|
|
company,
|
|
company.industry_ai,
|
|
industry_definition.scraper_search_term,
|
|
industry_definition.standardization_logic,
|
|
standardized_unit # Pass the derived unit
|
|
)
|
|
|
|
# Update company object with results
|
|
company.calculated_metric_name = metric_results["calculated_metric_name"]
|
|
company.calculated_metric_value = metric_results["calculated_metric_value"]
|
|
company.calculated_metric_unit = metric_results["calculated_metric_unit"]
|
|
company.standardized_metric_value = metric_results["standardized_metric_value"]
|
|
company.standardized_metric_unit = metric_results["standardized_metric_unit"]
|
|
company.metric_source = metric_results["metric_source"]
|
|
company.last_classification_at = datetime.utcnow() # Update timestamp
|
|
|
|
self.db.add(company)
|
|
self.db.commit()
|
|
self.db.refresh(company) # Refresh to get updated values
|
|
|
|
logger.info(f"Classification and metric extraction completed for {company.name}.")
|
|
return company
|
|
|
|
# --- HELPER FOR SAFE MATH EVALUATION (Moved from core_utils.py or assumed to be there) ---
|
|
# Assuming safe_eval_math is available via backend.lib.core_utils.safe_eval_math
|
|
# Example implementation if not:
|
|
# def safe_eval_math(expression: str) -> float:
|
|
# # Implement a safe parser/evaluator for simple math expressions
|
|
# # For now, a very basic eval might be used, but in production, this needs to be locked down
|
|
# allowed_chars = "0123456789.+-*/ "
|
|
# if not all(c in allowed_chars for c in expression):
|
|
# raise ValueError("Expression contains disallowed characters.")
|
|
# return eval(expression) |