Skip to content

Commit

Permalink
Merge pull request #15 from strongio/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
jwdink authored Jul 28, 2023
2 parents 6cfa727 + fd87973 commit 59bd742
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 49 deletions.
Binary file removed docs/examples/electricity/es_nn25.pt
Binary file not shown.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
zip_safe=False,
install_requires=[
'backports.cached-property',
'torch>=1.8',
'numpy>=1.4'
'torch>=1.12',
'numpy>=1.4',
],
extras_require={
'tests': ['parameterized>=0.7', 'filterpy>=1.4', 'pandas>=1.0'],
Expand Down
8 changes: 4 additions & 4 deletions torchcast/covariance/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def from_measures(cls,
if isinstance(measures, str):
measures = [measures]
warn(f"`measures` should be a list of strings not a string; interpreted as `{measures}`.")
elif not isinstance(measures[0], str):
# not good duck-typing, but too easy to accidentally pass dataset.measures instead of dataset.measures[0]
raise RuntimeError(f"`measures[0]` is {type(measures[0])}, expected str")
if 'method' not in kwargs and len(measures) > 5:
kwargs['method'] = 'low_rank'
if 'init_diag_multi' not in kwargs:
Expand Down Expand Up @@ -262,17 +265,14 @@ def forward(self,
mini_cov, num_groups=num_groups, num_times=num_times, trailing_dim=[self.param_rank, self.param_rank]
)

pred = None
if self.var_predict_module is not None and not _ignore_input:
pred = self.var_predict_module(*[inputs[x] for x in self.expected_kwargs])
if torch.isnan(pred).any() or torch.isinf(pred).any():
raise RuntimeError(f"{self.id}'s `predict_variance` produced nans/infs")
if (pred < 0).any():
raise RuntimeError(f"{self.id}'s `predict_variance` produced values <0; needs exp/softplus layer.")
pred = validate_gt_shape(pred, num_groups=num_groups, num_times=num_times, trailing_dim=[self.param_rank])

if pred is not None:
diag_multi = torch.diag_embed(torch.exp(pred))
diag_multi = torch.diag_embed(pred)
mini_cov = diag_multi @ mini_cov @ diag_multi

mask = self.mask.unsqueeze(0).unsqueeze(0)
Expand Down
14 changes: 8 additions & 6 deletions torchcast/exp_smooth/exp_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import torch
from torch import Tensor

from torchcast.state_space import Predictions
from torchcast.exp_smooth.smoothing_matrix import SmoothingMatrix
from torchcast.covariance import Covariance
from torchcast.process import Process
Expand All @@ -26,24 +25,27 @@ def _mask_mats(self,
input: Tensor,
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
# torchscript doesn't support super, see: https://github.com/pytorch/pytorch/issues/42885
new_kwargs = kwargs.copy()
if val_idx is None:
return input[groups], {k: v[groups] for k, v in kwargs.items()}
for k in ['H', 'R', 'K']:
new_kwargs[k] = kwargs[k][groups]
return input[groups], new_kwargs
else:
m1d = torch.meshgrid(groups, val_idx, indexing='ij')
m2d = torch.meshgrid(groups, val_idx, val_idx, indexing='ij')
masked_input = input[m1d[0], m1d[1]]
masked_kwargs = {
new_kwargs.update({
'H': kwargs['H'][m1d[0], m1d[1]],
'R': kwargs['R'][m2d[0], m2d[1], m2d[2]],
'K': kwargs['K'][m1d[0], m1d[1]],
}
return masked_input, masked_kwargs
})
return masked_input, new_kwargs

def _update(self,
input: Tensor,
mean: Tensor,
cov: Tensor,
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Optional[Tensor]]:
measured_mean = (kwargs['H'] @ mean.unsqueeze(-1)).squeeze(-1)
resid = input - measured_mean
new_mean = mean + (kwargs['K'] @ resid.unsqueeze(-1)).squeeze(-1)
Expand Down
48 changes: 38 additions & 10 deletions torchcast/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
----------
"""

from typing import Sequence, Dict, List, Iterable

from torchcast.covariance import Covariance
Expand All @@ -20,6 +19,8 @@
from torch import nn, Tensor
from typing_extensions import Final

from torchcast.utils.outliers import mahalanobis_dist


class KalmanStep(StateSpaceStep):
"""
Expand All @@ -40,26 +41,44 @@ def _mask_mats(self,
val_idx: Optional[Tensor],
input: Tensor,
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
new_kwargs = kwargs.copy()
if val_idx is None:
return input[groups], {k: v[groups] for k, v in kwargs.items()}
for k in ['H', 'R']:
new_kwargs[k] = kwargs[k][groups]
return input[groups], new_kwargs
else:
m1d = torch.meshgrid(groups, val_idx, indexing='ij')
m2d = torch.meshgrid(groups, val_idx, val_idx, indexing='ij')
masked_input = input[m1d[0], m1d[1]]
masked_kwargs = {
new_kwargs.update({
'H': kwargs['H'][m1d[0], m1d[1]],
'R': kwargs['R'][m2d[0], m2d[1], m2d[2]]
}
return masked_input, masked_kwargs
})
return masked_input, new_kwargs

def _update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
H = kwargs['H']
R = kwargs['R']
K = self._kalman_gain(cov=cov, H=H, R=R)
Ht = H.permute(0, 2, 1)
system_covariance = torch.baddbmm(R, H @ cov, Ht)

# kalman-gain:
K = self._kalman_gain(cov=cov, Ht=Ht, system_covariance=system_covariance)

# residuals:
measured_mean = (H @ mean.unsqueeze(-1)).squeeze(-1)
resid = input - measured_mean

# outlier-rejection:
if kwargs['outlier_threshold'] > 0:
mdist = mahalanobis_dist(resid, system_covariance)
multi = (mdist - kwargs['outlier_threshold']).clamp(min=0) + 1
R = R * multi.unsqueeze(-1).unsqueeze(-1)

# update:
new_mean = mean + (K @ resid.unsqueeze(-1)).squeeze(-1)
new_cov = self._covariance_update(cov=cov, K=K, H=H, R=R)

return new_mean, new_cov

def _covariance_update(self, cov: Tensor, K: Tensor, H: Tensor, R: Tensor) -> Tensor:
Expand All @@ -71,10 +90,8 @@ def _covariance_update(self, cov: Tensor, K: Tensor, H: Tensor, R: Tensor) -> Te
return ikh @ cov

@staticmethod
def _kalman_gain(cov: Tensor, H: Tensor, R: Tensor) -> Tensor:
Ht = H.permute(0, 2, 1)
def _kalman_gain(cov: Tensor, Ht: Tensor, system_covariance: Tensor) -> Tensor:
covs_measured = cov @ Ht
system_covariance = torch.baddbmm(R, H @ cov, Ht)
A = system_covariance.permute(0, 2, 1)
B = covs_measured.permute(0, 2, 1)
Kt = torch.linalg.solve(A, B)
Expand All @@ -90,14 +107,21 @@ class KalmanFilter(StateSpaceModel):
:param measures: A list of strings specifying the names of the dimensions of the time-series being measured.
:param process_covariance: A module created with ``Covariance.from_processes(processes)``.
:param measure_covariance: A module created with ``Covariance.from_measures(measures)``.
:param outlier_threshold: If specified, used as a threshold-for outlier under-weighting during the `update` step,
using mahalanobis distance; outliers will also be under-weighted when evaluating the ``log_prob()`` of the output
``Predictions``.
:param outlier_burnin: If outlier_threshold is specified, this specifies the number of timesteps to wait before
starting to reject outliers.
"""
ss_step_cls = KalmanStep

def __init__(self,
processes: Sequence[Process],
measures: Optional[Sequence[str]] = None,
process_covariance: Optional[Covariance] = None,
measure_covariance: Optional[Covariance] = None):
measure_covariance: Optional[Covariance] = None,
outlier_threshold: float = 0.,
outlier_burnin: Optional[int] = None):

initial_covariance = Covariance.from_processes(processes, cov_type='initial')

Expand All @@ -106,11 +130,15 @@ def __init__(self,

if measure_covariance is None:
measure_covariance = Covariance.from_measures(measures)
else:
assert measure_covariance.rank == 1 or measure_covariance.rank == len(measures)

super().__init__(
processes=processes,
measures=measures,
measure_covariance=measure_covariance,
outlier_threshold=outlier_threshold,
outlier_burnin=outlier_burnin
)
self.process_covariance = process_covariance.set_id('process_covariance')
self.initial_covariance = initial_covariance.set_id('initial_covariance')
Expand Down
27 changes: 23 additions & 4 deletions torchcast/state_space/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,27 @@ class StateSpaceModel(nn.Module):
:param processes: A list of :class:`.Process` modules.
:param measures: A list of strings specifying the names of the dimensions of the time-series being measured.
:param measure_covariance: A module created with ``Covariance.from_measures(measures)``.
:param outlier_threshold: If specified, used as a threshold-for outlier under-weighting during the `update` step,
using mahalanobis distance; outliers will also be under-weighted when evaluating the ``log_prob()`` of the output
``Predictions``.
:param outlier_burnin: If outlier_threshold is specified, this specifies the number of steps to wait before
starting to reject outliers.
"""
ss_step_cls: Type[StateSpaceStep]

def __init__(self,
processes: Sequence[Process],
measures: Optional[Sequence[str]] = None,
measure_covariance: Optional[Covariance] = None):
measures: Optional[Sequence[str]],
measure_covariance: Covariance,
outlier_threshold: float = 0.0,
outlier_burnin: Optional[int] = None):
super().__init__()

self.outlier_threshold = outlier_threshold
if self.outlier_threshold and outlier_burnin is None:
raise ValueError("If `outlier_threshold` is set, `outlier_burnin` must be set as well.")
self.outlier_burnin = outlier_burnin or 0

if isinstance(measures, str):
measures = [measures]
warn(f"`measures` should be a list of strings not a string; interpreted as `{measures}`.")
Expand Down Expand Up @@ -105,6 +117,9 @@ def fit(self,
optimizer = torch.optim.LBFGS([p for p in self.parameters() if p.requires_grad],
max_iter=10, line_search_fn='strong_wolfe', lr=.5)

if self.outlier_threshold and verbose:
print("``outlier_threshold`` is experimental")

if set_initial_values:
self.set_initial_values(y)

Expand Down Expand Up @@ -441,18 +456,22 @@ def _script_forward(self,
mean1s.append(mean1step)
cov1s.append(cov1step)
if t < len(inputs):
update_kwargs_t = {k: v[t] for k, v in update_kwargs.items()}
update_kwargs_t['outlier_threshold'] = torch.tensor(
self.outlier_threshold if t > self.outlier_burnin else 0.
)
meanu, covu = self.ss_step.update(
inputs[t],
mean1step,
cov1step,
{k: v[t] for k, v in update_kwargs.items()}
update_kwargs_t,
)
else:
meanu, covu = mean1step, cov1step
meanus.append(meanu)
covus.append(covu)

# 2nd loop to get n_step updates:
# 2nd loop to get n_step predicts:
# idx: Dict[int, int] = {}
meanps: Dict[int, Tensor] = {}
covps: Dict[int, Tensor] = {}
Expand Down
38 changes: 29 additions & 9 deletions torchcast/state_space/predictions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Tuple, Union, Optional, Dict, Iterator, Sequence
from warnings import warn

import torch
from torch import nn, Tensor
Expand All @@ -11,6 +10,7 @@
from torchcast.internals.utils import get_nan_groups, is_near_zero

from torchcast.utils.data import TimeSeriesDataset
from torchcast.utils.outliers import mahalanobis_dist


class Predictions(nn.Module):
Expand Down Expand Up @@ -51,11 +51,13 @@ def __init__(self,
model = {
'distribution_cls': model.ss_step.get_distribution(),
'measures': model.measures,
'all_state_elements': all_state_elements
'all_state_elements': all_state_elements,
'outlier_threshold': model.outlier_threshold
}
self.distribution_cls = model['distribution_cls']
self.measures = model['measures']
self.all_state_elements = model['all_state_elements']
self.outlier_threshold = model['outlier_threshold']

# for lazily populated properties:
self._means = self._covs = None
Expand Down Expand Up @@ -202,6 +204,21 @@ def log_prob(self, obs: Tensor) -> Tensor:
n_state_dim = self.state_means.shape[-1]

obs_flat = obs.reshape(-1, n_measure_dim)
means_flat = self.means.view(-1, n_measure_dim)
covs_flat = self.covs.view(-1, n_measure_dim, n_measure_dim)

# if the model used an outlier threshold, under-weight outliers
weights = torch.ones(obs_flat.shape[0], dtype=self.state_means.dtype, device=self.state_means.device)
if self.outlier_threshold:
obs_flat = obs_flat.clone()
for gt_idx, valid_idx in get_nan_groups(torch.isnan(obs_flat)):
if valid_idx is None:
mdist = mahalanobis_dist(obs_flat[gt_idx] - means_flat[gt_idx], covs_flat[gt_idx])
multi = (mdist - self.outlier_threshold).clamp(min=0) + 1
weights[gt_idx] = 1 / multi
else:
raise NotImplemented

state_means_flat = self.state_means.view(-1, n_state_dim)
state_covs_flat = self.state_covs.view(-1, n_state_dim, n_state_dim)
H_flat = self.H.view(-1, n_measure_dim, n_state_dim)
Expand All @@ -211,8 +228,8 @@ def log_prob(self, obs: Tensor) -> Tensor:
for gt_idx, valid_idx in get_nan_groups(torch.isnan(obs_flat)):
if valid_idx is None:
gt_obs = obs_flat[gt_idx]
gt_means_flat = self.means.view(-1, n_measure_dim)[gt_idx]
gt_covs_flat = self.covs.view(-1, n_measure_dim, n_measure_dim)[gt_idx]
gt_means_flat = means_flat[gt_idx]
gt_covs_flat = covs_flat[gt_idx]
else:
mask1d = torch.meshgrid(gt_idx, valid_idx, indexing='ij')
mask2d = torch.meshgrid(gt_idx, valid_idx, valid_idx, indexing='ij')
Expand All @@ -225,6 +242,8 @@ def log_prob(self, obs: Tensor) -> Tensor:
gt_obs = obs_flat[mask1d]
lp_flat[gt_idx] = self._log_prob(gt_obs, gt_means_flat, gt_covs_flat)

lp_flat = lp_flat * weights

return lp_flat.view(obs.shape[0:2])

def _log_prob(self, obs: Tensor, means: Tensor, covs: Tensor) -> Tensor:
Expand Down Expand Up @@ -514,8 +533,9 @@ def _model_attributes(self) -> dict:
"""
Has the attributes of a KalmanFilter that are needed in __init__
"""
return dict(
measures=self.measures,
distribution_cls=self.distribution_cls,
all_state_elements=self.all_state_elements
)
return {
'measures': self.measures,
'distribution_cls': self.distribution_cls,
'all_state_elements': self.all_state_elements,
'outlier_threshold': self.outlier_threshold
}
Loading

0 comments on commit 59bd742

Please sign in to comment.