From ea642d2d2005c8981ae96ef8cd3cc82782ad6eb1 Mon Sep 17 00:00:00 2001 From: Heiko Zimmermann Date: Tue, 16 Apr 2024 21:05:46 +0200 Subject: [PATCH] Minor fixes in tutorials --- notebooks/tutorial_part2_api.ipynb | 64 ++++++++++++++---------------- 1 file changed, 30 insertions(+), 34 deletions(-) diff --git a/notebooks/tutorial_part2_api.ipynb b/notebooks/tutorial_part2_api.ipynb index 18eecb3..130ed7d 100644 --- a/notebooks/tutorial_part2_api.ipynb +++ b/notebooks/tutorial_part2_api.ipynb @@ -153,7 +153,7 @@ "As mentioned above, we call a promitive program without any observe statements a *kernel program*.\n", "In combinators each primitive program denotes two densities:\n", "1. a **prior density**, which is defined as the joint density over all unobserverd variables in the program\n", - "2. an **unnormalized target density**, which is defined as the joint density over all variables in the program\n", + "2. an **unnormalized target density**, which is defined as the prior density multiplied by the product over the densities of the observed variables in the program\n", "\n", "To get a better understanding of these densities and why their distinction is important, let's visualize these densities for the primitive program `f` that we defined above:\n", "1. The prior density is given by the denstity of the normal distribution\n", @@ -326,7 +326,7 @@ "_, f_batch_trace, f_batch_metrics = traced_evaluate(\n", " numpyro.plate(\"particle_plate\", 10000)(f), seed=0\n", ")()\n", - "approx_target_sampels = f_batch_trace[\"x\"][\"value\"]\n", + "approx_target_samples = f_batch_trace[\"x\"][\"value\"]\n", "weights = jnp.exp(f_batch_metrics[\"log_weight\"])\n", "\n", "print(\"Normalizing constant:\", Z_target)\n", @@ -369,7 +369,7 @@ " color=\"C1\",\n", ")\n", "_ = plt.hist(\n", - " approx_target_sampels,\n", + " approx_target_samples,\n", " weights=weights,\n", " density=True,\n", " bins=100,\n", @@ -564,16 +564,14 @@ " m_xy[..., 0], m_xy[..., 1], m_p_target, levels=[0.05, 0.3], colors=\"C1\"\n", ")\n", "handles, labels = ax_xy.get_legend_handles_labels()\n", - "handles.extend(\n", - " [\n", - " lines.Line2D(\n", - " [0], [0], label=\"prior density of $extend(f,\\ k)$\", color=\"C0\"\n", - " ),\n", - " lines.Line2D(\n", - " [0], [0], label=\"target denstity of $extend(f,\\ k)$\", color=\"C1\"\n", - " ),\n", - " ]\n", - ")\n", + "handles.extend([\n", + " lines.Line2D(\n", + " [0], [0], label=\"prior density of $extend(f,\\ k)$\", color=\"C0\"\n", + " ),\n", + " lines.Line2D(\n", + " [0], [0], label=\"target denstity of $extend(f,\\ k)$\", color=\"C1\"\n", + " ),\n", + "])\n", "ax_xy.legend(handles=handles, loc=\"lower left\");" ] }, @@ -772,7 +770,7 @@ "_, q2_trace, q2_metrics = traced_evaluate(q2, seed=0)()\n", "_, _, f_batch_metrics = traced_evaluate(f_batch, seed=0)()\n", "\n", - "approx_target_sampels = q2_trace[\"x\"][\"value\"]\n", + "approx_target_samples = q2_trace[\"x\"][\"value\"]\n", "weights = jnp.exp(q2_metrics[\"log_weight\"])\n", "weights_prior = np.exp(f_batch_metrics[\"log_weight\"])\n", "ess = q2_metrics[\"ess\"]\n", @@ -806,7 +804,7 @@ " color=\"C2\",\n", ")\n", "_ = plt.hist(\n", - " approx_target_sampels,\n", + " approx_target_samples,\n", " weights=weights,\n", " density=True,\n", " bins=100,\n", @@ -874,7 +872,7 @@ "source": [ "q3 = coix.resample(q2)\n", "_, q3_trace, q3_metrics = traced_evaluate(q3, seed=0)()\n", - "approx_target_sampels = q3_trace[\"x\"][\"value\"]\n", + "approx_target_samples = q3_trace[\"x\"][\"value\"]\n", "weights = jnp.exp(q3_metrics[\"log_weight\"])\n", "print(\"The log weights after resampling are all equal:\", weights)\n", "\n", @@ -885,7 +883,7 @@ " color=\"C1\",\n", ")\n", "_ = plt.hist(\n", - " approx_target_sampels,\n", + " approx_target_samples,\n", " weights=weights,\n", " density=True,\n", " bins=100,\n", @@ -968,22 +966,20 @@ " m_xy[..., 0], m_xy[..., 1], m_p_target, levels=[0.05, 0.3], colors=\"C1\"\n", ")\n", "handles, labels = ax_xy.get_legend_handles_labels()\n", - "handles.extend(\n", - " [\n", - " lines.Line2D(\n", - " [0], [0], label=\"prior density of $extend(f,\\ k)$\", color=\"C0\"\n", - " ),\n", - " lines.Line2D(\n", - " [0], [0], label=\"proposal denstity $compose(k,\\ q2)$\", color=\"C2\"\n", - " ),\n", - " lines.Line2D(\n", - " [0],\n", - " [0],\n", - " label=\"target denstity $extend(f, k)$ and $compose(k,\\ q2)$\",\n", - " color=\"C1\",\n", - " ),\n", - " ]\n", - ")\n", + "handles.extend([\n", + " lines.Line2D(\n", + " [0], [0], label=\"prior density of $extend(f,\\ k)$\", color=\"C0\"\n", + " ),\n", + " lines.Line2D(\n", + " [0], [0], label=\"proposal denstity $compose(k,\\ q2)$\", color=\"C2\"\n", + " ),\n", + " lines.Line2D(\n", + " [0],\n", + " [0],\n", + " label=\"target denstity $extend(f, k)$ and $compose(k,\\ q2)$\",\n", + " color=\"C1\",\n", + " ),\n", + "])\n", "ax_xy.legend(handles=handles, loc=\"lower left\")\n", "\n", "_, f_ext_trace, f_ext_metrics = traced_evaluate(\n", @@ -1014,7 +1010,7 @@ "source": [ "### Takeaway\n", "\n", - "We are now ready to start combining programs using inference combinators and as long as we follow the rules of the grammar the resulting programs are valid, in the sense that they produce propoerly weighted sampels for the target densities they define.\n", + "We are now ready to start combining programs using inference combinators and as long as we follow the rules of the grammar the resulting programs are valid, in the sense that they produce propoerly weighted samples for the target densities they define.\n", "\n", "To ensure that all evaluations are properly weighted, more general programs are more restricted in the ways they can be combined with other programs. If in doubt, check the grammar!\n", "\n",