diff --git a/agenda/trip.py b/agenda/trip.py index a3cc35c..e730ad4 100644 --- a/agenda/trip.py +++ b/agenda/trip.py @@ -225,51 +225,33 @@ def add_coordinates_for_unbooked_flights( ) -def get_locations(trip: Trip) -> dict[str, StrDict]: - """Collect locations of all travel locations in trip.""" - locations: dict[str, StrDict] = { - "station": {}, - "airport": {}, - "ferry_terminal": {}, - } - +def collect_trip_coordinates(trip: Trip) -> list[StrDict]: + """Extract and de-duplicate airport and station coordinates from trip.""" + stations = {} station_list = [] + airports = {} + ferry_terminals = {} for t in trip.travel: - match t["type"]: - case "train": - station_list += [t["from_station"], t["to_station"]] - for leg in t["legs"]: - station_list.append(leg["from_station"]) - station_list.append(leg["to_station"]) - case "flight": - for field in "from_airport", "to_airport": - if field in t: - locations["airport"][t[field]["iata"]] = t[field] - case "ferry": - for field in "from_terminal", "to_terminal": - terminal = t[field] - locations["ferry_terminal"][terminal["name"]] = terminal + if t["type"] == "train": + station_list += [t["from_station"], t["to_station"]] + for leg in t["legs"]: + station_list.append(leg["from_station"]) + station_list.append(leg["to_station"]) + elif t["type"] == "flight": + for field in "from_airport", "to_airport": + if field in t: + airports[t[field]["iata"]] = t[field] + else: + assert t["type"] == "ferry" + for field in "from_terminal", "to_terminal": + terminal = t[field] + ferry_terminals[terminal["name"]] = terminal for s in station_list: - if s["uic"] in locations["station"]: + if s["uic"] in stations: continue - locations["station"][s["uic"]] = s + stations[s["uic"]] = s - return locations - - -def coordinate_dict(item: StrDict, coord_type: str) -> StrDict: - """Build coodinate dict for item.""" - return { - "name": item["name"], - "type": coord_type, - "latitude": item["latitude"], - "longitude": item["longitude"], - } - - -def collect_trip_coordinates(trip: Trip) -> list[StrDict]: - """Extract and de-duplicate travel location coordinates from trip.""" coords = [] src = [ @@ -279,14 +261,31 @@ def collect_trip_coordinates(trip: Trip) -> list[StrDict]: ] for coord_type, item_list in src: coords += [ - coordinate_dict(item, coord_type) + { + "name": item["name"], + "type": coord_type, + "latitude": item["latitude"], + "longitude": item["longitude"], + } for item in item_list if "latitude" in item and "longitude" in item ] - locations = get_locations(trip) - for coord_type, coord_dict in locations.items(): - coords += [coordinate_dict(s, coord_type) for s in coord_dict.values()] + locations = [ + ("station", stations), + ("airport", airports), + ("ferry_terminal", ferry_terminals), + ] + for coord_type, coord_dict in locations: + coords += [ + { + "name": s["name"], + "type": coord_type, + "latitude": s["latitude"], + "longitude": s["longitude"], + } + for s in coord_dict.values() + ] return coords