Files
Brancheneinstufung2/company-explorer/backend/app.py

839 lines
29 KiB
Python

from fastapi import FastAPI, Depends, HTTPException, Query, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session, joinedload
from typing import List, Optional, Dict, Any
from pydantic import BaseModel
from datetime import datetime
import os
import sys
from fastapi.security import HTTPBasic, HTTPBasicCredentials
import secrets
security = HTTPBasic()
async def authenticate_user(credentials: HTTPBasicCredentials = Depends(security)):
correct_username = secrets.compare_digest(credentials.username, os.getenv("API_USER", "default_user"))
correct_password = secrets.compare_digest(credentials.password, os.getenv("API_PASSWORD", "default_password"))
if not (correct_username and correct_password):
raise HTTPException(
status_code=401,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Basic"},
)
return credentials.username
from .config import settings
from .lib.logging_setup import setup_logging
# Setup Logging first
setup_logging()
import logging
logger = logging.getLogger(__name__)
from .database import init_db, get_db, Company, Signal, EnrichmentData, RoboticsCategory, Contact, Industry, JobRoleMapping, ReportedMistake, MarketingMatrix
from .services.deduplication import Deduplicator
from .services.discovery import DiscoveryService
from .services.scraping import ScraperService
from .services.classification import ClassificationService
# Initialize App
app = FastAPI(
title=settings.APP_NAME,
version=settings.VERSION,
description="Backend for Company Explorer (Robotics Edition)"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Service Singletons
scraper = ScraperService()
classifier = ClassificationService() # Now works without args
discovery = DiscoveryService()
# --- Pydantic Models ---
class CompanyCreate(BaseModel):
name: str
city: Optional[str] = None
country: str = "DE"
website: Optional[str] = None
crm_id: Optional[str] = None
class BulkImportRequest(BaseModel):
names: List[str]
class AnalysisRequest(BaseModel):
company_id: int
force_scrape: bool = False
class IndustryUpdateModel(BaseModel):
industry_ai: str
class ReportMistakeRequest(BaseModel):
field_name: str
wrong_value: Optional[str] = None
corrected_value: Optional[str] = None
source_url: Optional[str] = None
quote: Optional[str] = None
user_comment: Optional[str] = None
class ProvisioningRequest(BaseModel):
so_contact_id: int
so_person_id: Optional[int] = None
crm_name: Optional[str] = None
crm_website: Optional[str] = None
class ProvisioningResponse(BaseModel):
status: str
company_name: str
website: Optional[str] = None
vertical_name: Optional[str] = None
role_name: Optional[str] = None
texts: Dict[str, Optional[str]] = {}
# --- Events ---
@app.on_event("startup")
def on_startup():
logger.info("Startup Event: Initializing Database...")
try:
init_db()
logger.info("Database initialized successfully.")
except Exception as e:
logger.critical(f"Database init failed: {e}", exc_info=True)
# --- Routes ---
@app.get("/api/health")
def health_check(username: str = Depends(authenticate_user)):
return {"status": "ok", "version": settings.VERSION, "db": settings.DATABASE_URL}
@app.post("/api/provision/superoffice-contact", response_model=ProvisioningResponse)
def provision_superoffice_contact(
req: ProvisioningRequest,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
username: str = Depends(authenticate_user)
):
# 1. Find Company (via SO ID)
company = db.query(Company).filter(Company.crm_id == str(req.so_contact_id)).first()
if not company:
# AUTO-CREATE Logic
if not req.crm_name:
# Cannot create without name. Should ideally not happen if Connector does its job.
raise HTTPException(400, "Cannot create company: crm_name missing")
company = Company(
name=req.crm_name,
crm_id=str(req.so_contact_id),
crm_name=req.crm_name,
crm_website=req.crm_website,
status="NEW"
)
db.add(company)
db.commit()
db.refresh(company)
logger.info(f"Auto-created company {company.name} from SuperOffice request.")
# Trigger Discovery
background_tasks.add_task(run_discovery_task, company.id)
return ProvisioningResponse(
status="processing",
company_name=company.name
)
# 1b. Check Status & Progress
# If NEW or DISCOVERED, we are not ready to provide texts.
if company.status in ["NEW", "DISCOVERED"]:
# If we have a website, ensure analysis is triggered
if company.status == "DISCOVERED" or (company.website and company.website != "k.A."):
background_tasks.add_task(run_analysis_task, company.id)
elif company.status == "NEW":
# Ensure discovery runs
background_tasks.add_task(run_discovery_task, company.id)
return ProvisioningResponse(
status="processing",
company_name=company.name
)
# 1c. Update CRM Snapshot Data (The Double Truth)
changed = False
if req.crm_name:
company.crm_name = req.crm_name
changed = True
if req.crm_website:
company.crm_website = req.crm_website
changed = True
# Simple Mismatch Check
if company.website and company.crm_website:
def norm(u): return str(u).lower().replace("https://", "").replace("http://", "").replace("www.", "").strip("/")
if norm(company.website) != norm(company.crm_website):
company.data_mismatch_score = 0.8 # High mismatch
changed = True
else:
if company.data_mismatch_score != 0.0:
company.data_mismatch_score = 0.0
changed = True
if changed:
company.updated_at = datetime.utcnow()
db.commit()
# 2. Find Contact (Person)
if req.so_person_id is None:
# Just a company sync, no texts needed
return ProvisioningResponse(
status="success",
company_name=company.name,
website=company.website,
vertical_name=company.industry_ai
)
person = db.query(Contact).filter(Contact.so_person_id == req.so_person_id).first()
# 3. Determine Role
role_name = None
if person and person.role:
role_name = person.role
elif req.job_title:
# Simple classification fallback
mappings = db.query(JobRoleMapping).all()
for m in mappings:
# Check pattern type (Regex vs Simple) - simplified here
pattern_clean = m.pattern.replace("%", "").lower()
if pattern_clean in req.job_title.lower():
role_name = m.role
break
# 4. Determine Vertical (Industry)
vertical_name = company.industry_ai
# 5. Fetch Texts from Matrix
texts = {"subject": None, "intro": None, "social_proof": None}
if vertical_name and role_name:
industry_obj = db.query(Industry).filter(Industry.name == vertical_name).first()
if industry_obj:
# Find any mapping for this role to query the Matrix
# (Assuming Matrix is linked to *one* canonical mapping for this role string)
role_ids = [m.id for m in db.query(JobRoleMapping).filter(JobRoleMapping.role == role_name).all()]
if role_ids:
matrix_entry = db.query(MarketingMatrix).filter(
MarketingMatrix.industry_id == industry_obj.id,
MarketingMatrix.role_id.in_(role_ids)
).first()
if matrix_entry:
texts["subject"] = matrix_entry.subject
texts["intro"] = matrix_entry.intro
texts["social_proof"] = matrix_entry.social_proof
return ProvisioningResponse(
status="success",
company_name=company.name,
website=company.website,
vertical_name=vertical_name,
role_name=role_name,
texts=texts
)
@app.get("/api/companies")
def list_companies(
skip: int = 0,
limit: int = 50,
search: Optional[str] = None,
sort_by: Optional[str] = Query("name_asc"),
db: Session = Depends(get_db),
username: str = Depends(authenticate_user)
):
try:
query = db.query(Company)
if search:
query = query.filter(Company.name.ilike(f"%{search}%"))
total = query.count()
if sort_by == "updated_desc":
query = query.order_by(Company.updated_at.desc())
elif sort_by == "created_desc":
query = query.order_by(Company.id.desc())
else: # Default: name_asc
query = query.order_by(Company.name.asc())
items = query.offset(skip).limit(limit).all()
# Efficiently check for pending mistakes
company_ids = [c.id for c in items]
if company_ids:
pending_mistakes = db.query(ReportedMistake.company_id).filter(
ReportedMistake.company_id.in_(company_ids),
ReportedMistake.status == 'PENDING'
).distinct().all()
companies_with_pending_mistakes = {row[0] for row in pending_mistakes}
else:
companies_with_pending_mistakes = set()
# Add the flag to each company object
for company in items:
company.has_pending_mistakes = company.id in companies_with_pending_mistakes
return {"total": total, "items": items}
except Exception as e:
logger.error(f"List Companies Error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/companies/export")
def export_companies_csv(db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
"""
Exports a CSV of all companies with their key metrics.
"""
import io
import csv
from fastapi.responses import StreamingResponse
output = io.StringIO()
writer = csv.writer(output)
# Header
writer.writerow([
"ID", "Name", "Website", "City", "Country", "AI Industry",
"Metric Name", "Metric Value", "Metric Unit", "Standardized Value (m2)",
"Source", "Source URL", "Confidence", "Proof Text"
])
companies = db.query(Company).order_by(Company.name.asc()).all()
for c in companies:
writer.writerow([
c.id, c.name, c.website, c.city, c.country, c.industry_ai,
c.calculated_metric_name,
c.calculated_metric_value,
c.calculated_metric_unit,
c.standardized_metric_value,
c.metric_source,
c.metric_source_url,
c.metric_confidence,
c.metric_proof_text
])
output.seek(0)
return StreamingResponse(
output,
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename=company_export_{datetime.utcnow().strftime('%Y-%m-%d')}.csv"}
)
@app.get("/api/companies/{company_id}")
def get_company(company_id: int, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
company = db.query(Company).options(
joinedload(Company.enrichment_data),
joinedload(Company.contacts)
).filter(Company.id == company_id).first()
if not company:
raise HTTPException(404, detail="Company not found")
# Enrich with Industry Details (Strategy)
industry_details = None
if company.industry_ai:
ind = db.query(Industry).filter(Industry.name == company.industry_ai).first()
if ind:
industry_details = {
"pains": ind.pains,
"gains": ind.gains,
"priority": ind.priority,
"notes": ind.notes,
"ops_focus_secondary": ind.ops_focus_secondary
}
# HACK: Attach to response object (Pydantic would be cleaner, but this works for fast prototyping)
# We convert to dict and append
resp = company.__dict__.copy()
resp["industry_details"] = industry_details
# Handle SQLAlchemy internal state
if "_sa_instance_state" in resp: del resp["_sa_instance_state"]
# Handle relationships manually if needed, or let FastAPI encode the SQLAlchemy model + extra dict
# Better: return a custom dict merging both
# Since we use joinedload, relationships are loaded.
# Let's rely on FastAPI's ability to serialize the object, but we need to inject the extra field.
# The safest way without changing Pydantic schemas everywhere is to return a dict.
return {**resp, "enrichment_data": company.enrichment_data, "contacts": company.contacts, "signals": company.signals}
@app.post("/api/companies")
def create_company(company: CompanyCreate, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
db_company = db.query(Company).filter(Company.name == company.name).first()
if db_company:
raise HTTPException(status_code=400, detail="Company already registered")
new_company = Company(
name=company.name,
city=company.city,
country=company.country,
website=company.website,
crm_id=company.crm_id,
status="NEW"
)
db.add(new_company)
db.commit()
db.refresh(new_company)
return new_company
@app.post("/api/companies/bulk")
def bulk_import_companies(req: BulkImportRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
imported_count = 0
for name in req.names:
name = name.strip()
if not name: continue
exists = db.query(Company).filter(Company.name == name).first()
if not exists:
new_company = Company(name=name, status="NEW")
db.add(new_company)
imported_count += 1
# Optional: Auto-trigger discovery
# background_tasks.add_task(run_discovery_task, new_company.id)
db.commit()
return {"status": "success", "imported": imported_count}
@app.post("/api/companies/{company_id}/override/wikipedia")
def override_wikipedia(company_id: int, url: str, background_tasks: BackgroundTasks, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
company = db.query(Company).filter(Company.id == company_id).first()
if not company:
raise HTTPException(404, detail="Company not found")
# Create or update manual wikipedia lock
existing = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company_id,
EnrichmentData.source_type == "wikipedia"
).first()
# If URL is empty, we might want to clear it or set it to "k.A."
# Assuming 'url' param carries the new URL.
wiki_data = {"url": url, "full_text": None, "manual_override": True}
if not existing:
db.add(EnrichmentData(
company_id=company_id,
source_type="wikipedia",
content=wiki_data,
is_locked=True
))
else:
existing.content = wiki_data
existing.is_locked = True
db.commit()
# Trigger Re-evaluation if URL is valid
if url and url.startswith("http"):
background_tasks.add_task(run_wikipedia_reevaluation_task, company.id)
return {"status": "updated"}
@app.get("/api/robotics/categories")
def list_robotics_categories(db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
return db.query(RoboticsCategory).all()
@app.get("/api/industries")
def list_industries(db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
return db.query(Industry).all()
@app.get("/api/job_roles")
def list_job_roles(db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
return db.query(JobRoleMapping).order_by(JobRoleMapping.pattern.asc()).all()
@app.get("/api/mistakes")
def list_reported_mistakes(
status: Optional[str] = Query(None),
company_id: Optional[int] = Query(None),
skip: int = 0,
limit: int = 50,
db: Session = Depends(get_db),
username: str = Depends(authenticate_user)
):
query = db.query(ReportedMistake).options(joinedload(ReportedMistake.company))
if status:
query = query.filter(ReportedMistake.status == status.upper())
if company_id:
query = query.filter(ReportedMistake.company_id == company_id)
total = query.count()
items = query.order_by(ReportedMistake.created_at.desc()).offset(skip).limit(limit).all()
return {"total": total, "items": items}
class MistakeUpdateStatusRequest(BaseModel):
status: str # PENDING, APPROVED, REJECTED
@app.put("/api/mistakes/{mistake_id}")
def update_reported_mistake_status(
mistake_id: int,
request: MistakeUpdateStatusRequest,
db: Session = Depends(get_db),
username: str = Depends(authenticate_user)
):
mistake = db.query(ReportedMistake).filter(ReportedMistake.id == mistake_id).first()
if not mistake:
raise HTTPException(404, detail="Reported mistake not found")
if request.status.upper() not in ["PENDING", "APPROVED", "REJECTED"]:
raise HTTPException(400, detail="Invalid status. Must be PENDING, APPROVED, or REJECTED.")
mistake.status = request.status.upper()
mistake.updated_at = datetime.utcnow()
db.commit()
db.refresh(mistake)
logger.info(f"Updated status for mistake {mistake_id} to {mistake.status}")
return {"status": "success", "mistake": mistake}
@app.post("/api/enrich/discover")
def discover_company(req: AnalysisRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
company = db.query(Company).filter(Company.id == req.company_id).first()
if not company: raise HTTPException(404, "Company not found")
background_tasks.add_task(run_discovery_task, company.id)
return {"status": "queued"}
@app.post("/api/enrich/analyze")
def analyze_company(req: AnalysisRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
company = db.query(Company).filter(Company.id == req.company_id).first()
if not company: raise HTTPException(404, "Company not found")
if not company.website or company.website == "k.A.":
return {"error": "No website to analyze. Run Discovery first."}
background_tasks.add_task(run_analysis_task, company.id)
return {"status": "queued"}
@app.put("/api/companies/{company_id}/industry")
def update_company_industry(
company_id: int,
data: IndustryUpdateModel,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
username: str = Depends(authenticate_user)
):
company = db.query(Company).filter(Company.id == company_id).first()
if not company:
raise HTTPException(404, detail="Company not found")
# 1. Update Industry
company.industry_ai = data.industry_ai
company.updated_at = datetime.utcnow()
db.commit()
# 2. Trigger Metric Re-extraction in Background
background_tasks.add_task(run_metric_reextraction_task, company.id)
return {"status": "updated", "industry_ai": company.industry_ai}
@app.post("/api/companies/{company_id}/reevaluate-wikipedia")
def reevaluate_wikipedia(company_id: int, background_tasks: BackgroundTasks, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
company = db.query(Company).filter(Company.id == company_id).first()
if not company:
raise HTTPException(404, detail="Company not found")
background_tasks.add_task(run_wikipedia_reevaluation_task, company.id)
return {"status": "queued"}
@app.delete("/api/companies/{company_id}")
def delete_company(company_id: int, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
company = db.query(Company).filter(Company.id == company_id).first()
if not company:
raise HTTPException(404, detail="Company not found")
# Delete related data first (Cascade might handle this but being explicit is safer)
db.query(EnrichmentData).filter(EnrichmentData.company_id == company_id).delete()
db.query(Signal).filter(Signal.company_id == company_id).delete()
db.query(Contact).filter(Contact.company_id == company_id).delete()
db.delete(company)
db.commit()
return {"status": "deleted"}
@app.post("/api/companies/{company_id}/override/website")
def override_website(company_id: int, url: str, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
company = db.query(Company).filter(Company.id == company_id).first()
if not company:
raise HTTPException(404, detail="Company not found")
company.website = url
company.updated_at = datetime.utcnow()
db.commit()
return {"status": "updated", "website": company.website}
@app.post("/api/companies/{company_id}/override/impressum")
def override_impressum(company_id: int, url: str, background_tasks: BackgroundTasks, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
company = db.query(Company).filter(Company.id == company_id).first()
if not company:
raise HTTPException(404, detail="Company not found")
# Create or update manual impressum lock
existing = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company_id,
EnrichmentData.source_type == "impressum_override"
).first()
if not existing:
db.add(EnrichmentData(
company_id=company_id,
source_type="impressum_override",
content={"url": url},
is_locked=True
))
else:
existing.content = {"url": url}
existing.is_locked = True
db.commit()
return {"status": "updated"}
@app.post("/api/companies/{company_id}/report-mistake")
def report_company_mistake(
company_id: int,
request: ReportMistakeRequest,
db: Session = Depends(get_db),
username: str = Depends(authenticate_user)
):
company = db.query(Company).filter(Company.id == company_id).first()
if not company:
raise HTTPException(404, detail="Company not found")
new_mistake = ReportedMistake(
company_id=company_id,
field_name=request.field_name,
wrong_value=request.wrong_value,
corrected_value=request.corrected_value,
source_url=request.source_url,
quote=request.quote,
user_comment=request.user_comment
)
db.add(new_mistake)
db.commit()
db.refresh(new_mistake)
logger.info(f"Reported mistake for company {company_id}: {request.field_name} -> {request.corrected_value}")
return {"status": "success", "mistake_id": new_mistake.id}
def run_wikipedia_reevaluation_task(company_id: int):
from .database import SessionLocal
db = SessionLocal()
try:
company = db.query(Company).filter(Company.id == company_id).first()
if not company: return
logger.info(f"Re-evaluating Wikipedia metric for {company.name} (Industry: {company.industry_ai})")
industry = db.query(Industry).filter(Industry.name == company.industry_ai).first()
if industry:
classifier.reevaluate_wikipedia_metric(company, db, industry)
logger.info(f"Wikipedia metric re-evaluation complete for {company.name}")
else:
logger.warning(f"Industry '{company.industry_ai}' not found for re-evaluation.")
except Exception as e:
logger.error(f"Wikipedia Re-evaluation Task Error: {e}", exc_info=True)
finally:
db.close()
def run_metric_reextraction_task(company_id: int):
from .database import SessionLocal
db = SessionLocal()
try:
company = db.query(Company).filter(Company.id == company_id).first()
if not company: return
logger.info(f"Re-extracting metrics for {company.name} (Industry: {company.industry_ai})")
industries = db.query(Industry).all()
industry = next((i for i in industries if i.name == company.industry_ai), None)
if industry:
classifier.extract_metrics_for_industry(company, db, industry)
company.status = "ENRICHED"
db.commit()
logger.info(f"Metric re-extraction complete for {company.name}")
else:
logger.warning(f"Industry '{company.industry_ai}' not found for re-extraction.")
except Exception as e:
logger.error(f"Metric Re-extraction Task Error: {e}", exc_info=True)
finally:
db.close()
def run_discovery_task(company_id: int):
from .database import SessionLocal
db = SessionLocal()
try:
company = db.query(Company).filter(Company.id == company_id).first()
if not company: return
# 1. Website Search
if not company.website or company.website == "k.A.":
found_url = discovery.find_company_website(company.name, company.city)
if found_url and found_url != "k.A.":
company.website = found_url
# 2. Wikipedia Search
existing_wiki = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company.id,
EnrichmentData.source_type == "wikipedia"
).first()
if not existing_wiki or not existing_wiki.is_locked:
wiki_url = discovery.find_wikipedia_url(company.name, website=company.website, city=company.city)
wiki_data = discovery.extract_wikipedia_data(wiki_url) if wiki_url and wiki_url != "k.A." else {"url": wiki_url}
if not existing_wiki:
db.add(EnrichmentData(company_id=company.id, source_type="wikipedia", content=wiki_data))
else:
existing_wiki.content = wiki_data
existing_wiki.updated_at = datetime.utcnow()
if company.status == "NEW" and company.website and company.website != "k.A.":
company.status = "DISCOVERED"
db.commit()
except Exception as e:
logger.error(f"Discovery Task Error: {e}", exc_info=True)
finally:
db.close()
def run_analysis_task(company_id: int):
from .database import SessionLocal
db = SessionLocal()
try:
company = db.query(Company).filter(Company.id == company_id).first()
if not company: return
logger.info(f"Running Analysis Task for {company.name}")
# 1. Scrape Website (if not locked)
existing_scrape = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company.id,
EnrichmentData.source_type == "website_scrape"
).first()
if not existing_scrape or not existing_scrape.is_locked:
from .services.scraping import ScraperService
scrape_res = ScraperService().scrape_url(company.website)
if not existing_scrape:
db.add(EnrichmentData(company_id=company.id, source_type="website_scrape", content=scrape_res))
else:
existing_scrape.content = scrape_res
existing_scrape.updated_at = datetime.utcnow()
db.commit()
# 2. Classify Industry & Metrics
# IMPORTANT: Using the new method name and passing db session
classifier.classify_company_potential(company, db)
company.status = "ENRICHED"
db.commit()
logger.info(f"Analysis complete for {company.name}")
except Exception as e:
logger.error(f"Analyze Task Error: {e}", exc_info=True)
finally:
db.close()
# --- Serve Frontend ---
static_path = "/frontend_static"
if not os.path.exists(static_path):
# Local dev fallback
static_path = os.path.join(os.path.dirname(__file__), "../../frontend/dist")
if not os.path.exists(static_path):
static_path = os.path.join(os.path.dirname(__file__), "../static")
logger.info(f"Static files path: {static_path} (Exists: {os.path.exists(static_path)})")
if os.path.exists(static_path):
@app.get("/")
async def serve_index():
return FileResponse(os.path.join(static_path, "index.html"))
app.mount("/", StaticFiles(directory=static_path, html=True), name="static")
else:
@app.get("/")
def root_no_frontend():
return {"message": "Company Explorer API is running, but frontend was not found.", "path_tried": static_path}
if __name__ == "__main__":
import uvicorn
uvicorn.run("backend.app:app", host="0.0.0.0", port=8000, reload=True)