Skip to content

Commit

Permalink
BaseModel.predict needs predict=True in SpatioTemporalKernel.spatial_…
Browse files Browse the repository at this point in the history
…conditional.

In kernels.py, if SpatioTemporalKernel.sparse=False and then spatial_conditional
is given predict=False (default value) then the provided arguments (X, R) will
be ignored and the predictions computed on the training data.
  • Loading branch information
fsaad committed Jun 29, 2023
1 parent f75f9c9 commit 9f07ce2
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions bayesnewton/basemodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ def predict(self, X=None, R=None, pseudo_lik_params=None):
H = self.kernel.measurement_model()
if self.spatio_temporal:
# TODO: if R is fixed, only compute B, C once
B, C = self.kernel.spatial_conditional(X, R)
B, C = self.kernel.spatial_conditional(X, R, predict=True)
W = B @ H
test_mean = W @ state_mean
test_var = W @ state_cov @ transpose(W) + C
Expand Down Expand Up @@ -1060,7 +1060,7 @@ def predict(self, X, R=None):
H = self.kernel.measurement_model()
if self.spatio_temporal:
# TODO: if R is fixed, only compute B, C once
B, C = self.kernel.spatial_conditional(X, R)
B, C = self.kernel.spatial_conditional(X, R, predict=True)
W = B @ H
test_mean = W @ state_mean
test_var = W @ state_cov @ transpose(W) + C
Expand Down Expand Up @@ -1216,7 +1216,7 @@ def predict(self, X=None, R=None, pseudo_lik_params=None):
H = self.kernel.measurement_model()
if self.spatio_temporal:
# TODO: if R is fixed, only compute B, C once
B, C = self.kernel.spatial_conditional(X, R)
B, C = self.kernel.spatial_conditional(X, R, predict=True)
W = B @ H
test_mean = W @ state_mean
test_var = W @ state_cov @ transpose(W) + C
Expand Down

0 comments on commit 9f07ce2

Please sign in to comment.