"""Travel."""

import os
import typing
from datetime import date

import yaml

from .types import Event

Leg = dict[str, str]


def get(
    from_date: date,
    method: str,
    filepath: str,
    extra: typing.Callable[[Leg], str] | None = None,
) -> list[Event]:
    """Get travel events."""

    def title(item: Leg) -> str:
        ret = f'{method} from {item["from"]} to {item["to"]}'
        if extra:
            ret += f" ({extra(item)})"
        return ret

    return [
        Event(
            date=item["depart"],
            end_date=item["arrive"],
            name="transport",
            title=title(item),
            url=item.get("url"),
        )
        for item in yaml.safe_load(open(filepath))
        if item["depart"].date() >= from_date
    ]


def get_trains(from_date: date, filepath: str) -> list[Event]:
    """Get train events."""
    events: list[Event] = []
    for item in yaml.safe_load(open(filepath)):
        if item["depart"].date() < from_date:
            continue
        events += [
            Event(
                date=leg["depart"],
                end_date=leg["arrive"],
                name="transport",
                title=f'train from {leg["from"]} to {leg["to"]}',
                url=item.get("url"),
            )
            for leg in item["legs"]
        ]
    return events


def flight_number(flight: Leg) -> str:
    """Flight number."""
    airline_code = flight["airline"]
    # make sure this is the airline code, not the airline name
    assert " " not in airline_code and not any(c.islower() for c in airline_code)

    return airline_code + flight["flight_number"]


def all_events(from_date: date, data_dir: str) -> list[Event]:
    """Get all flights and rail journeys."""
    trains = get_trains(from_date, os.path.join(data_dir, "trains.yaml"))
    flights = get(
        from_date, "flight", os.path.join(data_dir, "flights.yaml"), extra=flight_number
    )

    return trains + flights