Compare commits

...

2 commits

Author SHA1 Message Date
Edward Betts 868c1407b5 Use pattern matching: train/flight/ferry 2024-10-02 15:36:26 +01:00
Edward Betts 7d803e0267 Refactor 2024-10-02 14:12:34 +01:00

View file

@ -225,33 +225,51 @@ def add_coordinates_for_unbooked_flights(
) )
def collect_trip_coordinates(trip: Trip) -> list[StrDict]: def get_locations(trip: Trip) -> dict[str, StrDict]:
"""Extract and de-duplicate airport and station coordinates from trip.""" """Collect locations of all travel locations in trip."""
stations = {} locations: dict[str, StrDict] = {
"station": {},
"airport": {},
"ferry_terminal": {},
}
station_list = [] station_list = []
airports = {}
ferry_terminals = {}
for t in trip.travel: for t in trip.travel:
if t["type"] == "train": match t["type"]:
station_list += [t["from_station"], t["to_station"]] case "train":
for leg in t["legs"]: station_list += [t["from_station"], t["to_station"]]
station_list.append(leg["from_station"]) for leg in t["legs"]:
station_list.append(leg["to_station"]) station_list.append(leg["from_station"])
elif t["type"] == "flight": station_list.append(leg["to_station"])
for field in "from_airport", "to_airport": case "flight":
if field in t: for field in "from_airport", "to_airport":
airports[t[field]["iata"]] = t[field] if field in t:
else: locations["airport"][t[field]["iata"]] = t[field]
assert t["type"] == "ferry" case "ferry":
for field in "from_terminal", "to_terminal": for field in "from_terminal", "to_terminal":
terminal = t[field] terminal = t[field]
ferry_terminals[terminal["name"]] = terminal locations["ferry_terminal"][terminal["name"]] = terminal
for s in station_list: for s in station_list:
if s["uic"] in stations: if s["uic"] in locations["station"]:
continue continue
stations[s["uic"]] = s locations["station"][s["uic"]] = s
return locations
def coordinate_dict(item: StrDict, coord_type: str) -> StrDict:
"""Build coodinate dict for item."""
return {
"name": item["name"],
"type": coord_type,
"latitude": item["latitude"],
"longitude": item["longitude"],
}
def collect_trip_coordinates(trip: Trip) -> list[StrDict]:
"""Extract and de-duplicate travel location coordinates from trip."""
coords = [] coords = []
src = [ src = [
@ -261,31 +279,14 @@ def collect_trip_coordinates(trip: Trip) -> list[StrDict]:
] ]
for coord_type, item_list in src: for coord_type, item_list in src:
coords += [ coords += [
{ coordinate_dict(item, coord_type)
"name": item["name"],
"type": coord_type,
"latitude": item["latitude"],
"longitude": item["longitude"],
}
for item in item_list for item in item_list
if "latitude" in item and "longitude" in item if "latitude" in item and "longitude" in item
] ]
locations = [ locations = get_locations(trip)
("station", stations), for coord_type, coord_dict in locations.items():
("airport", airports), coords += [coordinate_dict(s, coord_type) for s in coord_dict.values()]
("ferry_terminal", ferry_terminals),
]
for coord_type, coord_dict in locations:
coords += [
{
"name": s["name"],
"type": coord_type,
"latitude": s["latitude"],
"longitude": s["longitude"],
}
for s in coord_dict.values()
]
return coords return coords