Skip to content

Commit b16789a

Browse files
Run new pre-commit
1 parent b852140 commit b16789a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+1666
-697
lines changed

conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33

44
def pytest_addoption(parser):
5-
parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")
5+
parser.addoption(
6+
"--runslow", action="store_true", default=False, help="run slow tests"
7+
)
68

79

810
def pytest_configure(config):

docs/conf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,9 @@
171171

172172
# One entry per manual page. List of tuples
173173
# (source start file, name, description, authors, manual section).
174-
man_pages = [(master_doc, "pymc_experimental", "pymc_experimental Documentation", [author], 1)]
174+
man_pages = [
175+
(master_doc, "pymc_experimental", "pymc_experimental Documentation", [author], 1)
176+
]
175177

176178

177179
# -- Options for Texinfo output ----------------------------------------------

notebooks/Making a Custom Statespace Model.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1009,7 +1009,9 @@
10091009
"\n",
10101010
" # save_kalman_filter_outputs_in_idata=True is used to illustrate how setting coords works.\n",
10111011
" # In general, it's not necessary -- use post-estimation functions like sample_conditional_posterior instead\n",
1012-
" ar3.build_statespace_graph(data=data, mode=\"JAX\", save_kalman_filter_outputs_in_idata=True)\n",
1012+
" ar3.build_statespace_graph(\n",
1013+
" data=data, mode=\"JAX\", save_kalman_filter_outputs_in_idata=True\n",
1014+
" )\n",
10131015
" idata = pm.sample(nuts_sampler=\"numpyro\")"
10141016
]
10151017
},

notebooks/SARMA Example.ipynb

