Split code for holidays into separate file
This commit is contained in:
parent
a6a78d18e5
commit
b061262120
125
agenda/data.py
125
agenda/data.py
|
@ -1,7 +1,6 @@
|
|||
"""Agenda data."""
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import itertools
|
||||
import os
|
||||
import typing
|
||||
|
@ -11,7 +10,6 @@ from time import time
|
|||
import dateutil.rrule
|
||||
import dateutil.tz
|
||||
import flask
|
||||
import holidays
|
||||
import isodate # type: ignore
|
||||
import lxml
|
||||
import pytz
|
||||
|
@ -27,6 +25,7 @@ from . import (
|
|||
fx,
|
||||
gwr,
|
||||
hn,
|
||||
holidays,
|
||||
meetup,
|
||||
stock_market,
|
||||
subscription,
|
||||
|
@ -61,72 +60,6 @@ def timezone_transition(
|
|||
]
|
||||
|
||||
|
||||
def us_holidays(start_date: date, end_date: date) -> list[Holiday]:
|
||||
"""Get US holidays."""
|
||||
found: list[Holiday] = []
|
||||
for year in range(start_date.year, end_date.year + 1):
|
||||
hols = holidays.country_holidays("US", years=year, language="en")
|
||||
found += [
|
||||
Holiday(date=hol_date, name=title, country="us")
|
||||
for hol_date, title in hols.items()
|
||||
if start_date < hol_date < end_date
|
||||
]
|
||||
|
||||
extra = []
|
||||
for h in found:
|
||||
if h.name != "Thanksgiving":
|
||||
continue
|
||||
extra += [
|
||||
Holiday(date=h.date + timedelta(days=1), name="Black Friday", country="us"),
|
||||
Holiday(date=h.date + timedelta(days=4), name="Cyber Monday", country="us"),
|
||||
]
|
||||
|
||||
return found + extra
|
||||
|
||||
|
||||
def get_nyse_holidays(
|
||||
start_date: date, end_date: date, us_hols: list[Holiday]
|
||||
) -> list[Event]:
|
||||
"""NYSE holidays."""
|
||||
known_us_hols = {(h.date, h.name) for h in us_hols}
|
||||
found: list[Event] = []
|
||||
rename = {"Thanksgiving Day": "Thanksgiving"}
|
||||
for year in range(start_date.year, end_date.year + 1):
|
||||
hols = holidays.financial_holidays("NYSE", years=year)
|
||||
found += [
|
||||
Event(
|
||||
name="holiday",
|
||||
date=hol_date,
|
||||
title=rename.get(title, title),
|
||||
)
|
||||
for hol_date, title in hols.items()
|
||||
if start_date < hol_date < end_date
|
||||
]
|
||||
found = [hol for hol in found if (hol.date, hol.title) not in known_us_hols]
|
||||
for hol in found:
|
||||
assert hol.title
|
||||
hol.title += " (NYSE)"
|
||||
return found
|
||||
|
||||
|
||||
def get_holidays(country: str, start_date: date, end_date: date) -> list[Holiday]:
|
||||
"""Get holidays."""
|
||||
found: list[Holiday] = []
|
||||
for year in range(start_date.year, end_date.year + 1):
|
||||
hols = holidays.country_holidays(country.upper(), years=year, language="en_US")
|
||||
found += [
|
||||
Holiday(
|
||||
date=hol_date,
|
||||
name=title,
|
||||
country=country.lower(),
|
||||
)
|
||||
for hol_date, title in hols.items()
|
||||
if start_date < hol_date < end_date
|
||||
]
|
||||
|
||||
return found
|
||||
|
||||
|
||||
def midnight(d: date) -> datetime:
|
||||
"""Convert from date to midnight on that day."""
|
||||
return datetime.combine(d, datetime.min.time())
|
||||
|
@ -166,52 +99,6 @@ async def bristol_waste_collection_events(
|
|||
return await waste_schedule.get_bristol_gov_uk(start_date, data_dir, uprn)
|
||||
|
||||
|
||||
def combine_holidays(holidays: list[Holiday]) -> list[Event]:
|
||||
"""Combine UK and US holidays with the same date and title."""
|
||||
|
||||
all_countries = {h.country for h in holidays}
|
||||
|
||||
standard_name = {
|
||||
(1, 1): "New Year's Day",
|
||||
(1, 6): "Epiphany",
|
||||
(5, 1): "Labour Day",
|
||||
(8, 15): "Assumption Day",
|
||||
(12, 8): "Immaculate conception",
|
||||
(12, 25): "Christmas Day",
|
||||
(12, 26): "Boxing Day",
|
||||
}
|
||||
|
||||
combined: collections.defaultdict[
|
||||
tuple[date, str], set[str]
|
||||
] = collections.defaultdict(set)
|
||||
|
||||
for h in holidays:
|
||||
assert isinstance(h.name, str) and isinstance(h.date, date)
|
||||
|
||||
event_key = (h.date, standard_name.get((h.date.month, h.date.day), h.name))
|
||||
combined[event_key].add(h.country)
|
||||
|
||||
events: list[Event] = []
|
||||
for (d, name), countries in combined.items():
|
||||
if len(countries) == len(all_countries):
|
||||
country_list = ""
|
||||
elif len(countries) < len(all_countries) / 2:
|
||||
country_list = ", ".join(sorted(country.upper() for country in countries))
|
||||
else:
|
||||
country_list = "not " + ", ".join(
|
||||
sorted(country.upper() for country in all_countries - set(countries))
|
||||
)
|
||||
|
||||
e = Event(
|
||||
name="holiday",
|
||||
date=d,
|
||||
title=f"{name} ({country_list})" if country_list else name,
|
||||
)
|
||||
events.append(e)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def get_yaml_event_date_field(item: dict[str, str]) -> str:
|
||||
"""Event date field name."""
|
||||
return (
|
||||
|
@ -430,9 +317,9 @@ async def get_data(
|
|||
if gwr_advance_tickets:
|
||||
events.append(Event(name="gwr_advance_tickets", date=gwr_advance_tickets))
|
||||
|
||||
us_hols = us_holidays(last_year, next_year)
|
||||
us_hols = holidays.us_holidays(last_year, next_year)
|
||||
|
||||
holidays: list[Holiday] = results["bank_holiday"] + us_hols
|
||||
holiday_list: list[Holiday] = results["bank_holiday"] + us_hols
|
||||
for country in (
|
||||
"at",
|
||||
"be",
|
||||
|
@ -451,15 +338,15 @@ async def get_data(
|
|||
"nl",
|
||||
"pl",
|
||||
):
|
||||
holidays += get_holidays(country, last_year, next_year)
|
||||
holiday_list += holidays.get_holidays(country, last_year, next_year)
|
||||
|
||||
events += get_nyse_holidays(last_year, next_year, us_hols)
|
||||
events += holidays.get_nyse_holidays(last_year, next_year, us_hols)
|
||||
|
||||
accommodation_events = accommodation.get_events(
|
||||
os.path.join(my_data, "accommodation.yaml")
|
||||
)
|
||||
|
||||
events += combine_holidays(holidays)
|
||||
events += holidays.combine_holidays(holiday_list)
|
||||
events += birthday.get_birthdays(last_year, os.path.join(my_data, "entities.yaml"))
|
||||
events += accommodation_events
|
||||
events += travel.all_events(my_data)
|
||||
|
|
117
agenda/holidays.py
Normal file
117
agenda/holidays.py
Normal file
|
@ -0,0 +1,117 @@
|
|||
import collections
|
||||
from datetime import date, timedelta
|
||||
|
||||
import holidays
|
||||
|
||||
from .types import Event, Holiday
|
||||
|
||||
|
||||
def us_holidays(start_date: date, end_date: date) -> list[Holiday]:
|
||||
"""Get US holidays."""
|
||||
found: list[Holiday] = []
|
||||
for year in range(start_date.year, end_date.year + 1):
|
||||
hols = holidays.country_holidays("US", years=year, language="en")
|
||||
found += [
|
||||
Holiday(date=hol_date, name=title, country="us")
|
||||
for hol_date, title in hols.items()
|
||||
if start_date < hol_date < end_date
|
||||
]
|
||||
|
||||
extra = []
|
||||
for h in found:
|
||||
if h.name != "Thanksgiving":
|
||||
continue
|
||||
extra += [
|
||||
Holiday(date=h.date + timedelta(days=1), name="Black Friday", country="us"),
|
||||
Holiday(date=h.date + timedelta(days=4), name="Cyber Monday", country="us"),
|
||||
]
|
||||
|
||||
return found + extra
|
||||
|
||||
|
||||
def get_nyse_holidays(
|
||||
start_date: date, end_date: date, us_hols: list[Holiday]
|
||||
) -> list[Event]:
|
||||
"""NYSE holidays."""
|
||||
known_us_hols = {(h.date, h.name) for h in us_hols}
|
||||
found: list[Event] = []
|
||||
rename = {"Thanksgiving Day": "Thanksgiving"}
|
||||
for year in range(start_date.year, end_date.year + 1):
|
||||
hols = holidays.financial_holidays("NYSE", years=year)
|
||||
found += [
|
||||
Event(
|
||||
name="holiday",
|
||||
date=hol_date,
|
||||
title=rename.get(title, title),
|
||||
)
|
||||
for hol_date, title in hols.items()
|
||||
if start_date < hol_date < end_date
|
||||
]
|
||||
found = [hol for hol in found if (hol.date, hol.title) not in known_us_hols]
|
||||
for hol in found:
|
||||
assert hol.title
|
||||
hol.title += " (NYSE)"
|
||||
return found
|
||||
|
||||
|
||||
def get_holidays(country: str, start_date: date, end_date: date) -> list[Holiday]:
|
||||
"""Get holidays."""
|
||||
found: list[Holiday] = []
|
||||
for year in range(start_date.year, end_date.year + 1):
|
||||
hols = holidays.country_holidays(country.upper(), years=year, language="en_US")
|
||||
found += [
|
||||
Holiday(
|
||||
date=hol_date,
|
||||
name=title,
|
||||
country=country.lower(),
|
||||
)
|
||||
for hol_date, title in hols.items()
|
||||
if start_date < hol_date < end_date
|
||||
]
|
||||
|
||||
return found
|
||||
|
||||
|
||||
def combine_holidays(holidays: list[Holiday]) -> list[Event]:
|
||||
"""Combine UK and US holidays with the same date and title."""
|
||||
all_countries = {h.country for h in holidays}
|
||||
|
||||
standard_name = {
|
||||
(1, 1): "New Year's Day",
|
||||
(1, 6): "Epiphany",
|
||||
(5, 1): "Labour Day",
|
||||
(8, 15): "Assumption Day",
|
||||
(12, 8): "Immaculate conception",
|
||||
(12, 25): "Christmas Day",
|
||||
(12, 26): "Boxing Day",
|
||||
}
|
||||
|
||||
combined: collections.defaultdict[
|
||||
tuple[date, str], set[str]
|
||||
] = collections.defaultdict(set)
|
||||
|
||||
for h in holidays:
|
||||
assert isinstance(h.name, str) and isinstance(h.date, date)
|
||||
|
||||
event_key = (h.date, standard_name.get((h.date.month, h.date.day), h.name))
|
||||
combined[event_key].add(h.country)
|
||||
|
||||
events: list[Event] = []
|
||||
for (d, name), countries in combined.items():
|
||||
if len(countries) == len(all_countries):
|
||||
country_list = ""
|
||||
elif len(countries) < len(all_countries) / 2:
|
||||
country_list = ", ".join(sorted(country.upper() for country in countries))
|
||||
else:
|
||||
country_list = "not " + ", ".join(
|
||||
sorted(country.upper() for country in all_countries - set(countries))
|
||||
)
|
||||
|
||||
e = Event(
|
||||
name="holiday",
|
||||
date=d,
|
||||
title=f"{name} ({country_list})" if country_list else name,
|
||||
)
|
||||
events.append(e)
|
||||
|
||||
return events
|
Loading…
Reference in a new issue