diff --git a/notebooks/tutorial_part1_intro.ipynb b/notebooks/tutorial_part1_intro.ipynb index 73487fe..26e4b3a 100644 --- a/notebooks/tutorial_part1_intro.ipynb +++ b/notebooks/tutorial_part1_intro.ipynb @@ -16,7 +16,7 @@ " \\\\\n", " \\textbf{p} &::= \\textbf{f} \\mid \\mathrm{extend}(\\textbf{p}, \\textbf{f}) && (\\text{target programs})\n", " \\\\\n", - " \\textbf{q} &::= \\textbf{p} \\mid \\mathrm{propose}(\\textbf{p}, \\textbf{q}) \\mid \\mathrm{resample}(\\textbf{q}) \\mid \\mathrm{compose}(\\textbf{q}, \\textbf{q}) \n", + " \\textbf{q} &::= \\textbf{p} \\mid \\mathrm{propose}(\\textbf{p}, \\textbf{q}) \\mid \\mathrm{resample}(\\textbf{q}) \\mid \\mathrm{compose}(\\textbf{q}, \\textbf{q})\n", " && (\\text{inference programs})\n", "\\end{align}\n", "\n", @@ -45,15 +45,19 @@ "from numpyro.handlers import seed, trace\n", "import numpyro.distributions as dist\n", "import coix\n", - "coix.set_backend(\"coix.numpyro\") # Setting the backend depending on the modeling language, here python + numpyro\n", + "\n", + "coix.set_backend(\n", + " \"coix.numpyro\"\n", + ") # Setting the backend depending on the modeling language, here python + numpyro\n", "from coix import traced_evaluate\n", "\n", - "log_phi = lambda x: -0.5*((x-1.)/0.1)**2\n", + "log_phi = lambda x: -0.5 * ((x - 1.0) / 0.1) ** 2\n", + "\n", "\n", "def f():\n", - " x = numpyro.sample(\"x\", dist.Normal(0., 1.))\n", - " numpyro.factor(\"phi_x\", log_phi(x))\n", - " return (x,)" + " x = numpyro.sample(\"x\", dist.Normal(0.0, 1.0))\n", + " numpyro.factor(\"phi_x\", log_phi(x))\n", + " return (x,)" ] }, { @@ -118,14 +122,14 @@ "id": "ef210bb5-54b1-42f4-93b9-652d57b8a2d3", "metadata": {}, "source": [ - "We can see that the program trace has two nodes: (1) a node corresponding to the random variable $x$ and (2) a node corresponding to the factor node $\\phi_x$ which are both of type `sample`. The factor node `phi_x` is *observerd* while the random variable node `x` is not. \n", + "We can see that the program trace has two nodes: (1) a node corresponding to the random variable $x$ and (2) a node corresponding to the factor node $\\phi_x$ which are both of type `sample`. The factor node `phi_x` is *observerd* while the random variable node `x` is not.\n", "We will see that whether a node is *observed* or *unobserved* plays an important role in the semantics of a program, as it changes the density the denotes denotes. In combinators each primitive program denotes two densities:\n", "1. a **prior density**, which is defined as the joint density over all unobserverd variables in the program\n", "2. an **unnormalized target density**, which is defined as the joint density over all variables in the program\n", "\n", "To get a better understanding of these densities and why their distinction is important, let's visualize these densities for the primitive program `f` that we defined above:\n", "1. The prior density is given by the denstity of the normal distribution\n", - "2. The unnormalized target denstity given by the product of the densities of the normal distribution and the factor node\n" + "2. The unnormalized target denstity given by the product of the densities of the normal distribution and the factor node" ] }, { @@ -153,10 +157,16 @@ "log_target_density = lambda x: log_prior_density(x) + log_phi(x)\n", "xrange_prior = np.linspace(-5, 5, 1000)\n", "\n", - "plt.figure(figsize=(8,4))\n", + "plt.figure(figsize=(8, 4))\n", "plt.title(\"Prior and target density of a primitive program\")\n", - "plt.plot(xrange_prior, np.exp(log_prior_density(xrange_prior)), label=\"prior density\")\n", - "plt.plot(xrange_prior, np.exp(log_target_density(xrange_prior)), label=\"target density\")\n", + "plt.plot(\n", + " xrange_prior, np.exp(log_prior_density(xrange_prior)), label=\"prior density\"\n", + ")\n", + "plt.plot(\n", + " xrange_prior,\n", + " np.exp(log_target_density(xrange_prior)),\n", + " label=\"target density\",\n", + ")\n", "plt.plot(xrange_prior, np.exp(log_phi(xrange_prior)), label=\"factor density\")\n", "plt.legend();" ] @@ -167,7 +177,7 @@ "metadata": {}, "source": [ "\n", - "While it is possible to evaluate programs using numpyro effect handlers, coix implements it's own evaluation handler, `traced_evaluate`, which uses numpyro's `seed` and `trace` handlers under the hood. It exposes the return value of the program, a simplified trace, and an additional `metrics` dictionary, which stores different evaluation metrics that are accumulated during the execution of the program. \n", + "While it is possible to evaluate programs using numpyro effect handlers, coix implements it's own evaluation handler, `traced_evaluate`, which uses numpyro's `seed` and `trace` handlers under the hood. It exposes the return value of the program, a simplified trace, and an additional `metrics` dictionary, which stores different evaluation metrics that are accumulated during the execution of the program.\n", "Most importantly, the metrics dictionary contains the log-importance weights corresponding to the execution traces of the program. For primitive programs the log weight is defined as the sum of the log probabilities of all observed random variables in the trace. Hence, the log weight is precisely the difference between the log-prior density and the log-target density of a program. Let's verify this for our example program `f`:" ] }, @@ -189,16 +199,31 @@ } ], "source": [ - "_, f_batch_trace, f_batch_metrics = traced_evaluate(numpyro.plate(\"particle_plate\", 20)(f), seed=0)()\n", + "_, f_batch_trace, f_batch_metrics = traced_evaluate(\n", + " numpyro.plate(\"particle_plate\", 20)(f), seed=0\n", + ")()\n", "\n", "plt.figure(figsize=(10, 6))\n", "plt.title(\"Prior and target density of a primitive program\")\n", - "plt.plot(xrange_prior, np.exp(log_prior_density(xrange_prior)), label=\"prior density\")\n", - "plt.plot(xrange_prior, np.exp(log_target_density(xrange_prior)), label=\"target density\")\n", - "\n", - "plt.scatter(f_batch_trace[\"x\"][\"value\"], np.exp(f_batch_trace[\"x\"][\"log_prob\"]), label=\"prior density\")\n", - "plt.scatter(f_batch_trace[\"x\"][\"value\"], np.exp(f_batch_trace[\"x\"][\"log_prob\"] + f_batch_metrics[\"log_weight\"]), \n", - " label=\"density * weight\")\n", + "plt.plot(\n", + " xrange_prior, np.exp(log_prior_density(xrange_prior)), label=\"prior density\"\n", + ")\n", + "plt.plot(\n", + " xrange_prior,\n", + " np.exp(log_target_density(xrange_prior)),\n", + " label=\"target density\",\n", + ")\n", + "\n", + "plt.scatter(\n", + " f_batch_trace[\"x\"][\"value\"],\n", + " np.exp(f_batch_trace[\"x\"][\"log_prob\"]),\n", + " label=\"prior density\",\n", + ")\n", + "plt.scatter(\n", + " f_batch_trace[\"x\"][\"value\"],\n", + " np.exp(f_batch_trace[\"x\"][\"log_prob\"] + f_batch_metrics[\"log_weight\"]),\n", + " label=\"density * weight\",\n", + ")\n", "plt.legend();" ] }, @@ -268,21 +293,37 @@ } ], "source": [ - "_, f_batch_trace, f_batch_metrics = traced_evaluate(numpyro.plate(\"particle_plate\", 10000)(f), seed=0)()\n", + "_, f_batch_trace, f_batch_metrics = traced_evaluate(\n", + " numpyro.plate(\"particle_plate\", 10000)(f), seed=0\n", + ")()\n", "\n", "approx_target_sampels = f_batch_trace[\"x\"][\"value\"]\n", "weights = jax.nn.softmax(f_batch_metrics[\"log_weight\"])\n", "\n", - "var_prior = 1.\n", + "var_prior = 1.0\n", "var_factor = 0.1**2\n", - "var_target = 1/(1/var_prior + 1/var_factor)\n", - "Z_target = np.sqrt(2*np.pi*var_target) * jnp.sqrt(2*jnp.pi*var_factor)\n", - "normalized_log_target_density = lambda x: log_target_density(x) - jnp.log(Z_target)\n", + "var_target = 1 / (1 / var_prior + 1 / var_factor)\n", + "Z_target = np.sqrt(2 * np.pi * var_target) * jnp.sqrt(2 * jnp.pi * var_factor)\n", + "normalized_log_target_density = lambda x: log_target_density(x) - jnp.log(\n", + " Z_target\n", + ")\n", "\n", "xrange_target = np.linspace(0, 2, 100)\n", - "plt.plot(xrange_target, np.exp(normalized_log_target_density(xrange_target)), label=\"target density\", color=\"C1\")\n", - "_ = plt.hist(approx_target_sampels, weights=weights, density=True, \n", - " bins=100, range=(xrange_target[0], xrange_target[-1]), color=\"C1\", alpha=0.5)\n", + "plt.plot(\n", + " xrange_target,\n", + " np.exp(normalized_log_target_density(xrange_target)),\n", + " label=\"target density\",\n", + " color=\"C1\",\n", + ")\n", + "_ = plt.hist(\n", + " approx_target_sampels,\n", + " weights=weights,\n", + " density=True,\n", + " bins=100,\n", + " range=(xrange_target[0], xrange_target[-1]),\n", + " color=\"C1\",\n", + " alpha=0.5,\n", + ")\n", "plt.legend()" ] }, @@ -344,60 +385,147 @@ } ], "source": [ - "from matplotlib import colors, lines, gridspec \n", - "\n", - "def plot_extended_density_samples(p, name_x=\"x\", name_y=\"y\", color1=\"C0\", color2=\"C1\"):\n", - " p_batch = numpyro.plate(\"particle_plate\", 10000)(p)\n", - " out, trace, metrics = traced_evaluate(p_batch, seed=0)()\n", - " out_rs, trace_rs, metrics_rs = traced_evaluate(coix.resample(p_batch), seed=0)()\n", - " xs, ys, ws = trace[name_x][\"value\"], trace[name_y][\"value\"], jax.nn.softmax(metrics[\"log_weight\"])\n", - " # xs_rs, ys_rs = trace_rs[name_x][\"value\"], trace_rs[name_y][\"value\"]\n", - " \n", - " fig = plt.figure(figsize=(8,8))\n", - " gs = gridspec.GridSpec(3, 3, wspace=0, hspace=0)\n", - " \n", - " ax_xy = plt.subplot(gs[1:3, :2])\n", - " cmap_c0 = colors.LinearSegmentedColormap.from_list('c0_alpha', [colors.colorConverter.to_rgba(color1, alpha=0), colors.colorConverter.to_rgba(color1, alpha=1)], 256)\n", - " cmap_c1 = colors.LinearSegmentedColormap.from_list('c1_alpha', [colors.colorConverter.to_rgba(color2, alpha=0), colors.colorConverter.to_rgba(color2, alpha=1)], 256)\n", - " ax_xy.hist2d(xs, ys, bins=100, density=True, label=\"approx. samples from extended target density\", cmap=cmap_c0)\n", - " ax_xy.hist2d(xs, ys, bins=100, density=True, weights=ws, label=\"approx. samples from extended target density\", cmap=cmap_c1)\n", - " ax_xy.set(xlabel=name_x, ylabel=name_y)\n", - " \n", - " ax_x = plt.subplot(gs[0, :2], sharex=ax_xy)\n", - " ax_x.set(title=f\"{name_x}-marginal densities\")\n", - " ax_x.hist(xs, bins=100, align='mid', density=True, alpha=0.5, color=color1)\n", - " ax_x.hist(xs, bins=100, weights=ws, align='mid', density=True, alpha=0.5, color=color2)\n", - "\n", - " ax_y = plt.subplot(gs[1:3, 2], sharey=ax_xy)\n", - " ax_y.set(title=f\"{name_y}-marginal densities\")\n", - " ax_y.hist(ys, bins=100, orientation='horizontal', align='mid', density=True, alpha=0.5, color=color1)\n", - " ax_y.hist(ys, bins=100, weights=ws, align='mid', density=True, orientation='horizontal', alpha=0.5, color=color2)\n", - " return ax_xy, ax_x, ax_y\n", - " \n", + "from matplotlib import colors, lines, gridspec\n", + "\n", + "\n", + "def plot_extended_density_samples(\n", + " p, name_x=\"x\", name_y=\"y\", color1=\"C0\", color2=\"C1\"\n", + "):\n", + " p_batch = numpyro.plate(\"particle_plate\", 10000)(p)\n", + " out, trace, metrics = traced_evaluate(p_batch, seed=0)()\n", + " out_rs, trace_rs, metrics_rs = traced_evaluate(\n", + " coix.resample(p_batch), seed=0\n", + " )()\n", + " xs, ys, ws = (\n", + " trace[name_x][\"value\"],\n", + " trace[name_y][\"value\"],\n", + " jax.nn.softmax(metrics[\"log_weight\"]),\n", + " )\n", + " # xs_rs, ys_rs = trace_rs[name_x][\"value\"], trace_rs[name_y][\"value\"]\n", + "\n", + " fig = plt.figure(figsize=(8, 8))\n", + " gs = gridspec.GridSpec(3, 3, wspace=0, hspace=0)\n", + "\n", + " ax_xy = plt.subplot(gs[1:3, :2])\n", + " cmap_c0 = colors.LinearSegmentedColormap.from_list(\n", + " \"c0_alpha\",\n", + " [\n", + " colors.colorConverter.to_rgba(color1, alpha=0),\n", + " colors.colorConverter.to_rgba(color1, alpha=1),\n", + " ],\n", + " 256,\n", + " )\n", + " cmap_c1 = colors.LinearSegmentedColormap.from_list(\n", + " \"c1_alpha\",\n", + " [\n", + " colors.colorConverter.to_rgba(color2, alpha=0),\n", + " colors.colorConverter.to_rgba(color2, alpha=1),\n", + " ],\n", + " 256,\n", + " )\n", + " ax_xy.hist2d(\n", + " xs,\n", + " ys,\n", + " bins=100,\n", + " density=True,\n", + " label=\"approx. samples from extended target density\",\n", + " cmap=cmap_c0,\n", + " )\n", + " ax_xy.hist2d(\n", + " xs,\n", + " ys,\n", + " bins=100,\n", + " density=True,\n", + " weights=ws,\n", + " label=\"approx. samples from extended target density\",\n", + " cmap=cmap_c1,\n", + " )\n", + " ax_xy.set(xlabel=name_x, ylabel=name_y)\n", + "\n", + " ax_x = plt.subplot(gs[0, :2], sharex=ax_xy)\n", + " ax_x.set(title=f\"{name_x}-marginal densities\")\n", + " ax_x.hist(xs, bins=100, align=\"mid\", density=True, alpha=0.5, color=color1)\n", + " ax_x.hist(\n", + " xs,\n", + " bins=100,\n", + " weights=ws,\n", + " align=\"mid\",\n", + " density=True,\n", + " alpha=0.5,\n", + " color=color2,\n", + " )\n", + "\n", + " ax_y = plt.subplot(gs[1:3, 2], sharey=ax_xy)\n", + " ax_y.set(title=f\"{name_y}-marginal densities\")\n", + " ax_y.hist(\n", + " ys,\n", + " bins=100,\n", + " orientation=\"horizontal\",\n", + " align=\"mid\",\n", + " density=True,\n", + " alpha=0.5,\n", + " color=color1,\n", + " )\n", + " ax_y.hist(\n", + " ys,\n", + " bins=100,\n", + " weights=ws,\n", + " align=\"mid\",\n", + " density=True,\n", + " orientation=\"horizontal\",\n", + " alpha=0.5,\n", + " color=color2,\n", + " )\n", + " return ax_xy, ax_x, ax_y\n", + "\n", + "\n", "def f2(x):\n", - " y = numpyro.sample(\"y\", dist.Normal(2*x + 3, 0.5))\n", - " return (y,)\n", + " y = numpyro.sample(\"y\", dist.Normal(2 * x + 3, 0.5))\n", + " return (y,)\n", + "\n", "\n", "p_ext = coix.extend(f, f2)\n", - "log_extend_density = lambda x, y: dist.Normal(2*x + 3, 0.5).log_prob(y)\n", - "log_extended_target_density = lambda x, y: log_target_density(x) + log_extend_density(x, y)\n", - "log_extended_prior_density = lambda x, y: log_prior_density(x) + log_extend_density(x, y)\n", + "log_extend_density = lambda x, y: dist.Normal(2 * x + 3, 0.5).log_prob(y)\n", + "log_extended_target_density = lambda x, y: log_target_density(\n", + " x\n", + ") + log_extend_density(x, y)\n", + "log_extended_prior_density = lambda x, y: log_prior_density(\n", + " x\n", + ") + log_extend_density(x, y)\n", "\n", "N_x, N_y = 200, 400\n", "xrange_ext = np.linspace(-4, 4, N_x)\n", "yrange_ext = np.linspace(-4, 10, N_y)\n", "m_xy = np.dstack(np.meshgrid(xrange_ext, yrange_ext))\n", - "m_p_target = np.exp(log_extended_target_density(*m_xy.reshape(N_x * N_y, 2).T).reshape(N_y, N_x))\n", - "m_p_prior = np.exp(log_extended_prior_density(*m_xy.reshape(N_x * N_y, 2).T).reshape(N_y, N_x))\n", + "m_p_target = np.exp(\n", + " log_extended_target_density(*m_xy.reshape(N_x * N_y, 2).T).reshape(N_y, N_x)\n", + ")\n", + "m_p_prior = np.exp(\n", + " log_extended_prior_density(*m_xy.reshape(N_x * N_y, 2).T).reshape(N_y, N_x)\n", + ")\n", "\n", "ax_xy, ax_x, ax_y = plot_extended_density_samples(p_ext)\n", "ax_x.plot(xrange_prior, np.exp(log_prior_density(xrange_prior)), color=\"C0\")\n", - "ax_x.plot(xrange_prior, np.exp(normalized_log_target_density(xrange_prior)), color=\"C1\")\n", - "ax_xy.contour(m_xy[..., 0], m_xy[..., 1], m_p_prior, levels=[0.05, 0.3], colors=\"C0\")\n", - "ax_xy.contour(m_xy[..., 0], m_xy[..., 1], m_p_target, levels=[0.05, 0.3], colors=\"C1\")\n", + "ax_x.plot(\n", + " xrange_prior,\n", + " np.exp(normalized_log_target_density(xrange_prior)),\n", + " color=\"C1\",\n", + ")\n", + "ax_xy.contour(\n", + " m_xy[..., 0], m_xy[..., 1], m_p_prior, levels=[0.05, 0.3], colors=\"C0\"\n", + ")\n", + "ax_xy.contour(\n", + " m_xy[..., 0], m_xy[..., 1], m_p_target, levels=[0.05, 0.3], colors=\"C1\"\n", + ")\n", "handles, labels = ax_xy.get_legend_handles_labels()\n", - "handles.extend([lines.Line2D([0], [0], label='prior density of $extend(f,\\ f2)$', color='C0'),\n", - " lines.Line2D([0], [0], label='target denstity of $extend(f,\\ f2)$', color='C1')])\n", + "handles.extend([\n", + " lines.Line2D(\n", + " [0], [0], label=\"prior density of $extend(f,\\ f2)$\", color=\"C0\"\n", + " ),\n", + " lines.Line2D(\n", + " [0], [0], label=\"target denstity of $extend(f,\\ f2)$\", color=\"C1\"\n", + " ),\n", + "])\n", "ax_xy.legend(handles=handles, loc=\"lower left\");" ] }, @@ -457,8 +585,16 @@ "f3 = lambda *args: f2(*f(*args))\n", "out_pri, trace_pri, _ = traced_evaluate(f3, seed=0)()\n", "out_ext, trace_ext, _ = traced_evaluate(p_ext, seed=0)()\n", - "print(\"Return value of f3\", out_pri, \"--> This is the value of y (check trace below)\")\n", - "print(\"Return value of p_ext\", out_ext, \"--> This is the value of x (check trace below)\")\n", + "print(\n", + " \"Return value of f3\",\n", + " out_pri,\n", + " \"--> This is the value of y (check trace below)\",\n", + ")\n", + "print(\n", + " \"Return value of p_ext\",\n", + " out_ext,\n", + " \"--> This is the value of x (check trace below)\",\n", + ")\n", "print(\"\\nThe traces of f3 and p_ext are identical:\")\n", "trace_pri, trace_ext" ] @@ -569,13 +705,17 @@ ], "source": [ "def q():\n", - " x = numpyro.sample(\"x\", dist.Normal(1, 0.5))\n", - " return (x,)\n", + " x = numpyro.sample(\"x\", dist.Normal(1, 0.5))\n", + " return (x,)\n", + "\n", "\n", "def log_proposal_density(x):\n", - " return dist.Normal(1, 0.5).log_prob(x)\n", + " return dist.Normal(1, 0.5).log_prob(x)\n", + "\n", "\n", - "log_extended_proposal_density = lambda x, y: log_proposal_density(x) + log_extend_density(x, y)\n", + "log_extended_proposal_density = lambda x, y: log_proposal_density(\n", + " x\n", + ") + log_extend_density(x, y)\n", "\n", "\n", "f_batch = numpyro.plate(\"particle_plate\", 10000)(f)\n", @@ -589,14 +729,43 @@ "weights_prior = np.exp(f_batch_metrics[\"log_weight\"])\n", "ess = q2_metrics[\"ess\"]\n", "# ess_prior = jnp.exp(f_batch_metrics[\"ess\"])\n", - "print(\"Variance and ess of importance weight using the prior as a proposal:\", np.var(weights_prior))\n", - "print(\"Variance and ess of importance weight using the new proposal as a proposal:\", np.var(weights))\n", - "\n", - "plt.plot(xrange_target, np.exp(normalized_log_target_density(xrange_target)), label=\"target density\", color=\"C1\")\n", - "plt.plot(xrange_target, np.exp(log_prior_density(xrange_target)), label=\"prior density\", color=\"C0\")\n", - "plt.plot(xrange_target, np.exp(log_proposal_density(xrange_target)), label=\"new proposal density\", color=\"C2\")\n", - "_ = plt.hist(approx_target_sampels, weights=weights, density=True, \n", - " bins=100, range=(xrange_target[0], xrange_target[-1]), color=\"C1\", alpha=0.5)\n", + "print(\n", + " \"Variance and ess of importance weight using the prior as a proposal:\",\n", + " np.var(weights_prior),\n", + ")\n", + "print(\n", + " \"Variance and ess of importance weight using the new proposal as a\"\n", + " \" proposal:\",\n", + " np.var(weights),\n", + ")\n", + "\n", + "plt.plot(\n", + " xrange_target,\n", + " np.exp(normalized_log_target_density(xrange_target)),\n", + " label=\"target density\",\n", + " color=\"C1\",\n", + ")\n", + "plt.plot(\n", + " xrange_target,\n", + " np.exp(log_prior_density(xrange_target)),\n", + " label=\"prior density\",\n", + " color=\"C0\",\n", + ")\n", + "plt.plot(\n", + " xrange_target,\n", + " np.exp(log_proposal_density(xrange_target)),\n", + " label=\"new proposal density\",\n", + " color=\"C2\",\n", + ")\n", + "_ = plt.hist(\n", + " approx_target_sampels,\n", + " weights=weights,\n", + " density=True,\n", + " bins=100,\n", + " range=(xrange_target[0], xrange_target[-1]),\n", + " color=\"C1\",\n", + " alpha=0.5,\n", + ")\n", "plt.legend();" ] }, @@ -661,9 +830,21 @@ "weights = jnp.exp(q3_metrics[\"log_weight\"])\n", "print(\"The log weights after resampling are all equal:\", weights)\n", "\n", - "plt.plot(xrange_target, np.exp(normalized_log_target_density(xrange_target)), label=\"target density\", color=\"C1\")\n", - "_ = plt.hist(approx_target_sampels, weights=weights, density=True, \n", - " bins=100, range=(xrange_target[0], xrange_target[-1]), color=\"C1\", alpha=0.5)\n", + "plt.plot(\n", + " xrange_target,\n", + " np.exp(normalized_log_target_density(xrange_target)),\n", + " label=\"target density\",\n", + " color=\"C1\",\n", + ")\n", + "_ = plt.hist(\n", + " approx_target_sampels,\n", + " weights=weights,\n", + " density=True,\n", + " bins=100,\n", + " range=(xrange_target[0], xrange_target[-1]),\n", + " color=\"C1\",\n", + " alpha=0.5,\n", + ")\n", "plt.legend();" ] }, @@ -718,24 +899,55 @@ "q_com = coix.compose(f2, q2)\n", "\n", "m_xy = np.dstack(np.meshgrid(xrange_ext, yrange_ext))\n", - "m_p_target = np.exp(log_extended_target_density(*m_xy.reshape(N_x * N_y, 2).T).reshape(N_y, N_x))\n", - "m_p_prior = np.exp(log_extended_prior_density(*m_xy.reshape(N_x * N_y, 2).T).reshape(N_y, N_x))\n", - "m_p_proposal = np.exp(log_extended_proposal_density(*m_xy.reshape(N_x * N_y, 2).T).reshape(N_y, N_x))\n", + "m_p_target = np.exp(\n", + " log_extended_target_density(*m_xy.reshape(N_x * N_y, 2).T).reshape(N_y, N_x)\n", + ")\n", + "m_p_prior = np.exp(\n", + " log_extended_prior_density(*m_xy.reshape(N_x * N_y, 2).T).reshape(N_y, N_x)\n", + ")\n", + "m_p_proposal = np.exp(\n", + " log_extended_proposal_density(*m_xy.reshape(N_x * N_y, 2).T).reshape(\n", + " N_y, N_x\n", + " )\n", + ")\n", "\n", "ax_xy, ax_x, ax_y = plot_extended_density_samples(q_com, color1=\"C2\")\n", "ax_x.plot(xrange_prior, np.exp(log_prior_density(xrange_prior)), color=\"C0\")\n", "ax_x.plot(xrange_prior, np.exp(log_proposal_density(xrange_prior)), color=\"C2\")\n", - "ax_x.plot(xrange_prior, np.exp(normalized_log_target_density(xrange_prior)), color=\"C1\")\n", - "ax_xy.contour(m_xy[..., 0], m_xy[..., 1], m_p_prior, levels=[0.05, 0.3], colors=\"C0\")\n", - "ax_xy.contour(m_xy[..., 0], m_xy[..., 1], m_p_proposal, levels=[0.05, 0.3], colors=\"C2\")\n", - "ax_xy.contour(m_xy[..., 0], m_xy[..., 1], m_p_target, levels=[0.05, 0.3], colors=\"C1\")\n", + "ax_x.plot(\n", + " xrange_prior,\n", + " np.exp(normalized_log_target_density(xrange_prior)),\n", + " color=\"C1\",\n", + ")\n", + "ax_xy.contour(\n", + " m_xy[..., 0], m_xy[..., 1], m_p_prior, levels=[0.05, 0.3], colors=\"C0\"\n", + ")\n", + "ax_xy.contour(\n", + " m_xy[..., 0], m_xy[..., 1], m_p_proposal, levels=[0.05, 0.3], colors=\"C2\"\n", + ")\n", + "ax_xy.contour(\n", + " m_xy[..., 0], m_xy[..., 1], m_p_target, levels=[0.05, 0.3], colors=\"C1\"\n", + ")\n", "handles, labels = ax_xy.get_legend_handles_labels()\n", - "handles.extend([lines.Line2D([0], [0], label='prior density of $extend(f,\\ f2)$', color='C0'),\n", - " lines.Line2D([0], [0], label='proposal denstity $compose(f2,\\ q2)$', color='C2'),\n", - " lines.Line2D([0], [0], label='target denstity $extend(f, f2)$ and $compose(f2,\\ q2)$', color='C1')])\n", + "handles.extend([\n", + " lines.Line2D(\n", + " [0], [0], label=\"prior density of $extend(f,\\ f2)$\", color=\"C0\"\n", + " ),\n", + " lines.Line2D(\n", + " [0], [0], label=\"proposal denstity $compose(f2,\\ q2)$\", color=\"C2\"\n", + " ),\n", + " lines.Line2D(\n", + " [0],\n", + " [0],\n", + " label=\"target denstity $extend(f, f2)$ and $compose(f2,\\ q2)$\",\n", + " color=\"C1\",\n", + " ),\n", + "])\n", "ax_xy.legend(handles=handles, loc=\"lower left\")\n", "\n", - "_, f_ext_trace, f_ext_metrics = traced_evaluate(numpyro.plate(\"particle_plate\", 10000)(p_ext), seed=0)()\n", + "_, f_ext_trace, f_ext_metrics = traced_evaluate(\n", + " numpyro.plate(\"particle_plate\", 10000)(p_ext), seed=0\n", + ")()\n", "_, _, q_com_metrics = traced_evaluate(q_com, seed=0)()\n", "w_ext = np.exp(f_ext_metrics[\"log_weight\"])\n", "w_com = np.exp(q_com_metrics[\"log_weight\"])\n", diff --git a/notebooks/tutorial_part2_vae.ipynb b/notebooks/tutorial_part2_vae.ipynb index 1502fb7..e2cc5fa 100644 --- a/notebooks/tutorial_part2_vae.ipynb +++ b/notebooks/tutorial_part2_vae.ipynb @@ -129,26 +129,27 @@ "import numpyro\n", "import numpyro.distributions as dist\n", "\n", + "\n", "def make_programs(f_enc, f_dec):\n", - " # prior_program does not use x argument but needs to pass it on to dec_program\n", - " def prior_program(params, x):\n", - " z = numpyro.sample(\"z\", dist.Normal(0, 1).expand((2,)).to_event(1)) \n", - " return (params, z, x)\n", - "\n", - " # arguments matche the output of prior_program\n", - " def dec_program(params, z, x=None):\n", - " _, dec_params = params\n", - " mean_x = f_dec.apply(dec_params, z)\n", - " x = numpyro.sample(\"x\", dist.Normal(mean_x, 1.).to_event(1), obs=x) \n", - " return (mean_x, x)\n", - " \n", - " def enc_program(params, x):\n", - " enc_params, _ = params\n", - " mean_z = f_enc.apply(enc_params, x)\n", - " z = numpyro.sample(\"z\", dist.Normal(mean_z, 0.01).to_event(1)) \n", - " return (mean_z, z)\n", - " \n", - " return prior_program, enc_program, dec_program" + " # prior_program does not use x argument but needs to pass it on to dec_program\n", + " def prior_program(params, x):\n", + " z = numpyro.sample(\"z\", dist.Normal(0, 1).expand((2,)).to_event(1))\n", + " return (params, z, x)\n", + "\n", + " # arguments matche the output of prior_program\n", + " def dec_program(params, z, x=None):\n", + " _, dec_params = params\n", + " mean_x = f_dec.apply(dec_params, z)\n", + " x = numpyro.sample(\"x\", dist.Normal(mean_x, 1.0).to_event(1), obs=x)\n", + " return (mean_x, x)\n", + "\n", + " def enc_program(params, x):\n", + " enc_params, _ = params\n", + " mean_z = f_enc.apply(enc_params, x)\n", + " z = numpyro.sample(\"z\", dist.Normal(mean_z, 0.01).to_event(1))\n", + " return (mean_z, z)\n", + "\n", + " return prior_program, enc_program, dec_program" ] }, { @@ -171,7 +172,9 @@ "import jax\n", "from flax import linen as nn\n", "\n", + "\n", "class mlp(nn.Module):\n", + "\n", " @nn.compact\n", " def __call__(self, x):\n", " x = nn.Dense(features=32)(x)\n", @@ -179,6 +182,7 @@ " x = nn.Dense(features=2)(x)\n", " return x\n", "\n", + "\n", "f_enc = f_dec = mlp()\n", "enc_key, dec_key = jax.random.split(jax.random.PRNGKey(0))\n", "enc_params = f_enc.init(enc_key, data_test.shape)\n", @@ -221,24 +225,38 @@ "outputs": [], "source": [ "import coix\n", + "\n", "coix.set_backend(\"coix.numpyro\")\n", "\n", + "\n", "def make_particle_plate(num_particles):\n", - " return numpyro.plate(\"particle\", num_particles, dim=-1)\n", + " return numpyro.plate(\"particle\", num_particles, dim=-1)\n", + "\n", "\n", "def model_compose(p1, p2):\n", - " def wrapper(*args, **kwargs):\n", - " return p2(*p1(*args, **kwargs))\n", - " return wrapper\n", + " def wrapper(*args, **kwargs):\n", + " return p2(*p1(*args, **kwargs))\n", + "\n", + " return wrapper\n", + "\n", + "\n", + "def make_target_and_inference_program(\n", + " prior_program, enc_program, dec_program, num_particles, loss_fn=None\n", + "):\n", + " target_program = make_particle_plate(num_particles)(\n", + " model_compose(prior_program, dec_program)\n", + " )\n", + " proposal_program = make_particle_plate(num_particles)(enc_program)\n", + " inference_program = coix.propose(\n", + " target_program, proposal_program, loss_fn=loss_fn\n", + " )\n", + " return target_program, inference_program\n", "\n", - "def make_target_and_inference_program(prior_program, enc_program, dec_program, num_particles, loss_fn=None):\n", - " target_program = make_particle_plate(num_particles)(model_compose(prior_program, dec_program))\n", - " proposal_program = make_particle_plate(num_particles)(enc_program)\n", - " inference_program = coix.propose(target_program, proposal_program, loss_fn=loss_fn)\n", - " return target_program, inference_program\n", "\n", "programs = make_programs(f_enc, f_dec)\n", - "target_program, inference_program = make_target_and_inference_program(*programs, num_particles=n_test)\n", + "target_program, inference_program = make_target_and_inference_program(\n", + " *programs, num_particles=n_test\n", + ")\n", "out, _, _ = coix.traced_evaluate(target_program, seed=0)(params, x=data_test)" ] }, @@ -280,7 +298,9 @@ } ], "source": [ - "(means, _), _, _ = coix.traced_evaluate(inference_program, seed=0)(params, x=data_test)\n", + "(means, _), _, _ = coix.traced_evaluate(inference_program, seed=0)(\n", + " params, x=data_test\n", + ")\n", "\n", "plt.figure(figsize=(10, 3))\n", "plt.subplot(121)\n", @@ -332,31 +352,40 @@ "import optax\n", "\n", "n_batch = 300\n", - "_, inference_program = make_target_and_inference_program(*programs, num_particles=n_batch, loss_fn=coix.loss.elbo_loss)\n", + "_, inference_program = make_target_and_inference_program(\n", + " *programs, num_particles=n_batch, loss_fn=coix.loss.elbo_loss\n", + ")\n", + "\n", "\n", "def loss_fn(rng_key, params, data):\n", - " _, _, metrics = coix.traced_evaluate(inference_program, seed=rng_key)(params, x=data)\n", - " return metrics[\"loss\"]\n", + " _, _, metrics = coix.traced_evaluate(inference_program, seed=rng_key)(\n", + " params, x=data\n", + " )\n", + " return metrics[\"loss\"]\n", + "\n", "\n", "optimizer = optax.adam(1e-1)\n", "opt_state = optimizer.init(params)\n", "\n", "rng_key = jax.random.PRNGKey(0)\n", + "\n", + "\n", "def step(step, params, opt_state, data):\n", - " step_key = jax.random.fold_in(rng_key, step)\n", - " batch_key, loss_key = jax.random.split(step_key, 2)\n", - " batch = jax.random.choice(batch_key, data, (n_batch,))\n", - " value, grads = jax.value_and_grad(loss_fn, argnums=1)(loss_key, params, batch)\n", - " updates, opt_state = optimizer.update(grads, opt_state)\n", - " params = optax.apply_updates(params, updates)\n", - " return value, params, opt_state\n", + " step_key = jax.random.fold_in(rng_key, step)\n", + " batch_key, loss_key = jax.random.split(step_key, 2)\n", + " batch = jax.random.choice(batch_key, data, (n_batch,))\n", + " value, grads = jax.value_and_grad(loss_fn, argnums=1)(loss_key, params, batch)\n", + " updates, opt_state = optimizer.update(grads, opt_state)\n", + " params = optax.apply_updates(params, updates)\n", + " return value, params, opt_state\n", + "\n", "\n", "losses = []\n", - "for i in range(100): \n", - " loss, params, opt_state = step(i, params, opt_state, data_train)\n", - " losses.append(loss)\n", - " if i % 10 == 0:\n", - " print(f\"Interation {i}: Loss {loss}\")\n", + "for i in range(100):\n", + " loss, params, opt_state = step(i, params, opt_state, data_train)\n", + " losses.append(loss)\n", + " if i % 10 == 0:\n", + " print(f\"Interation {i}: Loss {loss}\")\n", "losses = np.stack(losses)" ] }, @@ -396,14 +425,18 @@ } ], "source": [ - "_, inference_program = make_target_and_inference_program(*programs, num_particles=n_test)\n", + "_, inference_program = make_target_and_inference_program(\n", + " *programs, num_particles=n_test\n", + ")\n", "\n", "plt.figure(figsize=(15, 3))\n", "plt.subplot(131)\n", "plt.title(\"ELBO\")\n", "plt.plot(-losses)\n", "plt.subplot(132)\n", - "out, trace, metrics = coix.traced_evaluate(inference_program, seed=0)(params, x=data_test)\n", + "out, trace, metrics = coix.traced_evaluate(inference_program, seed=0)(\n", + " params, x=data_test\n", + ")\n", "samples, means = out\n", "plt.title(\"Test Data\")\n", "plt.scatter(*samples.T, alpha=0.1)\n", diff --git a/notebooks/tutorial_part3_smcs.ipynb b/notebooks/tutorial_part3_smcs.ipynb index 76cdadf..689583d 100644 --- a/notebooks/tutorial_part3_smcs.ipynb +++ b/notebooks/tutorial_part3_smcs.ipynb @@ -97,22 +97,28 @@ ], "source": [ "import matplotlib.pyplot as plt\n", - "import numpy as np \n", + "import numpy as np\n", "import jax.numpy as jnp\n", "import flax\n", "import flax.linen as nn\n", "import numpyro.distributions as dist\n", "import numpyro\n", + "\n", "numpyro.set_platform(\"cpu\")\n", "\n", + "\n", "def ring_gmm_log_density(x, M):\n", - " angles = 2 * jnp.arange(1, M + 1) * jnp.pi / M\n", - " mu = 10 * jnp.stack([jnp.sin(angles), jnp.cos(angles)], -1)\n", - " sigma = jnp.sqrt(0.5)\n", - " return nn.logsumexp(dist.Normal(mu, sigma).log_prob(x[..., None, :]).sum(-1), -1)\n", + " angles = 2 * jnp.arange(1, M + 1) * jnp.pi / M\n", + " mu = 10 * jnp.stack([jnp.sin(angles), jnp.cos(angles)], -1)\n", + " sigma = jnp.sqrt(0.5)\n", + " return nn.logsumexp(\n", + " dist.Normal(mu, sigma).log_prob(x[..., None, :]).sum(-1), -1\n", + " )\n", + "\n", "\n", "def proposal_log_density(x):\n", - " return dist.Normal(0, 5).log_prob(x).sum(-1)\n", + " return dist.Normal(0, 5).log_prob(x).sum(-1)\n", + "\n", "\n", "xrange = np.linspace(-12, 12, 100)\n", "m_xy = np.dstack(np.meshgrid(xrange, xrange))\n", @@ -152,18 +158,20 @@ "outputs": [], "source": [ "class AnnealedDensity(nn.Module):\n", - " M = 8\n", - "\n", - " @nn.compact\n", - " def __call__(self, x, index=0):\n", - " beta_raw = self.param(\"beta_raw\", lambda _: -jnp.ones(self.M - 2))\n", - " beta = nn.sigmoid(beta_raw[0] + jnp.pad(jnp.cumsum(nn.softplus(beta_raw[1:])), (1, 0)))\n", - " beta = jnp.pad(beta, (1, 1), constant_values=(0, 1))\n", - " beta_k = beta[index]\n", - "\n", - " target_density = ring_gmm_log_density(x, self.M)\n", - " init_proposal = proposal_log_density(x)\n", - " return beta_k * target_density + (1 - beta_k) * init_proposal" + " M = 8\n", + "\n", + " @nn.compact\n", + " def __call__(self, x, index=0):\n", + " beta_raw = self.param(\"beta_raw\", lambda _: -jnp.ones(self.M - 2))\n", + " beta = nn.sigmoid(\n", + " beta_raw[0] + jnp.pad(jnp.cumsum(nn.softplus(beta_raw[1:])), (1, 0))\n", + " )\n", + " beta = jnp.pad(beta, (1, 1), constant_values=(0, 1))\n", + " beta_k = beta[index]\n", + "\n", + " target_density = ring_gmm_log_density(x, self.M)\n", + " init_proposal = proposal_log_density(x)\n", + " return beta_k * target_density + (1 - beta_k) * init_proposal" ] }, { @@ -186,39 +194,49 @@ "outputs": [], "source": [ "class VariationalKernelNetwork(nn.Module):\n", - " @nn.compact\n", - " def __call__(self, x):\n", - " h = nn.Dense(50)(x)\n", - " h = nn.relu(h)\n", - " loc = nn.Dense(2, kernel_init=nn.initializers.zeros)(h) + x\n", - " scale_raw = nn.Dense(2, kernel_init=nn.initializers.zeros)(h)\n", - " return loc, nn.softplus(scale_raw)\n", "\n", + " @nn.compact\n", + " def __call__(self, x):\n", + " h = nn.Dense(50)(x)\n", + " h = nn.relu(h)\n", + " loc = nn.Dense(2, kernel_init=nn.initializers.zeros)(h) + x\n", + " scale_raw = nn.Dense(2, kernel_init=nn.initializers.zeros)(h)\n", + " return loc, nn.softplus(scale_raw)\n", "\n", - "class VariationalKernelNetworks(nn.Module):\n", - " M = 8\n", "\n", - " @nn.compact\n", - " def __call__(self, x, index=0):\n", - " if self.is_mutable_collection('params'):\n", - " vmap_net = nn.vmap(VariationalKernelNetwork, variable_axes={'params': 0}, split_rngs={'params': True})\n", - " out = vmap_net(name='kernel')(jnp.broadcast_to(x, (self.M - 1,) + x.shape))\n", - " return jax.tree_util.tree_map(lambda x: x[index], out)\n", - " params = self.scope.get_variable('params', 'kernel')\n", - " params_i = jax.tree_util.tree_map(lambda x: x[index], params)\n", - " return VariationalKernelNetwork(name='kernel').apply(flax.core.freeze({\"params\": params_i}), x)\n", + "class VariationalKernelNetworks(nn.Module):\n", + " M = 8\n", + "\n", + " @nn.compact\n", + " def __call__(self, x, index=0):\n", + " if self.is_mutable_collection('params'):\n", + " vmap_net = nn.vmap(\n", + " VariationalKernelNetwork,\n", + " variable_axes={'params': 0},\n", + " split_rngs={'params': True},\n", + " )\n", + " out = vmap_net(name='kernel')(\n", + " jnp.broadcast_to(x, (self.M - 1,) + x.shape)\n", + " )\n", + " return jax.tree_util.tree_map(lambda x: x[index], out)\n", + " params = self.scope.get_variable('params', 'kernel')\n", + " params_i = jax.tree_util.tree_map(lambda x: x[index], params)\n", + " return VariationalKernelNetwork(name='kernel').apply(\n", + " flax.core.freeze({'params': params_i}), x\n", + " )\n", "\n", "\n", "class Networks(nn.Module):\n", - " def setup(self):\n", - " self.forward_kernel_params = VariationalKernelNetworks()\n", - " self.reverse_kernel_params = VariationalKernelNetworks()\n", - " self.anneal_density = AnnealedDensity()\n", - "\n", - " def __call__(self, x):\n", - " self.reverse_kernel_params(x)\n", - " self.anneal_density(x)\n", - " return self.forward_kernel_params(x)" + "\n", + " def setup(self):\n", + " self.forward_kernel_params = VariationalKernelNetworks()\n", + " self.reverse_kernel_params = VariationalKernelNetworks()\n", + " self.anneal_density = AnnealedDensity()\n", + "\n", + " def __call__(self, x):\n", + " self.reverse_kernel_params(x)\n", + " self.anneal_density(x)\n", + " return self.forward_kernel_params(x)" ] }, { @@ -241,18 +259,22 @@ "outputs": [], "source": [ "def anneal_target(network, k=0):\n", - " x = numpyro.sample(\"x\", dist.Normal(0, 5).expand([2]).mask(False).to_event())\n", - " #numpyro.factor(\"anneal_density\", network.anneal_density(x, index=k))\n", - " numpyro.sample(\"anneal_density\", dist.Unit(network.anneal_density(x, index=k)))\n", - " return {\"x\": x},\n", + " x = numpyro.sample(\"x\", dist.Normal(0, 5).expand([2]).mask(False).to_event())\n", + " # numpyro.factor(\"anneal_density\", network.anneal_density(x, index=k))\n", + " numpyro.sample(\n", + " \"anneal_density\", dist.Unit(network.anneal_density(x, index=k))\n", + " )\n", + " return ({\"x\": x},)\n", + "\n", "\n", "def anneal_forward(network, inputs, k=0):\n", - " mu, sigma = network.forward_kernel_params(inputs[\"x\"], index=k)\n", - " return numpyro.sample(\"x\", dist.Normal(mu, sigma).to_event(1))\n", + " mu, sigma = network.forward_kernel_params(inputs[\"x\"], index=k)\n", + " return numpyro.sample(\"x\", dist.Normal(mu, sigma).to_event(1))\n", + "\n", "\n", "def anneal_reverse(network, inputs, k=0):\n", - " mu, sigma = network.reverse_kernel_params(inputs[\"x\"], index=k)\n", - " return numpyro.sample(\"x\", dist.Normal(mu, sigma).to_event(1))" + " mu, sigma = network.reverse_kernel_params(inputs[\"x\"], index=k)\n", + " return numpyro.sample(\"x\", dist.Normal(mu, sigma).to_event(1))" ] }, { @@ -278,21 +300,31 @@ "import jax\n", "from jax import random\n", "import coix\n", + "\n", "coix.set_backend(\"coix.numpyro\")\n", "\n", + "\n", "def make_anneal(params, unroll=False, num_particles=10, num_targets=8):\n", - " network = coix.util.BindModule(Networks(), params)\n", - " # Add particle dimension and construct a program.\n", - " make_particle_plate = lambda: numpyro.plate(\"particle\", num_particles, dim=-1)\n", - " targets = lambda k: make_particle_plate()(partial(anneal_target, network, k=k))\n", - " forwards = lambda k: make_particle_plate()(partial(anneal_forward, network, k=k))\n", - " reverses = lambda k: make_particle_plate()(partial(anneal_reverse, network, k=k))\n", - " if unroll: # to unroll the algorithm, we provide a list of programs\n", - " targets = [targets(k) for k in range(num_targets)]\n", - " forwards = [forwards(k) for k in range(num_targets-1)]\n", - " reverses = [reverses(k) for k in range(num_targets-1)]\n", - " program = coix.algo.nvi_rkl(targets, forwards, reverses, num_targets=num_targets)\n", - " return program" + " network = coix.util.BindModule(Networks(), params)\n", + " # Add particle dimension and construct a program.\n", + " make_particle_plate = lambda: numpyro.plate(\"particle\", num_particles, dim=-1)\n", + " targets = lambda k: make_particle_plate()(\n", + " partial(anneal_target, network, k=k)\n", + " )\n", + " forwards = lambda k: make_particle_plate()(\n", + " partial(anneal_forward, network, k=k)\n", + " )\n", + " reverses = lambda k: make_particle_plate()(\n", + " partial(anneal_reverse, network, k=k)\n", + " )\n", + " if unroll: # to unroll the algorithm, we provide a list of programs\n", + " targets = [targets(k) for k in range(num_targets)]\n", + " forwards = [forwards(k) for k in range(num_targets - 1)]\n", + " reverses = [reverses(k) for k in range(num_targets - 1)]\n", + " program = coix.algo.nvi_rkl(\n", + " targets, forwards, reverses, num_targets=num_targets\n", + " )\n", + " return program" ] }, { @@ -337,23 +369,31 @@ "from matplotlib.animation import FuncAnimation\n", "from matplotlib.patches import Ellipse, Rectangle\n", "\n", + "\n", "def eval_program(seed, params, num_particles):\n", - " with numpyro.handlers.seed(rng_seed=seed):\n", - " p = make_anneal(params, unroll=True, num_particles=num_particles)\n", - " out, trace, metrics = coix.traced_evaluate(p)()\n", - " return out, trace, metrics\n", + " with numpyro.handlers.seed(rng_seed=seed):\n", + " p = make_anneal(params, unroll=True, num_particles=num_particles)\n", + " out, trace, metrics = coix.traced_evaluate(p)()\n", + " return out, trace, metrics\n", + "\n", "\n", "anneal_net = Networks()\n", "init_params = anneal_net.init(random.PRNGKey(0), jnp.zeros(2))\n", - "_, trace, metrics = eval_program(random.PRNGKey(1), init_params, num_particles=100000)\n", + "_, trace, metrics = eval_program(\n", + " random.PRNGKey(1), init_params, num_particles=100000\n", + ")\n", "\n", "metrics.pop(\"log_weight\")\n", - "anneal_metrics = jax.tree_util.tree_map(lambda x: round(float(jnp.mean(x)), 4), metrics)\n", + "anneal_metrics = jax.tree_util.tree_map(\n", + " lambda x: round(float(jnp.mean(x)), 4), metrics\n", + ")\n", "print(anneal_metrics)\n", "\n", "fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2)\n", "x = trace[\"x\"][\"value\"].reshape((-1, 2))\n", - "H, xedges, yedges = np.histogram2d(x[:, 0], x[:, 1], range=[[-12, 12], [-12, 12]], bins=100)\n", + "H, xedges, yedges = np.histogram2d(\n", + " x[:, 0], x[:, 1], range=[[-12, 12], [-12, 12]], bins=100\n", + ")\n", "ax1.set_title(\"Untrained Proposal Density\")\n", "ax1.imshow(H.T)\n", "xax1, yax1 = ax1.axes.get_xaxis(), ax1.axes.get_yaxis()\n", @@ -425,12 +465,14 @@ "source": [ "import optax\n", "\n", + "\n", "def loss_fn(params, key, num_particles, unroll=False):\n", - " # Run the program and get metrics.\n", - " program = make_anneal(params, num_particles=num_particles, unroll=unroll)\n", - " with numpyro.handlers.seed(rng_seed=key):\n", - " _, _, metrics = coix.traced_evaluate(program)()\n", - " return metrics[\"loss\"], metrics\n", + " # Run the program and get metrics.\n", + " program = make_anneal(params, num_particles=num_particles, unroll=unroll)\n", + " with numpyro.handlers.seed(rng_seed=key):\n", + " _, _, metrics = coix.traced_evaluate(program)()\n", + " return metrics[\"loss\"], metrics\n", + "\n", "\n", "optimizer = optax.adam(1e-3)\n", "num_steps = 50000\n", @@ -439,7 +481,11 @@ "\n", "trained_params, metrics = coix.util.train(\n", " partial(loss_fn, num_particles=num_particles, unroll=unroll),\n", - " init_params, optimizer, num_steps, jit_compile=True)" + " init_params,\n", + " optimizer,\n", + " num_steps,\n", + " jit_compile=True,\n", + ")" ] }, { @@ -486,14 +532,20 @@ } ], "source": [ - "_, trace, metrics = eval_program(random.PRNGKey(1), trained_params, num_particles=100000)\n", + "_, trace, metrics = eval_program(\n", + " random.PRNGKey(1), trained_params, num_particles=100000\n", + ")\n", "\n", - "anneal_metrics = jax.tree_util.tree_map(lambda x: round(float(jnp.mean(x)), 4), metrics)\n", + "anneal_metrics = jax.tree_util.tree_map(\n", + " lambda x: round(float(jnp.mean(x)), 4), metrics\n", + ")\n", "print(anneal_metrics)\n", "\n", "fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2)\n", "x = trace[\"x\"][\"value\"].reshape((-1, 2))\n", - "m_proposal, _, _= np.histogram2d(x[:, 0], x[:, 1], range=[[-12, 12], [-12, 12]], bins=100)\n", + "m_proposal, _, _ = np.histogram2d(\n", + " x[:, 0], x[:, 1], range=[[-12, 12], [-12, 12]], bins=100\n", + ")\n", "ax1.set_title(\"Trained Proposal Density\")\n", "ax1.imshow(m_proposal.T)\n", "xax1, yax1 = ax1.axes.get_xaxis(), ax1.axes.get_yaxis()\n",