agenda/agenda/holidays.py

160 lines
4.6 KiB
Python

"""Holidays."""
import collections
from datetime import date, timedelta
import agenda.uk_holiday
import holidays
from .types import Event, Holiday
def us_holidays(start_date: date, end_date: date) -> list[Holiday]:
"""Get US holidays."""
found: list[Holiday] = []
for year in range(start_date.year, end_date.year + 1):
hols = holidays.country_holidays("US", years=year, language="en")
found += [
Holiday(date=hol_date, name=title, country="us")
for hol_date, title in hols.items()
if start_date < hol_date < end_date
]
extra = []
for h in found:
if h.name != "Thanksgiving":
continue
extra += [
Holiday(date=h.date + timedelta(days=1), name="Black Friday", country="us"),
Holiday(date=h.date + timedelta(days=4), name="Cyber Monday", country="us"),
]
return found + extra
def get_nyse_holidays(
start_date: date, end_date: date, us_hols: list[Holiday]
) -> list[Event]:
"""NYSE holidays."""
known_us_hols = {(h.date, h.name) for h in us_hols}
found: list[Event] = []
rename = {"Thanksgiving Day": "Thanksgiving"}
for year in range(start_date.year, end_date.year + 1):
hols = holidays.financial_holidays("NYSE", years=year)
found += [
Event(
name="holiday",
date=hol_date,
title=rename.get(title, title),
)
for hol_date, title in hols.items()
if start_date <= hol_date <= end_date
]
found = [hol for hol in found if (hol.date, hol.title) not in known_us_hols]
for hol in found:
assert hol.title
hol.title += " (NYSE)"
return found
def get_holidays(country: str, start_date: date, end_date: date) -> list[Holiday]:
"""Get holidays."""
found: list[Holiday] = []
uc_country = country.upper()
holiday_country = getattr(holidays, uc_country)
default_language = holiday_country.default_language
for year in range(start_date.year, end_date.year + 1):
en_hols = holidays.country_holidays(uc_country, years=year, language="en_US")
local_lang = holidays.country_holidays(
uc_country, years=year, language=default_language
)
found += [
Holiday(
date=hol_date,
name=title,
local_name=local_lang[hol_date],
country=country.lower(),
)
for hol_date, title in en_hols.items()
if start_date <= hol_date <= end_date
]
return found
def combine_holidays(holidays: list[Holiday]) -> list[Event]:
"""Combine UK and US holidays with the same date and title."""
all_countries = {h.country for h in holidays}
standard_name = {
(1, 1): "New Year's Day",
(1, 6): "Epiphany",
(5, 1): "Labour Day",
(8, 15): "Assumption Day",
(12, 8): "Immaculate conception",
(12, 25): "Christmas Day",
(12, 26): "Boxing Day",
}
combined: collections.defaultdict[
tuple[date, str], set[str]
] = collections.defaultdict(set)
for h in holidays:
assert isinstance(h.name, str) and isinstance(h.date, date)
event_key = (h.date, standard_name.get((h.date.month, h.date.day), h.name))
combined[event_key].add(h.country)
events: list[Event] = []
for (d, name), countries in combined.items():
if len(countries) == len(all_countries):
country_list = ""
elif len(countries) < len(all_countries) / 2:
country_list = ", ".join(sorted(country.upper() for country in countries))
else:
country_list = "not " + ", ".join(
sorted(country.upper() for country in all_countries - set(countries))
)
e = Event(
name="holiday",
date=d,
title=f"{name} ({country_list})" if country_list else name,
)
events.append(e)
return events
def get_all(last_year: date, next_year: date, data_dir: str) -> list[Holiday]:
"""Get holidays for various countries and return as a list."""
us_hols = us_holidays(last_year, next_year)
bank_holidays = agenda.uk_holiday.bank_holiday_list(last_year, next_year, data_dir)
holiday_list: list[Holiday] = bank_holidays + us_hols
for country in (
"at",
"be",
"br",
"ch",
"cz",
"de",
"dk",
"ee",
"es",
"fi",
"fr",
"gr",
"it",
"ke",
"nl",
"pl",
):
holiday_list += get_holidays(country, last_year, next_year)
return holiday_list