Refactor
This commit is contained in:
		
							parent
							
								
									8df94aaafb
								
							
						
					
					
						commit
						69e10db8ef
					
				| 
						 | 
					@ -36,7 +36,7 @@ from . import (
 | 
				
			||||||
    uk_tz,
 | 
					    uk_tz,
 | 
				
			||||||
    waste_schedule,
 | 
					    waste_schedule,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from .types import Event, Holiday, StrDict
 | 
					from .types import Event, StrDict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
here = dateutil.tz.tzlocal()
 | 
					here = dateutil.tz.tzlocal()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -274,9 +274,6 @@ async def get_data(
 | 
				
			||||||
    result_list = await asyncio.gather(
 | 
					    result_list = await asyncio.gather(
 | 
				
			||||||
        time_function("gbpusd", fx.get_gbpusd, config),
 | 
					        time_function("gbpusd", fx.get_gbpusd, config),
 | 
				
			||||||
        time_function("gwr_advance_tickets", gwr.advance_ticket_date, data_dir),
 | 
					        time_function("gwr_advance_tickets", gwr.advance_ticket_date, data_dir),
 | 
				
			||||||
        time_function(
 | 
					 | 
				
			||||||
            "bank_holiday", uk_holiday.bank_holiday_list, last_year, next_year, data_dir
 | 
					 | 
				
			||||||
        ),
 | 
					 | 
				
			||||||
        time_function("rockets", thespacedevs.get_launches, rocket_dir, limit=40),
 | 
					        time_function("rockets", thespacedevs.get_launches, rocket_dir, limit=40),
 | 
				
			||||||
        time_function("backwell_bins", waste_collection_events, data_dir),
 | 
					        time_function("backwell_bins", waste_collection_events, data_dir),
 | 
				
			||||||
        time_function("bristol_bins", bristol_waste_collection_events, data_dir, today),
 | 
					        time_function("bristol_bins", bristol_waste_collection_events, data_dir, today),
 | 
				
			||||||
| 
						 | 
					@ -318,34 +315,13 @@ async def get_data(
 | 
				
			||||||
        events.append(Event(name="gwr_advance_tickets", date=gwr_advance_tickets))
 | 
					        events.append(Event(name="gwr_advance_tickets", date=gwr_advance_tickets))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    us_hols = holidays.us_holidays(last_year, next_year)
 | 
					    us_hols = holidays.us_holidays(last_year, next_year)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    holiday_list: list[Holiday] = results["bank_holiday"] + us_hols
 | 
					 | 
				
			||||||
    for country in (
 | 
					 | 
				
			||||||
        "at",
 | 
					 | 
				
			||||||
        "be",
 | 
					 | 
				
			||||||
        "br",
 | 
					 | 
				
			||||||
        "ch",
 | 
					 | 
				
			||||||
        "cz",
 | 
					 | 
				
			||||||
        "de",
 | 
					 | 
				
			||||||
        "dk",
 | 
					 | 
				
			||||||
        "ee",
 | 
					 | 
				
			||||||
        "es",
 | 
					 | 
				
			||||||
        "fi",
 | 
					 | 
				
			||||||
        "fr",
 | 
					 | 
				
			||||||
        "gr",
 | 
					 | 
				
			||||||
        "it",
 | 
					 | 
				
			||||||
        "ke",
 | 
					 | 
				
			||||||
        "nl",
 | 
					 | 
				
			||||||
        "pl",
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        holiday_list += holidays.get_holidays(country, last_year, next_year)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    events += holidays.get_nyse_holidays(last_year, next_year, us_hols)
 | 
					    events += holidays.get_nyse_holidays(last_year, next_year, us_hols)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    accommodation_events = accommodation.get_events(
 | 
					    accommodation_events = accommodation.get_events(
 | 
				
			||||||
        os.path.join(my_data, "accommodation.yaml")
 | 
					        os.path.join(my_data, "accommodation.yaml")
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    holiday_list = holidays.get_all(last_year, next_year, data_dir)
 | 
				
			||||||
    events += holidays.combine_holidays(holiday_list)
 | 
					    events += holidays.combine_holidays(holiday_list)
 | 
				
			||||||
    events += birthday.get_birthdays(last_year, os.path.join(my_data, "entities.yaml"))
 | 
					    events += birthday.get_birthdays(last_year, os.path.join(my_data, "entities.yaml"))
 | 
				
			||||||
    events += accommodation_events
 | 
					    events += accommodation_events
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,6 +3,7 @@
 | 
				
			||||||
import collections
 | 
					import collections
 | 
				
			||||||
from datetime import date, timedelta
 | 
					from datetime import date, timedelta
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import agenda.uk_holiday
 | 
				
			||||||
import holidays
 | 
					import holidays
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .types import Event, Holiday
 | 
					from .types import Event, Holiday
 | 
				
			||||||
| 
						 | 
					@ -117,3 +118,33 @@ def combine_holidays(holidays: list[Holiday]) -> list[Event]:
 | 
				
			||||||
        events.append(e)
 | 
					        events.append(e)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return events
 | 
					    return events
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_all(last_year: date, next_year: date, data_dir: str) -> list[Holiday]:
 | 
				
			||||||
 | 
					    """Get holidays for various countries and return as a list."""
 | 
				
			||||||
 | 
					    us_hols = us_holidays(last_year, next_year)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    bank_holidays = agenda.uk_holiday.bank_holiday_list(last_year, next_year, data_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    holiday_list: list[Holiday] = bank_holidays + us_hols
 | 
				
			||||||
 | 
					    for country in (
 | 
				
			||||||
 | 
					        "at",
 | 
				
			||||||
 | 
					        "be",
 | 
				
			||||||
 | 
					        "br",
 | 
				
			||||||
 | 
					        "ch",
 | 
				
			||||||
 | 
					        "cz",
 | 
				
			||||||
 | 
					        "de",
 | 
				
			||||||
 | 
					        "dk",
 | 
				
			||||||
 | 
					        "ee",
 | 
				
			||||||
 | 
					        "es",
 | 
				
			||||||
 | 
					        "fi",
 | 
				
			||||||
 | 
					        "fr",
 | 
				
			||||||
 | 
					        "gr",
 | 
				
			||||||
 | 
					        "it",
 | 
				
			||||||
 | 
					        "ke",
 | 
				
			||||||
 | 
					        "nl",
 | 
				
			||||||
 | 
					        "pl",
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        holiday_list += get_holidays(country, last_year, next_year)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return holiday_list
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,7 +3,6 @@
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
from datetime import date, datetime, timedelta
 | 
					from datetime import date, datetime, timedelta
 | 
				
			||||||
from time import time
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
import httpx
 | 
					import httpx
 | 
				
			||||||
from dateutil.easter import easter
 | 
					from dateutil.easter import easter
 | 
				
			||||||
| 
						 | 
					@ -13,9 +12,15 @@ from .types import Holiday, StrDict
 | 
				
			||||||
url = "https://www.gov.uk/bank-holidays.json"
 | 
					url = "https://www.gov.uk/bank-holidays.json"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def json_filename(data_dir: str) -> str:
 | 
				
			||||||
 | 
					    """Filename for cached bank holidays."""
 | 
				
			||||||
 | 
					    assert os.path.exists(data_dir)
 | 
				
			||||||
 | 
					    return os.path.join(data_dir, "bank-holidays.json")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def get_holiday_list(data_dir: str) -> list[StrDict]:
 | 
					async def get_holiday_list(data_dir: str) -> list[StrDict]:
 | 
				
			||||||
    """Download holiday list and save cache."""
 | 
					    """Download holiday list and save cache."""
 | 
				
			||||||
    filename = os.path.join(data_dir, "bank-holidays.json")
 | 
					    filename = json_filename(data_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async with httpx.AsyncClient() as client:
 | 
					    async with httpx.AsyncClient() as client:
 | 
				
			||||||
        r = await client.get(url)
 | 
					        r = await client.get(url)
 | 
				
			||||||
| 
						 | 
					@ -24,27 +29,12 @@ async def get_holiday_list(data_dir: str) -> list[StrDict]:
 | 
				
			||||||
    return events
 | 
					    return events
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def bank_holiday_list(
 | 
					def bank_holiday_list(start_date: date, end_date: date, data_dir: str) -> list[Holiday]:
 | 
				
			||||||
    start_date: date, end_date: date, data_dir: str
 | 
					 | 
				
			||||||
) -> list[Holiday]:
 | 
					 | 
				
			||||||
    """Date and name of the next UK bank holiday."""
 | 
					    """Date and name of the next UK bank holiday."""
 | 
				
			||||||
    filename = os.path.join(data_dir, "bank-holidays.json")
 | 
					    filename = json_filename(data_dir)
 | 
				
			||||||
    use_cached = False
 | 
					 | 
				
			||||||
    events: list[StrDict]
 | 
					 | 
				
			||||||
    if os.path.exists(filename):
 | 
					 | 
				
			||||||
        mtime = os.path.getmtime(filename)
 | 
					 | 
				
			||||||
        if (time() - mtime) < 60 * 60 * 6:  # six hours
 | 
					 | 
				
			||||||
            use_cached = True
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                events = json.load(open(filename))["england-and-wales"]["events"]
 | 
					 | 
				
			||||||
            except json.decoder.JSONDecodeError:
 | 
					 | 
				
			||||||
                use_cached = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if not use_cached:
 | 
					 | 
				
			||||||
        events = await get_holiday_list(data_dir)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    hols: list[Holiday] = []
 | 
					    hols: list[Holiday] = []
 | 
				
			||||||
    for event in events:
 | 
					    for event in json.load(open(filename))["england-and-wales"]["events"]:
 | 
				
			||||||
        event_date = datetime.strptime(event["date"], "%Y-%m-%d").date()
 | 
					        event_date = datetime.strptime(event["date"], "%Y-%m-%d").date()
 | 
				
			||||||
        if event_date < start_date:
 | 
					        if event_date < start_date:
 | 
				
			||||||
            continue
 | 
					            continue
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue