diff --git a/train_model.py b/train_model.py index 1c8a06a1..25bca032 100644 --- a/train_model.py +++ b/train_model.py @@ -11,9 +11,8 @@ from collections import Counter import logging import sys -# Importiere deine bestehenden Helfer +# Importiere NUR noch den GoogleSheetHandler from google_sheet_handler import GoogleSheetHandler -from helpers import normalize_company_name # Logging Setup logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') @@ -34,7 +33,22 @@ STOP_TOKENS_BASE = { } CITY_TOKENS = set() -# --- Hilfsfunktionen --- +# --- Hilfsfunktionen (jetzt direkt hier definiert, um helpers.py nicht zu benötigen) --- + +def normalize_company_name(name: str): + """Normalisiert einen Firmennamen für den Vergleich.""" + if not isinstance(name, str): return '' + # Kleinschreibung + name = name.lower() + # Inhalte in Klammern entfernen + name = re.sub(r'\(.*?\)', '', name) + name = re.sub(r'\[.*?\]', '', name) + # Alle Nicht-alphanumerischen Zeichen durch Leerzeichen ersetzen (Umlaute beibehalten) + name = re.sub(r'[^a-z0-9äöüß\s]', ' ', name) + # Mehrfache Leerzeichen durch ein einzelnes ersetzen und an den Rändern trimmen + name = re.sub(r'\s+', ' ', name).strip() + return name + def _tokenize(s: str): if not s: return [] return re.split(r"[^a-z0-9äöüß]+", str(s).lower()) @@ -49,12 +63,13 @@ def clean_name_for_scoring(norm_name: str): def choose_rarest_token(norm_name: str, term_weights: dict): _, toks = clean_name_for_scoring(norm_name) if not toks: return None + # Gibt das Token mit dem höchsten Gewicht (höchster IDF-Score) zurück return max(toks, key=lambda t: term_weights.get(t, 0)) def create_features(mrec: dict, crec: dict, term_weights: dict): features = {} - n1_raw = mrec.get('normalized_name', '') - n2_raw = crec.get('normalized_name', '') + n1_raw = mrec.get('normalized_CRM Name', '') # Angepasst an Spaltennamen + n2_raw = crec.get('normalized_Kandidat', '') # Angepasst an Spaltennamen clean1, toks1 = clean_name_for_scoring(n1_raw) clean2, toks2 = clean_name_for_scoring(n2_raw) @@ -63,9 +78,10 @@ def create_features(mrec: dict, crec: dict, term_weights: dict): features['fuzz_token_set_ratio'] = fuzz.token_set_ratio(clean1, clean2) features['fuzz_token_sort_ratio'] = fuzz.token_sort_ratio(clean1, clean2) - # Normalisiere Domains für den Vergleich - domain1 = str(mrec.get('CRM Website', '')).lower().replace('www.', '').split('/')[0] - domain2 = str(crec.get('Kandidat Website', '')).lower().replace('www.', '').split('/')[0] + domain1_raw = str(mrec.get('CRM Website', '')).lower() + domain2_raw = str(crec.get('Kandidat Website', '')).lower() + domain1 = domain1_raw.replace('www.', '').split('/')[0].strip() + domain2 = domain2_raw.replace('www.', '').split('/')[0].strip() features['domain_match'] = 1 if domain1 and domain1 == domain2 else 0 features['city_match'] = 1 if mrec.get('CRM Ort') and mrec.get('CRM Ort') == crec.get('Kandidat Ort') else 0 @@ -113,12 +129,16 @@ if __name__ == "__main__": logging.info("Erstelle Features für den Trainingsdatensatz...") features_list = [] labels = [] + + # Sicherstellen, dass die Spalte 'Best Match Option' existiert, um Fehler zu vermeiden + if 'Best Match Option' not in gold_df.columns: + logging.error("Die Spalte 'Best Match Option' wurde in deiner CSV nicht gefunden. Bitte überprüfe den Spaltennamen.") + sys.exit(1) for _, row in gold_df.iterrows(): - # Nur Zeilen mit einem Kandidaten und einem Best Match verarbeiten - if pd.notna(row['Kandidat']) and pd.notna(row['Best Match Option']): + if pd.notna(row['Kandidat']) and pd.notna(row['Best Match Option']) and str(row['Best Match Option']).strip() != '': mrec = row.to_dict() - crec = {'Kandidat Website': row['Kandidat Website'], 'Kandidat Ort': row['Kandidat Ort'], 'Kandidat Land': row['Kandidat Land']} + crec = {'normalized_name': row['normalized_Kandidat'], 'Kandidat Website': row['Kandidat Website'], 'Kandidat Ort': row['Kandidat Ort'], 'Kandidat Land': row['Kandidat Land']} features = create_features(mrec, crec, term_weights) features_list.append(features) @@ -129,13 +149,20 @@ if __name__ == "__main__": X = pd.DataFrame(features_list) y = np.array(labels) + if len(X) == 0: + logging.critical("Keine gültigen Trainingsdaten gefunden. Überprüfe die Spalten 'Kandidat' and 'Best Match Option' in deiner CSV.") + sys.exit(1) + logging.info(f"Trainingsdatensatz erstellt mit {X.shape[0]} Beispielen und {X.shape[1]} Features.") logging.info(f"Verteilung der Klassen: {Counter(y)}") logging.info("Trainiere das XGBoost-Modell...") X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y) - model = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss', scale_pos_weight= (len(y) - sum(y)) / sum(y)) + # Balance der Klassen für das Training + scale_pos_weight = (len(y_train) - sum(y_train)) / sum(y_train) if sum(y_train) > 0 else 1 + + model = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss', scale_pos_weight=scale_pos_weight) model.fit(X_train, y_train) logging.info("Modell erfolgreich trainiert.")