Skip to content

Commit

Permalink
Added reconciling experts notebook smoke test.
Browse files Browse the repository at this point in the history
  • Loading branch information
BenZickel committed Dec 31, 2024
1 parent 41f2946 commit 200fd36
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions tutorial/source/reconciling_experts.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"$$\n",
"with the $x_i$ being the observations and the $\\epsilon_i$ being the innovations.\n",
"\n",
"We start with some imports."
"We start with some imports and setup."
]
},
{
Expand All @@ -35,14 +35,17 @@
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import pyro\n",
"import torch, pyro, os\n",
"\n",
"import pyro.distributions as dist\n",
"from pyro.nn import PyroModuleList, PyroModule, PyroSample\n",
"from pyro.infer.reparam import SplitReparam\n",
"\n",
"from matplotlib import pyplot as plt"
"from matplotlib import pyplot as plt\n",
"\n",
"smoke_test = ('CI' in os.environ) # for use in continuous integration testing\n",
"num_samples = 2 if smoke_test else 1000\n",
"num_svi_iter = 2 if smoke_test else 1001"
]
},
{
Expand Down Expand Up @@ -240,7 +243,7 @@
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x25859059a90>"
"<graphviz.graphs.Digraph at 0x24445b0f8f0>"
]
},
"execution_count": 6,
Expand Down Expand Up @@ -337,7 +340,7 @@
" optimizer = pyro.optim.Adam(dict(lr=0.01))\n",
" loss = pyro.infer.JitTrace_ELBO(num_particles=20, vectorize_particles=True, ignore_jit_warnings=True)\n",
" svi = pyro.infer.SVI(model, guide, optimizer, loss)\n",
" for count in range(1001):\n",
" for count in range(num_svi_iter):\n",
" loss = svi.step(*args, **kwargs)\n",
" if count % 100 == 0:\n",
" print(f\"iteration {count} loss = {loss}\")\n",
Expand Down Expand Up @@ -374,7 +377,7 @@
],
"source": [
"sampler = pyro.infer.Predictive(zero_ultimate_model, guide=guide,\n",
" num_samples=1000, parallel=True, return_sites=('_RETURN',))\n",
" num_samples=num_samples, parallel=True, return_sites=('_RETURN',))\n",
"samples = sampler()['_RETURN']\n",
"\n",
"plt.figure()\n",
Expand Down Expand Up @@ -464,7 +467,7 @@
],
"source": [
"sampler = pyro.infer.Predictive(same_ultimate_model, guide=guide,\n",
" num_samples=1000, parallel=True, return_sites=('_RETURN',))\n",
" num_samples=num_samples, parallel=True, return_sites=('_RETURN',))\n",
"samples = sampler()['_RETURN']\n",
"\n",
"plt.figure()\n",
Expand Down Expand Up @@ -581,7 +584,7 @@
],
"source": [
"new_sampler = pyro.infer.Predictive(conditioned_same_ultimate_model, guide=guide,\n",
" num_samples=1000, parallel=True, return_sites=('_RETURN',))\n",
" num_samples=num_samples, parallel=True, return_sites=('_RETURN',))\n",
"new_samples = new_sampler()['_RETURN']\n",
"\n",
"plt.figure()\n",
Expand Down

0 comments on commit 200fd36

Please sign in to comment.