From e00207e986cef2e4b399395d3df3e2165d92468a Mon Sep 17 00:00:00 2001 From: Edward Betts Date: Sun, 16 Nov 2025 07:11:12 +0000 Subject: [PATCH] Add mastodon thread context. --- station_announcer/mastodon.py | 16 +++++++++- station_announcer/openai_client.py | 44 +++++++++++++++++++-------- station_announcer/routes.py | 48 +++++++++++++++++++++++++++++- 3 files changed, 93 insertions(+), 15 deletions(-) diff --git a/station_announcer/mastodon.py b/station_announcer/mastodon.py index 884b3c3..78e3410 100644 --- a/station_announcer/mastodon.py +++ b/station_announcer/mastodon.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence import requests @@ -115,3 +115,17 @@ class MastodonClient: "url": status.get("url"), } return None + + def get_status(self, status_id: str) -> Dict[str, Any]: + if not status_id: + raise MastodonError("status_id is required") + return self._request("GET", f"/api/v1/statuses/{status_id}") + + def get_status_ancestors(self, status_id: str) -> List[Dict[str, Any]]: + if not status_id: + raise MastodonError("status_id is required") + data = self._request("GET", f"/api/v1/statuses/{status_id}/context") + ancestors = data.get("ancestors") + if isinstance(ancestors, list): + return ancestors + return [] diff --git a/station_announcer/openai_client.py b/station_announcer/openai_client.py index ab71f73..3e68a28 100644 --- a/station_announcer/openai_client.py +++ b/station_announcer/openai_client.py @@ -2,9 +2,8 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional - import requests +import typing class OpenAIClientError(RuntimeError): @@ -20,7 +19,7 @@ class TextImprovementError(OpenAIClientError): class AltTextGenerator: - """Request alt text from a GPT-4o compatible OpenAI endpoint.""" + """Request alt text from a GPT compatible OpenAI endpoint.""" def __init__(self, api_key: str, model: str = "gpt-4.1") -> None: if not api_key: @@ -38,10 +37,10 @@ class AltTextGenerator: def generate_alt_text( self, image_source: str, - notes: Optional[str] = None, - captured_at: Optional[str] = None, - location: Optional[str] = None, - coordinates: Optional[str] = None, + notes: str | None = None, + captured_at: str | None = None, + location: str | None = None, + coordinates: str | None = None, ) -> str: if not image_source: raise AltTextGenerationError("Image URL required for alt text generation") @@ -63,7 +62,7 @@ class AltTextGenerator: prompt_lines.append(f"Coordinates: {coordinates}") text_prompt = "\n".join(prompt_lines) - content: List[Dict[str, Any]] = [ + content: list[dict[str, typing.Any]] = [ {"type": "text", "text": text_prompt}, {"type": "image_url", "image_url": {"url": image_source}}, ] @@ -104,8 +103,9 @@ class AltTextGenerator: def improve_post_text( self, draft_text: str, - instructions: Optional[str] = None, - hashtag_counts: Optional[str] = None, + instructions: str | None = None, + hashtag_counts: str | None = None, + thread_context: list[str] | None = None, ) -> str: if not draft_text or not draft_text.strip(): raise TextImprovementError("Post text cannot be empty") @@ -122,11 +122,29 @@ class AltTextGenerator: prompt_parts.append(f"Additional instructions: {instructions.strip()}") if hashtag_counts: prompt_parts.append( - "When you add hashtags, prefer the ones from my history below, " - "but feel free to invent new ones." + "When you add hashtags, prefer the ones from my history below and " + "only invent new ones if absolutely necessary." + ) + if thread_context: + prompt_parts.append( + "You are continuing a Mastodon thread. Use the previous posts for " + "context, avoid repeating them verbatim, and keep the narrative flowing." ) prompt_parts.append("Return only the improved post text.") - user_content = f"Draft post:\\n{draft_text.strip()}" + + context_block = "" + if thread_context: + cleaned_context = [ + entry.strip() for entry in thread_context if entry and entry.strip() + ] + if cleaned_context: + formatted_history = "\n".join( + f"{index + 1}. {value}" + for index, value in enumerate(cleaned_context) + ) + context_block = f"Previous thread posts:\n{formatted_history}\n\n" + + user_content = f"{context_block}Draft post:\n{draft_text.strip()}" if hashtag_counts: user_content = ( f"{user_content}\n\nHashtag history (tag with past uses):\n" diff --git a/station_announcer/routes.py b/station_announcer/routes.py index a3c8dd1..948862d 100644 --- a/station_announcer/routes.py +++ b/station_announcer/routes.py @@ -3,7 +3,9 @@ from __future__ import annotations import base64 +import re from datetime import datetime +from html import unescape import UniAuth.auth from flask import ( @@ -25,6 +27,17 @@ from .openai_client import AltTextGenerationError, TextImprovementError bp = Blueprint("main", __name__) +_TAG_RE = re.compile(r"<[^>]+>") + + +def _html_to_plain_text(raw: str | None) -> str: + if not raw: + return "" + text = _TAG_RE.sub(" ", raw) + text = unescape(text) + return " ".join(text.split()).strip() + + def _parse_timestamp(raw: str | None) -> datetime | None: if not raw: return None @@ -351,9 +364,42 @@ def compose_draft(): if request.method == "POST": action = request.form.get("action") if action == "refine": + thread_context: list[str] | None = None + if ( + reply_to_latest + and mastodon_client + and latest_status + and latest_status.get("id") + ): + status_id = str(latest_status["id"]) + context_entries: list[str] = [] + try: + ancestors = mastodon_client.get_status_ancestors(status_id) + except MastodonError as exc: + if not error_message: + error_message = str(exc) + ancestors = [] + for item in ancestors: + ancestor_text = _html_to_plain_text(item.get("content")) + if ancestor_text: + context_entries.append(ancestor_text) + latest_content = latest_status.get("content") + if not latest_content: + try: + latest_payload = mastodon_client.get_status(status_id) + latest_content = latest_payload.get("content") + except MastodonError as exc: + if not error_message: + error_message = str(exc) + latest_content = None + latest_text = _html_to_plain_text(latest_content) + if latest_text: + context_entries.append(latest_text) + if context_entries: + thread_context = context_entries try: post_text = generator.improve_post_text( - post_text, instructions, hashtag_counts + post_text, instructions, hashtag_counts, thread_context ) flash("Post refined with ChatGPT.") except TextImprovementError as exc: