Skip to content

Commit

Permalink
fix merge
Browse files Browse the repository at this point in the history
  • Loading branch information
jwdink committed Jul 14, 2024
1 parent 4f80b5f commit 8694115
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 53 deletions.
11 changes: 9 additions & 2 deletions torchcast/kalman_filter/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,20 @@ def _mask_mats(self,
})
return masked_input, new_kwargs

def _update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
def _update(self,
input: Tensor,
mean: Tensor,
cov: Tensor,
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
H = kwargs['H']
R = kwargs['R']
Ht = H.permute(0, 2, 1)

# residuals:
measured_mean = (H @ mean.unsqueeze(-1)).squeeze(-1)
if 'measured_mean' in kwargs: # calculated by super
measured_mean = kwargs['measured_mean']
else:
measured_mean = (H @ mean.unsqueeze(-1)).squeeze(-1)
resid = input - measured_mean

HcHt = H @ cov @ Ht
Expand Down
63 changes: 37 additions & 26 deletions torchcast/state_space/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def __init__(self,
warn(f"`measures` should be a list of strings not a string; interpreted as `{measures}`.")
self._validate(processes, measures)

self.measure_covariance = measure_covariance.set_id('measure_covariance')
self.measure_covariance = measure_covariance
if self.measure_covariance:
self.measure_covariance.set_id('measure_covariance')

self.ss_step = self.ss_step_cls()

Expand All @@ -66,8 +68,7 @@ def __init__(self,
# the initial mean
self.initial_mean = torch.nn.Parameter(.1 * torch.randn(self.state_rank))

# can disable for debugging/tests:
self._scale_by_measure_var = True
self._scale_by_measure_var = bool(self.measure_covariance)

@torch.jit.ignore()
def fit(self,
Expand Down Expand Up @@ -102,9 +103,10 @@ def fit(self,
:param get_loss: A function that takes the ``Predictions` object and the input data and returns the loss.
Default is ``lambda pred, y: -pred.log_prob(y).mean()``.
:param loss_callback: Deprecated; use ``get_loss`` instead.
:param set_initial_values: Default is to set ``initial_mean`` to sensible value given ``y``. This helps speed
up training if the data are not centered. Set to ``False`` if you're resuming training from a previous
``fit()`` call.
:param set_initial_values: Will set ``initial_mean`` to sensible value given ``y``, which helps speed
up training if the data are not centered. This argument determines the number of timesteps of ``y`` to use
when doing so (default 1). Set to 0/`False``, if you're resuming training from a previous ``fit()`` call. Set
to a larger value for sparse data where the first timestep isn't informative enough.
:param kwargs: Further keyword-arguments passed to :func:`StateSpaceModel.forward()`.
:param callable_kwargs: The kwargs passed to the forward pass are static, but sometimes you want to recompute
them each iteration. The values in this dictionary are functions that will be called each iteration to
Expand All @@ -123,8 +125,7 @@ def fit(self,
if self.outlier_threshold and verbose:
print("``outlier_threshold`` is experimental")

if set_initial_values:
self.set_initial_values(y)
self.set_initial_values(y, n=set_initial_values, verbose=verbose > 1)

if not get_loss:
get_loss = lambda pred, y: -pred.log_prob(y).mean()
Expand Down Expand Up @@ -181,10 +182,15 @@ def closure():
return self

@torch.jit.ignore()
def set_initial_values(self, y: Tensor):
def set_initial_values(self, y: Tensor, n: int, ilink: Optional[callable] = None, verbose: bool = True):
if not n:
return
if 'initial_mean' not in self.state_dict():
return

if ilink is None:
ilink = lambda x: x

assert len(self.measures) == y.shape[-1]

hits = {m: [] for m in self.measures}
Expand All @@ -197,11 +203,15 @@ def set_initial_values(self, y: Tensor):
assert process.measure

hits[process.measure].append(process.id)
se_idx = process.state_elements.index('position')
measure_idx = list(self.measures).index(process.measure)
with torch.no_grad():
t0 = y[:, 0, measure_idx]
self.state_dict()['initial_mean'][self.process_to_slice[pid][0]] = \
t0[~torch.isnan(t0) & ~torch.isinf(t0)].mean()
t0 = y[:, 0:n, measure_idx]
init_mean = ilink(t0[~torch.isnan(t0) & ~torch.isinf(t0)].mean())
if verbose:
print(f"Initializing {pid}.position to {init_mean.item()}")
# TODO instead of [0], should actually get index of 'position->position'
self.state_dict()['initial_mean'][self.process_to_slice[pid][se_idx]] = init_mean

for measure, procs in hits.items():
if len(procs) > 1:
Expand Down Expand Up @@ -244,7 +254,8 @@ def _validate(processes: Sequence[Process], measures: Sequence[str]):
def design_modules(self) -> Iterable[Tuple[str, nn.Module]]:
for pid in self.processes:
yield pid, self.processes[pid]
yield 'measure_covariance', self.measure_covariance
if self.measure_covariance:
yield 'measure_covariance', self.measure_covariance

@torch.jit.ignore()
def forward(self,
Expand Down Expand Up @@ -314,7 +325,7 @@ def forward(self,
raise ValueError("`out_timesteps` must be an int.")
out_timesteps = int(out_timesteps)

preds, updates, R, H = self._script_forward(
preds, updates, design_mats = self._script_forward(
input=input,
initial_state=initial_state,
n_step=n_step,
Expand All @@ -326,23 +337,26 @@ def forward(self,
**kwargs
)
)
return self._generate_predictions(preds, R, H, updates if include_updates_in_output else None)
return self._generate_predictions(
preds=preds,
updates=updates if include_updates_in_output else None,
**design_mats,
)

@torch.jit.ignore
def _generate_predictions(self,
preds: Tuple[List[Tensor], List[Tensor]],
R: List[Tensor],
H: List[Tensor],
updates: Optional[Tuple[List[Tensor], List[Tensor]]] = None) -> 'Predictions':
updates: Optional[Tuple[List[Tensor], List[Tensor]]] = None,
**kwargs) -> 'Predictions':
"""
StateSpace subclasses may pass subclasses of `Predictions` (e.g. for custom log-prob)
"""

kwargs = {
'state_means': preds[0],
'state_covs': preds[1],
'R': R,
'H': H,
'R': kwargs['R'],
'H': kwargs['H'],
'model': self
}
if updates is not None:
Expand Down Expand Up @@ -400,8 +414,7 @@ def _script_forward(self,
) -> Tuple[
Tuple[List[Tensor], List[Tensor]],
Tuple[List[Tensor], List[Tensor]],
List[Tensor],
List[Tensor]
Dict[str, List[Tensor]]
]:
"""
:param input: A (group X time X measures) tensor. Optional if `initial_state` is specified.
Expand Down Expand Up @@ -506,10 +519,8 @@ def _script_forward(self,

preds = [meanps[t] for t in range(out_timesteps)], [covps[t] for t in range(out_timesteps)]
updates = meanus, covus
R = update_kwargs['R']
H = update_kwargs['H']

return preds, updates, R, H
return preds, updates, update_kwargs

def _build_design_mats(self,
kwargs_per_process: Dict[str, Dict[str, Tensor]],
Expand Down Expand Up @@ -624,7 +635,7 @@ def simulate(self,
from tqdm.auto import tqdm
progress = tqdm
except ImportError:
warn("`progress=True` requires package `tqdm`.")
warn("verbose>1 w/progress-bar requires package `tqdm`.")
progress = lambda x: x
times = progress(times)

Expand Down
63 changes: 38 additions & 25 deletions torchcast/state_space/predictions.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from functools import cached_property
from typing import Tuple, Union, Optional, Dict, Iterator, Sequence
from typing import Tuple, Union, Optional, Dict, Iterator, Sequence, TYPE_CHECKING
from warnings import warn

import numpy as np
import torch
from scipy.stats import norm as ScipyNorm
from torch import nn, Tensor

from torchcast.internals.utils import get_nan_groups, is_near_zero
import numpy as np

from backports.cached_property import cached_property

from torchcast.internals.utils import get_nan_groups, is_near_zero, transpose_last_dims
from torchcast.utils.data import TimeSeriesDataset
from torchcast.utils.outliers import get_outlier_multi

if TYPE_CHECKING:
from pandas import DataFrame
from torchcast.state_space import StateSpaceModel


class Predictions(nn.Module):
Expand Down Expand Up @@ -105,7 +112,7 @@ def update_means(self) -> Optional[torch.Tensor]:
return self._update_means

@cached_property
def update_covs(self) -> torch.Tensor:
def update_covs(self) -> Optional[torch.Tensor]:
if self._update_covs is None:
return None
if not isinstance(self._update_covs, torch.Tensor):
Expand Down Expand Up @@ -165,10 +172,7 @@ def observe(cls, state_means: Tensor, state_covs: Tensor, R: Tensor, H: Tensor)
:return: A tuple of `means`, `covs`.
"""
means = H.matmul(state_means.unsqueeze(-1)).squeeze(-1)
pargs = list(range(len(H.shape)))
pargs[-2:] = reversed(pargs[-2:])
Ht = H.permute(*pargs)
assert R.shape[-1] == R.shape[-2], f"R is not symmetrical (shape is {R.shape})"
Ht = transpose_last_dims(H)
covs = H.matmul(state_covs).matmul(Ht) + R
return means, covs

Expand All @@ -194,9 +198,9 @@ def sample(self) -> Tensor:

def log_prob(self, obs: Tensor) -> Tensor:
"""
Compute the log-probability of data (e.g. data that was originally fed into the KalmanFilter).
Compute the log-probability of data (e.g. data that was originally fed into the ``StateSpaceModel``).
:param obs: A Tensor that could be used in the KalmanFilter.forward pass.
:param obs: A Tensor that could be used in the ``StateSpaceModel`` forward pass.
:return: A tensor with one element for each group X timestep indicating the log-probability.
"""
assert len(obs.shape) == 3
Expand Down Expand Up @@ -253,22 +257,33 @@ def log_prob(self, obs: Tensor) -> Tensor:
def _log_prob(self, obs: Tensor, means: Tensor, covs: Tensor) -> Tensor:
return self.distribution_cls(means, covs, validate_args=False).log_prob(obs)

@classmethod
def _get_quantiles(cls, mean, std, conf: float, observed: bool) -> tuple:
assert conf >= .50
multi = -ScipyNorm.ppf((1 - conf) / 2)
lower = mean - multi * std
upper = mean + multi * std
return lower, upper

def to_dataframe(self,
dataset: Union[TimeSeriesDataset, dict],
type: str = 'predictions',
group_colname: str = 'group',
time_colname: str = 'time',
multi: Optional[float] = 1.96) -> 'DataFrame':
conf: Optional[float] = .95,
multi: Optional[float] = None) -> 'DataFrame':
"""
:param dataset: Either a :class:`.TimeSeriesDataset`, or a dictionary with 'start_times', 'group_names', &
'dt_unit'
:param type: Either 'predictions' or 'components'.
:param group_colname: Column-name for 'group'
:param time_colname: Column-name for 'time'
:param multi: Multiplier on std-dev for lower/upper CIs. Default 1.96.
:param conf: Conf the lower/upper CIs will target. Default of 0.95 means these are 0.025 and 0.975.
:return: A pandas DataFrame with group, 'time', 'measure', 'mean', 'lower', 'upper'. For ``type='components'``
additionally includes: 'process' and 'state_element'.
"""
if multi is not None:
warn("Ignoring `multi` as it is deprecated, please use `conf` instead.", DeprecationWarning)

from pandas import concat

Expand Down Expand Up @@ -321,9 +336,8 @@ def _tensor_to_df(tens, measures):
for i, measure in enumerate(self.measures):
# predicted:
df = _tensor_to_df(torch.stack([self.means[..., i], stds[..., i]], 2), measures=['mean', 'std'])
if multi is not None:
df['lower'] = df['mean'] - multi * df['std']
df['upper'] = df['mean'] + multi * df.pop('std')
if conf is not None:
df['lower'], df['upper'] = self._get_quantiles(df['mean'], df.pop('std'), conf=conf, observed=True)

# actual:
orig_tensor = batch_info.get('named_tensors', {}).get(measure, None)
Expand All @@ -337,9 +351,8 @@ def _tensor_to_df(tens, measures):
# components:
for (measure, process, state_element), (m, std) in self._components().items():
df = _tensor_to_df(torch.stack([m, std], 2), measures=['mean', 'std'])
if multi is not None:
df['lower'] = df['mean'] - multi * df['std']
df['upper'] = df['mean'] + multi * df.pop('std')
if conf is not None:
df['lower'], df['upper'] = self._get_quantiles(df['mean'], df.pop('std'), conf=conf, observed=False)
df['process'], df['state_element'], df['measure'] = process, state_element, measure
out.append(df)

Expand Down Expand Up @@ -378,8 +391,9 @@ def _components(self) -> Dict[Tuple[str, str, str], Tuple[Tensor, Tensor]]:

return out

@staticmethod
def plot(df: 'DataFrame',
@classmethod
def plot(cls,
df: 'DataFrame',
group_colname: str = None,
time_colname: str = None,
max_num_groups: int = 1,
Expand Down Expand Up @@ -415,8 +429,7 @@ def plot(df: 'DataFrame',

df = df.copy()
if 'upper' not in df.columns and 'std' in df.columns:
df['upper'] = df['mean'] + 1.96 * df['std']
df['lower'] = df['mean'] - 1.96 * df['std']
df['lower'], df['upper'] = cls._get_quantiles(df['mean'], df['std'], conf=.95, observed=not is_components)
if df[group_colname].nunique() > max_num_groups:
subset_groups = df[group_colname].drop_duplicates().sample(max_num_groups).tolist()
if len(subset_groups) < df[group_colname].nunique():
Expand Down Expand Up @@ -482,7 +495,7 @@ def _subset_to_times(self,

if start_times is not None:
if isinstance(dt_unit, str):
dt_unit = np.timedelta64(1, dt_unit)
dt_unit = np.datetime64(1, dt_unit)
times = times - start_times
if dt_unit is not None:
times = times // dt_unit # todo: validate int?
Expand Down

0 comments on commit 8694115

Please sign in to comment.