From 05981aaf4ab3a8c8cb5dbdfad9309f6407cb3b29 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Tue, 2 Apr 2024 08:33:04 -0400 Subject: [PATCH] debug oryx --- coix/loss.py | 10 +++++++++- examples/dmm.py | 6 ++---- examples/gmm.py | 31 ++++++++++++++++++++++--------- 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/coix/loss.py b/coix/loss.py index bc7fda4..b925d56 100644 --- a/coix/loss.py +++ b/coix/loss.py @@ -31,7 +31,11 @@ def apg_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): """RWS objective that exploits conditional dependency.""" - del incoming_log_weight, incremental_log_weight + # del incoming_log_weight, incremental_log_weight + print("p_trace", jax.tree.map(jnp.shape, p_trace)) + print("q_trace", jax.tree.map(jnp.shape, q_trace)) + jax.debug.print("incoming={ilw}", ilw=incoming_log_weight) + jax.debug.print("incremental={ilw}", ilw=incremental_log_weight) p_log_probs = { name: util.get_site_log_prob(site) for name, site in p_trace.items() } @@ -206,6 +210,10 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): """Reweighted Wake-Sleep objective.""" + print("p_trace", jax.tree.map(jnp.shape, p_trace)) + print("q_trace", jax.tree.map(jnp.shape, q_trace)) + jax.debug.print("incoming={ilw}", ilw=incoming_log_weight) + jax.debug.print("incremental={ilw}", ilw=incremental_log_weight) batch_ndims = incoming_log_weight.ndim p_log_probs = { name: util.get_site_log_prob(site) for name, site in p_trace.items() diff --git a/examples/dmm.py b/examples/dmm.py index 0ae8b7d..a20fc52 100644 --- a/examples/dmm.py +++ b/examples/dmm.py @@ -142,7 +142,7 @@ def __call__(self, x): x = nn.tanh(x) x = nn.Dense(2)(x) angle = x / jnp.linalg.norm(x, axis=-1, keepdims=True) - radius = self.param("radius", nn.initializers.ones, (1,)) + radius = 1. # self.param("radius", nn.initializers.ones, (1,)) return radius * angle @@ -274,9 +274,7 @@ def main(args): train_ds, ) - program = make_dmm(dmm_params, num_sweeps) - next(test_ds) - next(test_ds) + program = make_dmm(dmm_params, num_sweeps, num_particles) batch = next(test_ds) out, _, _ = coix.traced_evaluate(program, seed=jax.random.PRNGKey(1))(batch) out = out[0] diff --git a/examples/gmm.py b/examples/gmm.py index 32fbe94..48fbdd2 100644 --- a/examples/gmm.py +++ b/examples/gmm.py @@ -40,6 +40,7 @@ import numpy as np import numpyro import numpyro.distributions as dist +from numpyro.ops.indexing import Vindex import optax import tensorflow as tf @@ -149,17 +150,29 @@ def __call__(self, x): # N x D # Then, we define the target and kernels as in Section 6.2. -def gmm_target(network, key, inputs): - key_out, key_mean, key_tau, key_c = random.split(key, 4) - N = inputs.shape[-2] - - tau = coix.rv(dist.Gamma(2, 2).expand([3, 2]), name="tau")(key_tau) - mean = coix.rv(dist.Normal(0, 1 / jnp.sqrt(tau * 0.1)), name="mean")(key_mean) - c = coix.rv(dist.DiscreteUniform(0, 3).expand([N]), name="c")(key_c) - x = coix.rv(dist.Normal(mean[c], 1 / jnp.sqrt(tau[c])), obs=inputs, name="x") +def gmm_target(network, inputs): + with numpyro.plate("N", inputs.shape[-2], dim=-1): + tau = numpyro.sample("tau", dist.Gamma(2, 2).expand([3, 2]).to_event()) + mean = numpyro.sample("mean", dist.Normal(0, 1 / jnp.sqrt(tau * 0.1))) + c = numpyro.sample("c", dist.Categorical(probs=jnp.ones(4) / 4)) + loc = Vindex(mean)[..., c, :] + scale = 1 / jnp.sqrt(Vindex(tau)[..., c, :]) + x = numpyro.sample("x", dist.Normal(loc, scale).to_event(1), obs=inputs) out = {"mean": mean, "tau": tau, "c": c, "x": x} - return key_out, out + return out, + + +def dmm_target(network, inputs): + mu = numpyro.sample("mu", dist.Normal(0, 10).expand([4, 2]).to_event()) + with numpyro.plate("N", inputs.shape[-2], dim=-1): + c = numpyro.sample("c", dist.Categorical(probs=jnp.ones(4) / 4)) + h = numpyro.sample("h", dist.Beta(1, 1)) + x_recon = network.decode_h(h) + Vindex(mu)[..., c, :] + x = numpyro.sample("x", dist.Normal(x_recon, 0.1).to_event(1), obs=inputs) + + out = {"mu": mu, "c": c, "h": h, "x_recon": x_recon, "x": x} + return (out,) def gmm_kernel_mean_tau(network, key, inputs):