Skip to content

Commit

Permalink
wip more performant last_measured_per_group
Browse files Browse the repository at this point in the history
  • Loading branch information
jwdink committed Jan 24, 2025
1 parent 130630e commit d3f8d0d
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 74 deletions.
24 changes: 17 additions & 7 deletions torchcast/exp_smooth/exp_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from torchcast.exp_smooth.smoothing_matrix import SmoothingMatrix
from torchcast.covariance import Covariance
from torchcast.internals.utils import update_tensor
from torchcast.process import Process
from torchcast.state_space import StateSpaceModel, Predictions
from torchcast.state_space.ss_step import StateSpaceStep
Expand Down Expand Up @@ -37,18 +38,27 @@ def _update(self,
input: Tensor,
mean: Tensor,
cov: Tensor,
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Optional[Tensor]]:
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
measured_mean = (kwargs['H'] @ mean.unsqueeze(-1)).squeeze(-1)
resid = input - measured_mean
new_mean = mean + (kwargs['K'] @ resid.unsqueeze(-1)).squeeze(-1)
return new_mean, None
# _update doesn't waste compute creating new_cov; in predict cov will be replaced by cov1step
# TODO: why not replace it here?
new_cov = torch.tensor(0.0, dtype=mean.dtype, device=mean.device)
return new_mean, new_cov

