From cee6f1bebdad05038f1a4d1a0315bdbb05975067 Mon Sep 17 00:00:00 2001 From: Jacob Date: Thu, 7 Mar 2024 08:25:58 -0600 Subject: [PATCH] abstract into EKFStep --- torchcast/kalman_filter/ekf.py | 28 +++++++++++++ torchcast/kalman_filter/kalman_filter.py | 17 +++++--- torchcast/kalman_filter/poisson_filter.py | 49 +++++++++-------------- torchcast/state_space/ss_step.py | 13 ++++-- 4 files changed, 70 insertions(+), 37 deletions(-) create mode 100644 torchcast/kalman_filter/ekf.py diff --git a/torchcast/kalman_filter/ekf.py b/torchcast/kalman_filter/ekf.py new file mode 100644 index 0000000..7ab46aa --- /dev/null +++ b/torchcast/kalman_filter/ekf.py @@ -0,0 +1,28 @@ +from typing import Dict, Tuple +from torch import Tensor + +from .kalman_filter import KalmanStep + + +class EKFStep(KalmanStep): + def _get_correction(self, mean: Tensor, H: Tensor) -> Tensor: + raise NotImplementedError + + def _update(self, + input: Tensor, + mean: Tensor, + cov: Tensor, + kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: + if kwargs['outlier_threshold'] > 0: + raise NotImplementedError("Outlier rejection is not yet supported for EKF") + + orig_H = kwargs['H'] + correction = self._get_correction(mean, orig_H) + newH = orig_H - correction + + return super()._update( + input=input, + mean=mean, + cov=cov, + kwargs={'H': newH} + ) diff --git a/torchcast/kalman_filter/kalman_filter.py b/torchcast/kalman_filter/kalman_filter.py index 0bb424b..8b8bbbc 100644 --- a/torchcast/kalman_filter/kalman_filter.py +++ b/torchcast/kalman_filter/kalman_filter.py @@ -55,19 +55,26 @@ def _mask_mats(self, }) return masked_input, new_kwargs - def _update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: + def _update(self, + input: Tensor, + mean: Tensor, + cov: Tensor, + kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: H = kwargs['H'] R = kwargs['R'] Ht = H.permute(0, 2, 1) system_covariance = torch.baddbmm(R, H @ cov, Ht) + # measured-mean -> residuals: + if 'measured_mean' in kwargs: + measured_mean = kwargs['measured_mean'] + else: + measured_mean = (H @ mean.unsqueeze(-1)).squeeze(-1) + resid = input - measured_mean + # kalman-gain: K = self._kalman_gain(cov=cov, Ht=Ht, system_covariance=system_covariance) - # residuals: - measured_mean = (H @ mean.unsqueeze(-1)).squeeze(-1) - resid = input - measured_mean - # outlier-rejection: valid_mask = self._get_update_mask(resid, system_covariance, outlier_threshold=kwargs['outlier_threshold']) diff --git a/torchcast/kalman_filter/poisson_filter.py b/torchcast/kalman_filter/poisson_filter.py index 0f87908..fb082ff 100644 --- a/torchcast/kalman_filter/poisson_filter.py +++ b/torchcast/kalman_filter/poisson_filter.py @@ -18,7 +18,7 @@ from backports.cached_property import cached_property from torch.nn import Softplus -from .kalman_filter import KalmanStep +from .ekf import EKFStep from ..covariance import Covariance from ..process import Process from ..state_space import StateSpaceModel, Predictions @@ -39,44 +39,35 @@ def inverse_softplus(x: torch.Tensor, eps: float = .001) -> torch.Tensor: return out -class PoissonStep(KalmanStep): +class PoissonStep(EKFStep): # this would ideally be a class-attribute but torch.jit.trace strips them @torch.jit.ignore() def get_distribution(self) -> Type[torch.distributions.Distribution]: return torch.distributions.Poisson + def _get_correction(self, mean: Tensor, H: Tensor) -> Tensor: + raw_mmean = (H @ mean.unsqueeze(-1)).squeeze(-1) + + correction = torch.zeros_like(H) + _do_cor = raw_mmean < POISSON_SMALL_THRESH + + # derivative of softplus: + correction[_do_cor] = H[_do_cor] / (torch.exp(raw_mmean[_do_cor]) + 1).unsqueeze(-1) + return correction + def _update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: orig_H = kwargs['H'] orig_mmean = (orig_H @ mean.unsqueeze(-1)).squeeze(-1) - measured_mean = softplus(orig_mmean) - # variance = mean - R = torch.diag_embed(measured_mean) - # use EKF: - correction = torch.zeros_like(orig_H) - _do_cor = orig_mmean < POISSON_SMALL_THRESH - # derivative of softplus: - correction[_do_cor] = orig_H[_do_cor] / (torch.exp(orig_mmean[_do_cor]) + 1).unsqueeze(-1) - newH = orig_H - correction - - # standard: - Ht = newH.permute(0, 2, 1) - system_covariance = torch.baddbmm(R, newH @ cov, Ht) - K = self._kalman_gain(cov=cov, Ht=Ht, system_covariance=system_covariance) - resid = input - measured_mean - - # outlier-rejection: - # TODO: does this make sense for EKF? - valid_mask = self._get_update_mask(resid, system_covariance, outlier_threshold=kwargs['outlier_threshold']) - - # update: - new_mean = mean.clone() - new_mean[valid_mask] = mean[valid_mask] + (K[valid_mask] @ resid[valid_mask].unsqueeze(-1)).squeeze(-1) - new_cov = cov.clone() - new_cov[valid_mask] = self._covariance_update( - cov=cov[valid_mask], K=K[valid_mask], H=H[valid_mask], R=R[valid_mask] + kwargs['measured_mean'] = softplus(orig_mmean) + kwargs['R'] = torch.diag_embed(kwargs['measured_mean']) # variance = mean + + return super()._update( + input=input, + mean=mean, + cov=cov, + kwargs=kwargs ) - return new_mean, new_cov class PoissonFilter(StateSpaceModel): diff --git a/torchcast/state_space/ss_step.py b/torchcast/state_space/ss_step.py index 903898c..e5a8810 100644 --- a/torchcast/state_space/ss_step.py +++ b/torchcast/state_space/ss_step.py @@ -29,7 +29,11 @@ def forward(self, def predict(self, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: raise NotImplementedError - def _update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: + def _update(self, + input: Tensor, + mean: Tensor, + cov: Tensor, + kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: raise NotImplementedError def update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: @@ -55,8 +59,11 @@ def update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Ten new_cov = cov.clone() for groups, val_idx in get_nan_groups(isnan): masked_input, masked_kwargs = self._mask_mats(groups, val_idx, input=input, kwargs=kwargs) - m,c = self._update( - input=masked_input, mean=mean[groups], cov=cov[groups], kwargs=masked_kwargs + m, c = self._update( + input=masked_input, + mean=mean[groups], + cov=cov[groups], + kwargs=masked_kwargs ) new_mean[groups] = m if c is None: