diff --git a/agenda/data.py b/agenda/data.py index e928c4d..473c3c2 100644 --- a/agenda/data.py +++ b/agenda/data.py @@ -36,7 +36,7 @@ from . import ( uk_tz, waste_schedule, ) -from .types import Event, Holiday, StrDict +from .types import Event, StrDict here = dateutil.tz.tzlocal() @@ -274,9 +274,6 @@ async def get_data( result_list = await asyncio.gather( time_function("gbpusd", fx.get_gbpusd, config), time_function("gwr_advance_tickets", gwr.advance_ticket_date, data_dir), - time_function( - "bank_holiday", uk_holiday.bank_holiday_list, last_year, next_year, data_dir - ), time_function("rockets", thespacedevs.get_launches, rocket_dir, limit=40), time_function("backwell_bins", waste_collection_events, data_dir), time_function("bristol_bins", bristol_waste_collection_events, data_dir, today), @@ -318,34 +315,13 @@ async def get_data( events.append(Event(name="gwr_advance_tickets", date=gwr_advance_tickets)) us_hols = holidays.us_holidays(last_year, next_year) - - holiday_list: list[Holiday] = results["bank_holiday"] + us_hols - for country in ( - "at", - "be", - "br", - "ch", - "cz", - "de", - "dk", - "ee", - "es", - "fi", - "fr", - "gr", - "it", - "ke", - "nl", - "pl", - ): - holiday_list += holidays.get_holidays(country, last_year, next_year) - events += holidays.get_nyse_holidays(last_year, next_year, us_hols) accommodation_events = accommodation.get_events( os.path.join(my_data, "accommodation.yaml") ) + holiday_list = holidays.get_all(last_year, next_year, data_dir) events += holidays.combine_holidays(holiday_list) events += birthday.get_birthdays(last_year, os.path.join(my_data, "entities.yaml")) events += accommodation_events diff --git a/agenda/holidays.py b/agenda/holidays.py index b95da82..08782b0 100644 --- a/agenda/holidays.py +++ b/agenda/holidays.py @@ -3,6 +3,7 @@ import collections from datetime import date, timedelta +import agenda.uk_holiday import holidays from .types import Event, Holiday @@ -117,3 +118,33 @@ def combine_holidays(holidays: list[Holiday]) -> list[Event]: events.append(e) return events + + +def get_all(last_year: date, next_year: date, data_dir: str) -> list[Holiday]: + """Get holidays for various countries and return as a list.""" + us_hols = us_holidays(last_year, next_year) + + bank_holidays = agenda.uk_holiday.bank_holiday_list(last_year, next_year, data_dir) + + holiday_list: list[Holiday] = bank_holidays + us_hols + for country in ( + "at", + "be", + "br", + "ch", + "cz", + "de", + "dk", + "ee", + "es", + "fi", + "fr", + "gr", + "it", + "ke", + "nl", + "pl", + ): + holiday_list += get_holidays(country, last_year, next_year) + + return holiday_list diff --git a/agenda/uk_holiday.py b/agenda/uk_holiday.py index 1747f12..3bdbfd2 100644 --- a/agenda/uk_holiday.py +++ b/agenda/uk_holiday.py @@ -3,7 +3,6 @@ import json import os from datetime import date, datetime, timedelta -from time import time import httpx from dateutil.easter import easter @@ -13,9 +12,15 @@ from .types import Holiday, StrDict url = "https://www.gov.uk/bank-holidays.json" +def json_filename(data_dir: str) -> str: + """Filename for cached bank holidays.""" + assert os.path.exists(data_dir) + return os.path.join(data_dir, "bank-holidays.json") + + async def get_holiday_list(data_dir: str) -> list[StrDict]: """Download holiday list and save cache.""" - filename = os.path.join(data_dir, "bank-holidays.json") + filename = json_filename(data_dir) async with httpx.AsyncClient() as client: r = await client.get(url) @@ -24,27 +29,12 @@ async def get_holiday_list(data_dir: str) -> list[StrDict]: return events -async def bank_holiday_list( - start_date: date, end_date: date, data_dir: str -) -> list[Holiday]: +def bank_holiday_list(start_date: date, end_date: date, data_dir: str) -> list[Holiday]: """Date and name of the next UK bank holiday.""" - filename = os.path.join(data_dir, "bank-holidays.json") - use_cached = False - events: list[StrDict] - if os.path.exists(filename): - mtime = os.path.getmtime(filename) - if (time() - mtime) < 60 * 60 * 6: # six hours - use_cached = True - try: - events = json.load(open(filename))["england-and-wales"]["events"] - except json.decoder.JSONDecodeError: - use_cached = False - - if not use_cached: - events = await get_holiday_list(data_dir) + filename = json_filename(data_dir) hols: list[Holiday] = [] - for event in events: + for event in json.load(open(filename))["england-and-wales"]["events"]: event_date = datetime.strptime(event["date"], "%Y-%m-%d").date() if event_date < start_date: continue