From 96ec2b7d897c7e1b0c0dad078a3a2b06fa70d3ab Mon Sep 17 00:00:00 2001
From: Edward Betts <edward@4angle.com>
Date: Wed, 25 Sep 2024 12:15:13 +0100
Subject: [PATCH] Catch missing station in train leg

Raise UnknownStation for missing stations in train journey leg.
---
 agenda/trip.py | 33 +++++++++++++++------------------
 1 file changed, 15 insertions(+), 18 deletions(-)

diff --git a/agenda/trip.py b/agenda/trip.py
index 4ce6eca..ca0e283 100644
--- a/agenda/trip.py
+++ b/agenda/trip.py
@@ -28,17 +28,18 @@ def load_travel(travel_type: str, plural: str, data_dir: str) -> list[StrDict]:
     return items
 
 
-def process_train_leg(
-    leg: StrDict,
-    by_name: StrDict,
-    route_distances: travel.RouteDistances | None,
-) -> None:
-    """Process train leg."""
-    assert leg["from"] in by_name and leg["to"] in by_name
-    leg["from_station"], leg["to_station"] = by_name[leg["from"]], by_name[leg["to"]]
+def get_station(name: str, by_name: dict[str, StrDict]) -> StrDict:
+    """Get station by name."""
+    try:
+        return by_name[name]
+    except IndexError:
+        raise UnknownStation(name)
 
-    if route_distances:
-        travel.add_leg_route_distance(leg, route_distances)
+
+def add_station_objects(item: StrDict, by_name: dict[str, StrDict]) -> None:
+    """Lookup stations and add to train or leg."""
+    item["from_station"] = get_station(item["from"], by_name)
+    item["to_station"] = get_station(item["to"], by_name)
 
 
 def load_trains(
@@ -50,15 +51,11 @@ def load_trains(
     by_name = {station["name"]: station for station in stations}
 
     for train in trains:
-        if train["from"] not in by_name:
-            raise UnknownStation(train["from"])
-        if train["to"] not in by_name:
-            raise UnknownStation(train["to"])
-        train["from_station"] = by_name[train["from"]]
-        train["to_station"] = by_name[train["to"]]
-
+        add_station_objects(train, by_name)
         for leg in train["legs"]:
-            process_train_leg(leg, by_name=by_name, route_distances=route_distances)
+            add_station_objects(leg, by_name)
+            if route_distances:
+                travel.add_leg_route_distance(leg, route_distances)
 
         if all("distance" in leg for leg in train["legs"]):
             train["distance"] = sum(leg["distance"] for leg in train["legs"])