Compare commits

...

2 commits

Author SHA1 Message Date
Edward Betts 0b23f71aa6 Refactor and add some docstrings. 2024-10-02 10:16:30 +01:00
Edward Betts 8cbfb745c4 Split code into new file stats.py 2024-10-02 09:09:39 +01:00
2 changed files with 91 additions and 66 deletions

45
agenda/stats.py Normal file
View file

@ -0,0 +1,45 @@
"""Trip statistic functions."""
from collections import defaultdict
from agenda.types import StrDict, Trip
def calculate_yearly_stats(trips: list[Trip]) -> dict[int, StrDict]:
"""Calculate total distance and distance by transport type grouped by year."""
yearly_stats: defaultdict[int, StrDict] = defaultdict(dict)
for trip in trips:
year = trip.start.year
dist = trip.total_distance()
yearly_stats[year].setdefault("count", 0)
yearly_stats[year]["count"] += 1
for c in trip.conferences:
yearly_stats[c["start"].year].setdefault("conferences", 0)
yearly_stats[c["start"].year]["conferences"] += 1
if dist:
yearly_stats[year]["total_distance"] = (
yearly_stats[year].get("total_distance", 0) + trip.total_distance()
)
for transport_type, distance in trip.distances_by_transport_type():
yearly_stats[year].setdefault("distances_by_transport_type", {})
yearly_stats[year]["distances_by_transport_type"][transport_type] = (
yearly_stats[year]["distances_by_transport_type"].get(transport_type, 0)
+ distance
)
for country in trip.countries:
if country.alpha_2 == "GB":
continue
yearly_stats[year].setdefault("countries", set())
yearly_stats[year]["countries"].add(country)
for leg in trip.travel:
if leg["type"] == "flight":
yearly_stats[year].setdefault("flight_count", 0)
yearly_stats[year]["flight_count"] += 1
if leg["type"] == "train":
yearly_stats[year].setdefault("train_count", 0)
yearly_stats[year]["train_count"] += 1
return dict(yearly_stats)

View file

