Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor fixes in tutorials #30

Merged
merged 1 commit into from
Apr 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 30 additions & 34 deletions notebooks/tutorial_part2_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@
"As mentioned above, we call a promitive program without any observe statements a *kernel program*.\n",
"In combinators each primitive program denotes two densities:\n",
"1. a **prior density**, which is defined as the joint density over all unobserverd variables in the program\n",
"2. an **unnormalized target density**, which is defined as the joint density over all variables in the program\n",
"2. an **unnormalized target density**, which is defined as the prior density multiplied by the product over the densities of the observed variables in the program\n",
"\n",
"To get a better understanding of these densities and why their distinction is important, let's visualize these densities for the primitive program `f` that we defined above:\n",
"1. The prior density is given by the denstity of the normal distribution\n",
Expand Down Expand Up @@ -326,7 +326,7 @@
"_, f_batch_trace, f_batch_metrics = traced_evaluate(\n",
" numpyro.plate(\"particle_plate\", 10000)(f), seed=0\n",
")()\n",
"approx_target_sampels = f_batch_trace[\"x\"][\"value\"]\n",
"approx_target_samples = f_batch_trace[\"x\"][\"value\"]\n",
"weights = jnp.exp(f_batch_metrics[\"log_weight\"])\n",
"\n",
"print(\"Normalizing constant:\", Z_target)\n",
Expand Down Expand Up @@ -369,7 +369,7 @@
" color=\"C1\",\n",
")\n",
"_ = plt.hist(\n",
" approx_target_sampels,\n",
" approx_target_samples,\n",
" weights=weights,\n",
" density=True,\n",
" bins=100,\n",
Expand Down Expand Up @@ -564,16 +564,14 @@
" m_xy[..., 0], m_xy[..., 1], m_p_target, levels=[0.05, 0.3], colors=\"C1\"\n",
")\n",
"handles, labels = ax_xy.get_legend_handles_labels()\n",
"handles.extend(\n",
" [\n",
" lines.Line2D(\n",
" [0], [0], label=\"prior density of $extend(f,\\ k)$\", color=\"C0\"\n",
" ),\n",
" lines.Line2D(\n",
" [0], [0], label=\"target denstity of $extend(f,\\ k)$\", color=\"C1\"\n",
" ),\n",
" ]\n",
")\n",
"handles.extend([\n",
" lines.Line2D(\n",
" [0], [0], label=\"prior density of $extend(f,\\ k)$\", color=\"C0\"\n",
" ),\n",
" lines.Line2D(\n",
" [0], [0], label=\"target denstity of $extend(f,\\ k)$\", color=\"C1\"\n",
" ),\n",
"])\n",
"ax_xy.legend(handles=handles, loc=\"lower left\");"
]
},
Expand Down Expand Up @@ -772,7 +770,7 @@
"_, q2_trace, q2_metrics = traced_evaluate(q2, seed=0)()\n",
"_, _, f_batch_metrics = traced_evaluate(f_batch, seed=0)()\n",
"\n",
"approx_target_sampels = q2_trace[\"x\"][\"value\"]\n",
"approx_target_samples = q2_trace[\"x\"][\"value\"]\n",
"weights = jnp.exp(q2_metrics[\"log_weight\"])\n",
"weights_prior = np.exp(f_batch_metrics[\"log_weight\"])\n",
"ess = q2_metrics[\"ess\"]\n",
Expand Down Expand Up @@ -806,7 +804,7 @@
" color=\"C2\",\n",
")\n",
"_ = plt.hist(\n",
" approx_target_sampels,\n",
" approx_target_samples,\n",
" weights=weights,\n",
" density=True,\n",
" bins=100,\n",
Expand Down Expand Up @@ -874,7 +872,7 @@
"source": [
"q3 = coix.resample(q2)\n",
"_, q3_trace, q3_metrics = traced_evaluate(q3, seed=0)()\n",
"approx_target_sampels = q3_trace[\"x\"][\"value\"]\n",
"approx_target_samples = q3_trace[\"x\"][\"value\"]\n",
"weights = jnp.exp(q3_metrics[\"log_weight\"])\n",
"print(\"The log weights after resampling are all equal:\", weights)\n",
"\n",
Expand All @@ -885,7 +883,7 @@
" color=\"C1\",\n",
")\n",
"_ = plt.hist(\n",
" approx_target_sampels,\n",
" approx_target_samples,\n",
" weights=weights,\n",
" density=True,\n",
" bins=100,\n",
Expand Down Expand Up @@ -968,22 +966,20 @@
" m_xy[..., 0], m_xy[..., 1], m_p_target, levels=[0.05, 0.3], colors=\"C1\"\n",
")\n",
"handles, labels = ax_xy.get_legend_handles_labels()\n",
"handles.extend(\n",
" [\n",
" lines.Line2D(\n",
" [0], [0], label=\"prior density of $extend(f,\\ k)$\", color=\"C0\"\n",
" ),\n",
" lines.Line2D(\n",
" [0], [0], label=\"proposal denstity $compose(k,\\ q2)$\", color=\"C2\"\n",
" ),\n",
" lines.Line2D(\n",
" [0],\n",
" [0],\n",
" label=\"target denstity $extend(f, k)$ and $compose(k,\\ q2)$\",\n",
" color=\"C1\",\n",
" ),\n",
" ]\n",
")\n",
"handles.extend([\n",
" lines.Line2D(\n",
" [0], [0], label=\"prior density of $extend(f,\\ k)$\", color=\"C0\"\n",
" ),\n",
" lines.Line2D(\n",
" [0], [0], label=\"proposal denstity $compose(k,\\ q2)$\", color=\"C2\"\n",
" ),\n",
" lines.Line2D(\n",
" [0],\n",
" [0],\n",
" label=\"target denstity $extend(f, k)$ and $compose(k,\\ q2)$\",\n",
" color=\"C1\",\n",
" ),\n",
"])\n",
"ax_xy.legend(handles=handles, loc=\"lower left\")\n",
"\n",
"_, f_ext_trace, f_ext_metrics = traced_evaluate(\n",
Expand Down Expand Up @@ -1014,7 +1010,7 @@
"source": [
"### Takeaway\n",
"\n",
"We are now ready to start combining programs using inference combinators and as long as we follow the rules of the grammar the resulting programs are valid, in the sense that they produce propoerly weighted sampels for the target densities they define.\n",
"We are now ready to start combining programs using inference combinators and as long as we follow the rules of the grammar the resulting programs are valid, in the sense that they produce propoerly weighted samples for the target densities they define.\n",
"\n",
"To ensure that all evaluations are properly weighted, more general programs are more restricted in the ways they can be combined with other programs. If in doubt, check the grammar!\n",
"\n",
Expand Down
Loading