Skip to content

Commit

Permalink
Merge pull request #16 from strongio/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
jwdink authored Aug 8, 2023
2 parents 59bd742 + 634a681 commit 74ffd26
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
6 changes: 5 additions & 1 deletion torchcast/state_space/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ def state_means(self) -> torch.Tensor:
if not isinstance(self._state_means, torch.Tensor):
self._state_means = torch.stack(self._state_means, 1)
if torch.isnan(self._state_means).any():
raise ValueError("`nans` in `state_means`")
if torch.isnan(self._state_means).all():
raise ValueError("`nans` in all groups' `state_means`")
else:
groups, *_ = zip(*torch.isnan(self._state_means).nonzero().tolist())
raise ValueError(f"`nans` in `state_means` for group-indices: {set(groups)}")
return self._state_means

@cached_property
Expand Down
9 changes: 6 additions & 3 deletions torchcast/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,8 +628,13 @@ def complete_times(data: 'DataFrame',
# (e.g. does not match behavior of `my_dates.to_period('W').dt.to_timestamp()`)
dt_unit = pd.Timedelta('7 days 00:00:00')

max_time = data[time_colname].max()
if global_max is True: # they can specify a specific value, or pass True for the max in the data
global_max = max_time
# or they can leave global_max=None, in which case will filter to group-specific max below

df_grid = pd.DataFrame(
{time_colname: pd.date_range(data[time_colname].min(), data[time_colname].max(), freq=dt_unit)}
{time_colname: pd.date_range(data[time_colname].min(), global_max or max_time, freq=dt_unit)}
)

df_group_summary = data. \
Expand All @@ -638,8 +643,6 @@ def complete_times(data: 'DataFrame',
_max=(time_colname, 'max')). \
reset_index()
if global_max:
if global_max is True:
global_max = df_group_summary['_max'].max()
df_group_summary['_max'] = global_max

# cross-join for all times to all groups (todo: not very memory efficient)
Expand Down

0 comments on commit 74ffd26

Please sign in to comment.