diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000..daf957c --- /dev/null +++ b/docs/README.md @@ -0,0 +1,11 @@ +To compile the docs, you must install the package with the `docs` extra: + +```bash +pip install git+https://github.com/strongio/torchcast.git#egg=torchcast[docs] +``` + +Then from project root run: + +```bash +sphinx-build -b html ./docs ./docs/_html +``` \ No newline at end of file diff --git a/docs/api/kalman_filter.rst b/docs/api/kalman_filter.rst index 51213db..8209830 100644 --- a/docs/api/kalman_filter.rst +++ b/docs/api/kalman_filter.rst @@ -1,7 +1,7 @@ Kalman Filter ============= -.. automodule:: torchcast.kalman_filter +.. automodule:: torchcast.kalman_filter.kalman_filter :members: KalmanFilter :exclude-members: ss_step_cls :show-inheritance: diff --git a/docs/examples/air_quality.py b/docs/examples/air_quality.py index e6196a0..3d1dc6c 100644 --- a/docs/examples/air_quality.py +++ b/docs/examples/air_quality.py @@ -81,7 +81,9 @@ def inverse_transform(df): df = df.copy() # bias-correction for log-transform (see https://otexts.com/fpp2/transformations.html#bias-adjustments) - df['mean'] = df['mean'] + .5 * (df['upper'] - df['lower']) / 1.96 ** 2 + df['mean'] = df['mean'] + .5 * df['std'] ** 2 + df['lower'] = df['mean'] - 1.96 * df['std'] + df['upper'] = df['mean'] + 1.96 * df['std'] # inverse the log10: df[['actual', 'mean', 'upper', 'lower']] = 10 ** df[['actual', 'mean', 'upper', 'lower']] df['measure'] = df['measure'].str.replace('_log10', '') @@ -94,7 +96,7 @@ def inverse_transform(df): out_timesteps=dataset_pm_univariate.tensors[0].shape[1] ) -df_forecast = inverse_transform(forecast.to_dataframe(dataset_pm_univariate)) +df_forecast = inverse_transform(forecast.to_dataframe(dataset_pm_univariate, conf=None)) print(forecast.plot(df_forecast, max_num_groups=3, split_dt=SPLIT_DT)) # %% [markdown] @@ -117,7 +119,7 @@ def inverse_transform(df): ) df_univariate_error = pred_4step.\ - to_dataframe(dataset_pm_univariate, group_colname='station', time_colname='week').\ + to_dataframe(dataset_pm_univariate, group_colname='station', time_colname='week', conf=None).\ pipe(inverse_transform).\ merge(df_aq.loc[:,['station', 'week', 'PM']]).\ assign( @@ -234,7 +236,7 @@ def mc_preds_to_dataframe(preds: Predictions, df_multivariate_error.groupby('validation')['error'].agg(['mean','std']) # %% [markdown] -# We see that this approach has reduced our error: substantially in the training period, and moderately in the validation period. We can look at the per-site differences to reduce common sources of noise and see that the reduction is consistent (it holds for all but one site): +# We see that this approach has reduced our error in the validation period. We can look at the per-site differences to reduce noise: # %% df_multivariate_error.\ @@ -304,7 +306,10 @@ def mc_preds_to_dataframe(preds: Predictions, start_offsets=dataset_pm_lm.start_datetimes, n_step=4 ) -pred_4step.plot(pred_4step.to_dataframe(dataset_pm_lm, type='components').query("process.str.contains('lm')"), split_dt=SPLIT_DT) +pred_4step.plot( + pred_4step.to_dataframe(dataset_pm_lm, type='components').query("process.str.contains('lm')"), + split_dt=SPLIT_DT +) # %% [markdown] # Now let's look at error: diff --git a/docs/quick_start.py b/docs/quick_start.py index 6523924..5481ecc 100644 --- a/docs/quick_start.py +++ b/docs/quick_start.py @@ -110,7 +110,7 @@ # `Predictions` can easily be converted to Pandas `DataFrames` for ease of inspecting predictions, comparing them to actuals, and visualizing: # %% -df_pred = pred.to_dataframe(dataset_all, multi=None) +df_pred = pred.to_dataframe(dataset_all, conf=None) # bias-correction for log-transform (see https://otexts.com/fpp2/transformations.html#bias-adjustments) df_pred['mean'] += .5 * df_pred['std'] ** 2 df_pred['lower'] = df_pred['mean'] - 1.96 * df_pred['std'] diff --git a/setup.py b/setup.py index 3fa7f9c..e006236 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ install_requires=[ 'torch>=1.12', 'numpy>=1.4', + 'scipy>=1.10' ], extras_require={ 'tests': ['parameterized>=0.7', 'filterpy>=1.4', 'pandas>=1.0'], diff --git a/tests/test_kalman_filter.py b/tests/test_kalman_filter.py index 30fac3f..3c2d958 100644 --- a/tests/test_kalman_filter.py +++ b/tests/test_kalman_filter.py @@ -1,8 +1,8 @@ import copy import itertools from collections import defaultdict -from typing import Callable, Optional, Dict -from unittest import TestCase +from typing import Callable, Dict +from unittest import TestCase, expectedFailure import torch from parameterized import parameterized @@ -66,7 +66,7 @@ def test_nans(self, ndim: int = 3, n_step: int = 1): processes=[LocalLevel(id=f'lm{i}', measure=str(i)) for i in range(ndim)], measures=[str(i) for i in range(ndim)] ) - kf = torch.jit.script(kf) +# kf = torch.jit.script(kf) obs_means, obs_covs = kf(data, n_step=n_step) self.assertFalse(torch.isnan(obs_means).any()) self.assertFalse(torch.isnan(obs_covs).any()) @@ -150,7 +150,7 @@ def test_equations_decay(self): # confirm decay works in forward pass # also tests that kf.forward works with `out_timesteps > input.shape[1]` pred = torch_kf( - initial_state=torch_kf._prepare_initial_state((None, None), start_offsets=np.zeros(1)), + initial_state=torch_kf._prepare_initial_state(None, start_offsets=np.zeros(1)), X=torch.randn(1, num_times, 3), out_timesteps=num_times ) @@ -168,7 +168,6 @@ def test_equations(self): processes=[LocalTrend(id='lt', decay_velocity=None, measure='y', velocity_multi=1.)], measures=['y'] ) - kf = torch.jit.script(torch_kf) expectedF = torch.tensor([[1., 1.], [0., 1.]]) expectedH = torch.tensor([[1., 0.]]) kwargs_per_process = torch_kf._parse_design_kwargs(input=data, out_timesteps=num_times) @@ -184,7 +183,7 @@ def test_equations(self): # make filterpy kf: filter_kf = filterpy_KalmanFilter(dim_x=2, dim_z=1) - filter_kf.x, filter_kf.P = torch_kf._prepare_initial_state((None, None)) + filter_kf.x, filter_kf.P = torch_kf._prepare_initial_state(None) filter_kf.x = filter_kf.x.detach().numpy().T filter_kf.P = filter_kf.P.detach().numpy().squeeze(0) filter_kf.Q = Q.numpy() @@ -228,7 +227,7 @@ def __init__(self, *args, **kwargs): ], measures=['y'] ) - kf._scale_by_measure_var = False + kf._get_measure_scaling = lambda: torch.ones(2) kf.state_dict()['initial_mean'][:] = torch.tensor([1.5, -0.5]) kf.state_dict()['measure_covariance.cholesky_log_diag'][0] = np.log(.1 ** .5) @@ -348,7 +347,7 @@ def _build_h_mat(self, inputs: Dict[str, Tensor], num_groups: int, num_times: in processes=[Season(id='s1')], measures=['y'] ) - kf._scale_by_measure_var = False + kf._get_measure_scaling = lambda: torch.ones(1) data = torch.arange(7).view(1, -1, 1).to(torch.float32) for init_state in [0., 1.]: kf.state_dict()['initial_mean'][:] = torch.ones(1) * init_state @@ -401,11 +400,10 @@ def test_no_proc_variance(self): self.assertTrue((cov == 0).all()) @parameterized.expand([ - (torch.float64, 2, True), (torch.float64, 2, False) ]) @torch.no_grad() - def test_dtype(self, dtype: torch.dtype, ndim: int = 2, compiled: bool = True): + def test_dtype(self, dtype: torch.dtype, ndim: int = 2, compiled: bool = False): data = torch.zeros((2, 5, ndim), dtype=dtype) kf = KalmanFilter( processes=[LocalLevel(id=f'll{i}', measure=str(i)) for i in range(ndim)], diff --git a/tests/test_process.py b/tests/test_process.py index c8974cc..eabf479 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -16,7 +16,7 @@ def test_fourier_season(self): processes=[Season(id='day_of_week', period='7D', dt_unit='D', K=3, fixed=True)], measures=['y'] ) - kf._scale_by_measure_var = False + kf._get_measure_scaling = lambda: torch.ones(6) kf.state_dict()['initial_mean'][:] = torch.tensor([1., 0., 0., 0., 0., 0.]) kf.state_dict()['measure_covariance.cholesky_log_diag'] -= 2 pred = kf(data, start_offsets=start_datetimes) diff --git a/tests/test_training.py b/tests/test_training.py index 9e39826..6c7b17c 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -14,21 +14,6 @@ class TestTraining(unittest.TestCase): - @parameterized.expand([(1,), (2,), (3,)]) - @torch.no_grad() - def test_gaussian_log_prob(self, ndim: int = 1): - data = torch.zeros((2, 5, ndim)) - kf = KalmanFilter( - processes=[LocalLevel(id=f'lm{i}', measure=str(i)) for i in range(ndim)], - measures=[str(i) for i in range(ndim)] - ) - dist = kf.ss_step.get_distribution() - pred = kf(data) - log_lik1 = dist(*pred).log_prob(data) - from torch.distributions import MultivariateNormal - mv = MultivariateNormal(*pred) - log_lik2 = mv.log_prob(data) - self.assertAlmostEqual(log_lik1.sum().item(), log_lik2.sum().item()) @parameterized.expand([(1,), (2,), (3,)]) @torch.no_grad() @@ -46,8 +31,6 @@ def test_log_prob_with_missings(self, ndim: int = 1, num_groups: int = 1, num_ti lp_method1 = pred.log_prob(data) lp_method1_sum = lp_method1.sum().item() - dist = kf.ss_step.get_distribution() - lp_method2_sum = 0 for g in range(num_groups): data_g = data[[g]] @@ -59,7 +42,7 @@ def test_log_prob_with_missings(self, ndim: int = 1, num_groups: int = 1, num_ti if not isvalid_gt.any(): continue if isvalid_gt.all(): - lp_gt = dist(*pred_gt).log_prob(data_gt).item() + lp_gt = torch.distributions.MultivariateNormal(*pred_gt).log_prob(data_gt).item() else: pred_gtm = pred_gt.observe( state_means=pred_gt.state_means, @@ -67,12 +50,12 @@ def test_log_prob_with_missings(self, ndim: int = 1, num_groups: int = 1, num_ti R=pred_gt.R[..., isvalid_gt, :][..., isvalid_gt], H=pred_gt.H[..., isvalid_gt, :] ) - lp_gt = dist(*pred_gtm).log_prob(data_gt[..., isvalid_gt]).item() + lp_gt = torch.distributions.MultivariateNormal(*pred_gtm).log_prob(data_gt[..., isvalid_gt]).item() self.assertAlmostEqual(lp_method1[g, t].item(), lp_gt, places=4) lp_method2_sum += lp_gt self.assertAlmostEqual(lp_method1_sum, lp_method2_sum, places=3) - def test_training1(self, ndim: int = 2, num_groups: int = 150, num_times: int = 24, compile: bool = True): + def test_training1(self, ndim: int = 2, num_groups: int = 150, num_times: int = 24, compile: bool = False): """ simulated data with known parameters, fitted loss should approach the loss given known params """ @@ -101,8 +84,8 @@ def _make_kf(): X = torch.randn((num_groups, num_times, 5)) kf_generator = _make_kf() with torch.no_grad(): - sim = kf_generator.simulate(out_timesteps=num_times, num_sims=num_groups, X=X) - y = sim.sample() + sim = kf_generator.simulate(out_timesteps=num_times, num_groups=num_groups, X=X) + y = torch.distributions.MultivariateNormal(*sim).sample() assert not y.requires_grad # train: @@ -136,7 +119,7 @@ def closure(): oracle_loss = -kf_generator(y, X=X).log_prob(y).mean() self.assertAlmostEqual(oracle_loss.item(), loss.item(), places=1) - def test_training2(self, num_groups: int = 50, compile: bool = True): + def test_training2(self, num_groups: int = 50, compile: bool = False): """ # manually generated data (sin-wave, trend, etc.) with virtually no noise: MSE should be near zero """ @@ -199,7 +182,7 @@ def closure(): # trend should be identified: self.assertAlmostEqual(pred.state_means[:, :, 1].mean().item(), 5., places=1) - def test_training3(self, compile: bool = True): + def test_training3(self, compile: bool = False): """ Test TBATS and TimeSeriesDataset integration """ diff --git a/torchcast/__init__.py b/torchcast/__init__.py index 0404d81..e1424ed 100644 --- a/torchcast/__init__.py +++ b/torchcast/__init__.py @@ -1 +1 @@ -__version__ = '0.3.0' +__version__ = '0.3.1' diff --git a/torchcast/exp_smooth/exp_smooth.py b/torchcast/exp_smooth/exp_smooth.py index a173271..f88060c 100644 --- a/torchcast/exp_smooth/exp_smooth.py +++ b/torchcast/exp_smooth/exp_smooth.py @@ -67,8 +67,6 @@ class ExpSmoother(StateSpaceModel): :param processes: A list of :class:`.Process` modules. :param measures: A list of strings specifying the names of the dimensions of the time-series being measured. :param measure_covariance: A module created with ``Covariance.from_measures(measures)``. - :param predict_smoothing: A ``torch.nn.Module`` which predicts the smoothing parameters. The module should predict - these as real-values and they will be constrained to 0-1 internally. """ ss_step_cls = ExpSmoothStep diff --git a/torchcast/internals/utils.py b/torchcast/internals/utils.py index 6c8258d..b5c4f33 100644 --- a/torchcast/internals/utils.py +++ b/torchcast/internals/utils.py @@ -5,6 +5,12 @@ import numpy as np +def transpose_last_dims(x: torch.Tensor) -> torch.Tensor: + args = list(range(len(x.shape))) + args[-2], args[-1] = args[-1], args[-2] + return x.permute(*args) + + def get_nan_groups(isnan: torch.Tensor) -> List[Tuple[torch.Tensor, Optional[torch.Tensor]]]: """ Iterable of (group_idx, valid_idx) tuples that can be passed to torch.meshgrid. If no valid, then not returned; if @@ -148,3 +154,12 @@ def true1d_idx(arr: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: def is_near_zero(tens: torch.Tensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> torch.Tensor: z = torch.zeros(1, dtype=tens.dtype, device=tens.device) return torch.isclose(tens, other=z, rtol=rtol, atol=atol, equal_nan=equal_nan) + + +def repeat(x: Union[torch.Tensor, np.ndarray], times: int, dim: int) -> Union[torch.Tensor, np.ndarray]: + reps = [1 for _ in x.shape] + reps[dim] = times + if isinstance(x, np.ndarray): + return np.tile(x, reps=reps) + else: + return x.repeat(*reps) diff --git a/torchcast/kalman_filter/__init__.py b/torchcast/kalman_filter/__init__.py new file mode 100644 index 0000000..2eae6f3 --- /dev/null +++ b/torchcast/kalman_filter/__init__.py @@ -0,0 +1 @@ +from .kalman_filter import KalmanFilter diff --git a/torchcast/kalman_filter/ekf.py b/torchcast/kalman_filter/ekf.py new file mode 100644 index 0000000..8be5f08 --- /dev/null +++ b/torchcast/kalman_filter/ekf.py @@ -0,0 +1,121 @@ +from typing import Dict, Tuple, Optional, Union + +import numpy as np +import pandas as pd +import torch +from torch import Tensor + +from .kalman_filter import KalmanStep +from ..state_space import Predictions +from ..utils import TimeSeriesDataset + + +class EKFStep(KalmanStep): + """ + Implements update for the extended kalman-filter. Currently limited: + + 1. Does not implement any special ``predict``, so no custom functions for transition (only measurement). + 2. Assumes the custom measure function takes the form ``custom_fun(H @ state)``. + + This means that currently, this is primarily useful for approximating non-gaussian measures (e.g. poisson, + binomial) via a link function. + """ + + def _adjust_h(self, mean: Tensor, H: Tensor) -> Tensor: + return H + + def _adjust_r(self, measured_mean: Tensor, R: Optional[Tensor]) -> Tensor: + assert R is not None + return R + + def _adjust_measurement(self, x: Tensor) -> Tensor: + return x + + def _update(self, + input: Tensor, + mean: Tensor, + cov: Tensor, + kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: + if (kwargs['outlier_threshold'] != 0).any(): + raise NotImplementedError("Outlier rejection is not yet supported for EKF") + + orig_H = kwargs['H'] + h_dot_state = (orig_H @ mean.unsqueeze(-1)).squeeze(-1) + kwargs['measured_mean'] = self._adjust_measurement(h_dot_state) + kwargs['H'] = self._adjust_h(mean, orig_H) + kwargs['R'] = self._adjust_r(kwargs['measured_mean'], kwargs.get('R', None)) + + return super()._update( + input=input, + mean=mean, + cov=cov, + kwargs=kwargs + ) + + +class EKFPredictions(Predictions): + @classmethod + def _adjust_measured_mean(cls, + x: Union[Tensor, np.ndarray, pd.Series], + std: Optional[Union[Tensor, np.ndarray, pd.Series]] = None, + conf: float = .95) -> Union[Tensor, pd.DataFrame]: + """ + In our EKF, the measured-mean is ``custom_fun(H @ state)``. + + - If only ``x`` (``= H @ state``) is passed, this method should apply the custom fun -- supporting x that is + either a tensor (with grad), a numpy array, or a pandas series. + - If both ``x`` and ``std`` is passed, this method should return a dataframe with mean, lower, upper columns. + The mean column should have any bias-correction applied, and the lower/upper should be conf% confidence + bounds (e.g. for plotting). + """ + raise NotImplementedError + + def _log_prob(self, obs: Tensor, means: Tensor, covs: Tensor) -> Tensor: + raise NotImplementedError + + def __array__(self) -> np.ndarray: + with torch.no_grad(): + stds = torch.diagonal(self.covs, dim1=-1, dim2=-2).sqrt() + out = [] + for i, m in enumerate(self.measures): + out.append(self._adjust_measured_mean(self.means[..., i], stds[..., i])) + return torch.stack(out, -1).numpy() + + def to_dataframe(self, + dataset: Union[TimeSeriesDataset, dict], + type: str = 'predictions', + group_colname: str = 'group', + time_colname: str = 'time', + conf: Optional[float] = .95, + multi: Optional[float] = None) -> pd.DataFrame: + df = super().to_dataframe( + dataset=dataset, + type=type, + group_colname=group_colname, + time_colname=time_colname, + conf=None, + multi=multi + ) + df[['mean', 'lower', 'upper']] = self._adjust_measured_mean(df['mean'], df.pop('std'), conf=conf) + return df + + @classmethod + def plot(cls, + df: pd.DataFrame, + group_colname: str = None, + time_colname: str = None, + max_num_groups: int = 1, + split_dt: Optional[np.datetime64] = None, + **kwargs) -> pd.DataFrame: + + if 'upper' not in df.columns and 'std' in df.columns: + df[['mean', 'lower', 'upper']] = cls._adjust_measured_mean(df['mean'], df['std']) + + return super().plot( + df=df, + group_colname=group_colname, + time_colname=time_colname, + max_num_groups=max_num_groups, + split_dt=split_dt, + **kwargs + ) diff --git a/torchcast/kalman_filter.py b/torchcast/kalman_filter/kalman_filter.py similarity index 87% rename from torchcast/kalman_filter.py rename to torchcast/kalman_filter/kalman_filter.py index d79d3e7..cd0a533 100644 --- a/torchcast/kalman_filter.py +++ b/torchcast/kalman_filter/kalman_filter.py @@ -19,7 +19,7 @@ from torch import nn, Tensor from typing_extensions import Final -from torchcast.utils.outliers import mahalanobis_dist +from torchcast.utils.outliers import get_outlier_multi class KalmanStep(StateSpaceStep): @@ -56,24 +56,37 @@ def _mask_mats(self, }) return masked_input, new_kwargs - def _update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: + def _update(self, + input: Tensor, + mean: Tensor, + cov: Tensor, + kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: H = kwargs['H'] R = kwargs['R'] Ht = H.permute(0, 2, 1) - system_covariance = torch.baddbmm(R, H @ cov, Ht) - - # kalman-gain: - K = self._kalman_gain(cov=cov, Ht=Ht, system_covariance=system_covariance) # residuals: - measured_mean = (H @ mean.unsqueeze(-1)).squeeze(-1) + if 'measured_mean' in kwargs: # calculated by super + measured_mean = kwargs['measured_mean'] + else: + measured_mean = (H @ mean.unsqueeze(-1)).squeeze(-1) resid = input - measured_mean - # outlier-rejection: - if kwargs['outlier_threshold'] > 0: - mdist = mahalanobis_dist(resid, system_covariance) - multi = (mdist - kwargs['outlier_threshold']).clamp(min=0) + 1 - R = R * multi.unsqueeze(-1).unsqueeze(-1) + HcHt = H @ cov @ Ht + system_covariance = HcHt + R + + # # outlier-rejection: + # if (kwargs['outlier_threshold'] != 0).any(): + # multi = get_outlier_multi( + # resid=resid, + # cov=system_covariance, + # outlier_threshold=kwargs['outlier_threshold'] + # ) + # R = R * multi.unsqueeze(-1).unsqueeze(-1) + # system_covariance = HcHt + R + + # kalman-gain: + K = self._kalman_gain(cov=cov, Ht=Ht, system_covariance=system_covariance) # update: new_mean = mean + (K @ resid.unsqueeze(-1)).squeeze(-1) @@ -119,9 +132,7 @@ def __init__(self, processes: Sequence[Process], measures: Optional[Sequence[str]] = None, process_covariance: Optional[Covariance] = None, - measure_covariance: Optional[Covariance] = None, - outlier_threshold: float = 0., - outlier_burnin: Optional[int] = None): + measure_covariance: Optional[Covariance] = None): initial_covariance = Covariance.from_processes(processes, cov_type='initial') @@ -136,9 +147,7 @@ def __init__(self, super().__init__( processes=processes, measures=measures, - measure_covariance=measure_covariance, - outlier_threshold=outlier_threshold, - outlier_burnin=outlier_burnin + measure_covariance=measure_covariance ) self.process_covariance = process_covariance.set_id('process_covariance') self.initial_covariance = initial_covariance.set_id('initial_covariance') diff --git a/torchcast/process/season.py b/torchcast/process/season.py index 7a88a2e..5d0c287 100644 --- a/torchcast/process/season.py +++ b/torchcast/process/season.py @@ -11,7 +11,6 @@ from torchcast.process.base import Process from torchcast.process.utils import SingleOutput, Multi, Bounded, ScriptSequential -from torchcast.utils.features import fourier_tensor class _Season: diff --git a/torchcast/state_space/base.py b/torchcast/state_space/base.py index 20baf25..8aab9d5 100644 --- a/torchcast/state_space/base.py +++ b/torchcast/state_space/base.py @@ -2,10 +2,11 @@ from typing import Tuple, List, Optional, Sequence, Dict, Iterable, Callable, Union, Type from warnings import warn +import numpy as np import torch from torch import nn, Tensor -from torchcast.internals.utils import get_owned_kwargs +from torchcast.internals.utils import get_owned_kwargs, repeat from torchcast.covariance import Covariance from torchcast.state_space.predictions import Predictions from torchcast.state_space.ss_step import StateSpaceStep @@ -19,33 +20,23 @@ class StateSpaceModel(nn.Module): :param processes: A list of :class:`.Process` modules. :param measures: A list of strings specifying the names of the dimensions of the time-series being measured. :param measure_covariance: A module created with ``Covariance.from_measures(measures)``. - :param outlier_threshold: If specified, used as a threshold-for outlier under-weighting during the `update` step, - using mahalanobis distance; outliers will also be under-weighted when evaluating the ``log_prob()`` of the output - ``Predictions``. - :param outlier_burnin: If outlier_threshold is specified, this specifies the number of steps to wait before - starting to reject outliers. """ ss_step_cls: Type[StateSpaceStep] def __init__(self, processes: Sequence[Process], measures: Optional[Sequence[str]], - measure_covariance: Covariance, - outlier_threshold: float = 0.0, - outlier_burnin: Optional[int] = None): + measure_covariance: Covariance): super().__init__() - self.outlier_threshold = outlier_threshold - if self.outlier_threshold and outlier_burnin is None: - raise ValueError("If `outlier_threshold` is set, `outlier_burnin` must be set as well.") - self.outlier_burnin = outlier_burnin or 0 - if isinstance(measures, str): measures = [measures] warn(f"`measures` should be a list of strings not a string; interpreted as `{measures}`.") self._validate(processes, measures) - self.measure_covariance = measure_covariance.set_id('measure_covariance') + self.measure_covariance = measure_covariance + if self.measure_covariance: + self.measure_covariance.set_id('measure_covariance') self.ss_step = self.ss_step_cls() @@ -66,8 +57,23 @@ def __init__(self, # the initial mean self.initial_mean = torch.nn.Parameter(.1 * torch.randn(self.state_rank)) - # can disable for debugging/tests: - self._scale_by_measure_var = True + @property + def dt_unit(self) -> Optional[np.timedelta64]: + dt_unit_ns = None + proc_with_dt = '' + for p in self.processes.values(): + if hasattr(p, 'dt_unit_ns'): + if dt_unit_ns is None: + dt_unit_ns = p.dt_unit_ns + proc_with_dt = p.id + elif p.dt_unit_ns != dt_unit_ns: + raise ValueError( + f"Found multiple processes with different dt_units:" + f"{proc_with_dt}: {dt_unit_ns}" + f"{p.id}: {p.dt_unit_ns}" + ) + if dt_unit_ns is not None: + return np.timedelta64(dt_unit_ns, 'ns') # todo: promote @torch.jit.ignore() def fit(self, @@ -78,6 +84,7 @@ def fit(self, optimizer: Optional[torch.optim.Optimizer] = None, verbose: int = 2, callbacks: Sequence[Callable] = (), + get_loss: Optional[Callable] = None, loss_callback: Optional[Callable] = None, callable_kwargs: Optional[Dict[str, Callable]] = None, set_initial_values: bool = True, @@ -98,14 +105,17 @@ def fit(self, (the default) this progress bar will tick within each epoch to track the calls to forward. :param callbacks: A list of functions that will be called at the end of each epoch, which take the current epoch's loss value. - :param loss_callback: A callback that takes the loss and returns a modified loss, called before each call to - `backward()`. This can be used for example to add regularization. - :param callable_kwargs: A dictionary where the keys are keyword-names and the values are no-argument functions - that will be called each iteration to recompute the corresponding arguments. - :param set_initial_values: Default is to set ``initial_mean`` to sensible value given ``y``. This helps speed - up training if the data are not centered. Set to ``False`` if you're resuming training from a previous - ``fit()`` call. + :param get_loss: A function that takes the ``Predictions` object and the input data and returns the loss. + Default is ``lambda pred, y: -pred.log_prob(y).mean()``. + :param loss_callback: Deprecated; use ``get_loss`` instead. + :param set_initial_values: Will set ``initial_mean`` to sensible value given ``y``, which helps speed + up training if the data are not centered. This argument determines the number of timesteps of ``y`` to use + when doing so (default 1). Set to 0/`False``, if you're resuming training from a previous ``fit()`` call. Set + to a larger value for sparse data where the first timestep isn't informative enough. :param kwargs: Further keyword-arguments passed to :func:`StateSpaceModel.forward()`. + :param callable_kwargs: The kwargs passed to the forward pass are static, but sometimes you want to recompute + them each iteration. The values in this dictionary are functions that will be called each iteration to + recompute the corresponding arguments. :return: This ``StateSpaceModel`` instance. """ @@ -117,11 +127,10 @@ def fit(self, optimizer = torch.optim.LBFGS([p for p in self.parameters() if p.requires_grad], max_iter=10, line_search_fn='strong_wolfe', lr=.5) - if self.outlier_threshold and verbose: - print("``outlier_threshold`` is experimental") + self.set_initial_values(y, n=set_initial_values, verbose=verbose > 1) - if set_initial_values: - self.set_initial_values(y) + if not get_loss: + get_loss = lambda pred, y: -pred.log_prob(y).mean() prog = None if verbose > 1: @@ -134,15 +143,15 @@ def fit(self, except ImportError: warn("`progress=True` requires package `tqdm`.") - epoch = 0 - callable_kwargs = callable_kwargs or {} + if loss_callback: + warn("`loss_callback` is deprecated; use `get_loss` instead.", DeprecationWarning) def closure(): optimizer.zero_grad() kwargs.update({k: v() for k, v in callable_kwargs.items()}) pred = self(y, **kwargs) - loss = -pred.log_prob(y).mean() + loss = get_loss(pred, y) if loss_callback: loss = loss_callback(loss) loss.backward() @@ -175,10 +184,15 @@ def closure(): return self @torch.jit.ignore() - def set_initial_values(self, y: Tensor): + def set_initial_values(self, y: Tensor, n: int, ilink: Optional[callable] = None, verbose: bool = True): + if not n: + return if 'initial_mean' not in self.state_dict(): return + if ilink is None: + ilink = lambda x: x + assert len(self.measures) == y.shape[-1] hits = {m: [] for m in self.measures} @@ -191,11 +205,15 @@ def set_initial_values(self, y: Tensor): assert process.measure hits[process.measure].append(process.id) + se_idx = process.state_elements.index('position') measure_idx = list(self.measures).index(process.measure) with torch.no_grad(): - t0 = y[:, 0, measure_idx] - self.state_dict()['initial_mean'][self.process_to_slice[pid][0]] = \ - t0[~torch.isnan(t0) & ~torch.isinf(t0)].mean() + t0 = y[:, 0:n, measure_idx] + init_mean = ilink(t0[~torch.isnan(t0) & ~torch.isinf(t0)].mean()) + if verbose: + print(f"Initializing {pid}.position to {init_mean.item()}") + # TODO instead of [0], should actually get index of 'position->position' + self.state_dict()['initial_mean'][self.process_to_slice[pid][se_idx]] = init_mean for measure, procs in hits.items(): if len(procs) > 1: @@ -238,7 +256,8 @@ def _validate(processes: Sequence[Process], measures: Sequence[str]): def design_modules(self) -> Iterable[Tuple[str, nn.Module]]: for pid in self.processes: yield pid, self.processes[pid] - yield 'measure_covariance', self.measure_covariance + if self.measure_covariance: + yield 'measure_covariance', self.measure_covariance @torch.jit.ignore() def forward(self, @@ -246,9 +265,10 @@ def forward(self, n_step: Union[int, float] = 1, start_offsets: Optional[Sequence] = None, out_timesteps: Optional[Union[int, float]] = None, - initial_state: Union[Tensor, Tuple[Optional[Tensor], Optional[Tensor]]] = (None, None), + initial_state: Optional[Tuple[Tensor, Tensor]] = None, every_step: bool = True, include_updates_in_output: bool = False, + simulate: Optional[int] = None, **kwargs) -> Predictions: """ Generate n-step-ahead predictions from the model. @@ -263,9 +283,9 @@ def forward(self, :param out_timesteps: The number of timesteps to produce in the output. This is useful when passing a tensor of predictors that goes later in time than the `input` tensor -- you can specify ``out_timesteps=X.shape[1]`` to get forecasts into this later time horizon. - :param initial_state: The initial prediction for the state of the system. XXX (single tensor for mean always - supported. for kf child class, can pass (mean,cov) tuple. this is usually if you're feeding from a previous - output. for exp-smooth child class, latter isn't supported. + :param initial_state: The initial prediction for the state of the system: a tuple of mean, cov tensors. This + would usually come from a previous call to this model, which produces a ``Predictions`` object, which you can + then call :func:`get_state_at_times()` on. :param every_step: By default, ``n_step`` ahead predictions will be generated at every timestep. If ``every_step=False``, then these predictions will only be generated every `n_step` timesteps. For example, with hourly data, ``n_step=24`` and ``every_step=True``, each timepoint would be a forecast generated with @@ -278,6 +298,7 @@ def forward(self, to True to allow this -- False by default to reduce memory. :param kwargs: Further arguments passed to the `processes`. For example, the :class:`.LinearModel` expects an ``X`` argument for predictors. + :param simulate: If specified, will generate `simulate` samples from the model. :return: A :class:`.Predictions` object with :func:`Predictions.log_prob()` and :func:`Predictions.to_dataframe()` methods. """ @@ -292,12 +313,17 @@ def forward(self, if out_timesteps is None and input is None: raise RuntimeError("If no input is passed, must specify `out_timesteps`") - if isinstance(initial_state, Tensor): - initial_state = (initial_state, None) initial_state = self._prepare_initial_state( initial_state, start_offsets=start_offsets, ) + if simulate and simulate > 1: + init_mean, init_cov = initial_state + initial_state = repeat(init_mean, simulate, dim=0), repeat(init_cov, simulate, dim=0) + if start_offsets is not None: + start_offsets = repeat(np.asarray(start_offsets), simulate, dim=0) + kwargs = {k: (repeat(v, simulate, dim=0) if isinstance(v, (Tensor, np.ndarray)) else v) + for k, v in kwargs.items()} if isinstance(n_step, float): if not n_step.is_integer(): @@ -308,7 +334,7 @@ def forward(self, raise ValueError("`out_timesteps` must be an int.") out_timesteps = int(out_timesteps) - preds, updates, R, H = self._script_forward( + preds, updates, design_mats = self._script_forward( input=input, initial_state=initial_state, n_step=n_step, @@ -318,70 +344,81 @@ def forward(self, input, out_timesteps=out_timesteps or input.shape[1], **kwargs - ) + ), + simulate=bool(simulate) + ) + preds = self._generate_predictions( + preds=preds, + updates=updates if include_updates_in_output else None, + **design_mats, + ) + return preds.set_metadata( + start_offsets=start_offsets, + dt_unit=self.dt_unit ) - return self._generate_predictions(preds, R, H, updates if include_updates_in_output else None) @torch.jit.ignore def _generate_predictions(self, preds: Tuple[List[Tensor], List[Tensor]], - R: List[Tensor], - H: List[Tensor], - updates: Optional[Tuple[List[Tensor], List[Tensor]]] = None) -> 'Predictions': + updates: Optional[Tuple[List[Tensor], List[Tensor]]] = None, + **kwargs) -> 'Predictions': """ StateSpace subclasses may pass subclasses of `Predictions` (e.g. for custom log-prob) """ - kwargs = { - 'state_means': preds[0], - 'state_covs': preds[1], - 'R': R, - 'H': H, - 'model': self - } if updates is not None: kwargs.update(update_means=updates[0], update_covs=updates[1]) - return Predictions(**kwargs) + preds = Predictions( + *preds, + R=kwargs.pop('R'), + H=kwargs.pop('H'), + model=self, + **kwargs + ) + return preds @torch.jit.ignore def _prepare_initial_state(self, - initial_state: Tuple[Optional[Tensor], Optional[Tensor]], + initial_state: Optional[Tuple[Tensor, Tensor]], start_offsets: Optional[Sequence] = None) -> Tuple[Tensor, Tensor]: - init_mean, init_cov = initial_state - if init_mean is None: - init_mean = self.initial_mean[None, :] - elif len(init_mean.shape) != 2: - raise ValueError( - f"Expected ``init_mean`` to have two-dimensions for (num_groups, state_dim), got {init_mean.shape}" - ) - if init_cov is None: + if initial_state is None: + init_mean = self.initial_mean[None, :].clone() init_cov = self.initial_covariance({}, num_groups=1, num_times=1, _ignore_input=True)[:, 0] - elif len(init_cov.shape) != 3: - raise ValueError( - f"Expected ``init_cov`` to be 3-D with (num_groups, state_dim, state_dim), got {init_cov.shape}" - ) + else: + init_mean, init_cov = initial_state + if len(init_mean.shape) != 2: + raise ValueError( + f"Expected ``init_mean`` to have two-dimensions for (num_groups, state_dim), got {init_mean.shape}" + ) + if len(init_cov.shape) != 3: + raise ValueError( + f"Expected ``init_cov`` to be 3-D with (num_groups, state_dim, state_dim), got {init_cov.shape}" + ) measure_scaling = torch.diag_embed(self._get_measure_scaling().unsqueeze(0)) init_cov = measure_scaling @ init_cov @ measure_scaling - # seasonal processes need to offset the initial mean: if start_offsets is not None: if init_mean.shape[0] == 1: init_mean = init_mean.expand(len(start_offsets), -1) elif init_mean.shape[0] != len(start_offsets): raise ValueError("Expected ``len(start_offets) == initial_state[0].shape[0]``") - init_mean_w_offset = [] - for pid in self.processes: - p = self.processes[pid] - _process_slice = slice(*self.process_to_slice[pid]) - init_mean_w_offset.append(p.offset_initial_state(init_mean[:, _process_slice], start_offsets)) - init_mean_offset = torch.cat(init_mean_w_offset, 1) - else: - init_mean_offset = init_mean + if initial_state is None: + # seasonal processes need to offset the initial mean: + # TODO: should also handle cov? + init_mean_w_offset = [] + for pid in self.processes: + p = self.processes[pid] + _process_slice = slice(*self.process_to_slice[pid]) + init_mean_w_offset.append(p.offset_initial_state(init_mean[:, _process_slice], start_offsets)) + init_mean = torch.cat(init_mean_w_offset, 1) + else: + # if they passed an initial_state, we assume it's from a previous call to forward, so already offset + pass - return init_mean_offset, init_cov + return init_mean, init_cov @torch.jit.export def _script_forward(self, @@ -390,12 +427,12 @@ def _script_forward(self, initial_state: Tuple[Tensor, Tensor], n_step: int = 1, out_timesteps: Optional[int] = None, - every_step: bool = True + every_step: bool = True, + simulate: bool = False ) -> Tuple[ Tuple[List[Tensor], List[Tensor]], Tuple[List[Tensor], List[Tensor]], - List[Tensor], - List[Tensor] + Dict[str, List[Tensor]] ]: """ :param input: A (group X time X measures) tensor. Optional if `initial_state` is specified. @@ -408,6 +445,8 @@ def _script_forward(self, Alternatively, we could generate 24-hour-ahead predictions at every 24th hour, in which case we'd save predictions 1-24. The former corresponds to every_step=True, the latter to every_step=False. If n_step=1 (the default) then this option has no effect. + :param simulate: If True, will simulate state-trajectories and return a ``Predictions`` object with zero state + covariance. :return: predictions (tuple of (means,covs)), updates (tuple of (means,covs)), R, H """ assert n_step > 0 @@ -455,11 +494,13 @@ def _script_forward(self, ) mean1s.append(mean1step) cov1s.append(cov1step) - if t < len(inputs): + + if simulate: + meanu = torch.distributions.MultivariateNormal(mean1step, cov1step, validate_args=False).sample() + covu = torch.eye(meanu.shape[-1]) * 1e-6 + elif t < len(inputs): update_kwargs_t = {k: v[t] for k, v in update_kwargs.items()} - update_kwargs_t['outlier_threshold'] = torch.tensor( - self.outlier_threshold if t > self.outlier_burnin else 0. - ) + # update_kwargs_t['outlier_threshold'] = torch.tensor(outlier_threshold if t > outlier_burnin else 0.) meanu, covu = self.ss_step.update( inputs[t], mean1step, @@ -468,6 +509,7 @@ def _script_forward(self, ) else: meanu, covu = mean1step, cov1step + meanus.append(meanu) covus.append(covu) @@ -480,7 +522,7 @@ def _script_forward(self, # t1: time of 1step tu = t1 - 1 - # - if every_step, we run this loop ever iter + # - if every_step, we run this loop every iter # - if not every_step, we run this loop every nth iter if every_step or (t1 % n_step) == 0: meanp, covp = mean1s[t1], cov1s[t1] # already had to generate h=1 above @@ -500,10 +542,8 @@ def _script_forward(self, preds = [meanps[t] for t in range(out_timesteps)], [covps[t] for t in range(out_timesteps)] updates = meanus, covus - R = update_kwargs['R'] - H = update_kwargs['H'] - return preds, updates, R, H + return preds, updates, update_kwargs def _build_design_mats(self, kwargs_per_process: Dict[str, Dict[str, Tensor]], @@ -570,15 +610,12 @@ def _parse_design_kwargs(self, input: Optional[Tensor], out_timesteps: int, **kw def _get_measure_scaling(self) -> Tensor: mcov = self.measure_covariance({}, num_groups=1, num_times=1, _ignore_input=True)[0, 0] - if self._scale_by_measure_var: - measure_var = mcov.diagonal(dim1=-2, dim2=-1) - multi = torch.zeros(mcov.shape[0:-2] + (self.state_rank,), dtype=mcov.dtype, device=mcov.device) - for pid, process in self.processes.items(): - pidx = self.process_to_slice[pid] - multi[..., slice(*pidx)] = measure_var[..., self.measure_to_idx[process.measure]].sqrt().unsqueeze(-1) - assert (multi > 0).all() - else: - multi = torch.ones((self.state_rank,), dtype=mcov.dtype, device=mcov.device) + measure_var = mcov.diagonal(dim1=-2, dim2=-1) + multi = torch.zeros(mcov.shape[0:-2] + (self.state_rank,), dtype=mcov.dtype, device=mcov.device) + for pid, process in self.processes.items(): + pidx = self.process_to_slice[pid] + multi[..., slice(*pidx)] = measure_var[..., self.measure_to_idx[process.measure]].sqrt().unsqueeze(-1) + assert (multi > 0).all() return multi def __repr__(self) -> str: @@ -589,61 +626,39 @@ def __repr__(self) -> str: @torch.jit.ignore() def simulate(self, out_timesteps: int, - initial_state: Tuple[Optional[Tensor], Optional[Tensor]] = (None, None), + initial_state: Optional[Tuple[Tensor, Tensor]] = None, start_offsets: Optional[Sequence] = None, - num_sims: Optional[int] = None, - progress: bool = False, + num_sims: int = 1, + num_groups: Optional[int] = None, **kwargs): """ Generate simulated state-trajectories from your model. :param out_timesteps: The number of timesteps to generate in the output. - :param initial_state: The initial state of the system: a tuple of `mean`, `cov`. + :param initial_state: The initial state of the system: a tuple of `mean`, `cov`. Can be obtained from previous + model-predictions by calling ``get_state_at_times()`` on the output predictions. :param start_offsets: If your model includes seasonal processes, then these needs to know the start-time for - each group in ``input``. If you passed ``dt_unit`` when constructing those processes, then you should pass an - array datetimes here. Otherwise you can pass an array of integers (or leave `None` if there are no seasonal - processes). - :param num_sims: The number of state-trajectories to simulate. - :param progress: Should a progress-bar be displayed? Requires `tqdm`. + each group in ``initial_state``. If you passed ``dt_unit`` when constructing those processes, then you should + pass an array of datetimes here, otherwise an array of ints. If there are no seasonal processes you can omit. + :param num_sims: The number of state-trajectories to simulate per group. The output will be laid out in blocks + (e.g. if there are 10 groups, the first ten elements of the output are sim 1, the next 10 elements are sim 2, + etc.). Tensors associated with this output can be reshaped with ``tensor.reshape(num_sims, num_groups, ...)``. + :param num_groups: The number of groups; if `None` will be inferred from the shape of `initial_state` and/or + ``start_offsets``. :param kwargs: Further arguments passed to the `processes`. - :return: A :class:`.Simulations` object with a :func:`Simulations.sample()` method. + :return: A :class:`.Predictions` object with zero state-covariance. """ - mean, cov = self._prepare_initial_state(initial_state, start_offsets=start_offsets) + if num_groups is not None: + if start_offsets is None: + start_offsets = [0] * num_groups + elif len(start_offsets) != num_groups: + raise ValueError("Expected `len(start_offsets) == num_groups` (or num_groups=None)") - times = range(out_timesteps) - if progress: - if progress is True: - try: - from tqdm.auto import tqdm - progress = tqdm - except ImportError: - warn("`progress=True` requires package `tqdm`.") - progress = lambda x: x - times = progress(times) - - predict_kwargs, update_kwargs = self._build_design_mats( - num_groups=num_sims, + return self( + start_offsets=start_offsets, out_timesteps=out_timesteps, - kwargs_per_process=self._parse_design_kwargs(input=None, out_timesteps=out_timesteps, **kwargs) - ) - - dist_cls = self.ss_step.get_distribution() - - means: List[Tensor] = [] - for t in times: - mean = dist_cls(mean, cov).rsample() - mean, cov = self.ss_step.predict( - mean=mean, cov=.0001 * torch.eye(mean.shape[-1]), kwargs={k: v[t] for k, v in predict_kwargs.items()} - ) - means.append(mean) - - smeans = torch.stack(means, 1) - num_groups, num_times, sdim = smeans.shape - scovs = torch.zeros((num_groups, num_times, sdim, sdim), dtype=smeans.dtype, device=smeans.device) - - return self._generate_predictions( - preds=(smeans, scovs), - R=torch.stack(update_kwargs['R'], 1), - H=torch.stack(update_kwargs['H'], 1) + initial_state=initial_state, + simulate=num_sims, + **kwargs ) diff --git a/torchcast/state_space/predictions.py b/torchcast/state_space/predictions.py index 1bb00e7..b3a7028 100644 --- a/torchcast/state_space/predictions.py +++ b/torchcast/state_space/predictions.py @@ -1,13 +1,20 @@ -from functools import cached_property -from typing import Tuple, Union, Optional, Dict, Iterator, Sequence +from dataclasses import dataclass, fields +from typing import Tuple, Union, Optional, Dict, Iterator, Sequence, TYPE_CHECKING +from warnings import warn -import numpy as np import torch from torch import nn, Tensor -from torchcast.internals.utils import get_nan_groups, is_near_zero -from torchcast.utils.data import TimeSeriesDataset -from torchcast.utils.outliers import mahalanobis_dist +import numpy as np +import pandas as pd + +from functools import cached_property + +from torchcast.internals.utils import get_nan_groups, is_near_zero, transpose_last_dims +from torchcast.utils import conf2bounds, TimeSeriesDataset + +if TYPE_CHECKING: + from torchcast.state_space import StateSpaceModel class Predictions(nn.Module): @@ -21,7 +28,7 @@ def __init__(self, state_covs: Sequence[Tensor], R: Sequence[Tensor], H: Sequence[Tensor], - model: Union['StateSpaceModel', dict], + model: Union['StateSpaceModel', 'StateSpaceModelMetadata'], update_means: Optional[Sequence[Tensor]] = None, update_covs: Optional[Sequence[Tensor]] = None): super().__init__() @@ -39,28 +46,62 @@ def __init__(self, self._R = R # some model attributes are needed for `log_prob` method and for names for plotting - if not isinstance(model, dict): + if not isinstance(model, StateSpaceModelMetadata): all_state_elements = [] for pid in model.processes: process = model.processes[pid] for state_element in process.state_elements: all_state_elements.append((pid, state_element)) - model = { - 'distribution_cls': model.ss_step.get_distribution(), - 'measures': model.measures, - 'all_state_elements': all_state_elements, - 'outlier_threshold': model.outlier_threshold - } - self.distribution_cls = model['distribution_cls'] - self.measures = model['measures'] - self.all_state_elements = model['all_state_elements'] - self.outlier_threshold = model['outlier_threshold'] + self._model_attributes = StateSpaceModelMetadata( + measures=model.measures, + all_state_elements=all_state_elements, + ) # for lazily populated properties: self._means = self._covs = None - # useful to have: + # metadata self.num_groups, self.num_timesteps, self.state_size = self.state_means.shape + self._dataset_metadata = None + + def set_metadata(self, + dataset: Optional[TimeSeriesDataset] = None, + group_names: Optional[Sequence[str]] = None, + start_offsets: Optional[np.ndarray] = None, + group_colname: str = 'group', + time_colname: str = 'time', + dt_unit: Optional[str] = None) -> 'Predictions': + if dataset is not None: + group_names = dataset.group_names + start_offsets = dataset.start_offsets + dt_unit = dataset.dt_unit + + if isinstance(dt_unit, str): + dt_unit = np.timedelta64(1, dt_unit) + + if group_names is not None and len(group_names) != self.num_groups: + raise ValueError("`group_names` must have the same length as the number of groups.") + if start_offsets is not None and len(start_offsets) != self.num_groups: + raise ValueError("`start_offsets` must have the same length as the number of groups.") + + kwargs = { + 'group_names': group_names, + 'start_offsets': start_offsets, + 'dt_unit': dt_unit, + 'group_colname': group_colname, + 'time_colname': time_colname + } + if self._dataset_metadata is not None: + self._dataset_metadata.update(**kwargs) + else: + self._dataset_metadata = DatasetMetadata(**kwargs) + return self + + @property + def dataset_metadata(self) -> 'DatasetMetadata': + if self._dataset_metadata is None: + raise RuntimeError("Metadata not set. Pass the dataset or call `set_metadata()`.") + return self._dataset_metadata @cached_property def R(self) -> torch.Tensor: @@ -74,6 +115,10 @@ def H(self) -> torch.Tensor: self._H = torch.stack(self._H, 1) return self._H + @property + def measures(self) -> Sequence[str]: + return self._model_attributes.measures + @cached_property def state_means(self) -> torch.Tensor: if not isinstance(self._state_means, torch.Tensor): @@ -97,7 +142,11 @@ def state_covs(self) -> torch.Tensor: @cached_property def update_means(self) -> Optional[torch.Tensor]: if self._update_means is None: - return None + raise RuntimeError( + "Cannot get ``update_means`` because update mean/cov was not passed when creating this " + "``Predictions`` object. This usually means you have to include ``include_updates_in_output=True`` " + "when calling ``StateSpaceModel()``." + ) if not isinstance(self._update_means, torch.Tensor): self._update_means = torch.stack(self._update_means, 1) if torch.isnan(self._update_means).any(): @@ -105,54 +154,100 @@ def update_means(self) -> Optional[torch.Tensor]: return self._update_means @cached_property - def update_covs(self) -> torch.Tensor: + def update_covs(self) -> Optional[torch.Tensor]: if self._update_covs is None: - return None + raise RuntimeError( + "Cannot get ``update_covs`` because update mean/cov was not passed when creating this " + "``Predictions`` object. This usually means you have to include ``include_updates_in_output=True`` " + "when calling ``StateSpaceModel()``." + ) if not isinstance(self._update_covs, torch.Tensor): self._update_covs = torch.stack(self._update_covs, 1) if torch.isnan(self._update_covs).any(): raise ValueError("`nans` in `update_covs`") return self._update_covs + def with_new_start_times(self, + start_times: Union[np.ndarray, np.datetime64], + n_timesteps: int, + **kwargs) -> 'Predictions': + """ + :param start_times: An array/sequence containing the start time for each group; or a single datetime to apply + to all groups. If the model/predictions are dateless (no dt_unit) then simply an array of indices. + :param n_timesteps: Each group will be sliced to this many timesteps, so times is start and times + n_timesteps + is end. + :return: A new ``Predictions`` object, with the state and measurement tensors sliced to the given times. + """ + start_indices = self._standardize_times(times=start_times, *kwargs) + time_indices = np.arange(n_timesteps)[None, ...] + start_indices[:, None, ...] + return self[np.arange(self.num_groups)[:, None, ...], time_indices] + def get_state_at_times(self, times: Union[np.ndarray, np.datetime64], - start_times: Optional[np.ndarray] = None, - dt_unit: Optional[str] = None, - type_: str = 'update') -> Tuple[Tensor, Tensor]: + type_: str = 'update', + **kwargs) -> Tuple[Tensor, Tensor]: """ For each group, get the state (tuple of (mean, cov)) for a timepoint. This is often useful since predictions are right-aligned and padded, so that the final prediction for each group is arbitrarily padded and does not - correspond to a timepoint of interest -- e.g. for forecasting (i.e., calling - ``StateSpaceModel.forward(initial_state=get_state_at_times(...))``). - - :param times: Either (a) indices corresponding to each group (e.g. ``times[0]`` corresponds to the timestep to - take for the 0th group, ``times[1]`` the timestep to take for the 1th group, etc.) or (b) if ``start_times`` - is passed, an array of datetimes. Will also support a single datetime. - :param start_times: If ``times`` is an array of datetimes, must also pass ``start_datetimes``, i.e. the - datetimes at which each group started. - :param dt_unit: If ``times`` is an array of datetimes, must also pass ``dt_unit``, i.e. a - :class:`numpy.timedelta64` that indicates how much time passes at each timestep. (times-start_times)/dt_unit - should be an array of integers. + correspond to a timepoint of interest -- e.g. for simulation (i.e., calling + ``StateSpaceModel.simulate(initial_state=get_state_at_times(...))``). + + :param times: An array/sequence containing the time for each group; or a single datetime to apply to all groups. + If the model/predictions are dateless (no dt_unit) then simply an array of indices :param type_: What type of state? Since this method is typically used for getting an `initial_state` for another call to :func:`StateSpaceModel.forward()`, this should generally be 'update' (the default); other option is 'prediction'. :return: A tuple of state-means and state-covs, appropriate for forecasting by passing as `initial_state` for :func:`StateSpaceModel.forward()`. """ - sliced = self._subset_to_times(times=times, start_times=start_times, dt_unit=dt_unit) + preds = self.with_new_start_times(start_times=times, n_timesteps=1, **kwargs) if type_.startswith('pred'): - return sliced.state_means.squeeze(1), sliced.state_covs.squeeze(1) + return preds.state_means.squeeze(1), preds.state_covs.squeeze(1) elif type_.startswith('update'): - if self.update_means is None: - raise RuntimeError( - "Cannot get with ``type_='update'`` because update mean/cov was not passed when creating this " - "``Predictions`` object. This usually means you have to include ``include_updates=True`` when " - "calling ``StateSpaceModel``." - ) - return sliced.update_means.squeeze(1), sliced.update_covs.squeeze(1) + return preds.update_means.squeeze(1), preds.update_covs.squeeze(1) else: raise ValueError("Unrecognized `type_`, expected 'prediction' or 'update'.") + def _standardize_times(self, + times: Union[np.ndarray, np.datetime64], + start_offsets: Optional[np.ndarray] = None, + dt_unit: Optional[str] = None) -> np.ndarray: + if start_offsets is not None: + warn( + "Passing `start_offsets` as an argument is deprecated, first call ``set_metadata()``", + DeprecationWarning + ) + if dt_unit is not None: + warn( + "Passing `dt_unit` as an argument is deprecated, first call ``set_metadata()``", + DeprecationWarning + ) + if self.dataset_metadata.start_offsets is not None: + start_offsets = self.dataset_metadata.start_offsets + if self.dataset_metadata.dt_unit is not None: + dt_unit = self.dataset_metadata.dt_unit + + if not isinstance(times, (list, tuple, np.ndarray)): + times = [times] * self.num_groups + times = np.asanyarray(times, dtype='datetime64' if dt_unit else 'int') + + if start_offsets is None: + if dt_unit is not None: + raise ValueError("If `dt_unit` is specified, then `start_offsets` must also be specified.") + else: + if isinstance(dt_unit, str): + dt_unit = np.timedelta64(1, dt_unit) + times = times - start_offsets + if dt_unit is not None: + times = times // dt_unit # todo: validate int? + else: + assert times.dtype.name.startswith('int') + + assert len(times.shape) == 1 + assert times.shape[0] == self.num_groups + + return times + @classmethod def observe(cls, state_means: Tensor, state_covs: Tensor, R: Tensor, H: Tensor) -> Tuple[Tensor, Tensor]: """ @@ -165,10 +260,7 @@ def observe(cls, state_means: Tensor, state_covs: Tensor, R: Tensor, H: Tensor) :return: A tuple of `means`, `covs`. """ means = H.matmul(state_means.unsqueeze(-1)).squeeze(-1) - pargs = list(range(len(H.shape))) - pargs[-2:] = reversed(pargs[-2:]) - Ht = H.permute(*pargs) - assert R.shape[-1] == R.shape[-2], f"R is not symmetrical (shape is {R.shape})" + Ht = transpose_last_dims(H) covs = H.matmul(state_covs).matmul(Ht) + R return means, covs @@ -187,16 +279,12 @@ def covs(self) -> Tensor: self._means, self._covs = self.observe(self.state_means, self.state_covs, self.R, self.H) return self._covs - def sample(self) -> Tensor: - with torch.no_grad(): - dist = self.distribution_cls(self.means, self.covs) - return dist.rsample() - - def log_prob(self, obs: Tensor) -> Tensor: + def log_prob(self, obs: Tensor, weights: Optional[Tensor] = None) -> Tensor: """ - Compute the log-probability of data (e.g. data that was originally fed into the KalmanFilter). + Compute the log-probability of data (e.g. data that was originally fed into the ``StateSpaceModel``). - :param obs: A Tensor that could be used in the KalmanFilter.forward pass. + :param obs: A Tensor that could be used in the ``StateSpaceModel`` forward pass. + :param weights: If specified, will be used to weight the log-probability of each group X timestep. :return: A tensor with one element for each group X timestep indicating the log-probability. """ assert len(obs.shape) == 3 @@ -208,17 +296,23 @@ def log_prob(self, obs: Tensor) -> Tensor: means_flat = self.means.view(-1, n_measure_dim) covs_flat = self.covs.view(-1, n_measure_dim, n_measure_dim) - # if the model used an outlier threshold, under-weight outliers - weights = torch.ones(obs_flat.shape[0], dtype=self.state_means.dtype, device=self.state_means.device) - if self.outlier_threshold: - obs_flat = obs_flat.clone() - for gt_idx, valid_idx in get_nan_groups(torch.isnan(obs_flat)): - if valid_idx is None: - mdist = mahalanobis_dist(obs_flat[gt_idx] - means_flat[gt_idx], covs_flat[gt_idx]) - multi = (mdist - self.outlier_threshold).clamp(min=0) + 1 - weights[gt_idx] = 1 / multi - else: - raise NotImplemented + # # if the model used an outlier threshold, under-weight outliers + if weights is None: + weights = torch.ones(obs_flat.shape[0], dtype=self.state_means.dtype, device=self.state_means.device) + else: + weights = weights.reshape(-1, n_measure_dim) + # if self.outlier_threshold: + # obs_flat = obs_flat.clone() + # for gt_idx, valid_idx in get_nan_groups(torch.isnan(obs_flat)): + # if valid_idx is None: + # multi = get_outlier_multi( + # resid=obs_flat[gt_idx] - means_flat[gt_idx], + # cov=covs_flat[gt_idx], + # outlier_threshold=torch.as_tensor(self.outlier_threshold) + # ) + # weights[gt_idx] /= multi + # else: + # raise NotImplemented state_means_flat = self.state_means.view(-1, n_state_dim) state_covs_flat = self.state_covs.view(-1, n_state_dim, n_state_dim) @@ -248,61 +342,64 @@ def log_prob(self, obs: Tensor) -> Tensor: return lp_flat.view(obs.shape[0:2]) def _log_prob(self, obs: Tensor, means: Tensor, covs: Tensor) -> Tensor: - return self.distribution_cls(means, covs, validate_args=False).log_prob(obs) + return torch.distributions.MultivariateNormal(means, covs, validate_args=False).log_prob(obs) def to_dataframe(self, - dataset: Union[TimeSeriesDataset, dict], + dataset: Optional[TimeSeriesDataset] = None, type: str = 'predictions', - group_colname: str = 'group', - time_colname: str = 'time', - multi: Optional[float] = 1.96) -> 'DataFrame': + group_colname: Optional[str] = None, + time_colname: Optional[str] = None, + conf: Optional[float] = .95, + **kwargs) -> pd.DataFrame: """ - :param dataset: Either a :class:`.TimeSeriesDataset`, or a dictionary with 'start_times', 'group_names', & - 'dt_unit' + :param dataset: The dataset which generated the predictions. If not supplied, will use the metadata set at + prediction time, but the group-names will be replaced by dummy group names, and the output will not include + actuals. :param type: Either 'predictions' or 'components'. :param group_colname: Column-name for 'group' :param time_colname: Column-name for 'time' - :param multi: Multiplier on std-dev for lower/upper CIs. Default 1.96. + :param conf: If set, specifies the confidence level for the 'lower' and 'upper' columns in the output. Default + of 0.95 means these are 0.025 and 0.975. If ``None``, then will just include 'std' column instead. :return: A pandas DataFrame with group, 'time', 'measure', 'mean', 'lower', 'upper'. For ``type='components'`` additionally includes: 'process' and 'state_element'. """ + multi = kwargs.pop('multi', False) + if multi is not False: + msg = "`multi` is deprecated, please use `conf` instead." + if multi is None: # old way of specifying "just return std", for backwards-compatibility + warn(msg, DeprecationWarning) + conf = None + else: + raise TypeError(msg) - from pandas import concat - - if isinstance(dataset, TimeSeriesDataset): - batch_info = { - 'start_times': dataset.start_times, - 'group_names': dataset.group_names, - 'named_tensors': {}, - 'dt_unit': dataset.dt_unit - } + named_tensors = {} + if dataset is None: + dataset = self.dataset_metadata.copy() + if dataset.group_names is None: + dataset.group_names = [f"group_{i}" for i in range(self.num_groups)] + else: for measure_group, tensor in zip(dataset.measures, dataset.tensors): for i, measure in enumerate(measure_group): if measure in self.measures: - batch_info['named_tensors'][measure] = tensor[..., [i]] + named_tensors[measure] = tensor[..., [i]] missing = set(self.measures) - set(dataset.all_measures) if missing: raise ValueError( f"Some measures in the design aren't in the dataset.\n" f"Design: {missing}\nDataset: {dataset.all_measures}" ) - elif isinstance(dataset, dict): - batch_info = dataset.copy() - if isinstance(batch_info['dt_unit'], str): - batch_info['dt_unit'] = np.timedelta64(1, batch_info['dt_unit']) - else: - raise TypeError( - "Expected `batch` to be a TimeSeriesDataset, or a dictionary with 'start_times' and 'group_names'." - ) + + group_colname = group_colname or self.dataset_metadata.group_colname + time_colname = time_colname or self.dataset_metadata.time_colname def _tensor_to_df(tens, measures): - offsets = np.arange(0, tens.shape[1]) * (batch_info['dt_unit'] if batch_info['dt_unit'] else 1) - times = batch_info['start_times'][:, None] + offsets + offsets = np.arange(0, tens.shape[1]) * (dataset.dt_unit if dataset.dt_unit else 1) + times = dataset.start_offsets[:, None] + offsets return TimeSeriesDataset.tensor_to_dataframe( tensor=tens, times=times, - group_names=batch_info['group_names'], + group_names=dataset.group_names, group_colname=group_colname, time_colname=time_colname, measures=measures @@ -312,36 +409,32 @@ def _tensor_to_df(tens, measures): assert time_colname not in {'mean', 'lower', 'upper', 'std'} out = [] - if type == 'predictions': + if type.startswith('pred'): stds = torch.diagonal(self.covs, dim1=-1, dim2=-2).sqrt() for i, measure in enumerate(self.measures): # predicted: df = _tensor_to_df(torch.stack([self.means[..., i], stds[..., i]], 2), measures=['mean', 'std']) - if multi is not None: - df['lower'] = df['mean'] - multi * df['std'] - df['upper'] = df['mean'] + multi * df.pop('std') + if conf is not None: + df['lower'], df['upper'] = conf2bounds(df['mean'], df.pop('std'), conf=conf) # actual: - orig_tensor = batch_info.get('named_tensors', {}).get(measure, None) - if orig_tensor is not None and (orig_tensor == orig_tensor).any(): + orig_tensor = named_tensors.get(measure, None) + if orig_tensor is not None and not torch.isnan(orig_tensor).all(): df_actual = _tensor_to_df(orig_tensor, measures=['actual']) df = df.merge(df_actual, on=[group_colname, time_colname], how='left') out.append(df.assign(measure=measure)) - elif type == 'components': - # components: + elif type.startswith('comp'): for (measure, process, state_element), (m, std) in self._components().items(): df = _tensor_to_df(torch.stack([m, std], 2), measures=['mean', 'std']) - if multi is not None: - df['lower'] = df['mean'] - multi * df['std'] - df['upper'] = df['mean'] + multi * df.pop('std') + if conf is not None: + df['lower'], df['upper'] = conf2bounds(df['mean'], df.pop('std'), conf=conf) df['process'], df['state_element'], df['measure'] = process, state_element, measure out.append(df) # residuals: - named_tensors = batch_info.get('named_tensors', {}) for i, measure in enumerate(self.measures): orig_tensor = named_tensors.get(measure) predictions = self.means[..., [i]] @@ -359,7 +452,7 @@ def _tensor_to_df(tens, measures): else: raise ValueError("Expected `type` to be 'predictions' or 'components'.") - return concat(out, sort=True) + return pd.concat(out, sort=True) @torch.no_grad() def _components(self) -> Dict[Tuple[str, str, str], Tuple[Tensor, Tensor]]: @@ -369,19 +462,20 @@ def _components(self) -> Dict[Tuple[str, str, str], Tuple[Tensor, Tensor]]: means = H * self.state_means stds = H * torch.diagonal(self.state_covs, dim1=-2, dim2=-1).sqrt() - for se_idx, (process, state_element) in enumerate(self.all_state_elements): + for se_idx, (process, state_element) in enumerate(self._model_attributes.all_state_elements): if not is_near_zero(means[:, :, se_idx]).all(): out[(measure, process, state_element)] = (means[:, :, se_idx], stds[:, :, se_idx]) return out - @staticmethod - def plot(df: 'DataFrame', + @classmethod + def plot(cls, + df: pd.DataFrame, group_colname: str = None, time_colname: str = None, max_num_groups: int = 1, split_dt: Optional[np.datetime64] = None, - **kwargs) -> 'DataFrame': + **kwargs) -> pd.DataFrame: """ :param df: The output of :func:`Predictions.to_dataframe()`. :param group_colname: The name of the group-column. @@ -397,6 +491,12 @@ def plot(df: 'DataFrame', ggplot, aes, geom_line, geom_ribbon, facet_grid, facet_wrap, theme_bw, theme, ylab, geom_vline ) + if isinstance(cls, Predictions): # using it as an instance-method + group_colname = group_colname or cls.dataset_metadata.group_colname + time_colname = time_colname or cls.dataset_metadata.time_colname + elif not group_colname or not time_colname: + raise TypeError("Please specify group_colname and time_colname") + is_components = 'process' in df.columns if is_components and 'state_element' not in df.columns: df = df.assign(state_element='all') @@ -412,8 +512,7 @@ def plot(df: 'DataFrame', df = df.copy() if 'upper' not in df.columns and 'std' in df.columns: - df['upper'] = df['mean'] + 1.96 * df['std'] - df['lower'] = df['mean'] - 1.96 * df['std'] + df['lower'], df['upper'] = conf2bounds(df['mean'], df.pop('std'), conf=.95) if df[group_colname].nunique() > max_num_groups: subset_groups = df[group_colname].drop_duplicates().sample(max_num_groups).tolist() if len(subset_groups) < df[group_colname].nunique(): @@ -440,8 +539,8 @@ def plot(df: 'DataFrame', elif num_groups == 1: plot = plot + facet_wrap(f"~ measure + process", scales='free_y', labeller='label_both') if 'figure_size' not in kwargs: - from plotnine.facets.facet_wrap import n2mfrow - nrow, _ = n2mfrow(len(df[['process', 'measure']].drop_duplicates().index)) + from plotnine.facets.facet_wrap import wrap_dims + nrow, _ = wrap_dims(len(df[['process', 'measure']].drop_duplicates().index)) kwargs['figure_size'] = (12, nrow * 2.5) else: plot = plot + facet_grid(f"{group_colname} ~ measure", scales='free_y', labeller='label_both') @@ -467,30 +566,8 @@ def plot(df: 'DataFrame', return plot + theme_bw() + theme(**kwargs) - def _subset_to_times(self, - times: Union[np.ndarray, np.datetime64], - start_times: Optional[np.ndarray] = None, - dt_unit: Optional[str] = None) -> 'Predictions': - """ - Return a `Predictions` object with a single timepoint for each group. - """ - if not isinstance(times, (list, tuple, np.ndarray)): - times = np.asanyarray([times] * self.num_groups) - - if start_times is not None: - if isinstance(dt_unit, str): - dt_unit = np.timedelta64(1, dt_unit) - times = times - start_times - if dt_unit is not None: - times = times // dt_unit # todo: validate int? - - assert len(times.shape) == 1 - assert times.shape[0] == self.num_groups - idx = (torch.arange(self.num_groups), torch.as_tensor(times, dtype=torch.int64)) - return self._subset(idx, collapsed_dim=1) - def __iter__(self) -> Iterator[Tensor]: - # for mean, cov = tuple(predictions) + # so that we can do ``mean, cov = predictions`` yield self.means yield self.covs @@ -499,27 +576,17 @@ def __array__(self) -> np.ndarray: return self.means.detach().numpy() def __getitem__(self, item: Tuple) -> 'Predictions': - return self._subset(item) - - def _subset(self, idx: Tuple, collapsed_dim: Optional[int] = None) -> 'Predictions': - """ - Helper for __getitem__ and get_timeslice - """ - if collapsed_dim is not None: - assert collapsed_dim < 2 kwargs = { - 'state_means': self.state_means[idx], - 'state_covs': self.state_covs[idx], - 'H': self.H[idx], - 'R': self.R[idx] + 'state_means': self.state_means[item], + 'state_covs': self.state_covs[item], + 'H': self.H[item], + 'R': self.R[item] } - if self.update_means is not None: - kwargs.update({'update_means': self.update_means[idx], 'update_covs': self.update_covs[idx]}) + if self._update_means is not None: + kwargs.update({'update_means': self.update_means[item], 'update_covs': self.update_covs[item]}) cls = type(self) for k in list(kwargs): expected_shape = getattr(self, k).shape - if collapsed_dim is not None: - kwargs[k] = kwargs[k].unsqueeze(collapsed_dim) v = kwargs[k] if len(v.shape) != len(expected_shape): raise TypeError(f"Expected {k} to have shape {expected_shape} but got {v.shape}.") @@ -529,14 +596,35 @@ def _subset(self, idx: Tuple, collapsed_dim: Optional[int] = None) -> 'Predictio raise TypeError(f"Cannot index into non-batch dims of {cls.__name__}") return cls(**kwargs, model=self._model_attributes) - @property - def _model_attributes(self) -> dict: - """ - Has the attributes of a KalmanFilter that are needed in __init__ - """ - return { - 'measures': self.measures, - 'distribution_cls': self.distribution_cls, - 'all_state_elements': self.all_state_elements, - 'outlier_threshold': self.outlier_threshold - } + +@dataclass +class StateSpaceModelMetadata: + measures: Sequence[str] + all_state_elements: Sequence[Tuple[str, str]] + + +@dataclass +class DatasetMetadata: + group_names: Optional[Sequence[str]] + start_offsets: Optional[np.ndarray] + dt_unit: Optional[np.timedelta64] + group_colname: str = 'group' + time_colname: str = 'time' + + def update(self, **kwargs) -> 'DatasetMetadata': + for f in fields(self): + v = kwargs.pop(f.name, None) + if v is not None: + setattr(self, f.name, v) + if kwargs: + raise TypeError(f"Unrecognized kwargs: {list(kwargs)}") + return self + + def copy(self) -> 'DatasetMetadata': + return DatasetMetadata( + group_names=self.group_names, + start_offsets=self.start_offsets, + dt_unit=self.dt_unit, + group_colname=self.group_colname, + time_colname=self.time_colname + ) diff --git a/torchcast/state_space/ss_step.py b/torchcast/state_space/ss_step.py index 903898c..dd2228d 100644 --- a/torchcast/state_space/ss_step.py +++ b/torchcast/state_space/ss_step.py @@ -1,4 +1,4 @@ -from typing import Type, Tuple, Dict, Optional +from typing import Tuple, Dict, Optional import torch from torch import Tensor @@ -11,11 +11,6 @@ class StateSpaceStep(torch.nn.Module): Base-class for modules that handle predict/update within a state-space model. """ - # this would ideally be a class-attribute but torch.jit.trace strips them - @torch.jit.ignore() - def get_distribution(self) -> Type[torch.distributions.Distribution]: - return torch.distributions.MultivariateNormal - def forward(self, input: Tensor, mean: Tensor, @@ -29,7 +24,11 @@ def forward(self, def predict(self, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: raise NotImplementedError - def _update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: + def _update(self, + input: Tensor, + mean: Tensor, + cov: Tensor, + kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: raise NotImplementedError def update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: @@ -55,8 +54,11 @@ def update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Ten new_cov = cov.clone() for groups, val_idx in get_nan_groups(isnan): masked_input, masked_kwargs = self._mask_mats(groups, val_idx, input=input, kwargs=kwargs) - m,c = self._update( - input=masked_input, mean=mean[groups], cov=cov[groups], kwargs=masked_kwargs + m, c = self._update( + input=masked_input, + mean=mean[groups], + cov=cov[groups], + kwargs=masked_kwargs ) new_mean[groups] = m if c is None: diff --git a/torchcast/utils/__init__.py b/torchcast/utils/__init__.py index 4f170c3..7737661 100644 --- a/torchcast/utils/__init__.py +++ b/torchcast/utils/__init__.py @@ -9,3 +9,5 @@ from .features import add_season_features from .data import TimeSeriesDataset, TimeSeriesDataLoader, complete_times +from .stats import conf2bounds +from .outliers import get_outlier_multi diff --git a/torchcast/utils/data.py b/torchcast/utils/data.py index 18ae5df..7f0356c 100644 --- a/torchcast/utils/data.py +++ b/torchcast/utils/data.py @@ -43,6 +43,8 @@ def __init__(self, assert len(tensors) == len(measures) for i, (tensor, tensor_measures) in enumerate(zip(tensors, measures)): + if isinstance(tensor_measures, str): + raise ValueError(f"Expected measures to be a list of lists/tuples, but element-{i} is a string.") if len(tensor.shape) < 3: raise ValueError(f"Tensor {i} has < 3 dimensions") if tensor.shape[0] != len(group_names): @@ -51,7 +53,7 @@ def __init__(self, raise ValueError(f"Tensor {i}'s 3rd dimension has length != len({tensor_measures}).") self.measures = tuple(tuple(m) for m in measures) - self.all_measures = tuple(itertools.chain.from_iterable(self.measures)) + self.group_names = group_names self.dt_unit = None if dt_unit: @@ -74,6 +76,10 @@ def __init__(self, self.start_times = start_times super().__init__(*tensors) + @property + def all_measures(self) -> tuple: + return tuple(itertools.chain.from_iterable(self.measures)) + def to(self, *args, **kwargs) -> 'TimeSeriesDataset': new_tensors = [x.to(*args, **kwargs) for x in self.tensors] return self.with_new_tensors(*new_tensors) @@ -151,15 +157,19 @@ def train_val_split(self, return train_dataset, val_dataset def with_new_start_times(self, - start_times: Union[np.ndarray, Sequence], + start_times: Union[datetime.datetime, np.datetime64, np.ndarray, Sequence], + n_timesteps: Optional[int] = None, quiet: bool = False) -> 'TimeSeriesDataset': """ Subset a :class:`.TimeSeriesDataset` so that some/all of the groups have later start times. - :param start_times: An array/list of new datetimes. + :param start_times: An array/list of new datetimes, or a single datetime that will be used for all groups. + :param n_timesteps: The number of timesteps in the output (nan-padded). :param quiet: If True, will not emit a warning for groups having only `nan` after the start-time. :return: A new :class:`.TimeSeriesDataset`. """ + if isinstance(start_times, (datetime.datetime, np.datetime64)): + start_times = np.full(len(self.group_names), start_times, dtype='datetime64[ns]' if self.dt_unit else 'int') new_tensors = [] for i, tens in enumerate(self.tensors): times = self.times(i) @@ -185,6 +195,14 @@ def with_new_start_times(self, end_idx = true1d_idx(~all_nan).max() + 1 new_tens.append(g_tens[:end_idx].unsqueeze(0)) new_tens = ragged_cat(new_tens, ragged_dim=1, cat_dim=0) + if n_timesteps: + if new_tens.shape[1] > n_timesteps: + new_tens = new_tens[:, :n_timesteps, :] + else: + tmp = torch.empty((new_tens.shape[0], n_timesteps, new_tens.shape[2]), dtype=new_tens.dtype) + tmp[:] = float('nan') + tmp[:, :new_tens.shape[1], :] = new_tens + new_tens = tmp new_tensors.append(new_tens) return type(self)( *new_tensors, @@ -311,13 +329,14 @@ def tensor_to_dataframe(tensor: Tensor, assert tensor.shape[1] <= times.shape[1] assert tensor.shape[2] == len(measures) + _all_nan_groups = [] dfs = [] for g, group_name in enumerate(group_names): # get values, don't store trailing nans: values = tensor[g] all_nan_per_row = np.min(np.isnan(values), axis=1) if all_nan_per_row.all(): - warn(f"Group {group_name} has only missing values.") + _all_nan_groups.append(group_name) continue end_idx = true1d_idx(~all_nan_per_row).max() + 1 # convert to dataframe: @@ -326,6 +345,9 @@ def tensor_to_dataframe(tensor: Tensor, df[time_colname] = np.nan df[time_colname] = times[g, 0:len(df.index)] dfs.append(df) + if _all_nan_groups: + warn(f"Groups have only missing values:{_all_nan_groups}") + if dfs: return concat(dfs) else: @@ -586,6 +608,7 @@ def complete_times(data: 'DataFrame', group_colnames: Sequence[str] = None, time_colname: Optional[str] = None, dt_unit: Optional[str] = None, + max_dt_colname: Optional[str] = None, global_max: Union[bool, datetime.datetime] = False, group_colname: Optional[str] = None): """ @@ -594,10 +617,10 @@ def complete_times(data: 'DataFrame', :param data: A pandas dataframe. :param group_colnames: The column name(s) for the groups. :param time_colname: The column name for the times. Will attempt to guess based on common labels. - :param dt_unit: A :class:`numpy.datetime64` or string representing the datetime increments. If not supplied will - try to guess based on the smallest difference in the data. - :param global_max: If `True`, will use the max time of all groups for the max time of each group. If false, will - keep times past each group's max time as implicitly missing. If a datetime is passed, will use that as the max. + :param dt_unit: Passed to ``pandas.date_range``. If not passed, will attempt to guess based on the minimum + difference between times. + :param max_dt_colname: Optional, a column-name that indicates the maximum time for each group. If not supplied, the + actual maximum time for each group will be used. :return: A dataframe where implicit missings are converted to explicit missings, but the min/max time for each group is preserved. """ @@ -610,6 +633,9 @@ def complete_times(data: 'DataFrame', raise TypeError("Missing required argument `group_colnames`") warn("Please pass `group_colnames` instead of `group_colname`", DeprecationWarning) group_colnames = [group_colname] + if max_dt_colname and max_dt_colname not in group_colnames: + assert (data.groupby(group_colnames)[max_dt_colname].nunique() == 1).all() + group_colnames.append(max_dt_colname) if time_colname is None: for col in ('datetime', 'date', 'timestamp', 'time', 'dt'): @@ -628,31 +654,26 @@ def complete_times(data: 'DataFrame', # (e.g. does not match behavior of `my_dates.to_period('W').dt.to_timestamp()`) dt_unit = pd.Timedelta('7 days 00:00:00') - max_time = data[time_colname].max() - if global_max is True: # they can specify a specific value, or pass True for the max in the data - global_max = max_time - # or they can leave global_max=None, in which case will filter to group-specific max below + if global_max: + warn("`global_max=True` is deprecated, use `max_dt_colname` instead.", DeprecationWarning) + + df_group_summary = (data + .groupby(group_colnames) + .agg(_min=(time_colname, 'min'), _max=(time_colname, 'max')) + .reset_index()) + if max_dt_colname: + df_group_summary['_max'] = df_group_summary[max_dt_colname] - df_grid = pd.DataFrame( - {time_colname: pd.date_range(data[time_colname].min(), global_max or max_time, freq=dt_unit)} - ) + max_of_maxes = df_group_summary['_max'].max() - df_group_summary = data. \ - groupby(group_colnames). \ - agg(_min=(time_colname, 'min'), - _max=(time_colname, 'max')). \ - reset_index() - if global_max: - df_group_summary['_max'] = global_max + df_grid = pd.DataFrame({time_colname: pd.date_range(data[time_colname].min(), max_of_maxes, freq=dt_unit)}) # cross-join for all times to all groups (todo: not very memory efficient) - df_cj = df_grid. \ - assign(_cj=1). \ - merge(df_group_summary.assign(_cj=1), how='left', on=['_cj']) + df_cj = df_grid.merge(df_group_summary, how='cross') # filter to min/max for each group - df_cj = df_cj. \ - loc[df_cj[time_colname].between(df_cj['_min'], df_cj['_max']), group_colnames + [time_colname]]. \ - reset_index(drop=True) + df_cj = (df_cj + .loc[df_cj[time_colname].between(df_cj['_min'], df_cj['_max']), group_colnames + [time_colname]] + .reset_index(drop=True)) return df_cj.merge(data, how='left', on=group_colnames + [time_colname]) diff --git a/torchcast/utils/outliers.py b/torchcast/utils/outliers.py index 0281119..8081118 100644 --- a/torchcast/utils/outliers.py +++ b/torchcast/utils/outliers.py @@ -1,4 +1,25 @@ import torch +from torch.linalg import LinAlgError + + +def get_outlier_multi(resid: torch.Tensor, + cov: torch.Tensor, + outlier_threshold: torch.Tensor) -> torch.Tensor: + if len(outlier_threshold) == 2: + if resid.shape[-1] != 1: + raise NotImplementedError + neg_mask = (resid < 0).squeeze(-1) + mdist_neg = mahalanobis_dist(resid[neg_mask], cov[neg_mask]) + mdist_pos = mahalanobis_dist(resid[~neg_mask], cov[~neg_mask]) + multi = torch.ones_like(resid).squeeze(-1) + neg_thresh, pos_thresh = outlier_threshold.abs() + multi[neg_mask] = (mdist_neg - neg_thresh).clamp(min=0) + 1 + multi[~neg_mask] = (mdist_pos - pos_thresh).clamp(min=0) + 1 + else: + assert outlier_threshold.numel() + mdist = mahalanobis_dist(resid, cov) + multi = (mdist - outlier_threshold).clamp(min=0) + 1 + return multi def mahalanobis_dist(diff: torch.Tensor, covariance: torch.Tensor) -> torch.Tensor: diff --git a/torchcast/utils/simulate.py b/torchcast/utils/simulate.py deleted file mode 100644 index c42ea32..0000000 --- a/torchcast/utils/simulate.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Optional - -from torch import Tensor -from torch.distributions import MultivariateNormal -from torch.distributions.multivariate_normal import _batch_mv -from torch.distributions.utils import _standard_normal - - -def deterministic_sample_mvnorm(distribution: MultivariateNormal, eps: Optional[Tensor] = None) -> Tensor: - if isinstance(eps, Tensor): - if eps.shape[-len(distribution.event_shape):] != distribution.event_shape: - raise RuntimeError(f"Expected shape ending in {distribution.event_shape}, got {eps.shape}.") - else: - shape = distribution.batch_shape + distribution.event_shape - if eps is None: - eps = 1.0 - eps *= _standard_normal(shape, dtype=distribution.loc.dtype, device=distribution.loc.device) - return distribution.loc + _batch_mv(distribution._unbroadcasted_scale_tril, eps) diff --git a/torchcast/utils/stats.py b/torchcast/utils/stats.py new file mode 100644 index 0000000..6fa7d9e --- /dev/null +++ b/torchcast/utils/stats.py @@ -0,0 +1,9 @@ +from scipy import stats + + +def conf2bounds(mean, std, conf) -> tuple: + assert conf >= .50 + multi = -stats.norm.ppf((1 - conf) / 2) + lower = mean - multi * std + upper = mean + multi * std + return lower, upper