From a3a05897a3120103ad41e304c146017882aac614 Mon Sep 17 00:00:00 2001 From: Kalvik Date: Wed, 14 Jun 2023 16:20:54 -0400 Subject: [PATCH 1/2] Fix Independent kernel Fixes a minor bug that was introduced when fixing the deprecated jax.ops.index_update and jax.ops.index_add. --- bayesnewton/kernels.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesnewton/kernels.py b/bayesnewton/kernels.py index 8a97509..8963233 100644 --- a/bayesnewton/kernels.py +++ b/bayesnewton/kernels.py @@ -1586,11 +1586,11 @@ def __init__(self, kernels): def K(self, X, X2): zeros = np.zeros(self.num_kernels) K0 = self.kernel0.K(X, X2) - index_vector = index_vector.at[0].add(1.) + index_vector = zeros.at[0].add(1.) Kstack = np.kron(K0, np.diag(index_vector)) for i in range(1, self.num_kernels): kerneli = eval("self.kernel" + str(i)) - index_vector = index_vector.at[i].add(1.) + index_vector = zeros.at[i].add(1.) Kstack += np.kron(kerneli.K(X, X2), np.diag(index_vector)) return Kstack From 3014b7de3260f759f6d6d1b680bbf5c5206708ca Mon Sep 17 00:00:00 2001 From: Kalvik Date: Thu, 15 Jun 2023 10:21:12 -0400 Subject: [PATCH 2/2] Generalize blockdiagmatrix_to_blocktensor to D>3 Found a clean solution at https://stackoverflow.com/questions/10831417/extracting-diagonal-blocks-from-a-numpy-array --- bayesnewton/ops.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/bayesnewton/ops.py b/bayesnewton/ops.py index 155cc2a..5bd781a 100644 --- a/bayesnewton/ops.py +++ b/bayesnewton/ops.py @@ -44,26 +44,9 @@ def get_3d_off_diag(offdiag_elems): def blockdiagmatrix_to_blocktensor(blockdiagmatrix, N, D): """ Convert [ND, ND] block-diagonal matrix to [N, D, D] tensor - TODO: extend to D>3 case + Code from https://stackoverflow.com/questions/10831417/extracting-diagonal-blocks-from-a-numpy-array """ - diags = vmap(np.diag)(np.diag(blockdiagmatrix).reshape(N, D)) - if D == 1: - blocktensors = diags - elif D == 2: - offdiag_elems = np.diag(np.concatenate([np.zeros([N * D, 1]), blockdiagmatrix], axis=1)).reshape(N, D)[:, 1:] - offdiags = offdiag_elems[..., None] * np.fliplr(np.eye(D)) - blocktensors = diags + offdiags - elif D == 3: - addzeros = np.concatenate([np.zeros([N * D, 1]), blockdiagmatrix], axis=1) - addzeros2 = np.concatenate([np.zeros([N * D, 1]), addzeros], axis=1) - offdiag_elems = np.diag(addzeros).reshape(N, D)[:, 1:] - offdiags = get_3d_off_diag(offdiag_elems) - corner_elements = np.diag(addzeros2).reshape(N, D)[:, 2:] - corners = corner_elements[..., None] * np.array([[0., 0., 1.], [0., 0., 0.], [1., 0., 0.]]) - blocktensors = diags + offdiags + corners - else: - raise NotImplementedError('Multi-latent case with D>3 not implemented') - return blocktensors + return np.array([blockdiagmatrix[i*D:(i+1)*D,i*D:(i+1)*D] for i in range(N)]) def gaussian_conditional(kernel, y, noise_cov, X, X_star=None):