Skip to content

Commit

Permalink
fix: move expensive TimezoneFinder call out of top level.
Browse files Browse the repository at this point in the history
  • Loading branch information
jdangerx committed Jan 28, 2025
1 parent c0eb11d commit 421d982
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions src/pudl/transform/eia.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@
logger = pudl.logging_helpers.get_logger(__name__)


TZ_FINDER = timezonefinder.TimezoneFinder()
"""A global TimezoneFinder to cache geographies in memory for faster access."""


class EiaEntity(StrEnum):
"""Enum for the different types of EIA entities."""

Expand All @@ -58,7 +54,7 @@ class EiaEntity(StrEnum):
GENERATORS = auto()


def find_timezone(*, lng=None, lat=None, state=None, strict=True):
def find_timezone(*, lng=None, lat=None, state=None, strict=True, tz_finder=None):
"""Find the timezone associated with the a specified input location.
Note that this function requires named arguments. The names are lng, lat,
Expand All @@ -84,10 +80,10 @@ def find_timezone(*, lng=None, lat=None, state=None, strict=True):
Update docstring.
"""
try:
tz = TZ_FINDER.timezone_at(lng=lng, lat=lat)
tz = tz_finder.timezone_at(lng=lng, lat=lat)
if tz is None: # Try harder
# Could change the search radius as well
tz = TZ_FINDER.closest_timezone_at(lng=lng, lat=lat)
tz = tz_finder.closest_timezone_at(lng=lng, lat=lat)
# For some reason w/ Python 3.6 we get a ValueError here, but with
# Python 3.7 we get an OverflowError...
except (OverflowError, ValueError) as err:
Expand Down Expand Up @@ -312,9 +308,16 @@ def _add_timezone(plants_entity: pd.DataFrame) -> pd.DataFrame:
A DataFrame containing the same table, with a "timezone" column added.
Timezone may be missing if lat / lon is missing or invalid.
"""
# Takes 300ms. Used to be at module level so we wouldn't have to initialize every time, but that makes import time slow.
tz_finder = timezonefinder.TimezoneFinder()

plants_entity["timezone"] = plants_entity.apply(
lambda row: find_timezone(
lng=row["longitude"], lat=row["latitude"], state=row["state"], strict=False
lng=row["longitude"],
lat=row["latitude"],
state=row["state"],
strict=False,
tz_finder=tz_finder,
),
axis=1,
)
Expand Down

0 comments on commit 421d982

Please sign in to comment.