From 67c2c72ba6e6a6c8306fafcf19658f6210b06643 Mon Sep 17 00:00:00 2001 From: Jacob Date: Sun, 30 Jul 2023 17:53:56 -0500 Subject: [PATCH 1/2] provide bad group indices --- torchcast/state_space/predictions.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchcast/state_space/predictions.py b/torchcast/state_space/predictions.py index bc5abe4..6f3cd0c 100644 --- a/torchcast/state_space/predictions.py +++ b/torchcast/state_space/predictions.py @@ -82,7 +82,11 @@ def state_means(self) -> torch.Tensor: if not isinstance(self._state_means, torch.Tensor): self._state_means = torch.stack(self._state_means, 1) if torch.isnan(self._state_means).any(): - raise ValueError("`nans` in `state_means`") + if torch.isnan(self._state_means).all(): + raise ValueError("`nans` in all groups' `state_means`") + else: + groups, *_ = zip(*torch.isnan(self._state_means).nonzero().tolist()) + raise ValueError(f"`nans` in `state_means` for group-indices: {set(groups)}") return self._state_means @cached_property From 634a681116a93d1bf61563c0b72a671892d3127f Mon Sep 17 00:00:00 2001 From: Jacob Date: Tue, 8 Aug 2023 14:36:35 -0500 Subject: [PATCH 2/2] fix global_max --- torchcast/utils/data.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torchcast/utils/data.py b/torchcast/utils/data.py index 9ad00e0..18ae5df 100644 --- a/torchcast/utils/data.py +++ b/torchcast/utils/data.py @@ -628,8 +628,13 @@ def complete_times(data: 'DataFrame', # (e.g. does not match behavior of `my_dates.to_period('W').dt.to_timestamp()`) dt_unit = pd.Timedelta('7 days 00:00:00') + max_time = data[time_colname].max() + if global_max is True: # they can specify a specific value, or pass True for the max in the data + global_max = max_time + # or they can leave global_max=None, in which case will filter to group-specific max below + df_grid = pd.DataFrame( - {time_colname: pd.date_range(data[time_colname].min(), data[time_colname].max(), freq=dt_unit)} + {time_colname: pd.date_range(data[time_colname].min(), global_max or max_time, freq=dt_unit)} ) df_group_summary = data. \ @@ -638,8 +643,6 @@ def complete_times(data: 'DataFrame', _max=(time_colname, 'max')). \ reset_index() if global_max: - if global_max is True: - global_max = df_group_summary['_max'].max() df_group_summary['_max'] = global_max # cross-join for all times to all groups (todo: not very memory efficient)