diff --git a/src/pudl/analysis/record_linkage/classify_plants_ferc1.py b/src/pudl/analysis/record_linkage/classify_plants_ferc1.py index ce9688f790..9cb41e4368 100644 --- a/src/pudl/analysis/record_linkage/classify_plants_ferc1.py +++ b/src/pudl/analysis/record_linkage/classify_plants_ferc1.py @@ -37,7 +37,7 @@ def __init__(self, plants_steam_df, metric="euclidean", penalty=100): metric: Distance metric to use in computation. penalty: Penalty to apply to records with the same report year. """ - self.df = plants_steam_df + self.plants_steam_df = plants_steam_df self.metric = metric self.penalty = penalty @@ -48,11 +48,14 @@ def fit(self, X, y=None, **fit_params): # noqa: N803 def transform(self, X, y=None, **fit_params): # noqa: N803 """Compute distance between records then add penalty to records from same year.""" dist_matrix = pairwise_distances(X, metric=self.metric) - report_years = range(self.df.report_year.min(), self.df.report_year.max() + 1) + report_years = range( + self.plants_steam_df.report_year.min(), + self.plants_steam_df.report_year.max() + 1, + ) penalty_matrix = np.full(dist_matrix.shape, 0) for yr in report_years: # get the indices of all the record pairs that have matching report years - yr_idx = self.df[self.df.report_year == yr].index + yr_idx = self.plants_steam_df[self.plants_steam_df.report_year == yr].index yr_match_pairs_idx = np.array(np.meshgrid(yr_idx, yr_idx)).T.reshape(-1, 2) idx_x = yr_match_pairs_idx[:, 0] idx_y = yr_match_pairs_idx[:, 1]