From 421d982634f2074bb95039452926fe3a6cde4304 Mon Sep 17 00:00:00 2001 From: Dazhong Xia Date: Tue, 28 Jan 2025 13:53:07 -0500 Subject: [PATCH] fix: move expensive TimezoneFinder call out of top level. --- src/pudl/transform/eia.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/pudl/transform/eia.py b/src/pudl/transform/eia.py index 4acbe38209..1499c770f7 100644 --- a/src/pudl/transform/eia.py +++ b/src/pudl/transform/eia.py @@ -45,10 +45,6 @@ logger = pudl.logging_helpers.get_logger(__name__) -TZ_FINDER = timezonefinder.TimezoneFinder() -"""A global TimezoneFinder to cache geographies in memory for faster access.""" - - class EiaEntity(StrEnum): """Enum for the different types of EIA entities.""" @@ -58,7 +54,7 @@ class EiaEntity(StrEnum): GENERATORS = auto() -def find_timezone(*, lng=None, lat=None, state=None, strict=True): +def find_timezone(*, lng=None, lat=None, state=None, strict=True, tz_finder=None): """Find the timezone associated with the a specified input location. Note that this function requires named arguments. The names are lng, lat, @@ -84,10 +80,10 @@ def find_timezone(*, lng=None, lat=None, state=None, strict=True): Update docstring. """ try: - tz = TZ_FINDER.timezone_at(lng=lng, lat=lat) + tz = tz_finder.timezone_at(lng=lng, lat=lat) if tz is None: # Try harder # Could change the search radius as well - tz = TZ_FINDER.closest_timezone_at(lng=lng, lat=lat) + tz = tz_finder.closest_timezone_at(lng=lng, lat=lat) # For some reason w/ Python 3.6 we get a ValueError here, but with # Python 3.7 we get an OverflowError... except (OverflowError, ValueError) as err: @@ -312,9 +308,16 @@ def _add_timezone(plants_entity: pd.DataFrame) -> pd.DataFrame: A DataFrame containing the same table, with a "timezone" column added. Timezone may be missing if lat / lon is missing or invalid. """ + # Takes 300ms. Used to be at module level so we wouldn't have to initialize every time, but that makes import time slow. + tz_finder = timezonefinder.TimezoneFinder() + plants_entity["timezone"] = plants_entity.apply( lambda row: find_timezone( - lng=row["longitude"], lat=row["latitude"], state=row["state"], strict=False + lng=row["longitude"], + lat=row["latitude"], + state=row["state"], + strict=False, + tz_finder=tz_finder, ), axis=1, )