Include stock market holidays as events

Closes: #76
This commit is contained in:
Edward Betts 2023-12-23 16:12:49 +00:00
parent 4ddd65357f
commit cd575afe68

View file

@ -65,7 +65,7 @@ def us_holidays(start_date: date, end_date: date) -> list[Holiday]:
for year in range(start_date.year, end_date.year + 1): for year in range(start_date.year, end_date.year + 1):
hols = holidays.country_holidays("US", years=year, language="en") hols = holidays.country_holidays("US", years=year, language="en")
found += [ found += [
Holiday(date=hol_date, name=title.replace("'", ""), country="us") Holiday(date=hol_date, name=title, country="us")
for hol_date, title in hols.items() for hol_date, title in hols.items()
if start_date < hol_date < end_date if start_date < hol_date < end_date
] ]
@ -82,6 +82,31 @@ def us_holidays(start_date: date, end_date: date) -> list[Holiday]:
return found + extra return found + extra
def get_nyse_holidays(
start_date: date, end_date: date, us_hols: list[Holiday]
) -> list[Event]:
"""NYSE holidays."""
known_us_hols = {(h.date, h.name) for h in us_hols}
found: list[Event] = []
rename = {"Thanksgiving Day": "Thanksgiving"}
for year in range(start_date.year, end_date.year + 1):
hols = holidays.financial_holidays("NYSE", years=year)
found += [
Event(
name="holiday",
date=hol_date,
title=rename.get(title, title),
)
for hol_date, title in hols.items()
if start_date < hol_date < end_date
]
found = [hol for hol in found if (hol.date, hol.title) not in known_us_hols]
for hol in found:
assert hol.title
hol.title += " (NYSE)"
return found
def get_holidays(country: str, start_date: date, end_date: date) -> list[Holiday]: def get_holidays(country: str, start_date: date, end_date: date) -> list[Holiday]:
"""Get holidays.""" """Get holidays."""
found: list[Holiday] = [] found: list[Holiday] = []
@ -296,7 +321,9 @@ async def get_data(
if gwr_advance_tickets: if gwr_advance_tickets:
events.append(Event(name="gwr_advance_tickets", date=gwr_advance_tickets)) events.append(Event(name="gwr_advance_tickets", date=gwr_advance_tickets))
holidays: list[Holiday] = bank_holiday + us_holidays(last_year, next_year) us_hols = us_holidays(last_year, next_year)
holidays: list[Holiday] = bank_holiday + us_hols
for country in ( for country in (
"at", "at",
"be", "be",
@ -315,6 +342,8 @@ async def get_data(
): ):
holidays += get_holidays(country, last_year, next_year) holidays += get_holidays(country, last_year, next_year)
events += get_nyse_holidays(last_year, next_year, us_hols)
accommodation_events = accommodation.get_events( accommodation_events = accommodation.get_events(
os.path.join(my_data, "accommodation.yaml") os.path.join(my_data, "accommodation.yaml")
) )