From 9f07ce27446f1ecd1d2597bd92e622bb1d8a68cb Mon Sep 17 00:00:00 2001 From: "Feras A. Saad" Date: Thu, 29 Jun 2023 14:37:17 -0400 Subject: [PATCH] BaseModel.predict needs predict=True in SpatioTemporalKernel.spatial_conditional. In kernels.py, if SpatioTemporalKernel.sparse=False and then spatial_conditional is given predict=False (default value) then the provided arguments (X, R) will be ignored and the predictions computed on the training data. --- bayesnewton/basemodels.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bayesnewton/basemodels.py b/bayesnewton/basemodels.py index ac221dd..b992f34 100644 --- a/bayesnewton/basemodels.py +++ b/bayesnewton/basemodels.py @@ -803,7 +803,7 @@ def predict(self, X=None, R=None, pseudo_lik_params=None): H = self.kernel.measurement_model() if self.spatio_temporal: # TODO: if R is fixed, only compute B, C once - B, C = self.kernel.spatial_conditional(X, R) + B, C = self.kernel.spatial_conditional(X, R, predict=True) W = B @ H test_mean = W @ state_mean test_var = W @ state_cov @ transpose(W) + C @@ -1060,7 +1060,7 @@ def predict(self, X, R=None): H = self.kernel.measurement_model() if self.spatio_temporal: # TODO: if R is fixed, only compute B, C once - B, C = self.kernel.spatial_conditional(X, R) + B, C = self.kernel.spatial_conditional(X, R, predict=True) W = B @ H test_mean = W @ state_mean test_var = W @ state_cov @ transpose(W) + C @@ -1216,7 +1216,7 @@ def predict(self, X=None, R=None, pseudo_lik_params=None): H = self.kernel.measurement_model() if self.spatio_temporal: # TODO: if R is fixed, only compute B, C once - B, C = self.kernel.spatial_conditional(X, R) + B, C = self.kernel.spatial_conditional(X, R, predict=True) W = B @ H test_mean = W @ state_mean test_var = W @ state_cov @ transpose(W) + C