Add mastodon thread context.
This commit is contained in:
parent
08618ee9a9
commit
e00207e986
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Sequence
|
from typing import Any, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
@ -115,3 +115,17 @@ class MastodonClient:
|
||||||
"url": status.get("url"),
|
"url": status.get("url"),
|
||||||
}
|
}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_status(self, status_id: str) -> Dict[str, Any]:
|
||||||
|
if not status_id:
|
||||||
|
raise MastodonError("status_id is required")
|
||||||
|
return self._request("GET", f"/api/v1/statuses/{status_id}")
|
||||||
|
|
||||||
|
def get_status_ancestors(self, status_id: str) -> List[Dict[str, Any]]:
|
||||||
|
if not status_id:
|
||||||
|
raise MastodonError("status_id is required")
|
||||||
|
data = self._request("GET", f"/api/v1/statuses/{status_id}/context")
|
||||||
|
ancestors = data.get("ancestors")
|
||||||
|
if isinstance(ancestors, list):
|
||||||
|
return ancestors
|
||||||
|
return []
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,8 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
import typing
|
||||||
|
|
||||||
|
|
||||||
class OpenAIClientError(RuntimeError):
|
class OpenAIClientError(RuntimeError):
|
||||||
|
|
@ -20,7 +19,7 @@ class TextImprovementError(OpenAIClientError):
|
||||||
|
|
||||||
|
|
||||||
class AltTextGenerator:
|
class AltTextGenerator:
|
||||||
"""Request alt text from a GPT-4o compatible OpenAI endpoint."""
|
"""Request alt text from a GPT compatible OpenAI endpoint."""
|
||||||
|
|
||||||
def __init__(self, api_key: str, model: str = "gpt-4.1") -> None:
|
def __init__(self, api_key: str, model: str = "gpt-4.1") -> None:
|
||||||
if not api_key:
|
if not api_key:
|
||||||
|
|
@ -38,10 +37,10 @@ class AltTextGenerator:
|
||||||
def generate_alt_text(
|
def generate_alt_text(
|
||||||
self,
|
self,
|
||||||
image_source: str,
|
image_source: str,
|
||||||
notes: Optional[str] = None,
|
notes: str | None = None,
|
||||||
captured_at: Optional[str] = None,
|
captured_at: str | None = None,
|
||||||
location: Optional[str] = None,
|
location: str | None = None,
|
||||||
coordinates: Optional[str] = None,
|
coordinates: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
if not image_source:
|
if not image_source:
|
||||||
raise AltTextGenerationError("Image URL required for alt text generation")
|
raise AltTextGenerationError("Image URL required for alt text generation")
|
||||||
|
|
@ -63,7 +62,7 @@ class AltTextGenerator:
|
||||||
prompt_lines.append(f"Coordinates: {coordinates}")
|
prompt_lines.append(f"Coordinates: {coordinates}")
|
||||||
text_prompt = "\n".join(prompt_lines)
|
text_prompt = "\n".join(prompt_lines)
|
||||||
|
|
||||||
content: List[Dict[str, Any]] = [
|
content: list[dict[str, typing.Any]] = [
|
||||||
{"type": "text", "text": text_prompt},
|
{"type": "text", "text": text_prompt},
|
||||||
{"type": "image_url", "image_url": {"url": image_source}},
|
{"type": "image_url", "image_url": {"url": image_source}},
|
||||||
]
|
]
|
||||||
|
|
@ -104,8 +103,9 @@ class AltTextGenerator:
|
||||||
def improve_post_text(
|
def improve_post_text(
|
||||||
self,
|
self,
|
||||||
draft_text: str,
|
draft_text: str,
|
||||||
instructions: Optional[str] = None,
|
instructions: str | None = None,
|
||||||
hashtag_counts: Optional[str] = None,
|
hashtag_counts: str | None = None,
|
||||||
|
thread_context: list[str] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
if not draft_text or not draft_text.strip():
|
if not draft_text or not draft_text.strip():
|
||||||
raise TextImprovementError("Post text cannot be empty")
|
raise TextImprovementError("Post text cannot be empty")
|
||||||
|
|
@ -122,11 +122,29 @@ class AltTextGenerator:
|
||||||
prompt_parts.append(f"Additional instructions: {instructions.strip()}")
|
prompt_parts.append(f"Additional instructions: {instructions.strip()}")
|
||||||
if hashtag_counts:
|
if hashtag_counts:
|
||||||
prompt_parts.append(
|
prompt_parts.append(
|
||||||
"When you add hashtags, prefer the ones from my history below, "
|
"When you add hashtags, prefer the ones from my history below and "
|
||||||
"but feel free to invent new ones."
|
"only invent new ones if absolutely necessary."
|
||||||
|
)
|
||||||
|
if thread_context:
|
||||||
|
prompt_parts.append(
|
||||||
|
"You are continuing a Mastodon thread. Use the previous posts for "
|
||||||
|
"context, avoid repeating them verbatim, and keep the narrative flowing."
|
||||||
)
|
)
|
||||||
prompt_parts.append("Return only the improved post text.")
|
prompt_parts.append("Return only the improved post text.")
|
||||||
user_content = f"Draft post:\\n{draft_text.strip()}"
|
|
||||||
|
context_block = ""
|
||||||
|
if thread_context:
|
||||||
|
cleaned_context = [
|
||||||
|
entry.strip() for entry in thread_context if entry and entry.strip()
|
||||||
|
]
|
||||||
|
if cleaned_context:
|
||||||
|
formatted_history = "\n".join(
|
||||||
|
f"{index + 1}. {value}"
|
||||||
|
for index, value in enumerate(cleaned_context)
|
||||||
|
)
|
||||||
|
context_block = f"Previous thread posts:\n{formatted_history}\n\n"
|
||||||
|
|
||||||
|
user_content = f"{context_block}Draft post:\n{draft_text.strip()}"
|
||||||
if hashtag_counts:
|
if hashtag_counts:
|
||||||
user_content = (
|
user_content = (
|
||||||
f"{user_content}\n\nHashtag history (tag with past uses):\n"
|
f"{user_content}\n\nHashtag history (tag with past uses):\n"
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,9 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from html import unescape
|
||||||
|
|
||||||
import UniAuth.auth
|
import UniAuth.auth
|
||||||
from flask import (
|
from flask import (
|
||||||
|
|
@ -25,6 +27,17 @@ from .openai_client import AltTextGenerationError, TextImprovementError
|
||||||
bp = Blueprint("main", __name__)
|
bp = Blueprint("main", __name__)
|
||||||
|
|
||||||
|
|
||||||
|
_TAG_RE = re.compile(r"<[^>]+>")
|
||||||
|
|
||||||
|
|
||||||
|
def _html_to_plain_text(raw: str | None) -> str:
|
||||||
|
if not raw:
|
||||||
|
return ""
|
||||||
|
text = _TAG_RE.sub(" ", raw)
|
||||||
|
text = unescape(text)
|
||||||
|
return " ".join(text.split()).strip()
|
||||||
|
|
||||||
|
|
||||||
def _parse_timestamp(raw: str | None) -> datetime | None:
|
def _parse_timestamp(raw: str | None) -> datetime | None:
|
||||||
if not raw:
|
if not raw:
|
||||||
return None
|
return None
|
||||||
|
|
@ -351,9 +364,42 @@ def compose_draft():
|
||||||
if request.method == "POST":
|
if request.method == "POST":
|
||||||
action = request.form.get("action")
|
action = request.form.get("action")
|
||||||
if action == "refine":
|
if action == "refine":
|
||||||
|
thread_context: list[str] | None = None
|
||||||
|
if (
|
||||||
|
reply_to_latest
|
||||||
|
and mastodon_client
|
||||||
|
and latest_status
|
||||||
|
and latest_status.get("id")
|
||||||
|
):
|
||||||
|
status_id = str(latest_status["id"])
|
||||||
|
context_entries: list[str] = []
|
||||||
|
try:
|
||||||
|
ancestors = mastodon_client.get_status_ancestors(status_id)
|
||||||
|
except MastodonError as exc:
|
||||||
|
if not error_message:
|
||||||
|
error_message = str(exc)
|
||||||
|
ancestors = []
|
||||||
|
for item in ancestors:
|
||||||
|
ancestor_text = _html_to_plain_text(item.get("content"))
|
||||||
|
if ancestor_text:
|
||||||
|
context_entries.append(ancestor_text)
|
||||||
|
latest_content = latest_status.get("content")
|
||||||
|
if not latest_content:
|
||||||
|
try:
|
||||||
|
latest_payload = mastodon_client.get_status(status_id)
|
||||||
|
latest_content = latest_payload.get("content")
|
||||||
|
except MastodonError as exc:
|
||||||
|
if not error_message:
|
||||||
|
error_message = str(exc)
|
||||||
|
latest_content = None
|
||||||
|
latest_text = _html_to_plain_text(latest_content)
|
||||||
|
if latest_text:
|
||||||
|
context_entries.append(latest_text)
|
||||||
|
if context_entries:
|
||||||
|
thread_context = context_entries
|
||||||
try:
|
try:
|
||||||
post_text = generator.improve_post_text(
|
post_text = generator.improve_post_text(
|
||||||
post_text, instructions, hashtag_counts
|
post_text, instructions, hashtag_counts, thread_context
|
||||||
)
|
)
|
||||||
flash("Post refined with ChatGPT.")
|
flash("Post refined with ChatGPT.")
|
||||||
except TextImprovementError as exc:
|
except TextImprovementError as exc:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue