diff --git a/agenda/build_place_yaml.py b/agenda/build_place_yaml.py new file mode 100644 index 0000000..bc346f4 --- /dev/null +++ b/agenda/build_place_yaml.py @@ -0,0 +1,257 @@ +"""Build airport and station YAML entries from Wikidata.""" + +import argparse +import typing +from pathlib import Path + +import requests +import yaml + +API_URL = "https://www.wikidata.org/w/api.php" +PERSONAL_DATA_DIR = Path("~/src/personal-data").expanduser() +USER_AGENT = "agenda-build-place-yaml/0.1" + +Entity = dict[str, typing.Any] +Entities = dict[str, Entity] + + +class WikidataClient: + """Small Wikidata API client for place lookups.""" + + def __init__(self) -> None: + """Create a Wikidata API session.""" + self.session = requests.Session() + self.session.headers.update({"User-Agent": USER_AGENT}) + + def get_json(self, params: dict[str, str]) -> dict[str, typing.Any]: + """Fetch JSON from Wikidata.""" + response = self.session.get(API_URL, params=params) + response.raise_for_status() + return typing.cast(dict[str, typing.Any], response.json()) + + def get_alpha2_country_code(self, qid: str) -> str: + """Query the Wikidata API for alpha-2 country code.""" + data = self.get_json( + { + "action": "wbgetclaims", + "entity": qid, + "property": "P297", + "format": "json", + } + ) + p297 = data["claims"]["P297"] + return typing.cast(str, p297[0]["mainsnak"]["datavalue"]["value"]).lower() + + def search_entities(self, query: str) -> Entities: + """Search Wikidata and return detailed entities.""" + search_data = self.get_json( + { + "action": "query", + "list": "search", + "format": "json", + "srsearch": query, + } + ) + search_results = search_data["query"]["search"] + if not search_results: + return {} + + ids = [result["title"] for result in search_results] + entity_data = self.get_json( + { + "action": "wbgetentities", + "format": "json", + "ids": "|".join(ids), + } + ) + return typing.cast(Entities, entity_data["entities"]) + + +def entity_names(entity: Entity) -> set[str]: + """Return labels and aliases for a Wikidata entity.""" + names: set[str] = {lang["value"] for lang in entity.get("labels", {}).values()} + for alias_list in entity.get("aliases", {}).values(): + for alias in alias_list: + names.add(alias["value"]) + return names + + +def exact_station_match(station_name: str, entity: Entity) -> bool: + """Return whether the entity names match the requested station name.""" + names = entity_names(entity) + return station_name in names or f"{station_name} station" in names + + +def entity_claim_value( + entity: Entity, property_id: str, default: typing.Any = None +) -> typing.Any: + """Return the first Wikidata claim value for a property.""" + claims = entity.get("claims", {}) + if property_id not in claims: + return default + return claims[property_id][0]["mainsnak"]["datavalue"]["value"] + + +def build_station_info( + client: WikidataClient, station_name: str, entity_id: str, entity: Entity +) -> dict[str, typing.Any]: + """Build a stations.yaml entry.""" + coords = entity_claim_value(entity, "P625") + country_value = entity_claim_value(entity, "P17") + uic = entity_claim_value(entity, "P722") + uk_station_code = entity_claim_value(entity, "P4755") + + station_info: dict[str, typing.Any] = { + "name": station_name, + "latitude": coords["latitude"], + "longitude": coords["longitude"], + "country": client.get_alpha2_country_code(country_value["id"]), + "wikidata": entity_id, + "routes": {}, + } + + if uic is not None: + station_info["uic"] = uic + if uk_station_code is not None: + station_info["alpha3"] = uk_station_code + + return station_info + + +def search_for_station( + client: WikidataClient, station_name: str +) -> dict[str, typing.Any]: + """Search for a station and return a stations.yaml entry.""" + haswbstatement = "P31=Q55488|P31=Q18543139|P31=Q1147171" + entities = client.search_entities(f"{station_name} haswbstatement:{haswbstatement}") + + for entity_id, entity in entities.items(): + if exact_station_match(station_name, entity): + return build_station_info(client, station_name, entity_id, entity) + + if entities: + entity_id, entity = next(iter(entities.items())) + return build_station_info(client, station_name, entity_id, entity) + + raise ValueError(f"No Wikidata station found for {station_name!r}") + + +def build_airport_info( + client: WikidataClient, iata: str, entity_id: str, entity: Entity +) -> dict[str, typing.Any]: + """Build an airports.yaml entry.""" + label = entity["labels"]["en"]["value"] + claims = entity["claims"] + coords = claims["P625"][0]["mainsnak"]["datavalue"]["value"] + country_qid = claims["P17"][0]["mainsnak"]["datavalue"]["value"]["id"] + + info: dict[str, typing.Any] = { + "iata": iata, + "name": label, + "city": label, + "country": client.get_alpha2_country_code(country_qid), + "latitude": coords["latitude"], + "longitude": coords["longitude"], + "qid": entity_id, + } + + website = entity_claim_value(entity, "P856") + if website is not None: + info["website"] = website + + return info + + +def search_for_airport(client: WikidataClient, iata: str) -> dict[str, typing.Any]: + """Search for an airport by IATA code and return an airports.yaml entry.""" + entities = client.search_entities(f"haswbstatement:P238={iata.upper()}") + if not entities: + raise ValueError(f"No Wikidata airport found for IATA code {iata!r}") + + entity_id, entity = next(iter(entities.items())) + return build_airport_info(client, iata.upper(), entity_id, entity) + + +def load_yaml(path: Path) -> typing.Any: + """Load a YAML file.""" + return yaml.safe_load(path.read_text()) + + +def dump_yaml(data: typing.Any) -> str: + """Dump YAML using the local personal-data style.""" + return yaml.dump(data, sort_keys=False, allow_unicode=True) + + +def dump_yaml_list_with_blank_lines(items: list[dict[str, typing.Any]]) -> str: + """Dump a YAML list with a blank line between top-level items.""" + text = dump_yaml(items).lstrip() + return text.replace("\n- ", "\n\n- ") + + +def upsert_station(data_dir: Path, station_info: dict[str, typing.Any]) -> bool: + """Add or replace a station entry. Return True when an existing entry changed.""" + path = data_dir / "stations.yaml" + stations = typing.cast(list[dict[str, typing.Any]], load_yaml(path)) + + for index, station in enumerate(stations): + if station.get("name") == station_info["name"]: + stations[index] = station_info + path.write_text(dump_yaml_list_with_blank_lines(stations)) + return True + + stations.append(station_info) + path.write_text(dump_yaml_list_with_blank_lines(stations)) + return False + + +def upsert_airport(data_dir: Path, airport_info: dict[str, typing.Any]) -> bool: + """Add or replace an airport entry. Return True when an existing entry changed.""" + path = data_dir / "airports.yaml" + airports = typing.cast(dict[str, dict[str, typing.Any]], load_yaml(path)) + iata = typing.cast(str, airport_info["iata"]) + existed = iata in airports + airports[iata] = airport_info + path.write_text(dump_yaml(airports)) + return existed + + +def station_main(argv: list[str] | None = None) -> int: + """CLI entrypoint for building and importing a station YAML entry.""" + parser = argparse.ArgumentParser( + description="Add or update a station in personal-data/stations.yaml." + ) + parser.add_argument("station_name") + parser.add_argument("--data-dir", default=str(PERSONAL_DATA_DIR)) + parser.add_argument("--print-only", action="store_true") + args = parser.parse_args(argv) + + station_info = search_for_station(WikidataClient(), args.station_name) + if args.print_only: + print(dump_yaml_list_with_blank_lines([station_info]).strip()) + return 0 + + replaced = upsert_station(Path(args.data_dir), station_info) + action = "Updated" if replaced else "Added" + print(f"{action} station {station_info['name']} in {args.data_dir}/stations.yaml") + return 0 + + +def airport_main(argv: list[str] | None = None) -> int: + """CLI entrypoint for building and importing an airport YAML entry.""" + parser = argparse.ArgumentParser( + description="Add or update an airport in personal-data/airports.yaml." + ) + parser.add_argument("iata") + parser.add_argument("--data-dir", default=str(PERSONAL_DATA_DIR)) + parser.add_argument("--print-only", action="store_true") + args = parser.parse_args(argv) + + airport_info = search_for_airport(WikidataClient(), args.iata) + if args.print_only: + print(dump_yaml({airport_info["iata"]: airport_info}).strip()) + return 0 + + replaced = upsert_airport(Path(args.data_dir), airport_info) + action = "Updated" if replaced else "Added" + print(f"{action} airport {airport_info['iata']} in {args.data_dir}/airports.yaml") + return 0 diff --git a/agenda/generate_booking_yaml.py b/agenda/generate_booking_yaml.py new file mode 100644 index 0000000..f9743c1 --- /dev/null +++ b/agenda/generate_booking_yaml.py @@ -0,0 +1,457 @@ +"""Generate travel booking YAML from booking text or a booking URL.""" + +import argparse +import configparser +import importlib +import json +import os +import sys +import typing +from dataclasses import dataclass +from datetime import date, datetime +from pathlib import Path + +import html2text +import lxml.html +import openai +import requests +import yaml + +USER_AGENT = "generate-booking-yaml/0.1" +REPO_ROOT = Path(__file__).resolve().parent.parent +SPEC_PATH = REPO_ROOT / "docs" / "personal-data-yaml.md" +PERSONAL_DATA_DIR = Path("~/src/personal-data").expanduser() + + +class TripLike(typing.Protocol): + """Trip attributes needed for booking import matching.""" + + start: date + end: date | None + + +@dataclass(frozen=True) +class BookingConfig: + """Configuration for one travel booking YAML generator.""" + + booking_type: str + yaml_filename: str + spec_heading: str + json_key: str = "booking" + + +BOOKING_CONFIGS: dict[str, BookingConfig] = { + "flight": BookingConfig( + booking_type="flight", + yaml_filename="flights.yaml", + spec_heading="flights.yaml", + ), + "train": BookingConfig( + booking_type="train", + yaml_filename="trains.yaml", + spec_heading="trains.yaml", + ), +} + + +def read_api_key() -> str: + """Read API key from ~/.config/openai/config.""" + config_path = os.path.expanduser("~/.config/openai/config") + parser = configparser.ConfigParser() + parser.read(config_path) + return parser["openai"]["api_key"] + + +def read_markdown_section(markdown_text: str, heading: str) -> str: + """Return one second-level markdown section by heading text.""" + section_headings = (f"## `{heading}`", f"## {heading}") + start = -1 + matched_heading = "" + for section_heading in section_headings: + start = markdown_text.find(section_heading) + if start != -1: + matched_heading = section_heading + break + if start == -1: + raise ValueError(f"Could not find section for heading {heading!r}") + + next_heading = markdown_text.find("\n## ", start + len(matched_heading)) + if next_heading == -1: + return markdown_text[start:].strip() + return markdown_text[start:next_heading].strip() + + +def yaml_format_description(config: BookingConfig) -> str: + """Return relevant personal-data YAML documentation for the prompt.""" + spec_text = SPEC_PATH.read_text() + sections = [ + read_markdown_section(spec_text, "General Rules"), + read_markdown_section(spec_text, "Cross-File References"), + read_markdown_section(spec_text, config.spec_heading), + ] + return "\n\n".join(sections) + + +def read_existing_bookings( + config: BookingConfig, data_dir: Path = PERSONAL_DATA_DIR +) -> str: + """Read the existing YAML file for examples and local style.""" + path = data_dir / config.yaml_filename + return path.read_text() + + +def build_prompt( + booking_text: str, + config: BookingConfig, + current_bookings: str | None = None, +) -> str: + """Build prompt to pass to the LLM.""" + bookings = current_bookings + if bookings is None: + bookings = read_existing_bookings(config) + + return f""" +I keep a record of all my {config.booking_type} bookings in a YAML file. + +Use this YAML format specification: + +{yaml_format_description(config)} + +Here's my current list of bookings for examples of local style and known +references. +=== +{bookings} +=== +Here's a new booking I just made. + +Return the YAML representation for this booking using the documented format and +the same local style as my existing bookings. + +Rules: +- Wrap the response in a JSON object with a single key "{config.json_key}" that + contains the booking in YAML. +- The value of "{config.json_key}" must be YAML text, not JSON. +- Exclude the top-level "trip" key from the YAML. +- Do not invent details that are not present in the booking text. +- Quote prices and identifiers that might otherwise be parsed as numbers. + +=== +{booking_text} +""" + + +def get_from_open_ai(prompt: str, model: str = "gpt-5.4") -> dict[str, str]: + """Pass prompt to OpenAI and get reply.""" + client = openai.OpenAI(api_key=read_api_key()) + + response = client.chat.completions.create( + messages=[{"role": "user", "content": prompt}], + model=model, + response_format={"type": "json_object"}, + ) + + reply = response.choices[0].message.content + assert isinstance(reply, str) + return typing.cast(dict[str, str], json.loads(reply)) + + +def fetch_webpage(url: str) -> lxml.html.HtmlElement: + """Fetch webpage HTML and parse it.""" + response = requests.get(url, headers={"User-Agent": USER_AGENT}) + response.raise_for_status() + return lxml.html.fromstring(response.content) + + +def webpage_to_text(root: lxml.html.HtmlElement) -> str: + """Convert parsed HTML into readable text content.""" + root_copy = lxml.html.fromstring(lxml.html.tostring(root)) + + for script_or_style in root_copy.xpath("//script|//style"): + script_or_style.drop_tree() + + text_maker = html2text.HTML2Text() + text_maker.ignore_links = False + text_maker.ignore_images = True + return text_maker.handle(lxml.html.tostring(root_copy, encoding="unicode")) + + +def url_to_booking_text(url: str) -> str: + """Fetch a URL and convert it to source text for the model.""" + return webpage_to_text(fetch_webpage(url)) + + +def booking_text_from_args(args: list[str]) -> str: + """Return booking text from a URL argument or stdin.""" + if args: + if len(args) != 1: + raise SystemExit("Usage: generate-BOOKING-booking-yaml [URL]") + return url_to_booking_text(args[0]) + return sys.stdin.read() + + +def generate_booking_yaml( + booking_text: str, config: BookingConfig, model: str = "gpt-5.4" +) -> str: + """Generate booking YAML from source text.""" + prompt = build_prompt(booking_text, config) + return get_from_open_ai(prompt, model=model)[config.json_key] + + +def datetime_from_yaml_value(value: typing.Any) -> datetime: + """Convert a YAML date/datetime/string value into a datetime.""" + if isinstance(value, datetime): + return value + if isinstance(value, date): + return datetime.combine(value, datetime.min.time()) + if isinstance(value, str): + parsed = datetime.fromisoformat(value) + return parsed + raise TypeError(f"Unsupported departure value: {value!r}") + + +def first_departure(booking: dict[str, typing.Any], config: BookingConfig) -> datetime: + """Return the first departure datetime for a generated booking.""" + if config.booking_type == "flight": + flights = booking["flights"] + assert isinstance(flights, list) + first_flight = flights[0] + assert isinstance(first_flight, dict) + return datetime_from_yaml_value(first_flight["depart"]) + + return datetime_from_yaml_value(booking["depart"]) + + +def comparable_departure( + booking: dict[str, typing.Any], config: BookingConfig +) -> datetime: + """Return a timezone-naive departure datetime for sorting.""" + return first_departure(booking, config).replace(tzinfo=None) + + +def generated_bookings_from_yaml(yaml_text: str) -> list[dict[str, typing.Any]]: + """Parse generated booking YAML into a list of booking mappings.""" + loaded = yaml.safe_load(yaml_text) + if isinstance(loaded, dict): + return [typing.cast(dict[str, typing.Any], loaded)] + if isinstance(loaded, list) and all(isinstance(item, dict) for item in loaded): + return typing.cast(list[dict[str, typing.Any]], loaded) + raise ValueError("Generated booking YAML must be a mapping or list of mappings.") + + +def trip_key_position(booking: dict[str, typing.Any], config: BookingConfig) -> int: + """Return the preferred insertion position for the top-level trip key.""" + keys = list(booking) + if config.booking_type == "flight": + if "booking_reference" in booking: + return keys.index("booking_reference") + 1 + return 0 + + if "to" in booking: + return keys.index("to") + 1 + if "from" in booking: + return keys.index("from") + 1 + return 0 + + +def set_trip_key( + booking: dict[str, typing.Any], config: BookingConfig, trip_date: date +) -> dict[str, typing.Any]: + """Set trip in the usual top-level position while preserving other key order.""" + without_trip = {key: value for key, value in booking.items() if key != "trip"} + keys = list(without_trip) + position = trip_key_position(without_trip, config) + reordered: dict[str, typing.Any] = {} + + for index, key in enumerate(keys): + if index == position: + reordered["trip"] = trip_date + reordered[key] = without_trip[key] + + if "trip" not in reordered: + reordered["trip"] = trip_date + + booking.clear() + booking.update(reordered) + return booking + + +def build_trips(data_dir: Path) -> list[TripLike]: + """Build trips from personal data without importing agenda.trip at module load.""" + trip_module = importlib.import_module("agenda.trip") + build_trip_list = typing.cast( + typing.Callable[..., list[TripLike]], trip_module.build_trip_list + ) + return build_trip_list(data_dir=str(data_dir)) + + +def matching_trip_date(depart: datetime, data_dir: Path = PERSONAL_DATA_DIR) -> date: + """Find the trip grouping date for a departure, falling back to departure date.""" + depart_date = depart.date() + matching_starts: list[date] = [] + + for trip in build_trips(data_dir): + trip_end = trip.end or trip.start + if trip.start <= depart_date <= trip_end: + matching_starts.append(trip.start) + + if matching_starts: + return max(matching_starts) + return depart_date + + +def add_trip_dates( + bookings: list[dict[str, typing.Any]], + config: BookingConfig, + data_dir: Path = PERSONAL_DATA_DIR, +) -> None: + """Add the top-level trip key to generated booking mappings.""" + for booking in bookings: + trip_date = matching_trip_date(first_departure(booking, config), data_dir) + set_trip_key(booking, config, trip_date) + + +def dump_generated_bookings(bookings: list[dict[str, typing.Any]]) -> str: + """Dump only generated bookings for insertion into an existing YAML list.""" + text = yaml.dump(bookings, sort_keys=False, allow_unicode=True) + return text.lstrip() + + +def join_yaml_list_blocks(preamble: str, blocks: list[str], trailing: str = "") -> str: + """Join top-level YAML list blocks with a blank line between items.""" + body = "\n\n".join(block.rstrip("\n") for block in blocks) + return preamble + body + "\n" + trailing + + +def split_yaml_list_blocks(text: str) -> tuple[str, list[str], str]: + """Split a top-level YAML list into preamble, item blocks, and trailing text.""" + lines = text.splitlines(keepends=True) + first_item = next( + (index for index, line in enumerate(lines) if line.startswith("- ")), None + ) + if first_item is None: + return text, [], "" + + item_starts = [ + index + for index, line in enumerate(lines[first_item:], start=first_item) + if line.startswith("- ") + ] + blocks = [ + "".join(lines[start:end]) + for start, end in zip(item_starts, item_starts[1:] + [len(lines)]) + ] + return "".join(lines[:first_item]), blocks, "" + + +def existing_bookings_from_blocks(blocks: list[str]) -> list[dict[str, typing.Any]]: + """Parse split YAML item blocks into booking mappings.""" + bookings = [] + for block in blocks: + loaded = yaml.safe_load(block) + if not isinstance(loaded, list) or len(loaded) != 1: + raise ValueError("Could not parse existing booking block.") + item = loaded[0] + if not isinstance(item, dict): + raise ValueError("Existing booking block is not a mapping.") + bookings.append(typing.cast(dict[str, typing.Any], item)) + return bookings + + +def insertion_index( + existing_bookings: list[dict[str, typing.Any]], + new_bookings: list[dict[str, typing.Any]], + config: BookingConfig, +) -> int: + """Return the chronological insertion index for generated bookings.""" + new_depart = min(comparable_departure(booking, config) for booking in new_bookings) + for index, booking in enumerate(existing_bookings): + if comparable_departure(booking, config) > new_depart: + return index + return len(existing_bookings) + + +def insert_booking_text( + existing_text: str, + new_yaml_text: str, + config: BookingConfig, +) -> str: + """Insert generated booking YAML into an existing top-level YAML list.""" + preamble, blocks, trailing = split_yaml_list_blocks(existing_text) + new_bookings = generated_bookings_from_yaml(new_yaml_text) + _, new_blocks, _ = split_yaml_list_blocks(dump_generated_bookings(new_bookings)) + if len(new_blocks) != len(new_bookings): + raise ValueError("Could not split generated booking YAML into item blocks.") + + existing_bookings = existing_bookings_from_blocks(blocks) + + new_items = sorted( + zip(new_bookings, new_blocks), + key=lambda item: comparable_departure(item[0], config), + ) + for booking, block in new_items: + insert_at = insertion_index(existing_bookings, [booking], config) + existing_bookings.insert(insert_at, booking) + blocks.insert(insert_at, block) + + return join_yaml_list_blocks(preamble, blocks, trailing) + + +def import_booking_yaml( + generated_yaml: str, + config: BookingConfig, + data_dir: Path = PERSONAL_DATA_DIR, +) -> int: + """Add generated booking YAML to the configured personal-data file.""" + bookings = generated_bookings_from_yaml(generated_yaml) + add_trip_dates(bookings, config, data_dir) + new_yaml = dump_generated_bookings(bookings) + + yaml_path = data_dir / config.yaml_filename + existing_text = yaml_path.read_text() + updated_text = insert_booking_text(existing_text, new_yaml, config) + yaml_path.write_text(updated_text) + return len(bookings) + + +def main_for_type(booking_type: str, argv: list[str] | None = None) -> int: + """CLI entrypoint for a specific booking type.""" + parser = argparse.ArgumentParser( + description=( + f"Generate {booking_type} booking YAML from stdin or a URL and import it." + ) + ) + parser.add_argument("url", nargs="?", help="Booking URL to fetch") + parser.add_argument("--model", default=os.environ.get("OPENAI_MODEL", "gpt-5.4")) + parser.add_argument( + "--data-dir", + default=str(PERSONAL_DATA_DIR), + help="Directory containing personal-data YAML files.", + ) + parser.add_argument( + "--print-only", + action="store_true", + help="Print generated YAML instead of editing the personal-data file.", + ) + parsed = parser.parse_args(argv) + + config = BOOKING_CONFIGS[booking_type] + args = [parsed.url] if parsed.url else [] + booking_text = booking_text_from_args(args) + new_yaml = generate_booking_yaml(booking_text, config, model=parsed.model) + if parsed.print_only: + print(new_yaml) + return 0 + + count = import_booking_yaml(new_yaml, config, data_dir=Path(parsed.data_dir)) + print(f"Imported {count} {booking_type} booking(s).") + return 0 + + +def train_main(argv: list[str] | None = None) -> int: + """CLI entrypoint for train booking YAML generation.""" + return main_for_type("train", argv) + + +def flight_main(argv: list[str] | None = None) -> int: + """CLI entrypoint for flight booking YAML generation.""" + return main_for_type("flight", argv) diff --git a/agenda/trip.py b/agenda/trip.py index 2561094..0518b93 100644 --- a/agenda/trip.py +++ b/agenda/trip.py @@ -119,6 +119,47 @@ def get_unbooked_flight_origin_iata( return "LHR" +UNBOOKED_RAIL_DESTINATIONS: dict[tuple[str, str], tuple[str, str]] = { + ("fr", "paris"): ("London St Pancras", "Paris Gare du Nord"), +} + + +def load_station_lookup(data_dir: str) -> dict[str, StrDict]: + """Load stations keyed by station name.""" + stations = travel.parse_yaml("stations", data_dir) + return {station["name"]: station for station in stations} + + +def get_unbooked_rail_route(item: StrDict, data_dir: str) -> StrDict | None: + """Return an assumed rail route for a conference without booked travel.""" + location = item.get("location") + country = item.get("country") + if not isinstance(location, str) or not isinstance(country, str): + return None + + station_names = UNBOOKED_RAIL_DESTINATIONS.get( + (country.casefold(), normalize_place_name(location)) + ) + if station_names is None: + return None + + stations = load_station_lookup(data_dir) + from_station = stations.get(station_names[0]) + to_station = stations.get(station_names[1]) + if from_station is None or to_station is None: + return None + + key = "_".join(["train"] + sorted([from_station["name"], to_station["name"]])) + route: StrDict = {"type": "train", "key": key} + geojson_filename = from_station.get("routes", {}).get(to_station["name"]) + if geojson_filename: + route["geojson_filename"] = os.path.join("train_routes", geojson_filename) + else: + route["from"] = latlon_tuple(from_station) + route["to"] = latlon_tuple(to_station) + return route + + def load_travel(travel_type: str, plural: str, data_dir: str) -> list[StrDict]: """Read flight and train journeys.""" items: list[StrDict] = travel.parse_yaml(plural, data_dir) @@ -473,10 +514,10 @@ def conference_free_days(trip: Trip) -> dict[str, tuple[int, int]]: return {} def conf_attend_start(c: StrDict) -> date: - return typing.cast(date, as_date(c.get("attend_start") or c["start"])) + return as_date(c.get("attend_start") or c["start"]) def conf_attend_end(c: StrDict) -> date: - return typing.cast(date, as_date(c.get("attend_end") or c["end"])) + return as_date(c.get("attend_end") or c["end"]) sorted_confs = sorted(trip.conferences, key=conf_attend_start) result: dict[str, tuple[int, int]] = {} @@ -634,6 +675,11 @@ def get_trip_routes(trip: Trip, data_dir: str) -> list[StrDict]: unbooked_routes = [] for item in trip.conferences: + unbooked_rail_route = get_unbooked_rail_route(item, data_dir) + if unbooked_rail_route is not None: + unbooked_routes.append(unbooked_rail_route) + continue + if item["country"] in {"gb", "be"}: # not flying to Belgium continue destination_iata = preferred_airport_iata(item, airports) @@ -870,7 +916,7 @@ def _trip_element_label(element: TripElement) -> str: return start_loc if isinstance(end_loc, str): return end_loc - return typing.cast(str, element.title) + return element.title def _trip_element_summary(trip: Trip, element: TripElement) -> str: diff --git a/scripts/build_airport_yaml b/scripts/build_airport_yaml new file mode 100755 index 0000000..cb7c9d6 --- /dev/null +++ b/scripts/build_airport_yaml @@ -0,0 +1,15 @@ +#!/usr/bin/python3 + +import os +import sys + +SCRIPT_PATH = os.path.realpath(__file__) +SCRIPT_DIR = os.path.dirname(SCRIPT_PATH) +REPO_ROOT = os.path.dirname(SCRIPT_DIR) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) + +from agenda.build_place_yaml import airport_main + +if __name__ == "__main__": + raise SystemExit(airport_main()) diff --git a/scripts/build_station_yaml b/scripts/build_station_yaml new file mode 100755 index 0000000..9285a8f --- /dev/null +++ b/scripts/build_station_yaml @@ -0,0 +1,15 @@ +#!/usr/bin/python3 + +import os +import sys + +SCRIPT_PATH = os.path.realpath(__file__) +SCRIPT_DIR = os.path.dirname(SCRIPT_PATH) +REPO_ROOT = os.path.dirname(SCRIPT_DIR) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) + +from agenda.build_place_yaml import station_main + +if __name__ == "__main__": + raise SystemExit(station_main()) diff --git a/scripts/generate-flight-booking-yaml b/scripts/generate-flight-booking-yaml new file mode 100755 index 0000000..ba9c161 --- /dev/null +++ b/scripts/generate-flight-booking-yaml @@ -0,0 +1,15 @@ +#!/usr/bin/python3 + +import os +import sys + +SCRIPT_PATH = os.path.realpath(__file__) +SCRIPT_DIR = os.path.dirname(SCRIPT_PATH) +REPO_ROOT = os.path.dirname(SCRIPT_DIR) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) + +from agenda.generate_booking_yaml import flight_main + +if __name__ == "__main__": + raise SystemExit(flight_main()) diff --git a/scripts/generate-train-booking-yaml b/scripts/generate-train-booking-yaml new file mode 100755 index 0000000..1c0a7ba --- /dev/null +++ b/scripts/generate-train-booking-yaml @@ -0,0 +1,15 @@ +#!/usr/bin/python3 + +import os +import sys + +SCRIPT_PATH = os.path.realpath(__file__) +SCRIPT_DIR = os.path.dirname(SCRIPT_PATH) +REPO_ROOT = os.path.dirname(SCRIPT_DIR) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) + +from agenda.generate_booking_yaml import train_main + +if __name__ == "__main__": + raise SystemExit(train_main()) diff --git a/tests/test_build_place_yaml.py b/tests/test_build_place_yaml.py new file mode 100644 index 0000000..9e181dd --- /dev/null +++ b/tests/test_build_place_yaml.py @@ -0,0 +1,101 @@ +"""Tests for agenda.build_place_yaml.""" + +from pathlib import Path + +import yaml + +from agenda import build_place_yaml + + +def test_upsert_station_adds_new_station(tmp_path: Path) -> None: + """Station upsert should add a new station to stations.yaml.""" + path = tmp_path / "stations.yaml" + path.write_text("""- name: London St Pancras + latitude: 51.531921 + longitude: -0.126361 + country: gb + wikidata: Q720102 + routes: {} +""") + + replaced = build_place_yaml.upsert_station( + tmp_path, + { + "name": "Paris Gare du Nord", + "latitude": 48.8809, + "longitude": 2.3553, + "country": "fr", + "wikidata": "Q624511", + "routes": {}, + }, + ) + + stations = yaml.safe_load(path.read_text()) + assert replaced is False + assert [station["name"] for station in stations] == [ + "London St Pancras", + "Paris Gare du Nord", + ] + assert "\n\n- name: Paris Gare du Nord\n" in path.read_text() + + +def test_upsert_station_replaces_existing_station(tmp_path: Path) -> None: + """Station upsert should replace an existing station with the same name.""" + path = tmp_path / "stations.yaml" + path.write_text("""- name: Paris Gare du Nord + latitude: 0 + longitude: 0 + country: fr + wikidata: Q624511 + routes: {} +""") + + replaced = build_place_yaml.upsert_station( + tmp_path, + { + "name": "Paris Gare du Nord", + "latitude": 48.8809, + "longitude": 2.3553, + "country": "fr", + "wikidata": "Q624511", + "routes": {}, + }, + ) + + stations = yaml.safe_load(path.read_text()) + assert replaced is True + assert len(stations) == 1 + assert stations[0]["latitude"] == 48.8809 + assert "\n\n- name:" not in path.read_text() + + +def test_upsert_airport_adds_mapping_entry(tmp_path: Path) -> None: + """Airport upsert should add a new IATA-keyed airport entry.""" + path = tmp_path / "airports.yaml" + path.write_text("""LHR: + iata: LHR + name: Heathrow Airport + city: London + country: gb + latitude: 51.47 + longitude: -0.4543 + qid: Q8691 +""") + + replaced = build_place_yaml.upsert_airport( + tmp_path, + { + "iata": "ORY", + "name": "Paris Orly Airport", + "city": "Paris Orly Airport", + "country": "fr", + "latitude": 48.723333, + "longitude": 2.379444, + "qid": "Q193353", + }, + ) + + airports = yaml.safe_load(path.read_text()) + assert replaced is False + assert list(airports) == ["LHR", "ORY"] + assert airports["ORY"]["country"] == "fr" diff --git a/tests/test_generate_booking_yaml.py b/tests/test_generate_booking_yaml.py new file mode 100644 index 0000000..51ff471 --- /dev/null +++ b/tests/test_generate_booking_yaml.py @@ -0,0 +1,187 @@ +"""Tests for agenda.generate_booking_yaml.""" + +from dataclasses import dataclass +from datetime import date +from pathlib import Path + +import pytest +import yaml + +from agenda import generate_booking_yaml + + +@dataclass +class FakeTrip: + """Small stand-in for agenda.types.Trip in matching tests.""" + + start: date + end: date | None + + +def test_yaml_format_description_includes_train_spec() -> None: + """Train prompt docs should come from personal-data-yaml.md.""" + description = generate_booking_yaml.yaml_format_description( + generate_booking_yaml.BOOKING_CONFIGS["train"] + ) + + assert "## General Rules" in description + assert "## Cross-File References" in description + assert "## `trains.yaml`" in description + assert "`operator`: booking/operator label" in description + + +def test_build_prompt_uses_spec_and_keeps_trip_exclusion() -> None: + """The generated prompt should use the documented schema.""" + prompt = generate_booking_yaml.build_prompt( + "Eurostar booking details", + generate_booking_yaml.BOOKING_CONFIGS["train"], + current_bookings="- operator: eurostar\n", + ) + + assert "Use this YAML format specification" in prompt + assert "## `trains.yaml`" in prompt + assert 'Exclude the top-level "trip" key' in prompt + assert "Eurostar booking details" in prompt + + +def test_booking_text_from_args_fetches_url(monkeypatch: pytest.MonkeyPatch) -> None: + """URL arguments should be fetched instead of reading stdin.""" + monkeypatch.setattr( + generate_booking_yaml, + "url_to_booking_text", + lambda url: f"fetched {url}", + ) + + assert generate_booking_yaml.booking_text_from_args(["https://example.com"]) == ( + "fetched https://example.com" + ) + + +def test_matching_trip_date_uses_built_trip_end( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Trip matching should use build_trip_list-derived end dates.""" + monkeypatch.setattr( + generate_booking_yaml, + "build_trips", + lambda data_dir: [ # noqa: ARG005 + FakeTrip(date(2026, 2, 6), date(2026, 2, 9)), + FakeTrip(date(2026, 3, 3), date(2026, 3, 5)), + ], + ) + + assert generate_booking_yaml.matching_trip_date( + generate_booking_yaml.datetime_from_yaml_value("2026-02-08 12:00:00+01:00"), + Path("/tmp/personal-data"), + ) == date(2026, 2, 6) + + +def test_matching_trip_date_falls_back_to_departure_date( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A booking outside known trips should use its first departure date.""" + monkeypatch.setattr( + generate_booking_yaml, "build_trips", lambda data_dir: [] # noqa: ARG005 + ) + + assert generate_booking_yaml.matching_trip_date( + generate_booking_yaml.datetime_from_yaml_value("2026-04-10 12:00:00+01:00"), + Path("/tmp/personal-data"), + ) == date(2026, 4, 10) + + +def test_import_train_booking_adds_trip_and_inserts_chronologically( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Generated train bookings should be written into trains.yaml in order.""" + (tmp_path / "trains.yaml").write_text("""- operator: early + from: A + to: B + trip: 2026-02-01 + depart: 2026-02-01 10:00:00+00:00 + arrive: 2026-02-01 11:00:00+00:00 + legs: [] + +- operator: late + from: C + to: D + trip: 2026-02-10 + depart: 2026-02-10 10:00:00+00:00 + arrive: 2026-02-10 11:00:00+00:00 + legs: [] +""") + monkeypatch.setattr( + generate_booking_yaml, + "matching_trip_date", + lambda depart, data_dir: date(2026, 2, 6), + ) + + count = generate_booking_yaml.import_booking_yaml( + """- operator: eurostar + from: London St Pancras + to: Brussels Midi + depart: 2026-02-06 15:04:00+00:00 + arrive: 2026-02-06 18:12:00+01:00 + legs: [] +""", + generate_booking_yaml.BOOKING_CONFIGS["train"], + data_dir=tmp_path, + ) + + written = yaml.safe_load((tmp_path / "trains.yaml").read_text()) + assert count == 1 + assert [item["operator"] for item in written] == ["early", "eurostar", "late"] + assert written[1]["trip"] == date(2026, 2, 6) + assert list(written[1]).index("trip") == 3 + assert "\n\n- operator: eurostar\n" in (tmp_path / "trains.yaml").read_text() + + +def test_import_flight_booking_adds_trip_and_inserts_chronologically( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Generated flight bookings should be written into flights.yaml in order.""" + (tmp_path / "flights.yaml").write_text("""--- +- booking_reference: OLD + trip: 2026-02-01 + flights: + - depart: 2026-02-01 10:00:00+00:00 + from: BRS + to: AMS + flight_number: '1' + airline: U2 + +- booking_reference: NEWER + trip: 2026-02-10 + flights: + - depart: 2026-02-10 10:00:00+00:00 + from: BRS + to: AMS + flight_number: '2' + airline: U2 +""") + monkeypatch.setattr( + generate_booking_yaml, + "matching_trip_date", + lambda depart, data_dir: date(2026, 2, 6), + ) + + count = generate_booking_yaml.import_booking_yaml( + """booking_reference: MID +flights: + - depart: 2026-02-06 10:00:00+00:00 + from: BRS + to: AMS + flight_number: '3' + airline: U2 +""", + generate_booking_yaml.BOOKING_CONFIGS["flight"], + data_dir=tmp_path, + ) + + written = yaml.safe_load((tmp_path / "flights.yaml").read_text()) + assert count == 1 + assert [item["booking_reference"] for item in written] == ["OLD", "MID", "NEWER"] + assert written[1]["trip"] == date(2026, 2, 6) + assert list(written[1]).index("trip") == 1 diff --git a/tests/test_trip.py b/tests/test_trip.py index 69c0e6e..614d669 100644 --- a/tests/test_trip.py +++ b/tests/test_trip.py @@ -3,6 +3,7 @@ from datetime import date import agenda.trip +import pytest from agenda.types import Trip from web_view import app @@ -101,3 +102,68 @@ def test_get_coordinates_and_routes_adds_unbooked_flight_airports() -> None: coord["name"] for coord in coordinates if coord["type"] == "airport" } assert airport_names == {"Heathrow Airport", "Paris Charles de Gaulle Airport"} + + +def test_get_trip_routes_assumes_unbooked_paris_trip_is_by_train( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Paris conferences without booked travel should show rail, not flight.""" + trip = Trip( + start=date(2026, 7, 20), + conferences=[ + { + "name": "Paris Conf", + "location": "Paris", + "country": "fr", + } + ], + ) + stations = [ + { + "name": "London St Pancras", + "latitude": 51.531921, + "longitude": -0.126361, + "routes": {"Paris Gare du Nord": "London_St_Pancras_to_Paris_Gare_du_Nord"}, + }, + { + "name": "Paris Gare du Nord", + "latitude": 48.88111111111111, + "longitude": 2.355277777777778, + "routes": {"London St Pancras": "London_St_Pancras_to_Paris_Gare_du_Nord"}, + }, + ] + + def fake_parse_yaml(name: str, data_dir: str) -> object: + if name == "stations": + return stations + if name == "airports": + return { + "LHR": { + "name": "Heathrow Airport", + "latitude": 51.47, + "longitude": -0.45, + }, + "CDG": { + "name": "Paris Charles de Gaulle Airport", + "city": "Paris", + "country": "fr", + "latitude": 49.01, + "longitude": 2.55, + }, + } + raise AssertionError(f"unexpected YAML load: {name}") + + monkeypatch.setattr(agenda.trip.travel, "parse_yaml", fake_parse_yaml) + monkeypatch.setattr( + agenda.trip, "load_flight_destination_rules", lambda _data_dir: [] + ) + + routes = agenda.trip.get_trip_routes(trip, "/tmp/personal-data") + + assert routes == [ + { + "type": "train", + "key": "train_London St Pancras_Paris Gare du Nord", + "geojson_filename": "train_routes/London_St_Pancras_to_Paris_Gare_du_Nord", + } + ]