Lines changed: 91 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -419,9 +419,13 @@
419419
],
420420
"source": [
421421
"with pm.Model(coords=ss_mod.coords) as arma_model:\n",
422-
" state_sigmas = pm.Gamma(\"sigma_state\", alpha=10, beta=2, dims=ss_mod.param_dims[\"sigma_state\"])\n",
422+
" state_sigmas = pm.Gamma(\n",
423+
" \"sigma_state\", alpha=10, beta=2, dims=ss_mod.param_dims[\"sigma_state\"]\n",
424+
" )\n",
423425
" rho = pm.Beta(\"ar_params\", alpha=5, beta=1, dims=ss_mod.param_dims[\"ar_params\"])\n",
424-
" theta = pm.Normal(\"ma_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"ma_params\"])\n",
426+
" theta = pm.Normal(\n",
427+
" \"ma_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"ma_params\"]\n",
428+
" )\n",
425429
"\n",
426430
" ss_mod.build_statespace_graph(df, mode=\"JAX\")\n",
427431
"\n",
@@ -509,9 +513,9 @@
509513
"source": [
510514
"fig, ax = plt.subplots(2, 1, figsize=(14, 6), dpi=144)\n",
511515
"for idx, (axis, state) in enumerate(zip(fig.axes, ss_mod.state_names)):\n",
512-
" unconditional_prior.prior_latent.sel(state=state).stack(sample=[\"chain\", \"draw\"]).plot.line(\n",
513-
" x=\"time\", ax=axis, add_legend=False\n",
514-
" )\n",
516+
" unconditional_prior.prior_latent.sel(state=state).stack(\n",
517+
" sample=[\"chain\", \"draw\"]\n",
518+
" ).plot.line(x=\"time\", ax=axis, add_legend=False)\n",
515519
" axis.set(title=state)\n",
516520
"\n",
517521
"fig.set_facecolor(\"w\")\n",
@@ -614,9 +618,9 @@
614618
"source": [
615619
"fig, ax = plt.subplots(2, 1, figsize=(14, 6), dpi=144)\n",
616620
"for idx, (axis, state) in enumerate(zip(fig.axes, ss_mod.state_names)):\n",
617-
" conditional_prior.filtered_prior.sel(state=state).stack(sample=[\"chain\", \"draw\"]).plot.line(\n",
618-
" x=\"time\", ax=axis, add_legend=False\n",
619-
" )\n",
621+
" conditional_prior.filtered_prior.sel(state=state).stack(\n",
622+
" sample=[\"chain\", \"draw\"]\n",
623+
" ).plot.line(x=\"time\", ax=axis, add_legend=False)\n",
620624
" axis.set(title=state)\n",
621625
"\n",
622626
"fig.set_facecolor(\"w\")\n",
@@ -660,9 +664,9 @@
660664
"source": [
661665
"fig, ax = plt.subplots(2, 1, figsize=(14, 6), dpi=144)\n",
662666
"for idx, (axis, state) in enumerate(zip(fig.axes, ss_mod.state_names)):\n",
663-
" conditional_prior.predicted_prior.sel(state=state).stack(sample=[\"chain\", \"draw\"]).plot.line(\n",
664-
" x=\"time\", ax=axis, add_legend=False\n",
665-
" )\n",
667+
" conditional_prior.predicted_prior.sel(state=state).stack(\n",
668+
" sample=[\"chain\", \"draw\"]\n",
669+
" ).plot.line(x=\"time\", ax=axis, add_legend=False)\n",
666670
" axis.set(title=state)\n",
667671
"\n",
668672
"fig.set_facecolor(\"w\")\n",
@@ -1278,7 +1282,12 @@
12781282
" zip(fig.axes, [\"Observed State (Data)\", \"Hidden State (ARMA Dynamics)\"])\n",
12791283
" ):\n",
12801284
" post[f\"{filter_output}_posterior\"].isel(state=idx).mean(dim=\"sample\").plot.line(\n",
1281-
" x=\"time\", ax=axis, lw=2, add_legend=False, label=filter_output.title(), color=color\n",
1285+
" x=\"time\",\n",
1286+
" ax=axis,\n",
1287+
" lw=2,\n",
1288+
" add_legend=False,\n",
1289+
" label=filter_output.title(),\n",
1290+
" color=color,\n",
12821291
" )\n",
12831292
" axis.fill_between(\n",
12841293
" hdi.coords[\"time\"], *hdi.isel(state=idx).values.T, alpha=0.25, color=color\n",
@@ -1554,9 +1563,11 @@
15541563
" hdi_forecast.coords[\"time\"].values,\n",
15551564
" *hdi_forecast.isel(observed_state=0).values.T,\n",
15561565
" alpha=0.25,\n",
1557-
" color=\"tab:blue\"\n",
1566+
" color=\"tab:blue\",\n",
15581567
" )\n",
1559-
"ax.set_title(\"Porcupine Graph of 10-Period Forecasts (parameters estimated on all data)\")\n",
1568+
"ax.set_title(\n",
1569+
" \"Porcupine Graph of 10-Period Forecasts (parameters estimated on all data)\"\n",
1570+
")\n",
15601571
"plt.show()"
15611572
]
15621573
},
@@ -1586,8 +1597,12 @@
15861597
" x_time = irf.coords[\"time\"]\n",
15871598
"\n",
15881599
" ax.plot(x_time, mean, color=\"k\", label=\"Mean\")\n",
1589-
" ax.fill_between(x_time, *hdi.values.T, color=\"tab:blue\", alpha=0.25, label=\"HDI 94%\")\n",
1590-
" ax.fill_between(x_time, *hdi_50.values.T, color=\"tab:blue\", alpha=0.5, label=\"HDI 50%\")\n",
1600+
" ax.fill_between(\n",
1601+
" x_time, *hdi.values.T, color=\"tab:blue\", alpha=0.25, label=\"HDI 94%\"\n",
1602+
" )\n",
1603+
" ax.fill_between(\n",
1604+
" x_time, *hdi_50.values.T, color=\"tab:blue\", alpha=0.5, label=\"HDI 50%\"\n",
1605+
" )\n",
15911606
" ax.set(title=title)\n",
15921607
"\n",
15931608
" ax.legend()\n",
@@ -1670,7 +1685,9 @@
16701685
"source": [
16711686
"steps = 40\n",
16721687
"irf = ss_mod.impulse_response_function(idata, n_steps=steps, orthogonalize_shocks=True)\n",
1673-
"plot_irf(irf.isel(state=0), \"Impulse response function from estimated covariance matrix\")"
1688+
"plot_irf(\n",
1689+
" irf.isel(state=0), \"Impulse response function from estimated covariance matrix\"\n",
1690+
")"
16741691
]
16751692
},
16761693
{
@@ -1749,7 +1766,10 @@
17491766
"shock_trajectory[20] = 0.5\n",
17501767
"\n",
17511768
"irf = ss_mod.impulse_response_function(idata, shock_trajectory=shock_trajectory)\n",
1752-
"plot_irf(irf.isel(state=0), \"Impulse response function with shock of 1 at t=0 and 0.5 at t=20\")"
1769+
"plot_irf(\n",
1770+
" irf.isel(state=0),\n",
1771+
" \"Impulse response function with shock of 1 at t=0 and 0.5 at t=20\",\n",
1772+
")"
17531773
]
17541774
},
17551775
{
@@ -1789,7 +1809,10 @@
17891809
],
17901810
"source": [
17911811
"ss_mod = pmss.BayesianSARIMA(\n",
1792-
" order=(1, 0, 1), state_structure=\"interpretable\", measurement_error=True, verbose=True\n",
1812+
" order=(1, 0, 1),\n",
1813+
" state_structure=\"interpretable\",\n",
1814+
" measurement_error=True,\n",
1815+
" verbose=True,\n",
17931816
")"
17941817
]
17951818
},
@@ -1945,12 +1968,23 @@
19451968
],
19461969
"source": [
19471970
"with pm.Model(coords=ss_mod.coords) as arma_model:\n",
1948-
" state_sigmas = pm.HalfNormal(\"sigma_state\", sigma=1.0, dims=ss_mod.param_dims[\"sigma_state\"])\n",
1949-
" obs_sigmas = pm.HalfNormal(\"sigma_obs\", sigma=1.0, dims=ss_mod.param_dims[\"sigma_obs\"])\n",
1971+
" state_sigmas = pm.HalfNormal(\n",
1972+
" \"sigma_state\", sigma=1.0, dims=ss_mod.param_dims[\"sigma_state\"]\n",
1973+
" )\n",
1974+
" obs_sigmas = pm.HalfNormal(\n",
1975+
" \"sigma_obs\", sigma=1.0, dims=ss_mod.param_dims[\"sigma_obs\"]\n",
1976+
" )\n",
19501977
" rho = pm.TruncatedNormal(\n",
1951-
" \"ar_params\", mu=0.0, sigma=0.5, lower=-1.0, upper=1.0, dims=ss_mod.param_dims[\"ar_params\"]\n",
1978+
" \"ar_params\",\n",
1979+
" mu=0.0,\n",
1980+
" sigma=0.5,\n",
1981+
" lower=-1.0,\n",
1982+
" upper=1.0,\n",
1983+
" dims=ss_mod.param_dims[\"ar_params\"],\n",
1984+
" )\n",
1985+
" theta = pm.Normal(\n",
1986+
" \"ma_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"ma_params\"]\n",
19521987
" )\n",
1953-
" theta = pm.Normal(\"ma_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"ma_params\"])\n",
19541988
" ss_mod.build_statespace_graph(df, mode=\"JAX\")\n",
19551989
"\n",
19561990
" idata = pm.sample(**sample_kwargs)"
@@ -2090,7 +2124,12 @@
20902124
" zip(fig.axes, [\"Observed State (Data)\", \"Hidden State (ARMA Dynamics)\"])\n",
20912125
" ):\n",
20922126
" post[f\"{filter_output}_posterior\"].isel(state=idx).mean(dim=\"sample\").plot.line(\n",
2093-
" x=\"time\", ax=axis, lw=2, add_legend=False, label=filter_output.title(), color=color\n",
2127+
" x=\"time\",\n",
2128+
" ax=axis,\n",
2129+
" lw=2,\n",
2130+
" add_legend=False,\n",
2131+
" label=filter_output.title(),\n",
2132+
" color=color,\n",
20942133
" )\n",
20952134
" axis.fill_between(\n",
20962135
" hdi.coords[\"time\"], *hdi.isel(state=idx).values.T, alpha=0.25, color=color\n",
@@ -2307,7 +2346,10 @@
23072346
],
23082347
"source": [
23092348
"ss_mod = pmss.BayesianSARIMA(\n",
2310-
" order=(2, 1, 2), seasonal_order=(2, 0, 2, 12), verbose=True, stationary_initialization=False\n",
2349+
" order=(2, 1, 2),\n",
2350+
" seasonal_order=(2, 0, 2, 12),\n",
2351+
" verbose=True,\n",
2352+
" stationary_initialization=False,\n",
23112353
")"
23122354
]
23132355
},
@@ -2455,17 +2497,29 @@
24552497
" P0 = pt.set_subtensor(P0[0, 0], sigma_P0[0])\n",
24562498
" P0 = pm.Deterministic(\"P0\", P0, dims=[\"state\", \"state_aux\"])\n",
24572499
"\n",
2458-
" ar_params = pm.Normal(\"ar_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"ar_params\"])\n",
2500+
" ar_params = pm.Normal(\n",
2501+
" \"ar_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"ar_params\"]\n",
2502+
" )\n",
24592503
" seasonal_ar_params = pm.Normal(\n",
2460-
" \"seasonal_ar_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"seasonal_ar_params\"]\n",
2504+
" \"seasonal_ar_params\",\n",
2505+
" mu=0.0,\n",
2506+
" sigma=0.5,\n",
2507+
" dims=ss_mod.param_dims[\"seasonal_ar_params\"],\n",
24612508
" )\n",
24622509
"\n",
2463-
" ma_params = pm.Normal(\"ma_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"ma_params\"])\n",
2510+
" ma_params = pm.Normal(\n",
2511+
" \"ma_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"ma_params\"]\n",
2512+
" )\n",
24642513
" seasonal_ma_params = pm.Normal(\n",
2465-
" \"seasonal_ma_params\", mu=0.0, sigma=0.5, dims=ss_mod.param_dims[\"seasonal_ma_params\"]\n",
2514+
" \"seasonal_ma_params\",\n",
2515+
" mu=0.0,\n",
2516+
" sigma=0.5,\n",
2517+
" dims=ss_mod.param_dims[\"seasonal_ma_params\"],\n",
24662518
" )\n",
24672519
"\n",
2468-
" state_sigmas = pm.Gamma(\"sigma_state\", alpha=2, beta=1.0, dims=ss_mod.param_dims[\"sigma_state\"])\n",
2520+
" state_sigmas = pm.Gamma(\n",
2521+
" \"sigma_state\", alpha=2, beta=1.0, dims=ss_mod.param_dims[\"sigma_state\"]\n",
2522+
" )\n",
24692523
"\n",
24702524
" # Remember to log the data by hand\n",
24712525
" ss_mod.build_statespace_graph(airpass.apply(np.log), mode=\"JAX\")\n",
@@ -2583,9 +2637,9 @@
25832637
"fig, ax = plt.subplots()\n",
25842638
"post = az.extract(post_pred).map(np.exp)\n",
25852639
"hdi = az.hdi(post_pred.map(np.exp))[f\"predicted_posterior_observed\"]\n",
2586-
"post[f\"predicted_posterior_observed\"].isel(observed_state=0).mean(dim=\"sample\").plot.line(\n",
2587-
" x=\"time\", ax=ax, add_legend=False, label=\"Posterior Mean, Predicted\"\n",
2588-
")\n",
2640+
"post[f\"predicted_posterior_observed\"].isel(observed_state=0).mean(\n",
2641+
" dim=\"sample\"\n",
2642+
").plot.line(x=\"time\", ax=ax, add_legend=False, label=\"Posterior Mean, Predicted\")\n",
25892643
"ax.fill_between(\n",
25902644
" hdi.coords[\"time\"],\n",
25912645
" *hdi.isel(observed_state=0).values.T,\n",
@@ -2682,7 +2736,9 @@
26822736
}
26832737
],
26842738
"source": [
2685-
"forecast_hdi = az.hdi(forecast_idata.map(np.exp)).forecast_observed.isel(observed_state=0)\n",
2739+
"forecast_hdi = az.hdi(forecast_idata.map(np.exp)).forecast_observed.isel(\n",
2740+
" observed_state=0\n",
2741+
")\n",
26862742
"forecast_mu = forecast_idata.map(np.exp).forecast_observed.mean(dim=[\"chain\", \"draw\"])\n",
26872743
"fig, ax = plt.subplots()\n",
26882744
"ax.plot(airpass.index, airpass.values, label=\"Data\")\n",
@@ -2692,7 +2748,7 @@
26922748
" *forecast_hdi.values.T,\n",
26932749
" label=\"Forecast 94% HDI\",\n",
26942750
" color=\"tab:orange\",\n",
2695-
" alpha=0.25\n",
2751+
" alpha=0.25,\n",
26962752
")\n",
26972753
"ax.legend()\n",
26982754
"plt.show()"

0 commit comments

Comments
 (0)