|
15 | 15 | from pymc.util import RandomState
|
16 | 16 | from pytensor import Variable, graph_replace
|
17 | 17 | from pytensor.compile import get_mode
|
| 18 | +from rich.box import SIMPLE_HEAD |
| 19 | +from rich.console import Console |
| 20 | +from rich.table import Table |
18 | 21 |
|
19 | 22 | from pymc_extras.statespace.core.representation import PytensorRepresentation
|
20 | 23 | from pymc_extras.statespace.filters import (
|
@@ -254,53 +257,72 @@ def __init__(
|
254 | 257 | self.kalman_smoother = KalmanSmoother()
|
255 | 258 | self.make_symbolic_graph()
|
256 | 259 |
|
257 |
| - if verbose: |
258 |
| - # These are split into separate try-except blocks, because it will be quite rare of models to implement |
259 |
| - # _print_data_requirements, but we still want to print the prior requirements. |
260 |
| - try: |
261 |
| - self._print_prior_requirements() |
262 |
| - except NotImplementedError: |
263 |
| - pass |
264 |
| - try: |
265 |
| - self._print_data_requirements() |
266 |
| - except NotImplementedError: |
267 |
| - pass |
268 |
| - |
269 |
| - def _print_prior_requirements(self) -> None: |
270 |
| - """ |
271 |
| - Prints a short report to the terminal about the priors needed for the model, including their names, |
| 260 | + self.requirement_table = None |
| 261 | + self._populate_prior_requirements() |
| 262 | + self._populate_data_requirements() |
| 263 | + |
| 264 | + if verbose and self.requirement_table: |
| 265 | + console = Console() |
| 266 | + console.print(self.requirement_table) |
| 267 | + |
| 268 | + def _populate_prior_requirements(self) -> None: |
| 269 | + """ |
| 270 | + Add requirements about priors needed for the model to a rich table, including their names, |
272 | 271 | shapes, named dimensions, and any parameter constraints.
|
273 | 272 | """
|
274 |
| - out = "" |
275 |
| - for param, info in self.param_info.items(): |
276 |
| - out += f'\t{param} -- shape: {info["shape"]}, constraints: {info["constraints"]}, dims: {info["dims"]}\n' |
277 |
| - out = out.rstrip() |
| 273 | + # Check that the param_info class is implemented, and also that it's a dictionary. We can't proceed if either |
| 274 | + # is not true. |
| 275 | + try: |
| 276 | + if not isinstance(self.param_info, dict): |
| 277 | + return |
| 278 | + except NotImplementedError: |
| 279 | + return |
278 | 280 |
|
279 |
| - _log.info( |
280 |
| - "The following parameters should be assigned priors inside a PyMC " |
281 |
| - f"model block: \n" |
282 |
| - f"{out}" |
283 |
| - ) |
| 281 | + if self.requirement_table is None: |
| 282 | + self._initialize_requirement_table() |
284 | 283 |
|
285 |
| - def _print_data_requirements(self) -> None: |
| 284 | + for param, info in self.param_info.items(): |
| 285 | + self.requirement_table.add_row( |
| 286 | + param, str(info["shape"]), info["constraints"], str(info["dims"]) |
| 287 | + ) |
| 288 | + |
| 289 | + def _populate_data_requirements(self) -> None: |
286 | 290 | """
|
287 |
| - Prints a short report to the terminal about the data needed for the model, including their names, shapes, |
288 |
| - and named dimensions. |
| 291 | + Add requirements about the data needed for the model, including their names, shapes, and named dimensions. |
289 | 292 | """
|
290 |
| - if not self.data_info: |
| 293 | + try: |
| 294 | + if not isinstance(self.data_info, dict): |
| 295 | + return |
| 296 | + except NotImplementedError: |
291 | 297 | return
|
292 | 298 |
|
293 |
| - out = "" |
| 299 | + if self.requirement_table is None: |
| 300 | + self._initialize_requirement_table() |
| 301 | + else: |
| 302 | + self.requirement_table.add_section() |
| 303 | + |
294 | 304 | for data, info in self.data_info.items():
|
295 |
| - out += f'\t{data} -- shape: {info["shape"]}, dims: {info["dims"]}\n' |
296 |
| - out = out.rstrip() |
| 305 | + self.requirement_table.add_row(data, str(info["shape"]), "pm.Data", str(info["dims"])) |
| 306 | + |
| 307 | + def _initialize_requirement_table(self) -> None: |
| 308 | + self.requirement_table = Table( |
| 309 | + show_header=True, |
| 310 | + show_edge=True, |
| 311 | + box=SIMPLE_HEAD, |
| 312 | + highlight=True, |
| 313 | + ) |
297 | 314 |
|
298 |
| - _log.info( |
299 |
| - "The following Data variables should be assigned to the model inside a PyMC " |
300 |
| - f"model block: \n" |
301 |
| - f"{out}" |
| 315 | + self.requirement_table.title = "Model Requirements" |
| 316 | + self.requirement_table.caption = ( |
| 317 | + "These parameters should be assigned priors inside a PyMC model block before " |
| 318 | + "calling the build_statespace_graph method." |
302 | 319 | )
|
303 | 320 |
|
| 321 | + self.requirement_table.add_column("Variable", justify="left") |
| 322 | + self.requirement_table.add_column("Shape", justify="left") |
| 323 | + self.requirement_table.add_column("Constraints", justify="left") |
| 324 | + self.requirement_table.add_column("Dimensions", justify="right") |
| 325 | + |
304 | 326 | def _unpack_statespace_with_placeholders(
|
305 | 327 | self,
|
306 | 328 | ) -> tuple[
|
|
0 commit comments