Skip to content

Commit

Permalink
Fix Independent kernel
Browse files Browse the repository at this point in the history
Fixes a minor bug that was introduced when fixing the deprecated jax.ops.index_update and jax.ops.index_add.
  • Loading branch information
itskalvik authored Jun 14, 2023
1 parent 8ce9598 commit a3a0589
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions bayesnewton/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1586,11 +1586,11 @@ def __init__(self, kernels):
def K(self, X, X2):
zeros = np.zeros(self.num_kernels)
K0 = self.kernel0.K(X, X2)
index_vector = index_vector.at[0].add(1.)
index_vector = zeros.at[0].add(1.)
Kstack = np.kron(K0, np.diag(index_vector))
for i in range(1, self.num_kernels):
kerneli = eval("self.kernel" + str(i))
index_vector = index_vector.at[i].add(1.)
index_vector = zeros.at[i].add(1.)
Kstack += np.kron(kerneli.K(X, X2), np.diag(index_vector))
return Kstack

Expand Down

0 comments on commit a3a0589

Please sign in to comment.