diff --git a/bayesnewton/basemodels.py b/bayesnewton/basemodels.py index ac221dd..b992f34 100644 --- a/bayesnewton/basemodels.py +++ b/bayesnewton/basemodels.py @@ -803,7 +803,7 @@ def predict(self, X=None, R=None, pseudo_lik_params=None): H = self.kernel.measurement_model() if self.spatio_temporal: # TODO: if R is fixed, only compute B, C once - B, C = self.kernel.spatial_conditional(X, R) + B, C = self.kernel.spatial_conditional(X, R, predict=True) W = B @ H test_mean = W @ state_mean test_var = W @ state_cov @ transpose(W) + C @@ -1060,7 +1060,7 @@ def predict(self, X, R=None): H = self.kernel.measurement_model() if self.spatio_temporal: # TODO: if R is fixed, only compute B, C once - B, C = self.kernel.spatial_conditional(X, R) + B, C = self.kernel.spatial_conditional(X, R, predict=True) W = B @ H test_mean = W @ state_mean test_var = W @ state_cov @ transpose(W) + C @@ -1216,7 +1216,7 @@ def predict(self, X=None, R=None, pseudo_lik_params=None): H = self.kernel.measurement_model() if self.spatio_temporal: # TODO: if R is fixed, only compute B, C once - B, C = self.kernel.spatial_conditional(X, R) + B, C = self.kernel.spatial_conditional(X, R, predict=True) W = B @ H test_mean = W @ state_mean test_var = W @ state_cov @ transpose(W) + C