Skip to content

Commit

Permalink
bugfixes, etc. changed a lot
Browse files Browse the repository at this point in the history
  • Loading branch information
Weinreb committed Sep 23, 2022
1 parent f995f2a commit 58657cc
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 63 deletions.
6 changes: 3 additions & 3 deletions keypoint_moseq/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def sample_chi2(key, degs):
return jr.gamma(key, degs/2)*2

def sample_discrete(key, distn,dtype=jnp.int32):
return jr.categorical(key, jnp.log(distn+1e-16))
return jr.categorical(key, jnp.log(distn))

def sample_mn(key, M, U, V):
G = jr.normal(key,M.shape)
Expand Down Expand Up @@ -72,10 +72,10 @@ def _backward_message(carry, args):
return jax.lax.cond(mask_t>0, _sample, lambda args: (args[:-1],0), (key, next_potential, alphan_t))

init_distn = jnp.ones(pi.shape[0])/pi.shape[0]
alphan = jax.lax.scan(_forward_message, (init_distn,0.), (log_likelihoods, mask))[1]
(_,log_likelihood), alphan = jax.lax.scan(_forward_message, (init_distn,0.), (log_likelihoods, mask))

init_potential = jnp.ones(pi.shape[0])
_,stateseq = jax.lax.scan(_backward_message, (key,init_potential), (alphan,mask), reverse=True)
return stateseq
return stateseq, log_likelihood


31 changes: 15 additions & 16 deletions keypoint_moseq/gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ def resample_latents(key, *, Y, mask, v, h, z, s, Cd, sigmasq, Ab, Q, **kwargs):
Cd = jnp.kron(Gamma, jnp.eye(Y.shape[-1])) @ Cd
ys = inverse_affine_transform(Y,v,h).reshape(*Y.shape[:-2],-1)
A, B, Q, C, D = *ar_to_lds(Ab[...,:-1],Ab[...,-1],Q,Cd[...,:-1]),Cd[...,-1]
R = jnp.repeat(s*sigmasq,Y.shape[-1],axis=-1)[:,nlags:]
mu0,S0 = jnp.zeros((n,d*nlags)),jnp.repeat(jnp.eye(d*nlags)[na]*10,n,axis=0)
xs = jax.vmap(kalman_sample, in_axes=(0,0,0,0,0,0,0,0,0,na,na,0))(
jr.split(key, ys.shape[0]), ys[:,nlags:], mask[:,nlags:],
mu0, S0, A[z], B[z], Q[z], jnp.linalg.inv(Q)[z], C, D, R)
xs = jnp.concatenate([xs[:,0,:-d].reshape(-1,nlags-1,d)[::-1], xs[:,:,-d:]],axis=1)
R = jnp.repeat(s*sigmasq,Y.shape[-1],axis=-1)[:,nlags-1:]
mu0,S0 = jnp.zeros(d*nlags),jnp.eye(d*nlags)*10
xs = jax.vmap(kalman_sample, in_axes=(0,0,0,0,na,na,na,na,na,na,na,0))(
jr.split(key, ys.shape[0]), ys[:,nlags-1:], mask[:,nlags-1:-1], z,
mu0, S0, A, B, Q, C, D, R)
xs = jnp.concatenate([xs[:,0,:-d].reshape(-1,nlags-1,d), xs[:,:,-d:]],axis=1)
return xs

