diff --git a/agenda/stats.py b/agenda/stats.py new file mode 100644 index 0000000..930e948 --- /dev/null +++ b/agenda/stats.py @@ -0,0 +1,45 @@ +"""Trip statistic functions.""" + +from collections import defaultdict + +from agenda.types import StrDict, Trip + + +def calculate_yearly_stats(trips: list[Trip]) -> dict[int, StrDict]: + """Calculate total distance and distance by transport type grouped by year.""" + yearly_stats: defaultdict[int, StrDict] = defaultdict(dict) + for trip in trips: + year = trip.start.year + dist = trip.total_distance() + yearly_stats[year].setdefault("count", 0) + yearly_stats[year]["count"] += 1 + + for c in trip.conferences: + yearly_stats[c["start"].year].setdefault("conferences", 0) + yearly_stats[c["start"].year]["conferences"] += 1 + + if dist: + yearly_stats[year]["total_distance"] = ( + yearly_stats[year].get("total_distance", 0) + trip.total_distance() + ) + + for transport_type, distance in trip.distances_by_transport_type(): + yearly_stats[year].setdefault("distances_by_transport_type", {}) + yearly_stats[year]["distances_by_transport_type"][transport_type] = ( + yearly_stats[year]["distances_by_transport_type"].get(transport_type, 0) + + distance + ) + + for country in trip.countries: + if country.alpha_2 == "GB": + continue + yearly_stats[year].setdefault("countries", set()) + yearly_stats[year]["countries"].add(country) + for leg in trip.travel: + if leg["type"] == "flight": + yearly_stats[year].setdefault("flight_count", 0) + yearly_stats[year]["flight_count"] += 1 + if leg["type"] == "train": + yearly_stats[year].setdefault("train_count", 0) + yearly_stats[year]["train_count"] += 1 + return dict(yearly_stats) diff --git a/agenda/trip.py b/agenda/trip.py index ca0e283..3d41f2d 100644 --- a/agenda/trip.py +++ b/agenda/trip.py @@ -3,7 +3,6 @@ import decimal import os import typing -from collections import defaultdict from datetime import date, datetime, time from zoneinfo import ZoneInfo @@ -88,6 +87,11 @@ def load_ferries( def depart_datetime(item: StrDict) -> datetime: + """Return a datetime for this travel item. + + If the travel item already has a datetime return that, otherwise if the + departure time is just a date return midnight UTC for that date. + """ depart = item["depart"] if isinstance(depart, datetime): return depart @@ -389,43 +393,3 @@ def get_coordinates_and_routes( route["geojson"] = read_geojson(data_dir, route.pop("geojson_filename")) return (coordinates, routes) - - -def calculate_yearly_stats(trips: list[Trip]) -> dict[int, StrDict]: - """Calculate total distance and distance by transport type grouped by year.""" - yearly_stats: defaultdict[int, StrDict] = defaultdict(dict) - for trip in trips: - year = trip.start.year - dist = trip.total_distance() - yearly_stats[year].setdefault("count", 0) - yearly_stats[year]["count"] += 1 - - for c in trip.conferences: - yearly_stats[c["start"].year].setdefault("conferences", 0) - yearly_stats[c["start"].year]["conferences"] += 1 - - if dist: - yearly_stats[year]["total_distance"] = ( - yearly_stats[year].get("total_distance", 0) + trip.total_distance() - ) - - for transport_type, distance in trip.distances_by_transport_type(): - yearly_stats[year].setdefault("distances_by_transport_type", {}) - yearly_stats[year]["distances_by_transport_type"][transport_type] = ( - yearly_stats[year]["distances_by_transport_type"].get(transport_type, 0) - + distance - ) - - for country in trip.countries: - if country.alpha_2 == "GB": - continue - yearly_stats[year].setdefault("countries", set()) - yearly_stats[year]["countries"].add(country) - for leg in trip.travel: - if leg["type"] == "flight": - yearly_stats[year].setdefault("flight_count", 0) - yearly_stats[year]["flight_count"] += 1 - if leg["type"] == "train": - yearly_stats[year].setdefault("train_count", 0) - yearly_stats[year]["train_count"] += 1 - return dict(yearly_stats)