189 lines
7.1 KiB
Python
189 lines
7.1 KiB
Python
"""Thin wrapper around the OpenAI API for generating alt text."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import requests
|
|
import typing
|
|
|
|
from .config import DEFAULT_OPENAI_MODEL
|
|
|
|
class OpenAIClientError(RuntimeError):
|
|
"""Raised when the OpenAI API cannot fulfill a request."""
|
|
|
|
|
|
class AltTextGenerationError(OpenAIClientError):
|
|
"""Raised when the OpenAI API cannot generate alt text."""
|
|
|
|
|
|
class TextImprovementError(OpenAIClientError):
|
|
"""Raised when the OpenAI API cannot improve post text."""
|
|
|
|
|
|
class AltTextGenerator:
|
|
"""Request alt text from a GPT compatible OpenAI endpoint."""
|
|
|
|
def __init__(self, api_key: str, model: str = DEFAULT_OPENAI_MODEL) -> None:
|
|
if not api_key:
|
|
raise ValueError("OPENAI_API_KEY is required")
|
|
self.model = model
|
|
self.session = requests.Session()
|
|
self.session.headers.update(
|
|
{
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
)
|
|
self.endpoint = "https://api.openai.com/v1/chat/completions"
|
|
|
|
def generate_alt_text(
|
|
self,
|
|
image_source: str,
|
|
notes: str | None = None,
|
|
captured_at: str | None = None,
|
|
location: str | None = None,
|
|
coordinates: str | None = None,
|
|
) -> str:
|
|
if not image_source:
|
|
raise AltTextGenerationError("Image URL required for alt text generation")
|
|
|
|
prompt_lines = [
|
|
"You write vivid but concise Mastodon alt text.",
|
|
"Keep it under 400 characters and mention key visual details, colours, "
|
|
"actions, and text. No need to mention the mood unless you think it is "
|
|
"super relevant.",
|
|
"Avoid speculation beyond what is visible. Use UK English spelling.",
|
|
]
|
|
if notes:
|
|
prompt_lines.append(f"Creator notes: {notes.strip()}")
|
|
if captured_at:
|
|
prompt_lines.append(f"Captured: {captured_at}")
|
|
if location:
|
|
prompt_lines.append(f"Location: {location}")
|
|
if coordinates:
|
|
prompt_lines.append(f"Coordinates: {coordinates}")
|
|
text_prompt = "\n".join(prompt_lines)
|
|
|
|
content: list[dict[str, typing.Any]] = [
|
|
{"type": "text", "text": text_prompt},
|
|
{"type": "image_url", "image_url": {"url": image_source}},
|
|
]
|
|
|
|
payload = {
|
|
"model": self.model,
|
|
"temperature": 0.2,
|
|
"max_tokens": 300,
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": "You help write accessible alt text for social media posts.",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": content,
|
|
},
|
|
],
|
|
}
|
|
try:
|
|
response = self.session.post(self.endpoint, json=payload, timeout=30)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
except requests.RequestException as exc: # pragma: no cover
|
|
raise AltTextGenerationError(str(exc)) from exc
|
|
|
|
choices = data.get("choices") or []
|
|
if not choices:
|
|
raise AltTextGenerationError("OpenAI response did not include choices")
|
|
|
|
message = choices[0].get("message", {})
|
|
content_text = message.get("content")
|
|
if not content_text:
|
|
raise AltTextGenerationError("OpenAI response missing content")
|
|
|
|
return content_text.strip()
|
|
|
|
def improve_post_text(
|
|
self,
|
|
draft_text: str,
|
|
instructions: str | None = None,
|
|
hashtag_counts: str | None = None,
|
|
thread_context: list[str] | None = None,
|
|
alt_texts: list[str] | None = None,
|
|
) -> str:
|
|
if not draft_text or not draft_text.strip():
|
|
raise TextImprovementError("Post text cannot be empty")
|
|
|
|
prompt_parts = [
|
|
"You review Mastodon drafts and improve them. Use UK English spelling.",
|
|
"Keep the tone warm, accessible, and descriptive without exaggeration.",
|
|
"Ensure clarity, fix spelling or grammar, and keep content suitable for social media.",
|
|
"Don't add #TravelVibes if I'm at home in Bristol, okay to add if setting off on a trip.",
|
|
"Don't add #BetterByRail if not on a rail journey, equally don't remove it if present.",
|
|
"Always format hashtags in CamelCase (e.g. #BetterByRail).",
|
|
]
|
|
if instructions:
|
|
prompt_parts.append(f"Additional instructions: {instructions.strip()}")
|
|
if hashtag_counts:
|
|
prompt_parts.append(
|
|
"When you add hashtags, prefer the ones from my history below and "
|
|
"only invent new ones if absolutely necessary."
|
|
)
|
|
if thread_context:
|
|
prompt_parts.append(
|
|
"You are continuing a Mastodon thread. Use the previous posts for "
|
|
"context, avoid repeating them verbatim, and keep the narrative flowing."
|
|
)
|
|
prompt_parts.append("Return only the improved post text.")
|
|
|
|
context_block = ""
|
|
if thread_context:
|
|
cleaned_context = [
|
|
entry.strip() for entry in thread_context if entry and entry.strip()
|
|
]
|
|
if cleaned_context:
|
|
formatted_history = "\n".join(
|
|
f"{index + 1}. {value}"
|
|
for index, value in enumerate(cleaned_context)
|
|
)
|
|
context_block = f"Previous thread posts:\n{formatted_history}\n\n"
|
|
|
|
user_content = f"{context_block}Draft post:\n{draft_text.strip()}"
|
|
if hashtag_counts:
|
|
user_content = (
|
|
f"{user_content}\n\nHashtag history (tag with past uses):\n"
|
|
f"{hashtag_counts.strip()}"
|
|
)
|
|
if alt_texts:
|
|
cleaned_alt_texts = [entry.strip() for entry in alt_texts if entry.strip()]
|
|
if cleaned_alt_texts:
|
|
alt_text_block = "\n".join(
|
|
f"{index + 1}. {value}"
|
|
for index, value in enumerate(cleaned_alt_texts)
|
|
)
|
|
user_content = (
|
|
f"{user_content}\n\nAlt text for attached images:\n{alt_text_block}"
|
|
)
|
|
|
|
payload = {
|
|
"model": self.model,
|
|
"temperature": 0.4,
|
|
"max_tokens": 400,
|
|
"messages": [
|
|
{"role": "system", "content": "\n".join(prompt_parts)},
|
|
{"role": "user", "content": user_content},
|
|
],
|
|
}
|
|
try:
|
|
response = self.session.post(self.endpoint, json=payload, timeout=30)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
except requests.RequestException as exc: # pragma: no cover
|
|
raise TextImprovementError(str(exc)) from exc
|
|
|
|
choices = data.get("choices") or []
|
|
if not choices:
|
|
raise TextImprovementError("OpenAI response did not include choices")
|
|
content_text = choices[0].get("message", {}).get("content")
|
|
if not content_text:
|
|
raise TextImprovementError("OpenAI response missing content")
|
|
return content_text.strip()
|