Add mastodon thread context.
This commit is contained in:
parent
08618ee9a9
commit
e00207e986
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional, Sequence
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
import requests
|
||||
|
||||
|
|
@ -115,3 +115,17 @@ class MastodonClient:
|
|||
"url": status.get("url"),
|
||||
}
|
||||
return None
|
||||
|
||||
def get_status(self, status_id: str) -> Dict[str, Any]:
|
||||
if not status_id:
|
||||
raise MastodonError("status_id is required")
|
||||
return self._request("GET", f"/api/v1/statuses/{status_id}")
|
||||
|
||||
def get_status_ancestors(self, status_id: str) -> List[Dict[str, Any]]:
|
||||
if not status_id:
|
||||
raise MastodonError("status_id is required")
|
||||
data = self._request("GET", f"/api/v1/statuses/{status_id}/context")
|
||||
ancestors = data.get("ancestors")
|
||||
if isinstance(ancestors, list):
|
||||
return ancestors
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -2,9 +2,8 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
import typing
|
||||
|
||||
|
||||
class OpenAIClientError(RuntimeError):
|
||||
|
|
@ -20,7 +19,7 @@ class TextImprovementError(OpenAIClientError):
|
|||
|
||||
|
||||
class AltTextGenerator:
|
||||
"""Request alt text from a GPT-4o compatible OpenAI endpoint."""
|
||||
"""Request alt text from a GPT compatible OpenAI endpoint."""
|
||||
|
||||
def __init__(self, api_key: str, model: str = "gpt-4.1") -> None:
|
||||
if not api_key:
|
||||
|
|
@ -38,10 +37,10 @@ class AltTextGenerator:
|
|||
def generate_alt_text(
|
||||
self,
|
||||
image_source: str,
|
||||
notes: Optional[str] = None,
|
||||
captured_at: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
coordinates: Optional[str] = None,
|
||||
notes: str | None = None,
|
||||
captured_at: str | None = None,
|
||||
location: str | None = None,
|
||||
coordinates: str | None = None,
|
||||
) -> str:
|
||||
if not image_source:
|
||||
raise AltTextGenerationError("Image URL required for alt text generation")
|
||||
|
|
@ -63,7 +62,7 @@ class AltTextGenerator:
|
|||
prompt_lines.append(f"Coordinates: {coordinates}")
|
||||
text_prompt = "\n".join(prompt_lines)
|
||||
|
||||
content: List[Dict[str, Any]] = [
|
||||
content: list[dict[str, typing.Any]] = [
|
||||
{"type": "text", "text": text_prompt},
|
||||
{"type": "image_url", "image_url": {"url": image_source}},
|
||||
]
|
||||
|
|
@ -104,8 +103,9 @@ class AltTextGenerator:
|
|||
def improve_post_text(
|
||||
self,
|
||||
draft_text: str,
|
||||
instructions: Optional[str] = None,
|
||||
hashtag_counts: Optional[str] = None,
|
||||
instructions: str | None = None,
|
||||
hashtag_counts: str | None = None,
|
||||
thread_context: list[str] | None = None,
|
||||
) -> str:
|
||||
if not draft_text or not draft_text.strip():
|
||||
raise TextImprovementError("Post text cannot be empty")
|
||||
|
|
@ -122,11 +122,29 @@ class AltTextGenerator:
|
|||
prompt_parts.append(f"Additional instructions: {instructions.strip()}")
|
||||
if hashtag_counts:
|
||||
prompt_parts.append(
|
||||
"When you add hashtags, prefer the ones from my history below, "
|
||||
"but feel free to invent new ones."
|
||||
"When you add hashtags, prefer the ones from my history below and "
|
||||
"only invent new ones if absolutely necessary."
|
||||
)
|
||||
if thread_context:
|
||||
prompt_parts.append(
|
||||
"You are continuing a Mastodon thread. Use the previous posts for "
|
||||
"context, avoid repeating them verbatim, and keep the narrative flowing."
|
||||
)
|
||||
prompt_parts.append("Return only the improved post text.")
|
||||
user_content = f"Draft post:\\n{draft_text.strip()}"
|
||||
|
||||
context_block = ""
|
||||
if thread_context:
|
||||
cleaned_context = [
|
||||
entry.strip() for entry in thread_context if entry and entry.strip()
|
||||
]
|
||||
if cleaned_context:
|
||||
formatted_history = "\n".join(
|
||||
f"{index + 1}. {value}"
|
||||
for index, value in enumerate(cleaned_context)
|
||||
)
|
||||
context_block = f"Previous thread posts:\n{formatted_history}\n\n"
|
||||
|
||||
user_content = f"{context_block}Draft post:\n{draft_text.strip()}"
|
||||
if hashtag_counts:
|
||||
user_content = (
|
||||
f"{user_content}\n\nHashtag history (tag with past uses):\n"
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import re
|
||||
from datetime import datetime
|
||||
from html import unescape
|
||||
|
||||
import UniAuth.auth
|
||||
from flask import (
|
||||
|
|
@ -25,6 +27,17 @@ from .openai_client import AltTextGenerationError, TextImprovementError
|
|||
bp = Blueprint("main", __name__)
|
||||
|
||||
|
||||
_TAG_RE = re.compile(r"<[^>]+>")
|
||||
|
||||
|
||||
def _html_to_plain_text(raw: str | None) -> str:
|
||||
if not raw:
|
||||
return ""
|
||||
text = _TAG_RE.sub(" ", raw)
|
||||
text = unescape(text)
|
||||
return " ".join(text.split()).strip()
|
||||
|
||||
|
||||
def _parse_timestamp(raw: str | None) -> datetime | None:
|
||||
if not raw:
|
||||
return None
|
||||
|
|
@ -351,9 +364,42 @@ def compose_draft():
|
|||
if request.method == "POST":
|
||||
action = request.form.get("action")
|
||||
if action == "refine":
|
||||
thread_context: list[str] | None = None
|
||||
if (
|
||||
reply_to_latest
|
||||
and mastodon_client
|
||||
and latest_status
|
||||
and latest_status.get("id")
|
||||
):
|
||||
status_id = str(latest_status["id"])
|
||||
context_entries: list[str] = []
|
||||
try:
|
||||
ancestors = mastodon_client.get_status_ancestors(status_id)
|
||||
except MastodonError as exc:
|
||||
if not error_message:
|
||||
error_message = str(exc)
|
||||
ancestors = []
|
||||
for item in ancestors:
|
||||
ancestor_text = _html_to_plain_text(item.get("content"))
|
||||
if ancestor_text:
|
||||
context_entries.append(ancestor_text)
|
||||
latest_content = latest_status.get("content")
|
||||
if not latest_content:
|
||||
try:
|
||||
latest_payload = mastodon_client.get_status(status_id)
|
||||
latest_content = latest_payload.get("content")
|
||||
except MastodonError as exc:
|
||||
if not error_message:
|
||||
error_message = str(exc)
|
||||
latest_content = None
|
||||
latest_text = _html_to_plain_text(latest_content)
|
||||
if latest_text:
|
||||
context_entries.append(latest_text)
|
||||
if context_entries:
|
||||
thread_context = context_entries
|
||||
try:
|
||||
post_text = generator.improve_post_text(
|
||||
post_text, instructions, hashtag_counts
|
||||
post_text, instructions, hashtag_counts, thread_context
|
||||
)
|
||||
flash("Post refined with ChatGPT.")
|
||||
except TextImprovementError as exc:
|
||||
|
|
|
|||
Loading…
Reference in a new issue