diff --git a/keypoint_moseq/distributions.py b/keypoint_moseq/distributions.py index 580d191..8033de6 100644 --- a/keypoint_moseq/distributions.py +++ b/keypoint_moseq/distributions.py @@ -18,7 +18,7 @@ def sample_chi2(key, degs): return jr.gamma(key, degs/2)*2 def sample_discrete(key, distn,dtype=jnp.int32): - return jr.categorical(key, jnp.log(distn+1e-16)) + return jr.categorical(key, jnp.log(distn)) def sample_mn(key, M, U, V): G = jr.normal(key,M.shape) @@ -72,10 +72,10 @@ def _backward_message(carry, args): return jax.lax.cond(mask_t>0, _sample, lambda args: (args[:-1],0), (key, next_potential, alphan_t)) init_distn = jnp.ones(pi.shape[0])/pi.shape[0] - alphan = jax.lax.scan(_forward_message, (init_distn,0.), (log_likelihoods, mask))[1] + (_,log_likelihood), alphan = jax.lax.scan(_forward_message, (init_distn,0.), (log_likelihoods, mask)) init_potential = jnp.ones(pi.shape[0]) _,stateseq = jax.lax.scan(_backward_message, (key,init_potential), (alphan,mask), reverse=True) - return stateseq + return stateseq, log_likelihood diff --git a/keypoint_moseq/gibbs.py b/keypoint_moseq/gibbs.py index 0a90842..4865efc 100644 --- a/keypoint_moseq/gibbs.py +++ b/keypoint_moseq/gibbs.py @@ -14,12 +14,12 @@ def resample_latents(key, *, Y, mask, v, h, z, s, Cd, sigmasq, Ab, Q, **kwargs): Cd = jnp.kron(Gamma, jnp.eye(Y.shape[-1])) @ Cd ys = inverse_affine_transform(Y,v,h).reshape(*Y.shape[:-2],-1) A, B, Q, C, D = *ar_to_lds(Ab[...,:-1],Ab[...,-1],Q,Cd[...,:-1]),Cd[...,-1] - R = jnp.repeat(s*sigmasq,Y.shape[-1],axis=-1)[:,nlags:] - mu0,S0 = jnp.zeros((n,d*nlags)),jnp.repeat(jnp.eye(d*nlags)[na]*10,n,axis=0) - xs = jax.vmap(kalman_sample, in_axes=(0,0,0,0,0,0,0,0,0,na,na,0))( - jr.split(key, ys.shape[0]), ys[:,nlags:], mask[:,nlags:], - mu0, S0, A[z], B[z], Q[z], jnp.linalg.inv(Q)[z], C, D, R) - xs = jnp.concatenate([xs[:,0,:-d].reshape(-1,nlags-1,d)[::-1], xs[:,:,-d:]],axis=1) + R = jnp.repeat(s*sigmasq,Y.shape[-1],axis=-1)[:,nlags-1:] + mu0,S0 = jnp.zeros(d*nlags),jnp.eye(d*nlags)*10 + xs = jax.vmap(kalman_sample, in_axes=(0,0,0,0,na,na,na,na,na,na,na,0))( + jr.split(key, ys.shape[0]), ys[:,nlags-1:], mask[:,nlags-1:-1], z, + mu0, S0, A, B, Q, C, D, R) + xs = jnp.concatenate([xs[:,0,:-d].reshape(-1,nlags-1,d), xs[:,:,-d:]],axis=1) return xs @jax.jit @@ -45,14 +45,13 @@ def resample_location(key, *, mask, Y, h, x, s, Cd, sigmasq, sigmasq_loc, **kwar mu = ((Y - (rot_matrix[...,na,:,:]*Ybar[...,na,:]).sum(-1)) \ *(gammasq[...,na]/(s*sigmasq))[...,na]).sum(-2) - m0, S0 = mu[:,0], gammasq[:,0][...,na,na]*jnp.eye(d) - As = jnp.tile(jnp.eye(d), (*mask.shape,1,1)) - Bs = jnp.zeros((*mask.shape,d)) - Qs = jnp.tile(jnp.eye(d), (*mask.shape,1,1))*sigmasq_loc - Qinvs = jnp.tile(jnp.eye(d), (*mask.shape,1,1))/sigmasq_loc - C,D,Rs = jnp.eye(d),jnp.zeros(d),gammasq[...,na]*jnp.ones(d) - return jax.vmap(kalman_sample, in_axes=(0,0,0,0,0,0,0,0,0,na,na,0))( - jr.split(key,mask.shape[0]), mu, mask, m0, S0, As, Bs, Qs, Qinvs, C, D, Rs)[...,:-1,:] + m0,S0 = jnp.zeros(d), jnp.eye(d)*1e6 + A,B,Q = jnp.eye(d)[na],jnp.zeros(d)[na],jnp.eye(d)[na]*sigmasq_loc + C,D,R = jnp.eye(d),jnp.zeros(d),gammasq[...,na]*jnp.ones(d) + z = jnp.zeros_like(mask[:,1:], dtype=int) + + return jax.vmap(kalman_sample, in_axes=(0,0,0,0,na,na,na,na,na,na,na,0))( + jr.split(key, mask.shape[0]), mu, mask[:,:-1], z, m0, S0, A, B, Q, C, D, R) @@ -109,11 +108,11 @@ def _ar_log_likelihood(x, params): def resample_stateseqs(key, *, x, mask, Ab, Q, pi, **kwargs): nlags = Ab.shape[2]//Ab.shape[1] log_likelihoods = jax.lax.map(partial(_ar_log_likelihood,x), (Ab, Q)) - stateseqs = jax.vmap(sample_hmm_stateseq, in_axes=(0,0,0,na))( + stateseqs, log_likelihoods = jax.vmap(sample_hmm_stateseq, in_axes=(0,0,0,na))( jr.split(key,mask.shape[0]), jnp.moveaxis(log_likelihoods,0,-1), mask.astype(float)[:,nlags:], pi) - return stateseqs + return stateseqs, log_likelihoods @jax.jit def resample_scales(key, *, x, v, h, Y, Cd, sigmasq, nu_s, s_0, **kwargs): diff --git a/keypoint_moseq/initialize.py b/keypoint_moseq/initialize.py index dfad034..a590fe9 100644 --- a/keypoint_moseq/initialize.py +++ b/keypoint_moseq/initialize.py @@ -2,30 +2,48 @@ from keypoint_moseq.transitions import sample_hdp_transitions, sample_transitions from keypoint_moseq.distributions import sample_mniw from keypoint_moseq.util import * +from sklearn.decomposition import PCA na = jnp.newaxis +''' +def initial_latents(*, Y, mask, v, h, latent_dim, num_samples=100000, **kwargs): + n,t,k,d = Y.shape + y = center_embedding(k).T @ inverse_affine_transform(Y,v,h) + yflat = y.reshape(t*n, (k-1)*d) + ysample = np.array(yflat)[np.random.choice(t*n,num_samples)] + pca = PCA(n_components=latent_dim, whiten=True).fit(ysample) + latents = jnp.array(pca.transform(yflat).reshape(n,t,latent_dim)) + Cd = jnp.array(jnp.hstack([pca.components_.T, pca.mean_[:,na]])) + return latents, Cd, pca +''' -def initial_location(*, Y, outliers, **kwargs): - m = (outliers==0)[...,na] * jnp.ones_like(Y) - v = masked_mean(Y, m, axis=-2) - return v.at[...,2:].set(0) +def initial_latents(*, Y, mask, v, h, latent_dim, num_samples=100000, whiten=True, pca=None, **kwargs): + n,t,k,d = Y.shape + y = center_embedding(k).T @ inverse_affine_transform(Y,v,h) + yflat = y.reshape(t*n, (k-1)*d) + ysample = np.array(yflat)[np.random.choice(t*n,num_samples)] + if pca is None: pca = PCA(n_components=latent_dim).fit(ysample) + latents_flat = jnp.array(pca.transform(yflat))[:,:latent_dim] + Cd = jnp.array(jnp.hstack([pca.components_.T, pca.mean_[:,na]])) + + if whiten: + cov = jnp.cov(latents_flat[mask.flatten()>0].T) + L = jnp.linalg.cholesky(cov) + Linv = jnp.linalg.inv(L) + latents_flat = latents_flat @ Linv.T + Cd = Cd.at[:,:-1].set(Cd[:,:-1] @ L) + + latents = latents_flat.reshape(n,t,latent_dim) + return latents, Cd, pca -def initial_heading(posterior_keypoints, anterior_keypoints, *, Y, outliers, **kwargs): - m = (outliers==0)[...,na] * jnp.ones_like(Y) - posterior_loc = masked_mean(Y[...,posterior_keypoints,:2], - m[...,posterior_keypoints,:2], axis=-2) - anterior_loc = masked_mean(Y[...,anterior_keypoints,:2], - m[...,anterior_keypoints,:2], axis=-2) +def initial_location(*, Y, **kwargs): + return Y.mean(-2).at[...,2:].set(0) + +def initial_heading(posterior_keypoints, anterior_keypoints, *, Y, **kwargs): + posterior_loc = Y[..., posterior_keypoints,:2].mean(-2) + anterior_loc = Y[..., anterior_keypoints,:2].mean(-2) return vector_to_angle(anterior_loc-posterior_loc) -def initial_latents(key, *, Y, outliers, mask, v, h, latent_dim, **kwargs): - y = inverse_affine_transform(Y,v,h).reshape(*Y.shape[:-2],-1) - missing = jnp.repeat(outliers,Y.shape[-1],axis=-1)>0 - components, means, latents = ppca(key, y, missing, latent_dim) - Gamma = center_embedding(Y.shape[-2]) - Cd = jnp.hstack([components, means[:,None]]) - Cd = jnp.kron(Gamma.T, jnp.eye(Y.shape[-1]))@Cd - return latents, Cd def initial_ar_params(key, *, num_states, nu_0, S_0, M_0, K_0, **kwargs): Ab,Q = jax.vmap(sample_mniw, in_axes=(0,na,na,na,na))( diff --git a/keypoint_moseq/kalman.py b/keypoint_moseq/kalman.py index 80c1e56..305d743 100644 --- a/keypoint_moseq/kalman.py +++ b/keypoint_moseq/kalman.py @@ -6,20 +6,23 @@ - +''' def kalman_filter(ys, mask, m0, S0, As, Bs, Qs, C, D, Rs): """ Run a Kalman filter to produce the marginal likelihood and filtered state estimates. """ - def _step(carry, args): - m_pred, S_pred = carry - A, B, Q, CRC, CRyD = args - # condition + + def _cond(m_pred, S_pred, CRC, CRyD): S_pred_inv = jnp.linalg.inv(S_pred) S_cond = jnp.linalg.inv(S_pred_inv + CRC) m_cond = S_cond @ (S_pred_inv @ m_pred + CRyD) - # predict + return m_cond, S_cond + + def _step(carry, args): + m_pred, S_pred = carry + A, B, Q, CRC, CRyD = args + m_cond, S_cond = _cond(m_pred, S_pred, CRC, CRyD) m_pred = A @ m_cond + B S_pred = ensure_symmetric(A @ S_cond @ A.T + Q) S_pred = S_pred + jnp.eye(S_pred.shape[0])*1e-4 @@ -32,10 +35,13 @@ def _masked_step(carry, args): CRCs = C.T@(C/Rs[...,na]) CRyDs = ((ys-D)/Rs)@C - _,(filtered_mus, filtered_Sigmas) = jax.lax.scan( + (m_pred, S_pred),(filtered_ms, filtered_Ss) = jax.lax.scan( lambda carry,args: jax.lax.cond(args[0]>0, _step, _masked_step, carry, args[1:]), - (m0, S0), (mask, As, Bs, Qs, CRCs, CRyDs)) - return filtered_mus, filtered_Sigmas + (m0, S0), (mask, As, Bs, Qs, CRCs[:-1], CRyDs[:-1])) + m_cond, S_cond = _cond(m_pred, S_pred, CRCs[-1], CRyDs[-1]) + filtered_ms = jnp.concatenate((filtered_ms,m_cond[na]),axis=0) + filtered_Ss = jnp.concatenate((filtered_Ss,S_cond[na]),axis=0) + return filtered_ms, filtered_Ss @@ -54,8 +60,8 @@ def _masked_step(x, args): # precompute and sample AQinvs = jnp.swapaxes(As,-2,-1)@Qinvs filtered_Sinvs = jax.lax.map(jnp.linalg.inv, filtered_Ss) - Ss = jax.lax.map(jnp.linalg.inv, filtered_Sinvs + AQinvs@As) - means = (Ss @ filtered_Sinvs @ filtered_ms[...,na])[...,0] + Ss = jax.lax.map(jnp.linalg.inv, filtered_Sinvs[:-1] + AQinvs@As) + means = (Ss @ filtered_Sinvs[:-1] @ filtered_ms[:-1,...,na])[...,0] samples = jr.multivariate_normal(rng, means, Ss) SAQinvs = Ss @ AQinvs @@ -67,5 +73,85 @@ def _masked_step(x, args): _, xs = jax.lax.scan(lambda carry,args: jax.lax.cond( args[0]>0, _step, _masked_step, carry, args[1:]), x, args, reverse=True) return jnp.vstack([xs, x]) +''' + + +def kalman_filter(ys, mask, zs, m0, S0, A, B, Q, C, D, Rs): + """ + Run a Kalman filter to produce the marginal likelihood and filtered state + estimates. + """ + + def _predict(m, S, A, B, Q): + mu_pred = A @ m + B + Sigma_pred = A @ S @ A.T + Q + return mu_pred, Sigma_pred + + def _condition_on(m, S, C, D, R, y): + Sinv = jnp.linalg.inv(S) + S_cond = jnp.linalg.inv(Sinv + (C.T / R) @ C) + m_cond = S_cond @ (Sinv @ m + (C.T / R) @ (y-D)) + return m_cond, S_cond + + def _step(carry, args): + m_pred, S_pred = carry + z, y, R = args + + m_cond, S_cond = _condition_on( + m_pred, S_pred, C, D, R, y) + + m_pred, S_pred = _predict( + m_cond, S_cond, A[z], B[z], Q[z]) + + return (m_pred, S_pred), (m_cond, S_cond) + + def _masked_step(carry, args): + m_pred, S_pred = carry + return (m_pred, S_pred), (m_pred, S_pred) + + + (m_pred, S_pred),(filtered_ms, filtered_Ss) = jax.lax.scan( + lambda carry,args: jax.lax.cond(args[0]>0, _step, _masked_step, carry, args[1:]), + (m0, S0), (mask, zs, ys[:-1], Rs[:-1])) + m_cond, S_cond = _condition_on(m_pred, S_pred, C, D, Rs[-1], ys[-1]) + filtered_ms = jnp.concatenate((filtered_ms,m_cond[na]),axis=0) + filtered_Ss = jnp.concatenate((filtered_Ss,S_cond[na]),axis=0) + return filtered_ms, filtered_Ss + + +@jax.jit +def kalman_sample(rng, ys, mask, zs, m0, S0, A, B, Q, C, D, Rs): + + # run the kalman filter + filtered_ms, filtered_Ss = kalman_filter(ys, mask, zs, m0, S0, A, B, Q, C, D, Rs) + + def _condition_on(m, S, A, B, Qinv, x): + Sinv = jnp.linalg.inv(S) + S_cond = jnp.linalg.inv(Sinv + A.T @ Qinv @ A) + m_cond = S_cond @ (Sinv @ m + A.T @ Qinv @ (x-B)) + return m_cond, S_cond + + def _step(x, args): + m_pred, S_pred, z, w = args + m_cond, S_cond = _condition_on(m_pred, S_pred, A[z], B[z], Qinv[z], x) + L = jnp.linalg.cholesky(S_cond) + x = L @ w + m_cond + return x, x + + def _masked_step(x, args): + return x,jnp.zeros_like(x) + + # precompute and sample + Qinv = jnp.linalg.inv(Q) + samples = jr.normal(rng, filtered_ms[:-1].shape) + + # initialize the last state + x = jr.multivariate_normal(rng, filtered_ms[-1], filtered_Ss[-1]) + + # scan (reverse direction) + args = (mask, filtered_ms[:-1], filtered_Ss[:-1], zs, samples) + _, xs = jax.lax.scan(lambda carry,args: jax.lax.cond( + args[0]>0, _step, _masked_step, carry, args[1:]), x, args, reverse=True) + return jnp.vstack([xs, x]) diff --git a/keypoint_moseq/util.py b/keypoint_moseq/util.py index 739f225..f276dd8 100644 --- a/keypoint_moseq/util.py +++ b/keypoint_moseq/util.py @@ -8,10 +8,36 @@ from functools import partial na = jnp.newaxis -def masked_mean(X, m, axis=0): - return (X*m).sum(axis) / (m.sum(axis)+1e-10) +def expected_keypoints(*, Y, v, h, x, Cd, **kwargs): + k,d = Y.shape[-2:] + Gamma = center_embedding(k) + Ybar = Gamma @ (pad_affine(x)@Cd.T).reshape(*Y.shape[:-2],k-1,d) + Yexp = affine_transform(Ybar,v,h) + return Yexp + + +def pca_transform(components, means, *, Y, v, h, **kwargs): + y = inverse_affine_transform(Y,v,h).reshape(*Y.shape[:-2],-1) + return (y-means) @ components + +def interpolate_keypoints(keypoints, outliers, axis=1): + keypoints = np.moveaxis(keypoints, axis, 0) + outliers = np.moveaxis(outliers, axis, 0) + init_shape = keypoints.shape + outliers = np.repeat(outliers[...,None],init_shape[-1],axis=-1) + keypoints = keypoints.reshape(init_shape[0],-1) + outliers = outliers.reshape(init_shape[0],-1) + for i in range(keypoints.shape[1]): + keypoints[:,i] = np.interp( + np.arange(init_shape[0]), + np.nonzero(~outliers[:,i])[0], + keypoints[:,i][~outliers[:,i]]) + return np.moveaxis(keypoints.reshape(init_shape),0,axis) + + -def ppca(key, data, missing, num_components, num_iters=50): +''' +def ppca(key, data, missing, num_components, num_iters=50, whiten=True): batch_shape = data.shape[:-1] data = data.reshape(-1, data.shape[-1]) missing = missing.reshape(-1, data.shape[-1]) @@ -49,17 +75,23 @@ def ppca(key, data, missing, num_components, num_iters=50): vals = vals[order] C = C @ vecs latents = (data*stds)@C + + if whiten: + Sigma = jnp.cov(latents.T) + L = jnp.linalg.cholesky(Sigma) + latents = jnp.linalg.solve(L, latents.T).T + C = jnp.linalg.solve(L, C.T).T return C, means, latents.reshape(*batch_shape, num_components) -@jax.jit -def obs_log_prob(*, Y, mask, x, s, v, h, Cd, sigmasq, **kwargs): - k,d = Y.shape[-2:] - Gamma = center_embedding(k) - Ybar = Gamma @ (pad_affine(x)@Cd.T).reshape(*Y.shape[:-2],k-1,d) - sqerr = ((Y - affine_transform(Ybar,v,h))**2).sum(-1) - return -1/2 * sqerr/s/sigmasq - 3/2 * jnp.log(s*sigmasq*jnp.pi) - +def whiten(x): + shape,x = x.shape, x.reshape(-1,x.shape[-1]) + mu = x[mask.flatten()>0].mean(0) + Sigma = jnp.cov(x[mask.flatten()>0].T) + L = jnp.linalg.cholesky(Sigma) + x = jnp.linalg.solve(L, (x-mu).T).T + return x.reshape(shape), L +''' def center_embedding(k): return jnp.linalg.svd(jnp.eye(k) - jnp.ones((k,k))/k)[0][:,:-1] @@ -74,15 +106,15 @@ def merge_data(data_dict, keys=None, batch_length=None): if keys is None: keys = sorted(data_dict.keys()) max_length = np.max([data_dict[k].shape[0] for k in keys]) if batch_length is None: batch_length = max_length - else: max_length = int(np.ceil(max_length/batch_length)*batch_length) - + def reshape(x): - x = np.concatenate([x, np.zeros((max_length-x.shape[0],*x.shape[1:]))],axis=0) + padding = (-x.shape[0])%batch_length + x = np.concatenate([x, np.zeros((padding,*x.shape[1:]))],axis=0) return x.reshape(-1, batch_length, *x.shape[1:]) data = np.concatenate([reshape(data_dict[k]) for k in keys],axis=0) mask = np.concatenate([reshape(np.ones(data_dict[k].shape[0])) for k in keys],axis=0) - keys = [(k,i) for k in keys for i in range(int(len(data_dict[k])/batch_length+1))] + keys = [(k,i) for k in keys for i in range(int(np.ceil(len(data_dict[k])/batch_length)))] return data, mask, keys def ensure_symmetric(X): @@ -255,4 +287,48 @@ def ar_to_lds(As, bs, Qs, Cs): C_ = C_.at[:,-k:].set(Cs) return A_, b_, Q_, C_ +def gaussian_log_prob(x, mu, sigma_inv): + return (-((mu-x)[...,na,:]*sigma_inv*(mu-x)[...,:,na]).sum((-1,-2))/2 + +jnp.log(jnp.linalg.det(sigma_inv))/2) + + +def latent_log_prob(*, x, z, Ab, Q, **kwargs): + Qinv = jnp.linalg.inv(Q) + Qdet = jnp.linalg.det(Q) + + nlags = Ab.shape[2]//Ab.shape[1] + x_lagged = get_lags(x, nlags) + x_pred = (Ab[z] @ pad_affine(x_lagged)[...,na])[...,0] + + d = x_pred - x[:,nlags:] + return (-(d[...,na,:]*Qinv[z]*d[...,:,na]).sum((2,3))/2 + -jnp.log(Qdet[z])/2 -jnp.log(2*jnp.pi)*Q.shape[-1]/2) + +def stateseq_log_prob(*, z, pi, **kwargs): + return jnp.log(pi[z[:,:-1],z[:,1:]]) + +def scale_log_prob(*, s, nu_s, s_0, **kwargs): + return -nu_s*s_0 / s / 2 - (1+nu_s/2)*jnp.log(s) + +def location_log_prob(*, v, sigmasq_loc): + d = v[:,:-1]-v[:,1:] + return (-(d**2).sum(-1)/sigmasq_loc/2 + -v.shape[-1]/2*jnp.log(sigmasq_loc*2*jnp.pi)) + +def obs_log_prob(*, Y, x, s, v, h, Cd, sigmasq, **kwargs): + k,d = Y.shape[-2:] + Gamma = center_embedding(k) + Ybar = Gamma @ (pad_affine(x)@Cd.T).reshape(*Y.shape[:-2],k-1,d) + sqerr = ((Y - affine_transform(Ybar,v,h))**2).sum(-1) + return (-1/2 * sqerr/s/sigmasq - d/2 * jnp.log(2*s*sigmasq*jnp.pi)) + +@jax.jit +def log_joint_likelihood(*, Y, mask, x, s, v, h, z, pi, Ab, Q, Cd, sigmasq, sigmasq_loc, nu_s, s_0, **kwargs): + nlags = Ab.shape[2]//Ab.shape[1] + return { + 'Y': (obs_log_prob(Y=Y, x=x, s=s, v=v, h=h, Cd=Cd, sigmasq=sigmasq)*mask[:,:,na]).sum(), + 'x': (latent_log_prob(x=x, z=z, Ab=Ab, Q=Q)*mask[:,nlags:]).sum(), + 'z': (stateseq_log_prob(z=z, pi=pi)*mask[:,nlags+1:]).sum(), + 'v': (location_log_prob(v=v, sigmasq_loc=sigmasq_loc)*mask[:,1:]).sum(), + 's': (scale_log_prob(s=s, nu_s=nu_s, s_0=s_0)*mask[:,:,na]).sum()}