@jax.jit
Expand All @@ -45,14 +45,13 @@ def resample_location(key, *, mask, Y, h, x, s, Cd, sigmasq, sigmasq_loc, **kwar
mu = ((Y - (rot_matrix[...,na,:,:]*Ybar[...,na,:]).sum(-1)) \
*(gammasq[...,na]/(s*sigmasq))[...,na]).sum(-2)

m0, S0 = mu[:,0], gammasq[:,0][...,na,na]*jnp.eye(d)
As = jnp.tile(jnp.eye(d), (*mask.shape,1,1))
Bs = jnp.zeros((*mask.shape,d))
Qs = jnp.tile(jnp.eye(d), (*mask.shape,1,1))*sigmasq_loc
Qinvs = jnp.tile(jnp.eye(d), (*mask.shape,1,1))/sigmasq_loc
C,D,Rs = jnp.eye(d),jnp.zeros(d),gammasq[...,na]*jnp.ones(d)
return jax.vmap(kalman_sample, in_axes=(0,0,0,0,0,0,0,0,0,na,na,0))(
jr.split(key,mask.shape[0]), mu, mask, m0, S0, As, Bs, Qs, Qinvs, C, D, Rs)[...,:-1,:]
m0,S0 = jnp.zeros(d), jnp.eye(d)*1e6
A,B,Q = jnp.eye(d)[na],jnp.zeros(d)[na],jnp.eye(d)[na]*sigmasq_loc
C,D,R = jnp.eye(d),jnp.zeros(d),gammasq[...,na]*jnp.ones(d)
z = jnp.zeros_like(mask[:,1:], dtype=int)

return jax.vmap(kalman_sample, in_axes=(0,0,0,0,na,na,na,na,na,na,na,0))(
jr.split(key, mask.shape[0]), mu, mask[:,:-1], z, m0, S0, A, B, Q, C, D, R)



Expand Down Expand Up @@ -109,11 +108,11 @@ def _ar_log_likelihood(x, params):
def resample_stateseqs(key, *, x, mask, Ab, Q, pi, **kwargs):
nlags = Ab.shape[2]//Ab.shape[1]
log_likelihoods = jax.lax.map(partial(_ar_log_likelihood,x), (Ab, Q))
stateseqs = jax.vmap(sample_hmm_stateseq, in_axes=(0,0,0,na))(
stateseqs, log_likelihoods = jax.vmap(sample_hmm_stateseq, in_axes=(0,0,0,na))(
jr.split(key,mask.shape[0]),
jnp.moveaxis(log_likelihoods,0,-1),
mask.astype(float)[:,nlags:], pi)
return stateseqs
return stateseqs, log_likelihoods

@jax.jit
def resample_scales(key, *, x, v, h, Y, Cd, sigmasq, nu_s, s_0, **kwargs):
Expand Down
54 changes: 36 additions & 18 deletions keypoint_moseq/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,48 @@
from keypoint_moseq.transitions import sample_hdp_transitions, sample_transitions
from keypoint_moseq.distributions import sample_mniw
from keypoint_moseq.util import *
from sklearn.decomposition import PCA
na = jnp.newaxis

'''
def initial_latents(*, Y, mask, v, h, latent_dim, num_samples=100000, **kwargs):
n,t,k,d = Y.shape
y = center_embedding(k).T @ inverse_affine_transform(Y,v,h)
yflat = y.reshape(t*n, (k-1)*d)
ysample = np.array(yflat)[np.random.choice(t*n,num_samples)]
pca = PCA(n_components=latent_dim, whiten=True).fit(ysample)
latents = jnp.array(pca.transform(yflat).reshape(n,t,latent_dim))
Cd = jnp.array(jnp.hstack([pca.components_.T, pca.mean_[:,na]]))
return latents, Cd, pca
'''

def initial_location(*, Y, outliers, **kwargs):
m = (outliers==0)[...,na] * jnp.ones_like(Y)
v = masked_mean(Y, m, axis=-2)
return v.at[...,2:].set(0)
def initial_latents(*, Y, mask, v, h, latent_dim, num_samples=100000, whiten=True, pca=None, **kwargs):
n,t,k,d = Y.shape
y = center_embedding(k).T @ inverse_affine_transform(Y,v,h)
yflat = y.reshape(t*n, (k-1)*d)
ysample = np.array(yflat)[np.random.choice(t*n,num_samples)]
if pca is None: pca = PCA(n_components=latent_dim).fit(ysample)
latents_flat = jnp.array(pca.transform(yflat))[:,:latent_dim]
Cd = jnp.array(jnp.hstack([pca.components_.T, pca.mean_[:,na]]))

if whiten:
cov = jnp.cov(latents_flat[mask.flatten()>0].T)
L = jnp.linalg.cholesky(cov)
Linv = jnp.linalg.inv(L)
latents_flat = latents_flat @ Linv.T
Cd = Cd.at[:,:-1].set(Cd[:,:-1] @ L)

latents = latents_flat.reshape(n,t,latent_dim)
return latents, Cd, pca

def initial_heading(posterior_keypoints, anterior_keypoints, *, Y, outliers, **kwargs):
m = (outliers==0)[...,na] * jnp.ones_like(Y)
posterior_loc = masked_mean(Y[...,posterior_keypoints,:2],
m[...,posterior_keypoints,:2], axis=-2)
anterior_loc = masked_mean(Y[...,anterior_keypoints,:2],
m[...,anterior_keypoints,:2], axis=-2)
def initial_location(*, Y, **kwargs):
return Y.mean(-2).at[...,2:].set(0)

def initial_heading(posterior_keypoints, anterior_keypoints, *, Y, **kwargs):
posterior_loc = Y[..., posterior_keypoints,:2].mean(-2)
anterior_loc = Y[..., anterior_keypoints,:2].mean(-2)
return vector_to_angle(anterior_loc-posterior_loc)

def initial_latents(key, *, Y, outliers, mask, v, h, latent_dim, **kwargs):
y = inverse_affine_transform(Y,v,h).reshape(*Y.shape[:-2],-1)
missing = jnp.repeat(outliers,Y.shape[-1],axis=-1)>0
components, means, latents = ppca(key, y, missing, latent_dim)
Gamma = center_embedding(Y.shape[-2])
Cd = jnp.hstack([components, means[:,None]])
Cd = jnp.kron(Gamma.T, jnp.eye(Y.shape[-1]))@Cd
return latents, Cd

def initial_ar_params(key, *, num_states, nu_0, S_0, M_0, K_0, **kwargs):
Ab,Q = jax.vmap(sample_mniw, in_axes=(0,na,na,na,na))(
Expand Down
108 changes: 97 additions & 11 deletions keypoint_moseq/kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,23 @@




'''
def kalman_filter(ys, mask, m0, S0, As, Bs, Qs, C, D, Rs):
"""
Run a Kalman filter to produce the marginal likelihood and filtered state
estimates.
"""
def _step(carry, args):
m_pred, S_pred = carry
A, B, Q, CRC, CRyD = args
# condition
def _cond(m_pred, S_pred, CRC, CRyD):
S_pred_inv = jnp.linalg.inv(S_pred)
S_cond = jnp.linalg.inv(S_pred_inv + CRC)
m_cond = S_cond @ (S_pred_inv @ m_pred + CRyD)
# predict
return m_cond, S_cond
def _step(carry, args):
m_pred, S_pred = carry
A, B, Q, CRC, CRyD = args
m_cond, S_cond = _cond(m_pred, S_pred, CRC, CRyD)
m_pred = A @ m_cond + B
S_pred = ensure_symmetric(A @ S_cond @ A.T + Q)
S_pred = S_pred + jnp.eye(S_pred.shape[0])*1e-4
Expand All @@ -32,10 +35,13 @@ def _masked_step(carry, args):
CRCs = C.T@(C/Rs[...,na])
CRyDs = ((ys-D)/Rs)@C
_,(filtered_mus, filtered_Sigmas) = jax.lax.scan(
(m_pred, S_pred),(filtered_ms, filtered_Ss) = jax.lax.scan(
lambda carry,args: jax.lax.cond(args[0]>0, _step, _masked_step, carry, args[1:]),
(m0, S0), (mask, As, Bs, Qs, CRCs, CRyDs))
return filtered_mus, filtered_Sigmas
(m0, S0), (mask, As, Bs, Qs, CRCs[:-1], CRyDs[:-1]))
m_cond, S_cond = _cond(m_pred, S_pred, CRCs[-1], CRyDs[-1])
filtered_ms = jnp.concatenate((filtered_ms,m_cond[na]),axis=0)
filtered_Ss = jnp.concatenate((filtered_Ss,S_cond[na]),axis=0)
return filtered_ms, filtered_Ss
Expand All @@ -54,8 +60,8 @@ def _masked_step(x, args):
# precompute and sample
AQinvs = jnp.swapaxes(As,-2,-1)@Qinvs
filtered_Sinvs = jax.lax.map(jnp.linalg.inv, filtered_Ss)
Ss = jax.lax.map(jnp.linalg.inv, filtered_Sinvs + AQinvs@As)
means = (Ss @ filtered_Sinvs @ filtered_ms[...,na])[...,0]
Ss = jax.lax.map(jnp.linalg.inv, filtered_Sinvs[:-1] + AQinvs@As)
means = (Ss @ filtered_Sinvs[:-1] @ filtered_ms[:-1,...,na])[...,0]
samples = jr.multivariate_normal(rng, means, Ss)
SAQinvs = Ss @ AQinvs
Expand All @@ -67,5 +73,85 @@ def _masked_step(x, args):
_, xs = jax.lax.scan(lambda carry,args: jax.lax.cond(
args[0]>0, _step, _masked_step, carry, args[1:]), x, args, reverse=True)
return jnp.vstack([xs, x])
'''



def kalman_filter(ys, mask, zs, m0, S0, A, B, Q, C, D, Rs):
"""
Run a Kalman filter to produce the marginal likelihood and filtered state
estimates.
"""

def _predict(m, S, A, B, Q):
mu_pred = A @ m + B
Sigma_pred = A @ S @ A.T + Q
return mu_pred, Sigma_pred

def _condition_on(m, S, C, D, R, y):
Sinv = jnp.linalg.inv(S)
S_cond = jnp.linalg.inv(Sinv + (C.T / R) @ C)
m_cond = S_cond @ (Sinv @ m + (C.T / R) @ (y-D))
return m_cond, S_cond

def _step(carry, args):
m_pred, S_pred = carry
z, y, R = args

m_cond, S_cond = _condition_on(
m_pred, S_pred, C, D, R, y)

m_pred, S_pred = _predict(
m_cond, S_cond, A[z], B[z], Q[z])

return (m_pred, S_pred), (m_cond, S_cond)

def _masked_step(carry, args):
m_pred, S_pred = carry
return (m_pred, S_pred), (m_pred, S_pred)


(m_pred, S_pred),(filtered_ms, filtered_Ss) = jax.lax.scan(
lambda carry,args: jax.lax.cond(args[0]>0, _step, _masked_step, carry, args[1:]),
(m0, S0), (mask, zs, ys[:-1], Rs[:-1]))
m_cond, S_cond = _condition_on(m_pred, S_pred, C, D, Rs[-1], ys[-1])
filtered_ms = jnp.concatenate((filtered_ms,m_cond[na]),axis=0)
filtered_Ss = jnp.concatenate((filtered_Ss,S_cond[na]),axis=0)
return filtered_ms, filtered_Ss


@jax.jit
def kalman_sample(rng, ys, mask, zs, m0, S0, A, B, Q, C, D, Rs):

# run the kalman filter
filtered_ms, filtered_Ss = kalman_filter(ys, mask, zs, m0, S0, A, B, Q, C, D, Rs)

def _condition_on(m, S, A, B, Qinv, x):
Sinv = jnp.linalg.inv(S)
S_cond = jnp.linalg.inv(Sinv + A.T @ Qinv @ A)
m_cond = S_cond @ (Sinv @ m + A.T @ Qinv @ (x-B))
return m_cond, S_cond

def _step(x, args):
m_pred, S_pred, z, w = args
m_cond, S_cond = _condition_on(m_pred, S_pred, A[z], B[z], Qinv[z], x)
L = jnp.linalg.cholesky(S_cond)
x = L @ w + m_cond
return x, x

def _masked_step(x, args):
return x,jnp.zeros_like(x)

# precompute and sample
Qinv = jnp.linalg.inv(Q)
samples = jr.normal(rng, filtered_ms[:-1].shape)

# initialize the last state
x = jr.multivariate_normal(rng, filtered_ms[-1], filtered_Ss[-1])

# scan (reverse direction)
args = (mask, filtered_ms[:-1], filtered_Ss[:-1], zs, samples)
_, xs = jax.lax.scan(lambda carry,args: jax.lax.cond(
args[0]>0, _step, _masked_step, carry, args[1:]), x, args, reverse=True)
return jnp.vstack([xs, x])

Loading

0 comments on commit 58657cc

Please sign in to comment.