Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: move expensive TimezoneFinder call out of top level. #4034

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/pudl/transform/eia.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@
logger = pudl.logging_helpers.get_logger(__name__)


TZ_FINDER = timezonefinder.TimezoneFinder()
"""A global TimezoneFinder to cache geographies in memory for faster access."""


class EiaEntity(StrEnum):
"""Enum for the different types of EIA entities."""

Expand All @@ -58,7 +54,7 @@ class EiaEntity(StrEnum):
GENERATORS = auto()


def find_timezone(*, lng=None, lat=None, state=None, strict=True):
def find_timezone(*, lng=None, lat=None, state=None, strict=True, tz_finder=None):
"""Find the timezone associated with the a specified input location.

Note that this function requires named arguments. The names are lng, lat,
Expand All @@ -84,10 +80,10 @@ def find_timezone(*, lng=None, lat=None, state=None, strict=True):
Update docstring.
"""
try:
tz = TZ_FINDER.timezone_at(lng=lng, lat=lat)
tz = tz_finder.timezone_at(lng=lng, lat=lat)
if tz is None: # Try harder
# Could change the search radius as well
tz = TZ_FINDER.closest_timezone_at(lng=lng, lat=lat)
tz = tz_finder.closest_timezone_at(lng=lng, lat=lat)
# For some reason w/ Python 3.6 we get a ValueError here, but with
# Python 3.7 we get an OverflowError...
except (OverflowError, ValueError) as err:
Expand Down Expand Up @@ -312,9 +308,16 @@ def _add_timezone(plants_entity: pd.DataFrame) -> pd.DataFrame:
A DataFrame containing the same table, with a "timezone" column added.
Timezone may be missing if lat / lon is missing or invalid.
"""
# Takes 300ms. Used to be at module level so we wouldn't have to initialize every time, but that makes import time slow.
tz_finder = timezonefinder.TimezoneFinder()

plants_entity["timezone"] = plants_entity.apply(
lambda row: find_timezone(
lng=row["longitude"], lat=row["latitude"], state=row["state"], strict=False
lng=row["longitude"],
lat=row["latitude"],
state=row["state"],
strict=False,
tz_finder=tz_finder,
),
axis=1,
)
Expand Down