Skip to content

Commit

Permalink
NB
Browse files Browse the repository at this point in the history
  • Loading branch information
victor committed Jul 14, 2024
1 parent a763c75 commit 0747d38
Showing 1 changed file with 104 additions and 0 deletions.
104 changes: 104 additions & 0 deletions notebook/mauna-loa-co2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,110 @@
"start_time": "2024-07-14T15:06:40.960298Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"Alright, predictions look good seasonally wise. However, it looks like the drift component is highly underestimated, or that it might be time varying. So, let's alter our model and use an integrated random walk"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"sample: 100%|██████████| 6000/6000 [00:27<00:00, 215.43it/s, 31 steps of size 1.50e-01. acc. prob=0.90] \n",
"sample: 100%|██████████| 6000/6000 [00:20<00:00, 296.37it/s, 31 steps of size 1.56e-01. acc. prob=0.89] \n",
"sample: 100%|██████████| 6000/6000 [00:20<00:00, 295.69it/s, 31 steps of size 1.32e-01. acc. prob=0.92]\n",
"warmup: 14%|█▍ | 833/6000 [00:06<00:10, 488.38it/s, 15 steps of size 2.69e-01. acc. prob=0.79] "
]
}
],
"source": [
"from numpyro_sts import SmoothLocalLinearTrend\n",
"\n",
"\n",
"def model(n: int, y: np.ndarray = None, mask: bool = True, num_seasons: int = 12, future: int = 0):\n",
" # level component\n",
" with numpyro.handlers.scope(prefix=\"level\"):\n",
" std = numpyro.sample(\"std\", HalfNormal(scale=100.0))\n",
"\n",
" with numpyro.plate(\"factors\", 2):\n",
" x_0_loc = numpyro.sample(\"x_0_loc\", Normal(scale=1000.0))\n",
" x_0 = x_0_loc + numpyro.sample(\"x_0\", Normal())\n",
"\n",
" level_model = SmoothLocalLinearTrend(n, std, x_0)\n",
" level = numpyro.sample(\"x\", level_model)\n",
"\n",
" if future > 0:\n",
" future_level = numpyro.sample(\"x_future\", level_model.predict(future, level[-1]))\n",
"\n",
" # seasonality\n",
" with numpyro.handlers.scope(prefix=\"seasonality\"):\n",
" x_0 = numpyro.sample(\"x_0\", ZeroSumNormal(scale=10.0, event_shape=(num_seasons,)))[:-1]\n",
"\n",
" seasonality_model = periodic.TimeSeasonal(n, num_seasons, 0.0, x_0)\n",
" seasonality = numpyro.deterministic(\"x\", seasonality_model.deterministic())\n",
"\n",
" if future > 0:\n",
" future_seasonality = numpyro.deterministic(\n",
" \"x_future\", seasonality_model.predict(future, seasonality[-1]).deterministic()\n",
" )\n",
"\n",
" # observable\n",
" std = numpyro.sample(\"std\", HalfCauchy())\n",
" loc = level[..., 0] + seasonality[..., 0]\n",
"\n",
" with numpyro.handlers.mask(mask=mask):\n",
" y_ = numpyro.sample(\"y\", Normal(loc, std), obs=y)\n",
"\n",
" if future > 0:\n",
" loc_future = future_level[..., 0] + future_seasonality[..., 0]\n",
" y_future = numpyro.sample(\"y_future\", Normal(loc_future, std))\n",
"\n",
" return\n",
"\n",
"mcmc.run(key, n=train.shape[0], y=train, mask=np.isfinite(train))"
],
"metadata": {
"collapsed": false,
"is_executing": true,
"ExecuteTime": {
"start_time": "2024-07-14T15:16:37.774058Z"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"new_samples = mcmc.get_samples()\n",
"\n",
"predictive = Predictive(model, posterior_samples=new_samples)\n",
"new_predictions = predictive(key, n=train.shape[0], y=train, future=test.shape[0])\n",
"\n",
"low, high = np.quantile(new_predictions[\"y_future\"], [0.025, 0.0975], axis=0)\n",
"\n",
"x = np.arange(train.shape[0], test.shape[0] + train.shape[0])\n",
"\n",
"fig, ax = plt.subplots()\n",
"\n",
"ax.fill_between(x, low, high, alpha=0.25, label=\"Predicted\")\n",
"ax.plot(x, test, \"o\", label=\"Outcome\")\n",
"\n",
"ax.legend()"
],
"metadata": {
"collapsed": false,
"is_executing": true
}
}
],
"metadata": {
Expand Down

0 comments on commit 0747d38

Please sign in to comment.