Split up train and flight loading
Reduce complexity of train and flight loading functions by splitting code out into separate functions.
This commit is contained in:
parent
a130a85a48
commit
f423fcdcbe
|
@ -28,6 +28,19 @@ def load_travel(travel_type: str, plural: str, data_dir: str) -> list[StrDict]:
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
def process_train_leg(
|
||||||
|
leg: StrDict,
|
||||||
|
by_name: StrDict,
|
||||||
|
route_distances: travel.RouteDistances | None,
|
||||||
|
) -> None:
|
||||||
|
"""Process train leg."""
|
||||||
|
assert leg["from"] in by_name and leg["to"] in by_name
|
||||||
|
leg["from_station"], leg["to_station"] = by_name[leg["from"]], by_name[leg["to"]]
|
||||||
|
|
||||||
|
if route_distances:
|
||||||
|
travel.add_leg_route_distance(leg, route_distances)
|
||||||
|
|
||||||
|
|
||||||
def load_trains(
|
def load_trains(
|
||||||
data_dir: str, route_distances: travel.RouteDistances | None = None
|
data_dir: str, route_distances: travel.RouteDistances | None = None
|
||||||
) -> list[StrDict]:
|
) -> list[StrDict]:
|
||||||
|
@ -45,13 +58,7 @@ def load_trains(
|
||||||
train["to_station"] = by_name[train["to"]]
|
train["to_station"] = by_name[train["to"]]
|
||||||
|
|
||||||
for leg in train["legs"]:
|
for leg in train["legs"]:
|
||||||
assert leg["from"] in by_name
|
process_train_leg(leg, by_name=by_name, route_distances=route_distances)
|
||||||
assert leg["to"] in by_name
|
|
||||||
leg["from_station"] = by_name[leg["from"]]
|
|
||||||
leg["to_station"] = by_name[leg["to"]]
|
|
||||||
|
|
||||||
if route_distances:
|
|
||||||
travel.add_leg_route_distance(leg, route_distances)
|
|
||||||
|
|
||||||
if all("distance" in leg for leg in train["legs"]):
|
if all("distance" in leg for leg in train["legs"]):
|
||||||
train["distance"] = sum(leg["distance"] for leg in train["legs"])
|
train["distance"] = sum(leg["distance"] for leg in train["legs"])
|
||||||
|
@ -90,14 +97,10 @@ def depart_datetime(item: StrDict) -> datetime:
|
||||||
return datetime.combine(depart, time.min).replace(tzinfo=ZoneInfo("UTC"))
|
return datetime.combine(depart, time.min).replace(tzinfo=ZoneInfo("UTC"))
|
||||||
|
|
||||||
|
|
||||||
def load_flight_bookings(data_dir: str) -> list[StrDict]:
|
def process_flight(
|
||||||
"""Load flight bookings."""
|
flight: StrDict, iata: dict[str, str], airports: list[StrDict]
|
||||||
bookings = load_travel("flight", "flights", data_dir)
|
) -> None:
|
||||||
airlines = yaml.safe_load(open(os.path.join(data_dir, "airlines.yaml")))
|
"""Add airport detail, airline name and distance to flight."""
|
||||||
iata = {a["iata"]: a["name"] for a in airlines}
|
|
||||||
airports = travel.parse_yaml("airports", data_dir)
|
|
||||||
for booking in bookings:
|
|
||||||
for flight in booking["flights"]:
|
|
||||||
if flight["from"] in airports:
|
if flight["from"] in airports:
|
||||||
flight["from_airport"] = airports[flight["from"]]
|
flight["from_airport"] = airports[flight["from"]]
|
||||||
if flight["to"] in airports:
|
if flight["to"] in airports:
|
||||||
|
@ -106,6 +109,17 @@ def load_flight_bookings(data_dir: str) -> list[StrDict]:
|
||||||
flight["airline_name"] = iata.get(flight["airline"], "[unknown]")
|
flight["airline_name"] = iata.get(flight["airline"], "[unknown]")
|
||||||
|
|
||||||
flight["distance"] = travel.flight_distance(flight)
|
flight["distance"] = travel.flight_distance(flight)
|
||||||
|
|
||||||
|
|
||||||
|
def load_flight_bookings(data_dir: str) -> list[StrDict]:
|
||||||
|
"""Load flight bookings."""
|
||||||
|
bookings = load_travel("flight", "flights", data_dir)
|
||||||
|
airlines = yaml.safe_load(open(os.path.join(data_dir, "airlines.yaml")))
|
||||||
|
iata = {a["iata"]: a["name"] for a in airlines}
|
||||||
|
airports = travel.parse_yaml("airports", data_dir)
|
||||||
|
for booking in bookings:
|
||||||
|
for flight in booking["flights"]:
|
||||||
|
process_flight(flight, iata, airports)
|
||||||
return bookings
|
return bookings
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue