Skip to content

Commit

Permalink
Generalize blockdiagmatrix_to_blocktensor to D>3
Browse files Browse the repository at this point in the history
  • Loading branch information
itskalvik authored Jun 15, 2023
1 parent a3a0589 commit 3014b7d
Showing 1 changed file with 2 additions and 19 deletions.
21 changes: 2 additions & 19 deletions bayesnewton/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,9 @@ def get_3d_off_diag(offdiag_elems):
def blockdiagmatrix_to_blocktensor(blockdiagmatrix, N, D):
"""
Convert [ND, ND] block-diagonal matrix to [N, D, D] tensor
TODO: extend to D>3 case
Code from https://stackoverflow.com/questions/10831417/extracting-diagonal-blocks-from-a-numpy-array
"""
diags = vmap(np.diag)(np.diag(blockdiagmatrix).reshape(N, D))
if D == 1:
blocktensors = diags
elif D == 2:
offdiag_elems = np.diag(np.concatenate([np.zeros([N * D, 1]), blockdiagmatrix], axis=1)).reshape(N, D)[:, 1:]
offdiags = offdiag_elems[..., None] * np.fliplr(np.eye(D))
blocktensors = diags + offdiags
elif D == 3:
addzeros = np.concatenate([np.zeros([N * D, 1]), blockdiagmatrix], axis=1)
addzeros2 = np.concatenate([np.zeros([N * D, 1]), addzeros], axis=1)
offdiag_elems = np.diag(addzeros).reshape(N, D)[:, 1:]
offdiags = get_3d_off_diag(offdiag_elems)
corner_elements = np.diag(addzeros2).reshape(N, D)[:, 2:]
corners = corner_elements[..., None] * np.array([[0., 0., 1.], [0., 0., 0.], [1., 0., 0.]])
blocktensors = diags + offdiags + corners
else:
raise NotImplementedError('Multi-latent case with D>3 not implemented')
return blocktensors
return np.array([blockdiagmatrix[i*D:(i+1)*D,i*D:(i+1)*D] for i in range(N)])


def gaussian_conditional(kernel, y, noise_cov, X, X_star=None):
Expand Down

0 comments on commit 3014b7d

Please sign in to comment.