diff --git a/keypoint_moseq/util.py b/keypoint_moseq/util.py index f276dd8..b929f5e 100644 --- a/keypoint_moseq/util.py +++ b/keypoint_moseq/util.py @@ -35,64 +35,6 @@ def interpolate_keypoints(keypoints, outliers, axis=1): return np.moveaxis(keypoints.reshape(init_shape),0,axis) - -''' -def ppca(key, data, missing, num_components, num_iters=50, whiten=True): - batch_shape = data.shape[:-1] - data = data.reshape(-1, data.shape[-1]) - missing = missing.reshape(-1, data.shape[-1]) - N,D = data.shape - - observed = ~missing - total_missing = missing.sum() - means = masked_mean(data, observed) - stds = jnp.sqrt(masked_mean((data-means)**2, observed)) - data = (data - means) / stds - - # initial - C = jr.normal(key, (D, num_components)) - X = data @ C @ jnp.linalg.inv(C.T@C) - recon = jnp.where(observed, X @ C.T, 0) - ss = ((recon - data)**2).sum() / (N*D-total_missing) - - for itr in range(num_iters): - - # e-step - data = jnp.where(observed, data, X@C.T) - Sx = jnp.linalg.inv(jnp.eye(num_components) + C.T@C/ss) - X = data @ C @ Sx / ss - - # m-step - C = data.T @ X @ jnp.linalg.pinv(X.T@X + N*Sx) - recon = jnp.where(observed, X@C.T, 0) - ss = (((recon-data)**2).sum() + N*(C.T@C*Sx).sum() + total_missing*ss)/(N*D) - - U = jnp.linalg.svd(C)[0] - C = U[:,:num_components] - vals, vecs = jnp.linalg.eigh(jnp.cov((data @ C).T)) - order = jnp.flipud(jnp.argsort(vals)) - vecs = vecs[:, order] - vals = vals[order] - C = C @ vecs - latents = (data*stds)@C - - if whiten: - Sigma = jnp.cov(latents.T) - L = jnp.linalg.cholesky(Sigma) - latents = jnp.linalg.solve(L, latents.T).T - C = jnp.linalg.solve(L, C.T).T - return C, means, latents.reshape(*batch_shape, num_components) - - -def whiten(x): - shape,x = x.shape, x.reshape(-1,x.shape[-1]) - mu = x[mask.flatten()>0].mean(0) - Sigma = jnp.cov(x[mask.flatten()>0].T) - L = jnp.linalg.cholesky(Sigma) - x = jnp.linalg.solve(L, (x-mu).T).T - return x.reshape(shape), L -''' - def center_embedding(k): return jnp.linalg.svd(jnp.eye(k) - jnp.ones((k,k))/k)[0][:,:-1]