Compare commits

..

2 commits

Author SHA1 Message Date
Edward Betts 0b23f71aa6 Refactor and add some docstrings. 2024-10-02 10:16:30 +01:00
Edward Betts 8cbfb745c4 Split code into new file stats.py 2024-10-02 09:09:39 +01:00
2 changed files with 91 additions and 66 deletions

45
agenda/stats.py Normal file
View file

@ -0,0 +1,45 @@
"""Trip statistic functions."""
from collections import defaultdict
from agenda.types import StrDict, Trip
def calculate_yearly_stats(trips: list[Trip]) -> dict[int, StrDict]:
"""Calculate total distance and distance by transport type grouped by year."""
yearly_stats: defaultdict[int, StrDict] = defaultdict(dict)
for trip in trips:
year = trip.start.year
dist = trip.total_distance()
yearly_stats[year].setdefault("count", 0)
yearly_stats[year]["count"] += 1
for c in trip.conferences:
yearly_stats[c["start"].year].setdefault("conferences", 0)
yearly_stats[c["start"].year]["conferences"] += 1
if dist:
yearly_stats[year]["total_distance"] = (
yearly_stats[year].get("total_distance", 0) + trip.total_distance()
)
for transport_type, distance in trip.distances_by_transport_type():
yearly_stats[year].setdefault("distances_by_transport_type", {})
yearly_stats[year]["distances_by_transport_type"][transport_type] = (
yearly_stats[year]["distances_by_transport_type"].get(transport_type, 0)
+ distance
)
for country in trip.countries:
if country.alpha_2 == "GB":
continue
yearly_stats[year].setdefault("countries", set())
yearly_stats[year]["countries"].add(country)
for leg in trip.travel:
if leg["type"] == "flight":
yearly_stats[year].setdefault("flight_count", 0)
yearly_stats[year]["flight_count"] += 1
if leg["type"] == "train":
yearly_stats[year].setdefault("train_count", 0)
yearly_stats[year]["train_count"] += 1
return dict(yearly_stats)

View file

