Add travel booking and place import scripts
This commit is contained in:
parent
a87c9f993e
commit
5f6cb57c2a
10 changed files with 1177 additions and 3 deletions
257
agenda/build_place_yaml.py
Normal file
257
agenda/build_place_yaml.py
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
"""Build airport and station YAML entries from Wikidata."""
|
||||
|
||||
import argparse
|
||||
import typing
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
API_URL = "https://www.wikidata.org/w/api.php"
|
||||
PERSONAL_DATA_DIR = Path("~/src/personal-data").expanduser()
|
||||
USER_AGENT = "agenda-build-place-yaml/0.1"
|
||||
|
||||
Entity = dict[str, typing.Any]
|
||||
Entities = dict[str, Entity]
|
||||
|
||||
|
||||
class WikidataClient:
|
||||
"""Small Wikidata API client for place lookups."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Create a Wikidata API session."""
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({"User-Agent": USER_AGENT})
|
||||
|
||||
def get_json(self, params: dict[str, str]) -> dict[str, typing.Any]:
|
||||
"""Fetch JSON from Wikidata."""
|
||||
response = self.session.get(API_URL, params=params)
|
||||
response.raise_for_status()
|
||||
return typing.cast(dict[str, typing.Any], response.json())
|
||||
|
||||
def get_alpha2_country_code(self, qid: str) -> str:
|
||||
"""Query the Wikidata API for alpha-2 country code."""
|
||||
data = self.get_json(
|
||||
{
|
||||
"action": "wbgetclaims",
|
||||
"entity": qid,
|
||||
"property": "P297",
|
||||
"format": "json",
|
||||
}
|
||||
)
|
||||
p297 = data["claims"]["P297"]
|
||||
return typing.cast(str, p297[0]["mainsnak"]["datavalue"]["value"]).lower()
|
||||
|
||||
def search_entities(self, query: str) -> Entities:
|
||||
"""Search Wikidata and return detailed entities."""
|
||||
search_data = self.get_json(
|
||||
{
|
||||
"action": "query",
|
||||
"list": "search",
|
||||
"format": "json",
|
||||
"srsearch": query,
|
||||
}
|
||||
)
|
||||
search_results = search_data["query"]["search"]
|
||||
if not search_results:
|
||||
return {}
|
||||
|
||||
ids = [result["title"] for result in search_results]
|
||||
entity_data = self.get_json(
|
||||
{
|
||||
"action": "wbgetentities",
|
||||
"format": "json",
|
||||
"ids": "|".join(ids),
|
||||
}
|
||||
)
|
||||
return typing.cast(Entities, entity_data["entities"])
|
||||
|
||||
|
||||
def entity_names(entity: Entity) -> set[str]:
|
||||
"""Return labels and aliases for a Wikidata entity."""
|
||||
names: set[str] = {lang["value"] for lang in entity.get("labels", {}).values()}
|
||||
for alias_list in entity.get("aliases", {}).values():
|
||||
for alias in alias_list:
|
||||
names.add(alias["value"])
|
||||
return names
|
||||
|
||||
|
||||
def exact_station_match(station_name: str, entity: Entity) -> bool:
|
||||
"""Return whether the entity names match the requested station name."""
|
||||
names = entity_names(entity)
|
||||
return station_name in names or f"{station_name} station" in names
|
||||
|
||||
|
||||
def entity_claim_value(
|
||||
entity: Entity, property_id: str, default: typing.Any = None
|
||||
) -> typing.Any:
|
||||
"""Return the first Wikidata claim value for a property."""
|
||||
claims = entity.get("claims", {})
|
||||
if property_id not in claims:
|
||||
return default
|
||||
return claims[property_id][0]["mainsnak"]["datavalue"]["value"]
|
||||
|
||||
|
||||
def build_station_info(
|
||||
client: WikidataClient, station_name: str, entity_id: str, entity: Entity
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Build a stations.yaml entry."""
|
||||
coords = entity_claim_value(entity, "P625")
|
||||
country_value = entity_claim_value(entity, "P17")
|
||||
uic = entity_claim_value(entity, "P722")
|
||||
uk_station_code = entity_claim_value(entity, "P4755")
|
||||
|
||||
station_info: dict[str, typing.Any] = {
|
||||
"name": station_name,
|
||||
"latitude": coords["latitude"],
|
||||
"longitude": coords["longitude"],
|
||||
"country": client.get_alpha2_country_code(country_value["id"]),
|
||||
"wikidata": entity_id,
|
||||
"routes": {},
|
||||
}
|
||||
|
||||
if uic is not None:
|
||||
station_info["uic"] = uic
|
||||
if uk_station_code is not None:
|
||||
station_info["alpha3"] = uk_station_code
|
||||
|
||||
return station_info
|
||||
|
||||
|
||||
def search_for_station(
|
||||
client: WikidataClient, station_name: str
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Search for a station and return a stations.yaml entry."""
|
||||
haswbstatement = "P31=Q55488|P31=Q18543139|P31=Q1147171"
|
||||
entities = client.search_entities(f"{station_name} haswbstatement:{haswbstatement}")
|
||||
|
||||
for entity_id, entity in entities.items():
|
||||
if exact_station_match(station_name, entity):
|
||||
return build_station_info(client, station_name, entity_id, entity)
|
||||
|
||||
if entities:
|
||||
entity_id, entity = next(iter(entities.items()))
|
||||
return build_station_info(client, station_name, entity_id, entity)
|
||||
|
||||
raise ValueError(f"No Wikidata station found for {station_name!r}")
|
||||
|
||||
|
||||
def build_airport_info(
|
||||
client: WikidataClient, iata: str, entity_id: str, entity: Entity
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Build an airports.yaml entry."""
|
||||
label = entity["labels"]["en"]["value"]
|
||||
claims = entity["claims"]
|
||||
coords = claims["P625"][0]["mainsnak"]["datavalue"]["value"]
|
||||
country_qid = claims["P17"][0]["mainsnak"]["datavalue"]["value"]["id"]
|
||||
|
||||
info: dict[str, typing.Any] = {
|
||||
"iata": iata,
|
||||
"name": label,
|
||||
"city": label,
|
||||
"country": client.get_alpha2_country_code(country_qid),
|
||||
"latitude": coords["latitude"],
|
||||
"longitude": coords["longitude"],
|
||||
"qid": entity_id,
|
||||
}
|
||||
|
||||
website = entity_claim_value(entity, "P856")
|
||||
if website is not None:
|
||||
info["website"] = website
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def search_for_airport(client: WikidataClient, iata: str) -> dict[str, typing.Any]:
|
||||
"""Search for an airport by IATA code and return an airports.yaml entry."""
|
||||
entities = client.search_entities(f"haswbstatement:P238={iata.upper()}")
|
||||
if not entities:
|
||||
raise ValueError(f"No Wikidata airport found for IATA code {iata!r}")
|
||||
|
||||
entity_id, entity = next(iter(entities.items()))
|
||||
return build_airport_info(client, iata.upper(), entity_id, entity)
|
||||
|
||||
|
||||
def load_yaml(path: Path) -> typing.Any:
|
||||
"""Load a YAML file."""
|
||||
return yaml.safe_load(path.read_text())
|
||||
|
||||
|
||||
def dump_yaml(data: typing.Any) -> str:
|
||||
"""Dump YAML using the local personal-data style."""
|
||||
return yaml.dump(data, sort_keys=False, allow_unicode=True)
|
||||
|
||||
|
||||
def dump_yaml_list_with_blank_lines(items: list[dict[str, typing.Any]]) -> str:
|
||||
"""Dump a YAML list with a blank line between top-level items."""
|
||||
text = dump_yaml(items).lstrip()
|
||||
return text.replace("\n- ", "\n\n- ")
|
||||
|
||||
|
||||
def upsert_station(data_dir: Path, station_info: dict[str, typing.Any]) -> bool:
|
||||
"""Add or replace a station entry. Return True when an existing entry changed."""
|
||||
path = data_dir / "stations.yaml"
|
||||
stations = typing.cast(list[dict[str, typing.Any]], load_yaml(path))
|
||||
|
||||
for index, station in enumerate(stations):
|
||||
if station.get("name") == station_info["name"]:
|
||||
stations[index] = station_info
|
||||
path.write_text(dump_yaml_list_with_blank_lines(stations))
|
||||
return True
|
||||
|
||||
stations.append(station_info)
|
||||
path.write_text(dump_yaml_list_with_blank_lines(stations))
|
||||
return False
|
||||
|
||||
|
||||
def upsert_airport(data_dir: Path, airport_info: dict[str, typing.Any]) -> bool:
|
||||
"""Add or replace an airport entry. Return True when an existing entry changed."""
|
||||
path = data_dir / "airports.yaml"
|
||||
airports = typing.cast(dict[str, dict[str, typing.Any]], load_yaml(path))
|
||||
iata = typing.cast(str, airport_info["iata"])
|
||||
existed = iata in airports
|
||||
airports[iata] = airport_info
|
||||
path.write_text(dump_yaml(airports))
|
||||
return existed
|
||||
|
||||
|
||||
def station_main(argv: list[str] | None = None) -> int:
|
||||
"""CLI entrypoint for building and importing a station YAML entry."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Add or update a station in personal-data/stations.yaml."
|
||||
)
|
||||
parser.add_argument("station_name")
|
||||
parser.add_argument("--data-dir", default=str(PERSONAL_DATA_DIR))
|
||||
parser.add_argument("--print-only", action="store_true")
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
station_info = search_for_station(WikidataClient(), args.station_name)
|
||||
if args.print_only:
|
||||
print(dump_yaml_list_with_blank_lines([station_info]).strip())
|
||||
return 0
|
||||
|
||||
replaced = upsert_station(Path(args.data_dir), station_info)
|
||||
action = "Updated" if replaced else "Added"
|
||||
print(f"{action} station {station_info['name']} in {args.data_dir}/stations.yaml")
|
||||
return 0
|
||||
|
||||
|
||||
def airport_main(argv: list[str] | None = None) -> int:
|
||||
"""CLI entrypoint for building and importing an airport YAML entry."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Add or update an airport in personal-data/airports.yaml."
|
||||
)
|
||||
parser.add_argument("iata")
|
||||
parser.add_argument("--data-dir", default=str(PERSONAL_DATA_DIR))
|
||||
parser.add_argument("--print-only", action="store_true")
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
airport_info = search_for_airport(WikidataClient(), args.iata)
|
||||
if args.print_only:
|
||||
print(dump_yaml({airport_info["iata"]: airport_info}).strip())
|
||||
return 0
|
||||
|
||||
replaced = upsert_airport(Path(args.data_dir), airport_info)
|
||||
action = "Updated" if replaced else "Added"
|
||||
print(f"{action} airport {airport_info['iata']} in {args.data_dir}/airports.yaml")
|
||||
return 0
|
||||
457
agenda/generate_booking_yaml.py
Normal file
457
agenda/generate_booking_yaml.py
Normal file
|
|
@ -0,0 +1,457 @@
|
|||
"""Generate travel booking YAML from booking text or a booking URL."""
|
||||
|
||||
import argparse
|
||||
import configparser
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime
|
||||
from pathlib import Path
|
||||
|
||||
import html2text
|
||||
import lxml.html
|
||||
import openai
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
USER_AGENT = "generate-booking-yaml/0.1"
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent
|
||||
SPEC_PATH = REPO_ROOT / "docs" / "personal-data-yaml.md"
|
||||
PERSONAL_DATA_DIR = Path("~/src/personal-data").expanduser()
|
||||
|
||||
|
||||
class TripLike(typing.Protocol):
|
||||
"""Trip attributes needed for booking import matching."""
|
||||
|
||||
start: date
|
||||
end: date | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BookingConfig:
|
||||
"""Configuration for one travel booking YAML generator."""
|
||||
|
||||
booking_type: str
|
||||
yaml_filename: str
|
||||
spec_heading: str
|
||||
json_key: str = "booking"
|
||||
|
||||
|
||||
BOOKING_CONFIGS: dict[str, BookingConfig] = {
|
||||
"flight": BookingConfig(
|
||||
booking_type="flight",
|
||||
yaml_filename="flights.yaml",
|
||||
spec_heading="flights.yaml",
|
||||
),
|
||||
"train": BookingConfig(
|
||||
booking_type="train",
|
||||
yaml_filename="trains.yaml",
|
||||
spec_heading="trains.yaml",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def read_api_key() -> str:
|
||||
"""Read API key from ~/.config/openai/config."""
|
||||
config_path = os.path.expanduser("~/.config/openai/config")
|
||||
parser = configparser.ConfigParser()
|
||||
parser.read(config_path)
|
||||
return parser["openai"]["api_key"]
|
||||
|
||||
|
||||
def read_markdown_section(markdown_text: str, heading: str) -> str:
|
||||
"""Return one second-level markdown section by heading text."""
|
||||
section_headings = (f"## `{heading}`", f"## {heading}")
|
||||
start = -1
|
||||
matched_heading = ""
|
||||
for section_heading in section_headings:
|
||||
start = markdown_text.find(section_heading)
|
||||
if start != -1:
|
||||
matched_heading = section_heading
|
||||
break
|
||||
if start == -1:
|
||||
raise ValueError(f"Could not find section for heading {heading!r}")
|
||||
|
||||
next_heading = markdown_text.find("\n## ", start + len(matched_heading))
|
||||
if next_heading == -1:
|
||||
return markdown_text[start:].strip()
|
||||
return markdown_text[start:next_heading].strip()
|
||||
|
||||
|
||||
def yaml_format_description(config: BookingConfig) -> str:
|
||||
"""Return relevant personal-data YAML documentation for the prompt."""
|
||||
spec_text = SPEC_PATH.read_text()
|
||||
sections = [
|
||||
read_markdown_section(spec_text, "General Rules"),
|
||||
read_markdown_section(spec_text, "Cross-File References"),
|
||||
read_markdown_section(spec_text, config.spec_heading),
|
||||
]
|
||||
return "\n\n".join(sections)
|
||||
|
||||
|
||||
def read_existing_bookings(
|
||||
config: BookingConfig, data_dir: Path = PERSONAL_DATA_DIR
|
||||
) -> str:
|
||||
"""Read the existing YAML file for examples and local style."""
|
||||
path = data_dir / config.yaml_filename
|
||||
return path.read_text()
|
||||
|
||||
|
||||
def build_prompt(
|
||||
booking_text: str,
|
||||
config: BookingConfig,
|
||||
current_bookings: str | None = None,
|
||||
) -> str:
|
||||
"""Build prompt to pass to the LLM."""
|
||||
bookings = current_bookings
|
||||
if bookings is None:
|
||||
bookings = read_existing_bookings(config)
|
||||
|
||||
return f"""
|
||||
I keep a record of all my {config.booking_type} bookings in a YAML file.
|
||||
|
||||
Use this YAML format specification:
|
||||
|
||||
{yaml_format_description(config)}
|
||||
|
||||
Here's my current list of bookings for examples of local style and known
|
||||
references.
|
||||
===
|
||||
{bookings}
|
||||
===
|
||||
Here's a new booking I just made.
|
||||
|
||||
Return the YAML representation for this booking using the documented format and
|
||||
the same local style as my existing bookings.
|
||||
|
||||
Rules:
|
||||
- Wrap the response in a JSON object with a single key "{config.json_key}" that
|
||||
contains the booking in YAML.
|
||||
- The value of "{config.json_key}" must be YAML text, not JSON.
|
||||
- Exclude the top-level "trip" key from the YAML.
|
||||
- Do not invent details that are not present in the booking text.
|
||||
- Quote prices and identifiers that might otherwise be parsed as numbers.
|
||||
|
||||
===
|
||||
{booking_text}
|
||||
"""
|
||||
|
||||
|
||||
def get_from_open_ai(prompt: str, model: str = "gpt-5.4") -> dict[str, str]:
|
||||
"""Pass prompt to OpenAI and get reply."""
|
||||
client = openai.OpenAI(api_key=read_api_key())
|
||||
|
||||
response = client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
model=model,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
reply = response.choices[0].message.content
|
||||
assert isinstance(reply, str)
|
||||
return typing.cast(dict[str, str], json.loads(reply))
|
||||
|
||||
|
||||
def fetch_webpage(url: str) -> lxml.html.HtmlElement:
|
||||
"""Fetch webpage HTML and parse it."""
|
||||
response = requests.get(url, headers={"User-Agent": USER_AGENT})
|
||||
response.raise_for_status()
|
||||
return lxml.html.fromstring(response.content)
|
||||
|
||||
|
||||
def webpage_to_text(root: lxml.html.HtmlElement) -> str:
|
||||
"""Convert parsed HTML into readable text content."""
|
||||
root_copy = lxml.html.fromstring(lxml.html.tostring(root))
|
||||
|
||||
for script_or_style in root_copy.xpath("//script|//style"):
|
||||
script_or_style.drop_tree()
|
||||
|
||||
text_maker = html2text.HTML2Text()
|
||||
text_maker.ignore_links = False
|
||||
text_maker.ignore_images = True
|
||||
return text_maker.handle(lxml.html.tostring(root_copy, encoding="unicode"))
|
||||
|
||||
|
||||
def url_to_booking_text(url: str) -> str:
|
||||
"""Fetch a URL and convert it to source text for the model."""
|
||||
return webpage_to_text(fetch_webpage(url))
|
||||
|
||||
|
||||
def booking_text_from_args(args: list[str]) -> str:
|
||||
"""Return booking text from a URL argument or stdin."""
|
||||
if args:
|
||||
if len(args) != 1:
|
||||
raise SystemExit("Usage: generate-BOOKING-booking-yaml [URL]")
|
||||
return url_to_booking_text(args[0])
|
||||
return sys.stdin.read()
|
||||
|
||||
|
||||
def generate_booking_yaml(
|
||||
booking_text: str, config: BookingConfig, model: str = "gpt-5.4"
|
||||
) -> str:
|
||||
"""Generate booking YAML from source text."""
|
||||
prompt = build_prompt(booking_text, config)
|
||||
return get_from_open_ai(prompt, model=model)[config.json_key]
|
||||
|
||||
|
||||
def datetime_from_yaml_value(value: typing.Any) -> datetime:
|
||||
"""Convert a YAML date/datetime/string value into a datetime."""
|
||||
if isinstance(value, datetime):
|
||||
return value
|
||||
if isinstance(value, date):
|
||||
return datetime.combine(value, datetime.min.time())
|
||||
if isinstance(value, str):
|
||||
parsed = datetime.fromisoformat(value)
|
||||
return parsed
|
||||
raise TypeError(f"Unsupported departure value: {value!r}")
|
||||
|
||||
|
||||
def first_departure(booking: dict[str, typing.Any], config: BookingConfig) -> datetime:
|
||||
"""Return the first departure datetime for a generated booking."""
|
||||
if config.booking_type == "flight":
|
||||
flights = booking["flights"]
|
||||
assert isinstance(flights, list)
|
||||
first_flight = flights[0]
|
||||
assert isinstance(first_flight, dict)
|
||||
return datetime_from_yaml_value(first_flight["depart"])
|
||||
|
||||
return datetime_from_yaml_value(booking["depart"])
|
||||
|
||||
|
||||
def comparable_departure(
|
||||
booking: dict[str, typing.Any], config: BookingConfig
|
||||
) -> datetime:
|
||||
"""Return a timezone-naive departure datetime for sorting."""
|
||||
return first_departure(booking, config).replace(tzinfo=None)
|
||||
|
||||
|
||||
def generated_bookings_from_yaml(yaml_text: str) -> list[dict[str, typing.Any]]:
|
||||
"""Parse generated booking YAML into a list of booking mappings."""
|
||||
loaded = yaml.safe_load(yaml_text)
|
||||
if isinstance(loaded, dict):
|
||||
return [typing.cast(dict[str, typing.Any], loaded)]
|
||||
if isinstance(loaded, list) and all(isinstance(item, dict) for item in loaded):
|
||||
return typing.cast(list[dict[str, typing.Any]], loaded)
|
||||
raise ValueError("Generated booking YAML must be a mapping or list of mappings.")
|
||||
|
||||
|
||||
def trip_key_position(booking: dict[str, typing.Any], config: BookingConfig) -> int:
|
||||
"""Return the preferred insertion position for the top-level trip key."""
|
||||
keys = list(booking)
|
||||
if config.booking_type == "flight":
|
||||
if "booking_reference" in booking:
|
||||
return keys.index("booking_reference") + 1
|
||||
return 0
|
||||
|
||||
if "to" in booking:
|
||||
return keys.index("to") + 1
|
||||
if "from" in booking:
|
||||
return keys.index("from") + 1
|
||||
return 0
|
||||
|
||||
|
||||
def set_trip_key(
|
||||
booking: dict[str, typing.Any], config: BookingConfig, trip_date: date
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Set trip in the usual top-level position while preserving other key order."""
|
||||
without_trip = {key: value for key, value in booking.items() if key != "trip"}
|
||||
keys = list(without_trip)
|
||||
position = trip_key_position(without_trip, config)
|
||||
reordered: dict[str, typing.Any] = {}
|
||||
|
||||
for index, key in enumerate(keys):
|
||||
if index == position:
|
||||
reordered["trip"] = trip_date
|
||||
reordered[key] = without_trip[key]
|
||||
|
||||
if "trip" not in reordered:
|
||||
reordered["trip"] = trip_date
|
||||
|
||||
booking.clear()
|
||||
booking.update(reordered)
|
||||
return booking
|
||||
|
||||
|
||||
def build_trips(data_dir: Path) -> list[TripLike]:
|
||||
"""Build trips from personal data without importing agenda.trip at module load."""
|
||||
trip_module = importlib.import_module("agenda.trip")
|
||||
build_trip_list = typing.cast(
|
||||
typing.Callable[..., list[TripLike]], trip_module.build_trip_list
|
||||
)
|
||||
return build_trip_list(data_dir=str(data_dir))
|
||||
|
||||
|
||||
def matching_trip_date(depart: datetime, data_dir: Path = PERSONAL_DATA_DIR) -> date:
|
||||
"""Find the trip grouping date for a departure, falling back to departure date."""
|
||||
depart_date = depart.date()
|
||||
matching_starts: list[date] = []
|
||||
|
||||
for trip in build_trips(data_dir):
|
||||
trip_end = trip.end or trip.start
|
||||
if trip.start <= depart_date <= trip_end:
|
||||
matching_starts.append(trip.start)
|
||||
|
||||
if matching_starts:
|
||||
return max(matching_starts)
|
||||
return depart_date
|
||||
|
||||
|
||||
def add_trip_dates(
|
||||
bookings: list[dict[str, typing.Any]],
|
||||
config: BookingConfig,
|
||||
data_dir: Path = PERSONAL_DATA_DIR,
|
||||
) -> None:
|
||||
"""Add the top-level trip key to generated booking mappings."""
|
||||
for booking in bookings:
|
||||
trip_date = matching_trip_date(first_departure(booking, config), data_dir)
|
||||
set_trip_key(booking, config, trip_date)
|
||||
|
||||
|
||||
def dump_generated_bookings(bookings: list[dict[str, typing.Any]]) -> str:
|
||||
"""Dump only generated bookings for insertion into an existing YAML list."""
|
||||
text = yaml.dump(bookings, sort_keys=False, allow_unicode=True)
|
||||
return text.lstrip()
|
||||
|
||||
|
||||
def join_yaml_list_blocks(preamble: str, blocks: list[str], trailing: str = "") -> str:
|
||||
"""Join top-level YAML list blocks with a blank line between items."""
|
||||
body = "\n\n".join(block.rstrip("\n") for block in blocks)
|
||||
return preamble + body + "\n" + trailing
|
||||
|
||||
|
||||
def split_yaml_list_blocks(text: str) -> tuple[str, list[str], str]:
|
||||
"""Split a top-level YAML list into preamble, item blocks, and trailing text."""
|
||||
lines = text.splitlines(keepends=True)
|
||||
first_item = next(
|
||||
(index for index, line in enumerate(lines) if line.startswith("- ")), None
|
||||
)
|
||||
if first_item is None:
|
||||
return text, [], ""
|
||||
|
||||
item_starts = [
|
||||
index
|
||||
for index, line in enumerate(lines[first_item:], start=first_item)
|
||||
if line.startswith("- ")
|
||||
]
|
||||
blocks = [
|
||||
"".join(lines[start:end])
|
||||
for start, end in zip(item_starts, item_starts[1:] + [len(lines)])
|
||||
]
|
||||
return "".join(lines[:first_item]), blocks, ""
|
||||
|
||||
|
||||
def existing_bookings_from_blocks(blocks: list[str]) -> list[dict[str, typing.Any]]:
|
||||
"""Parse split YAML item blocks into booking mappings."""
|
||||
bookings = []
|
||||
for block in blocks:
|
||||
loaded = yaml.safe_load(block)
|
||||
if not isinstance(loaded, list) or len(loaded) != 1:
|
||||
raise ValueError("Could not parse existing booking block.")
|
||||
item = loaded[0]
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError("Existing booking block is not a mapping.")
|
||||
bookings.append(typing.cast(dict[str, typing.Any], item))
|
||||
return bookings
|
||||
|
||||
|
||||
def insertion_index(
|
||||
existing_bookings: list[dict[str, typing.Any]],
|
||||
new_bookings: list[dict[str, typing.Any]],
|
||||
config: BookingConfig,
|
||||
) -> int:
|
||||
"""Return the chronological insertion index for generated bookings."""
|
||||
new_depart = min(comparable_departure(booking, config) for booking in new_bookings)
|
||||
for index, booking in enumerate(existing_bookings):
|
||||
if comparable_departure(booking, config) > new_depart:
|
||||
return index
|
||||
return len(existing_bookings)
|
||||
|
||||
|
||||
def insert_booking_text(
|
||||
existing_text: str,
|
||||
new_yaml_text: str,
|
||||
config: BookingConfig,
|
||||
) -> str:
|
||||
"""Insert generated booking YAML into an existing top-level YAML list."""
|
||||
preamble, blocks, trailing = split_yaml_list_blocks(existing_text)
|
||||
new_bookings = generated_bookings_from_yaml(new_yaml_text)
|
||||
_, new_blocks, _ = split_yaml_list_blocks(dump_generated_bookings(new_bookings))
|
||||
if len(new_blocks) != len(new_bookings):
|
||||
raise ValueError("Could not split generated booking YAML into item blocks.")
|
||||
|
||||
existing_bookings = existing_bookings_from_blocks(blocks)
|
||||
|
||||
new_items = sorted(
|
||||
zip(new_bookings, new_blocks),
|
||||
key=lambda item: comparable_departure(item[0], config),
|
||||
)
|
||||
for booking, block in new_items:
|
||||
insert_at = insertion_index(existing_bookings, [booking], config)
|
||||
existing_bookings.insert(insert_at, booking)
|
||||
blocks.insert(insert_at, block)
|
||||
|
||||
return join_yaml_list_blocks(preamble, blocks, trailing)
|
||||
|
||||
|
||||
def import_booking_yaml(
|
||||
generated_yaml: str,
|
||||
config: BookingConfig,
|
||||
data_dir: Path = PERSONAL_DATA_DIR,
|
||||
) -> int:
|
||||
"""Add generated booking YAML to the configured personal-data file."""
|
||||
bookings = generated_bookings_from_yaml(generated_yaml)
|
||||
add_trip_dates(bookings, config, data_dir)
|
||||
new_yaml = dump_generated_bookings(bookings)
|
||||
|
||||
yaml_path = data_dir / config.yaml_filename
|
||||
existing_text = yaml_path.read_text()
|
||||
updated_text = insert_booking_text(existing_text, new_yaml, config)
|
||||
yaml_path.write_text(updated_text)
|
||||
return len(bookings)
|
||||
|
||||
|
||||
def main_for_type(booking_type: str, argv: list[str] | None = None) -> int:
|
||||
"""CLI entrypoint for a specific booking type."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
f"Generate {booking_type} booking YAML from stdin or a URL and import it."
|
||||
)
|
||||
)
|
||||
parser.add_argument("url", nargs="?", help="Booking URL to fetch")
|
||||
parser.add_argument("--model", default=os.environ.get("OPENAI_MODEL", "gpt-5.4"))
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
default=str(PERSONAL_DATA_DIR),
|
||||
help="Directory containing personal-data YAML files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print-only",
|
||||
action="store_true",
|
||||
help="Print generated YAML instead of editing the personal-data file.",
|
||||
)
|
||||
parsed = parser.parse_args(argv)
|
||||
|
||||
config = BOOKING_CONFIGS[booking_type]
|
||||
args = [parsed.url] if parsed.url else []
|
||||
booking_text = booking_text_from_args(args)
|
||||
new_yaml = generate_booking_yaml(booking_text, config, model=parsed.model)
|
||||
if parsed.print_only:
|
||||
print(new_yaml)
|
||||
return 0
|
||||
|
||||
count = import_booking_yaml(new_yaml, config, data_dir=Path(parsed.data_dir))
|
||||
print(f"Imported {count} {booking_type} booking(s).")
|
||||
return 0
|
||||
|
||||
|
||||
def train_main(argv: list[str] | None = None) -> int:
|
||||
"""CLI entrypoint for train booking YAML generation."""
|
||||
return main_for_type("train", argv)
|
||||
|
||||
|
||||
def flight_main(argv: list[str] | None = None) -> int:
|
||||
"""CLI entrypoint for flight booking YAML generation."""
|
||||
return main_for_type("flight", argv)
|
||||
|
|
@ -119,6 +119,47 @@ def get_unbooked_flight_origin_iata(
|
|||
return "LHR"
|
||||
|
||||
|
||||
UNBOOKED_RAIL_DESTINATIONS: dict[tuple[str, str], tuple[str, str]] = {
|
||||
("fr", "paris"): ("London St Pancras", "Paris Gare du Nord"),
|
||||
}
|
||||
|
||||
|
||||
def load_station_lookup(data_dir: str) -> dict[str, StrDict]:
|
||||
"""Load stations keyed by station name."""
|
||||
stations = travel.parse_yaml("stations", data_dir)
|
||||
return {station["name"]: station for station in stations}
|
||||
|
||||
|
||||
def get_unbooked_rail_route(item: StrDict, data_dir: str) -> StrDict | None:
|
||||
"""Return an assumed rail route for a conference without booked travel."""
|
||||
location = item.get("location")
|
||||
country = item.get("country")
|
||||
if not isinstance(location, str) or not isinstance(country, str):
|
||||
return None
|
||||
|
||||
station_names = UNBOOKED_RAIL_DESTINATIONS.get(
|
||||
(country.casefold(), normalize_place_name(location))
|
||||
)
|
||||
if station_names is None:
|
||||
return None
|
||||
|
||||
stations = load_station_lookup(data_dir)
|
||||
from_station = stations.get(station_names[0])
|
||||
to_station = stations.get(station_names[1])
|
||||
if from_station is None or to_station is None:
|
||||
return None
|
||||
|
||||
key = "_".join(["train"] + sorted([from_station["name"], to_station["name"]]))
|
||||
route: StrDict = {"type": "train", "key": key}
|
||||
geojson_filename = from_station.get("routes", {}).get(to_station["name"])
|
||||
if geojson_filename:
|
||||
route["geojson_filename"] = os.path.join("train_routes", geojson_filename)
|
||||
else:
|
||||
route["from"] = latlon_tuple(from_station)
|
||||
route["to"] = latlon_tuple(to_station)
|
||||
return route
|
||||
|
||||
|
||||
def load_travel(travel_type: str, plural: str, data_dir: str) -> list[StrDict]:
|
||||
"""Read flight and train journeys."""
|
||||
items: list[StrDict] = travel.parse_yaml(plural, data_dir)
|
||||
|
|
@ -473,10 +514,10 @@ def conference_free_days(trip: Trip) -> dict[str, tuple[int, int]]:
|
|||
return {}
|
||||
|
||||
def conf_attend_start(c: StrDict) -> date:
|
||||
return typing.cast(date, as_date(c.get("attend_start") or c["start"]))
|
||||
return as_date(c.get("attend_start") or c["start"])
|
||||
|
||||
def conf_attend_end(c: StrDict) -> date:
|
||||
return typing.cast(date, as_date(c.get("attend_end") or c["end"]))
|
||||
return as_date(c.get("attend_end") or c["end"])
|
||||
|
||||
sorted_confs = sorted(trip.conferences, key=conf_attend_start)
|
||||
result: dict[str, tuple[int, int]] = {}
|
||||
|
|
@ -634,6 +675,11 @@ def get_trip_routes(trip: Trip, data_dir: str) -> list[StrDict]:
|
|||
|
||||
unbooked_routes = []
|
||||
for item in trip.conferences:
|
||||
unbooked_rail_route = get_unbooked_rail_route(item, data_dir)
|
||||
if unbooked_rail_route is not None:
|
||||
unbooked_routes.append(unbooked_rail_route)
|
||||
continue
|
||||
|
||||
if item["country"] in {"gb", "be"}: # not flying to Belgium
|
||||
continue
|
||||
destination_iata = preferred_airport_iata(item, airports)
|
||||
|
|
@ -870,7 +916,7 @@ def _trip_element_label(element: TripElement) -> str:
|
|||
return start_loc
|
||||
if isinstance(end_loc, str):
|
||||
return end_loc
|
||||
return typing.cast(str, element.title)
|
||||
return element.title
|
||||
|
||||
|
||||
def _trip_element_summary(trip: Trip, element: TripElement) -> str:
|
||||
|
|
|
|||
15
scripts/build_airport_yaml
Executable file
15
scripts/build_airport_yaml
Executable file
|
|
@ -0,0 +1,15 @@
|
|||
#!/usr/bin/python3
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
SCRIPT_PATH = os.path.realpath(__file__)
|
||||
SCRIPT_DIR = os.path.dirname(SCRIPT_PATH)
|
||||
REPO_ROOT = os.path.dirname(SCRIPT_DIR)
|
||||
if REPO_ROOT not in sys.path:
|
||||
sys.path.insert(0, REPO_ROOT)
|
||||
|
||||
from agenda.build_place_yaml import airport_main
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(airport_main())
|
||||
15
scripts/build_station_yaml
Executable file
15
scripts/build_station_yaml
Executable file
|
|
@ -0,0 +1,15 @@
|
|||
#!/usr/bin/python3
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
SCRIPT_PATH = os.path.realpath(__file__)
|
||||
SCRIPT_DIR = os.path.dirname(SCRIPT_PATH)
|
||||
REPO_ROOT = os.path.dirname(SCRIPT_DIR)
|
||||
if REPO_ROOT not in sys.path:
|
||||
sys.path.insert(0, REPO_ROOT)
|
||||
|
||||
from agenda.build_place_yaml import station_main
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(station_main())
|
||||
15
scripts/generate-flight-booking-yaml
Executable file
15
scripts/generate-flight-booking-yaml
Executable file
|
|
@ -0,0 +1,15 @@
|
|||
#!/usr/bin/python3
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
SCRIPT_PATH = os.path.realpath(__file__)
|
||||
SCRIPT_DIR = os.path.dirname(SCRIPT_PATH)
|
||||
REPO_ROOT = os.path.dirname(SCRIPT_DIR)
|
||||
if REPO_ROOT not in sys.path:
|
||||
sys.path.insert(0, REPO_ROOT)
|
||||
|
||||
from agenda.generate_booking_yaml import flight_main
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(flight_main())
|
||||
15
scripts/generate-train-booking-yaml
Executable file
15
scripts/generate-train-booking-yaml
Executable file
|
|
@ -0,0 +1,15 @@
|
|||
#!/usr/bin/python3
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
SCRIPT_PATH = os.path.realpath(__file__)
|
||||
SCRIPT_DIR = os.path.dirname(SCRIPT_PATH)
|
||||
REPO_ROOT = os.path.dirname(SCRIPT_DIR)
|
||||
if REPO_ROOT not in sys.path:
|
||||
sys.path.insert(0, REPO_ROOT)
|
||||
|
||||
from agenda.generate_booking_yaml import train_main
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(train_main())
|
||||
101
tests/test_build_place_yaml.py
Normal file
101
tests/test_build_place_yaml.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
"""Tests for agenda.build_place_yaml."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from agenda import build_place_yaml
|
||||
|
||||
|
||||
def test_upsert_station_adds_new_station(tmp_path: Path) -> None:
|
||||
"""Station upsert should add a new station to stations.yaml."""
|
||||
path = tmp_path / "stations.yaml"
|
||||
path.write_text("""- name: London St Pancras
|
||||
latitude: 51.531921
|
||||
longitude: -0.126361
|
||||
country: gb
|
||||
wikidata: Q720102
|
||||
routes: {}
|
||||
""")
|
||||
|
||||
replaced = build_place_yaml.upsert_station(
|
||||
tmp_path,
|
||||
{
|
||||
"name": "Paris Gare du Nord",
|
||||
"latitude": 48.8809,
|
||||
"longitude": 2.3553,
|
||||
"country": "fr",
|
||||
"wikidata": "Q624511",
|
||||
"routes": {},
|
||||
},
|
||||
)
|
||||
|
||||
stations = yaml.safe_load(path.read_text())
|
||||
assert replaced is False
|
||||
assert [station["name"] for station in stations] == [
|
||||
"London St Pancras",
|
||||
"Paris Gare du Nord",
|
||||
]
|
||||
assert "\n\n- name: Paris Gare du Nord\n" in path.read_text()
|
||||
|
||||
|
||||
def test_upsert_station_replaces_existing_station(tmp_path: Path) -> None:
|
||||
"""Station upsert should replace an existing station with the same name."""
|
||||
path = tmp_path / "stations.yaml"
|
||||
path.write_text("""- name: Paris Gare du Nord
|
||||
latitude: 0
|
||||
longitude: 0
|
||||
country: fr
|
||||
wikidata: Q624511
|
||||
routes: {}
|
||||
""")
|
||||
|
||||
replaced = build_place_yaml.upsert_station(
|
||||
tmp_path,
|
||||
{
|
||||
"name": "Paris Gare du Nord",
|
||||
"latitude": 48.8809,
|
||||
"longitude": 2.3553,
|
||||
"country": "fr",
|
||||
"wikidata": "Q624511",
|
||||
"routes": {},
|
||||
},
|
||||
)
|
||||
|
||||
stations = yaml.safe_load(path.read_text())
|
||||
assert replaced is True
|
||||
assert len(stations) == 1
|
||||
assert stations[0]["latitude"] == 48.8809
|
||||
assert "\n\n- name:" not in path.read_text()
|
||||
|
||||
|
||||
def test_upsert_airport_adds_mapping_entry(tmp_path: Path) -> None:
|
||||
"""Airport upsert should add a new IATA-keyed airport entry."""
|
||||
path = tmp_path / "airports.yaml"
|
||||
path.write_text("""LHR:
|
||||
iata: LHR
|
||||
name: Heathrow Airport
|
||||
city: London
|
||||
country: gb
|
||||
latitude: 51.47
|
||||
longitude: -0.4543
|
||||
qid: Q8691
|
||||
""")
|
||||
|
||||
replaced = build_place_yaml.upsert_airport(
|
||||
tmp_path,
|
||||
{
|
||||
"iata": "ORY",
|
||||
"name": "Paris Orly Airport",
|
||||
"city": "Paris Orly Airport",
|
||||
"country": "fr",
|
||||
"latitude": 48.723333,
|
||||
"longitude": 2.379444,
|
||||
"qid": "Q193353",
|
||||
},
|
||||
)
|
||||
|
||||
airports = yaml.safe_load(path.read_text())
|
||||
assert replaced is False
|
||||
assert list(airports) == ["LHR", "ORY"]
|
||||
assert airports["ORY"]["country"] == "fr"
|
||||
187
tests/test_generate_booking_yaml.py
Normal file
187
tests/test_generate_booking_yaml.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
"""Tests for agenda.generate_booking_yaml."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from agenda import generate_booking_yaml
|
||||
|
||||
|
||||
@dataclass
|
||||
class FakeTrip:
|
||||
"""Small stand-in for agenda.types.Trip in matching tests."""
|
||||
|
||||
start: date
|
||||
end: date | None
|
||||
|
||||
|
||||
def test_yaml_format_description_includes_train_spec() -> None:
|
||||
"""Train prompt docs should come from personal-data-yaml.md."""
|
||||
description = generate_booking_yaml.yaml_format_description(
|
||||
generate_booking_yaml.BOOKING_CONFIGS["train"]
|
||||
)
|
||||
|
||||
assert "## General Rules" in description
|
||||
assert "## Cross-File References" in description
|
||||
assert "## `trains.yaml`" in description
|
||||
assert "`operator`: booking/operator label" in description
|
||||
|
||||
|
||||
def test_build_prompt_uses_spec_and_keeps_trip_exclusion() -> None:
|
||||
"""The generated prompt should use the documented schema."""
|
||||
prompt = generate_booking_yaml.build_prompt(
|
||||
"Eurostar booking details",
|
||||
generate_booking_yaml.BOOKING_CONFIGS["train"],
|
||||
current_bookings="- operator: eurostar\n",
|
||||
)
|
||||
|
||||
assert "Use this YAML format specification" in prompt
|
||||
assert "## `trains.yaml`" in prompt
|
||||
assert 'Exclude the top-level "trip" key' in prompt
|
||||
assert "Eurostar booking details" in prompt
|
||||
|
||||
|
||||
def test_booking_text_from_args_fetches_url(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""URL arguments should be fetched instead of reading stdin."""
|
||||
monkeypatch.setattr(
|
||||
generate_booking_yaml,
|
||||
"url_to_booking_text",
|
||||
lambda url: f"fetched {url}",
|
||||
)
|
||||
|
||||
assert generate_booking_yaml.booking_text_from_args(["https://example.com"]) == (
|
||||
"fetched https://example.com"
|
||||
)
|
||||
|
||||
|
||||
def test_matching_trip_date_uses_built_trip_end(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Trip matching should use build_trip_list-derived end dates."""
|
||||
monkeypatch.setattr(
|
||||
generate_booking_yaml,
|
||||
"build_trips",
|
||||
lambda data_dir: [ # noqa: ARG005
|
||||
FakeTrip(date(2026, 2, 6), date(2026, 2, 9)),
|
||||
FakeTrip(date(2026, 3, 3), date(2026, 3, 5)),
|
||||
],
|
||||
)
|
||||
|
||||
assert generate_booking_yaml.matching_trip_date(
|
||||
generate_booking_yaml.datetime_from_yaml_value("2026-02-08 12:00:00+01:00"),
|
||||
Path("/tmp/personal-data"),
|
||||
) == date(2026, 2, 6)
|
||||
|
||||
|
||||
def test_matching_trip_date_falls_back_to_departure_date(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A booking outside known trips should use its first departure date."""
|
||||
monkeypatch.setattr(
|
||||
generate_booking_yaml, "build_trips", lambda data_dir: [] # noqa: ARG005
|
||||
)
|
||||
|
||||
assert generate_booking_yaml.matching_trip_date(
|
||||
generate_booking_yaml.datetime_from_yaml_value("2026-04-10 12:00:00+01:00"),
|
||||
Path("/tmp/personal-data"),
|
||||
) == date(2026, 4, 10)
|
||||
|
||||
|
||||
def test_import_train_booking_adds_trip_and_inserts_chronologically(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Generated train bookings should be written into trains.yaml in order."""
|
||||
(tmp_path / "trains.yaml").write_text("""- operator: early
|
||||
from: A
|
||||
to: B
|
||||
trip: 2026-02-01
|
||||
depart: 2026-02-01 10:00:00+00:00
|
||||
arrive: 2026-02-01 11:00:00+00:00
|
||||
legs: []
|
||||
|
||||
- operator: late
|
||||
from: C
|
||||
to: D
|
||||
trip: 2026-02-10
|
||||
depart: 2026-02-10 10:00:00+00:00
|
||||
arrive: 2026-02-10 11:00:00+00:00
|
||||
legs: []
|
||||
""")
|
||||
monkeypatch.setattr(
|
||||
generate_booking_yaml,
|
||||
"matching_trip_date",
|
||||
lambda depart, data_dir: date(2026, 2, 6),
|
||||
)
|
||||
|
||||
count = generate_booking_yaml.import_booking_yaml(
|
||||
"""- operator: eurostar
|
||||
from: London St Pancras
|
||||
to: Brussels Midi
|
||||
depart: 2026-02-06 15:04:00+00:00
|
||||
arrive: 2026-02-06 18:12:00+01:00
|
||||
legs: []
|
||||
""",
|
||||
generate_booking_yaml.BOOKING_CONFIGS["train"],
|
||||
data_dir=tmp_path,
|
||||
)
|
||||
|
||||
written = yaml.safe_load((tmp_path / "trains.yaml").read_text())
|
||||
assert count == 1
|
||||
assert [item["operator"] for item in written] == ["early", "eurostar", "late"]
|
||||
assert written[1]["trip"] == date(2026, 2, 6)
|
||||
assert list(written[1]).index("trip") == 3
|
||||
assert "\n\n- operator: eurostar\n" in (tmp_path / "trains.yaml").read_text()
|
||||
|
||||
|
||||
def test_import_flight_booking_adds_trip_and_inserts_chronologically(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Generated flight bookings should be written into flights.yaml in order."""
|
||||
(tmp_path / "flights.yaml").write_text("""---
|
||||
- booking_reference: OLD
|
||||
trip: 2026-02-01
|
||||
flights:
|
||||
- depart: 2026-02-01 10:00:00+00:00
|
||||
from: BRS
|
||||
to: AMS
|
||||
flight_number: '1'
|
||||
airline: U2
|
||||
|
||||
- booking_reference: NEWER
|
||||
trip: 2026-02-10
|
||||
flights:
|
||||
- depart: 2026-02-10 10:00:00+00:00
|
||||
from: BRS
|
||||
to: AMS
|
||||
flight_number: '2'
|
||||
airline: U2
|
||||
""")
|
||||
monkeypatch.setattr(
|
||||
generate_booking_yaml,
|
||||
"matching_trip_date",
|
||||
lambda depart, data_dir: date(2026, 2, 6),
|
||||
)
|
||||
|
||||
count = generate_booking_yaml.import_booking_yaml(
|
||||
"""booking_reference: MID
|
||||
flights:
|
||||
- depart: 2026-02-06 10:00:00+00:00
|
||||
from: BRS
|
||||
to: AMS
|
||||
flight_number: '3'
|
||||
airline: U2
|
||||
""",
|
||||
generate_booking_yaml.BOOKING_CONFIGS["flight"],
|
||||
data_dir=tmp_path,
|
||||
)
|
||||
|
||||
written = yaml.safe_load((tmp_path / "flights.yaml").read_text())
|
||||
assert count == 1
|
||||
assert [item["booking_reference"] for item in written] == ["OLD", "MID", "NEWER"]
|
||||
assert written[1]["trip"] == date(2026, 2, 6)
|
||||
assert list(written[1]).index("trip") == 1
|
||||
|
|
@ -3,6 +3,7 @@
|
|||
from datetime import date
|
||||
|
||||
import agenda.trip
|
||||
import pytest
|
||||
from agenda.types import Trip
|
||||
from web_view import app
|
||||
|
||||
|
|
@ -101,3 +102,68 @@ def test_get_coordinates_and_routes_adds_unbooked_flight_airports() -> None:
|
|||
coord["name"] for coord in coordinates if coord["type"] == "airport"
|
||||
}
|
||||
assert airport_names == {"Heathrow Airport", "Paris Charles de Gaulle Airport"}
|
||||
|
||||
|
||||
def test_get_trip_routes_assumes_unbooked_paris_trip_is_by_train(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Paris conferences without booked travel should show rail, not flight."""
|
||||
trip = Trip(
|
||||
start=date(2026, 7, 20),
|
||||
conferences=[
|
||||
{
|
||||
"name": "Paris Conf",
|
||||
"location": "Paris",
|
||||
"country": "fr",
|
||||
}
|
||||
],
|
||||
)
|
||||
stations = [
|
||||
{
|
||||
"name": "London St Pancras",
|
||||
"latitude": 51.531921,
|
||||
"longitude": -0.126361,
|
||||
"routes": {"Paris Gare du Nord": "London_St_Pancras_to_Paris_Gare_du_Nord"},
|
||||
},
|
||||
{
|
||||
"name": "Paris Gare du Nord",
|
||||
"latitude": 48.88111111111111,
|
||||
"longitude": 2.355277777777778,
|
||||
"routes": {"London St Pancras": "London_St_Pancras_to_Paris_Gare_du_Nord"},
|
||||
},
|
||||
]
|
||||
|
||||
def fake_parse_yaml(name: str, data_dir: str) -> object:
|
||||
if name == "stations":
|
||||
return stations
|
||||
if name == "airports":
|
||||
return {
|
||||
"LHR": {
|
||||
"name": "Heathrow Airport",
|
||||
"latitude": 51.47,
|
||||
"longitude": -0.45,
|
||||
},
|
||||
"CDG": {
|
||||
"name": "Paris Charles de Gaulle Airport",
|
||||
"city": "Paris",
|
||||
"country": "fr",
|
||||
"latitude": 49.01,
|
||||
"longitude": 2.55,
|
||||
},
|
||||
}
|
||||
raise AssertionError(f"unexpected YAML load: {name}")
|
||||
|
||||
monkeypatch.setattr(agenda.trip.travel, "parse_yaml", fake_parse_yaml)
|
||||
monkeypatch.setattr(
|
||||
agenda.trip, "load_flight_destination_rules", lambda _data_dir: []
|
||||
)
|
||||
|
||||
routes = agenda.trip.get_trip_routes(trip, "/tmp/personal-data")
|
||||
|
||||
assert routes == [
|
||||
{
|
||||
"type": "train",
|
||||
"key": "train_London St Pancras_Paris Gare du Nord",
|
||||
"geojson_filename": "train_routes/London_St_Pancras_to_Paris_Gare_du_Nord",
|
||||
}
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue