Skip to content

Commit

Permalink
abstract into EKFStep
Browse files Browse the repository at this point in the history
  • Loading branch information
jwdink committed Mar 7, 2024
1 parent ad762d8 commit cee6f1b
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 37 deletions.
28 changes: 28 additions & 0 deletions torchcast/kalman_filter/ekf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Dict, Tuple
from torch import Tensor

from .kalman_filter import KalmanStep


class EKFStep(KalmanStep):
def _get_correction(self, mean: Tensor, H: Tensor) -> Tensor:
raise NotImplementedError

def _update(self,
input: Tensor,
mean: Tensor,
cov: Tensor,
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
if kwargs['outlier_threshold'] > 0:
raise NotImplementedError("Outlier rejection is not yet supported for EKF")

orig_H = kwargs['H']
correction = self._get_correction(mean, orig_H)
newH = orig_H - correction

return super()._update(
input=input,
mean=mean,
cov=cov,
kwargs={'H': newH}
)
17 changes: 12 additions & 5 deletions torchcast/kalman_filter/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,26 @@ def _mask_mats(self,
})
return masked_input, new_kwargs

def _update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
def _update(self,
input: Tensor,
mean: Tensor,
cov: Tensor,
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
H = kwargs['H']
R = kwargs['R']
Ht = H.permute(0, 2, 1)
system_covariance = torch.baddbmm(R, H @ cov, Ht)

# measured-mean -> residuals:
if 'measured_mean' in kwargs:
measured_mean = kwargs['measured_mean']
else:
measured_mean = (H @ mean.unsqueeze(-1)).squeeze(-1)
resid = input - measured_mean

# kalman-gain:
K = self._kalman_gain(cov=cov, Ht=Ht, system_covariance=system_covariance)

# residuals:
measured_mean = (H @ mean.unsqueeze(-1)).squeeze(-1)
resid = input - measured_mean

# outlier-rejection:
valid_mask = self._get_update_mask(resid, system_covariance, outlier_threshold=kwargs['outlier_threshold'])

Expand Down
49 changes: 20 additions & 29 deletions torchcast/kalman_filter/poisson_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from backports.cached_property import cached_property
from torch.nn import Softplus

from .kalman_filter import KalmanStep
from .ekf import EKFStep
from ..covariance import Covariance
from ..process import Process
from ..state_space import StateSpaceModel, Predictions
Expand All @@ -39,44 +39,35 @@ def inverse_softplus(x: torch.Tensor, eps: float = .001) -> torch.Tensor:
return out


class PoissonStep(KalmanStep):
class PoissonStep(EKFStep):
# this would ideally be a class-attribute but torch.jit.trace strips them
@torch.jit.ignore()
def get_distribution(self) -> Type[torch.distributions.Distribution]:
return torch.distributions.Poisson

def _get_correction(self, mean: Tensor, H: Tensor) -> Tensor:
raw_mmean = (H @ mean.unsqueeze(-1)).squeeze(-1)

correction = torch.zeros_like(H)
_do_cor = raw_mmean < POISSON_SMALL_THRESH

# derivative of softplus:
correction[_do_cor] = H[_do_cor] / (torch.exp(raw_mmean[_do_cor]) + 1).unsqueeze(-1)
return correction

def _update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
orig_H = kwargs['H']
orig_mmean = (orig_H @ mean.unsqueeze(-1)).squeeze(-1)
measured_mean = softplus(orig_mmean)
# variance = mean
R = torch.diag_embed(measured_mean)

# use EKF:
correction = torch.zeros_like(orig_H)
_do_cor = orig_mmean < POISSON_SMALL_THRESH
# derivative of softplus:
correction[_do_cor] = orig_H[_do_cor] / (torch.exp(orig_mmean[_do_cor]) + 1).unsqueeze(-1)
newH = orig_H - correction

# standard:
Ht = newH.permute(0, 2, 1)
system_covariance = torch.baddbmm(R, newH @ cov, Ht)
K = self._kalman_gain(cov=cov, Ht=Ht, system_covariance=system_covariance)
resid = input - measured_mean

# outlier-rejection:
# TODO: does this make sense for EKF?
valid_mask = self._get_update_mask(resid, system_covariance, outlier_threshold=kwargs['outlier_threshold'])

# update:
new_mean = mean.clone()
new_mean[valid_mask] = mean[valid_mask] + (K[valid_mask] @ resid[valid_mask].unsqueeze(-1)).squeeze(-1)
new_cov = cov.clone()
new_cov[valid_mask] = self._covariance_update(
cov=cov[valid_mask], K=K[valid_mask], H=H[valid_mask], R=R[valid_mask]
kwargs['measured_mean'] = softplus(orig_mmean)
kwargs['R'] = torch.diag_embed(kwargs['measured_mean']) # variance = mean

return super()._update(
input=input,
mean=mean,
cov=cov,
kwargs=kwargs
)
return new_mean, new_cov


class PoissonFilter(StateSpaceModel):
Expand Down
13 changes: 10 additions & 3 deletions torchcast/state_space/ss_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def forward(self,
def predict(self, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
raise NotImplementedError

def _update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
def _update(self,
input: Tensor,
mean: Tensor,
cov: Tensor,
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
raise NotImplementedError

def update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
Expand All @@ -55,8 +59,11 @@ def update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Ten
new_cov = cov.clone()
for groups, val_idx in get_nan_groups(isnan):
masked_input, masked_kwargs = self._mask_mats(groups, val_idx, input=input, kwargs=kwargs)
m,c = self._update(
input=masked_input, mean=mean[groups], cov=cov[groups], kwargs=masked_kwargs
m, c = self._update(
input=masked_input,
mean=mean[groups],
cov=cov[groups],
kwargs=masked_kwargs
)
new_mean[groups] = m
if c is None:
Expand Down

0 comments on commit cee6f1b

Please sign in to comment.