From 3f0e0ed7eeb576c85849e4a12ee24d16e37e9034 Mon Sep 17 00:00:00 2001 From: Jacob Date: Thu, 16 Jan 2025 08:45:15 -0600 Subject: [PATCH] wip: in fit, dont predict past last measured for each group --- torchcast/state_space/base.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/torchcast/state_space/base.py b/torchcast/state_space/base.py index 70de3a6..127b043 100644 --- a/torchcast/state_space/base.py +++ b/torchcast/state_space/base.py @@ -426,7 +426,8 @@ def _script_forward(self, n_step: int = 1, out_timesteps: Optional[int] = None, every_step: bool = True, - simulate: bool = False + simulate: bool = False, + stop_at_last_measured: bool = False ) -> Tuple[ Tuple[List[Tensor], List[Tensor]], Tuple[List[Tensor], List[Tensor]], @@ -445,6 +446,7 @@ def _script_forward(self, (the default) then this option has no effect. :param simulate: If True, will simulate state-trajectories and return a ``Predictions`` object with zero state covariance. + :param stop_at_last_measured: TODO :return: predictions (tuple of (means,covs)), updates (tuple of (means,covs)), R, H """ assert n_step > 0 @@ -457,6 +459,8 @@ def _script_forward(self, inputs = [] num_groups = meanu.shape[0] + if stop_at_last_measured: + warn("Ignoring `stop_at_last_measured` since `input` is None.") else: if len(input.shape) != 3: raise ValueError(f"Expected len(input.shape) == 3 (group,time,measure)") @@ -473,6 +477,11 @@ def _script_forward(self, if out_timesteps is None: out_timesteps = len(inputs) + if stop_at_last_measured: + raise NotImplementedError("todo") + else: + last_measured_per_group = torch.full((num_groups,), out_timesteps, dtype=torch.int, device=input.device) + predict_kwargs, update_kwargs = self._build_design_mats( kwargs_per_process=kwargs_per_process, num_groups=num_groups, @@ -485,10 +494,13 @@ def _script_forward(self, mean1s: List[Tensor] = [] cov1s: List[Tensor] = [] for t in range(out_timesteps): - mean1step, cov1step = self.ss_step.predict( - meanu, - covu, - {k: v[t] for k, v in predict_kwargs.items()} + group_mask = (t <= last_measured_per_group) + mean1step = meanu.clone() + cov1step = covu.clone() + mean1step[group_mask], cov1step[group_mask] = self.ss_step.predict( + meanu[group_mask], + covu[group_mask], + {k: v[t][group_mask] for k, v in predict_kwargs.items()} ) mean1s.append(mean1step) cov1s.append(cov1step) @@ -528,10 +540,14 @@ def _script_forward(self, if tu + h >= out_timesteps: break if h > 1: - meanp, covp = self.ss_step.predict( - meanp, - covp, - {k: v[tu + h] for k, v in predict_kwargs.items()} + tu_h = tu + h + group_mask = (tu_h <= last_measured_per_group) + meanp = meanp.clone() + covp = covp.clone() + meanp[group_mask], covp[group_mask] = self.ss_step.predict( + meanp[group_mask], + covp[group_mask], + {k: v[tu_h][group_mask] for k, v in predict_kwargs.items()} ) if tu + h not in meanps: # idx[tu + h] = tu