Add model config option and pass alt text to improve_post_text()
This commit is contained in:
parent
e00207e986
commit
8db4cc8c0b
|
|
@ -10,10 +10,14 @@ Station Announcer is a small Flask app that keeps your Immich library and Mastod
|
||||||
IMMICH_API_URL=https://photos.4angle.com/
|
IMMICH_API_URL=https://photos.4angle.com/
|
||||||
IMMICH_API_KEY=your-immich-api-key
|
IMMICH_API_KEY=your-immich-api-key
|
||||||
OPENAI_API_KEY=your-openai-api-key
|
OPENAI_API_KEY=your-openai-api-key
|
||||||
|
OPENAI_MODEL=gpt-4o-mini
|
||||||
MASTODON_BASE_URL=http://localhost:3000
|
MASTODON_BASE_URL=http://localhost:3000
|
||||||
MASTODON_ACCESS_TOKEN=your-mastodon-token
|
MASTODON_ACCESS_TOKEN=your-mastodon-token
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`OPENAI_MODEL` is optional; set it to override the default model used for both
|
||||||
|
alt-text generation and copy refinement.
|
||||||
|
|
||||||
2. Install dependencies and run the development server:
|
2. Install dependencies and run the development server:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from flask import Flask
|
||||||
sys.path.append("/home/edward/src/2024/UniAuth") # isort:skip
|
sys.path.append("/home/edward/src/2024/UniAuth") # isort:skip
|
||||||
|
|
||||||
from .cache import AltTextCache
|
from .cache import AltTextCache
|
||||||
from .config import load_settings
|
from .config import DEFAULT_OPENAI_MODEL, load_settings
|
||||||
from .immich import ImmichClient
|
from .immich import ImmichClient
|
||||||
from .mastodon import MastodonClient
|
from .mastodon import MastodonClient
|
||||||
from .openai_client import AltTextGenerator
|
from .openai_client import AltTextGenerator
|
||||||
|
|
@ -60,7 +60,7 @@ def create_app() -> Flask:
|
||||||
app.alt_text_cache = AltTextCache(db_path)
|
app.alt_text_cache = AltTextCache(db_path)
|
||||||
app.alt_text_generator = AltTextGenerator(
|
app.alt_text_generator = AltTextGenerator(
|
||||||
api_key=app.config["OPENAI_API_KEY"],
|
api_key=app.config["OPENAI_API_KEY"],
|
||||||
model=app.config.get("OPENAI_MODEL", "gpt-4o-mini"),
|
model=app.config.get("OPENAI_MODEL", DEFAULT_OPENAI_MODEL),
|
||||||
)
|
)
|
||||||
mastodon_token = app.config.get("MASTODON_ACCESS_TOKEN")
|
mastodon_token = app.config.get("MASTODON_ACCESS_TOKEN")
|
||||||
if mastodon_token:
|
if mastodon_token:
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from typing import Dict
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
DEFAULT_IMMICH_URL = "https://photos.4angle.com/"
|
DEFAULT_IMMICH_URL = "https://photos.4angle.com/"
|
||||||
|
DEFAULT_OPENAI_MODEL = "gpt-4o-mini"
|
||||||
|
|
||||||
|
|
||||||
class ConfigError(RuntimeError):
|
class ConfigError(RuntimeError):
|
||||||
|
|
@ -26,7 +27,7 @@ def load_settings() -> Dict[str, str]:
|
||||||
"IMMICH_API_KEY": os.getenv("IMMICH_API_KEY", ""),
|
"IMMICH_API_KEY": os.getenv("IMMICH_API_KEY", ""),
|
||||||
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY", ""),
|
"OPENAI_API_KEY": os.getenv("OPENAI_API_KEY", ""),
|
||||||
"RECENT_DAYS": int(os.getenv("RECENT_DAYS", "3")),
|
"RECENT_DAYS": int(os.getenv("RECENT_DAYS", "3")),
|
||||||
"OPENAI_MODEL": os.getenv("OPENAI_MODEL", "gpt-4o-mini"),
|
"OPENAI_MODEL": os.getenv("OPENAI_MODEL", DEFAULT_OPENAI_MODEL),
|
||||||
"STATION_DB": db_path,
|
"STATION_DB": db_path,
|
||||||
"SECRET_KEY": os.getenv("SECRET_KEY", ""),
|
"SECRET_KEY": os.getenv("SECRET_KEY", ""),
|
||||||
"MASTODON_BASE_URL": os.getenv(
|
"MASTODON_BASE_URL": os.getenv(
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||||
import requests
|
import requests
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
|
from .config import DEFAULT_OPENAI_MODEL
|
||||||
|
|
||||||
class OpenAIClientError(RuntimeError):
|
class OpenAIClientError(RuntimeError):
|
||||||
"""Raised when the OpenAI API cannot fulfill a request."""
|
"""Raised when the OpenAI API cannot fulfill a request."""
|
||||||
|
|
@ -21,7 +22,7 @@ class TextImprovementError(OpenAIClientError):
|
||||||
class AltTextGenerator:
|
class AltTextGenerator:
|
||||||
"""Request alt text from a GPT compatible OpenAI endpoint."""
|
"""Request alt text from a GPT compatible OpenAI endpoint."""
|
||||||
|
|
||||||
def __init__(self, api_key: str, model: str = "gpt-4.1") -> None:
|
def __init__(self, api_key: str, model: str = DEFAULT_OPENAI_MODEL) -> None:
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError("OPENAI_API_KEY is required")
|
raise ValueError("OPENAI_API_KEY is required")
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
@ -106,6 +107,7 @@ class AltTextGenerator:
|
||||||
instructions: str | None = None,
|
instructions: str | None = None,
|
||||||
hashtag_counts: str | None = None,
|
hashtag_counts: str | None = None,
|
||||||
thread_context: list[str] | None = None,
|
thread_context: list[str] | None = None,
|
||||||
|
alt_texts: list[str] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
if not draft_text or not draft_text.strip():
|
if not draft_text or not draft_text.strip():
|
||||||
raise TextImprovementError("Post text cannot be empty")
|
raise TextImprovementError("Post text cannot be empty")
|
||||||
|
|
@ -150,6 +152,16 @@ class AltTextGenerator:
|
||||||
f"{user_content}\n\nHashtag history (tag with past uses):\n"
|
f"{user_content}\n\nHashtag history (tag with past uses):\n"
|
||||||
f"{hashtag_counts.strip()}"
|
f"{hashtag_counts.strip()}"
|
||||||
)
|
)
|
||||||
|
if alt_texts:
|
||||||
|
cleaned_alt_texts = [entry.strip() for entry in alt_texts if entry.strip()]
|
||||||
|
if cleaned_alt_texts:
|
||||||
|
alt_text_block = "\n".join(
|
||||||
|
f"{index + 1}. {value}"
|
||||||
|
for index, value in enumerate(cleaned_alt_texts)
|
||||||
|
)
|
||||||
|
user_content = (
|
||||||
|
f"{user_content}\n\nAlt text for attached images:\n{alt_text_block}"
|
||||||
|
)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
|
|
|
||||||
|
|
@ -397,9 +397,18 @@ def compose_draft():
|
||||||
context_entries.append(latest_text)
|
context_entries.append(latest_text)
|
||||||
if context_entries:
|
if context_entries:
|
||||||
thread_context = context_entries
|
thread_context = context_entries
|
||||||
|
alt_texts = [
|
||||||
|
entry["alt_text"].strip()
|
||||||
|
for entry in asset_entries
|
||||||
|
if entry.get("alt_text")
|
||||||
|
]
|
||||||
try:
|
try:
|
||||||
post_text = generator.improve_post_text(
|
post_text = generator.improve_post_text(
|
||||||
post_text, instructions, hashtag_counts, thread_context
|
post_text,
|
||||||
|
instructions,
|
||||||
|
hashtag_counts,
|
||||||
|
thread_context,
|
||||||
|
alt_texts,
|
||||||
)
|
)
|
||||||
flash("Post refined with ChatGPT.")
|
flash("Post refined with ChatGPT.")
|
||||||
except TextImprovementError as exc:
|
except TextImprovementError as exc:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue