"""Thin wrapper around the OpenAI API for generating alt text.""" from __future__ import annotations import requests import typing class OpenAIClientError(RuntimeError): """Raised when the OpenAI API cannot fulfill a request.""" class AltTextGenerationError(OpenAIClientError): """Raised when the OpenAI API cannot generate alt text.""" class TextImprovementError(OpenAIClientError): """Raised when the OpenAI API cannot improve post text.""" class AltTextGenerator: """Request alt text from a GPT compatible OpenAI endpoint.""" def __init__(self, api_key: str, model: str = "gpt-4.1") -> None: if not api_key: raise ValueError("OPENAI_API_KEY is required") self.model = model self.session = requests.Session() self.session.headers.update( { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } ) self.endpoint = "https://api.openai.com/v1/chat/completions" def generate_alt_text( self, image_source: str, notes: str | None = None, captured_at: str | None = None, location: str | None = None, coordinates: str | None = None, ) -> str: if not image_source: raise AltTextGenerationError("Image URL required for alt text generation") prompt_lines = [ "You write vivid but concise Mastodon alt text.", "Keep it under 400 characters and mention key visual details, colours, " "actions, and text. No need to mention the mood unless you think it is " "super relevant.", "Avoid speculation beyond what is visible. Use UK English spelling.", ] if notes: prompt_lines.append(f"Creator notes: {notes.strip()}") if captured_at: prompt_lines.append(f"Captured: {captured_at}") if location: prompt_lines.append(f"Location: {location}") if coordinates: prompt_lines.append(f"Coordinates: {coordinates}") text_prompt = "\n".join(prompt_lines) content: list[dict[str, typing.Any]] = [ {"type": "text", "text": text_prompt}, {"type": "image_url", "image_url": {"url": image_source}}, ] payload = { "model": self.model, "temperature": 0.2, "max_tokens": 300, "messages": [ { "role": "system", "content": "You help write accessible alt text for social media posts.", }, { "role": "user", "content": content, }, ], } try: response = self.session.post(self.endpoint, json=payload, timeout=30) response.raise_for_status() data = response.json() except requests.RequestException as exc: # pragma: no cover raise AltTextGenerationError(str(exc)) from exc choices = data.get("choices") or [] if not choices: raise AltTextGenerationError("OpenAI response did not include choices") message = choices[0].get("message", {}) content_text = message.get("content") if not content_text: raise AltTextGenerationError("OpenAI response missing content") return content_text.strip() def improve_post_text( self, draft_text: str, instructions: str | None = None, hashtag_counts: str | None = None, thread_context: list[str] | None = None, ) -> str: if not draft_text or not draft_text.strip(): raise TextImprovementError("Post text cannot be empty") prompt_parts = [ "You review Mastodon drafts and improve them. Use UK English spelling.", "Keep the tone warm, accessible, and descriptive without exaggeration.", "Ensure clarity, fix spelling or grammar, and keep content suitable for social media.", "Don't add #TravelVibes if I'm at home in Bristol, okay to add if setting off on a trip.", "Don't add #BetterByRail if not on a rail journey, equally don't remove it if present.", "Always format hashtags in CamelCase (e.g. #BetterByRail).", ] if instructions: prompt_parts.append(f"Additional instructions: {instructions.strip()}") if hashtag_counts: prompt_parts.append( "When you add hashtags, prefer the ones from my history below and " "only invent new ones if absolutely necessary." ) if thread_context: prompt_parts.append( "You are continuing a Mastodon thread. Use the previous posts for " "context, avoid repeating them verbatim, and keep the narrative flowing." ) prompt_parts.append("Return only the improved post text.") context_block = "" if thread_context: cleaned_context = [ entry.strip() for entry in thread_context if entry and entry.strip() ] if cleaned_context: formatted_history = "\n".join( f"{index + 1}. {value}" for index, value in enumerate(cleaned_context) ) context_block = f"Previous thread posts:\n{formatted_history}\n\n" user_content = f"{context_block}Draft post:\n{draft_text.strip()}" if hashtag_counts: user_content = ( f"{user_content}\n\nHashtag history (tag with past uses):\n" f"{hashtag_counts.strip()}" ) payload = { "model": self.model, "temperature": 0.4, "max_tokens": 400, "messages": [ {"role": "system", "content": "\n".join(prompt_parts)}, {"role": "user", "content": user_content}, ], } try: response = self.session.post(self.endpoint, json=payload, timeout=30) response.raise_for_status() data = response.json() except requests.RequestException as exc: # pragma: no cover raise TextImprovementError(str(exc)) from exc choices = data.get("choices") or [] if not choices: raise TextImprovementError("OpenAI response did not include choices") content_text = choices[0].get("message", {}).get("content") if not content_text: raise TextImprovementError("OpenAI response missing content") return content_text.strip()