station-announcer/station_announcer/openai_client.py

159 lines
5.8 KiB
Python

"""Thin wrapper around the OpenAI API for generating alt text."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
import requests
class OpenAIClientError(RuntimeError):
"""Raised when the OpenAI API cannot fulfill a request."""
class AltTextGenerationError(OpenAIClientError):
"""Raised when the OpenAI API cannot generate alt text."""
class TextImprovementError(OpenAIClientError):
"""Raised when the OpenAI API cannot improve post text."""
class AltTextGenerator:
"""Request alt text from a GPT-4o compatible OpenAI endpoint."""
def __init__(self, api_key: str, model: str = "gpt-4.1") -> None:
if not api_key:
raise ValueError("OPENAI_API_KEY is required")
self.model = model
self.session = requests.Session()
self.session.headers.update(
{
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
)
self.endpoint = "https://api.openai.com/v1/chat/completions"
def generate_alt_text(
self,
image_source: str,
notes: Optional[str] = None,
captured_at: Optional[str] = None,
location: Optional[str] = None,
coordinates: Optional[str] = None,
) -> str:
if not image_source:
raise AltTextGenerationError("Image URL required for alt text generation")
prompt_lines = [
"You write vivid but concise Mastodon alt text.",
"Keep it under 400 characters and mention key visual details, colours, "
"actions, and text. No need to mention the mood unless you think it is "
"super relevant.",
"Avoid speculation beyond what is visible. Use UK English spelling.",
]
if notes:
prompt_lines.append(f"Creator notes: {notes.strip()}")
if captured_at:
prompt_lines.append(f"Captured: {captured_at}")
if location:
prompt_lines.append(f"Location: {location}")
if coordinates:
prompt_lines.append(f"Coordinates: {coordinates}")
text_prompt = "\n".join(prompt_lines)
content: List[Dict[str, Any]] = [
{"type": "text", "text": text_prompt},
{"type": "image_url", "image_url": {"url": image_source}},
]
payload = {
"model": self.model,
"temperature": 0.2,
"max_tokens": 300,
"messages": [
{
"role": "system",
"content": "You help write accessible alt text for social media posts.",
},
{
"role": "user",
"content": content,
},
],
}
try:
response = self.session.post(self.endpoint, json=payload, timeout=30)
response.raise_for_status()
data = response.json()
except requests.RequestException as exc: # pragma: no cover
raise AltTextGenerationError(str(exc)) from exc
choices = data.get("choices") or []
if not choices:
raise AltTextGenerationError("OpenAI response did not include choices")
message = choices[0].get("message", {})
content_text = message.get("content")
if not content_text:
raise AltTextGenerationError("OpenAI response missing content")
return content_text.strip()
def improve_post_text(
self,
draft_text: str,
instructions: Optional[str] = None,
hashtag_counts: Optional[str] = None,
) -> str:
if not draft_text or not draft_text.strip():
raise TextImprovementError("Post text cannot be empty")
prompt_parts = [
"You review Mastodon drafts and improve them. Use UK English spelling.",
"Keep the tone warm, accessible, and descriptive without exaggeration.",
"Ensure clarity, fix spelling or grammar, and keep content suitable for social media.",
"Don't add #TravelVibes if I'm at home in Bristol, okay to add if setting off on a trip.",
"Don't add #BetterByRail if not on a rail journey, equally don't remove it if present.",
"Always format hashtags in CamelCase (e.g. #BetterByRail).",
]
if instructions:
prompt_parts.append(f"Additional instructions: {instructions.strip()}")
if hashtag_counts:
prompt_parts.append(
"When you add hashtags, prefer the ones from my history below, "
"but feel free to invent new ones."
)
prompt_parts.append("Return only the improved post text.")
user_content = f"Draft post:\\n{draft_text.strip()}"
if hashtag_counts:
user_content = (
f"{user_content}\n\nHashtag history (tag with past uses):\n"
f"{hashtag_counts.strip()}"
)
payload = {
"model": self.model,
"temperature": 0.4,
"max_tokens": 400,
"messages": [
{"role": "system", "content": "\n".join(prompt_parts)},
{"role": "user", "content": user_content},
],
}
try:
response = self.session.post(self.endpoint, json=payload, timeout=30)
response.raise_for_status()
data = response.json()
except requests.RequestException as exc: # pragma: no cover
raise TextImprovementError(str(exc)) from exc
choices = data.get("choices") or []
if not choices:
raise TextImprovementError("OpenAI response did not include choices")
content_text = choices[0].get("message", {}).get("content")
if not content_text:
raise TextImprovementError("OpenAI response missing content")
return content_text.strip()