Skip to content

Commit

Permalink
wip: in fit, dont predict past last measured for each group
Browse files Browse the repository at this point in the history
  • Loading branch information
jwdink committed Jan 16, 2025
1 parent ba322c3 commit 3f0e0ed
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions torchcast/state_space/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,8 @@ def _script_forward(self,
n_step: int = 1,
out_timesteps: Optional[int] = None,
every_step: bool = True,
simulate: bool = False
simulate: bool = False,
stop_at_last_measured: bool = False
) -> Tuple[
Tuple[List[Tensor], List[Tensor]],
Tuple[List[Tensor], List[Tensor]],
Expand All @@ -445,6 +446,7 @@ def _script_forward(self,
(the default) then this option has no effect.
:param simulate: If True, will simulate state-trajectories and return a ``Predictions`` object with zero state
covariance.
:param stop_at_last_measured: TODO
:return: predictions (tuple of (means,covs)), updates (tuple of (means,covs)), R, H
"""
assert n_step > 0
Expand All @@ -457,6 +459,8 @@ def _script_forward(self,
inputs = []

num_groups = meanu.shape[0]
if stop_at_last_measured:
warn("Ignoring `stop_at_last_measured` since `input` is None.")
else:
if len(input.shape) != 3:
raise ValueError(f"Expected len(input.shape) == 3 (group,time,measure)")
Expand All @@ -473,6 +477,11 @@ def _script_forward(self,
if out_timesteps is None:
out_timesteps = len(inputs)

if stop_at_last_measured:
raise NotImplementedError("todo")
else:
last_measured_per_group = torch.full((num_groups,), out_timesteps, dtype=torch.int, device=input.device)

predict_kwargs, update_kwargs = self._build_design_mats(
kwargs_per_process=kwargs_per_process,
num_groups=num_groups,
Expand All @@ -485,10 +494,13 @@ def _script_forward(self,
mean1s: List[Tensor] = []
cov1s: List[Tensor] = []
for t in range(out_timesteps):
mean1step, cov1step = self.ss_step.predict(
meanu,
covu,
{k: v[t] for k, v in predict_kwargs.items()}
group_mask = (t <= last_measured_per_group)
mean1step = meanu.clone()
cov1step = covu.clone()
mean1step[group_mask], cov1step[group_mask] = self.ss_step.predict(
meanu[group_mask],
covu[group_mask],
{k: v[t][group_mask] for k, v in predict_kwargs.items()}
)
mean1s.append(mean1step)
cov1s.append(cov1step)
Expand Down Expand Up @@ -528,10 +540,14 @@ def _script_forward(self,
if tu + h >= out_timesteps:
break
if h > 1:
meanp, covp = self.ss_step.predict(
meanp,
covp,
{k: v[tu + h] for k, v in predict_kwargs.items()}
tu_h = tu + h
group_mask = (tu_h <= last_measured_per_group)
meanp = meanp.clone()
covp = covp.clone()
meanp[group_mask], covp[group_mask] = self.ss_step.predict(
meanp[group_mask],
covp[group_mask],
{k: v[tu_h][group_mask] for k, v in predict_kwargs.items()}
)
if tu + h not in meanps:
# idx[tu + h] = tu
Expand Down

0 comments on commit 3f0e0ed

Please sign in to comment.