From b0612621202fc52b28635e176c42b7055bbc725c Mon Sep 17 00:00:00 2001 From: Edward Betts Date: Tue, 16 Jan 2024 07:42:44 +0000 Subject: [PATCH] Split code for holidays into separate file --- agenda/data.py | 125 +++------------------------------------------ agenda/holidays.py | 117 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 119 deletions(-) create mode 100644 agenda/holidays.py diff --git a/agenda/data.py b/agenda/data.py index 536e5e8..e928c4d 100644 --- a/agenda/data.py +++ b/agenda/data.py @@ -1,7 +1,6 @@ """Agenda data.""" import asyncio -import collections import itertools import os import typing @@ -11,7 +10,6 @@ from time import time import dateutil.rrule import dateutil.tz import flask -import holidays import isodate # type: ignore import lxml import pytz @@ -27,6 +25,7 @@ from . import ( fx, gwr, hn, + holidays, meetup, stock_market, subscription, @@ -61,72 +60,6 @@ def timezone_transition( ] -def us_holidays(start_date: date, end_date: date) -> list[Holiday]: - """Get US holidays.""" - found: list[Holiday] = [] - for year in range(start_date.year, end_date.year + 1): - hols = holidays.country_holidays("US", years=year, language="en") - found += [ - Holiday(date=hol_date, name=title, country="us") - for hol_date, title in hols.items() - if start_date < hol_date < end_date - ] - - extra = [] - for h in found: - if h.name != "Thanksgiving": - continue - extra += [ - Holiday(date=h.date + timedelta(days=1), name="Black Friday", country="us"), - Holiday(date=h.date + timedelta(days=4), name="Cyber Monday", country="us"), - ] - - return found + extra - - -def get_nyse_holidays( - start_date: date, end_date: date, us_hols: list[Holiday] -) -> list[Event]: - """NYSE holidays.""" - known_us_hols = {(h.date, h.name) for h in us_hols} - found: list[Event] = [] - rename = {"Thanksgiving Day": "Thanksgiving"} - for year in range(start_date.year, end_date.year + 1): - hols = holidays.financial_holidays("NYSE", years=year) - found += [ - Event( - name="holiday", - date=hol_date, - title=rename.get(title, title), - ) - for hol_date, title in hols.items() - if start_date < hol_date < end_date - ] - found = [hol for hol in found if (hol.date, hol.title) not in known_us_hols] - for hol in found: - assert hol.title - hol.title += " (NYSE)" - return found - - -def get_holidays(country: str, start_date: date, end_date: date) -> list[Holiday]: - """Get holidays.""" - found: list[Holiday] = [] - for year in range(start_date.year, end_date.year + 1): - hols = holidays.country_holidays(country.upper(), years=year, language="en_US") - found += [ - Holiday( - date=hol_date, - name=title, - country=country.lower(), - ) - for hol_date, title in hols.items() - if start_date < hol_date < end_date - ] - - return found - - def midnight(d: date) -> datetime: """Convert from date to midnight on that day.""" return datetime.combine(d, datetime.min.time()) @@ -166,52 +99,6 @@ async def bristol_waste_collection_events( return await waste_schedule.get_bristol_gov_uk(start_date, data_dir, uprn) -def combine_holidays(holidays: list[Holiday]) -> list[Event]: - """Combine UK and US holidays with the same date and title.""" - - all_countries = {h.country for h in holidays} - - standard_name = { - (1, 1): "New Year's Day", - (1, 6): "Epiphany", - (5, 1): "Labour Day", - (8, 15): "Assumption Day", - (12, 8): "Immaculate conception", - (12, 25): "Christmas Day", - (12, 26): "Boxing Day", - } - - combined: collections.defaultdict[ - tuple[date, str], set[str] - ] = collections.defaultdict(set) - - for h in holidays: - assert isinstance(h.name, str) and isinstance(h.date, date) - - event_key = (h.date, standard_name.get((h.date.month, h.date.day), h.name)) - combined[event_key].add(h.country) - - events: list[Event] = [] - for (d, name), countries in combined.items(): - if len(countries) == len(all_countries): - country_list = "" - elif len(countries) < len(all_countries) / 2: - country_list = ", ".join(sorted(country.upper() for country in countries)) - else: - country_list = "not " + ", ".join( - sorted(country.upper() for country in all_countries - set(countries)) - ) - - e = Event( - name="holiday", - date=d, - title=f"{name} ({country_list})" if country_list else name, - ) - events.append(e) - - return events - - def get_yaml_event_date_field(item: dict[str, str]) -> str: """Event date field name.""" return ( @@ -430,9 +317,9 @@ async def get_data( if gwr_advance_tickets: events.append(Event(name="gwr_advance_tickets", date=gwr_advance_tickets)) - us_hols = us_holidays(last_year, next_year) + us_hols = holidays.us_holidays(last_year, next_year) - holidays: list[Holiday] = results["bank_holiday"] + us_hols + holiday_list: list[Holiday] = results["bank_holiday"] + us_hols for country in ( "at", "be", @@ -451,15 +338,15 @@ async def get_data( "nl", "pl", ): - holidays += get_holidays(country, last_year, next_year) + holiday_list += holidays.get_holidays(country, last_year, next_year) - events += get_nyse_holidays(last_year, next_year, us_hols) + events += holidays.get_nyse_holidays(last_year, next_year, us_hols) accommodation_events = accommodation.get_events( os.path.join(my_data, "accommodation.yaml") ) - events += combine_holidays(holidays) + events += holidays.combine_holidays(holiday_list) events += birthday.get_birthdays(last_year, os.path.join(my_data, "entities.yaml")) events += accommodation_events events += travel.all_events(my_data) diff --git a/agenda/holidays.py b/agenda/holidays.py new file mode 100644 index 0000000..9ec7dc7 --- /dev/null +++ b/agenda/holidays.py @@ -0,0 +1,117 @@ +import collections +from datetime import date, timedelta + +import holidays + +from .types import Event, Holiday + + +def us_holidays(start_date: date, end_date: date) -> list[Holiday]: + """Get US holidays.""" + found: list[Holiday] = [] + for year in range(start_date.year, end_date.year + 1): + hols = holidays.country_holidays("US", years=year, language="en") + found += [ + Holiday(date=hol_date, name=title, country="us") + for hol_date, title in hols.items() + if start_date < hol_date < end_date + ] + + extra = [] + for h in found: + if h.name != "Thanksgiving": + continue + extra += [ + Holiday(date=h.date + timedelta(days=1), name="Black Friday", country="us"), + Holiday(date=h.date + timedelta(days=4), name="Cyber Monday", country="us"), + ] + + return found + extra + + +def get_nyse_holidays( + start_date: date, end_date: date, us_hols: list[Holiday] +) -> list[Event]: + """NYSE holidays.""" + known_us_hols = {(h.date, h.name) for h in us_hols} + found: list[Event] = [] + rename = {"Thanksgiving Day": "Thanksgiving"} + for year in range(start_date.year, end_date.year + 1): + hols = holidays.financial_holidays("NYSE", years=year) + found += [ + Event( + name="holiday", + date=hol_date, + title=rename.get(title, title), + ) + for hol_date, title in hols.items() + if start_date < hol_date < end_date + ] + found = [hol for hol in found if (hol.date, hol.title) not in known_us_hols] + for hol in found: + assert hol.title + hol.title += " (NYSE)" + return found + + +def get_holidays(country: str, start_date: date, end_date: date) -> list[Holiday]: + """Get holidays.""" + found: list[Holiday] = [] + for year in range(start_date.year, end_date.year + 1): + hols = holidays.country_holidays(country.upper(), years=year, language="en_US") + found += [ + Holiday( + date=hol_date, + name=title, + country=country.lower(), + ) + for hol_date, title in hols.items() + if start_date < hol_date < end_date + ] + + return found + + +def combine_holidays(holidays: list[Holiday]) -> list[Event]: + """Combine UK and US holidays with the same date and title.""" + all_countries = {h.country for h in holidays} + + standard_name = { + (1, 1): "New Year's Day", + (1, 6): "Epiphany", + (5, 1): "Labour Day", + (8, 15): "Assumption Day", + (12, 8): "Immaculate conception", + (12, 25): "Christmas Day", + (12, 26): "Boxing Day", + } + + combined: collections.defaultdict[ + tuple[date, str], set[str] + ] = collections.defaultdict(set) + + for h in holidays: + assert isinstance(h.name, str) and isinstance(h.date, date) + + event_key = (h.date, standard_name.get((h.date.month, h.date.day), h.name)) + combined[event_key].add(h.country) + + events: list[Event] = [] + for (d, name), countries in combined.items(): + if len(countries) == len(all_countries): + country_list = "" + elif len(countries) < len(all_countries) / 2: + country_list = ", ".join(sorted(country.upper() for country in countries)) + else: + country_list = "not " + ", ".join( + sorted(country.upper() for country in all_countries - set(countries)) + ) + + e = Event( + name="holiday", + date=d, + title=f"{name} ({country_list})" if country_list else name, + ) + events.append(e) + + return events