Files
Brancheneinstufung2/company-explorer/backend/services/classification.py
Floke 31e1a5fc08 fix(classification): restore service logic and standardization formula
- Restored missing method implementations in ClassificationService (classify, extract_metrics)
- Fixed Standardization Logic not being applied in metric cascade
- Bumped version to v0.7.4 in config.py
- Removed duplicate API endpoint in app.py
- Updated MIGRATION_PLAN.md
2026-01-24 13:34:04 +00:00

268 lines
12 KiB
Python

from typing import Tuple
import json
import logging
import re
from datetime import datetime
from typing import Optional, Dict, Any, List
from sqlalchemy.orm import Session
from backend.database import Company, Industry, RoboticsCategory, EnrichmentData
from backend.lib.core_utils import call_gemini_flash, safe_eval_math, run_serp_search
from backend.services.scraping import scrape_website_content
from backend.lib.metric_parser import MetricParser
logger = logging.getLogger(__name__)
class ClassificationService:
def __init__(self):
pass
def _load_industry_definitions(self, db: Session) -> List[Industry]:
industries = db.query(Industry).all()
if not industries:
logger.warning("No industry definitions found in DB. Classification might be limited.")
return industries
def _get_wikipedia_content(self, db: Session, company_id: int) -> Optional[Dict[str, Any]]:
enrichment = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company_id,
EnrichmentData.source_type == "wikipedia"
).order_by(EnrichmentData.created_at.desc()).first()
return enrichment.content if enrichment and enrichment.content else None
def _run_llm_classification_prompt(self, website_text: str, company_name: str, industry_definitions: List[Dict[str, str]]) -> Optional[str]:
prompt = f"""
Act as a strict B2B Industry Classifier.
Company: {company_name}
Context: {website_text[:3000]}
Available Industries:
{json.dumps(industry_definitions, indent=2)}
Task: Select the ONE industry that best matches the company.
If the company is a Hospital/Klinik, select 'Healthcare - Hospital'.
If none match well, select 'Others'.
Return ONLY the exact name of the industry.
"""
try:
response = call_gemini_flash(prompt)
if not response: return "Others"
cleaned = response.strip().replace('"', '').replace("'", "")
# Simple fuzzy match check
valid_names = [i['name'] for i in industry_definitions] + ["Others"]
if cleaned in valid_names:
return cleaned
# Fallback: Try to find name in response
for name in valid_names:
if name in cleaned:
return name
return "Others"
except Exception as e:
logger.error(f"Classification Prompt Error: {e}")
return "Others"
def _run_llm_metric_extraction_prompt(self, text_content: str, search_term: str, industry_name: str) -> Optional[Dict[str, Any]]:
prompt = f"""
Extract the following metric for the company in industry '{industry_name}':
Target Metric: "{search_term}"
Source Text:
{text_content[:6000]}
Return a JSON object with:
- "raw_value": The number found (e.g. 352 or 352.0). If text says "352 Betten", extract 352. If not found, null.
- "raw_unit": The unit found (e.g. "Betten", "").
- "proof_text": A short quote from the text proving this value.
JSON ONLY.
"""
try:
response = call_gemini_flash(prompt, json_mode=True)
if not response: return None
if isinstance(response, str):
response = response.replace("```json", "").replace("```", "").strip()
data = json.loads(response)
else:
data = response
# Basic cleanup
if data.get("raw_value") == "null": data["raw_value"] = None
return data
except Exception as e:
logger.error(f"LLM Extraction Parse Error: {e}")
return None
def _is_metric_plausible(self, metric_name: str, value: Optional[float]) -> bool:
if value is None: return False
try:
val_float = float(value)
return val_float > 0
except:
return False
def _parse_standardization_logic(self, formula: str, raw_value: float) -> Optional[float]:
if not formula or raw_value is None:
return None
formula_cleaned = formula.replace("wert", str(raw_value)).replace("Value", str(raw_value)).replace("Wert", str(raw_value))
formula_cleaned = re.sub(r'(?i)m[²2]', '', formula_cleaned)
formula_cleaned = re.sub(r'(?i)qm', '', formula_cleaned)
formula_cleaned = re.sub(r'\s*\(.*\)\s*$', '', formula_cleaned).strip()
try:
return safe_eval_math(formula_cleaned)
except Exception as e:
logger.error(f"Failed to parse standardization logic '{formula}' with value {raw_value}: {e}")
return None
def _get_best_metric_result(self, results_list: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if not results_list:
return None
source_priority = {"wikipedia": 0, "website": 1, "serpapi": 2}
valid_results = [r for r in results_list if r.get("calculated_metric_value") is not None]
if not valid_results:
return None
valid_results.sort(key=lambda r: (source_priority.get(r.get("metric_source"), 99), -r.get("metric_confidence", 0.0)))
logger.info(f"Best result chosen: {valid_results[0]}")
return valid_results[0]
def _get_website_content_and_url(self, company: Company) -> Tuple[Optional[str], Optional[str]]:
return scrape_website_content(company.website), company.website
def _get_wikipedia_content_and_url(self, db: Session, company_id: int) -> Tuple[Optional[str], Optional[str]]:
wiki_data = self._get_wikipedia_content(db, company_id)
return (wiki_data.get('full_text'), wiki_data.get('url')) if wiki_data else (None, None)
def _get_serpapi_content_and_url(self, company: Company, search_term: str) -> Tuple[Optional[str], Optional[str]]:
serp_results = run_serp_search(f"{company.name} {company.city or ''} {search_term}")
if not serp_results:
return None, None
content = " ".join([res.get("snippet", "") for res in serp_results.get("organic_results", [])])
url = serp_results.get("organic_results", [{}])[0].get("link") if serp_results.get("organic_results") else None
return content, url
def _extract_and_calculate_metric_cascade(self, db: Session, company: Company, industry_name: str, search_term: str, standardization_logic: Optional[str], standardized_unit: Optional[str]) -> Dict[str, Any]:
final_result = {"calculated_metric_name": search_term, "calculated_metric_value": None, "calculated_metric_unit": None, "standardized_metric_value": None, "standardized_metric_unit": standardized_unit, "metric_source": None, "metric_proof_text": None, "metric_source_url": None, "metric_confidence": 0.0, "metric_confidence_reason": "No value found in any source."}
sources = [
("website", self._get_website_content_and_url),
("wikipedia", self._get_wikipedia_content_and_url),
("serpapi", self._get_serpapi_content_and_url)
]
all_source_results = []
for source_name, content_loader in sources:
logger.info(f"Checking {source_name} for '{search_term}' for {company.name}")
try:
args = (company,) if source_name == 'website' else (db, company.id) if source_name == 'wikipedia' else (company, search_term)
content_text, current_source_url = content_loader(*args)
if not content_text:
logger.info(f"No content for {source_name}.")
continue
llm_result = self._run_llm_metric_extraction_prompt(content_text, search_term, industry_name)
if llm_result:
llm_result['source_url'] = current_source_url
all_source_results.append((source_name, llm_result))
except Exception as e:
logger.error(f"Error in {source_name} stage: {e}")
processed_results = []
for source_name, llm_result in all_source_results:
metric_value = llm_result.get("raw_value")
metric_unit = llm_result.get("raw_unit")
if metric_value is not None and self._is_metric_plausible(search_term, metric_value):
standardized_value = None
if standardization_logic and metric_value is not None:
standardized_value = self._parse_standardization_logic(standardization_logic, metric_value)
processed_results.append({
"calculated_metric_name": search_term,
"calculated_metric_value": metric_value,
"calculated_metric_unit": metric_unit,
"standardized_metric_value": standardized_value,
"standardized_metric_unit": standardized_unit,
"metric_source": source_name,
"metric_proof_text": llm_result.get("proof_text"),
"metric_source_url": llm_result.get("source_url"),
"metric_confidence": 0.95,
"metric_confidence_reason": "Value found and extracted by LLM."
})
else:
logger.info(f"LLM found no plausible metric for {search_term} in {source_name}.")
best_result = self._get_best_metric_result(processed_results)
return best_result if best_result else final_result
def extract_metrics_for_industry(self, company: Company, db: Session, industry: Industry) -> Company:
if not industry or not industry.scraper_search_term:
logger.warning(f"No metric configuration for industry '{industry.name if industry else 'None'}'")
return company
# Improved unit derivation
if "" in (industry.standardization_logic or "") or "" in (industry.scraper_search_term or ""):
std_unit = ""
else:
std_unit = "Einheiten"
metrics = self._extract_and_calculate_metric_cascade(
db, company, industry.name, industry.scraper_search_term, industry.standardization_logic, std_unit
)
company.calculated_metric_name = metrics["calculated_metric_name"]
company.calculated_metric_value = metrics["calculated_metric_value"]
company.calculated_metric_unit = metrics["calculated_metric_unit"]
company.standardized_metric_value = metrics["standardized_metric_value"]
company.standardized_metric_unit = metrics["standardized_metric_unit"]
company.metric_source = metrics["metric_source"]
company.metric_proof_text = metrics["metric_proof_text"]
company.metric_source_url = metrics.get("metric_source_url")
company.metric_confidence = metrics["metric_confidence"]
company.metric_confidence_reason = metrics["metric_confidence_reason"]
company.last_classification_at = datetime.utcnow()
db.commit()
return company
def reevaluate_wikipedia_metric(self, company: Company, db: Session, industry: Industry) -> Company:
logger.info(f"Re-evaluating metric for {company.name}...")
return self.extract_metrics_for_industry(company, db, industry)
def classify_company_potential(self, company: Company, db: Session) -> Company:
logger.info(f"Starting classification for {company.name}...")
# 1. Load Definitions
industries = self._load_industry_definitions(db)
industry_defs = [{"name": i.name, "description": i.description} for i in industries]
# 2. Get Content (Website)
website_content, _ = self._get_website_content_and_url(company)
if not website_content:
logger.warning(f"No website content for {company.name}. Skipping classification.")
return company
# 3. Classify Industry
suggested_industry_name = self._run_llm_classification_prompt(website_content, company.name, industry_defs)
logger.info(f"AI suggests industry: {suggested_industry_name}")
# 4. Update Company
# Match back to DB object
matched_industry = next((i for i in industries if i.name == suggested_industry_name), None)
if matched_industry:
company.industry_ai = matched_industry.name
else:
company.industry_ai = "Others"
# 5. Extract Metrics (Cascade)
if matched_industry:
self.extract_metrics_for_industry(company, db, matched_industry)
company.last_classification_at = datetime.utcnow()
db.commit()
return company