Skip to content

Commit dcc353c

Browse files
Use rich tables for statespace build reports (#411)
* Use rich table in build report * Justify left-most column to left * Re-run example notebooks * Re-run example notebooks * Refactor table initialization * set requirement_table to None during initialization
1 parent c9134fe commit dcc353c

6 files changed

+3470
-2265
lines changed

notebooks/Exponential Trend Smoothing.ipynb

Lines changed: 751 additions & 492 deletions
Large diffs are not rendered by default.

notebooks/Making a Custom Statespace Model.ipynb

Lines changed: 424 additions & 154 deletions
Large diffs are not rendered by default.

notebooks/SARMA Example.ipynb

Lines changed: 922 additions & 698 deletions
Large diffs are not rendered by default.

notebooks/Structural Timeseries Modeling.ipynb

Lines changed: 966 additions & 580 deletions
Large diffs are not rendered by default.

notebooks/VARMAX Example.ipynb

Lines changed: 350 additions & 306 deletions
Large diffs are not rendered by default.

pymc_extras/statespace/core/statespace.py

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from pymc.util import RandomState
1616
from pytensor import Variable, graph_replace
1717
from pytensor.compile import get_mode
18+
from rich.box import SIMPLE_HEAD
19+
from rich.console import Console
20+
from rich.table import Table
1821

1922
from pymc_extras.statespace.core.representation import PytensorRepresentation
2023
from pymc_extras.statespace.filters import (
@@ -254,53 +257,72 @@ def __init__(
254257
self.kalman_smoother = KalmanSmoother()
255258
self.make_symbolic_graph()
256259

257-
if verbose:
258-
# These are split into separate try-except blocks, because it will be quite rare of models to implement
259-
# _print_data_requirements, but we still want to print the prior requirements.
260-
try:
261-
self._print_prior_requirements()
262-
except NotImplementedError:
263-
pass
264-
try:
265-
self._print_data_requirements()
266-
except NotImplementedError:
267-
pass
268-
269-
def _print_prior_requirements(self) -> None:
270-
"""
271-
Prints a short report to the terminal about the priors needed for the model, including their names,
260+
self.requirement_table = None
261+
self._populate_prior_requirements()
262+
self._populate_data_requirements()
263+
264+
if verbose and self.requirement_table:
265+
console = Console()
266+
console.print(self.requirement_table)
267+
268+
def _populate_prior_requirements(self) -> None:
269+
"""
270+
Add requirements about priors needed for the model to a rich table, including their names,
272271
shapes, named dimensions, and any parameter constraints.
273272
"""
274-
out = ""
275-
for param, info in self.param_info.items():
276-
out += f'\t{param} -- shape: {info["shape"]}, constraints: {info["constraints"]}, dims: {info["dims"]}\n'
277-
out = out.rstrip()
273+
# Check that the param_info class is implemented, and also that it's a dictionary. We can't proceed if either
274+
# is not true.
275+
try:
276+
if not isinstance(self.param_info, dict):
277+
return
278+
except NotImplementedError:
279+
return
278280

279-
_log.info(
280-
"The following parameters should be assigned priors inside a PyMC "
281-
f"model block: \n"
282-
f"{out}"
283-
)
281+
if self.requirement_table is None:
282+
self._initialize_requirement_table()
284283

285-
def _print_data_requirements(self) -> None:
284+
for param, info in self.param_info.items():
285+
self.requirement_table.add_row(
286+
param, str(info["shape"]), info["constraints"], str(info["dims"])
287+
)
288+
289+
def _populate_data_requirements(self) -> None:
286290
"""
287-
Prints a short report to the terminal about the data needed for the model, including their names, shapes,
288-
and named dimensions.
291+
Add requirements about the data needed for the model, including their names, shapes, and named dimensions.
289292
"""
290-
if not self.data_info:
293+
try:
294+
if not isinstance(self.data_info, dict):
295+
return
296+
except NotImplementedError:
291297
return
292298

293-
out = ""
299+
if self.requirement_table is None:
300+
self._initialize_requirement_table()
301+
else:
302+
self.requirement_table.add_section()
303+
294304
for data, info in self.data_info.items():
295-
out += f'\t{data} -- shape: {info["shape"]}, dims: {info["dims"]}\n'
296-
out = out.rstrip()
305+
self.requirement_table.add_row(data, str(info["shape"]), "pm.Data", str(info["dims"]))
306+
307+
def _initialize_requirement_table(self) -> None:
308+
self.requirement_table = Table(
309+
show_header=True,
310+
show_edge=True,
311+
box=SIMPLE_HEAD,
312+
highlight=True,
313+
)
297314

298-
_log.info(
299-
"The following Data variables should be assigned to the model inside a PyMC "
300-
f"model block: \n"
301-
f"{out}"
315+
self.requirement_table.title = "Model Requirements"
316+
self.requirement_table.caption = (
317+
"These parameters should be assigned priors inside a PyMC model block before "
318+
"calling the build_statespace_graph method."
302319
)
303320

321+
self.requirement_table.add_column("Variable", justify="left")
322+
self.requirement_table.add_column("Shape", justify="left")
323+
self.requirement_table.add_column("Constraints", justify="left")
324+
self.requirement_table.add_column("Dimensions", justify="right")
325+
304326
def _unpack_statespace_with_placeholders(
305327
self,
306328
) -> tuple[

0 commit comments

Comments
 (0)