Files
Brancheneinstufung2/company-explorer/backend/services/classification.py

360 lines
17 KiB
Python

from typing import Tuple
import json
import logging
import re
from datetime import datetime
from typing import Optional, Dict, Any, List
from sqlalchemy.orm import Session
from backend.database import Company, Industry, RoboticsCategory, EnrichmentData
from backend.lib.core_utils import call_gemini_flash, safe_eval_math, run_serp_search
from backend.services.scraping import scrape_website_content
from backend.lib.metric_parser import MetricParser
logger = logging.getLogger(__name__)
class ClassificationService:
def __init__(self):
pass
def _load_industry_definitions(self, db: Session) -> List[Industry]:
industries = db.query(Industry).all()
if not industries:
logger.warning("No industry definitions found in DB. Classification might be limited.")
return industries
def _get_wikipedia_content(self, db: Session, company_id: int) -> Optional[Dict[str, Any]]:
enrichment = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company_id,
EnrichmentData.source_type == "wikipedia"
).order_by(EnrichmentData.created_at.desc()).first()
return enrichment.content if enrichment and enrichment.content else None
def _run_llm_classification_prompt(self, website_text: str, company_name: str, industry_definitions: List[Dict[str, str]]) -> Optional[str]:
prompt = f"""
Act as a strict B2B Industry Classifier.
Company: {company_name}
Context: {website_text[:3000]}
Available Industries:
{json.dumps(industry_definitions, indent=2)}
Task: Select the ONE industry that best matches the company.
If the company is a Hospital/Klinik, select 'Healthcare - Hospital'.
If none match well, select 'Others'.
Return ONLY the exact name of the industry.
"""
try:
response = call_gemini_flash(prompt)
if not response: return "Others"
cleaned = response.strip().replace('"', '').replace("'", "")
# Simple fuzzy match check
valid_names = [i['name'] for i in industry_definitions] + ["Others"]
if cleaned in valid_names:
return cleaned
# Fallback: Try to find name in response
for name in valid_names:
if name in cleaned:
return name
return "Others"
except Exception as e:
logger.error(f"Classification Prompt Error: {e}")
return "Others"
def _run_llm_metric_extraction_prompt(self, text_content: str, search_term: str, industry_name: str) -> Optional[Dict[str, Any]]:
prompt = f"""
Extract the following metric for the company in industry '{industry_name}':
Target Metric: "{search_term}"
Source Text:
{text_content[:6000]}
Return a JSON object with:
- "raw_value": The number found (e.g. 352 or 352.0). If not found, null.
- "raw_unit": The unit found (e.g. "Betten", "").
- "proof_text": A short quote from the text proving this value.
**IMPORTANT:** Ignore obvious year numbers (like 1900-2026) if other, more plausible metric values are present in the text. Focus on the target metric.
JSON ONLY.
"""
try:
response = call_gemini_flash(prompt, json_mode=True)
if not response: return None
if isinstance(response, str):
response = response.replace("```json", "").replace("```", "").strip()
data = json.loads(response)
else:
data = response
# Basic cleanup
if data.get("raw_value") == "null": data["raw_value"] = None
return data
except Exception as e:
logger.error(f"LLM Extraction Parse Error: {e}")
return None
def _is_metric_plausible(self, metric_name: str, value: Optional[float]) -> bool:
if value is None: return False
try:
val_float = float(value)
return val_float > 0
except:
return False
def _parse_standardization_logic(self, formula: str, raw_value: float) -> Optional[float]:
if not formula or raw_value is None:
return None
formula_cleaned = formula.replace("wert", str(raw_value)).replace("Value", str(raw_value)).replace("Wert", str(raw_value))
formula_cleaned = re.sub(r'(?i)m[²2]', '', formula_cleaned)
formula_cleaned = re.sub(r'(?i)qm', '', formula_cleaned)
formula_cleaned = re.sub(r'\s*\(.*\)\s*$', '', formula_cleaned).strip()
try:
return safe_eval_math(formula_cleaned)
except Exception as e:
logger.error(f"Failed to parse standardization logic '{formula}' with value {raw_value}: {e}")
return None
def _get_best_metric_result(self, results_list: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if not results_list:
return None
source_priority = {"wikipedia": 0, "website": 1, "serpapi": 2}
valid_results = [r for r in results_list if r.get("calculated_metric_value") is not None]
if not valid_results:
return None
valid_results.sort(key=lambda r: (source_priority.get(r.get("metric_source"), 99), -r.get("metric_confidence", 0.0)))
logger.info(f"Best result chosen: {valid_results[0]}")
return valid_results[0]
def _get_website_content_and_url(self, company: Company) -> Tuple[Optional[str], Optional[str]]:
return scrape_website_content(company.website), company.website
def _get_wikipedia_content_and_url(self, db: Session, company_id: int) -> Tuple[Optional[str], Optional[str]]:
wiki_data = self._get_wikipedia_content(db, company_id)
return (wiki_data.get('full_text'), wiki_data.get('url')) if wiki_data else (None, None)
def _get_serpapi_content_and_url(self, company: Company, search_term: str) -> Tuple[Optional[str], Optional[str]]:
serp_results = run_serp_search(f"{company.name} {company.city or ''} {search_term}")
if not serp_results:
return None, None
content = " ".join([res.get("snippet", "") for res in serp_results.get("organic_results", [])])
url = serp_results.get("organic_results", [{}])[0].get("link") if serp_results.get("organic_results") else None
return content, url
def _extract_and_calculate_metric_cascade(self, db: Session, company: Company, industry_name: str, search_term: str, standardization_logic: Optional[str], standardized_unit: Optional[str]) -> Dict[str, Any]:
final_result = {"calculated_metric_name": search_term, "calculated_metric_value": None, "calculated_metric_unit": None, "standardized_metric_value": None, "standardized_metric_unit": standardized_unit, "metric_source": None, "metric_proof_text": None, "metric_source_url": None, "metric_confidence": 0.0, "metric_confidence_reason": "No value found in any source."}
sources = [
("website", self._get_website_content_and_url),
("wikipedia", self._get_wikipedia_content_and_url),
("serpapi", self._get_serpapi_content_and_url)
]
all_source_results = []
for source_name, content_loader in sources:
logger.info(f"Checking {source_name} for '{search_term}' for {company.name}")
try:
args = (company,) if source_name == 'website' else (db, company.id) if source_name == 'wikipedia' else (company, search_term)
content_text, current_source_url = content_loader(*args)
if not content_text or len(content_text) < 100:
logger.info(f"No or insufficient content for {source_name} (Length: {len(content_text) if content_text else 0}).")
continue
llm_result = self._run_llm_metric_extraction_prompt(content_text, search_term, industry_name)
if llm_result:
llm_result['source_url'] = current_source_url
all_source_results.append((source_name, llm_result))
except Exception as e:
logger.error(f"Error in {source_name} stage: {e}")
processed_results = []
for source_name, llm_result in all_source_results:
metric_value = llm_result.get("raw_value")
metric_unit = llm_result.get("raw_unit")
if metric_value is not None and self._is_metric_plausible(search_term, metric_value):
standardized_value = None
if standardization_logic and metric_value is not None:
standardized_value = self._parse_standardization_logic(standardization_logic, metric_value)
processed_results.append({
"calculated_metric_name": search_term,
"calculated_metric_value": metric_value,
"calculated_metric_unit": metric_unit,
"standardized_metric_value": standardized_value,
"standardized_metric_unit": standardized_unit,
"metric_source": source_name,
"metric_proof_text": llm_result.get("proof_text"),
"metric_source_url": llm_result.get("source_url"),
"metric_confidence": 0.95,
"metric_confidence_reason": "Value found and extracted by LLM."
})
else:
logger.info(f"LLM found no plausible metric for {search_term} in {source_name}.")
best_result = self._get_best_metric_result(processed_results)
return best_result if best_result else final_result
def extract_metrics_for_industry(self, company: Company, db: Session, industry: Industry) -> Company:
if not industry or not industry.scraper_search_term:
logger.warning(f"No metric configuration for industry '{industry.name if industry else 'None'}'")
return company
# Improved unit derivation
if "" in (industry.standardization_logic or "") or "" in (industry.scraper_search_term or ""):
std_unit = ""
else:
std_unit = "Einheiten"
metrics = self._extract_and_calculate_metric_cascade(
db, company, industry.name, industry.scraper_search_term, industry.standardization_logic, std_unit
)
company.calculated_metric_name = metrics["calculated_metric_name"]
company.calculated_metric_value = metrics["calculated_metric_value"]
company.calculated_metric_unit = metrics["calculated_metric_unit"]
company.standardized_metric_value = metrics["standardized_metric_value"]
company.standardized_metric_unit = metrics["standardized_metric_unit"]
company.metric_source = metrics["metric_source"]
company.metric_proof_text = metrics["metric_proof_text"]
company.metric_source_url = metrics.get("metric_source_url")
company.metric_confidence = metrics["metric_confidence"]
company.metric_confidence_reason = metrics["metric_confidence_reason"]
company.last_classification_at = datetime.utcnow()
# REMOVED: db.commit() - This should be handled by the calling function.
return company
def reevaluate_wikipedia_metric(self, company: Company, db: Session, industry: Industry) -> Company:
logger.info(f"Re-evaluating metric for {company.name}...")
return self.extract_metrics_for_industry(company, db, industry)
def _generate_marketing_opener(self, company_name: str, website_text: str, industry_name: str, industry_pains: str, focus_mode: str = "primary") -> Optional[str]:
"""
Generates the 'First Sentence' (Opener).
focus_mode: 'primary' (Standard/Cleaning) or 'secondary' (Service/Logistics).
"""
if not industry_pains:
industry_pains = "Effizienz und Personalmangel" # Fallback
# Dynamic Focus Instruction
if focus_mode == "secondary":
focus_instruction = """
- **FOKUS: SEKUNDÄR-PROZESSE (Logistik/Service/Versorgung).**
- Ignoriere das Thema Reinigung. Konzentriere dich auf **Abläufe, Materialfluss, Entlastung von Fachkräften** oder **Gäste-Service**.
- Der Satz muss einen operativen Entscheider (z.B. Pflegedienstleitung, Produktionsleiter) abholen."""
else:
focus_instruction = """
- **FOKUS: PRIMÄR-PROZESSE (Infrastruktur/Sauberkeit/Sicherheit).**
- Konzentriere dich auf Anforderungen an das Facility Management, Hygiene, Außenwirkung oder Arbeitssicherheit.
- Der Satz muss einen Infrastruktur-Entscheider (z.B. FM-Leiter, Geschäftsführer) abholen."""
prompt = f"""
Du bist ein exzellenter B2B-Stratege und Texter.
Deine Aufgabe ist es, einen hochpersonalisierten Einleitungssatz für eine E-Mail an ein potenzielles Kundenunternehmen zu formulieren.
--- KONTEXT ---
Zielunternehmen: {company_name}
Branche: {industry_name}
Operative Herausforderung (Pain): "{industry_pains}"
Webseiten-Kontext:
{website_text[:2500]}
--- Denkprozess & Stilvorgaben ---
1. **Analysiere den Kontext:** Verstehe das Kerngeschäft.
2. **Identifiziere den Hebel:** Was ist der Erfolgsfaktor in Bezug auf den FOKUS?
3. **Formuliere den Satz (ca. 20-35 Wörter):**
- Wähle einen eleganten, aktiven Einstieg.
- Verbinde die **Tätigkeit** mit dem **Hebel** und den **Konsequenzen**.
- **WICHTIG:** Formuliere als positive Beobachtung über eine Kernkompetenz.
- **VERMEIDE:** Konkrete Zahlen.
- Verwende den Firmennamen: {company_name}.
{focus_instruction}
--- Deine Ausgabe ---
Gib NUR den finalen Satz aus. Keine Anführungszeichen.
"""
try:
response = call_gemini_flash(prompt)
if response:
return response.strip().strip('"')
return None
except Exception as e:
logger.error(f"Opener Generation Error: {e}")
return None
def classify_company_potential(self, company: Company, db: Session) -> Company:
logger.info(f"Starting classification for {company.name}...")
# 1. Load Definitions
industries = self._load_industry_definitions(db)
industry_defs = [{"name": i.name, "description": i.description} for i in industries]
logger.debug(f"Loaded {len(industries)} industry definitions.")
# 2. Get Content (Website)
website_content, _ = self._get_website_content_and_url(company)
if not website_content or len(website_content) < 100:
logger.warning(f"No or insufficient website content for {company.name} (Length: {len(website_content) if website_content else 0}). Skipping classification.")
return company
logger.debug(f"Website content length for classification: {len(website_content)}")
# 3. Classify Industry
logger.info(f"Running LLM classification prompt for {company.name}...")
suggested_industry_name = self._run_llm_classification_prompt(website_content, company.name, industry_defs)
logger.info(f"AI suggests industry: {suggested_industry_name}")
# 4. Update Company & Generate Openers
matched_industry = next((i for i in industries if i.name == suggested_industry_name), None)
if matched_industry:
company.industry_ai = matched_industry.name
logger.info(f"Matched company to industry: {matched_industry.name}")
# --- Generate PRIMARY Opener (Infrastructure/Cleaning) ---
logger.info(f"Generating PRIMARY opener for {company.name}...")
op_prim = self._generate_marketing_opener(
company.name, website_content, matched_industry.name, matched_industry.pains, "primary"
)
if op_prim:
company.ai_opener = op_prim
logger.info(f"Opener (Primary) generated and set.")
else:
logger.warning(f"Failed to generate PRIMARY opener for {company.name}.")
# --- Generate SECONDARY Opener (Service/Logistics) ---
logger.info(f"Generating SECONDARY opener for {company.name}...")
op_sec = self._generate_marketing_opener(
company.name, website_content, matched_industry.name, matched_industry.pains, "secondary"
)
if op_sec:
company.ai_opener_secondary = op_sec
logger.info(f"Opener (Secondary) generated and set.")
else:
logger.warning(f"Failed to generate SECONDARY opener for {company.name}.")
else:
company.industry_ai = "Others"
logger.warning(f"No specific industry matched for {company.name}. Set to 'Others'.")
# 5. Extract Metrics (Cascade)
if matched_industry:
logger.info(f"Extracting metrics for {company.name} and industry {matched_industry.name}...")
try:
self.extract_metrics_for_industry(company, db, matched_industry)
logger.info(f"Metric extraction completed for {company.name}.")
except Exception as e:
logger.error(f"Error during metric extraction for {company.name}: {e}", exc_info=True)
else:
logger.warning(f"Skipping metric extraction for {company.name} as no specific industry was matched.")
company.last_classification_at = datetime.utcnow()
db.commit()
logger.info(f"Classification and enrichment for {company.name} completed and committed.")
return company