station-announcer/station_announcer/cache.py

97 lines
3.2 KiB
Python

"""SQLite-backed cache for generated alt text."""
from __future__ import annotations
import sqlite3
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, Iterable, Optional
class AltTextCache:
"""Minimal cache with a SQLite backend."""
def __init__(self, db_path: str) -> None:
self.db_path = db_path
self._ensure_db()
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
return conn
def _ensure_db(self) -> None:
path = Path(self.db_path)
path.parent.mkdir(parents=True, exist_ok=True)
with self._connect() as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS alt_text (
asset_id TEXT PRIMARY KEY,
alt_text TEXT NOT NULL,
updated_at TEXT NOT NULL
)
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS asset_usage (
asset_id TEXT PRIMARY KEY,
last_posted_at TEXT NOT NULL,
last_post_url TEXT
)
"""
)
conn.commit()
def get(self, asset_id: str) -> Optional[str]:
with self._connect() as conn:
row = conn.execute(
"SELECT alt_text FROM alt_text WHERE asset_id = ?",
(asset_id,),
).fetchone()
return row[0] if row else None
def set(self, asset_id: str, alt_text: str) -> None:
timestamp = datetime.now(timezone.utc).isoformat()
with self._connect() as conn:
conn.execute(
"""
INSERT INTO alt_text(asset_id, alt_text, updated_at)
VALUES(?, ?, ?)
ON CONFLICT(asset_id)
DO UPDATE SET alt_text = excluded.alt_text,
updated_at = excluded.updated_at
""",
(asset_id, alt_text, timestamp),
)
conn.commit()
def mark_posted(self, asset_id: str, post_url: Optional[str] = None) -> None:
timestamp = datetime.now(timezone.utc).isoformat()
with self._connect() as conn:
conn.execute(
"""
INSERT INTO asset_usage(asset_id, last_posted_at, last_post_url)
VALUES(?, ?, ?)
ON CONFLICT(asset_id)
DO UPDATE SET last_posted_at = excluded.last_posted_at,
last_post_url = excluded.last_post_url
""",
(asset_id, timestamp, post_url),
)
conn.commit()
def get_usage_map(self, asset_ids: Iterable[str]) -> Dict[str, sqlite3.Row]:
ids = [asset_id for asset_id in asset_ids if asset_id]
if not ids:
return {}
placeholders = ",".join(["?"] * len(ids))
query = (
"SELECT asset_id, last_posted_at, last_post_url FROM asset_usage"
f" WHERE asset_id IN ({placeholders})"
)
with self._connect() as conn:
rows = conn.execute(query, ids).fetchall()
return {row["asset_id"]: row for row in rows}