"""Trip statistic functions."""

from collections import defaultdict
from typing import Counter, Mapping

from agenda.types import StrDict, Trip


def travel_legs(trip: Trip, stats: StrDict) -> None:
    """Calcuate stats for travel legs."""
    for leg in trip.travel:
        if leg["type"] == "flight":
            stats.setdefault("flight_count", 0)
            stats.setdefault("airlines", Counter())
            stats["flight_count"] += 1
            stats["airlines"][leg["airline_name"]] += 1
        if leg["type"] == "train":
            stats.setdefault("train_count", 0)
            stats["train_count"] += 1


def conferences(trip: Trip, yearly_stats: Mapping[int, StrDict]) -> None:
    """Calculate conference stats."""
    for c in trip.conferences:
        yearly_stats[c["start"].year].setdefault("conferences", 0)
        yearly_stats[c["start"].year]["conferences"] += 1


def calculate_yearly_stats(trips: list[Trip]) -> dict[int, StrDict]:
    """Calculate total distance and distance by transport type grouped by year."""
    yearly_stats: defaultdict[int, StrDict] = defaultdict(dict)
    for trip in trips:
        year = trip.start.year
        dist = trip.total_distance()
        yearly_stats[year].setdefault("count", 0)
        yearly_stats[year]["count"] += 1

        conferences(trip, yearly_stats)

        if dist:
            yearly_stats[year]["total_distance"] = (
                yearly_stats[year].get("total_distance", 0) + trip.total_distance()
            )

        for transport_type, distance in trip.distances_by_transport_type():
            yearly_stats[year].setdefault("distances_by_transport_type", {})
            yearly_stats[year]["distances_by_transport_type"][transport_type] = (
                yearly_stats[year]["distances_by_transport_type"].get(transport_type, 0)
                + distance
            )

        for country in trip.countries:
            if country.alpha_2 == "GB":
                continue
            yearly_stats[year].setdefault("countries", set())
            yearly_stats[year]["countries"].add(country)

        travel_legs(trip, yearly_stats[year])

    return dict(yearly_stats)