diff --git a/agenda/schengen.py b/agenda/schengen.py index 7f1d355..016cc01 100644 --- a/agenda/schengen.py +++ b/agenda/schengen.py @@ -2,8 +2,8 @@ from datetime import date, datetime, timedelta +from .trip import depart_datetime from .types import SchengenCalculation, SchengenStay, StrDict -from .utils import depart_datetime # Schengen Area countries as of 2025 SCHENGEN_COUNTRIES = { diff --git a/agenda/trip.py b/agenda/trip.py index 9aad250..978434d 100644 --- a/agenda/trip.py +++ b/agenda/trip.py @@ -3,14 +3,14 @@ import decimal import os import typing -from datetime import date +from datetime import date, datetime, time +from zoneinfo import ZoneInfo import flask import yaml -from agenda import travel, trip_schengen +from agenda import travel from agenda.types import StrDict, Trip -from agenda.utils import depart_datetime class Airline(typing.TypedDict, total=False): @@ -122,6 +122,18 @@ def load_coaches( return coaches +def depart_datetime(item: StrDict) -> datetime: + """Return a datetime for this travel item. + + If the travel item already has a datetime return that, otherwise if the + departure time is just a date return midnight UTC for that date. + """ + depart = item["depart"] + if isinstance(depart, datetime): + return depart + return datetime.combine(depart, time.min).replace(tzinfo=ZoneInfo("UTC")) + + def process_flight( flight: StrDict, by_iata: dict[str, Airline], airports: list[StrDict] ) -> None: @@ -397,9 +409,7 @@ def get_trip_routes(trip: Trip, data_dir: str) -> list[StrDict]: # Use GeoJSON route when available, otherwise draw straight line if t.get("geojson_filename"): filename = os.path.join("coach_routes", t["geojson_filename"]) - routes.append( - {"type": "coach", "key": key, "geojson_filename": filename} - ) + routes.append({"type": "coach", "key": key, "geojson_filename": filename}) else: routes.append( { @@ -414,9 +424,7 @@ def get_trip_routes(trip: Trip, data_dir: str) -> list[StrDict]: for leg in t["legs"]: train_from, train_to = leg["from_station"], leg["to_station"] geojson_filename = train_from.get("routes", {}).get(train_to["name"]) - key = "_".join( - ["train"] + sorted([train_from["name"], train_to["name"]]) - ) + key = "_".join(["train"] + sorted([train_from["name"], train_to["name"]])) if not geojson_filename: routes.append( { @@ -436,9 +444,7 @@ def get_trip_routes(trip: Trip, data_dir: str) -> list[StrDict]: { "type": "train", "key": key, - "geojson_filename": os.path.join( - "train_routes", geojson_filename - ), + "geojson_filename": os.path.join("train_routes", geojson_filename), } ) @@ -490,33 +496,3 @@ def get_coordinates_and_routes( route["geojson"] = read_geojson(data_dir, route.pop("geojson_filename")) return (coordinates, routes) - - -def get_trip_list( - route_distances: travel.RouteDistances | None = None, -) -> list[Trip]: - """Get list of trips respecting current authentication status.""" - trips = [ - trip - for trip in build_trip_list(route_distances=route_distances) - if flask.g.user.is_authenticated or not trip.private - ] - - # Add Schengen compliance information to each trip - for trip in trips: - trip_schengen.add_schengen_compliance_to_trip(trip) - - return trips - - -def get_current_trip(today: date) -> Trip | None: - """Get current trip.""" - trip_list = get_trip_list(route_distances=None) - - current = [ - item - for item in trip_list - if item.start <= today and (item.end or item.start) >= today - ] - assert len(current) < 2 - return current[0] if current else None diff --git a/agenda/trip_schengen.py b/agenda/trip_schengen.py index 6e34044..8d3cd8b 100644 --- a/agenda/trip_schengen.py +++ b/agenda/trip_schengen.py @@ -7,37 +7,26 @@ from datetime import date, timedelta import flask from . import get_country, trip -from .schengen import ( - SCHENGEN_COUNTRIES, - calculate_schengen_time, - extract_schengen_stays_from_travel, -) +from .schengen import calculate_schengen_time, extract_schengen_stays_from_travel from .types import SchengenCalculation, SchengenStay, StrDict, Trip -def trip_includes_schengen(trip: Trip) -> bool: - return bool({c.alpha_2.lower() for c in trip.countries} & SCHENGEN_COUNTRIES) - - -def add_schengen_compliance_to_trip(trip: Trip) -> Trip: +def add_schengen_compliance_to_trip(trip_obj: Trip) -> Trip: """Add Schengen compliance information to a trip object.""" - if not trip_includes_schengen(trip): - return trip - try: # Calculate Schengen compliance for the trip - calculation = calculate_schengen_time(trip.travel) + calculation = calculate_schengen_time(trip_obj.travel) # Add the calculation to the trip object - trip.schengen_compliance = calculation + trip_obj.schengen_compliance = calculation except Exception as e: # Log the error but don't fail the trip loading logging.warning( - f"Failed to calculate Schengen compliance for trip {trip.start}: {e}" + f"Failed to calculate Schengen compliance for trip {trip_obj.start}: {e}" ) - trip.schengen_compliance = None + trip_obj.schengen_compliance = None - return trip + return trip_obj def get_schengen_compliance_for_all_trips( @@ -138,9 +127,7 @@ def schengen_dashboard_data(data_dir: str | None = None) -> dict[str, typing.Any data_dir = flask.current_app.config["PERSONAL_DATA"] # Load all trips - trip_list = [ - trip for trip in trip.build_trip_list(data_dir) if trip_includes_schengen(trip) - ] + trip_list = trip.build_trip_list(data_dir) # Calculate current compliance with trip information all_travel_items = [] diff --git a/agenda/types.py b/agenda/types.py index d6041ed..d1ecdda 100644 --- a/agenda/types.py +++ b/agenda/types.py @@ -55,43 +55,6 @@ def airport_label(airport: StrDict) -> str: return f"{name} ({airport['iata']})" -@dataclass -class SchengenStay: - """Represents a stay in the Schengen area.""" - - entry_date: date - exit_date: date | None # None if currently in Schengen - country: str - days: int - trip_date: date | None = None # Trip start date for linking - trip_name: str | None = None # Trip name for display - - def __post_init__(self) -> None: - """Post init.""" - if self.exit_date is None: - # Currently in Schengen, calculate days up to today - self.days = (date.today() - self.entry_date).days + 1 - else: - self.days = (self.exit_date - self.entry_date).days + 1 - - -@dataclass -class SchengenCalculation: - """Result of Schengen time calculation.""" - - total_days_used: int - days_remaining: int - is_compliant: bool - current_180_day_period: tuple[date, date] # (start, end) - stays_in_period: SchengenStay - next_reset_date: typing.Optional[date] # When the 180-day window resets - - @property - def days_over_limit(self) -> int: - """Days over the 90-day limit.""" - return max(0, self.total_days_used - 90) - - @dataclass class Trip: """Trip.""" @@ -104,7 +67,7 @@ class Trip: flight_bookings: list[StrDict] = field(default_factory=list) name: str | None = None private: bool = False - schengen_compliance: SchengenCalculation | None = None + schengen_compliance: typing.Optional["SchengenCalculation"] = None @property def title(self) -> str: @@ -446,3 +409,39 @@ class Holiday: if self.local_name and self.local_name != self.name else self.name ) + + +@dataclass +class SchengenStay: + """Represents a stay in the Schengen area.""" + + entry_date: date + exit_date: typing.Optional[date] # None if currently in Schengen + country: str + days: int + trip_date: typing.Optional[date] = None # Trip start date for linking + trip_name: typing.Optional[str] = None # Trip name for display + + def __post_init__(self) -> None: + if self.exit_date is None: + # Currently in Schengen, calculate days up to today + self.days = (date.today() - self.entry_date).days + 1 + else: + self.days = (self.exit_date - self.entry_date).days + 1 + + +@dataclass +class SchengenCalculation: + """Result of Schengen time calculation.""" + + total_days_used: int + days_remaining: int + is_compliant: bool + current_180_day_period: tuple[date, date] # (start, end) + stays_in_period: list["SchengenStay"] + next_reset_date: typing.Optional[date] # When the 180-day window resets + + @property + def days_over_limit(self) -> int: + """Days over the 90-day limit.""" + return max(0, self.total_days_used - 90) diff --git a/agenda/utils.py b/agenda/utils.py index 15c9100..5f8573c 100644 --- a/agenda/utils.py +++ b/agenda/utils.py @@ -2,10 +2,8 @@ import os import typing -from datetime import date, datetime, time, timedelta, timezone -from zoneinfo import ZoneInfo - -from .types import StrDict +from datetime import date, datetime, timedelta, timezone +from time import time def as_date(d: datetime | date) -> date: @@ -120,15 +118,3 @@ async def time_function( exception = e end_time = time() return name, result, end_time - start_time, exception - - -def depart_datetime(item: StrDict) -> datetime: - """Return a datetime for this travel item. - - If the travel item already has a datetime return that, otherwise if the - departure time is just a date return midnight UTC for that date. - """ - depart = item["depart"] - if isinstance(depart, datetime): - return depart - return datetime.combine(depart, time.min).replace(tzinfo=ZoneInfo("UTC")) diff --git a/templates/schengen_report.html b/templates/schengen_report.html index b34c02d..9222785 100644 --- a/templates/schengen_report.html +++ b/templates/schengen_report.html @@ -162,8 +162,8 @@