diff --git a/bayesnewton/kernels.py b/bayesnewton/kernels.py index 8a97509..8963233 100644 --- a/bayesnewton/kernels.py +++ b/bayesnewton/kernels.py @@ -1586,11 +1586,11 @@ def __init__(self, kernels): def K(self, X, X2): zeros = np.zeros(self.num_kernels) K0 = self.kernel0.K(X, X2) - index_vector = index_vector.at[0].add(1.) + index_vector = zeros.at[0].add(1.) Kstack = np.kron(K0, np.diag(index_vector)) for i in range(1, self.num_kernels): kerneli = eval("self.kernel" + str(i)) - index_vector = index_vector.at[i].add(1.) + index_vector = zeros.at[i].add(1.) Kstack += np.kron(kerneli.K(X, X2), np.diag(index_vector)) return Kstack diff --git a/bayesnewton/ops.py b/bayesnewton/ops.py index 155cc2a..5bd781a 100644 --- a/bayesnewton/ops.py +++ b/bayesnewton/ops.py @@ -44,26 +44,9 @@ def get_3d_off_diag(offdiag_elems): def blockdiagmatrix_to_blocktensor(blockdiagmatrix, N, D): """ Convert [ND, ND] block-diagonal matrix to [N, D, D] tensor - TODO: extend to D>3 case + Code from https://stackoverflow.com/questions/10831417/extracting-diagonal-blocks-from-a-numpy-array """ - diags = vmap(np.diag)(np.diag(blockdiagmatrix).reshape(N, D)) - if D == 1: - blocktensors = diags - elif D == 2: - offdiag_elems = np.diag(np.concatenate([np.zeros([N * D, 1]), blockdiagmatrix], axis=1)).reshape(N, D)[:, 1:] - offdiags = offdiag_elems[..., None] * np.fliplr(np.eye(D)) - blocktensors = diags + offdiags - elif D == 3: - addzeros = np.concatenate([np.zeros([N * D, 1]), blockdiagmatrix], axis=1) - addzeros2 = np.concatenate([np.zeros([N * D, 1]), addzeros], axis=1) - offdiag_elems = np.diag(addzeros).reshape(N, D)[:, 1:] - offdiags = get_3d_off_diag(offdiag_elems) - corner_elements = np.diag(addzeros2).reshape(N, D)[:, 2:] - corners = corner_elements[..., None] * np.array([[0., 0., 1.], [0., 0., 0.], [1., 0., 0.]]) - blocktensors = diags + offdiags + corners - else: - raise NotImplementedError('Multi-latent case with D>3 not implemented') - return blocktensors + return np.array([blockdiagmatrix[i*D:(i+1)*D,i*D:(i+1)*D] for i in range(N)]) def gaussian_conditional(kernel, y, noise_cov, X, X_star=None):