Skip to content

Commit

Permalink
reduce number of calls to mask.all()
Browse files Browse the repository at this point in the history
  • Loading branch information
jwdink committed Jan 24, 2025
1 parent 0768012 commit 2c5c9d7
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
3 changes: 3 additions & 0 deletions torchcast/exp_smooth/exp_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def predict(self,
cov: Tensor,
mask: Tensor,
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
if mask.all():
mask = slice(None)

F = kwargs['F'][mask]

new_mean = update_tensor(mean, new=(F @ mean[mask].unsqueeze(-1)).squeeze(-1), mask=mask)
Expand Down
7 changes: 5 additions & 2 deletions torchcast/internals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np


def update_tensor(orig: torch.Tensor, new: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
def update_tensor(orig: torch.Tensor, new: torch.Tensor, mask: Optional[Union[slice, torch.Tensor]]) -> torch.Tensor:
"""
In some cases we will want to save compute by only performing operations on a subset of a tensor, leaving the other
elements as-is. In this case we'd need:
Expand All @@ -28,13 +28,16 @@ def update_tensor(orig: torch.Tensor, new: torch.Tensor, mask: torch.Tensor) ->
:return: If ``mask`` is all True, returns ``new`` (not a copy). Otherwise, returns a new tensor with the same shape
as ``orig`` where the elements in ``mask`` are replaced with the elements in ``new``.
"""
if mask.all():
if isinstance(mask, slice) and not mask.start and not mask.stop and not mask.step:
mask = None
if mask is None or mask.all():
return new
else:
out = orig.clone()
out[mask] = new
return out


def transpose_last_dims(x: torch.Tensor) -> torch.Tensor:
args = list(range(len(x.shape)))
args[-2], args[-1] = args[-1], args[-2]
Expand Down
2 changes: 2 additions & 0 deletions torchcast/kalman_filter/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def predict(self,
cov: Tensor,
mask: Tensor,
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
if mask.all():
mask = slice(None)
F = kwargs['F'][mask]
Q = kwargs['Q'][mask]

Expand Down

0 comments on commit 2c5c9d7

Please sign in to comment.