from fastapi import FastAPI, Depends, HTTPException, Query, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from sqlalchemy.orm import Session, joinedload from typing import List, Optional, Dict, Any from pydantic import BaseModel from datetime import datetime import os import sys from fastapi.security import HTTPBasic, HTTPBasicCredentials import secrets security = HTTPBasic() async def authenticate_user(credentials: HTTPBasicCredentials = Depends(security)): correct_username = secrets.compare_digest(credentials.username, os.getenv("API_USER", "default_user")) correct_password = secrets.compare_digest(credentials.password, os.getenv("API_PASSWORD", "default_password")) if not (correct_username and correct_password): raise HTTPException( status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}, ) return credentials.username from .config import settings from .lib.logging_setup import setup_logging # Setup Logging first setup_logging() import logging logger = logging.getLogger(__name__) from .database import init_db, get_db, Company, Signal, EnrichmentData, RoboticsCategory, Contact, Industry, JobRoleMapping, ReportedMistake from .services.deduplication import Deduplicator from .services.discovery import DiscoveryService from .services.scraping import ScraperService from .services.classification import ClassificationService # Initialize App app = FastAPI( title=settings.APP_NAME, version=settings.VERSION, description="Backend for Company Explorer (Robotics Edition)", root_path="/ce" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Service Singletons scraper = ScraperService() classifier = ClassificationService() # Now works without args discovery = DiscoveryService() # --- Pydantic Models --- class CompanyCreate(BaseModel): name: str city: Optional[str] = None country: str = "DE" website: Optional[str] = None class BulkImportRequest(BaseModel): names: List[str] class AnalysisRequest(BaseModel): company_id: int force_scrape: bool = False class IndustryUpdateModel(BaseModel): industry_ai: str class ReportMistakeRequest(BaseModel): field_name: str wrong_value: Optional[str] = None corrected_value: Optional[str] = None source_url: Optional[str] = None quote: Optional[str] = None user_comment: Optional[str] = None # --- Events --- @app.on_event("startup") def on_startup(): logger.info("Startup Event: Initializing Database...") try: init_db() logger.info("Database initialized successfully.") except Exception as e: logger.critical(f"Database init failed: {e}", exc_info=True) # --- Routes --- @app.get("/api/health") def health_check(username: str = Depends(authenticate_user)): return {"status": "ok", "version": settings.VERSION, "db": settings.DATABASE_URL} @app.get("/api/companies") def list_companies( skip: int = 0, limit: int = 50, search: Optional[str] = None, sort_by: Optional[str] = Query("name_asc"), db: Session = Depends(get_db), username: str = Depends(authenticate_user) ): try: query = db.query(Company) if search: query = query.filter(Company.name.ilike(f"%{search}%")) total = query.count() if sort_by == "updated_desc": query = query.order_by(Company.updated_at.desc()) elif sort_by == "created_desc": query = query.order_by(Company.id.desc()) else: # Default: name_asc query = query.order_by(Company.name.asc()) items = query.offset(skip).limit(limit).all() # Efficiently check for pending mistakes company_ids = [c.id for c in items] if company_ids: pending_mistakes = db.query(ReportedMistake.company_id).filter( ReportedMistake.company_id.in_(company_ids), ReportedMistake.status == 'PENDING' ).distinct().all() companies_with_pending_mistakes = {row[0] for row in pending_mistakes} else: companies_with_pending_mistakes = set() # Add the flag to each company object for company in items: company.has_pending_mistakes = company.id in companies_with_pending_mistakes return {"total": total, "items": items} except Exception as e: logger.error(f"List Companies Error: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/companies/export") def export_companies_csv(db: Session = Depends(get_db), username: str = Depends(authenticate_user)): """ Exports a CSV of all companies with their key metrics. """ import io import csv from fastapi.responses import StreamingResponse output = io.StringIO() writer = csv.writer(output) # Header writer.writerow([ "ID", "Name", "Website", "City", "Country", "AI Industry", "Metric Name", "Metric Value", "Metric Unit", "Standardized Value (m2)", "Source", "Source URL", "Confidence", "Proof Text" ]) companies = db.query(Company).order_by(Company.name.asc()).all() for c in companies: writer.writerow([ c.id, c.name, c.website, c.city, c.country, c.industry_ai, c.calculated_metric_name, c.calculated_metric_value, c.calculated_metric_unit, c.standardized_metric_value, c.metric_source, c.metric_source_url, c.metric_confidence, c.metric_proof_text ]) output.seek(0) return StreamingResponse( output, media_type="text/csv", headers={"Content-Disposition": f"attachment; filename=company_export_{datetime.utcnow().strftime('%Y-%m-%d')}.csv"} ) @app.get("/api/companies/{company_id}") def get_company(company_id: int, db: Session = Depends(get_db), username: str = Depends(authenticate_user)): company = db.query(Company).options( joinedload(Company.enrichment_data), joinedload(Company.contacts) ).filter(Company.id == company_id).first() if not company: raise HTTPException(404, detail="Company not found") return company @app.post("/api/companies") def create_company(company: CompanyCreate, db: Session = Depends(get_db), username: str = Depends(authenticate_user)): db_company = db.query(Company).filter(Company.name == company.name).first() if db_company: raise HTTPException(status_code=400, detail="Company already registered") new_company = Company( name=company.name, city=company.city, country=company.country, website=company.website, status="NEW" ) db.add(new_company) db.commit() db.refresh(new_company) return new_company @app.post("/api/companies/bulk") def bulk_import_companies(req: BulkImportRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db), username: str = Depends(authenticate_user)): imported_count = 0 for name in req.names: name = name.strip() if not name: continue exists = db.query(Company).filter(Company.name == name).first() if not exists: new_company = Company(name=name, status="NEW") db.add(new_company) imported_count += 1 # Optional: Auto-trigger discovery # background_tasks.add_task(run_discovery_task, new_company.id) db.commit() return {"status": "success", "imported": imported_count} @app.post("/api/companies/{company_id}/override/wikipedia") def override_wikipedia(company_id: int, url: str, background_tasks: BackgroundTasks, db: Session = Depends(get_db), username: str = Depends(authenticate_user)): company = db.query(Company).filter(Company.id == company_id).first() if not company: raise HTTPException(404, detail="Company not found") # Create or update manual wikipedia lock existing = db.query(EnrichmentData).filter( EnrichmentData.company_id == company_id, EnrichmentData.source_type == "wikipedia" ).first() # If URL is empty, we might want to clear it or set it to "k.A." # Assuming 'url' param carries the new URL. wiki_data = {"url": url, "full_text": None, "manual_override": True} if not existing: db.add(EnrichmentData( company_id=company_id, source_type="wikipedia", content=wiki_data, is_locked=True )) else: existing.content = wiki_data existing.is_locked = True db.commit() # Trigger Re-evaluation if URL is valid if url and url.startswith("http"): background_tasks.add_task(run_wikipedia_reevaluation_task, company.id) return {"status": "updated"} @app.get("/api/robotics/categories") def list_robotics_categories(db: Session = Depends(get_db), username: str = Depends(authenticate_user)): return db.query(RoboticsCategory).all() @app.get("/api/industries") def list_industries(db: Session = Depends(get_db), username: str = Depends(authenticate_user)): return db.query(Industry).all() @app.get("/api/job_roles") def list_job_roles(db: Session = Depends(get_db), username: str = Depends(authenticate_user)): return db.query(JobRoleMapping).order_by(JobRoleMapping.pattern.asc()).all() @app.get("/api/mistakes") def list_reported_mistakes( status: Optional[str] = Query(None), company_id: Optional[int] = Query(None), skip: int = 0, limit: int = 50, db: Session = Depends(get_db), username: str = Depends(authenticate_user) ): query = db.query(ReportedMistake).options(joinedload(ReportedMistake.company)) if status: query = query.filter(ReportedMistake.status == status.upper()) if company_id: query = query.filter(ReportedMistake.company_id == company_id) total = query.count() items = query.order_by(ReportedMistake.created_at.desc()).offset(skip).limit(limit).all() return {"total": total, "items": items} class MistakeUpdateStatusRequest(BaseModel): status: str # PENDING, APPROVED, REJECTED @app.put("/api/mistakes/{mistake_id}") def update_reported_mistake_status( mistake_id: int, request: MistakeUpdateStatusRequest, db: Session = Depends(get_db), username: str = Depends(authenticate_user) ): mistake = db.query(ReportedMistake).filter(ReportedMistake.id == mistake_id).first() if not mistake: raise HTTPException(404, detail="Reported mistake not found") if request.status.upper() not in ["PENDING", "APPROVED", "REJECTED"]: raise HTTPException(400, detail="Invalid status. Must be PENDING, APPROVED, or REJECTED.") mistake.status = request.status.upper() mistake.updated_at = datetime.utcnow() db.commit() db.refresh(mistake) logger.info(f"Updated status for mistake {mistake_id} to {mistake.status}") return {"status": "success", "mistake": mistake} @app.post("/api/enrich/discover") def discover_company(req: AnalysisRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db), username: str = Depends(authenticate_user)): company = db.query(Company).filter(Company.id == req.company_id).first() if not company: raise HTTPException(404, "Company not found") background_tasks.add_task(run_discovery_task, company.id) return {"status": "queued"} @app.post("/api/enrich/analyze") def analyze_company(req: AnalysisRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db), username: str = Depends(authenticate_user)): company = db.query(Company).filter(Company.id == req.company_id).first() if not company: raise HTTPException(404, "Company not found") if not company.website or company.website == "k.A.": return {"error": "No website to analyze. Run Discovery first."} background_tasks.add_task(run_analysis_task, company.id) return {"status": "queued"} @app.put("/api/companies/{company_id}/industry") def update_company_industry( company_id: int, data: IndustryUpdateModel, background_tasks: BackgroundTasks, db: Session = Depends(get_db), username: str = Depends(authenticate_user) ): company = db.query(Company).filter(Company.id == company_id).first() if not company: raise HTTPException(404, detail="Company not found") # 1. Update Industry company.industry_ai = data.industry_ai company.updated_at = datetime.utcnow() db.commit() # 2. Trigger Metric Re-extraction in Background background_tasks.add_task(run_metric_reextraction_task, company.id) return {"status": "updated", "industry_ai": company.industry_ai} @app.post("/api/companies/{company_id}/reevaluate-wikipedia") def reevaluate_wikipedia(company_id: int, background_tasks: BackgroundTasks, db: Session = Depends(get_db), username: str = Depends(authenticate_user)): company = db.query(Company).filter(Company.id == company_id).first() if not company: raise HTTPException(404, detail="Company not found") background_tasks.add_task(run_wikipedia_reevaluation_task, company.id) return {"status": "queued"} @app.delete("/api/companies/{company_id}") def delete_company(company_id: int, db: Session = Depends(get_db), username: str = Depends(authenticate_user)): company = db.query(Company).filter(Company.id == company_id).first() if not company: raise HTTPException(404, detail="Company not found") # Delete related data first (Cascade might handle this but being explicit is safer) db.query(EnrichmentData).filter(EnrichmentData.company_id == company_id).delete() db.query(Signal).filter(Signal.company_id == company_id).delete() db.query(Contact).filter(Contact.company_id == company_id).delete() db.delete(company) db.commit() return {"status": "deleted"} @app.post("/api/companies/{company_id}/override/website") def override_website(company_id: int, url: str, db: Session = Depends(get_db), username: str = Depends(authenticate_user)): company = db.query(Company).filter(Company.id == company_id).first() if not company: raise HTTPException(404, detail="Company not found") company.website = url company.updated_at = datetime.utcnow() db.commit() return {"status": "updated", "website": company.website} @app.post("/api/companies/{company_id}/override/impressum") def override_impressum(company_id: int, url: str, background_tasks: BackgroundTasks, db: Session = Depends(get_db), username: str = Depends(authenticate_user)): company = db.query(Company).filter(Company.id == company_id).first() if not company: raise HTTPException(404, detail="Company not found") # Create or update manual impressum lock existing = db.query(EnrichmentData).filter( EnrichmentData.company_id == company_id, EnrichmentData.source_type == "impressum_override" ).first() if not existing: db.add(EnrichmentData( company_id=company_id, source_type="impressum_override", content={"url": url}, is_locked=True )) else: existing.content = {"url": url} existing.is_locked = True db.commit() return {"status": "updated"} @app.post("/api/companies/{company_id}/report-mistake") def report_company_mistake( company_id: int, request: ReportMistakeRequest, db: Session = Depends(get_db), username: str = Depends(authenticate_user) ): company = db.query(Company).filter(Company.id == company_id).first() if not company: raise HTTPException(404, detail="Company not found") new_mistake = ReportedMistake( company_id=company_id, field_name=request.field_name, wrong_value=request.wrong_value, corrected_value=request.corrected_value, source_url=request.source_url, quote=request.quote, user_comment=request.user_comment ) db.add(new_mistake) db.commit() db.refresh(new_mistake) logger.info(f"Reported mistake for company {company_id}: {request.field_name} -> {request.corrected_value}") return {"status": "success", "mistake_id": new_mistake.id} def run_wikipedia_reevaluation_task(company_id: int): from .database import SessionLocal db = SessionLocal() try: company = db.query(Company).filter(Company.id == company_id).first() if not company: return logger.info(f"Re-evaluating Wikipedia metric for {company.name} (Industry: {company.industry_ai})") industry = db.query(Industry).filter(Industry.name == company.industry_ai).first() if industry: classifier.reevaluate_wikipedia_metric(company, db, industry) logger.info(f"Wikipedia metric re-evaluation complete for {company.name}") else: logger.warning(f"Industry '{company.industry_ai}' not found for re-evaluation.") except Exception as e: logger.error(f"Wikipedia Re-evaluation Task Error: {e}", exc_info=True) finally: db.close() def run_metric_reextraction_task(company_id: int): from .database import SessionLocal db = SessionLocal() try: company = db.query(Company).filter(Company.id == company_id).first() if not company: return logger.info(f"Re-extracting metrics for {company.name} (Industry: {company.industry_ai})") industries = db.query(Industry).all() industry = next((i for i in industries if i.name == company.industry_ai), None) if industry: classifier.extract_metrics_for_industry(company, db, industry) company.status = "ENRICHED" db.commit() logger.info(f"Metric re-extraction complete for {company.name}") else: logger.warning(f"Industry '{company.industry_ai}' not found for re-extraction.") except Exception as e: logger.error(f"Metric Re-extraction Task Error: {e}", exc_info=True) finally: db.close() def run_discovery_task(company_id: int): from .database import SessionLocal db = SessionLocal() try: company = db.query(Company).filter(Company.id == company_id).first() if not company: return # 1. Website Search if not company.website or company.website == "k.A.": found_url = discovery.find_company_website(company.name, company.city) if found_url and found_url != "k.A.": company.website = found_url # 2. Wikipedia Search existing_wiki = db.query(EnrichmentData).filter( EnrichmentData.company_id == company.id, EnrichmentData.source_type == "wikipedia" ).first() if not existing_wiki or not existing_wiki.is_locked: wiki_url = discovery.find_wikipedia_url(company.name, website=company.website, city=company.city) wiki_data = discovery.extract_wikipedia_data(wiki_url) if wiki_url and wiki_url != "k.A." else {"url": wiki_url} if not existing_wiki: db.add(EnrichmentData(company_id=company.id, source_type="wikipedia", content=wiki_data)) else: existing_wiki.content = wiki_data existing_wiki.updated_at = datetime.utcnow() if company.status == "NEW" and company.website and company.website != "k.A.": company.status = "DISCOVERED" db.commit() except Exception as e: logger.error(f"Discovery Task Error: {e}", exc_info=True) finally: db.close() def run_analysis_task(company_id: int): from .database import SessionLocal db = SessionLocal() try: company = db.query(Company).filter(Company.id == company_id).first() if not company: return logger.info(f"Running Analysis Task for {company.name}") # 1. Scrape Website (if not locked) existing_scrape = db.query(EnrichmentData).filter( EnrichmentData.company_id == company.id, EnrichmentData.source_type == "website_scrape" ).first() if not existing_scrape or not existing_scrape.is_locked: from .services.scraping import ScraperService scrape_res = ScraperService().scrape_url(company.website) if not existing_scrape: db.add(EnrichmentData(company_id=company.id, source_type="website_scrape", content=scrape_res)) else: existing_scrape.content = scrape_res existing_scrape.updated_at = datetime.utcnow() db.commit() # 2. Classify Industry & Metrics # IMPORTANT: Using the new method name and passing db session classifier.classify_company_potential(company, db) company.status = "ENRICHED" db.commit() logger.info(f"Analysis complete for {company.name}") except Exception as e: logger.error(f"Analyze Task Error: {e}", exc_info=True) finally: db.close() # --- Serve Frontend --- static_path = "/frontend_static" if not os.path.exists(static_path): static_path = os.path.join(os.path.dirname(__file__), "../static") if os.path.exists(static_path): app.mount("/", StaticFiles(directory=static_path, html=True), name="static") if __name__ == "__main__": import uvicorn uvicorn.run("backend.app:app", host="0.0.0.0", port=8000, reload=True)