Skip to content

Commit 1fb5536

Browse files
Rebase from main and run new pre-commit
1 parent 2394448 commit 1fb5536

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

pymc_experimental/statespace/models/ETS.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Sequence
1+
from collections.abc import Sequence
2+
from typing import Any
23

34
import numpy as np
45
import pytensor.tensor as pt
@@ -159,7 +160,6 @@ def __init__(
159160
filter_type: str = "standard",
160161
verbose: bool = True,
161162
):
162-
163163
if order is not None:
164164
if len(order) != 3 or any(not isinstance(o, str) for o in order):
165165
raise ValueError("Order must be a tuple of three strings.")
@@ -405,14 +405,14 @@ def make_symbolic_graph(self) -> None:
405405
self.ssm["design"] = Z
406406

407407
# Set up the state covariance matrix
408-
state_cov_idx = ("state_cov",) + np.diag_indices(self.k_posdef)
408+
state_cov_idx = ("state_cov", *np.diag_indices(self.k_posdef))
409409
state_cov = self.make_and_register_variable(
410410
"sigma_state", shape=() if self.k_posdef == 1 else (self.k_posdef,), dtype=floatX
411411
)
412412
self.ssm[state_cov_idx] = state_cov**2
413413

414414
if self.measurement_error:
415-
obs_cov_idx = ("obs_cov",) + np.diag_indices(self.k_endog)
415+
obs_cov_idx = ("obs_cov", *np.diag_indices(self.k_endog))
416416
obs_cov = self.make_and_register_variable(
417417
"sigma_obs", shape=() if self.k_endog == 1 else (self.k_endog,), dtype=floatX
418418
)

tests/statespace/test_ETS.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytensor
33
import pytest
44
import statsmodels.api as sm
5+
56
from numpy.testing import assert_allclose
67
from pytensor.graph.basic import explicit_graph_inputs
78
from scipy import linalg
@@ -20,11 +21,11 @@ def data():
2021
def tests_invalid_order_raises():
2122
# Order must be length 3
2223
with pytest.raises(ValueError, match="Order must be a tuple of three strings"):
23-
BayesianETS(order=("A", "N")) # noqa
24+
BayesianETS(order=("A", "N"))
2425

2526
# Order must be strings
2627
with pytest.raises(ValueError, match="Order must be a tuple of three strings"):
27-
BayesianETS(order=(2, 1, 1)) # noqa
28+
BayesianETS(order=(2, 1, 1))
2829

2930
# Only additive errors allowed
3031
with pytest.raises(ValueError, match="Only additive errors are supported"):

0 commit comments

Comments
 (0)