Skip to content

Commit

Permalink
Merge pull request #25 from itskalvik/main
Browse files Browse the repository at this point in the history
Fix Independent kernel
  • Loading branch information
wil-j-wil authored Jun 22, 2023
2 parents 8ce9598 + 3014b7d commit a8836f1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 21 deletions.
4 changes: 2 additions & 2 deletions bayesnewton/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1586,11 +1586,11 @@ def __init__(self, kernels):
def K(self, X, X2):
zeros = np.zeros(self.num_kernels)
K0 = self.kernel0.K(X, X2)
index_vector = index_vector.at[0].add(1.)
index_vector = zeros.at[0].add(1.)
Kstack = np.kron(K0, np.diag(index_vector))
for i in range(1, self.num_kernels):
kerneli = eval("self.kernel" + str(i))
index_vector = index_vector.at[i].add(1.)
index_vector = zeros.at[i].add(1.)
Kstack += np.kron(kerneli.K(X, X2), np.diag(index_vector))
return Kstack

Expand Down
21 changes: 2 additions & 19 deletions bayesnewton/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,9 @@ def get_3d_off_diag(offdiag_elems):
def blockdiagmatrix_to_blocktensor(blockdiagmatrix, N, D):
"""
Convert [ND, ND] block-diagonal matrix to [N, D, D] tensor
TODO: extend to D>3 case
Code from https://stackoverflow.com/questions/10831417/extracting-diagonal-blocks-from-a-numpy-array
"""
diags = vmap(np.diag)(np.diag(blockdiagmatrix).reshape(N, D))
if D == 1:
blocktensors = diags
elif D == 2:
offdiag_elems = np.diag(np.concatenate([np.zeros([N * D, 1]), blockdiagmatrix], axis=1)).reshape(N, D)[:, 1:]
offdiags = offdiag_elems[..., None] * np.fliplr(np.eye(D))
blocktensors = diags + offdiags
elif D == 3:
addzeros = np.concatenate([np.zeros([N * D, 1]), blockdiagmatrix], axis=1)
addzeros2 = np.concatenate([np.zeros([N * D, 1]), addzeros], axis=1)
offdiag_elems = np.diag(addzeros).reshape(N, D)[:, 1:]
offdiags = get_3d_off_diag(offdiag_elems)
corner_elements = np.diag(addzeros2).reshape(N, D)[:, 2:]
corners = corner_elements[..., None] * np.array([[0., 0., 1.], [0., 0., 0.], [1., 0., 0.]])
blocktensors = diags + offdiags + corners
else:
raise NotImplementedError('Multi-latent case with D>3 not implemented')
return blocktensors
return np.array([blockdiagmatrix[i*D:(i+1)*D,i*D:(i+1)*D] for i in range(N)])


def gaussian_conditional(kernel, y, noise_cov, X, X_star=None):
Expand Down

0 comments on commit a8836f1

Please sign in to comment.