Skip to content

Commit

Permalink
don't validate args
Browse files Browse the repository at this point in the history
  • Loading branch information
jwdink committed Jul 24, 2024
1 parent 450f1c2 commit a92ea56
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions torchcast/state_space/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,18 +347,20 @@ def forward(self,
),
simulate=bool(simulate)
)
return self._generate_predictions(
preds = self._generate_predictions(
preds=preds,
updates=updates if include_updates_in_output else None,
start_offsets=start_offsets,
**design_mats,
)
return preds.set_metadata(
start_offsets=start_offsets,
dt_unit=self.dt_unit
)

@torch.jit.ignore
def _generate_predictions(self,
preds: Tuple[List[Tensor], List[Tensor]],
updates: Optional[Tuple[List[Tensor], List[Tensor]]] = None,
start_offsets: Optional[np.ndarray] = None,
**kwargs) -> 'Predictions':
"""
StateSpace subclasses may pass subclasses of `Predictions` (e.g. for custom log-prob)
Expand All @@ -373,10 +375,7 @@ def _generate_predictions(self,
model=self,
**kwargs
)
return preds.set_metadata(
start_offsets=start_offsets,
dt_unit=self.dt_unit
)
return preds

@torch.jit.ignore
def _prepare_initial_state(self,
Expand Down Expand Up @@ -497,7 +496,7 @@ def _script_forward(self,
cov1s.append(cov1step)

if simulate:
meanu = torch.distributions.MultivariateNormal(mean1step, cov1step).sample()
meanu = torch.distributions.MultivariateNormal(mean1step, cov1step, validate_args=False).sample()
covu = torch.eye(meanu.shape[-1]) * 1e-6
elif t < len(inputs):
update_kwargs_t = {k: v[t] for k, v in update_kwargs.items()}
Expand Down

0 comments on commit a92ea56

Please sign in to comment.