@ -3,7 +3,6 @@
import decimal import decimal
import os import os
import typing import typing
from collections import defaultdict
from datetime import date, datetime, time from datetime import date, datetime, time
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
@ -88,6 +87,11 @@ def load_ferries(
def depart_datetime(item: StrDict) -> datetime: def depart_datetime(item: StrDict) -> datetime:
"""Return a datetime for this travel item.
If the travel item already has a datetime return that, otherwise if the
departure time is just a date return midnight UTC for that date.
"""
depart = item["depart"] depart = item["depart"]
if isinstance(depart, datetime): if isinstance(depart, datetime):
return depart return depart
@ -132,42 +136,29 @@ def load_flights(flight_bookings: list[StrDict]) -> list[StrDict]:
return flights return flights
def build_trip_list( def collect_travel_items(
flight_bookings: list[StrDict],
data_dir: str | None = None, data_dir: str | None = None,
route_distances: travel.RouteDistances | None = None, route_distances: travel.RouteDistances | None = None,
) -> list[Trip]: ) -> list[StrDict]:
"""Generate list of trips.""" """Generate list of trips."""
trips: dict[date, Trip] = {}
if data_dir is None: if data_dir is None:
data_dir = flask.current_app.config["PERSONAL_DATA"] data_dir = flask.current_app.config["PERSONAL_DATA"]
yaml_trip_list = travel.parse_yaml("trips", data_dir) return sorted(
load_flights(load_flight_bookings(data_dir))
yaml_trip_lookup = {item["trip"]: item for item in yaml_trip_list}
flight_bookings = load_flight_bookings(data_dir)
travel_items = sorted(
load_flights(flight_bookings)
+ load_trains(data_dir, route_distances=route_distances) + load_trains(data_dir, route_distances=route_distances)
+ load_ferries(data_dir, route_distances=route_distances), + load_ferries(data_dir, route_distances=route_distances),
key=depart_datetime, key=depart_datetime,
) )
data = {
"flight_bookings": flight_bookings,
"travel": travel_items,
"accommodation": travel.parse_yaml("accommodation", data_dir),
"conferences": travel.parse_yaml("conferences", data_dir),
"events": travel.parse_yaml("events", data_dir),
}
for item in data["accommodation"]:
price = item.get("price")
if price:
item["price"] = decimal.Decimal(price)
def group_travel_items_into_trips(
data: StrDict, yaml_trip_list: list[StrDict]
) -> list[Trip]:
"""Group travel items into trips."""
trips: dict[date, Trip] = {}
yaml_trip_lookup = {item["trip"]: item for item in yaml_trip_list}
for key, item_list in data.items(): for key, item_list in data.items():
assert isinstance(item_list, list) assert isinstance(item_list, list)
for item in item_list: for item in item_list:
@ -183,6 +174,34 @@ def build_trip_list(
return [trip for _, trip in sorted(trips.items())] return [trip for _, trip in sorted(trips.items())]
def build_trip_list(
data_dir: str | None = None,
route_distances: travel.RouteDistances | None = None,
) -> list[Trip]:
"""Generate list of trips."""
if data_dir is None:
data_dir = flask.current_app.config["PERSONAL_DATA"]
yaml_trip_list = travel.parse_yaml("trips", data_dir)
flight_bookings = load_flight_bookings(data_dir)
data = {
"flight_bookings": flight_bookings,
"travel": collect_travel_items(flight_bookings, data_dir, route_distances),
"accommodation": travel.parse_yaml("accommodation", data_dir),
"conferences": travel.parse_yaml("conferences", data_dir),
"events": travel.parse_yaml("events", data_dir),
}
for item in data["accommodation"]:
price = item.get("price")
if price:
item["price"] = decimal.Decimal(price)
return group_travel_items_into_trips(data, yaml_trip_list)
def add_coordinates_for_unbooked_flights( def add_coordinates_for_unbooked_flights(
routes: list[StrDict], coordinates: list[StrDict] routes: list[StrDict], coordinates: list[StrDict]
) -> None: ) -> None:
@ -207,7 +226,7 @@ def add_coordinates_for_unbooked_flights(
def collect_trip_coordinates(trip: Trip) -> list[StrDict]: def collect_trip_coordinates(trip: Trip) -> list[StrDict]:
"""Extract and deduplicate airport and station coordinates from trip.""" """Extract and de-duplicate airport and station coordinates from trip."""
stations = {} stations = {}
station_list = [] station_list = []
airports = {} airports = {}
@ -364,6 +383,7 @@ def get_trip_routes(trip: Trip) -> list[StrDict]:
def get_coordinates_and_routes( def get_coordinates_and_routes(
trip_list: list[Trip], data_dir: str | None = None trip_list: list[Trip], data_dir: str | None = None
) -> tuple[list[StrDict], list[StrDict]]: ) -> tuple[list[StrDict], list[StrDict]]:
"""Given a list of trips return the associated coordinates and routes."""
if data_dir is None: if data_dir is None:
data_dir = flask.current_app.config["PERSONAL_DATA"] data_dir = flask.current_app.config["PERSONAL_DATA"]
coordinates = [] coordinates = []
@ -389,43 +409,3 @@ def get_coordinates_and_routes(
route["geojson"] = read_geojson(data_dir, route.pop("geojson_filename")) route["geojson"] = read_geojson(data_dir, route.pop("geojson_filename"))
return (coordinates, routes) return (coordinates, routes)
def calculate_yearly_stats(trips: list[Trip]) -> dict[int, StrDict]:
"""Calculate total distance and distance by transport type grouped by year."""
yearly_stats: defaultdict[int, StrDict] = defaultdict(dict)
for trip in trips:
year = trip.start.year
dist = trip.total_distance()
yearly_stats[year].setdefault("count", 0)
yearly_stats[year]["count"] += 1
for c in trip.conferences:
yearly_stats[c["start"].year].setdefault("conferences", 0)
yearly_stats[c["start"].year]["conferences"] += 1
if dist:
yearly_stats[year]["total_distance"] = (
yearly_stats[year].get("total_distance", 0) + trip.total_distance()
)
for transport_type, distance in trip.distances_by_transport_type():
yearly_stats[year].setdefault("distances_by_transport_type", {})
yearly_stats[year]["distances_by_transport_type"][transport_type] = (
yearly_stats[year]["distances_by_transport_type"].get(transport_type, 0)
+ distance
)
for country in trip.countries:
if country.alpha_2 == "GB":
continue
yearly_stats[year].setdefault("countries", set())
yearly_stats[year]["countries"].add(country)
for leg in trip.travel:
if leg["type"] == "flight":
yearly_stats[year].setdefault("flight_count", 0)
yearly_stats[year]["flight_count"] += 1
if leg["type"] == "train":
yearly_stats[year].setdefault("train_count", 0)
yearly_stats[year]["train_count"] += 1
return dict(yearly_stats)