Refactor
This commit is contained in:
parent
8df94aaafb
commit
69e10db8ef
|
@ -36,7 +36,7 @@ from . import (
|
|||
uk_tz,
|
||||
waste_schedule,
|
||||
)
|
||||
from .types import Event, Holiday, StrDict
|
||||
from .types import Event, StrDict
|
||||
|
||||
here = dateutil.tz.tzlocal()
|
||||
|
||||
|
@ -274,9 +274,6 @@ async def get_data(
|
|||
result_list = await asyncio.gather(
|
||||
time_function("gbpusd", fx.get_gbpusd, config),
|
||||
time_function("gwr_advance_tickets", gwr.advance_ticket_date, data_dir),
|
||||
time_function(
|
||||
"bank_holiday", uk_holiday.bank_holiday_list, last_year, next_year, data_dir
|
||||
),
|
||||
time_function("rockets", thespacedevs.get_launches, rocket_dir, limit=40),
|
||||
time_function("backwell_bins", waste_collection_events, data_dir),
|
||||
time_function("bristol_bins", bristol_waste_collection_events, data_dir, today),
|
||||
|
@ -318,34 +315,13 @@ async def get_data(
|
|||
events.append(Event(name="gwr_advance_tickets", date=gwr_advance_tickets))
|
||||
|
||||
us_hols = holidays.us_holidays(last_year, next_year)
|
||||
|
||||
holiday_list: list[Holiday] = results["bank_holiday"] + us_hols
|
||||
for country in (
|
||||
"at",
|
||||
"be",
|
||||
"br",
|
||||
"ch",
|
||||
"cz",
|
||||
"de",
|
||||
"dk",
|
||||
"ee",
|
||||
"es",
|
||||
"fi",
|
||||
"fr",
|
||||
"gr",
|
||||
"it",
|
||||
"ke",
|
||||
"nl",
|
||||
"pl",
|
||||
):
|
||||
holiday_list += holidays.get_holidays(country, last_year, next_year)
|
||||
|
||||
events += holidays.get_nyse_holidays(last_year, next_year, us_hols)
|
||||
|
||||
accommodation_events = accommodation.get_events(
|
||||
os.path.join(my_data, "accommodation.yaml")
|
||||
)
|
||||
|
||||
holiday_list = holidays.get_all(last_year, next_year, data_dir)
|
||||
events += holidays.combine_holidays(holiday_list)
|
||||
events += birthday.get_birthdays(last_year, os.path.join(my_data, "entities.yaml"))
|
||||
events += accommodation_events
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
import collections
|
||||
from datetime import date, timedelta
|
||||
|
||||
import agenda.uk_holiday
|
||||
import holidays
|
||||
|
||||
from .types import Event, Holiday
|
||||
|
@ -117,3 +118,33 @@ def combine_holidays(holidays: list[Holiday]) -> list[Event]:
|
|||
events.append(e)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def get_all(last_year: date, next_year: date, data_dir: str) -> list[Holiday]:
|
||||
"""Get holidays for various countries and return as a list."""
|
||||
us_hols = us_holidays(last_year, next_year)
|
||||
|
||||
bank_holidays = agenda.uk_holiday.bank_holiday_list(last_year, next_year, data_dir)
|
||||
|
||||
holiday_list: list[Holiday] = bank_holidays + us_hols
|
||||
for country in (
|
||||
"at",
|
||||
"be",
|
||||
"br",
|
||||
"ch",
|
||||
"cz",
|
||||
"de",
|
||||
"dk",
|
||||
"ee",
|
||||
"es",
|
||||
"fi",
|
||||
"fr",
|
||||
"gr",
|
||||
"it",
|
||||
"ke",
|
||||
"nl",
|
||||
"pl",
|
||||
):
|
||||
holiday_list += get_holidays(country, last_year, next_year)
|
||||
|
||||
return holiday_list
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
import json
|
||||
import os
|
||||
from datetime import date, datetime, timedelta
|
||||
from time import time
|
||||
|
||||
import httpx
|
||||
from dateutil.easter import easter
|
||||
|
@ -13,9 +12,15 @@ from .types import Holiday, StrDict
|
|||
url = "https://www.gov.uk/bank-holidays.json"
|
||||
|
||||
|
||||
def json_filename(data_dir: str) -> str:
|
||||
"""Filename for cached bank holidays."""
|
||||
assert os.path.exists(data_dir)
|
||||
return os.path.join(data_dir, "bank-holidays.json")
|
||||
|
||||
|
||||
async def get_holiday_list(data_dir: str) -> list[StrDict]:
|
||||
"""Download holiday list and save cache."""
|
||||
filename = os.path.join(data_dir, "bank-holidays.json")
|
||||
filename = json_filename(data_dir)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(url)
|
||||
|
@ -24,27 +29,12 @@ async def get_holiday_list(data_dir: str) -> list[StrDict]:
|
|||
return events
|
||||
|
||||
|
||||
async def bank_holiday_list(
|
||||
start_date: date, end_date: date, data_dir: str
|
||||
) -> list[Holiday]:
|
||||
def bank_holiday_list(start_date: date, end_date: date, data_dir: str) -> list[Holiday]:
|
||||
"""Date and name of the next UK bank holiday."""
|
||||
filename = os.path.join(data_dir, "bank-holidays.json")
|
||||
use_cached = False
|
||||
events: list[StrDict]
|
||||
if os.path.exists(filename):
|
||||
mtime = os.path.getmtime(filename)
|
||||
if (time() - mtime) < 60 * 60 * 6: # six hours
|
||||
use_cached = True
|
||||
try:
|
||||
events = json.load(open(filename))["england-and-wales"]["events"]
|
||||
except json.decoder.JSONDecodeError:
|
||||
use_cached = False
|
||||
|
||||
if not use_cached:
|
||||
events = await get_holiday_list(data_dir)
|
||||
filename = json_filename(data_dir)
|
||||
|
||||
hols: list[Holiday] = []
|
||||
for event in events:
|
||||
for event in json.load(open(filename))["england-and-wales"]["events"]:
|
||||
event_date = datetime.strptime(event["date"], "%Y-%m-%d").date()
|
||||
if event_date < start_date:
|
||||
continue
|
||||
|
|
Loading…
Reference in a new issue