From a3a05897a3120103ad41e304c146017882aac614 Mon Sep 17 00:00:00 2001 From: Kalvik Date: Wed, 14 Jun 2023 16:20:54 -0400 Subject: [PATCH] Fix Independent kernel Fixes a minor bug that was introduced when fixing the deprecated jax.ops.index_update and jax.ops.index_add. --- bayesnewton/kernels.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesnewton/kernels.py b/bayesnewton/kernels.py index 8a97509..8963233 100644 --- a/bayesnewton/kernels.py +++ b/bayesnewton/kernels.py @@ -1586,11 +1586,11 @@ def __init__(self, kernels): def K(self, X, X2): zeros = np.zeros(self.num_kernels) K0 = self.kernel0.K(X, X2) - index_vector = index_vector.at[0].add(1.) + index_vector = zeros.at[0].add(1.) Kstack = np.kron(K0, np.diag(index_vector)) for i in range(1, self.num_kernels): kerneli = eval("self.kernel" + str(i)) - index_vector = index_vector.at[i].add(1.) + index_vector = zeros.at[i].add(1.) Kstack += np.kron(kerneli.K(X, X2), np.diag(index_vector)) return Kstack