diff --git a/torchcast/exp_smooth/exp_smooth.py b/torchcast/exp_smooth/exp_smooth.py index d47ea33..cf58d24 100644 --- a/torchcast/exp_smooth/exp_smooth.py +++ b/torchcast/exp_smooth/exp_smooth.py @@ -51,6 +51,9 @@ def predict(self, cov: Tensor, mask: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: + if mask.all(): + mask = slice(None) + F = kwargs['F'][mask] new_mean = update_tensor(mean, new=(F @ mean[mask].unsqueeze(-1)).squeeze(-1), mask=mask) diff --git a/torchcast/internals/utils.py b/torchcast/internals/utils.py index ae4afc1..8034e4a 100644 --- a/torchcast/internals/utils.py +++ b/torchcast/internals/utils.py @@ -6,7 +6,7 @@ import numpy as np -def update_tensor(orig: torch.Tensor, new: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: +def update_tensor(orig: torch.Tensor, new: torch.Tensor, mask: Optional[Union[slice, torch.Tensor]]) -> torch.Tensor: """ In some cases we will want to save compute by only performing operations on a subset of a tensor, leaving the other elements as-is. In this case we'd need: @@ -28,13 +28,16 @@ def update_tensor(orig: torch.Tensor, new: torch.Tensor, mask: torch.Tensor) -> :return: If ``mask`` is all True, returns ``new`` (not a copy). Otherwise, returns a new tensor with the same shape as ``orig`` where the elements in ``mask`` are replaced with the elements in ``new``. """ - if mask.all(): + if isinstance(mask, slice) and not mask.start and not mask.stop and not mask.step: + mask = None + if mask is None or mask.all(): return new else: out = orig.clone() out[mask] = new return out + def transpose_last_dims(x: torch.Tensor) -> torch.Tensor: args = list(range(len(x.shape))) args[-2], args[-1] = args[-1], args[-2] diff --git a/torchcast/kalman_filter/kalman_filter.py b/torchcast/kalman_filter/kalman_filter.py index a639ed0..fc21a67 100644 --- a/torchcast/kalman_filter/kalman_filter.py +++ b/torchcast/kalman_filter/kalman_filter.py @@ -31,6 +31,8 @@ def predict(self, cov: Tensor, mask: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: + if mask.all(): + mask = slice(None) F = kwargs['F'][mask] Q = kwargs['Q'][mask]