171 lines
7.0 KiB
Python
171 lines
7.0 KiB
Python
import sys
|
|
import os
|
|
import argparse
|
|
import json
|
|
import logging
|
|
from sqlalchemy.orm import sessionmaker, declarative_base
|
|
from sqlalchemy import create_engine, Column, Integer, String, Boolean, DateTime
|
|
from datetime import datetime
|
|
|
|
# --- Standalone Configuration ---
|
|
# Add the project root to the Python path to find the LLM utility
|
|
sys.path.insert(0, '/app')
|
|
from company_explorer.backend.lib.core_utils import call_gemini_flash
|
|
|
|
DATABASE_URL = "sqlite:////app/companies_v3_fixed_2.db"
|
|
LOG_FILE = "/app/Log_from_docker/batch_classifier.log"
|
|
BATCH_SIZE = 50 # Number of titles to process in one LLM call
|
|
|
|
# --- Logging Setup ---
|
|
os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True)
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler(LOG_FILE),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- SQLAlchemy Models (self-contained) ---
|
|
Base = declarative_base()
|
|
|
|
class RawJobTitle(Base):
|
|
__tablename__ = 'raw_job_titles'
|
|
id = Column(Integer, primary_key=True)
|
|
title = Column(String, unique=True, index=True)
|
|
count = Column(Integer, default=1)
|
|
source = Column(String)
|
|
is_mapped = Column(Boolean, default=False)
|
|
created_at = Column(DateTime, default=datetime.now)
|
|
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
|
|
|
|
class JobRolePattern(Base):
|
|
__tablename__ = "job_role_patterns"
|
|
id = Column(Integer, primary_key=True, index=True)
|
|
pattern_type = Column(String, default="exact", index=True)
|
|
pattern_value = Column(String, unique=True)
|
|
role = Column(String, index=True)
|
|
priority = Column(Integer, default=100)
|
|
is_active = Column(Boolean, default=True)
|
|
created_by = Column(String, default="system")
|
|
created_at = Column(DateTime, default=datetime.utcnow)
|
|
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
|
|
|
class Persona(Base):
|
|
__tablename__ = "personas"
|
|
id = Column(Integer, primary_key=True, index=True)
|
|
name = Column(String, unique=True, index=True)
|
|
pains = Column(String)
|
|
gains = Column(String)
|
|
created_at = Column(DateTime, default=datetime.utcnow)
|
|
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
|
|
|
# --- Database Connection ---
|
|
engine = create_engine(DATABASE_URL)
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
def build_classification_prompt(titles_to_classify, available_roles):
|
|
"""Builds the prompt for the LLM to classify a batch of job titles."""
|
|
prompt = f"""
|
|
You are an expert in B2B contact segmentation. Your task is to classify a list of job titles into predefined roles.
|
|
|
|
Analyze the following list of job titles and assign each one to the most appropriate role from the list provided.
|
|
|
|
The available roles are:
|
|
- {', '.join(available_roles)}
|
|
|
|
RULES:
|
|
1. Respond ONLY with a valid JSON object. Do not include any text, explanations, or markdown code fences before or after the JSON.
|
|
2. The JSON object should have the original job title as the key and the assigned role as the value.
|
|
3. If a job title is ambiguous or you cannot confidently classify it, assign the value "Influencer". Use this as a fallback.
|
|
4. Do not invent new roles. Only use the roles from the provided list.
|
|
|
|
Here are the job titles to classify:
|
|
{json.dumps(titles_to_classify, indent=2)}
|
|
|
|
Your JSON response:
|
|
"""
|
|
return prompt
|
|
|
|
def classify_and_store_titles():
|
|
db = SessionLocal()
|
|
try:
|
|
# 1. Fetch available persona names (roles)
|
|
personas = db.query(Persona).all()
|
|
available_roles = [p.name for p in personas]
|
|
if not available_roles:
|
|
logger.error("No Personas/Roles found in the database. Cannot classify. Please seed personas first.")
|
|
return
|
|
|
|
logger.info(f"Classifying based on these roles: {available_roles}")
|
|
|
|
# 2. Fetch unmapped titles
|
|
unmapped_titles = db.query(RawJobTitle).filter(RawJobTitle.is_mapped == False).all()
|
|
if not unmapped_titles:
|
|
logger.info("No unmapped job titles found. Nothing to do.")
|
|
return
|
|
|
|
logger.info(f"Found {len(unmapped_titles)} unmapped job titles to process.")
|
|
|
|
# 3. Process in batches
|
|
for i in range(0, len(unmapped_titles), BATCH_SIZE):
|
|
batch = unmapped_titles[i:i + BATCH_SIZE]
|
|
title_strings = [item.title for item in batch]
|
|
|
|
logger.info(f"Processing batch {i//BATCH_SIZE + 1} of { (len(unmapped_titles) + BATCH_SIZE - 1) // BATCH_SIZE } with {len(title_strings)} titles...")
|
|
|
|
# 4. Call LLM
|
|
prompt = build_classification_prompt(title_strings, available_roles)
|
|
response_text = ""
|
|
try:
|
|
response_text = call_gemini_flash(prompt, json_mode=True)
|
|
# Clean potential markdown fences
|
|
if response_text.strip().startswith("```json"):
|
|
response_text = response_text.strip()[7:-4]
|
|
|
|
classifications = json.loads(response_text)
|
|
except Exception as e:
|
|
logger.error(f"Failed to get or parse LLM response for batch. Skipping. Error: {e}")
|
|
logger.error(f"Raw response was: {response_text}")
|
|
continue
|
|
|
|
# 5. Process results
|
|
new_patterns = 0
|
|
for title_obj in batch:
|
|
original_title = title_obj.title
|
|
assigned_role = classifications.get(original_title)
|
|
|
|
if assigned_role and assigned_role in available_roles:
|
|
exists = db.query(JobRolePattern).filter(JobRolePattern.pattern_value == original_title).first()
|
|
if not exists:
|
|
new_pattern = JobRolePattern(
|
|
pattern_type='exact',
|
|
pattern_value=original_title,
|
|
role=assigned_role,
|
|
priority=90,
|
|
created_by='llm_batch'
|
|
)
|
|
db.add(new_pattern)
|
|
new_patterns += 1
|
|
title_obj.is_mapped = True
|
|
else:
|
|
logger.warning(f"Could not classify '{original_title}' or role '{assigned_role}' is invalid. It will be re-processed later.")
|
|
|
|
db.commit()
|
|
logger.info(f"Batch {i//BATCH_SIZE + 1} complete. Created {new_patterns} new mapping patterns.")
|
|
|
|
except Exception as e:
|
|
logger.error(f"An unexpected error occurred: {e}", exc_info=True)
|
|
db.rollback()
|
|
finally:
|
|
db.close()
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Batch classify unmapped job titles using an LLM.")
|
|
args = parser.parse_args()
|
|
|
|
logger.info("--- Starting Batch Classification Script ---")
|
|
classify_and_store_titles()
|
|
logger.info("--- Batch Classification Script Finished ---") |