159 lines
5.8 KiB
Python
159 lines
5.8 KiB
Python
"""Thin wrapper around the OpenAI API for generating alt text."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import requests
|
|
|
|
|
|
class OpenAIClientError(RuntimeError):
|
|
"""Raised when the OpenAI API cannot fulfill a request."""
|
|
|
|
|
|
class AltTextGenerationError(OpenAIClientError):
|
|
"""Raised when the OpenAI API cannot generate alt text."""
|
|
|
|
|
|
class TextImprovementError(OpenAIClientError):
|
|
"""Raised when the OpenAI API cannot improve post text."""
|
|
|
|
|
|
class AltTextGenerator:
|
|
"""Request alt text from a GPT-4o compatible OpenAI endpoint."""
|
|
|
|
def __init__(self, api_key: str, model: str = "gpt-4.1") -> None:
|
|
if not api_key:
|
|
raise ValueError("OPENAI_API_KEY is required")
|
|
self.model = model
|
|
self.session = requests.Session()
|
|
self.session.headers.update(
|
|
{
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
)
|
|
self.endpoint = "https://api.openai.com/v1/chat/completions"
|
|
|
|
def generate_alt_text(
|
|
self,
|
|
image_source: str,
|
|
notes: Optional[str] = None,
|
|
captured_at: Optional[str] = None,
|
|
location: Optional[str] = None,
|
|
coordinates: Optional[str] = None,
|
|
) -> str:
|
|
if not image_source:
|
|
raise AltTextGenerationError("Image URL required for alt text generation")
|
|
|
|
prompt_lines = [
|
|
"You write vivid but concise Mastodon alt text.",
|
|
"Keep it under 400 characters and mention key visual details, colours, "
|
|
"actions, and text. No need to mention the mood unless you think it is "
|
|
"super relevant.",
|
|
"Avoid speculation beyond what is visible. Use UK English spelling.",
|
|
]
|
|
if notes:
|
|
prompt_lines.append(f"Creator notes: {notes.strip()}")
|
|
if captured_at:
|
|
prompt_lines.append(f"Captured: {captured_at}")
|
|
if location:
|
|
prompt_lines.append(f"Location: {location}")
|
|
if coordinates:
|
|
prompt_lines.append(f"Coordinates: {coordinates}")
|
|
text_prompt = "\n".join(prompt_lines)
|
|
|
|
content: List[Dict[str, Any]] = [
|
|
{"type": "text", "text": text_prompt},
|
|
{"type": "image_url", "image_url": {"url": image_source}},
|
|
]
|
|
|
|
payload = {
|
|
"model": self.model,
|
|
"temperature": 0.2,
|
|
"max_tokens": 300,
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": "You help write accessible alt text for social media posts.",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": content,
|
|
},
|
|
],
|
|
}
|
|
try:
|
|
response = self.session.post(self.endpoint, json=payload, timeout=30)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
except requests.RequestException as exc: # pragma: no cover
|
|
raise AltTextGenerationError(str(exc)) from exc
|
|
|
|
choices = data.get("choices") or []
|
|
if not choices:
|
|
raise AltTextGenerationError("OpenAI response did not include choices")
|
|
|
|
message = choices[0].get("message", {})
|
|
content_text = message.get("content")
|
|
if not content_text:
|
|
raise AltTextGenerationError("OpenAI response missing content")
|
|
|
|
return content_text.strip()
|
|
|
|
def improve_post_text(
|
|
self,
|
|
draft_text: str,
|
|
instructions: Optional[str] = None,
|
|
hashtag_counts: Optional[str] = None,
|
|
) -> str:
|
|
if not draft_text or not draft_text.strip():
|
|
raise TextImprovementError("Post text cannot be empty")
|
|
|
|
prompt_parts = [
|
|
"You review Mastodon drafts and improve them. Use UK English spelling.",
|
|
"Keep the tone warm, accessible, and descriptive without exaggeration.",
|
|
"Ensure clarity, fix spelling or grammar, and keep content suitable for social media.",
|
|
"Don't add #TravelVibes if I'm at home in Bristol, okay to add if setting off on a trip.",
|
|
"Don't add #BetterByRail if not on a rail journey, equally don't remove it if present.",
|
|
"Always format hashtags in CamelCase (e.g. #BetterByRail).",
|
|
]
|
|
if instructions:
|
|
prompt_parts.append(f"Additional instructions: {instructions.strip()}")
|
|
if hashtag_counts:
|
|
prompt_parts.append(
|
|
"When you add hashtags, prefer the ones from my history below, "
|
|
"but feel free to invent new ones."
|
|
)
|
|
prompt_parts.append("Return only the improved post text.")
|
|
user_content = f"Draft post:\\n{draft_text.strip()}"
|
|
if hashtag_counts:
|
|
user_content = (
|
|
f"{user_content}\n\nHashtag history (tag with past uses):\n"
|
|
f"{hashtag_counts.strip()}"
|
|
)
|
|
|
|
payload = {
|
|
"model": self.model,
|
|
"temperature": 0.4,
|
|
"max_tokens": 400,
|
|
"messages": [
|
|
{"role": "system", "content": "\n".join(prompt_parts)},
|
|
{"role": "user", "content": user_content},
|
|
],
|
|
}
|
|
try:
|
|
response = self.session.post(self.endpoint, json=payload, timeout=30)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
except requests.RequestException as exc: # pragma: no cover
|
|
raise TextImprovementError(str(exc)) from exc
|
|
|
|
choices = data.get("choices") or []
|
|
if not choices:
|
|
raise TextImprovementError("OpenAI response did not include choices")
|
|
content_text = choices[0].get("message", {}).get("content")
|
|
if not content_text:
|
|
raise TextImprovementError("OpenAI response missing content")
|
|
return content_text.strip()
|