Files
Brancheneinstufung2/heatmap-tool/backend/main.py

191 lines
7.2 KiB
Python

from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import pandas as pd
import io
from pydantic import BaseModel
from typing import Dict, List
app = FastAPI()
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
# --- In-memory Storage & Data Loading ---
df_storage = None
plz_column_name = None
plz_geocoord_df = None
@app.on_event("startup")
def load_plz_data():
global plz_geocoord_df
try:
print("--- Loading PLZ geocoordinates dataset... ---")
# The CSV has a malformed header. We read it and assign names manually.
df = pd.read_csv("plz_geocoord.csv", dtype=str)
# Rename the columns based on their expected order: PLZ, Latitude, Longitude
df.columns = ['plz', 'y', 'x']
df['plz'] = df['plz'].str.zfill(5)
plz_geocoord_df = df.set_index('plz')
print(f"--- Successfully loaded {len(plz_geocoord_df)} PLZ coordinates. ---")
except FileNotFoundError:
print("--- FATAL ERROR: plz_geocoord.csv not found. Geocoding will not work. ---")
plz_geocoord_df = pd.DataFrame()
except Exception as e:
print(f"--- FATAL ERROR loading plz_geocoord.csv: {e} ---")
plz_geocoord_df = pd.DataFrame()
# --- Pydantic Models ---
class FilterRequest(BaseModel):
filters: Dict[str, List[str]]
# --- API Endpoints ---
@app.get("/")
def read_root():
return {"message": "Heatmap Tool Backend"}
@app.post("/api/upload")
async def upload_file(file: UploadFile = File(...)):
global df_storage, plz_column_name
print(f"--- Received request to /api/upload for file: {file.filename} ---")
if not file.filename.endswith('.xlsx'):
raise HTTPException(status_code=400, detail="Invalid file format. Please upload an .xlsx file.")
try:
contents = await file.read()
df = pd.read_excel(io.BytesIO(contents), dtype=str) # Read all as string to be safe
df.fillna('N/A', inplace=True)
# --- PLZ Column Detection ---
temp_plz_col = None
for col in df.columns:
if 'plz' in col.lower():
temp_plz_col = col
break
if not temp_plz_col:
raise HTTPException(status_code=400, detail="No column with 'PLZ' found in the file.")
plz_column_name = temp_plz_col
# Normalize PLZ data
df[plz_column_name] = df[plz_column_name].str.strip().str.zfill(5)
# --- Dynamic Filter Detection ---
filters = {}
for col in df.columns:
if col != plz_column_name:
unique_values = df[col].unique().tolist()
filters[col] = sorted(unique_values)
df_storage = df
print(f"Successfully processed file. Found PLZ column: '{plz_column_name}'. Detected {len(filters)} filterable columns.")
return {"filename": file.filename, "filters": filters, "plz_column": plz_column_name}
except Exception as e:
print(f"ERROR processing file: {e}")
raise HTTPException(status_code=500, detail=f"An error occurred while processing the file: {e}")
@app.post("/api/heatmap")
async def get_heatmap_data(request: FilterRequest):
global df_storage, plz_column_name, plz_geocoord_df
print(f"--- Received request to /api/heatmap with filters: {request.filters} ---")
if df_storage is None:
print("ERROR: No data in df_storage. File must be uploaded first.")
raise HTTPException(status_code=404, detail="No data available. Please upload a file first.")
if plz_geocoord_df.empty:
raise HTTPException(status_code=500, detail="Geocoding data is not available on the server.")
try:
filtered_df = df_storage.copy()
# Apply filters from the request
for column, values in request.filters.items():
if values:
filtered_df = filtered_df[filtered_df[column].isin(values)]
if filtered_df.empty:
return []
# Aggregate data by PLZ, and also collect attribute summaries
plz_grouped = filtered_df.groupby(plz_column_name)
plz_counts = plz_grouped.size().reset_index(name='count')
# Collect unique attributes for each PLZ
attribute_summaries = {}
for plz_val, group in plz_grouped:
summary = {}
for col in filtered_df.columns:
if col != plz_column_name and col != 'lat' and col != 'lon': # Exclude lat/lon if they somehow exist
unique_attrs = group[col].unique().tolist()
# Limit to top 3 unique values for readability
summary[col] = unique_attrs[:3]
attribute_summaries[plz_val] = summary
# Convert summaries to a DataFrame for merging
summary_df = pd.DataFrame.from_dict(attribute_summaries, orient='index')
summary_df.index.name = plz_column_name
# --- Geocoding Step ---
# Merge the aggregated counts with the geocoding dataframe
merged_df = pd.merge(
plz_counts,
plz_geocoord_df,
left_on=plz_column_name,
right_index=True,
how='inner'
)
# Merge with attribute summaries
merged_df = pd.merge(
merged_df,
summary_df,
left_on=plz_column_name,
right_index=True,
how='left'
)
# Rename columns to match frontend expectations ('lon' and 'lat')
merged_df.rename(columns={'x': 'lon', 'y': 'lat'}, inplace=True)
# Also rename the original PLZ column to the consistent name 'plz'
merged_df.rename(columns={plz_column_name: 'plz'}, inplace=True)
# Convert to the required JSON format, including all remaining columns (which are the attributes)
# We'll dynamically collect attribute columns for output
output_columns = ['plz', 'lat', 'lon', 'count']
for col in merged_df.columns:
if col not in output_columns and col != plz_column_name: # Ensure we don't duplicate PLZ or coords
output_columns.append(col)
heatmap_data = merged_df[output_columns].to_dict(orient='records')
# The frontend expects 'attributes_summary' as a single field, so let's restructure for that
# For each record, pick out the attributes that are not 'plz', 'lat', 'lon', 'count'
final_heatmap_data = []
for record in heatmap_data:
attrs = {k: v for k, v in record.items() if k not in ['plz', 'lat', 'lon', 'count']}
final_heatmap_data.append({
"plz": record['plz'],
"lat": record['lat'],
"lon": record['lon'],
"count": record['count'],
"attributes_summary": attrs
})
print(f"Generated heatmap data with {len(final_heatmap_data)} PLZ points.")
return final_heatmap_data
except Exception as e:
print(f"ERROR generating heatmap: {e}")
raise HTTPException(status_code=500, detail=f"An error occurred while generating heatmap data: {e}")