From 0c88ad46387b4838a8e2cc5756934a8bf31f20b1 Mon Sep 17 00:00:00 2001 From: Edward Betts Date: Sat, 7 Mar 2026 13:10:33 +0000 Subject: [PATCH] Validate weekends year range and fix trip page unbooked-flight helper call --- tests/test_trip.py | 4 ++- tests/test_trip_page_route.py | 66 +++++++++++++++++++++++++++++++++++ tests/test_weekends_route.py | 31 ++++++++++++++++ web_view.py | 27 +++++++++++--- 4 files changed, 123 insertions(+), 5 deletions(-) create mode 100644 tests/test_trip_page_route.py create mode 100644 tests/test_weekends_route.py diff --git a/tests/test_trip.py b/tests/test_trip.py index 8c07a3c..69c0e6e 100644 --- a/tests/test_trip.py +++ b/tests/test_trip.py @@ -44,7 +44,9 @@ def test_add_coordinates_for_unbooked_flights_adds_missing_airports() -> None: original_parse_yaml = agenda.trip.travel.parse_yaml try: agenda.trip.travel.parse_yaml = lambda _name, _data_dir: airports - agenda.trip.add_coordinates_for_unbooked_flights(routes, coordinates) + agenda.trip.add_coordinates_for_unbooked_flights( + routes, coordinates, app.config["PERSONAL_DATA"] + ) finally: agenda.trip.travel.parse_yaml = original_parse_yaml diff --git a/tests/test_trip_page_route.py b/tests/test_trip_page_route.py new file mode 100644 index 0000000..db7abb4 --- /dev/null +++ b/tests/test_trip_page_route.py @@ -0,0 +1,66 @@ +"""Regression tests for trip page route wiring.""" + +from datetime import date +import typing + +import web_view +from agenda.types import Trip + + +def test_trip_page_passes_data_dir_to_unbooked_flight_helper() -> None: + """Trip page should call helper with routes, coordinates and data_dir.""" + trip = Trip(start=date(2025, 1, 28)) + captured: dict[str, str] = {} + + with web_view.app.app_context(): + original_get_trip_list = web_view.get_trip_list + original_add_schengen = ( + web_view.agenda.trip_schengen.add_schengen_compliance_to_trip + ) + original_collect_trip_coordinates = ( + web_view.agenda.trip.collect_trip_coordinates + ) + original_get_trip_routes = web_view.agenda.trip.get_trip_routes + original_add_coordinates = ( + web_view.agenda.trip.add_coordinates_for_unbooked_flights + ) + original_get_trip_weather = web_view.agenda.weather.get_trip_weather + original_render_template = web_view.flask.render_template + try: + web_view.get_trip_list = lambda: [trip] + web_view.agenda.trip_schengen.add_schengen_compliance_to_trip = lambda t: t + web_view.agenda.trip.collect_trip_coordinates = lambda _trip: [] + web_view.agenda.trip.get_trip_routes = lambda _trip, _data_dir: [] + + def fake_add_coordinates( + _routes: list[typing.Any], + _coordinates: list[typing.Any], + data_dir: str, + ) -> None: + captured["data_dir"] = data_dir + + web_view.agenda.trip.add_coordinates_for_unbooked_flights = ( + fake_add_coordinates + ) + web_view.agenda.weather.get_trip_weather = lambda *_args, **_kwargs: [] + web_view.flask.render_template = lambda *_args, **_kwargs: "ok" + + with web_view.app.test_request_context("/trip/2025-01-28"): + result = web_view.trip_page("2025-01-28") + + assert result == "ok" + assert captured["data_dir"] == web_view.app.config["PERSONAL_DATA"] + finally: + web_view.get_trip_list = original_get_trip_list + web_view.agenda.trip_schengen.add_schengen_compliance_to_trip = ( + original_add_schengen + ) + web_view.agenda.trip.collect_trip_coordinates = ( + original_collect_trip_coordinates + ) + web_view.agenda.trip.get_trip_routes = original_get_trip_routes + web_view.agenda.trip.add_coordinates_for_unbooked_flights = ( + original_add_coordinates + ) + web_view.agenda.weather.get_trip_weather = original_get_trip_weather + web_view.flask.render_template = original_render_template diff --git a/tests/test_weekends_route.py b/tests/test_weekends_route.py new file mode 100644 index 0000000..46de40c --- /dev/null +++ b/tests/test_weekends_route.py @@ -0,0 +1,31 @@ +"""Tests for weekends route query validation.""" + +from datetime import date +import typing + +import pytest + +import web_view + + +@pytest.fixture # type: ignore[untyped-decorator] +def client() -> typing.Any: + """Flask test client.""" + web_view.app.config["TESTING"] = True + with web_view.app.test_client() as c: + yield c + + +def test_weekends_rejects_year_before_2020(client: typing.Any) -> None: + """Years before 2020 should return HTTP 400.""" + response = client.get("/weekends?year=2019&week=1") + assert response.status_code == 400 + assert b"Year must be between 2020" in response.data + + +def test_weekends_rejects_year_more_than_five_years_ahead(client: typing.Any) -> None: + """Years beyond current year + 5 should return HTTP 400.""" + too_far = date.today().year + 6 + response = client.get(f"/weekends?year={too_far}&week=1") + assert response.status_code == 400 + assert b"Year must be between 2020" in response.data diff --git a/web_view.py b/web_view.py index 4fa2d7c..18a62dc 100755 --- a/web_view.py +++ b/web_view.py @@ -268,6 +268,16 @@ async def gaps_page() -> str: async def weekends() -> str: """List of available weekends using an optional date, week, or year parameter.""" today = datetime.now().date() + min_year = 2020 + max_year = today.year + 5 + + def validate_year(year: int) -> None: + """Validate year parameter range for weekends page.""" + if year < min_year or year > max_year: + flask.abort( + 400, description=f"Year must be between {min_year} and {max_year}." + ) + date_str = flask.request.args.get("date") week_str = flask.request.args.get("week") year_str = flask.request.args.get("year") @@ -275,12 +285,14 @@ async def weekends() -> str: if date_str: try: start = datetime.strptime(date_str, "%Y-%m-%d").date() + validate_year(start.year) except ValueError: return flask.abort(400, description="Invalid date format. Use YYYY-MM-DD.") elif week_str: try: week = int(week_str) year = int(year_str) if year_str else today.year + validate_year(year) if week < 1 or week > 53: return flask.abort( 400, description="Week number must be between 1 and 53." @@ -293,6 +305,13 @@ async def weekends() -> str: return flask.abort( 400, description="Invalid week or year format. Use integers." ) + elif year_str: + try: + year = int(year_str) + validate_year(year) + start = date(year, 1, 1) + except ValueError: + return flask.abort(400, description="Invalid year format. Use an integer.") else: start = date(today.year, 1, 1) @@ -914,9 +933,7 @@ def get_destination_timezones(trip: Trip) -> list[StrDict]: if flight_country: flight_locations.append((city, flight_country)) - existing_location_keys = { - (loc, c.alpha_2.lower()) for loc, c in trip.locations() - } + existing_location_keys = {(loc, c.alpha_2.lower()) for loc, c in trip.locations()} all_locations = list(trip.locations()) + [ (city, country) for city, country in flight_locations @@ -1018,7 +1035,9 @@ def trip_page(start: str) -> str: coordinates = agenda.trip.collect_trip_coordinates(trip) routes = agenda.trip.get_trip_routes(trip, app.config["PERSONAL_DATA"]) - agenda.trip.add_coordinates_for_unbooked_flights(routes, coordinates) + agenda.trip.add_coordinates_for_unbooked_flights( + routes, coordinates, app.config["PERSONAL_DATA"] + ) for route in routes: if "geojson_filename" in route: