|
419 | 419 | ],
|
420 | 420 | "source": [
|
421 | 421 | "with pm.Model(coords=ss_mod.coords) as arma_model:\n",
|
422 |
| - " state_sigmas = pm.Gamma(\"sigma_state\", alpha=10, beta=2, dims=ss_mod.param_dims[\"sigma_state\"])\n", |
| 422 | + " state_sigmas = pm.Gamma(\n", |
| 423 | + " \"sigma_state\", alpha=10, beta=2, dims=ss_mod.param_dims[\"sigma_state\"]\n", |
| 424 | + " )\n", |
423 | 425 | " rho = pm.Beta(\"ar_params\", alpha=5, beta=1, dims=ss_mod.param_dims[\"ar_params\"])\n",
|
424 |
| - " theta = pm.Normal(\"ma_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"ma_params\"])\n", |
| 426 | + " theta = pm.Normal(\n", |
| 427 | + " \"ma_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"ma_params\"]\n", |
| 428 | + " )\n", |
425 | 429 | "\n",
|
426 | 430 | " ss_mod.build_statespace_graph(df, mode=\"JAX\")\n",
|
427 | 431 | "\n",
|
|
509 | 513 | "source": [
|
510 | 514 | "fig, ax = plt.subplots(2, 1, figsize=(14, 6), dpi=144)\n",
|
511 | 515 | "for idx, (axis, state) in enumerate(zip(fig.axes, ss_mod.state_names)):\n",
|
512 |
| - " unconditional_prior.prior_latent.sel(state=state).stack(sample=[\"chain\", \"draw\"]).plot.line(\n", |
513 |
| - " x=\"time\", ax=axis, add_legend=False\n", |
514 |
| - " )\n", |
| 516 | + " unconditional_prior.prior_latent.sel(state=state).stack(\n", |
| 517 | + " sample=[\"chain\", \"draw\"]\n", |
| 518 | + " ).plot.line(x=\"time\", ax=axis, add_legend=False)\n", |
515 | 519 | " axis.set(title=state)\n",
|
516 | 520 | "\n",
|
517 | 521 | "fig.set_facecolor(\"w\")\n",
|
|
614 | 618 | "source": [
|
615 | 619 | "fig, ax = plt.subplots(2, 1, figsize=(14, 6), dpi=144)\n",
|
616 | 620 | "for idx, (axis, state) in enumerate(zip(fig.axes, ss_mod.state_names)):\n",
|
617 |
| - " conditional_prior.filtered_prior.sel(state=state).stack(sample=[\"chain\", \"draw\"]).plot.line(\n", |
618 |
| - " x=\"time\", ax=axis, add_legend=False\n", |
619 |
| - " )\n", |
| 621 | + " conditional_prior.filtered_prior.sel(state=state).stack(\n", |
| 622 | + " sample=[\"chain\", \"draw\"]\n", |
| 623 | + " ).plot.line(x=\"time\", ax=axis, add_legend=False)\n", |
620 | 624 | " axis.set(title=state)\n",
|
621 | 625 | "\n",
|
622 | 626 | "fig.set_facecolor(\"w\")\n",
|
|
660 | 664 | "source": [
|
661 | 665 | "fig, ax = plt.subplots(2, 1, figsize=(14, 6), dpi=144)\n",
|
662 | 666 | "for idx, (axis, state) in enumerate(zip(fig.axes, ss_mod.state_names)):\n",
|
663 |
| - " conditional_prior.predicted_prior.sel(state=state).stack(sample=[\"chain\", \"draw\"]).plot.line(\n", |
664 |
| - " x=\"time\", ax=axis, add_legend=False\n", |
665 |
| - " )\n", |
| 667 | + " conditional_prior.predicted_prior.sel(state=state).stack(\n", |
| 668 | + " sample=[\"chain\", \"draw\"]\n", |
| 669 | + " ).plot.line(x=\"time\", ax=axis, add_legend=False)\n", |
666 | 670 | " axis.set(title=state)\n",
|
667 | 671 | "\n",
|
668 | 672 | "fig.set_facecolor(\"w\")\n",
|
|
1278 | 1282 | " zip(fig.axes, [\"Observed State (Data)\", \"Hidden State (ARMA Dynamics)\"])\n",
|
1279 | 1283 | " ):\n",
|
1280 | 1284 | " post[f\"{filter_output}_posterior\"].isel(state=idx).mean(dim=\"sample\").plot.line(\n",
|
1281 |
| - " x=\"time\", ax=axis, lw=2, add_legend=False, label=filter_output.title(), color=color\n", |
| 1285 | + " x=\"time\",\n", |
| 1286 | + " ax=axis,\n", |
| 1287 | + " lw=2,\n", |
| 1288 | + " add_legend=False,\n", |
| 1289 | + " label=filter_output.title(),\n", |
| 1290 | + " color=color,\n", |
1282 | 1291 | " )\n",
|
1283 | 1292 | " axis.fill_between(\n",
|
1284 | 1293 | " hdi.coords[\"time\"], *hdi.isel(state=idx).values.T, alpha=0.25, color=color\n",
|
|
1554 | 1563 | " hdi_forecast.coords[\"time\"].values,\n",
|
1555 | 1564 | " *hdi_forecast.isel(observed_state=0).values.T,\n",
|
1556 | 1565 | " alpha=0.25,\n",
|
1557 |
| - " color=\"tab:blue\"\n", |
| 1566 | + " color=\"tab:blue\",\n", |
1558 | 1567 | " )\n",
|
1559 |
| - "ax.set_title(\"Porcupine Graph of 10-Period Forecasts (parameters estimated on all data)\")\n", |
| 1568 | + "ax.set_title(\n", |
| 1569 | + " \"Porcupine Graph of 10-Period Forecasts (parameters estimated on all data)\"\n", |
| 1570 | + ")\n", |
1560 | 1571 | "plt.show()"
|
1561 | 1572 | ]
|
1562 | 1573 | },
|
|
1586 | 1597 | " x_time = irf.coords[\"time\"]\n",
|
1587 | 1598 | "\n",
|
1588 | 1599 | " ax.plot(x_time, mean, color=\"k\", label=\"Mean\")\n",
|
1589 |
| - " ax.fill_between(x_time, *hdi.values.T, color=\"tab:blue\", alpha=0.25, label=\"HDI 94%\")\n", |
1590 |
| - " ax.fill_between(x_time, *hdi_50.values.T, color=\"tab:blue\", alpha=0.5, label=\"HDI 50%\")\n", |
| 1600 | + " ax.fill_between(\n", |
| 1601 | + " x_time, *hdi.values.T, color=\"tab:blue\", alpha=0.25, label=\"HDI 94%\"\n", |
| 1602 | + " )\n", |
| 1603 | + " ax.fill_between(\n", |
| 1604 | + " x_time, *hdi_50.values.T, color=\"tab:blue\", alpha=0.5, label=\"HDI 50%\"\n", |
| 1605 | + " )\n", |
1591 | 1606 | " ax.set(title=title)\n",
|
1592 | 1607 | "\n",
|
1593 | 1608 | " ax.legend()\n",
|
|
1670 | 1685 | "source": [
|
1671 | 1686 | "steps = 40\n",
|
1672 | 1687 | "irf = ss_mod.impulse_response_function(idata, n_steps=steps, orthogonalize_shocks=True)\n",
|
1673 |
| - "plot_irf(irf.isel(state=0), \"Impulse response function from estimated covariance matrix\")" |
| 1688 | + "plot_irf(\n", |
| 1689 | + " irf.isel(state=0), \"Impulse response function from estimated covariance matrix\"\n", |
| 1690 | + ")" |
1674 | 1691 | ]
|
1675 | 1692 | },
|
1676 | 1693 | {
|
|
1749 | 1766 | "shock_trajectory[20] = 0.5\n",
|
1750 | 1767 | "\n",
|
1751 | 1768 | "irf = ss_mod.impulse_response_function(idata, shock_trajectory=shock_trajectory)\n",
|
1752 |
| - "plot_irf(irf.isel(state=0), \"Impulse response function with shock of 1 at t=0 and 0.5 at t=20\")" |
| 1769 | + "plot_irf(\n", |
| 1770 | + " irf.isel(state=0),\n", |
| 1771 | + " \"Impulse response function with shock of 1 at t=0 and 0.5 at t=20\",\n", |
| 1772 | + ")" |
1753 | 1773 | ]
|
1754 | 1774 | },
|
1755 | 1775 | {
|
|
1789 | 1809 | ],
|
1790 | 1810 | "source": [
|
1791 | 1811 | "ss_mod = pmss.BayesianSARIMA(\n",
|
1792 |
| - " order=(1, 0, 1), state_structure=\"interpretable\", measurement_error=True, verbose=True\n", |
| 1812 | + " order=(1, 0, 1),\n", |
| 1813 | + " state_structure=\"interpretable\",\n", |
| 1814 | + " measurement_error=True,\n", |
| 1815 | + " verbose=True,\n", |
1793 | 1816 | ")"
|
1794 | 1817 | ]
|
1795 | 1818 | },
|
|
1945 | 1968 | ],
|
1946 | 1969 | "source": [
|
1947 | 1970 | "with pm.Model(coords=ss_mod.coords) as arma_model:\n",
|
1948 |
| - " state_sigmas = pm.HalfNormal(\"sigma_state\", sigma=1.0, dims=ss_mod.param_dims[\"sigma_state\"])\n", |
1949 |
| - " obs_sigmas = pm.HalfNormal(\"sigma_obs\", sigma=1.0, dims=ss_mod.param_dims[\"sigma_obs\"])\n", |
| 1971 | + " state_sigmas = pm.HalfNormal(\n", |
| 1972 | + " \"sigma_state\", sigma=1.0, dims=ss_mod.param_dims[\"sigma_state\"]\n", |
| 1973 | + " )\n", |
| 1974 | + " obs_sigmas = pm.HalfNormal(\n", |
| 1975 | + " \"sigma_obs\", sigma=1.0, dims=ss_mod.param_dims[\"sigma_obs\"]\n", |
| 1976 | + " )\n", |
1950 | 1977 | " rho = pm.TruncatedNormal(\n",
|
1951 |
| - " \"ar_params\", mu=0.0, sigma=0.5, lower=-1.0, upper=1.0, dims=ss_mod.param_dims[\"ar_params\"]\n", |
| 1978 | + " \"ar_params\",\n", |
| 1979 | + " mu=0.0,\n", |
| 1980 | + " sigma=0.5,\n", |
| 1981 | + " lower=-1.0,\n", |
| 1982 | + " upper=1.0,\n", |
| 1983 | + " dims=ss_mod.param_dims[\"ar_params\"],\n", |
| 1984 | + " )\n", |
| 1985 | + " theta = pm.Normal(\n", |
| 1986 | + " \"ma_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"ma_params\"]\n", |
1952 | 1987 | " )\n",
|
1953 |
| - " theta = pm.Normal(\"ma_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"ma_params\"])\n", |
1954 | 1988 | " ss_mod.build_statespace_graph(df, mode=\"JAX\")\n",
|
1955 | 1989 | "\n",
|
1956 | 1990 | " idata = pm.sample(**sample_kwargs)"
|
|
2090 | 2124 | " zip(fig.axes, [\"Observed State (Data)\", \"Hidden State (ARMA Dynamics)\"])\n",
|
2091 | 2125 | " ):\n",
|
2092 | 2126 | " post[f\"{filter_output}_posterior\"].isel(state=idx).mean(dim=\"sample\").plot.line(\n",
|
2093 |
| - " x=\"time\", ax=axis, lw=2, add_legend=False, label=filter_output.title(), color=color\n", |
| 2127 | + " x=\"time\",\n", |
| 2128 | + " ax=axis,\n", |
| 2129 | + " lw=2,\n", |
| 2130 | + " add_legend=False,\n", |
| 2131 | + " label=filter_output.title(),\n", |
| 2132 | + " color=color,\n", |
2094 | 2133 | " )\n",
|
2095 | 2134 | " axis.fill_between(\n",
|
2096 | 2135 | " hdi.coords[\"time\"], *hdi.isel(state=idx).values.T, alpha=0.25, color=color\n",
|
|
2307 | 2346 | ],
|
2308 | 2347 | "source": [
|
2309 | 2348 | "ss_mod = pmss.BayesianSARIMA(\n",
|
2310 |
| - " order=(2, 1, 2), seasonal_order=(2, 0, 2, 12), verbose=True, stationary_initialization=False\n", |
| 2349 | + " order=(2, 1, 2),\n", |
| 2350 | + " seasonal_order=(2, 0, 2, 12),\n", |
| 2351 | + " verbose=True,\n", |
| 2352 | + " stationary_initialization=False,\n", |
2311 | 2353 | ")"
|
2312 | 2354 | ]
|
2313 | 2355 | },
|
|
2455 | 2497 | " P0 = pt.set_subtensor(P0[0, 0], sigma_P0[0])\n",
|
2456 | 2498 | " P0 = pm.Deterministic(\"P0\", P0, dims=[\"state\", \"state_aux\"])\n",
|
2457 | 2499 | "\n",
|
2458 |
| - " ar_params = pm.Normal(\"ar_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"ar_params\"])\n", |
| 2500 | + " ar_params = pm.Normal(\n", |
| 2501 | + " \"ar_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"ar_params\"]\n", |
| 2502 | + " )\n", |
2459 | 2503 | " seasonal_ar_params = pm.Normal(\n",
|
2460 |
| - " \"seasonal_ar_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"seasonal_ar_params\"]\n", |
| 2504 | + " \"seasonal_ar_params\",\n", |
| 2505 | + " mu=0.0,\n", |
| 2506 | + " sigma=0.5,\n", |
| 2507 | + " dims=ss_mod.param_dims[\"seasonal_ar_params\"],\n", |
2461 | 2508 | " )\n",
|
2462 | 2509 | "\n",
|
2463 |
| - " ma_params = pm.Normal(\"ma_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"ma_params\"])\n", |
| 2510 | + " ma_params = pm.Normal(\n", |
| 2511 | + " \"ma_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"ma_params\"]\n", |
| 2512 | + " )\n", |
2464 | 2513 | " seasonal_ma_params = pm.Normal(\n",
|
2465 |
| - " \"seasonal_ma_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"seasonal_ma_params\"]\n", |
| 2514 | + " \"seasonal_ma_params\",\n", |
| 2515 | + " mu=0.0,\n", |
| 2516 | + " sigma=0.5,\n", |
| 2517 | + " dims=ss_mod.param_dims[\"seasonal_ma_params\"],\n", |
2466 | 2518 | " )\n",
|
2467 | 2519 | "\n",
|
2468 |
| - " state_sigmas = pm.Gamma(\"sigma_state\", alpha=2, beta=1.0, dims=ss_mod.param_dims[\"sigma_state\"])\n", |
| 2520 | + " state_sigmas = pm.Gamma(\n", |
| 2521 | + " \"sigma_state\", alpha=2, beta=1.0, dims=ss_mod.param_dims[\"sigma_state\"]\n", |
| 2522 | + " )\n", |
2469 | 2523 | "\n",
|
2470 | 2524 | " # Remember to log the data by hand\n",
|
2471 | 2525 | " ss_mod.build_statespace_graph(airpass.apply(np.log), mode=\"JAX\")\n",
|
|
2583 | 2637 | "fig, ax = plt.subplots()\n",
|
2584 | 2638 | "post = az.extract(post_pred).map(np.exp)\n",
|
2585 | 2639 | "hdi = az.hdi(post_pred.map(np.exp))[f\"predicted_posterior_observed\"]\n",
|
2586 |
| - "post[f\"predicted_posterior_observed\"].isel(observed_state=0).mean(dim=\"sample\").plot.line(\n", |
2587 |
| - " x=\"time\", ax=ax, add_legend=False, label=\"Posterior Mean, Predicted\"\n", |
2588 |
| - ")\n", |
| 2640 | + "post[f\"predicted_posterior_observed\"].isel(observed_state=0).mean(\n", |
| 2641 | + " dim=\"sample\"\n", |
| 2642 | + ").plot.line(x=\"time\", ax=ax, add_legend=False, label=\"Posterior Mean, Predicted\")\n", |
2589 | 2643 | "ax.fill_between(\n",
|
2590 | 2644 | " hdi.coords[\"time\"],\n",
|
2591 | 2645 | " *hdi.isel(observed_state=0).values.T,\n",
|
|
2682 | 2736 | }
|
2683 | 2737 | ],
|
2684 | 2738 | "source": [
|
2685 |
| - "forecast_hdi = az.hdi(forecast_idata.map(np.exp)).forecast_observed.isel(observed_state=0)\n", |
| 2739 | + "forecast_hdi = az.hdi(forecast_idata.map(np.exp)).forecast_observed.isel(\n", |
| 2740 | + " observed_state=0\n", |
| 2741 | + ")\n", |
2686 | 2742 | "forecast_mu = forecast_idata.map(np.exp).forecast_observed.mean(dim=[\"chain\", \"draw\"])\n",
|
2687 | 2743 | "fig, ax = plt.subplots()\n",
|
2688 | 2744 | "ax.plot(airpass.index, airpass.values, label=\"Data\")\n",
|
|
2692 | 2748 | " *forecast_hdi.values.T,\n",
|
2693 | 2749 | " label=\"Forecast 94% HDI\",\n",
|
2694 | 2750 | " color=\"tab:orange\",\n",
|
2695 |
| - " alpha=0.25\n", |
| 2751 | + " alpha=0.25,\n", |
2696 | 2752 | ")\n",
|
2697 | 2753 | "ax.legend()\n",
|
2698 | 2754 | "plt.show()"
|
|
0 commit comments