93 lines
3.8 KiB
Python
93 lines
3.8 KiB
Python
import os
|
|
import requests
|
|
import logging
|
|
import time
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
def call_gemini_api(prompt: str, temperature: float = 0.7, retries: int = 3, timeout: int = 600) -> str:
|
|
"""
|
|
Calls the Gemini Pro API with a given prompt.
|
|
|
|
Args:
|
|
prompt: The text prompt to send to the API.
|
|
temperature: Controls randomness in output.
|
|
retries: The number of times to retry on failure.
|
|
timeout: The request timeout in seconds.
|
|
|
|
Returns:
|
|
The text response from the API or an empty string if the response is malformed.
|
|
|
|
Raises:
|
|
Exception: If the API call fails after all retries.
|
|
"""
|
|
api_key = os.getenv("GEMINI_API_KEY")
|
|
if not api_key:
|
|
logging.error("GEMINI_API_KEY environment variable not set.")
|
|
raise ValueError("API key not found.")
|
|
|
|
url = f"https://generativelanguage.googleapis.com/v1/models/gemini-1.5-flash:generateContent?key={api_key}"
|
|
headers = {'Content-Type': 'application/json'}
|
|
payload = {
|
|
"contents": [{
|
|
"parts": [{"text": prompt}]
|
|
}],
|
|
"generationConfig": {
|
|
"temperature": temperature,
|
|
"topK": 40,
|
|
"topP": 0.95,
|
|
"maxOutputTokens": 8192,
|
|
}
|
|
}
|
|
|
|
for attempt in range(retries):
|
|
try:
|
|
response = requests.post(url, headers=headers, json=payload, timeout=timeout)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
|
|
if 'candidates' in result and result['candidates']:
|
|
candidate = result['candidates'][0]
|
|
if 'content' in candidate and 'parts' in candidate['content']:
|
|
# Check for safety ratings
|
|
if 'safetyRatings' in candidate:
|
|
blocked = any(r.get('blocked') for r in candidate['safetyRatings'])
|
|
if blocked:
|
|
logging.error(f"API call blocked due to safety ratings: {candidate['safetyRatings']}")
|
|
# Provide a more specific error or return a specific string
|
|
return "[Blocked by Safety Filter]"
|
|
return candidate['content']['parts'][0]['text']
|
|
|
|
# Handle cases where the response is valid but doesn't contain expected content
|
|
if 'promptFeedback' in result and result['promptFeedback'].get('blockReason'):
|
|
reason = result['promptFeedback']['blockReason']
|
|
logging.error(f"Prompt was blocked by the API. Reason: {reason}")
|
|
return f"[Prompt Blocked: {reason}]"
|
|
|
|
logging.warning(f"Unexpected API response structure on attempt {attempt+1}: {result}")
|
|
return ""
|
|
|
|
except requests.exceptions.HTTPError as e:
|
|
if e.response.status_code in [500, 502, 503, 504] and attempt < retries - 1:
|
|
wait_time = (2 ** attempt) * 2 # Exponential backoff
|
|
logging.warning(f"Server Error {e.response.status_code}. Retrying in {wait_time}s...")
|
|
time.sleep(wait_time)
|
|
continue
|
|
logging.error(f"HTTP Error calling Gemini API: {e.response.status_code} {e.response.text}")
|
|
raise
|
|
except requests.exceptions.RequestException as e:
|
|
if attempt < retries - 1:
|
|
wait_time = (2 ** attempt) * 2
|
|
logging.warning(f"Connection Error: {e}. Retrying in {wait_time}s...")
|
|
time.sleep(wait_time)
|
|
continue
|
|
logging.error(f"Final Connection Error calling Gemini API: {e}")
|
|
raise
|
|
except Exception as e:
|
|
logging.error(f"An unexpected error occurred: {e}", exc_info=True)
|
|
raise
|
|
|
|
return "" # Should not be reached if retries are exhausted
|