Skip to content

Commit

Permalink
debug oryx
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Apr 2, 2024
1 parent 53954c7 commit 05981aa
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 14 deletions.
10 changes: 9 additions & 1 deletion coix/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@

def apg_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
"""RWS objective that exploits conditional dependency."""
del incoming_log_weight, incremental_log_weight
# del incoming_log_weight, incremental_log_weight
print("p_trace", jax.tree.map(jnp.shape, p_trace))
print("q_trace", jax.tree.map(jnp.shape, q_trace))
jax.debug.print("incoming={ilw}", ilw=incoming_log_weight)
jax.debug.print("incremental={ilw}", ilw=incremental_log_weight)
p_log_probs = {
name: util.get_site_log_prob(site) for name, site in p_trace.items()
}
Expand Down Expand Up @@ -206,6 +210,10 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):

def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
"""Reweighted Wake-Sleep objective."""
print("p_trace", jax.tree.map(jnp.shape, p_trace))
print("q_trace", jax.tree.map(jnp.shape, q_trace))
jax.debug.print("incoming={ilw}", ilw=incoming_log_weight)
jax.debug.print("incremental={ilw}", ilw=incremental_log_weight)
batch_ndims = incoming_log_weight.ndim
p_log_probs = {
name: util.get_site_log_prob(site) for name, site in p_trace.items()
Expand Down
6 changes: 2 additions & 4 deletions examples/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __call__(self, x):
x = nn.tanh(x)
x = nn.Dense(2)(x)
angle = x / jnp.linalg.norm(x, axis=-1, keepdims=True)
radius = self.param("radius", nn.initializers.ones, (1,))
radius = 1. # self.param("radius", nn.initializers.ones, (1,))
return radius * angle


Expand Down Expand Up @@ -274,9 +274,7 @@ def main(args):
train_ds,
)

program = make_dmm(dmm_params, num_sweeps)
next(test_ds)
next(test_ds)
program = make_dmm(dmm_params, num_sweeps, num_particles)
batch = next(test_ds)
out, _, _ = coix.traced_evaluate(program, seed=jax.random.PRNGKey(1))(batch)
out = out[0]
Expand Down
31 changes: 22 additions & 9 deletions examples/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.ops.indexing import Vindex
import optax
import tensorflow as tf

Expand Down Expand Up @@ -149,17 +150,29 @@ def __call__(self, x): # N x D
# Then, we define the target and kernels as in Section 6.2.


def gmm_target(network, key, inputs):
key_out, key_mean, key_tau, key_c = random.split(key, 4)
N = inputs.shape[-2]

tau = coix.rv(dist.Gamma(2, 2).expand([3, 2]), name="tau")(key_tau)
mean = coix.rv(dist.Normal(0, 1 / jnp.sqrt(tau * 0.1)), name="mean")(key_mean)
c = coix.rv(dist.DiscreteUniform(0, 3).expand([N]), name="c")(key_c)
x = coix.rv(dist.Normal(mean[c], 1 / jnp.sqrt(tau[c])), obs=inputs, name="x")
def gmm_target(network, inputs):
with numpyro.plate("N", inputs.shape[-2], dim=-1):
tau = numpyro.sample("tau", dist.Gamma(2, 2).expand([3, 2]).to_event())
mean = numpyro.sample("mean", dist.Normal(0, 1 / jnp.sqrt(tau * 0.1)))
c = numpyro.sample("c", dist.Categorical(probs=jnp.ones(4) / 4))
loc = Vindex(mean)[..., c, :]
scale = 1 / jnp.sqrt(Vindex(tau)[..., c, :])
x = numpyro.sample("x", dist.Normal(loc, scale).to_event(1), obs=inputs)

out = {"mean": mean, "tau": tau, "c": c, "x": x}
return key_out, out
return out,


def dmm_target(network, inputs):
mu = numpyro.sample("mu", dist.Normal(0, 10).expand([4, 2]).to_event())
with numpyro.plate("N", inputs.shape[-2], dim=-1):
c = numpyro.sample("c", dist.Categorical(probs=jnp.ones(4) / 4))
h = numpyro.sample("h", dist.Beta(1, 1))
x_recon = network.decode_h(h) + Vindex(mu)[..., c, :]
x = numpyro.sample("x", dist.Normal(x_recon, 0.1).to_event(1), obs=inputs)

out = {"mu": mu, "c": c, "h": h, "x_recon": x_recon, "x": x}
return (out,)


def gmm_kernel_mean_tau(network, key, inputs):
Expand Down

0 comments on commit 05981aa

Please sign in to comment.