From cd575afe68afa2d919a5fdbaf2542a102775167e Mon Sep 17 00:00:00 2001 From: Edward Betts Date: Sat, 23 Dec 2023 16:12:49 +0000 Subject: [PATCH] Include stock market holidays as events Closes: #76 --- agenda/data.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/agenda/data.py b/agenda/data.py index a9ceae3..f48e2eb 100644 --- a/agenda/data.py +++ b/agenda/data.py @@ -65,7 +65,7 @@ def us_holidays(start_date: date, end_date: date) -> list[Holiday]: for year in range(start_date.year, end_date.year + 1): hols = holidays.country_holidays("US", years=year, language="en") found += [ - Holiday(date=hol_date, name=title.replace("'", "’"), country="us") + Holiday(date=hol_date, name=title, country="us") for hol_date, title in hols.items() if start_date < hol_date < end_date ] @@ -82,6 +82,31 @@ def us_holidays(start_date: date, end_date: date) -> list[Holiday]: return found + extra +def get_nyse_holidays( + start_date: date, end_date: date, us_hols: list[Holiday] +) -> list[Event]: + """NYSE holidays.""" + known_us_hols = {(h.date, h.name) for h in us_hols} + found: list[Event] = [] + rename = {"Thanksgiving Day": "Thanksgiving"} + for year in range(start_date.year, end_date.year + 1): + hols = holidays.financial_holidays("NYSE", years=year) + found += [ + Event( + name="holiday", + date=hol_date, + title=rename.get(title, title), + ) + for hol_date, title in hols.items() + if start_date < hol_date < end_date + ] + found = [hol for hol in found if (hol.date, hol.title) not in known_us_hols] + for hol in found: + assert hol.title + hol.title += " (NYSE)" + return found + + def get_holidays(country: str, start_date: date, end_date: date) -> list[Holiday]: """Get holidays.""" found: list[Holiday] = [] @@ -296,7 +321,9 @@ async def get_data( if gwr_advance_tickets: events.append(Event(name="gwr_advance_tickets", date=gwr_advance_tickets)) - holidays: list[Holiday] = bank_holiday + us_holidays(last_year, next_year) + us_hols = us_holidays(last_year, next_year) + + holidays: list[Holiday] = bank_holiday + us_hols for country in ( "at", "be", @@ -315,6 +342,8 @@ async def get_data( ): holidays += get_holidays(country, last_year, next_year) + events += get_nyse_holidays(last_year, next_year, us_hols) + accommodation_events = accommodation.get_events( os.path.join(my_data, "accommodation.yaml") )