Reduce complexity of agent.trip.get_locations().

This commit is contained in:
Edward Betts 2025-02-09 13:33:11 +00:00
parent 6444229694
commit b2cef3933d

View file

@ -211,6 +211,26 @@ def add_coordinates_for_unbooked_flights(
) )
def stations_from_travel(t: StrDict) -> list[StrDict]:
"""Stations from train journey."""
station_list = [t["from_station"], t["to_station"]]
for leg in t["legs"]:
station_list.append(leg["from_station"])
station_list.append(leg["to_station"])
return station_list
def process_station_list(station_list: list[StrDict]) -> StrDict:
"""Proess sation list."""
stations = {}
for s in station_list:
if s["name"] in stations:
continue
stations[s["name"]] = s
return stations
def get_locations(trip: Trip) -> dict[str, StrDict]: def get_locations(trip: Trip) -> dict[str, StrDict]:
"""Collect locations of all travel locations in trip.""" """Collect locations of all travel locations in trip."""
locations: dict[str, StrDict] = { locations: dict[str, StrDict] = {
@ -223,10 +243,7 @@ def get_locations(trip: Trip) -> dict[str, StrDict]:
for t in trip.travel: for t in trip.travel:
match t["type"]: match t["type"]:
case "train": case "train":
station_list += [t["from_station"], t["to_station"]] station_list += stations_from_travel(t)
for leg in t["legs"]:
station_list.append(leg["from_station"])
station_list.append(leg["to_station"])
case "flight": case "flight":
for field in "from_airport", "to_airport": for field in "from_airport", "to_airport":
if field in t: if field in t:
@ -236,11 +253,7 @@ def get_locations(trip: Trip) -> dict[str, StrDict]:
terminal = t[field] terminal = t[field]
locations["ferry_terminal"][terminal["name"]] = terminal locations["ferry_terminal"][terminal["name"]] = terminal
for s in station_list: locations["station"] = process_station_list(station_list)
if s["name"] in locations["station"]:
continue
locations["station"][s["name"]] = s
return locations return locations