Skip to content

Commit

Permalink
make dmm numpyro work
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Apr 2, 2024
1 parent 11425c5 commit 53954c7
Show file tree
Hide file tree
Showing 11 changed files with 107 additions and 85 deletions.
6 changes: 5 additions & 1 deletion coix/numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def wrapped(*args, **kwargs):
value = site["value"]
log_prob = site["fn"].log_prob(value)
event_dim_holder = jnp.empty([1] * site["fn"].event_dim)
trace[name] = {"value": value, "log_prob": log_prob, "_event_dim_holder": event_dim_holder}
trace[name] = {
"value": value,
"log_prob": log_prob,
"_event_dim_holder": event_dim_holder,
}
if site.get("is_observed", False):
trace[name]["is_observed"] = True
metrics = {
Expand Down
Binary file added docs/_static/anneal.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/anneal_oryx.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/gmm_oryx.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion examples/anneal.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def eval_program(seed):

plt.figure(figsize=(8, 8))
x = trace["x"]["value"].reshape((-1, 2))
H, xedges, yedges = np.histogram2d(x[:, 0], x[:, 1], bins=100)
H, _, _ = np.histogram2d(x[:, 0], x[:, 1], bins=100)
plt.imshow(H.T)
plt.show()

Expand Down
2 changes: 1 addition & 1 deletion examples/anneal_oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def main(args):

plt.figure(figsize=(8, 8))
x = trace["x"]["value"].reshape((-1, 2))
H, xedges, yedges = np.histogram2d(x[:, 0], x[:, 1], bins=100)
H, _, _ = np.histogram2d(x[:, 0], x[:, 1], bins=100)
plt.imshow(H.T)
plt.show()

Expand Down
50 changes: 27 additions & 23 deletions examples/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from numpyro.ops.indexing import Vindex
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

# %%
# First, let's simulate a synthetic dataset of 2D ring-shaped mixtures.
Expand All @@ -64,17 +63,22 @@ def simulate_rings(num_instances=1, N=200, seed=0):
return np.take_along_axis(x, shuffle_idx, axis=1)


def load_dataset(split, *, is_training, batch_size):
num_data = 20000 if is_training else batch_size
num_points = 200 if is_training else 600
seed = 0 if is_training else 1
def load_dataset(split, *, batch_size):
if split == "train":
num_data = 20000
num_points = 200
seed = 0
else:
num_data = batch_size
num_points = 600
seed = 1
data = simulate_rings(num_data, num_points, seed=seed)
ds = tf.data.Dataset.from_tensor_slices(data)
ds = ds.repeat()
if is_training:
if split == "train":
ds = ds.shuffle(10 * batch_size, seed=0)
ds = ds.batch(batch_size)
return iter(tfds.as_numpy(ds))
return ds.as_numpy_iterator()


# %%
Expand Down Expand Up @@ -155,8 +159,6 @@ def __call__(self, x): # N x D
# Heuristic procedure to setup initial parameters.
mu, _ = self.encode_initial_mu(x) # M x D

# concatenate_fn = lambda x, m: jnp.concatenate([x, m], axis=-1)
# xmu = jax.vmap(jax.vmap(concatenate_fn, (None, 0)), (0, None))(x, mu)
xmu = jnp.expand_dims(x, -2) - mu
logits = self.encode_c(xmu) # N x M
c = jnp.argmax(logits, -1) # N
Expand All @@ -178,15 +180,15 @@ def __call__(self, x): # N x D


def dmm_target(network, inputs):
mu = numpyro.sample("mu", dist.Normal(0, 10).expand([4, 2 ]).to_event())
mu = numpyro.sample("mu", dist.Normal(0, 10).expand([4, 2]).to_event())
with numpyro.plate("N", inputs.shape[-2], dim=-1):
c = numpyro.sample("c", dist.Categorical(probs=jnp.ones(4) / 4))
h = numpyro.sample("h", dist.Beta(1, 1))
x_recon = network.decode_h(h) + Vindex(mu)[..., c, :]
x = numpyro.sample("x", dist.Normal(x_recon, 0.1).to_event(1), obs=inputs)

out = {"mu": mu, "c": c, "h": h, "x_recon": x_recon, "x": x}
return out,
return (out,)


def dmm_kernel_mu(network, inputs):
Expand All @@ -205,7 +207,7 @@ def dmm_kernel_mu(network, inputs):
mu = numpyro.sample("mu", dist.Normal(loc, scale).to_event(2))

out = {**inputs, **{"mu": mu}}
return out,
return (out,)


def dmm_kernel_c_h(network, inputs):
Expand All @@ -218,7 +220,7 @@ def dmm_kernel_c_h(network, inputs):
h = numpyro.sample("h", dist.Beta(alpha, beta))

out = {**inputs, **{"c": c, "h": h}}
return out,
return (out,)


# %%
Expand Down Expand Up @@ -258,8 +260,8 @@ def main(args):
num_sweeps = args.num_sweeps
num_particles = args.num_particles

train_ds = load_dataset("train", is_training=True, batch_size=batch_size)
test_ds = load_dataset("test", is_training=False, batch_size=batch_size)
train_ds = load_dataset("train", batch_size=batch_size)
test_ds = load_dataset("test", batch_size=batch_size)

init_params = DMMAutoEncoder().init(
jax.random.PRNGKey(0), jnp.zeros((200, 2))
Expand All @@ -273,23 +275,25 @@ def main(args):
)

program = make_dmm(dmm_params, num_sweeps)
next(test_ds)
next(test_ds)
batch = next(test_ds)
out, _, _ = coix.traced_evaluate(program, seed=jax.random.PRNGKey(1))(batch)
out = out[0]

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
_, axes = plt.subplots(2, 3, figsize=(15, 10))
for i in range(3):
n = i
axes[0][i].scatter(out["x"][n, 0, :, 0], out["x"][n, 0, :, 1], marker=".")
axes[0][i].scatter(out["x"][i, :, 0], out["x"][i, :, 1], marker=".")
axes[1][i].scatter(
out["x_recon"][n, 0, :, 0],
out["x_recon"][n, 0, :, 1],
c=out["c"][n, 0],
out["x_recon"][0, i, :, 0],
out["x_recon"][0, i, :, 1],
c=out["c"][0, i],
cmap="Accent",
marker=".",
)
axes[1][i].scatter(
out["mu"][n, 0, :, 0],
out["mu"][n, 0, :, 1],
out["mu"][0, i, 0, :, 0],
out["mu"][0, i, 0, :, 1],
c=range(4),
marker="x",
cmap="Accent",
Expand Down
32 changes: 17 additions & 15 deletions examples/dmm_oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import numpyro.distributions as dist
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

# %%
# First, let's simulate a synthetic dataset of 2D ring-shaped mixtures.
Expand All @@ -65,16 +64,21 @@ def simulate_rings(num_instances=1, N=200, seed=0):


def load_dataset(split, *, is_training, batch_size):
num_data = 20000 if is_training else batch_size
num_points = 200 if is_training else 600
seed = 0 if is_training else 1
if split == "train":
num_data = 20000
num_points = 200
seed = 0
else:
num_data = batch_size
num_points = 600
seed = 1
data = simulate_rings(num_data, num_points, seed=seed)
ds = tf.data.Dataset.from_tensor_slices(data)
ds = ds.cache().repeat()
if is_training:
ds = ds.repeat()
if split == "train":
ds = ds.shuffle(10 * batch_size, seed=0)
ds = ds.batch(batch_size)
return iter(tfds.as_numpy(ds))
return ds.as_numpy_iterator()


# %%
Expand Down Expand Up @@ -154,8 +158,7 @@ def __call__(self, x): # N x D
# Heuristic procedure to setup initial parameters.
mu, _ = self.encode_initial_mu(x) # M x D

concatenate_fn = lambda x, m: jnp.concatenate([x, m], axis=-1)
xmu = jax.vmap(jax.vmap(concatenate_fn, (None, 0)), (0, None))(x, mu)
xmu = jnp.expand_dims(x, -2) - mu
logits = self.encode_c(xmu) # N x M
c = jnp.argmax(logits, -1) # N

Expand Down Expand Up @@ -195,9 +198,10 @@ def dmm_kernel_mu(network, key, inputs):
key_out, key_mu = random.split(key)

if "c" in inputs:
x = inputs["x"]
c = jax.nn.one_hot(inputs["c"], 4)
h = jnp.expand_dims(inputs["h"], -1)
xch = jnp.concatenate([inputs["x"], c, h], -1)
xch = jnp.concatenate([x, c, h], -1)
loc, scale = network.encode_mu(xch)
else:
loc, scale = network.encode_initial_mu(inputs["x"])
Expand All @@ -210,13 +214,11 @@ def dmm_kernel_mu(network, key, inputs):
def dmm_kernel_c_h(network, key, inputs):
key_out, key_c, key_h = random.split(key, 3)

concatenate_fn = lambda x, m: jnp.concatenate([x, m], axis=-1)
xmu = jax.vmap(jax.vmap(concatenate_fn, (None, 0)), (0, None))(
inputs["x"], inputs["mu"]
)
x, mu = inputs["x"], inputs["mu"]
xmu = jnp.expand_dims(x, -2) - mu
logits = network.encode_c(xmu)
c = coryx.rv(dist.Categorical(logits=logits), name="c")(key_c)
alpha, beta = network.encode_h(inputs["x"] - inputs["mu"][c])
alpha, beta = network.encode_h(x - mu[c])
h = coryx.rv(dist.Beta(alpha, beta), name="h")(key_h)

out = {**inputs, **{"c": c, "h": h}}
Expand Down
26 changes: 15 additions & 11 deletions examples/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import numpyro.distributions as dist
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

# %%
# First, let's simulate a synthetic dataset of 2D Gaussian mixtures.
Expand All @@ -61,20 +60,25 @@ def simulate_clusters(num_instances=1, N=60, seed=0):
return x, c


def load_dataset(split, *, is_training, batch_size):
num_data = 20000 if is_training else batch_size
num_points = 60 if is_training else 100
seed = 0 if is_training else 1
def load_dataset(split, *, batch_size):
if split == "train":
num_data = 20000
num_points = 60
seed = 0
else:
num_data = batch_size
num_points = 100
seed = 1
data, label = simulate_clusters(num_data, num_points, seed=seed)
if is_training:
if split == "train":
ds = tf.data.Dataset.from_tensor_slices(data)
ds = ds.cache().repeat()
ds = ds.repeat()
ds = ds.shuffle(10 * batch_size, seed=0)
else:
ds = tf.data.Dataset.from_tensor_slices((data, label))
ds = ds.cache().repeat()
ds = ds.repeat()
ds = ds.batch(batch_size)
return iter(tfds.as_numpy(ds))
return ds.as_numpy_iterator()


# %%
Expand Down Expand Up @@ -231,8 +235,8 @@ def main(args):
num_sweeps = args.num_sweeps
num_particles = args.num_particles

train_ds = load_dataset("train", is_training=True, batch_size=batch_size)
test_ds = load_dataset("test", is_training=False, batch_size=batch_size)
train_ds = load_dataset("train", batch_size=batch_size)
test_ds = load_dataset("test", batch_size=batch_size)

init_params = GMMEncoder().init(jax.random.PRNGKey(0), jnp.zeros((60, 2)))
gmm_params, _ = coix.util.train(
Expand Down
26 changes: 15 additions & 11 deletions examples/gmm_oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import numpyro.distributions as dist
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

# %%
# First, let's simulate a synthetic dataset of Gaussian clusters.
Expand All @@ -65,20 +64,25 @@ def simulate_clusters(num_instances=1, N=60, seed=0):
return x, c


def load_dataset(split, *, is_training, batch_size):
num_data = 20000 if is_training else batch_size
num_points = 60 if is_training else 100
seed = 0 if is_training else 1
def load_dataset(split, *, batch_size):
if split == "train":
num_data = 20000
num_points = 60
seed = 0
else:
num_data = batch_size
num_points = 100
seed = 1
data, label = simulate_clusters(num_data, num_points, seed=seed)
if is_training:
if split == "train":
ds = tf.data.Dataset.from_tensor_slices(data)
ds = ds.cache().repeat()
ds = ds.repeat()
ds = ds.shuffle(10 * batch_size, seed=0)
else:
ds = tf.data.Dataset.from_tensor_slices((data, label))
ds = ds.cache().repeat()
ds = ds.repeat()
ds = ds.batch(batch_size)
return iter(tfds.as_numpy(ds))
return ds.as_numpy_iterator()


# %%
Expand Down Expand Up @@ -239,8 +243,8 @@ def main(args):
num_sweeps = args.num_sweeps
num_particles = args.num_particles

train_ds = load_dataset("train", is_training=True, batch_size=batch_size)
test_ds = load_dataset("test", is_training=False, batch_size=batch_size)
train_ds = load_dataset("train", batch_size=batch_size)
test_ds = load_dataset("test", batch_size=batch_size)

init_params = GMMEncoder().init(jax.random.PRNGKey(0), jnp.zeros((60, 2)))
gmm_params, _ = coix.util.train(
Expand Down
Loading

0 comments on commit 53954c7

Please sign in to comment.