@ -3,7 +3,6 @@
import decimal
import os
import typing
from collections import defaultdict
from datetime import date, datetime, time
from zoneinfo import ZoneInfo
@ -88,6 +87,11 @@ def load_ferries(
def depart_datetime(item: StrDict) -> datetime:
"""Return a datetime for this travel item.
If the travel item already has a datetime return that, otherwise if the
departure time is just a date return midnight UTC for that date.
"""
depart = item["depart"]
if isinstance(depart, datetime):
return depart
@ -132,42 +136,29 @@ def load_flights(flight_bookings: list[StrDict]) -> list[StrDict]:
return flights
def build_trip_list(
def collect_travel_items(
flight_bookings: list[StrDict],
data_dir: str | None = None,
route_distances: travel.RouteDistances | None = None,
) -> list[Trip]:
) -> list[StrDict]:
"""Generate list of trips."""
trips: dict[date, Trip] = {}
if data_dir is None:
data_dir = flask.current_app.config["PERSONAL_DATA"]
yaml_trip_list = travel.parse_yaml("trips", data_dir)
yaml_trip_lookup = {item["trip"]: item for item in yaml_trip_list}
flight_bookings = load_flight_bookings(data_dir)
travel_items = sorted(
load_flights(flight_bookings)
return sorted(
load_flights(load_flight_bookings(data_dir))
+ load_trains(data_dir, route_distances=route_distances)
+ load_ferries(data_dir, route_distances=route_distances),
key=depart_datetime,
)
data = {
"flight_bookings": flight_bookings,
"travel": travel_items,
"accommodation": travel.parse_yaml("accommodation", data_dir),
"conferences": travel.parse_yaml("conferences", data_dir),
"events": travel.parse_yaml("events", data_dir),
}
for item in data["accommodation"]:
price = item.get("price")
if price:
item["price"] = decimal.Decimal(price)
def group_travel_items_into_trips(
data: StrDict, yaml_trip_list: list[StrDict]
) -> list[Trip]:
"""Group travel items into trips."""
trips: dict[date, Trip] = {}
yaml_trip_lookup = {item["trip"]: item for item in yaml_trip_list}
for key, item_list in data.items():
assert isinstance(item_list, list)
for item in item_list:
@ -183,6 +174,34 @@ def build_trip_list(
return [trip for _, trip in sorted(trips.items())]
def build_trip_list(
data_dir: str | None = None,
route_distances: travel.RouteDistances | None = None,
) -> list[Trip]:
"""Generate list of trips."""
if data_dir is None:
data_dir = flask.current_app.config["PERSONAL_DATA"]
yaml_trip_list = travel.parse_yaml("trips", data_dir)
flight_bookings = load_flight_bookings(data_dir)
data = {
"flight_bookings": flight_bookings,
"travel": collect_travel_items(flight_bookings, data_dir, route_distances),
"accommodation": travel.parse_yaml("accommodation", data_dir),
"conferences": travel.parse_yaml("conferences", data_dir),
"events": travel.parse_yaml("events", data_dir),
}
for item in data["accommodation"]:
price = item.get("price")
if price:
item["price"] = decimal.Decimal(price)
return group_travel_items_into_trips(data, yaml_trip_list)
def add_coordinates_for_unbooked_flights(
routes: list[StrDict], coordinates: list[StrDict]
) -> None:
@ -207,7 +226,7 @@ def add_coordinates_for_unbooked_flights(
def collect_trip_coordinates(trip: Trip) -> list[StrDict]:
"""Extract and deduplicate airport and station coordinates from trip."""
"""Extract and de-duplicate airport and station coordinates from trip."""
stations = {}
station_list = []
airports = {}
@ -364,6 +383,7 @@ def get_trip_routes(trip: Trip) -> list[StrDict]:
def get_coordinates_and_routes(
trip_list: list[Trip], data_dir: str | None = None
) -> tuple[list[StrDict], list[StrDict]]:
"""Given a list of trips return the associated coordinates and routes."""
if data_dir is None:
data_dir = flask.current_app.config["PERSONAL_DATA"]
coordinates = []
@ -389,43 +409,3 @@ def get_coordinates_and_routes(
route["geojson"] = read_geojson(data_dir, route.pop("geojson_filename"))
return (coordinates, routes)
def calculate_yearly_stats(trips: list[Trip]) -> dict[int, StrDict]:
"""Calculate total distance and distance by transport type grouped by year."""
yearly_stats: defaultdict[int, StrDict] = defaultdict(dict)
for trip in trips:
year = trip.start.year
dist = trip.total_distance()
yearly_stats[year].setdefault("count", 0)
yearly_stats[year]["count"] += 1
for c in trip.conferences:
yearly_stats[c["start"].year].setdefault("conferences", 0)
yearly_stats[c["start"].year]["conferences"] += 1
if dist:
yearly_stats[year]["total_distance"] = (
yearly_stats[year].get("total_distance", 0) + trip.total_distance()
)
for transport_type, distance in trip.distances_by_transport_type():
yearly_stats[year].setdefault("distances_by_transport_type", {})
yearly_stats[year]["distances_by_transport_type"][transport_type] = (
yearly_stats[year]["distances_by_transport_type"].get(transport_type, 0)
+ distance
)
for country in trip.countries:
if country.alpha_2 == "GB":
continue
yearly_stats[year].setdefault("countries", set())
yearly_stats[year]["countries"].add(country)
for leg in trip.travel:
if leg["type"] == "flight":
yearly_stats[year].setdefault("flight_count", 0)
yearly_stats[year]["flight_count"] += 1
if leg["type"] == "train":
yearly_stats[year].setdefault("train_count", 0)
yearly_stats[year]["train_count"] += 1
return dict(yearly_stats)