From a92ea5640282161907afb99217483d2fc491c46c Mon Sep 17 00:00:00 2001 From: Jacob Date: Wed, 24 Jul 2024 09:13:18 -0500 Subject: [PATCH] don't validate args --- torchcast/state_space/base.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/torchcast/state_space/base.py b/torchcast/state_space/base.py index 9b89aa5..8aab9d5 100644 --- a/torchcast/state_space/base.py +++ b/torchcast/state_space/base.py @@ -347,18 +347,20 @@ def forward(self, ), simulate=bool(simulate) ) - return self._generate_predictions( + preds = self._generate_predictions( preds=preds, updates=updates if include_updates_in_output else None, - start_offsets=start_offsets, **design_mats, ) + return preds.set_metadata( + start_offsets=start_offsets, + dt_unit=self.dt_unit + ) @torch.jit.ignore def _generate_predictions(self, preds: Tuple[List[Tensor], List[Tensor]], updates: Optional[Tuple[List[Tensor], List[Tensor]]] = None, - start_offsets: Optional[np.ndarray] = None, **kwargs) -> 'Predictions': """ StateSpace subclasses may pass subclasses of `Predictions` (e.g. for custom log-prob) @@ -373,10 +375,7 @@ def _generate_predictions(self, model=self, **kwargs ) - return preds.set_metadata( - start_offsets=start_offsets, - dt_unit=self.dt_unit - ) + return preds @torch.jit.ignore def _prepare_initial_state(self, @@ -497,7 +496,7 @@ def _script_forward(self, cov1s.append(cov1step) if simulate: - meanu = torch.distributions.MultivariateNormal(mean1step, cov1step).sample() + meanu = torch.distributions.MultivariateNormal(mean1step, cov1step, validate_args=False).sample() covu = torch.eye(meanu.shape[-1]) * 1e-6 elif t < len(inputs): update_kwargs_t = {k: v[t] for k, v in update_kwargs.items()}