973 lines
33 KiB
Python
973 lines
33 KiB
Python
from fastapi import FastAPI, Depends, HTTPException, Query, BackgroundTasks
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.responses import FileResponse
|
|
from sqlalchemy.orm import Session, joinedload
|
|
from typing import List, Optional, Dict, Any
|
|
from pydantic import BaseModel
|
|
from datetime import datetime
|
|
import os
|
|
import sys
|
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
|
import secrets
|
|
|
|
security = HTTPBasic()
|
|
|
|
async def authenticate_user(credentials: HTTPBasicCredentials = Depends(security)):
|
|
correct_username = secrets.compare_digest(credentials.username, os.getenv("API_USER", "default_user"))
|
|
correct_password = secrets.compare_digest(credentials.password, os.getenv("API_PASSWORD", "default_password"))
|
|
if not (correct_username and correct_password):
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="Incorrect username or password",
|
|
headers={"WWW-Authenticate": "Basic"},
|
|
)
|
|
return credentials.username
|
|
|
|
from .config import settings
|
|
from .lib.logging_setup import setup_logging
|
|
|
|
# Setup Logging first
|
|
setup_logging()
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
from .database import init_db, get_db, Company, Signal, EnrichmentData, RoboticsCategory, Contact, Industry, JobRoleMapping, ReportedMistake, MarketingMatrix, Persona, RawJobTitle
|
|
from .services.deduplication import Deduplicator
|
|
from .services.discovery import DiscoveryService
|
|
from .services.scraping import ScraperService
|
|
from .services.classification import ClassificationService
|
|
|
|
# Initialize App
|
|
app = FastAPI(
|
|
title=settings.APP_NAME,
|
|
version=settings.VERSION,
|
|
description="Backend for Company Explorer (Robotics Edition)"
|
|
)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Service Singletons
|
|
scraper = ScraperService()
|
|
classifier = ClassificationService() # Now works without args
|
|
discovery = DiscoveryService()
|
|
|
|
# --- Pydantic Models ---
|
|
class CompanyCreate(BaseModel):
|
|
name: str
|
|
city: Optional[str] = None
|
|
country: str = "DE"
|
|
website: Optional[str] = None
|
|
crm_id: Optional[str] = None
|
|
|
|
class BulkImportRequest(BaseModel):
|
|
names: List[str]
|
|
|
|
class AnalysisRequest(BaseModel):
|
|
company_id: int
|
|
force_scrape: bool = False
|
|
|
|
class IndustryUpdateModel(BaseModel):
|
|
industry_ai: str
|
|
|
|
class ReportMistakeRequest(BaseModel):
|
|
field_name: str
|
|
wrong_value: Optional[str] = None
|
|
corrected_value: Optional[str] = None
|
|
source_url: Optional[str] = None
|
|
quote: Optional[str] = None
|
|
user_comment: Optional[str] = None
|
|
|
|
class ProvisioningRequest(BaseModel):
|
|
so_contact_id: int
|
|
so_person_id: Optional[int] = None
|
|
crm_name: Optional[str] = None
|
|
crm_website: Optional[str] = None
|
|
job_title: Optional[str] = None
|
|
crm_industry_name: Optional[str] = None
|
|
|
|
class ProvisioningResponse(BaseModel):
|
|
status: str
|
|
company_name: str
|
|
website: Optional[str] = None
|
|
vertical_name: Optional[str] = None
|
|
role_name: Optional[str] = None
|
|
opener: Optional[str] = None # Primary opener (Infrastructure/Cleaning)
|
|
opener_secondary: Optional[str] = None # Secondary opener (Service/Logistics)
|
|
texts: Dict[str, Optional[str]] = {}
|
|
|
|
# Enrichment Data for Write-Back
|
|
address_city: Optional[str] = None
|
|
address_zip: Optional[str] = None
|
|
address_street: Optional[str] = None
|
|
address_country: Optional[str] = None
|
|
vat_id: Optional[str] = None
|
|
|
|
class IndustryDetails(BaseModel):
|
|
pains: Optional[str] = None
|
|
gains: Optional[str] = None
|
|
priority: Optional[str] = None
|
|
notes: Optional[str] = None
|
|
ops_focus_secondary: bool = False
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
class ContactResponse(BaseModel):
|
|
id: int
|
|
first_name: Optional[str] = None
|
|
last_name: Optional[str] = None
|
|
job_title: Optional[str] = None
|
|
role: Optional[str] = None
|
|
email: Optional[str] = None
|
|
is_primary: bool
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
class EnrichmentDataResponse(BaseModel):
|
|
id: int
|
|
source_type: str
|
|
content: Dict[str, Any]
|
|
is_locked: bool
|
|
wiki_verified_empty: bool
|
|
updated_at: datetime
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
class CompanyDetailsResponse(BaseModel):
|
|
id: int
|
|
name: str
|
|
website: Optional[str] = None
|
|
city: Optional[str] = None
|
|
country: Optional[str] = None
|
|
industry_ai: Optional[str] = None
|
|
status: str
|
|
|
|
# Metrics
|
|
calculated_metric_name: Optional[str] = None
|
|
calculated_metric_value: Optional[float] = None
|
|
calculated_metric_unit: Optional[str] = None
|
|
standardized_metric_value: Optional[float] = None
|
|
standardized_metric_unit: Optional[str] = None
|
|
metric_source: Optional[str] = None
|
|
metric_proof_text: Optional[str] = None
|
|
metric_source_url: Optional[str] = None
|
|
metric_confidence: Optional[float] = None
|
|
|
|
# Openers
|
|
ai_opener: Optional[str] = None
|
|
ai_opener_secondary: Optional[str] = None
|
|
|
|
# Relations
|
|
industry_details: Optional[IndustryDetails] = None
|
|
contacts: List[ContactResponse] = []
|
|
enrichment_data: List[EnrichmentDataResponse] = []
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
# --- Events ---
|
|
@app.on_event("startup")
|
|
def on_startup():
|
|
logger.info("Startup Event: Initializing Database...")
|
|
try:
|
|
init_db()
|
|
logger.info("Database initialized successfully.")
|
|
except Exception as e:
|
|
logger.critical(f"Database init failed: {e}", exc_info=True)
|
|
|
|
# --- Routes ---
|
|
|
|
@app.get("/api/health")
|
|
def health_check(username: str = Depends(authenticate_user)):
|
|
return {"status": "ok", "version": settings.VERSION, "db": settings.DATABASE_URL}
|
|
|
|
@app.post("/api/provision/superoffice-contact", response_model=ProvisioningResponse)
|
|
def provision_superoffice_contact(
|
|
req: ProvisioningRequest,
|
|
background_tasks: BackgroundTasks,
|
|
db: Session = Depends(get_db),
|
|
username: str = Depends(authenticate_user)
|
|
):
|
|
# 1. Find Company (via SO ID)
|
|
company = db.query(Company).filter(Company.crm_id == str(req.so_contact_id)).first()
|
|
|
|
if not company:
|
|
# AUTO-CREATE Logic
|
|
if not req.crm_name:
|
|
# Cannot create without name. Should ideally not happen if Connector does its job.
|
|
raise HTTPException(400, "Cannot create company: crm_name missing")
|
|
|
|
company = Company(
|
|
name=req.crm_name,
|
|
crm_id=str(req.so_contact_id),
|
|
crm_name=req.crm_name,
|
|
crm_website=req.crm_website,
|
|
status="NEW"
|
|
)
|
|
db.add(company)
|
|
db.commit()
|
|
db.refresh(company)
|
|
logger.info(f"Auto-created company {company.name} from SuperOffice request.")
|
|
|
|
# Trigger Discovery
|
|
background_tasks.add_task(run_discovery_task, company.id)
|
|
|
|
return ProvisioningResponse(
|
|
status="processing",
|
|
company_name=company.name
|
|
)
|
|
|
|
# 1b. Check Status & Progress
|
|
# If NEW or DISCOVERED, we are not ready to provide texts.
|
|
if company.status in ["NEW", "DISCOVERED"]:
|
|
# If we have a website, ensure analysis is triggered
|
|
if company.status == "DISCOVERED" or (company.website and company.website != "k.A."):
|
|
background_tasks.add_task(run_analysis_task, company.id)
|
|
elif company.status == "NEW":
|
|
# Ensure discovery runs
|
|
background_tasks.add_task(run_discovery_task, company.id)
|
|
|
|
return ProvisioningResponse(
|
|
status="processing",
|
|
company_name=company.name
|
|
)
|
|
|
|
# 1c. Update CRM Snapshot Data (The Double Truth)
|
|
changed = False
|
|
name_changed_significantly = False
|
|
|
|
if req.crm_name and req.crm_name != company.crm_name:
|
|
logger.info(f"CRM Name Change detected for ID {company.crm_id}: {company.crm_name} -> {req.crm_name}")
|
|
company.crm_name = req.crm_name
|
|
# If the name changes, we should potentially re-evaluate the whole company
|
|
# especially if the status was already ENRICHED
|
|
if company.status == "ENRICHED":
|
|
name_changed_significantly = True
|
|
changed = True
|
|
|
|
if req.crm_website:
|
|
if company.crm_website != req.crm_website:
|
|
company.crm_website = req.crm_website
|
|
changed = True
|
|
|
|
# ...
|
|
|
|
if changed:
|
|
company.updated_at = datetime.utcnow()
|
|
if name_changed_significantly:
|
|
logger.info(f"Triggering FRESH discovery for {company.name} due to CRM name change.")
|
|
company.status = "NEW"
|
|
# We don't change the internal 'name' yet, Discovery will do that or we keep it as anchor.
|
|
# But we must clear old results to avoid stale data.
|
|
company.industry_ai = None
|
|
company.ai_opener = None
|
|
company.ai_opener_secondary = None
|
|
background_tasks.add_task(run_discovery_task, company.id)
|
|
|
|
db.commit()
|
|
|
|
# If we just triggered a fresh discovery, tell the worker to wait.
|
|
if name_changed_significantly:
|
|
return ProvisioningResponse(
|
|
status="processing",
|
|
company_name=company.crm_name
|
|
)
|
|
|
|
# 2. Find Contact (Person)
|
|
if req.so_person_id is None:
|
|
# Just a company sync, but return all company-level metadata
|
|
return ProvisioningResponse(
|
|
status="success",
|
|
company_name=company.name,
|
|
website=company.website,
|
|
vertical_name=company.industry_ai,
|
|
opener=company.ai_opener,
|
|
opener_secondary=company.ai_opener_secondary,
|
|
address_city=company.city,
|
|
address_street=company.street,
|
|
address_zip=company.zip_code,
|
|
address_country=company.country,
|
|
vat_id=company.crm_vat
|
|
)
|
|
|
|
person = db.query(Contact).filter(Contact.so_person_id == req.so_person_id).first()
|
|
|
|
# Auto-Create/Update Person
|
|
if not person:
|
|
person = Contact(
|
|
company_id=company.id,
|
|
so_contact_id=req.so_contact_id,
|
|
so_person_id=req.so_person_id,
|
|
status="ACTIVE"
|
|
)
|
|
db.add(person)
|
|
logger.info(f"Created new person {req.so_person_id} for company {company.name}")
|
|
|
|
# Update Job Title & Role logic
|
|
if req.job_title:
|
|
person.job_title = req.job_title
|
|
|
|
# Simple classification fallback
|
|
mappings = db.query(JobRoleMapping).all()
|
|
found_role = None
|
|
for m in mappings:
|
|
pattern_clean = m.pattern.replace("%", "").lower()
|
|
if pattern_clean in req.job_title.lower():
|
|
found_role = m.role
|
|
break
|
|
|
|
# ALWAYS update role, even if to None, to avoid 'sticking' old roles
|
|
if found_role != person.role:
|
|
logger.info(f"Role Change for {person.so_person_id}: {person.role} -> {found_role}")
|
|
person.role = found_role
|
|
|
|
db.commit()
|
|
db.refresh(person)
|
|
|
|
# 3. Determine Role
|
|
role_name = person.role
|
|
|
|
# 4. Determine Vertical (Industry)
|
|
vertical_name = company.industry_ai
|
|
|
|
# 5. Fetch Texts from Matrix
|
|
texts = {"subject": None, "intro": None, "social_proof": None}
|
|
|
|
if vertical_name and role_name:
|
|
industry_obj = db.query(Industry).filter(Industry.name == vertical_name).first()
|
|
persona_obj = db.query(Persona).filter(Persona.name == role_name).first()
|
|
|
|
if industry_obj and persona_obj:
|
|
matrix_entry = db.query(MarketingMatrix).filter(
|
|
MarketingMatrix.industry_id == industry_obj.id,
|
|
MarketingMatrix.persona_id == persona_obj.id
|
|
).first()
|
|
|
|
if matrix_entry:
|
|
texts["subject"] = matrix_entry.subject
|
|
texts["intro"] = matrix_entry.intro
|
|
texts["social_proof"] = matrix_entry.social_proof
|
|
|
|
return ProvisioningResponse(
|
|
status="success",
|
|
company_name=company.name,
|
|
website=company.website,
|
|
vertical_name=vertical_name,
|
|
role_name=role_name,
|
|
opener=company.ai_opener,
|
|
opener_secondary=company.ai_opener_secondary,
|
|
texts=texts,
|
|
address_city=company.city,
|
|
address_street=company.street,
|
|
address_zip=company.zip_code,
|
|
address_country=company.country,
|
|
# TODO: Add VAT field to Company model if not present, for now using crm_vat if available
|
|
vat_id=company.crm_vat
|
|
)
|
|
|
|
@app.get("/api/companies")
|
|
def list_companies(
|
|
skip: int = 0,
|
|
limit: int = 50,
|
|
search: Optional[str] = None,
|
|
sort_by: Optional[str] = Query("name_asc"),
|
|
db: Session = Depends(get_db),
|
|
username: str = Depends(authenticate_user)
|
|
):
|
|
try:
|
|
query = db.query(Company)
|
|
if search:
|
|
query = query.filter(Company.name.ilike(f"%{search}%"))
|
|
|
|
total = query.count()
|
|
if sort_by == "updated_desc":
|
|
query = query.order_by(Company.updated_at.desc())
|
|
elif sort_by == "created_desc":
|
|
query = query.order_by(Company.id.desc())
|
|
else: # Default: name_asc
|
|
query = query.order_by(Company.name.asc())
|
|
|
|
items = query.offset(skip).limit(limit).all()
|
|
|
|
# Efficiently check for pending mistakes
|
|
company_ids = [c.id for c in items]
|
|
if company_ids:
|
|
pending_mistakes = db.query(ReportedMistake.company_id).filter(
|
|
ReportedMistake.company_id.in_(company_ids),
|
|
ReportedMistake.status == 'PENDING'
|
|
).distinct().all()
|
|
companies_with_pending_mistakes = {row[0] for row in pending_mistakes}
|
|
else:
|
|
companies_with_pending_mistakes = set()
|
|
|
|
# Add the flag to each company object
|
|
for company in items:
|
|
company.has_pending_mistakes = company.id in companies_with_pending_mistakes
|
|
|
|
return {"total": total, "items": items}
|
|
except Exception as e:
|
|
logger.error(f"List Companies Error: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/api/companies/export")
|
|
def export_companies_csv(db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
|
|
"""
|
|
Exports a CSV of all companies with their key metrics.
|
|
"""
|
|
import io
|
|
import csv
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
output = io.StringIO()
|
|
writer = csv.writer(output)
|
|
|
|
# Header
|
|
writer.writerow([
|
|
"ID", "Name", "Website", "City", "Country", "AI Industry",
|
|
"Metric Name", "Metric Value", "Metric Unit", "Standardized Value (m2)",
|
|
"Source", "Source URL", "Confidence", "Proof Text"
|
|
])
|
|
|
|
companies = db.query(Company).order_by(Company.name.asc()).all()
|
|
|
|
for c in companies:
|
|
writer.writerow([
|
|
c.id, c.name, c.website, c.city, c.country, c.industry_ai,
|
|
c.calculated_metric_name,
|
|
c.calculated_metric_value,
|
|
c.calculated_metric_unit,
|
|
c.standardized_metric_value,
|
|
c.metric_source,
|
|
c.metric_source_url,
|
|
c.metric_confidence,
|
|
c.metric_proof_text
|
|
])
|
|
|
|
output.seek(0)
|
|
|
|
return StreamingResponse(
|
|
output,
|
|
media_type="text/csv",
|
|
headers={"Content-Disposition": f"attachment; filename=company_export_{datetime.utcnow().strftime('%Y-%m-%d')}.csv"}
|
|
)
|
|
|
|
@app.get("/api/companies/{company_id}", response_model=CompanyDetailsResponse)
|
|
def get_company(company_id: int, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
|
|
company = db.query(Company).options(
|
|
joinedload(Company.enrichment_data),
|
|
joinedload(Company.contacts)
|
|
).filter(Company.id == company_id).first()
|
|
if not company:
|
|
raise HTTPException(404, detail="Company not found")
|
|
|
|
# Enrich with Industry Details (Strategy)
|
|
industry_details = None
|
|
if company.industry_ai:
|
|
ind = db.query(Industry).filter(Industry.name == company.industry_ai).first()
|
|
if ind:
|
|
industry_details = IndustryDetails.model_validate(ind)
|
|
|
|
# FastAPI will automatically serialize the 'company' ORM object into the
|
|
# CompanyDetailsResponse schema. We just need to attach the extra 'industry_details'.
|
|
response_data = CompanyDetailsResponse.model_validate(company)
|
|
response_data.industry_details = industry_details
|
|
|
|
return response_data
|
|
|
|
@app.post("/api/companies")
|
|
def create_company(company: CompanyCreate, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
|
|
db_company = db.query(Company).filter(Company.name == company.name).first()
|
|
if db_company:
|
|
raise HTTPException(status_code=400, detail="Company already registered")
|
|
|
|
new_company = Company(
|
|
name=company.name,
|
|
city=company.city,
|
|
country=company.country,
|
|
website=company.website,
|
|
crm_id=company.crm_id,
|
|
status="NEW"
|
|
)
|
|
db.add(new_company)
|
|
db.commit()
|
|
db.refresh(new_company)
|
|
return new_company
|
|
|
|
@app.post("/api/companies/bulk")
|
|
def bulk_import_companies(req: BulkImportRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
|
|
imported_count = 0
|
|
for name in req.names:
|
|
name = name.strip()
|
|
if not name: continue
|
|
|
|
exists = db.query(Company).filter(Company.name == name).first()
|
|
if not exists:
|
|
new_company = Company(name=name, status="NEW")
|
|
db.add(new_company)
|
|
imported_count += 1
|
|
# Optional: Auto-trigger discovery
|
|
# background_tasks.add_task(run_discovery_task, new_company.id)
|
|
|
|
db.commit()
|
|
return {"status": "success", "imported": imported_count}
|
|
|
|
@app.post("/api/companies/{company_id}/override/wikipedia")
|
|
def override_wikipedia(company_id: int, url: str, background_tasks: BackgroundTasks, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
|
|
company = db.query(Company).filter(Company.id == company_id).first()
|
|
if not company:
|
|
raise HTTPException(404, detail="Company not found")
|
|
|
|
# Create or update manual wikipedia lock
|
|
existing = db.query(EnrichmentData).filter(
|
|
EnrichmentData.company_id == company_id,
|
|
EnrichmentData.source_type == "wikipedia"
|
|
).first()
|
|
|
|
# If URL is empty, we might want to clear it or set it to "k.A."
|
|
# Assuming 'url' param carries the new URL.
|
|
|
|
wiki_data = {"url": url, "full_text": None, "manual_override": True}
|
|
|
|
if not existing:
|
|
db.add(EnrichmentData(
|
|
company_id=company_id,
|
|
source_type="wikipedia",
|
|
content=wiki_data,
|
|
is_locked=True
|
|
))
|
|
else:
|
|
existing.content = wiki_data
|
|
existing.is_locked = True
|
|
|
|
db.commit()
|
|
|
|
# Trigger Re-evaluation if URL is valid
|
|
if url and url.startswith("http"):
|
|
background_tasks.add_task(run_wikipedia_reevaluation_task, company.id)
|
|
|
|
return {"status": "updated"}
|
|
|
|
@app.get("/api/robotics/categories")
|
|
def list_robotics_categories(db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
|
|
return db.query(RoboticsCategory).all()
|
|
|
|
@app.get("/api/industries")
|
|
def list_industries(db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
|
|
return db.query(Industry).all()
|
|
|
|
@app.get("/api/job_roles")
|
|
def list_job_roles(db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
|
|
return db.query(JobRoleMapping).order_by(JobRoleMapping.pattern.asc()).all()
|
|
|
|
@app.get("/api/job_roles/raw")
|
|
def list_raw_job_titles(
|
|
limit: int = 100,
|
|
unmapped_only: bool = True,
|
|
db: Session = Depends(get_db),
|
|
username: str = Depends(authenticate_user)
|
|
):
|
|
"""
|
|
Returns unique raw job titles from CRM imports, prioritized by frequency.
|
|
"""
|
|
query = db.query(RawJobTitle)
|
|
if unmapped_only:
|
|
query = query.filter(RawJobTitle.is_mapped == False)
|
|
|
|
return query.order_by(RawJobTitle.count.desc()).limit(limit).all()
|
|
|
|
@app.get("/api/mistakes")
|
|
def list_reported_mistakes(
|
|
status: Optional[str] = Query(None),
|
|
company_id: Optional[int] = Query(None),
|
|
skip: int = 0,
|
|
limit: int = 50,
|
|
db: Session = Depends(get_db),
|
|
username: str = Depends(authenticate_user)
|
|
):
|
|
query = db.query(ReportedMistake).options(joinedload(ReportedMistake.company))
|
|
|
|
if status:
|
|
query = query.filter(ReportedMistake.status == status.upper())
|
|
|
|
if company_id:
|
|
query = query.filter(ReportedMistake.company_id == company_id)
|
|
|
|
total = query.count()
|
|
items = query.order_by(ReportedMistake.created_at.desc()).offset(skip).limit(limit).all()
|
|
|
|
return {"total": total, "items": items}
|
|
|
|
class MistakeUpdateStatusRequest(BaseModel):
|
|
status: str # PENDING, APPROVED, REJECTED
|
|
|
|
@app.put("/api/mistakes/{mistake_id}")
|
|
def update_reported_mistake_status(
|
|
mistake_id: int,
|
|
request: MistakeUpdateStatusRequest,
|
|
db: Session = Depends(get_db),
|
|
username: str = Depends(authenticate_user)
|
|
):
|
|
mistake = db.query(ReportedMistake).filter(ReportedMistake.id == mistake_id).first()
|
|
if not mistake:
|
|
raise HTTPException(404, detail="Reported mistake not found")
|
|
|
|
if request.status.upper() not in ["PENDING", "APPROVED", "REJECTED"]:
|
|
raise HTTPException(400, detail="Invalid status. Must be PENDING, APPROVED, or REJECTED.")
|
|
|
|
mistake.status = request.status.upper()
|
|
mistake.updated_at = datetime.utcnow()
|
|
db.commit()
|
|
db.refresh(mistake)
|
|
|
|
logger.info(f"Updated status for mistake {mistake_id} to {mistake.status}")
|
|
return {"status": "success", "mistake": mistake}
|
|
|
|
@app.post("/api/enrich/discover")
|
|
def discover_company(req: AnalysisRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
|
|
company = db.query(Company).filter(Company.id == req.company_id).first()
|
|
if not company: raise HTTPException(404, "Company not found")
|
|
background_tasks.add_task(run_discovery_task, company.id)
|
|
return {"status": "queued"}
|
|
|
|
@app.post("/api/enrich/analyze")
|
|
def analyze_company(req: AnalysisRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
|
|
company = db.query(Company).filter(Company.id == req.company_id).first()
|
|
if not company: raise HTTPException(404, "Company not found")
|
|
|
|
if not company.website or company.website == "k.A.":
|
|
return {"error": "No website to analyze. Run Discovery first."}
|
|
|
|
background_tasks.add_task(run_analysis_task, company.id)
|
|
return {"status": "queued"}
|
|
|
|
@app.put("/api/companies/{company_id}/industry")
|
|
def update_company_industry(
|
|
company_id: int,
|
|
data: IndustryUpdateModel,
|
|
background_tasks: BackgroundTasks,
|
|
db: Session = Depends(get_db),
|
|
username: str = Depends(authenticate_user)
|
|
):
|
|
company = db.query(Company).filter(Company.id == company_id).first()
|
|
if not company:
|
|
raise HTTPException(404, detail="Company not found")
|
|
|
|
# 1. Update Industry
|
|
company.industry_ai = data.industry_ai
|
|
company.updated_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
# 2. Trigger Metric Re-extraction in Background
|
|
background_tasks.add_task(run_metric_reextraction_task, company.id)
|
|
|
|
return {"status": "updated", "industry_ai": company.industry_ai}
|
|
|
|
|
|
@app.post("/api/companies/{company_id}/reevaluate-wikipedia")
|
|
def reevaluate_wikipedia(company_id: int, background_tasks: BackgroundTasks, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
|
|
company = db.query(Company).filter(Company.id == company_id).first()
|
|
if not company:
|
|
raise HTTPException(404, detail="Company not found")
|
|
|
|
background_tasks.add_task(run_wikipedia_reevaluation_task, company.id)
|
|
return {"status": "queued"}
|
|
|
|
|
|
@app.delete("/api/companies/{company_id}")
|
|
def delete_company(company_id: int, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
|
|
company = db.query(Company).filter(Company.id == company_id).first()
|
|
if not company:
|
|
raise HTTPException(404, detail="Company not found")
|
|
|
|
# Delete related data first (Cascade might handle this but being explicit is safer)
|
|
db.query(EnrichmentData).filter(EnrichmentData.company_id == company_id).delete()
|
|
db.query(Signal).filter(Signal.company_id == company_id).delete()
|
|
db.query(Contact).filter(Contact.company_id == company_id).delete()
|
|
|
|
db.delete(company)
|
|
db.commit()
|
|
return {"status": "deleted"}
|
|
|
|
@app.post("/api/companies/{company_id}/override/website")
|
|
def override_website(company_id: int, url: str, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
|
|
company = db.query(Company).filter(Company.id == company_id).first()
|
|
if not company:
|
|
raise HTTPException(404, detail="Company not found")
|
|
|
|
company.website = url
|
|
company.updated_at = datetime.utcnow()
|
|
db.commit()
|
|
return {"status": "updated", "website": company.website}
|
|
|
|
@app.post("/api/companies/{company_id}/override/impressum")
|
|
|
|
def override_impressum(company_id: int, url: str, background_tasks: BackgroundTasks, db: Session = Depends(get_db), username: str = Depends(authenticate_user)):
|
|
|
|
company = db.query(Company).filter(Company.id == company_id).first()
|
|
|
|
if not company:
|
|
|
|
raise HTTPException(404, detail="Company not found")
|
|
|
|
|
|
|
|
# Create or update manual impressum lock
|
|
|
|
existing = db.query(EnrichmentData).filter(
|
|
|
|
EnrichmentData.company_id == company_id,
|
|
|
|
EnrichmentData.source_type == "impressum_override"
|
|
|
|
).first()
|
|
|
|
|
|
|
|
if not existing:
|
|
|
|
db.add(EnrichmentData(
|
|
|
|
company_id=company_id,
|
|
|
|
source_type="impressum_override",
|
|
|
|
content={"url": url},
|
|
|
|
is_locked=True
|
|
|
|
))
|
|
|
|
else:
|
|
|
|
existing.content = {"url": url}
|
|
|
|
existing.is_locked = True
|
|
|
|
|
|
|
|
db.commit()
|
|
|
|
return {"status": "updated"}
|
|
|
|
|
|
|
|
@app.post("/api/companies/{company_id}/report-mistake")
|
|
|
|
def report_company_mistake(
|
|
|
|
company_id: int,
|
|
|
|
request: ReportMistakeRequest,
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
username: str = Depends(authenticate_user)
|
|
|
|
):
|
|
company = db.query(Company).filter(Company.id == company_id).first()
|
|
|
|
if not company:
|
|
|
|
raise HTTPException(404, detail="Company not found")
|
|
|
|
|
|
|
|
new_mistake = ReportedMistake(
|
|
|
|
company_id=company_id,
|
|
|
|
field_name=request.field_name,
|
|
|
|
wrong_value=request.wrong_value,
|
|
|
|
corrected_value=request.corrected_value,
|
|
|
|
source_url=request.source_url,
|
|
|
|
quote=request.quote,
|
|
|
|
user_comment=request.user_comment
|
|
|
|
)
|
|
|
|
db.add(new_mistake)
|
|
|
|
db.commit()
|
|
|
|
db.refresh(new_mistake)
|
|
|
|
|
|
|
|
logger.info(f"Reported mistake for company {company_id}: {request.field_name} -> {request.corrected_value}")
|
|
|
|
return {"status": "success", "mistake_id": new_mistake.id}
|
|
|
|
|
|
|
|
|
|
|
|
def run_wikipedia_reevaluation_task(company_id: int):
|
|
|
|
from .database import SessionLocal
|
|
|
|
db = SessionLocal()
|
|
try:
|
|
company = db.query(Company).filter(Company.id == company_id).first()
|
|
if not company: return
|
|
|
|
logger.info(f"Re-evaluating Wikipedia metric for {company.name} (Industry: {company.industry_ai})")
|
|
|
|
industry = db.query(Industry).filter(Industry.name == company.industry_ai).first()
|
|
|
|
if industry:
|
|
classifier.reevaluate_wikipedia_metric(company, db, industry)
|
|
logger.info(f"Wikipedia metric re-evaluation complete for {company.name}")
|
|
else:
|
|
logger.warning(f"Industry '{company.industry_ai}' not found for re-evaluation.")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Wikipedia Re-evaluation Task Error: {e}", exc_info=True)
|
|
finally:
|
|
db.close()
|
|
|
|
def run_metric_reextraction_task(company_id: int):
|
|
from .database import SessionLocal
|
|
db = SessionLocal()
|
|
try:
|
|
company = db.query(Company).filter(Company.id == company_id).first()
|
|
if not company: return
|
|
|
|
logger.info(f"Re-extracting metrics for {company.name} (Industry: {company.industry_ai})")
|
|
|
|
industries = db.query(Industry).all()
|
|
industry = next((i for i in industries if i.name == company.industry_ai), None)
|
|
|
|
if industry:
|
|
classifier.extract_metrics_for_industry(company, db, industry)
|
|
company.status = "ENRICHED"
|
|
db.commit()
|
|
logger.info(f"Metric re-extraction complete for {company.name}")
|
|
else:
|
|
logger.warning(f"Industry '{company.industry_ai}' not found for re-extraction.")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Metric Re-extraction Task Error: {e}", exc_info=True)
|
|
finally:
|
|
db.close()
|
|
|
|
def run_discovery_task(company_id: int):
|
|
from .database import SessionLocal
|
|
db = SessionLocal()
|
|
try:
|
|
company = db.query(Company).filter(Company.id == company_id).first()
|
|
if not company: return
|
|
|
|
# 1. Website Search
|
|
if not company.website or company.website == "k.A.":
|
|
found_url = discovery.find_company_website(company.name, company.city)
|
|
if found_url and found_url != "k.A.":
|
|
company.website = found_url
|
|
|
|
# 2. Wikipedia Search
|
|
existing_wiki = db.query(EnrichmentData).filter(
|
|
EnrichmentData.company_id == company.id,
|
|
EnrichmentData.source_type == "wikipedia"
|
|
).first()
|
|
|
|
if not existing_wiki or not existing_wiki.is_locked:
|
|
wiki_url = discovery.find_wikipedia_url(company.name, website=company.website, city=company.city)
|
|
wiki_data = discovery.extract_wikipedia_data(wiki_url) if wiki_url and wiki_url != "k.A." else {"url": wiki_url}
|
|
|
|
if not existing_wiki:
|
|
db.add(EnrichmentData(company_id=company.id, source_type="wikipedia", content=wiki_data))
|
|
else:
|
|
existing_wiki.content = wiki_data
|
|
existing_wiki.updated_at = datetime.utcnow()
|
|
|
|
if company.status == "NEW" and company.website and company.website != "k.A.":
|
|
company.status = "DISCOVERED"
|
|
|
|
db.commit()
|
|
except Exception as e:
|
|
logger.error(f"Discovery Task Error: {e}", exc_info=True)
|
|
finally:
|
|
db.close()
|
|
|
|
def run_analysis_task(company_id: int):
|
|
from .database import SessionLocal
|
|
db = SessionLocal()
|
|
try:
|
|
company = db.query(Company).filter(Company.id == company_id).first()
|
|
if not company:
|
|
logger.error(f"Analysis Task: Company with ID {company_id} not found.")
|
|
return
|
|
|
|
logger.info(f"--- [BACKGROUND TASK] Starting for {company.name} ---")
|
|
|
|
# --- 1. Scrape Website (if not locked) ---
|
|
existing_scrape = db.query(EnrichmentData).filter(
|
|
EnrichmentData.company_id == company.id,
|
|
EnrichmentData.source_type == "website_scrape"
|
|
).first()
|
|
|
|
if not existing_scrape or not existing_scrape.is_locked:
|
|
logger.info(f"Scraping website for {company.name}...")
|
|
scrape_res = scraper.scrape_url(company.website)
|
|
if not existing_scrape:
|
|
db.add(EnrichmentData(company_id=company.id, source_type="website_scrape", content=scrape_res))
|
|
logger.info("Created new website_scrape entry.")
|
|
else:
|
|
existing_scrape.content = scrape_res
|
|
existing_scrape.updated_at = datetime.utcnow()
|
|
logger.info("Updated existing website_scrape entry.")
|
|
db.commit()
|
|
else:
|
|
logger.info("Website scrape is locked. Skipping.")
|
|
|
|
# --- 2. Classify Industry & Metrics ---
|
|
logger.info(f"Handing over to ClassificationService for {company.name}...")
|
|
classifier.classify_company_potential(company, db)
|
|
|
|
company.status = "ENRICHED"
|
|
db.commit()
|
|
logger.info(f"--- [BACKGROUND TASK] Successfully finished for {company.name} ---")
|
|
|
|
except Exception as e:
|
|
logger.critical(f"--- [BACKGROUND TASK] CRITICAL ERROR for Company ID {company_id} ---", exc_info=True)
|
|
finally:
|
|
db.close()
|
|
|
|
# --- Serve Frontend ---
|
|
static_path = "/frontend_static"
|
|
if not os.path.exists(static_path):
|
|
# Local dev fallback
|
|
static_path = os.path.join(os.path.dirname(__file__), "../../frontend/dist")
|
|
if not os.path.exists(static_path):
|
|
static_path = os.path.join(os.path.dirname(__file__), "../static")
|
|
|
|
logger.info(f"Static files path: {static_path} (Exists: {os.path.exists(static_path)})")
|
|
|
|
if os.path.exists(static_path):
|
|
@app.get("/")
|
|
async def serve_index():
|
|
return FileResponse(os.path.join(static_path, "index.html"))
|
|
|
|
app.mount("/", StaticFiles(directory=static_path, html=True), name="static")
|
|
else:
|
|
@app.get("/")
|
|
def root_no_frontend():
|
|
return {"message": "Company Explorer API is running, but frontend was not found.", "path_tried": static_path}
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run("backend.app:app", host="0.0.0.0", port=8000, reload=True)
|