Files
Brancheneinstufung2/company-explorer/backend/scripts/classify_unmapped_titles.py

171 lines
7.0 KiB
Python

import sys
import os
import argparse
import json
import logging
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy import create_engine, Column, Integer, String, Boolean, DateTime
from datetime import datetime
# --- Standalone Configuration ---
# Add the project root to the Python path to find the LLM utility
sys.path.insert(0, '/app')
from company_explorer.backend.lib.core_utils import call_gemini_flash
DATABASE_URL = "sqlite:////app/companies_v3_fixed_2.db"
LOG_FILE = "/app/Log_from_docker/batch_classifier.log"
BATCH_SIZE = 50 # Number of titles to process in one LLM call
# --- Logging Setup ---
os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(LOG_FILE),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# --- SQLAlchemy Models (self-contained) ---
Base = declarative_base()
class RawJobTitle(Base):
__tablename__ = 'raw_job_titles'
id = Column(Integer, primary_key=True)
title = Column(String, unique=True, index=True)
count = Column(Integer, default=1)
source = Column(String)
is_mapped = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
class JobRolePattern(Base):
__tablename__ = "job_role_patterns"
id = Column(Integer, primary_key=True, index=True)
pattern_type = Column(String, default="exact", index=True)
pattern_value = Column(String, unique=True)
role = Column(String, index=True)
priority = Column(Integer, default=100)
is_active = Column(Boolean, default=True)
created_by = Column(String, default="system")
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class Persona(Base):
__tablename__ = "personas"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, unique=True, index=True)
pains = Column(String)
gains = Column(String)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# --- Database Connection ---
engine = create_engine(DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def build_classification_prompt(titles_to_classify, available_roles):
"""Builds the prompt for the LLM to classify a batch of job titles."""
prompt = f"""
You are an expert in B2B contact segmentation. Your task is to classify a list of job titles into predefined roles.
Analyze the following list of job titles and assign each one to the most appropriate role from the list provided.
The available roles are:
- {', '.join(available_roles)}
RULES:
1. Respond ONLY with a valid JSON object. Do not include any text, explanations, or markdown code fences before or after the JSON.
2. The JSON object should have the original job title as the key and the assigned role as the value.
3. If a job title is ambiguous or you cannot confidently classify it, assign the value "Influencer". Use this as a fallback.
4. Do not invent new roles. Only use the roles from the provided list.
Here are the job titles to classify:
{json.dumps(titles_to_classify, indent=2)}
Your JSON response:
"""
return prompt
def classify_and_store_titles():
db = SessionLocal()
try:
# 1. Fetch available persona names (roles)
personas = db.query(Persona).all()
available_roles = [p.name for p in personas]
if not available_roles:
logger.error("No Personas/Roles found in the database. Cannot classify. Please seed personas first.")
return
logger.info(f"Classifying based on these roles: {available_roles}")
# 2. Fetch unmapped titles
unmapped_titles = db.query(RawJobTitle).filter(RawJobTitle.is_mapped == False).all()
if not unmapped_titles:
logger.info("No unmapped job titles found. Nothing to do.")
return
logger.info(f"Found {len(unmapped_titles)} unmapped job titles to process.")
# 3. Process in batches
for i in range(0, len(unmapped_titles), BATCH_SIZE):
batch = unmapped_titles[i:i + BATCH_SIZE]
title_strings = [item.title for item in batch]
logger.info(f"Processing batch {i//BATCH_SIZE + 1} of { (len(unmapped_titles) + BATCH_SIZE - 1) // BATCH_SIZE } with {len(title_strings)} titles...")
# 4. Call LLM
prompt = build_classification_prompt(title_strings, available_roles)
response_text = ""
try:
response_text = call_gemini_flash(prompt, json_mode=True)
# Clean potential markdown fences
if response_text.strip().startswith("```json"):
response_text = response_text.strip()[7:-4]
classifications = json.loads(response_text)
except Exception as e:
logger.error(f"Failed to get or parse LLM response for batch. Skipping. Error: {e}")
logger.error(f"Raw response was: {response_text}")
continue
# 5. Process results
new_patterns = 0
for title_obj in batch:
original_title = title_obj.title
assigned_role = classifications.get(original_title)
if assigned_role and assigned_role in available_roles:
exists = db.query(JobRolePattern).filter(JobRolePattern.pattern_value == original_title).first()
if not exists:
new_pattern = JobRolePattern(
pattern_type='exact',
pattern_value=original_title,
role=assigned_role,
priority=90,
created_by='llm_batch'
)
db.add(new_pattern)
new_patterns += 1
title_obj.is_mapped = True
else:
logger.warning(f"Could not classify '{original_title}' or role '{assigned_role}' is invalid. It will be re-processed later.")
db.commit()
logger.info(f"Batch {i//BATCH_SIZE + 1} complete. Created {new_patterns} new mapping patterns.")
except Exception as e:
logger.error(f"An unexpected error occurred: {e}", exc_info=True)
db.rollback()
finally:
db.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Batch classify unmapped job titles using an LLM.")
args = parser.parse_args()
logger.info("--- Starting Batch Classification Script ---")
classify_and_store_titles()
logger.info("--- Batch Classification Script Finished ---")