def predict(self, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
F = kwargs['F']
new_mean = (F @ mean.unsqueeze(-1)).squeeze(-1)
def predict(self,
mean: Tensor,
cov: Tensor,
mask: Tensor,
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
F = kwargs['F'][mask]

new_mean = update_tensor(mean, new=(F @ mean[mask].unsqueeze(-1)).squeeze(-1), mask=mask)
new_cov = kwargs['cov1step']
if cov is not None:
new_cov = new_cov + F @ cov @ F.permute(0, 2, 1)
if len(cov.shape): # see note in _update() above
new_cov = update_tensor(new_cov.clone(), new=F @ cov[mask] @ F.permute(0, 2, 1), mask=mask)

return new_mean, new_cov


Expand Down
29 changes: 29 additions & 0 deletions torchcast/internals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,35 @@
import numpy as np


def update_tensor(orig: torch.Tensor, new: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""
In some cases we will want to save compute by only performing operations on a subset of a tensor, leaving the other
elements as-is. In this case we'd need:
.. code-block:: python
new = orig.clone()
new[mask] = some_op(orig[mask])
However, in other cases we don't want to mask. If the second case is common, we're going to waste compute with the
call to .clone().
This function is a convenience function that will handle the masking for you, but only if a non-trivial mask is
provided. If the mask is all True, it will return the new tensor as-is.
:param orig: The original tensor.
:param new: The new tensor. Should have the same shape as ``orig[mask]``.
:param mask: A boolean mask Tensor.
:return: If ``mask`` is all True, returns ``new`` (not a copy). Otherwise, returns a new tensor with the same shape
as ``orig`` where the elements in ``mask`` are replaced with the elements in ``new``.
"""
if mask.all():
return new
else:
out = orig.clone()
out[mask] = new
return out

def transpose_last_dims(x: torch.Tensor) -> torch.Tensor:
args = list(range(len(x.shape)))
args[-2], args[-1] = args[-1], args[-2]
Expand Down
19 changes: 13 additions & 6 deletions torchcast/kalman_filter/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Sequence, Dict, List, Iterable

from torchcast.covariance import Covariance
from torchcast.internals.utils import update_tensor
from torchcast.process import Process
from torchcast.state_space.base import StateSpaceModel
from torchcast.state_space.ss_step import StateSpaceStep
Expand All @@ -25,12 +26,18 @@ class KalmanStep(StateSpaceStep):
"""
use_stable_cov_update: Final[bool] = True

def predict(self, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
F = kwargs['F']
Q = kwargs['Q']
mean = (F @ mean.unsqueeze(-1)).squeeze(-1)
cov = F @ cov @ F.permute(0, 2, 1) + Q
return mean, cov
def predict(self,
mean: Tensor,
cov: Tensor,
mask: Tensor,
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
F = kwargs['F'][mask]
Q = kwargs['Q'][mask]

new_mean = update_tensor(mean, new=(F @ mean[mask].unsqueeze(-1)).squeeze(-1), mask=mask)
new_cov = update_tensor(cov, new=(F @ cov[mask] @ F.permute(0, 2, 1) + Q), mask=mask)

return new_mean, new_cov

def _mask_mats(self,
groups: Tensor,
Expand Down
84 changes: 35 additions & 49 deletions torchcast/state_space/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,14 @@ def fit(self,
if loss_callback:
warn("`loss_callback` is deprecated; use `get_loss` instead.", DeprecationWarning)

# see doc in ``forward()``: at training, no need to waste compute on unused predictions
kwargs['stop_at_last_measured'] = kwargs.get('stop_at_last_measured', True)
# TODO: explain
# todo: duplicate code in ``TimeSeriesDataset.get_durations()``
any_measured_bool = ~np.isnan(y.numpy()).all(2)
kwargs['last_measured_per_group'] = torch.as_tensor(
[np.max(true1d_idx(any_measured_bool[g]).numpy(), initial=0) for g in range(y.shape[0])],
dtype=torch.int,
device=y.device
) + 1

def closure():
optimizer.zero_grad()
Expand Down Expand Up @@ -267,7 +273,7 @@ def forward(self,
every_step: bool = True,
include_updates_in_output: bool = False,
simulate: Optional[int] = None,
stop_at_last_measured: bool = False,
last_measured_per_group: Optional[Tensor] = None,
prediction_kwargs: Optional[dict] = None,
**kwargs) -> Predictions:
"""
Expand Down Expand Up @@ -296,7 +302,7 @@ def forward(self,
:param include_updates_in_output: If False, only the ``n_step`` ahead predictions are included in the output.
This means that we cannot use this output to generate the ``initial_state`` for subsequent forward-passes. Set
to True to allow this -- False by default to reduce memory.
:param stop_at_last_measured: If True, then predictions will not be generated after the last measured timestep.
:param last_measured_per_group: TODO then predictions will not be generated after the last measured timestep.
Default False. The 'last measured timestep' is the last timestep where at least one measure is non-nan. Setting
to True can be useful in training when the series in a batch are highly variable in length. For example, if a
batch contains two series, one with length=1 and the other with length=999, then we can save the compute used
Expand Down Expand Up @@ -358,7 +364,7 @@ def forward(self,
out_timesteps=out_timesteps or input.shape[1],
**kwargs
),
stop_at_last_measured=stop_at_last_measured,
last_measured_per_group=last_measured_per_group,
simulate=bool(simulate)
)
prediction_kwargs = prediction_kwargs or {}
Expand Down Expand Up @@ -444,9 +450,9 @@ def _script_forward(self,
initial_state: Tuple[Tensor, Tensor],
n_step: int = 1,
out_timesteps: Optional[int] = None,
last_measured_per_group: Optional[Tensor] = None,
every_step: bool = True,
simulate: bool = False,
stop_at_last_measured: bool = False
) -> Tuple[
Tuple[List[Tensor], List[Tensor]],
Tuple[List[Tensor], List[Tensor]],
Expand All @@ -465,8 +471,7 @@ def _script_forward(self,
(the default) then this option has no effect.
:param simulate: If True, will simulate state-trajectories and return a ``Predictions`` object with zero state
covariance.
:param stop_at_last_measured: If True, then predictions will not be generated after the last measured timestep.
For more details, see :func:`StateSpaceModel.forward()`.
:param last_measured_per_group: TODO
"""
assert n_step > 0

Expand All @@ -478,9 +483,6 @@ def _script_forward(self,
inputs = []

num_groups = meanu.shape[0]
if stop_at_last_measured:
warn("Ignoring `stop_at_last_measured` since `input` is None.")
last_measured_per_group = torch.full((num_groups,), out_timesteps, dtype=torch.int, device=meanu.device)

if covu.shape[0] == 1:
covu = repeat(covu, times=num_groups, dim=0)
Expand All @@ -500,36 +502,25 @@ def _script_forward(self,
if out_timesteps is None:
out_timesteps = len(inputs)

if stop_at_last_measured:
# todo: duplicate code in ``TimeSeriesDataset.get_durations()``
any_measured_bool = ~np.isnan(input.numpy()).all(2)
last_measured_per_group = torch.as_tensor(
[np.max(true1d_idx(any_measured_bool[g]).numpy(), initial=0) for g in range(num_groups)],
dtype=torch.int,
device=input.device
) + 1
else:
last_measured_per_group = torch.full((num_groups,), out_timesteps, dtype=torch.int, device=input.device)

predict_kwargs, update_kwargs = self._build_design_mats(
kwargs_per_process=kwargs_per_process,
num_groups=num_groups,
out_timesteps=out_timesteps
)
if last_measured_per_group is None:
last_measured_per_group = torch.full((num_groups,), out_timesteps, dtype=torch.int, device=meanu.device)

# first loop through to do predict -> update
meanus: List[Tensor] = []
covus: List[Tensor] = []
mean1s: List[Tensor] = []
cov1s: List[Tensor] = []
for t in range(out_timesteps):
group_mask = (t <= last_measured_per_group)
mean1step = meanu.clone()
cov1step = covu.clone()
mean1step[group_mask], cov1step[group_mask] = self.ss_step.predict(
meanu[group_mask],
covu[group_mask],
{k: v[t][group_mask] for k, v in predict_kwargs.items()}
mean1step, cov1step = self.ss_step.predict(
meanu,
covu,
mask=(t <= last_measured_per_group),
kwargs={k: v[t] for k, v in predict_kwargs.items()}
)
mean1s.append(mean1step)
cov1s.append(cov1step)
Expand All @@ -538,14 +529,11 @@ def _script_forward(self,
meanu = torch.distributions.MultivariateNormal(mean1step, cov1step, validate_args=False).sample()
covu = torch.eye(meanu.shape[-1]).expand(num_groups, -1, -1) * 1e-6
elif t < len(inputs):
update_kwargs_t = {k: v[t][group_mask] for k, v in update_kwargs.items()}
# update_kwargs_t['outlier_threshold'] = torch.tensor(outlier_threshold if t > outlier_burnin else 0.)
meanu = meanu.clone()
covu = covu.clone()
meanu[group_mask], covu[group_mask] = self.ss_step.update(
inputs[t][group_mask],
mean1step[group_mask],
cov1step[group_mask],
update_kwargs_t = {k: v[t] for k, v in update_kwargs.items()}
meanu, covu = self.ss_step.update(
inputs[t],
mean1step,
cov1step,
update_kwargs_t,
)
else:
Expand All @@ -568,22 +556,20 @@ def _script_forward(self,
if every_step or (t1 % n_step) == 0:
meanp, covp = mean1s[t1], cov1s[t1] # already had to generate h=1 above
for h in range(1, n_step + 1):
if tu + h >= out_timesteps:
tu_h = tu + h
if tu_h >= out_timesteps:
break
if h > 1:
tu_h = tu + h
group_mask = (tu_h <= last_measured_per_group)
meanp = meanp.clone()
covp = covp.clone()
meanp[group_mask], covp[group_mask] = self.ss_step.predict(
meanp[group_mask],
covp[group_mask],
{k: v[tu_h][group_mask] for k, v in predict_kwargs.items()}
meanp, covp = self.ss_step.predict(
meanp,
covp,
mask=(tu_h <= last_measured_per_group),
kwargs={k: v[tu_h] for k, v in predict_kwargs.items()},
)
if tu + h not in meanps:
if tu_h not in meanps:
# idx[tu + h] = tu
meanps[tu + h] = meanp
covps[tu + h] = covp
meanps[tu_h] = meanp
covps[tu_h] = covp

preds = [meanps[t] for t in range(out_timesteps)], [covps[t] for t in range(out_timesteps)]
updates = meanus, covus
Expand Down
26 changes: 14 additions & 12 deletions torchcast/state_space/ss_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,19 @@ class StateSpaceStep(torch.nn.Module):
Base-class for modules that handle predict/update within a state-space model.
"""

def forward(self,
input: Tensor,
def predict(self,
mean: Tensor,
cov: Tensor,
predict_kwargs: Dict[str, Tensor],
update_kwargs: Dict[str, Tensor],
) -> Tuple[Tensor, Tensor]:
mean, cov = self.update(input, mean, cov, update_kwargs)
return self.predict(mean, cov, predict_kwargs)

def predict(self, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
mask: Tensor,
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
"""
:param mean: The current mean tensor.
:param cov: The current covariance tensor.
:param mask: A boolean mask tensor. Only masked elements of mean/cov will be updated, and remaining elements
will be returned as-is.
:param kwargs: A dictionary of keyword arguments.
:return: A tuple of (new_mean, new_cov) tensors.
"""
raise NotImplementedError

def _update(self,
Expand Down Expand Up @@ -61,8 +63,6 @@ def update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Ten
kwargs=masked_kwargs
)
new_mean[groups] = m
if c is None:
c = 0
new_cov[groups] = c
return new_mean, new_cov
else:
Expand All @@ -73,7 +73,9 @@ def _mask_mats(self,
val_idx: Optional[Tensor],
input: Tensor,
kwargs: Dict[str, Tensor],
kwargs_dims: Optional[Dict[str, int]]) -> Tuple[Tensor, Dict[str, Tensor]]:
kwargs_dims: Optional[Dict[str, int]] = None) -> Tuple[Tensor, Dict[str, Tensor]]:
if kwargs_dims is None:
raise RuntimeError("_mask_mats should only ever be called from subclasses which pass `kwargs_dims`")
new_kwargs = kwargs.copy()
if val_idx is None:
for k in kwargs_dims:
Expand Down

0 comments on commit d3f8d0d

Please sign in to comment.