From 29d5145b871e9f30aad98dfa9e4dae2c6b1e367e Mon Sep 17 00:00:00 2001 From: Edward Betts Date: Wed, 16 Jul 2025 12:08:19 +0200 Subject: [PATCH] Refactor get_location_for_date to use trip data directly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Simplify the location tracking function by extracting travel data directly from trip objects instead of requiring separate YAML file parameters. Changes: - Remove airport, train, and ferry location helper functions that required separate YAML data lookups - Update get_location_for_date signature to only take target_date and trips - Extract flight/train/ferry details directly from trip.travel items - Use embedded airport/station/terminal objects from trip data - Remove YAML file parsing from weekends function - Update test calls to use new simplified signature This eliminates duplicate data loading and simplifies the API while maintaining all existing functionality. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- agenda/busy.py | 556 +++++++++++++++++++++------------------------ tests/test_busy.py | 71 ------ 2 files changed, 260 insertions(+), 367 deletions(-) diff --git a/agenda/busy.py b/agenda/busy.py index 9d7e61d..8452903 100644 --- a/agenda/busy.py +++ b/agenda/busy.py @@ -92,125 +92,20 @@ def _parse_datetime_field(datetime_obj: datetime | date) -> tuple[datetime, date raise ValueError(f"Invalid datetime format: {datetime_obj}") -def _get_airport_location( - airport_code: str, airports: StrDict, uk_airports: set[str], on_trip: bool = False -) -> tuple[str | None, pycountry.db.Country | None]: - """Get location from airport code.""" - if airport_code in uk_airports: - if on_trip: - # When on a trip, show the actual location even for UK airports - airport_info = airports.get(airport_code) - if airport_info: - location_name = airport_info.get( - "city", airport_info.get("name", "London") - ) - return (location_name, get_country("gb")) - else: - return ("London", get_country("gb")) - else: - # When not on a trip, UK airports mean home - return (None, get_country("gb")) - else: - # Non-UK airports - airport_info = airports.get(airport_code) - if airport_info: - location_name = airport_info.get( - "city", airport_info.get("name", airport_code) - ) - return (location_name, get_country(airport_info.get("country", "gb"))) - else: - return (airport_code, get_country("gb")) - - def _get_accommodation_location( acc: StrDict, on_trip: bool = False -) -> tuple[str | None, pycountry.db.Country | None]: +) -> tuple[str | None, pycountry.db.Country]: """Get location from accommodation data.""" - if acc.get("country") == "gb": - if on_trip: - # When on a trip, show the actual location even for UK accommodations - return (acc.get("location", "London"), get_country("gb")) - else: - # When not on a trip, UK accommodation means home - return (None, get_country("gb")) - else: - return (acc.get("location", "Unknown"), get_country(acc.get("country", "gb"))) - - -def _get_train_location( - train_leg: StrDict, stations: StrDict, on_trip: bool = False -) -> tuple[str | None, pycountry.db.Country | None]: - """Get location from train leg data.""" - destination = train_leg.get("to") - if not destination: - return (None, get_country("gb")) - - # Find station info - station_info = None - for station in stations: - if station.get("name") == destination: - station_info = station - break - - if not station_info: - return (destination, get_country("gb")) - - station_country = station_info.get("country", "gb") - - if station_country == "gb": - if on_trip: - # When on a trip, show the actual location even for UK stations - return (destination, get_country("gb")) - else: - # When not on a trip, UK stations mean home - return (None, get_country("gb")) - else: - return (destination, get_country(station_country)) - - -def _get_ferry_location( - ferry: StrDict, terminals: StrDict, on_trip: bool = False -) -> tuple[str | None, pycountry.db.Country | None]: - """Get location from ferry data.""" - destination = ferry.get("to") - if not destination: - return (None, get_country("gb")) - - # Find terminal info - terminal_info = None - for terminal in terminals: - if terminal.get("name") == destination: - terminal_info = terminal - break - - if not terminal_info: - return (destination, get_country("gb")) - - terminal_country = terminal_info.get("country", "gb") - terminal_city = terminal_info.get("city", destination) - - if terminal_country == "gb": - if on_trip: - # When on a trip, show the actual location even for UK terminals - return (terminal_city, get_country("gb")) - else: - # When not on a trip, UK terminals mean home - return (None, get_country("gb")) - else: - return (terminal_city, get_country(terminal_country)) + c = get_country(acc["country"]) + assert c + assert isinstance(acc["location"], str) + return (acc["location"] if on_trip else None, c) def _find_most_recent_travel_within_trip( trip: Trip, target_date: date, - bookings: list[StrDict], - accommodations: list[StrDict], - airports: StrDict, - trains: list[StrDict] | None = None, - stations: StrDict | None = None, - ferries: list[StrDict] | None = None, - terminals: StrDict | None = None, -) -> tuple[str | None, pycountry.db.Country | None] | None: +) -> tuple[str | None, pycountry.db.Country] | None: """Find the most recent travel location within a trip.""" uk_airports = {"LHR", "LGW", "STN", "LTN", "BRS", "BHX", "MAN", "EDI", "GLA"} @@ -219,39 +114,54 @@ def _find_most_recent_travel_within_trip( trip_most_recent_datetime = None # Check flights within trip period - for booking in bookings: - for flight in booking.get("flights", []): - if "arrive" in flight: - try: - arrive_datetime, arrive_date = _parse_datetime_field( - flight["arrive"] - ) - except ValueError: - continue + for travel_item in trip.travel: + if travel_item["type"] == "flight" and "arrive" in travel_item: + arrive_datetime, arrive_date = _parse_datetime_field(travel_item["arrive"]) - # Only consider flights within this trip and before target date - if trip.start <= arrive_date <= target_date: - # Compare both date and time to handle same-day flights correctly - if ( - trip_most_recent_date is None - or arrive_date > trip_most_recent_date - or ( - arrive_date == trip_most_recent_date - and ( - trip_most_recent_datetime is None - or arrive_datetime > trip_most_recent_datetime - ) + # Only consider flights within this trip and before target date + if not (trip.start <= arrive_date <= target_date): + continue + # Compare both date and time to handle same-day flights correctly + if ( + trip_most_recent_date is None + or arrive_date > trip_most_recent_date + or ( + arrive_date == trip_most_recent_date + and ( + trip_most_recent_datetime is None + or arrive_datetime > trip_most_recent_datetime + ) + ) + ): + trip_most_recent_date = arrive_date + trip_most_recent_datetime = arrive_datetime + destination_airport = travel_item["to"] + assert "to_airport" in travel_item + airport_info = travel_item["to_airport"] + airport_country = airport_info["country"] + if airport_country == "gb": + if destination_airport in uk_airports: + # UK airport while on trip - show actual location + location_name = airport_info.get( + "city", airport_info.get("name", "London") ) - ): - trip_most_recent_date = arrive_date - trip_most_recent_datetime = arrive_datetime - destination_airport = flight["to"] - trip_most_recent_location = _get_airport_location( - destination_airport, airports, uk_airports, on_trip=True + trip_most_recent_location = ( + location_name, + get_country("gb"), ) + else: + trip_most_recent_location = (None, get_country("gb")) + else: + location_name = airport_info.get( + "city", airport_info.get("name", destination_airport) + ) + trip_most_recent_location = ( + location_name, + get_country(airport_country), + ) # Check accommodations within trip period - for acc in accommodations: + for acc in trip.accommodation: if "from" in acc: try: _, acc_date = _parse_datetime_field(acc["from"]) @@ -273,9 +183,9 @@ def _find_most_recent_travel_within_trip( ) # Check trains within trip period - if trains and stations: - for train in trains: - for leg in train.get("legs", []): + for travel_item in trip.travel: + if travel_item["type"] == "train": + for leg in travel_item.get("legs", []): if "arrive" in leg: try: arrive_datetime, arrive_date = _parse_datetime_field( @@ -300,39 +210,63 @@ def _find_most_recent_travel_within_trip( ): trip_most_recent_date = arrive_date trip_most_recent_datetime = arrive_datetime - trip_most_recent_location = _get_train_location( - leg, stations, on_trip=True - ) + # For trains, we can get station info from to_station if available + destination = leg.get("to") + assert "to_station" in leg + station_info = leg["to_station"] + station_country = station_info["country"] + if station_country == "gb": + trip_most_recent_location = ( + destination, + get_country("gb"), + ) + else: + trip_most_recent_location = ( + destination, + get_country(station_country), + ) # Check ferries within trip period - if ferries and terminals: - for ferry in ferries: - if "arrive" in ferry: - try: - arrive_datetime, arrive_date = _parse_datetime_field( - ferry["arrive"] - ) - except ValueError: - continue + for travel_item in trip.travel: + if travel_item["type"] == "ferry" and "arrive" in travel_item: + try: + arrive_datetime, arrive_date = _parse_datetime_field( + travel_item["arrive"] + ) + except ValueError: + continue - # Only consider ferries within this trip and before target date - if trip.start <= arrive_date <= target_date: - # Compare both date and time to handle same-day arrivals correctly - if ( - trip_most_recent_date is None - or arrive_date > trip_most_recent_date - or ( - arrive_date == trip_most_recent_date - and ( - trip_most_recent_datetime is None - or arrive_datetime > trip_most_recent_datetime - ) + # Only consider ferries within this trip and before target date + if trip.start <= arrive_date <= target_date: + # Compare both date and time to handle same-day arrivals correctly + if ( + trip_most_recent_date is None + or arrive_date > trip_most_recent_date + or ( + arrive_date == trip_most_recent_date + and ( + trip_most_recent_datetime is None + or arrive_datetime > trip_most_recent_datetime ) - ): - trip_most_recent_date = arrive_date - trip_most_recent_datetime = arrive_datetime - trip_most_recent_location = _get_ferry_location( - ferry, terminals, on_trip=True + ) + ): + trip_most_recent_date = arrive_date + trip_most_recent_datetime = arrive_datetime + # For ferries, we can get terminal info from to_terminal if available + destination = travel_item.get("to") + assert "to_terminal" in travel_item + terminal_info = travel_item["to_terminal"] + terminal_country = terminal_info.get("country", "gb") + terminal_city = terminal_info.get("city", destination) + if terminal_country == "gb": + trip_most_recent_location = ( + terminal_city, + get_country("gb"), + ) + else: + trip_most_recent_location = ( + terminal_city, + get_country(terminal_country), ) return trip_most_recent_location @@ -366,13 +300,7 @@ def _get_trip_location_by_progression( def _find_most_recent_travel_before_date( target_date: date, - bookings: list[StrDict], - accommodations: list[StrDict], - airports: StrDict, - trains: list[StrDict] | None = None, - stations: StrDict | None = None, - ferries: list[StrDict] | None = None, - terminals: StrDict | None = None, + trips: list[Trip], ) -> tuple[str | None, pycountry.db.Country | None] | None: """Find the most recent travel location before a given date.""" uk_airports = {"LHR", "LGW", "STN", "LTN", "BRS", "BHX", "MAN", "EDI", "GLA"} @@ -381,13 +309,14 @@ def _find_most_recent_travel_before_date( most_recent_date = None most_recent_datetime = None - # Check flights - for booking in bookings: - for flight in booking.get("flights", []): - if "arrive" in flight: + # Check all travel across all trips + for trip in trips: + # Check flights + for travel_item in trip.travel: + if travel_item["type"] == "flight" and "arrive" in travel_item: try: arrive_datetime, arrive_date = _parse_datetime_field( - flight["arrive"] + travel_item["arrive"] ) except ValueError: continue @@ -407,65 +336,105 @@ def _find_most_recent_travel_before_date( ): most_recent_date = arrive_date most_recent_datetime = arrive_datetime - destination_airport = flight["to"] - most_recent_location = _get_airport_location( - destination_airport, airports, uk_airports, on_trip=False + destination_airport = travel_item["to"] + # For flights, determine if we're "on trip" based on whether this is within any trip period + on_trip = any( + t.start <= arrive_date <= (t.end or t.start) for t in trips ) - # Check accommodation - only override if accommodation is more recent - for acc in accommodations: - if "from" in acc: - try: - _, acc_date = _parse_datetime_field(acc["from"]) - except ValueError: - continue - - if acc_date <= target_date: - # Only update if this accommodation is more recent than existing result - if most_recent_date is None or acc_date > most_recent_date: - most_recent_date = acc_date - most_recent_location = _get_accommodation_location( - acc, on_trip=False - ) - - # Check trains - if trains and stations: - for train in trains: - for leg in train.get("legs", []): - if "arrive" in leg: - try: - arrive_datetime, arrive_date = _parse_datetime_field( - leg["arrive"] - ) - except ValueError: - continue - - if arrive_date <= target_date: - # Compare both date and time to handle same-day arrivals correctly - if ( - most_recent_date is None - or arrive_date > most_recent_date - or ( - arrive_date == most_recent_date - and ( - most_recent_datetime is None - or arrive_datetime > most_recent_datetime + if "to_airport" in travel_item: + airport_info = travel_item["to_airport"] + airport_country = airport_info.get("country", "gb") + if airport_country == "gb": + if not on_trip: + # When not on a trip, UK airports mean home + most_recent_location = (None, get_country("gb")) + else: + # When on a trip, show the actual location even for UK airports + location_name = airport_info.get( + "city", airport_info.get("name", "London") + ) + most_recent_location = ( + location_name, + get_country("gb"), + ) + else: + location_name = airport_info.get( + "city", + airport_info.get("name", destination_airport), ) - ) - ): - most_recent_date = arrive_date - most_recent_datetime = arrive_datetime - most_recent_location = _get_train_location( - leg, stations, on_trip=False + most_recent_location = ( + location_name, + get_country(airport_country), + ) + else: + most_recent_location = ( + destination_airport, + get_country("gb"), ) - # Check ferries - if ferries and terminals: - for ferry in ferries: - if "arrive" in ferry: + # Check trains + elif travel_item["type"] == "train": + for leg in travel_item.get("legs", []): + if "arrive" in leg: + try: + arrive_datetime, arrive_date = _parse_datetime_field( + leg["arrive"] + ) + except ValueError: + continue + + if arrive_date <= target_date: + # Compare both date and time to handle same-day arrivals correctly + if ( + most_recent_date is None + or arrive_date > most_recent_date + or ( + arrive_date == most_recent_date + and ( + most_recent_datetime is None + or arrive_datetime > most_recent_datetime + ) + ) + ): + most_recent_date = arrive_date + most_recent_datetime = arrive_datetime + destination = leg.get("to") + on_trip = any( + t.start <= arrive_date <= (t.end or t.start) + for t in trips + ) + + if "to_station" in leg: + station_info = leg["to_station"] + station_country = station_info.get("country", "gb") + if station_country == "gb": + if not on_trip: + most_recent_location = ( + None, + get_country("gb"), + ) + else: + most_recent_location = ( + destination, + get_country("gb"), + ) + else: + most_recent_location = ( + destination, + get_country(station_country), + ) + else: + most_recent_location = ( + destination, + get_country("gb"), + ) + + # Check ferries + elif travel_item["type"] == "ferry" and "arrive" in travel_item: try: arrive_datetime, arrive_date = _parse_datetime_field( - ferry["arrive"] + travel_item["arrive"] ) except ValueError: continue @@ -485,8 +454,48 @@ def _find_most_recent_travel_before_date( ): most_recent_date = arrive_date most_recent_datetime = arrive_datetime - most_recent_location = _get_ferry_location( - ferry, terminals, on_trip=False + destination = travel_item.get("to") + on_trip = any( + t.start <= arrive_date <= (t.end or t.start) for t in trips + ) + + if "to_terminal" in travel_item: + terminal_info = travel_item["to_terminal"] + terminal_country = terminal_info.get("country", "gb") + terminal_city = terminal_info.get("city", destination) + if terminal_country == "gb": + if not on_trip: + most_recent_location = (None, get_country("gb")) + else: + most_recent_location = ( + terminal_city, + get_country("gb"), + ) + else: + most_recent_location = ( + terminal_city, + get_country(terminal_country), + ) + else: + most_recent_location = (destination, get_country("gb")) + + # Check accommodation - only override if accommodation is more recent + for acc in trip.accommodation: + if "from" in acc: + try: + _, acc_date = _parse_datetime_field(acc["from"]) + except ValueError: + continue + + if acc_date <= target_date: + # Only update if this accommodation is more recent than existing result + if most_recent_date is None or acc_date > most_recent_date: + most_recent_date = acc_date + on_trip = any( + t.start <= acc_date <= (t.end or t.start) for t in trips + ) + most_recent_location = _get_accommodation_location( + acc, on_trip=on_trip ) return most_recent_location @@ -514,49 +523,27 @@ def _check_return_home_heuristic( def get_location_for_date( target_date: date, trips: list[Trip], - bookings: list[StrDict], - accommodations: list[StrDict], - airports: StrDict, - trains: list[StrDict] | None = None, - stations: StrDict | None = None, - ferries: list[StrDict] | None = None, - terminals: StrDict | None = None, ) -> tuple[str | None, pycountry.db.Country | None]: """Get location (city, country) for a specific date using travel history.""" # First check if currently on a trip for trip in trips: - if trip.start <= target_date <= (trip.end or trip.start): - # For trips, find the most recent travel within the trip period - trip_location = _find_most_recent_travel_within_trip( - trip, - target_date, - bookings, - accommodations, - airports, - trains, - stations, - ferries, - terminals, - ) - if trip_location: - return trip_location + if not (trip.start <= target_date <= (trip.end or trip.start)): + continue + # For trips, find the most recent travel within the trip period + trip_location = _find_most_recent_travel_within_trip( + trip, + target_date, + ) + if trip_location: + return trip_location - # Fallback: determine location based on trip progression and date - progression_location = _get_trip_location_by_progression(trip, target_date) - if progression_location: - return progression_location + # Fallback: determine location based on trip progression and date + progression_location = _get_trip_location_by_progression(trip, target_date) + if progression_location: + return progression_location # Find most recent travel before this date - recent_travel = _find_most_recent_travel_before_date( - target_date, - bookings, - accommodations, - airports, - trains, - stations, - ferries, - terminals, - ) + recent_travel = _find_most_recent_travel_before_date(target_date, trips) # Check for recent trips that have ended - prioritize this over individual travel data # This handles cases where you're traveling home after a trip (e.g. stopovers, connections) @@ -583,15 +570,6 @@ def weekends( else: start_date = start + timedelta(days=(5 - weekday)) - # Parse YAML files once for all location lookups - bookings = travel.parse_yaml("flights", data_dir) - accommodations = travel.parse_yaml("accommodation", data_dir) - airports = travel.parse_yaml("airports", data_dir) - trains = travel.parse_yaml("trains", data_dir) - stations = travel.parse_yaml("stations", data_dir) - ferries = travel.parse_yaml("ferries", data_dir) - terminals = travel.parse_yaml("ferry_terminals", data_dir) - weekends_info = [] for i in range(52): saturday = start_date + timedelta(weeks=i) @@ -611,24 +589,10 @@ def weekends( saturday_location = get_location_for_date( saturday, trips, - bookings, - accommodations, - airports, - trains, - stations, - ferries, - terminals, ) sunday_location = get_location_for_date( sunday, trips, - bookings, - accommodations, - airports, - trains, - stations, - ferries, - terminals, ) weekends_info.append( diff --git a/tests/test_busy.py b/tests/test_busy.py index 026eeb9..8a12d8c 100644 --- a/tests/test_busy.py +++ b/tests/test_busy.py @@ -73,9 +73,6 @@ def test_specific_home_dates(travel_data): location = agenda.busy.get_location_for_date( test_date, trips, - travel_data["bookings"], - travel_data["accommodations"], - travel_data["airports"], ) assert not location[ 0 @@ -94,9 +91,6 @@ def test_specific_away_dates(travel_data): location = agenda.busy.get_location_for_date( test_date, trips, - travel_data["bookings"], - travel_data["accommodations"], - travel_data["airports"], ) assert ( location[0] == expected_city @@ -111,9 +105,6 @@ def test_get_location_for_date_basic(travel_data): location = agenda.busy.get_location_for_date( test_date, trips, - travel_data["bookings"], - travel_data["accommodations"], - travel_data["airports"], ) # Should return a tuple with (city|None, country) @@ -178,68 +169,6 @@ def test_parse_datetime_field(): assert parsed_dt.day == 1 -def test_train_location_helpers(): - """Test the train location helper functions.""" - from agenda.busy import _get_train_location - - # Mock station data - stations = [ - {"name": "London St Pancras", "country": "gb"}, - {"name": "Brussels Midi", "country": "be"}, - {"name": "Edinburgh Waverley", "country": "gb"}, - ] - - # Test UK station when not on trip (should return None for home) - train_leg = {"to": "London St Pancras"} - location = _get_train_location(train_leg, stations, on_trip=False) - assert location[0] is None # Should be home - assert location[1].alpha_2 == "GB" - - # Test UK station when on trip (should return city name) - location = _get_train_location(train_leg, stations, on_trip=True) - assert location[0] == "London St Pancras" - assert location[1].alpha_2 == "GB" - - # Test non-UK station - train_leg = {"to": "Brussels Midi"} - location = _get_train_location(train_leg, stations, on_trip=False) - assert location[0] == "Brussels Midi" - assert location[1].alpha_2 == "BE" - - -def test_ferry_location_helpers(): - """Test the ferry location helper functions.""" - from agenda.busy import _get_ferry_location - - # Mock terminal data - terminals = [ - {"name": "Dover Eastern Docks", "country": "gb", "city": "Dover"}, - {"name": "Calais Ferry Terminal", "country": "fr", "city": "Calais"}, - { - "name": "Portsmouth Continental Terminal", - "country": "gb", - "city": "Portsmouth", - }, - ] - - # Test UK terminal when not on trip (should return None for home) - ferry = {"to": "Dover Eastern Docks"} - location = _get_ferry_location(ferry, terminals, on_trip=False) - assert location[0] is None # Should be home - assert location[1].alpha_2 == "GB" - - # Test UK terminal when on trip (should return city name) - location = _get_ferry_location(ferry, terminals, on_trip=True) - assert location[0] == "Dover" - assert location[1].alpha_2 == "GB" - - # Test non-UK terminal - ferry = {"to": "Calais Ferry Terminal"} - location = _get_ferry_location(ferry, terminals, on_trip=False) - assert location[0] == "Calais" - assert location[1].alpha_2 == "FR" - - def test_get_busy_events(app_context, trips): """Test get_busy_events function.""" start_date = date(2023, 1, 1)