Skip to content

Commit

Permalink
dont use POISSON_SMALL_THRESH for log_prob
Browse files Browse the repository at this point in the history
  • Loading branch information
jwdink committed May 18, 2023
1 parent 9ee27b1 commit ad762d8
Showing 1 changed file with 48 additions and 27 deletions.
75 changes: 48 additions & 27 deletions torchcast/kalman_filter/poisson_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
----------
"""
from typing import Dict, Tuple, Optional, Type, Sequence, Iterable, List
from typing import Dict, Tuple, Optional, Type, Sequence, Iterable, List, TYPE_CHECKING
from warnings import warn

import numpy as np
import torch
from scipy import special as scipy_special
from torch import Tensor, nn
Expand All @@ -22,6 +23,9 @@
from ..process import Process
from ..state_space import StateSpaceModel, Predictions

if TYPE_CHECKING:
from pandas import DataFrame

POISSON_SMALL_THRESH = 10

softplus = Softplus()
Expand All @@ -43,7 +47,6 @@ def get_distribution(self) -> Type[torch.distributions.Distribution]:

def _update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
orig_H = kwargs['H']
#
orig_mmean = (orig_H @ mean.unsqueeze(-1)).squeeze(-1)
measured_mean = softplus(orig_mmean)
# variance = mean
Expand All @@ -57,10 +60,22 @@ def _update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Te
newH = orig_H - correction

# standard:
K = self._kalman_gain(cov=cov, H=newH, R=R)
Ht = newH.permute(0, 2, 1)
system_covariance = torch.baddbmm(R, newH @ cov, Ht)
K = self._kalman_gain(cov=cov, Ht=Ht, system_covariance=system_covariance)
resid = input - measured_mean
new_mean = mean + (K @ resid.unsqueeze(-1)).squeeze(-1)
new_cov = self._covariance_update(cov=cov, K=K, H=newH, R=R)

# outlier-rejection:
# TODO: does this make sense for EKF?
valid_mask = self._get_update_mask(resid, system_covariance, outlier_threshold=kwargs['outlier_threshold'])

# update:
new_mean = mean.clone()
new_mean[valid_mask] = mean[valid_mask] + (K[valid_mask] @ resid[valid_mask].unsqueeze(-1)).squeeze(-1)
new_cov = cov.clone()
new_cov[valid_mask] = self._covariance_update(
cov=cov[valid_mask], K=K[valid_mask], H=H[valid_mask], R=R[valid_mask]
)
return new_mean, new_cov


Expand Down Expand Up @@ -161,38 +176,44 @@ def observe(cls,
H=H
)
means = softplus(means)
if _warn_once.get('poisson_predictions', False):
_warn_once['poisson_predictions'] = True

return means, covs

@classmethod
def plot(cls,
df: 'DataFrame',
group_colname: str = None,
time_colname: str = None,
max_num_groups: int = 1,
split_dt: Optional[np.datetime64] = None,
**kwargs) -> 'DataFrame':
if _warn_once.get('poisson_plot', False):
_warn_once['poisson_plot'] = True
warn(
"Poisson implementation is experimental. Currently, this means that, (1) in plotting, will "
"over-estimate uncertainty for small values, (2) in log-prob, will ignore state-covariance for small "
"values."
"Poisson implementation is experimental. Currently plotting will over-estimate uncertainty."
)
return means, covs
return super().plot(
df=df,
group_colname=group_colname,
time_colname=time_colname,
max_num_groups=max_num_groups,
split_dt=split_dt,
**kwargs
)

@cached_property
def R(self) -> torch.Tensor:
means, _ = self.observe(self.state_means, self.state_covs, R=0.0, H=self.H)
return torch.diag_embed(means)

def _log_prob(self, obs: Tensor, means: Tensor, covs: Tensor) -> Tensor:
if _warn_once.get('poisson_log_prob', False):
_warn_once['poisson_log_prob'] = True
warn(
"Poisson implementation is experimental. Currently log-prob will ignore state-covariance."
)
# TODO: use monte-carlo instead.
# aside from the problem of ignoring state cov when means<POISSON_SMALL_THRESH, the other problem is
# that means is itself an estimate, so even when means > POISSON_SMALL_THRESH, true value might be less
if means.shape[-1] > 1:
raise NotImplementedError("log-prob not currently implemented for poisson when there are multiple measures")
use_mvnorm = (means >= POISSON_SMALL_THRESH).any(-1)
out = torch.zeros_like(obs[..., 0])
# use normal approximation for larger values:
out[use_mvnorm] = torch.distributions.MultivariateNormal(
loc=means[use_mvnorm],
covariance_matrix=covs[use_mvnorm],
validate_args=False
).log_prob(obs[use_mvnorm])
# ignore state-cov and use poisson for smaller values:
out[~use_mvnorm] = self.distribution_cls(
means[~use_mvnorm].squeeze(-1), validate_args=False
).log_prob(obs[~use_mvnorm].squeeze(-1))
out = self.distribution_cls(means.squeeze(-1), validate_args=False).log_prob(obs.squeeze(-1))
return out

def sample(self) -> Tensor:
Expand Down

0 comments on commit ad762d8

Please sign in to comment.