diff --git a/torchcast/exp_smooth/exp_smooth.py b/torchcast/exp_smooth/exp_smooth.py index 2a3f59c..e67886a 100644 --- a/torchcast/exp_smooth/exp_smooth.py +++ b/torchcast/exp_smooth/exp_smooth.py @@ -43,7 +43,6 @@ def _update(self, resid = input - measured_mean new_mean = mean + (kwargs['K'] @ resid.unsqueeze(-1)).squeeze(-1) # _update doesn't waste compute creating new_cov; in predict cov will be replaced by cov1step - # TODO: why not replace it here? new_cov = torch.tensor(0.0, dtype=mean.dtype, device=mean.device) return new_mean, new_cov @@ -57,7 +56,11 @@ def predict(self, new_mean = update_tensor(mean, new=(F @ mean[mask].unsqueeze(-1)).squeeze(-1), mask=mask) new_cov = kwargs['cov1step'] if len(cov.shape): # see note in _update() above - new_cov = update_tensor(new_cov.clone(), new=F @ cov[mask] @ F.permute(0, 2, 1), mask=mask) + new_cov = update_tensor( + orig=new_cov, + new=new_cov[mask] + F @ cov[mask] @ F.permute(0, 2, 1), + mask=mask + ) return new_mean, new_cov