Split up train and flight loading

Reduce complexity of train and flight loading functions by splitting
code out into separate functions.
This commit is contained in:
Edward Betts 2024-08-03 14:49:21 +08:00
parent a130a85a48
commit f423fcdcbe

View file

@ -28,6 +28,19 @@ def load_travel(travel_type: str, plural: str, data_dir: str) -> list[StrDict]:
return items
def process_train_leg(
leg: StrDict,
by_name: StrDict,
route_distances: travel.RouteDistances | None,
) -> None:
"""Process train leg."""
assert leg["from"] in by_name and leg["to"] in by_name
leg["from_station"], leg["to_station"] = by_name[leg["from"]], by_name[leg["to"]]
if route_distances:
travel.add_leg_route_distance(leg, route_distances)
def load_trains(
data_dir: str, route_distances: travel.RouteDistances | None = None
) -> list[StrDict]:
@ -45,13 +58,7 @@ def load_trains(
train["to_station"] = by_name[train["to"]]
for leg in train["legs"]:
assert leg["from"] in by_name
assert leg["to"] in by_name
leg["from_station"] = by_name[leg["from"]]
leg["to_station"] = by_name[leg["to"]]
if route_distances:
travel.add_leg_route_distance(leg, route_distances)
process_train_leg(leg, by_name=by_name, route_distances=route_distances)
if all("distance" in leg for leg in train["legs"]):
train["distance"] = sum(leg["distance"] for leg in train["legs"])
@ -90,14 +97,10 @@ def depart_datetime(item: StrDict) -> datetime:
return datetime.combine(depart, time.min).replace(tzinfo=ZoneInfo("UTC"))
def load_flight_bookings(data_dir: str) -> list[StrDict]:
"""Load flight bookings."""
bookings = load_travel("flight", "flights", data_dir)
airlines = yaml.safe_load(open(os.path.join(data_dir, "airlines.yaml")))
iata = {a["iata"]: a["name"] for a in airlines}
airports = travel.parse_yaml("airports", data_dir)
for booking in bookings:
for flight in booking["flights"]:
def process_flight(
flight: StrDict, iata: dict[str, str], airports: list[StrDict]
) -> None:
"""Add airport detail, airline name and distance to flight."""
if flight["from"] in airports:
flight["from_airport"] = airports[flight["from"]]
if flight["to"] in airports:
@ -106,6 +109,17 @@ def load_flight_bookings(data_dir: str) -> list[StrDict]:
flight["airline_name"] = iata.get(flight["airline"], "[unknown]")
flight["distance"] = travel.flight_distance(flight)
def load_flight_bookings(data_dir: str) -> list[StrDict]:
"""Load flight bookings."""
bookings = load_travel("flight", "flights", data_dir)
airlines = yaml.safe_load(open(os.path.join(data_dir, "airlines.yaml")))
iata = {a["iata"]: a["name"] for a in airlines}
airports = travel.parse_yaml("airports", data_dir)
for booking in bookings:
for flight in booking["flights"]:
process_flight(flight, iata, airports)
return bookings