Files
Brancheneinstufung2/company-explorer/backend/app.py

730 lines
25 KiB
Python

from fastapi import FastAPI, Depends, HTTPException, Query, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session, joinedload
from typing import List, Optional, Dict, Any
from pydantic import BaseModel
from datetime import datetime
import os
import sys
from .config import settings
from .lib.logging_setup import setup_logging
# Setup Logging first
setup_logging()
import logging
logger = logging.getLogger(__name__)
from .database import init_db, get_db, Company, Signal, EnrichmentData, RoboticsCategory, Contact, Industry, JobRoleMapping
from .services.deduplication import Deduplicator
from .services.discovery import DiscoveryService
from .services.scraping import ScraperService
from .services.classification import ClassificationService
# Initialize App
app = FastAPI(
title=settings.APP_NAME,
version=settings.VERSION,
description="Backend for Company Explorer (Robotics Edition)",
root_path="/ce"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Service Singletons
scraper = ScraperService()
classifier = ClassificationService()
discovery = DiscoveryService()
# --- Pydantic Models ---
class CompanyCreate(BaseModel):
name: str
city: Optional[str] = None
country: str = "DE"
website: Optional[str] = None
class BulkImportRequest(BaseModel):
names: List[str]
class AnalysisRequest(BaseModel):
company_id: int
force_scrape: bool = False
class ContactBase(BaseModel):
gender: str
title: str = ""
first_name: str
last_name: str
email: str
job_title: str
language: str = "De"
role: str
status: str = ""
is_primary: bool = False
class ContactCreate(ContactBase):
company_id: int
class ContactUpdate(BaseModel):
gender: Optional[str] = None
title: Optional[str] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
email: Optional[str] = None
job_title: Optional[str] = None
language: Optional[str] = None
role: Optional[str] = None
status: Optional[str] = None
is_primary: Optional[bool] = None
# --- Events ---
@app.on_event("startup")
def on_startup():
logger.info("Startup Event: Initializing Database...")
try:
init_db()
logger.info("Database initialized successfully.")
except Exception as e:
logger.critical(f"Database init failed: {e}", exc_info=True)
# --- Routes ---
@app.get("/api/health")
def health_check():
return {"status": "ok", "version": settings.VERSION, "db": settings.DATABASE_URL}
@app.get("/api/companies")
def list_companies(
skip: int = 0,
limit: int = 50,
search: Optional[str] = None,
db: Session = Depends(get_db)
):
try:
query = db.query(Company)
if search:
query = query.filter(Company.name.ilike(f"%{search}%"))
total = query.count()
# Sort by ID desc (newest first)
items = query.order_by(Company.id.desc()).offset(skip).limit(limit).all()
return {"total": total, "items": items}
except Exception as e:
logger.error(f"List Companies Error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/companies/{company_id}")
def get_company(company_id: int, db: Session = Depends(get_db)):
company = db.query(Company).options(
joinedload(Company.signals),
joinedload(Company.enrichment_data),
joinedload(Company.contacts)
).filter(Company.id == company_id).first()
if not company:
raise HTTPException(status_code=404, detail="Company not found")
return company
@app.post("/api/companies/bulk")
def bulk_import_names(req: BulkImportRequest, db: Session = Depends(get_db)):
"""
Quick import for testing. Just a list of names.
"""
logger.info(f"Starting bulk import of {len(req.names)} names.")
try:
added = 0
skipped = 0
# Deduplicator init
try:
dedup = Deduplicator(db)
logger.info("Deduplicator initialized.")
except Exception as e:
logger.warning(f"Deduplicator init failed: {e}")
dedup = None
for name in req.names:
clean_name = name.strip()
if not clean_name: continue
# 1. Simple Deduplication (Exact Name)
exists = db.query(Company).filter(Company.name == clean_name).first()
if exists:
skipped += 1
continue
# 2. Smart Deduplication (if available)
if dedup:
matches = dedup.find_duplicates({"name": clean_name})
if matches and matches[0]['score'] > 95:
logger.info(f"Duplicate found for {clean_name}: {matches[0]['name']}")
skipped += 1
continue
# 3. Create
new_comp = Company(
name=clean_name,
status="NEW" # This triggered the error before
)
db.add(new_comp)
added += 1
db.commit()
logger.info(f"Import success. Added: {added}, Skipped: {skipped}")
return {"added": added, "skipped": skipped}
except Exception as e:
logger.error(f"Bulk Import Failed: {e}", exc_info=True)
db.rollback()
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/robotics/categories")
def list_robotics_categories(db: Session = Depends(get_db)):
"""Lists all configured robotics categories."""
return db.query(RoboticsCategory).all()
class CategoryUpdate(BaseModel):
description: str
reasoning_guide: str
@app.put("/api/robotics/categories/{id}")
def update_robotics_category(id: int, cat: CategoryUpdate, db: Session = Depends(get_db)):
"""Updates a robotics category definition."""
category = db.query(RoboticsCategory).filter(RoboticsCategory.id == id).first()
if not category:
raise HTTPException(404, "Category not found")
category.description = cat.description
category.reasoning_guide = cat.reasoning_guide
db.commit()
return category
@app.post("/api/enrich/discover")
def discover_company(req: AnalysisRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
"""
Triggers Stage 1: Discovery (Website Search + Wikipedia Search)
"""
try:
company = db.query(Company).filter(Company.id == req.company_id).first()
if not company:
raise HTTPException(404, "Company not found")
# Run in background
background_tasks.add_task(run_discovery_task, company.id)
return {"status": "queued", "message": f"Discovery started for {company.name}"}
except Exception as e:
logger.error(f"Discovery Error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/companies/{company_id}/override/wiki")
def override_wiki_url(company_id: int, url: str = Query(...), db: Session = Depends(get_db)):
"""
Manually sets the Wikipedia URL for a company and triggers re-extraction.
Locks the data against auto-discovery.
"""
company = db.query(Company).filter(Company.id == company_id).first()
if not company:
raise HTTPException(404, "Company not found")
logger.info(f"Manual Override for {company.name}: Setting Wiki URL to {url}")
# Update or create EnrichmentData entry
existing_wiki = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company.id,
EnrichmentData.source_type == "wikipedia"
).first()
# Extract data immediately
wiki_data = {"url": url}
if url and url != "k.A.":
try:
wiki_data = discovery.extract_wikipedia_data(url)
wiki_data['url'] = url # Ensure URL is correct
except Exception as e:
logger.error(f"Extraction failed for manual URL: {e}")
wiki_data["error"] = str(e)
if not existing_wiki:
db.add(EnrichmentData(
company_id=company.id,
source_type="wikipedia",
content=wiki_data,
is_locked=True
))
else:
existing_wiki.content = wiki_data
existing_wiki.updated_at = datetime.utcnow()
existing_wiki.is_locked = True # LOCK IT
db.commit()
return {"status": "updated", "data": wiki_data}
@app.post("/api/companies/{company_id}/override/website")
def override_website_url(company_id: int, url: str = Query(...), db: Session = Depends(get_db)):
"""
Manually sets the Website URL for a company.
Clears existing scrape data to force a fresh analysis on next run.
"""
company = db.query(Company).filter(Company.id == company_id).first()
if not company:
raise HTTPException(404, "Company not found")
logger.info(f"Manual Override for {company.name}: Setting Website to {url}")
company.website = url
# Remove old scrape data since URL changed
db.query(EnrichmentData).filter(
EnrichmentData.company_id == company.id,
EnrichmentData.source_type == "website_scrape"
).delete()
db.commit()
return {"status": "updated", "website": url}
@app.post("/api/companies/{company_id}/override/impressum")
def override_impressum_url(company_id: int, url: str = Query(...), db: Session = Depends(get_db)):
"""
Manually sets the Impressum URL for a company and triggers re-extraction.
"""
company = db.query(Company).filter(Company.id == company_id).first()
if not company:
raise HTTPException(404, "Company not found")
logger.info(f"Manual Override for {company.name}: Setting Impressum URL to {url}")
# 1. Scrape Impressum immediately
impressum_data = scraper._scrape_impressum_data(url)
if not impressum_data:
raise HTTPException(status_code=400, detail="Failed to extract data from provided URL")
# 2. Find existing scrape data or create new
existing_scrape = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company.id,
EnrichmentData.source_type == "website_scrape"
).first()
if not existing_scrape:
# Create minimal scrape entry
db.add(EnrichmentData(
company_id=company.id,
source_type="website_scrape",
content={"impressum": impressum_data, "text": "", "title": "Manual Impressum", "url": url}
))
else:
# Update existing
content = dict(existing_scrape.content) if existing_scrape.content else {}
content["impressum"] = impressum_data
existing_scrape.content = content
existing_scrape.updated_at = datetime.utcnow()
db.commit()
return {"status": "updated", "data": impressum_data}
# --- Contact Routes ---
@app.post("/api/contacts")
def create_contact(contact: ContactCreate, db: Session = Depends(get_db)):
"""Creates a new contact and handles primary contact logic."""
if contact.is_primary:
db.query(Contact).filter(Contact.company_id == contact.company_id).update({"is_primary": False})
db_contact = Contact(**contact.dict())
db.add(db_contact)
db.commit()
db.refresh(db_contact)
return db_contact
# --- Industry Routes ---
class IndustryCreate(BaseModel):
name: str
description: Optional[str] = None
is_focus: bool = False
primary_category_id: Optional[int] = None
class IndustryUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
is_focus: Optional[bool] = None
primary_category_id: Optional[int] = None
@app.get("/api/industries")
def list_industries(db: Session = Depends(get_db)):
return db.query(Industry).all()
@app.post("/api/industries")
def create_industry(ind: IndustryCreate, db: Session = Depends(get_db)):
# 1. Prepare data
ind_data = ind.dict()
base_name = ind_data['name']
# 2. Check for duplicate name
existing = db.query(Industry).filter(Industry.name == base_name).first()
if existing:
# Auto-increment name if duplicated
counter = 1
while db.query(Industry).filter(Industry.name == f"{base_name} ({counter})").first():
counter += 1
ind_data['name'] = f"{base_name} ({counter})"
# 3. Create
db_ind = Industry(**ind_data)
db.add(db_ind)
db.commit()
db.refresh(db_ind)
return db_ind
@app.put("/api/industries/{id}")
def update_industry(id: int, ind: IndustryUpdate, db: Session = Depends(get_db)):
db_ind = db.query(Industry).filter(Industry.id == id).first()
if not db_ind:
raise HTTPException(404, "Industry not found")
for key, value in ind.dict(exclude_unset=True).items():
setattr(db_ind, key, value)
db.commit()
db.refresh(db_ind)
return db_ind
@app.delete("/api/industries/{id}")
def delete_industry(id: int, db: Session = Depends(get_db)):
db_ind = db.query(Industry).filter(Industry.id == id).first()
if not db_ind:
raise HTTPException(404, "Industry not found")
db.delete(db_ind)
db.commit()
return {"status": "deleted"}
# --- Job Role Mapping Routes ---
class JobRoleMappingCreate(BaseModel):
pattern: str
role: str
@app.get("/api/job_roles")
def list_job_roles(db: Session = Depends(get_db)):
return db.query(JobRoleMapping).all()
@app.post("/api/job_roles")
def create_job_role(mapping: JobRoleMappingCreate, db: Session = Depends(get_db)):
db_mapping = JobRoleMapping(**mapping.dict())
db.add(db_mapping)
db.commit()
db.refresh(db_mapping)
return db_mapping
@app.delete("/api/job_roles/{id}")
def delete_job_role(id: int, db: Session = Depends(get_db)):
db_mapping = db.query(JobRoleMapping).filter(JobRoleMapping.id == id).first()
if not db_mapping:
raise HTTPException(404, "Mapping not found")
db.delete(db_mapping)
db.commit()
return {"status": "deleted"}
@app.put("/api/contacts/{contact_id}")
def update_contact(contact_id: int, contact: ContactUpdate, db: Session = Depends(get_db)):
"""Updates an existing contact."""
db_contact = db.query(Contact).filter(Contact.id == contact_id).first()
if not db_contact:
raise HTTPException(404, "Contact not found")
update_data = contact.dict(exclude_unset=True)
if update_data.get("is_primary"):
db.query(Contact).filter(Contact.company_id == db_contact.company_id).update({"is_primary": False})
for key, value in update_data.items():
setattr(db_contact, key, value)
db.commit()
db.refresh(db_contact)
return db_contact
@app.delete("/api/contacts/{contact_id}")
def delete_contact(contact_id: int, db: Session = Depends(get_db)):
"""Deletes a contact."""
db_contact = db.query(Contact).filter(Contact.id == contact_id).first()
if not db_contact:
raise HTTPException(404, "Contact not found")
db.delete(db_contact)
db.commit()
return {"status": "deleted"}
@app.get("/api/contacts/all")
def list_all_contacts(
skip: int = 0,
limit: int = 50,
search: Optional[str] = None,
db: Session = Depends(get_db)
):
"""
Lists all contacts across all companies with pagination and search.
"""
query = db.query(Contact).join(Company)
if search:
search_term = f"%{search}%"
query = query.filter(
(Contact.first_name.ilike(search_term)) |
(Contact.last_name.ilike(search_term)) |
(Contact.email.ilike(search_term)) |
(Company.name.ilike(search_term))
)
total = query.count()
# Sort by ID desc
contacts = query.order_by(Contact.id.desc()).offset(skip).limit(limit).all()
# Enrich with Company Name for the frontend list
result = []
for c in contacts:
c_dict = {k: v for k, v in c.__dict__.items() if not k.startswith('_')}
c_dict['company_name'] = c.company.name if c.company else "Unknown"
result.append(c_dict)
return {"total": total, "items": result}
class BulkContactImportItem(BaseModel):
company_name: str
first_name: str
last_name: str
email: Optional[str] = None
job_title: Optional[str] = None
role: Optional[str] = "Operativer Entscheider"
gender: Optional[str] = "männlich"
class BulkContactImportRequest(BaseModel):
contacts: List[BulkContactImportItem]
@app.post("/api/contacts/bulk")
def bulk_import_contacts(req: BulkContactImportRequest, db: Session = Depends(get_db)):
"""
Bulk imports contacts.
Matches Company by Name (creates if missing).
Dedupes Contact by Email.
"""
logger.info(f"Starting bulk contact import: {len(req.contacts)} items")
stats = {"added": 0, "skipped": 0, "companies_created": 0}
for item in req.contacts:
if not item.company_name: continue
# 1. Find or Create Company
company = db.query(Company).filter(Company.name.ilike(item.company_name.strip())).first()
if not company:
company = Company(name=item.company_name.strip(), status="NEW")
db.add(company)
db.commit() # Commit to get ID
db.refresh(company)
stats["companies_created"] += 1
# 2. Check for Duplicate Contact (by Email)
if item.email:
exists = db.query(Contact).filter(Contact.email == item.email.strip()).first()
if exists:
stats["skipped"] += 1
continue
# 3. Create Contact
new_contact = Contact(
company_id=company.id,
first_name=item.first_name,
last_name=item.last_name,
email=item.email,
job_title=item.job_title,
role=item.role,
gender=item.gender,
status="Init" # Default status
)
db.add(new_contact)
stats["added"] += 1
db.commit()
return stats
def run_discovery_task(company_id: int):
# New Session for Background Task
from .database import SessionLocal
db = SessionLocal()
try:
company = db.query(Company).filter(Company.id == company_id).first()
if not company: return
logger.info(f"Running Discovery Task for {company.name}")
# 1. Website Search (Always try if missing)
if not company.website or company.website == "k.A.":
found_url = discovery.find_company_website(company.name, company.city)
if found_url and found_url != "k.A.":
company.website = found_url
logger.info(f"-> Found URL: {found_url}")
# 2. Wikipedia Search & Extraction
# Check if locked
existing_wiki = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company.id,
EnrichmentData.source_type == "wikipedia"
).first()
if existing_wiki and existing_wiki.is_locked:
logger.info(f"Skipping Wiki Discovery for {company.name} - Data is LOCKED.")
else:
# Pass available info for better validation
current_website = company.website if company.website and company.website != "k.A." else None
wiki_url = discovery.find_wikipedia_url(company.name, website=current_website, city=company.city)
company.last_wiki_search_at = datetime.utcnow()
wiki_data = {"url": wiki_url}
if wiki_url and wiki_url != "k.A.":
logger.info(f"Extracting full data from Wikipedia for {company.name}...")
wiki_data = discovery.extract_wikipedia_data(wiki_url)
if not existing_wiki:
db.add(EnrichmentData(company_id=company.id, source_type="wikipedia", content=wiki_data))
else:
existing_wiki.content = wiki_data
existing_wiki.updated_at = datetime.utcnow()
if company.status == "NEW" and company.website and company.website != "k.A.":
company.status = "DISCOVERED"
db.commit()
logger.info(f"Discovery finished for {company.id}")
except Exception as e:
logger.error(f"Background Task Error: {e}", exc_info=True)
db.rollback()
finally:
db.close()
@app.post("/api/enrich/analyze")
def analyze_company(req: AnalysisRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
company = db.query(Company).filter(Company.id == req.company_id).first()
if not company:
raise HTTPException(404, "Company not found")
if not company.website or company.website == "k.A.":
return {"error": "No website to analyze. Run Discovery first."}
# FORCE SCRAPE LOGIC
# If explicit force_scrape is requested OR if we want to ensure fresh data for debugging
# We delete the old scrape data.
# For now, let's assume every manual "Analyze" click implies a desire for fresh results if previous failed.
# But let's respect the flag from frontend if we add it later.
# Always clearing scrape data for now to fix the "stuck cache" issue reported by user
db.query(EnrichmentData).filter(
EnrichmentData.company_id == company.id,
EnrichmentData.source_type == "website_scrape"
).delete()
db.commit()
background_tasks.add_task(run_analysis_task, company.id, company.website)
return {"status": "queued"}
def run_analysis_task(company_id: int, url: str):
from .database import SessionLocal
db = SessionLocal()
try:
company = db.query(Company).filter(Company.id == company_id).first()
if not company: return
logger.info(f"Running Analysis Task for {company.name}")
# 1. Scrape Website
scrape_result = scraper.scrape_url(url)
# Save Scrape Data
existing_scrape_data = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company.id,
EnrichmentData.source_type == "website_scrape"
).first()
if "text" in scrape_result and scrape_result["text"]:
if not existing_scrape_data:
db.add(EnrichmentData(company_id=company.id, source_type="website_scrape", content=scrape_result))
else:
existing_scrape_data.content = scrape_result
existing_scrape_data.updated_at = datetime.utcnow()
elif "error" in scrape_result:
logger.warning(f"Scraping failed for {company.name}: {scrape_result['error']}")
# 2. Classify Robotics Potential
if "text" in scrape_result and scrape_result["text"]:
analysis = classifier.analyze_robotics_potential(
company_name=company.name,
website_text=scrape_result["text"]
)
if "error" in analysis:
logger.error(f"Robotics classification failed for {company.name}: {analysis['error']}")
else:
industry = analysis.get("industry")
if industry:
company.industry_ai = industry
# Delete old signals
db.query(Signal).filter(Signal.company_id == company.id).delete()
# Save new signals
potentials = analysis.get("potentials", {})
for signal_type, data in potentials.items():
new_signal = Signal(
company_id=company.id,
signal_type=f"robotics_{signal_type}_potential",
confidence=data.get("score", 0),
value="High" if data.get("score", 0) > 70 else "Medium" if data.get("score", 0) > 30 else "Low",
proof_text=data.get("reason")
)
db.add(new_signal)
# Save Full Analysis Blob (Business Model + Evidence)
existing_analysis = db.query(EnrichmentData).filter(
EnrichmentData.company_id == company.id,
EnrichmentData.source_type == "ai_analysis"
).first()
if not existing_analysis:
db.add(EnrichmentData(company_id=company.id, source_type="ai_analysis", content=analysis))
else:
existing_analysis.content = analysis
existing_analysis.updated_at = datetime.utcnow()
company.status = "ENRICHED"
company.last_classification_at = datetime.utcnow()
logger.info(f"Robotics analysis complete for {company.name}.")
db.commit()
logger.info(f"Analysis finished for {company.id}")
except Exception as e:
logger.error(f"Analyze Task Error: {e}", exc_info=True)
db.rollback()
finally:
db.close()
# --- Serve Frontend ---
# Priority 1: Container Path (outside of /app volume)
static_path = "/frontend_static"
# Priority 2: Local Dev Path (relative to this file)
if not os.path.exists(static_path):
static_path = os.path.join(os.path.dirname(__file__), "../static")
if os.path.exists(static_path):
logger.info(f"Serving frontend from {static_path}")
app.mount("/", StaticFiles(directory=static_path, html=True), name="static")
else:
logger.warning(f"Frontend static files not found at {static_path} or local fallback.")
if __name__ == "__main__":
import uvicorn
uvicorn.run("backend.app:app", host="0.0.0.0", port=8000, reload=True)