Split code into new file stats.py

This commit is contained in:
Edward Betts 2024-10-02 09:09:39 +01:00
parent a324046332
commit 8cbfb745c4
2 changed files with 50 additions and 41 deletions

45
agenda/stats.py Normal file
View file

@ -0,0 +1,45 @@
"""Trip statistic functions."""
from collections import defaultdict
from agenda.types import StrDict, Trip
def calculate_yearly_stats(trips: list[Trip]) -> dict[int, StrDict]:
"""Calculate total distance and distance by transport type grouped by year."""
yearly_stats: defaultdict[int, StrDict] = defaultdict(dict)
for trip in trips:
year = trip.start.year
dist = trip.total_distance()
yearly_stats[year].setdefault("count", 0)
yearly_stats[year]["count"] += 1
for c in trip.conferences:
yearly_stats[c["start"].year].setdefault("conferences", 0)
yearly_stats[c["start"].year]["conferences"] += 1
if dist:
yearly_stats[year]["total_distance"] = (
yearly_stats[year].get("total_distance", 0) + trip.total_distance()
)
for transport_type, distance in trip.distances_by_transport_type():
yearly_stats[year].setdefault("distances_by_transport_type", {})
yearly_stats[year]["distances_by_transport_type"][transport_type] = (
yearly_stats[year]["distances_by_transport_type"].get(transport_type, 0)
+ distance
)
for country in trip.countries:
if country.alpha_2 == "GB":
continue
yearly_stats[year].setdefault("countries", set())
yearly_stats[year]["countries"].add(country)
for leg in trip.travel:
if leg["type"] == "flight":
yearly_stats[year].setdefault("flight_count", 0)
yearly_stats[year]["flight_count"] += 1
if leg["type"] == "train":
yearly_stats[year].setdefault("train_count", 0)
yearly_stats[year]["train_count"] += 1
return dict(yearly_stats)

View file

@ -3,7 +3,6 @@
import decimal import decimal
import os import os
import typing import typing
from collections import defaultdict
from datetime import date, datetime, time from datetime import date, datetime, time
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
@ -88,6 +87,11 @@ def load_ferries(
def depart_datetime(item: StrDict) -> datetime: def depart_datetime(item: StrDict) -> datetime:
"""Return a datetime for this travel item.
If the travel item already has a datetime return that, otherwise if the
departure time is just a date return midnight UTC for that date.
"""
depart = item["depart"] depart = item["depart"]
if isinstance(depart, datetime): if isinstance(depart, datetime):
return depart return depart
@ -389,43 +393,3 @@ def get_coordinates_and_routes(
route["geojson"] = read_geojson(data_dir, route.pop("geojson_filename")) route["geojson"] = read_geojson(data_dir, route.pop("geojson_filename"))
return (coordinates, routes) return (coordinates, routes)
def calculate_yearly_stats(trips: list[Trip]) -> dict[int, StrDict]:
"""Calculate total distance and distance by transport type grouped by year."""
yearly_stats: defaultdict[int, StrDict] = defaultdict(dict)
for trip in trips:
year = trip.start.year
dist = trip.total_distance()
yearly_stats[year].setdefault("count", 0)
yearly_stats[year]["count"] += 1
for c in trip.conferences:
yearly_stats[c["start"].year].setdefault("conferences", 0)
yearly_stats[c["start"].year]["conferences"] += 1
if dist:
yearly_stats[year]["total_distance"] = (
yearly_stats[year].get("total_distance", 0) + trip.total_distance()
)
for transport_type, distance in trip.distances_by_transport_type():
yearly_stats[year].setdefault("distances_by_transport_type", {})
yearly_stats[year]["distances_by_transport_type"][transport_type] = (
yearly_stats[year]["distances_by_transport_type"].get(transport_type, 0)
+ distance
)
for country in trip.countries:
if country.alpha_2 == "GB":
continue
yearly_stats[year].setdefault("countries", set())
yearly_stats[year]["countries"].add(country)
for leg in trip.travel:
if leg["type"] == "flight":
yearly_stats[year].setdefault("flight_count", 0)
yearly_stats[year]["flight_count"] += 1
if leg["type"] == "train":
yearly_stats[year].setdefault("train_count", 0)
yearly_stats[year]["train_count"] += 1
return dict(yearly_stats)