station-announcer/station_announcer/openai_client.py
2025-11-15 19:00:09 +00:00

143 lines
5.1 KiB
Python

"""Thin wrapper around the OpenAI API for generating alt text."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
import requests
class OpenAIClientError(RuntimeError):
"""Raised when the OpenAI API cannot fulfill a request."""
class AltTextGenerationError(OpenAIClientError):
"""Raised when the OpenAI API cannot generate alt text."""
class TextImprovementError(OpenAIClientError):
"""Raised when the OpenAI API cannot improve post text."""
class AltTextGenerator:
"""Request alt text from a GPT-4o compatible OpenAI endpoint."""
def __init__(self, api_key: str, model: str = "gpt-4o-mini") -> None:
if not api_key:
raise ValueError("OPENAI_API_KEY is required")
self.model = model
self.session = requests.Session()
self.session.headers.update(
{
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
)
self.endpoint = "https://api.openai.com/v1/chat/completions"
def generate_alt_text(
self,
image_source: str,
notes: Optional[str] = None,
captured_at: Optional[str] = None,
location: Optional[str] = None,
coordinates: Optional[str] = None,
) -> str:
if not image_source:
raise AltTextGenerationError("Image URL required for alt text generation")
prompt_lines = [
"You write vivid but concise Mastodon alt text.",
"Keep it under 400 characters and mention key visual details, colours, "
"actions, and text. No need to mention the mood unless you think it is "
"super relevant.",
"Avoid speculation beyond what is visible. Use UK English spelling.",
]
if notes:
prompt_lines.append(f"Creator notes: {notes.strip()}")
if captured_at:
prompt_lines.append(f"Captured: {captured_at}")
if location:
prompt_lines.append(f"Location: {location}")
if coordinates:
prompt_lines.append(f"Coordinates: {coordinates}")
text_prompt = "\n".join(prompt_lines)
content: List[Dict[str, Any]] = [
{"type": "text", "text": text_prompt},
{"type": "image_url", "image_url": {"url": image_source}},
]
payload = {
"model": self.model,
"temperature": 0.2,
"max_tokens": 300,
"messages": [
{
"role": "system",
"content": "You help write accessible alt text for social media posts.",
},
{
"role": "user",
"content": content,
},
],
}
try:
response = self.session.post(self.endpoint, json=payload, timeout=30)
response.raise_for_status()
data = response.json()
except requests.RequestException as exc: # pragma: no cover
raise AltTextGenerationError(str(exc)) from exc
choices = data.get("choices") or []
if not choices:
raise AltTextGenerationError("OpenAI response did not include choices")
message = choices[0].get("message", {})
content_text = message.get("content")
if not content_text:
raise AltTextGenerationError("OpenAI response missing content")
return content_text.strip()
def improve_post_text(
self, draft_text: str, instructions: Optional[str] = None
) -> str:
if not draft_text or not draft_text.strip():
raise TextImprovementError("Post text cannot be empty")
prompt_parts = [
"You review Mastodon drafts and rewrite them in UK English.",
"Keep the tone warm, accessible, and descriptive without exaggeration.",
"Ensure clarity, fix spelling or grammar, and keep content suitable for social media.",
]
if instructions:
prompt_parts.append(f"Additional instructions: {instructions.strip()}")
prompt_parts.append("Return only the improved post text.")
user_content = f"Draft post:\\n{draft_text.strip()}"
payload = {
"model": self.model,
"temperature": 0.4,
"max_tokens": 400,
"messages": [
{"role": "system", "content": "\n".join(prompt_parts)},
{"role": "user", "content": user_content},
],
}
try:
response = self.session.post(self.endpoint, json=payload, timeout=30)
response.raise_for_status()
data = response.json()
except requests.RequestException as exc: # pragma: no cover
raise TextImprovementError(str(exc)) from exc
choices = data.get("choices") or []
if not choices:
raise TextImprovementError("OpenAI response did not include choices")
content_text = choices[0].get("message", {}).get("content")
if not content_text:
raise TextImprovementError("OpenAI response missing content")
return content_text.strip()