Combine holidays on same for different countries

Closes: #83
This commit is contained in:
Edward Betts 2023-12-09 13:51:24 +00:00
parent 2cc4d553bf
commit 25eb8fa23e
3 changed files with 64 additions and 35 deletions

View file

@ -1,6 +1,7 @@
"""Agenda data."""
import asyncio
import collections
import os
import typing
from datetime import date, datetime, timedelta
@ -34,7 +35,7 @@ from . import (
uk_midnight,
waste_schedule,
)
from .types import Event
from .types import Event, Holiday
here = dateutil.tz.tzlocal()
@ -58,43 +59,39 @@ def timezone_transition(
]
def us_holidays(start_date: date, end_date: date) -> list[Event]:
def us_holidays(start_date: date, end_date: date) -> list[Holiday]:
"""Get US holidays."""
found: list[Event] = []
found: list[Holiday] = []
for year in range(start_date.year, end_date.year + 1):
hols = holidays.country_holidays("US", years=year, language="en")
found += [
Event(name="us_holiday", date=hol_date, title=title.replace("'", ""))
Holiday(date=hol_date, name=title.replace("'", ""), country="us")
for hol_date, title in hols.items()
if start_date < hol_date < end_date
]
extra = []
for e in found:
if e.title != "Thanksgiving":
for h in found:
if h.name != "Thanksgiving":
continue
extra += [
Event(
name="us_holiday", date=e.date + timedelta(days=1), title="Black Friday"
),
Event(
name="us_holiday", date=e.date + timedelta(days=4), title="Cyber Monday"
),
Holiday(date=h.date + timedelta(days=1), name="Black Friday", country="us"),
Holiday(date=h.date + timedelta(days=4), name="Cyber Monday", country="us"),
]
return found + extra
def get_holidays(country: str, start_date: date, end_date: date) -> list[Event]:
def get_holidays(country: str, start_date: date, end_date: date) -> list[Holiday]:
"""Get holidays."""
found: list[Event] = []
found: list[Holiday] = []
for year in range(start_date.year, end_date.year + 1):
hols = holidays.country_holidays(country.upper(), years=year, language="en")
found += [
Event(
name=country.lower() + "_holiday",
Holiday(
date=hol_date,
title=f"{title} ({country.upper()})",
name=title,
country=country.lower(),
)
for hol_date, title in hols.items()
if start_date < hol_date < end_date
@ -137,20 +134,25 @@ async def bristol_waste_collection_events(
return await waste_schedule.get_bristol_gov_uk(start_date, data_dir, uprn)
def combine_holidays(events: list[Event]) -> list[Event]:
def combine_holidays(holidays: list[Holiday]) -> list[Event]:
"""Combine UK and US holidays with the same date and title."""
combined: dict[tuple[date, str], Event] = {}
combined: collections.defaultdict[
tuple[date, str], set[str]
] = collections.defaultdict(set)
for e in events:
assert isinstance(e.title, str) and isinstance(e.date, date)
event_key = (e.date, e.title)
combined[event_key] = (
Event(name="bank_holiday", date=e.date, title=e.title + " (UK & US)")
if event_key in combined
else e
)
for h in holidays:
assert isinstance(h.name, str) and isinstance(h.date, date)
return list(combined.values())
event_key = (h.date, h.name)
combined[event_key].add(h.country)
events: list[Event] = []
for (d, name), countries in combined.items():
title = f'{name} ({", ".join(country.upper() for country in countries)})'
e = Event(name="holiday", date=d, title=title)
events.append(e)
return events
def get_yaml_event_date_field(item: dict[str, str]) -> str:
@ -257,9 +259,27 @@ async def get_data(
if gwr_advance_tickets:
events.append(Event(name="gwr_advance_tickets", date=gwr_advance_tickets))
events += combine_holidays(bank_holiday + us_holidays(last_year, next_year))
for country in "be", "de", "fr", "nl":
events += get_holidays(country, last_year, next_year)
holidays: list[Holiday] = bank_holiday + us_holidays(last_year, next_year)
for country in (
"at",
"be",
"ch",
"cz",
"de",
"dk",
"ee",
"es",
"fr",
"gr",
"it",
"ke",
"nl",
"pl",
):
holidays += get_holidays(country, last_year, next_year)
events += combine_holidays(holidays)
events += birthday.get_birthdays(last_year, os.path.join(my_data, "entities.yaml"))
events += accommodation.get_events(os.path.join(my_data, "accommodation.yaml"))
events += travel.all_events(my_data)

View file

@ -4,6 +4,15 @@ import dataclasses
import datetime
@dataclasses.dataclass
class Holiday:
"""Holiay."""
name: str
country: str
date: datetime.date
@dataclasses.dataclass
class Event:
"""Event."""

View file

@ -8,12 +8,12 @@ from time import time
import httpx
from dateutil.easter import easter
from .types import Event
from .types import Holiday
async def bank_holiday_list(
start_date: date, end_date: date, data_dir: str
) -> list[Event]:
) -> list[Holiday]:
"""Date and name of the next UK bank holiday."""
url = "https://www.gov.uk/bank-holidays.json"
filename = os.path.join(data_dir, "bank-holidays.json")
@ -24,14 +24,14 @@ async def bank_holiday_list(
open(filename, "w").write(r.text)
events = json.load(open(filename))["england-and-wales"]["events"]
hols: list[Event] = []
hols: list[Holiday] = []
for event in events:
event_date = datetime.strptime(event["date"], "%Y-%m-%d").date()
if event_date < start_date:
continue
if event_date > end_date:
break
hols.append(Event(name="bank_holiday", date=event_date, title=event["title"]))
hols.append(Holiday(date=event_date, name=event["title"], country="gb"))
return hols