Skip to content

Commit

Permalink
move to _get_update_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
jwdink committed May 18, 2023
1 parent 7f511d5 commit 7c1b3c4
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 18 deletions.
11 changes: 7 additions & 4 deletions torchcast/exp_smooth/exp_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,21 @@ def _mask_mats(self,
input: Tensor,
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
# torchscript doesn't support super, see: https://github.com/pytorch/pytorch/issues/42885
new_kwargs = kwargs.copy()
if val_idx is None:
return input[groups], {k: v[groups] for k, v in kwargs.items()}
for k in ['H', 'R', 'K']:
new_kwargs[k] = kwargs[k][groups]
return input[groups], new_kwargs
else:
m1d = torch.meshgrid(groups, val_idx, indexing='ij')
m2d = torch.meshgrid(groups, val_idx, val_idx, indexing='ij')
masked_input = input[m1d[0], m1d[1]]
masked_kwargs = {
new_kwargs.update({
'H': kwargs['H'][m1d[0], m1d[1]],
'R': kwargs['R'][m2d[0], m2d[1], m2d[2]],
'K': kwargs['K'][m1d[0], m1d[1]],
}
return masked_input, masked_kwargs
})
return masked_input, new_kwargs

def _update(self,
input: Tensor,
Expand Down
25 changes: 15 additions & 10 deletions torchcast/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,20 @@ def _mask_mats(self,
val_idx: Optional[Tensor],
input: Tensor,
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
new_kwargs = kwargs.copy()
if val_idx is None:
new_kwargs = kwargs.copy()
for k in ['H', 'R']:
new_kwargs[k] = kwargs[k][groups]
return input[groups], new_kwargs
else:
m1d = torch.meshgrid(groups, val_idx, indexing='ij')
m2d = torch.meshgrid(groups, val_idx, val_idx, indexing='ij')
masked_input = input[m1d[0], m1d[1]]
masked_kwargs = {
new_kwargs.update({
'H': kwargs['H'][m1d[0], m1d[1]],
'R': kwargs['R'][m2d[0], m2d[1], m2d[2]]
}
return masked_input, masked_kwargs
})
return masked_input, new_kwargs

def _update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
H = kwargs['H']
Expand All @@ -69,12 +69,7 @@ def _update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Te
resid = input - measured_mean

# outlier-rejection:
valid_mask = torch.ones(len(input), dtype=torch.bool, device=input.device)
if 'outlier_threshold' in kwargs.keys() and kwargs['outlier_threshold'] > 0:
mdist = mahalanobis_dist(resid, system_covariance)
valid_mask = mdist <= kwargs['outlier_threshold']
# if (~valid_mask).any():
# print('outlier idxs:', torch.where(~valid_mask)[0])
valid_mask = self._get_update_mask(resid, system_covariance, outlier_threshold=kwargs['outlier_threshold'])

# update:
new_mean = mean.clone()
Expand All @@ -86,6 +81,16 @@ def _update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Te

return new_mean, new_cov

@staticmethod
def _get_update_mask(resid: torch.Tensor,
system_covariance: torch.Tensor,
outlier_threshold: torch.Tensor) -> torch.Tensor:
if outlier_threshold > 0:
mdist = mahalanobis_dist(resid, system_covariance)
return mdist <= outlier_threshold
else:
return torch.ones(len(resid), dtype=torch.bool, device=resid.device)

def _covariance_update(self, cov: Tensor, K: Tensor, H: Tensor, R: Tensor) -> Tensor:
I = torch.eye(cov.shape[1], dtype=cov.dtype, device=cov.device).unsqueeze(0)
ikh = I - K @ H
Expand Down
8 changes: 4 additions & 4 deletions torchcast/state_space/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self,
self.outlier_threshold = outlier_threshold
if self.outlier_threshold and outlier_burnin is None:
raise ValueError("If `outlier_threshold` is set, `outlier_burnin` must be set as well.")
self.outlier_burnin = outlier_burnin
self.outlier_burnin = outlier_burnin or 0

if isinstance(measures, str):
measures = [measures]
Expand Down Expand Up @@ -452,9 +452,9 @@ def _script_forward(self,
cov1s.append(cov1step)
if t < len(inputs):
update_kwargs_t = {k: v[t] for k, v in update_kwargs.items()}
if self.outlier_burnin is not None:
if t > self.outlier_burnin: # short-circuiting and doesn't work with jit
update_kwargs_t['outlier_threshold'] = torch.tensor(self.outlier_threshold)
update_kwargs_t['outlier_threshold'] = torch.tensor(
self.outlier_threshold if t > self.outlier_burnin else 0.
)
meanu, covu = self.ss_step.update(
inputs[t],
mean1step,
Expand Down

0 comments on commit 7c1b3c4

Please sign in to comment.