Skip to content

Commit

Permalink
removed unused util code
Browse files Browse the repository at this point in the history
  • Loading branch information
Weinreb committed Oct 8, 2022
1 parent 58657cc commit 5f245c4
Showing 1 changed file with 0 additions and 58 deletions.
58 changes: 0 additions & 58 deletions keypoint_moseq/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,64 +35,6 @@ def interpolate_keypoints(keypoints, outliers, axis=1):
return np.moveaxis(keypoints.reshape(init_shape),0,axis)



'''
def ppca(key, data, missing, num_components, num_iters=50, whiten=True):
batch_shape = data.shape[:-1]
data = data.reshape(-1, data.shape[-1])
missing = missing.reshape(-1, data.shape[-1])
N,D = data.shape
observed = ~missing
total_missing = missing.sum()
means = masked_mean(data, observed)
stds = jnp.sqrt(masked_mean((data-means)**2, observed))
data = (data - means) / stds
# initial
C = jr.normal(key, (D, num_components))
X = data @ C @ jnp.linalg.inv(C.T@C)
recon = jnp.where(observed, X @ C.T, 0)
ss = ((recon - data)**2).sum() / (N*D-total_missing)
for itr in range(num_iters):
# e-step
data = jnp.where(observed, data, [email protected])
Sx = jnp.linalg.inv(jnp.eye(num_components) + C.T@C/ss)
X = data @ C @ Sx / ss
# m-step
C = data.T @ X @ jnp.linalg.pinv(X.T@X + N*Sx)
recon = jnp.where(observed, [email protected], 0)
ss = (((recon-data)**2).sum() + N*(C.T@C*Sx).sum() + total_missing*ss)/(N*D)
U = jnp.linalg.svd(C)[0]
C = U[:,:num_components]
vals, vecs = jnp.linalg.eigh(jnp.cov((data @ C).T))
order = jnp.flipud(jnp.argsort(vals))
vecs = vecs[:, order]
vals = vals[order]
C = C @ vecs
latents = (data*stds)@C
if whiten:
Sigma = jnp.cov(latents.T)
L = jnp.linalg.cholesky(Sigma)
latents = jnp.linalg.solve(L, latents.T).T
C = jnp.linalg.solve(L, C.T).T
return C, means, latents.reshape(*batch_shape, num_components)
def whiten(x):
shape,x = x.shape, x.reshape(-1,x.shape[-1])
mu = x[mask.flatten()>0].mean(0)
Sigma = jnp.cov(x[mask.flatten()>0].T)
L = jnp.linalg.cholesky(Sigma)
x = jnp.linalg.solve(L, (x-mu).T).T
return x.reshape(shape), L
'''

def center_embedding(k):
return jnp.linalg.svd(jnp.eye(k) - jnp.ones((k,k))/k)[0][:,:-1]

Expand Down

0 comments on commit 5f245c4

Please sign in to comment.