Skip to content

Commit 200a577

Browse files
Adjust statespace to match statsmodels
1 parent 50e71ba commit 200a577

File tree

2 files changed

+156
-94
lines changed

2 files changed

+156
-94
lines changed

pymc_experimental/statespace/models/ETS.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ def __init__(
208208
@property
209209
def param_names(self):
210210
names = [
211-
"x0",
211+
"initial_level",
212+
"initial_trend",
213+
"initial_seasonal",
212214
"P0",
213215
"alpha",
214216
"beta",
@@ -218,10 +220,12 @@ def param_names(self):
218220
"sigma_obs",
219221
]
220222
if not self.trend:
223+
names.remove("initial_trend")
221224
names.remove("beta")
222225
if not self.damped_trend:
223226
names.remove("phi")
224227
if not self.seasonal:
228+
names.remove("initial_seasonal")
225229
names.remove("gamma")
226230
if not self.measurement_error:
227231
names.remove("sigma_obs")
@@ -231,14 +235,19 @@ def param_names(self):
231235
@property
232236
def param_info(self) -> dict[str, dict[str, Any]]:
233237
info = {
234-
"x0": {
235-
"shape": (self.k_states,),
236-
"constraints": None,
237-
},
238238
"P0": {
239239
"shape": (self.k_states, self.k_states),
240240
"constraints": "Positive Semi-definite",
241241
},
242+
"initial_level": {
243+
"shape": None if self.k_endog == 1 else (self.k_endog,),
244+
"constraints": None,
245+
},
246+
"initial_trend": {
247+
"shape": None if self.k_endog == 1 else (self.k_endog,),
248+
"constraints": None,
249+
},
250+
"initial_seasonal": {"shape": (self.seasonal_periods,), "constraints": None},
242251
"sigma_obs": {
243252
"shape": None if self.k_endog == 1 else (self.k_endog,),
244253
"constraints": "Positive",
@@ -291,16 +300,20 @@ def shock_names(self):
291300
@property
292301
def param_dims(self):
293302
coord_map = {
294-
"x0": (ALL_STATE_DIM,),
295303
"P0": (ALL_STATE_DIM, ALL_STATE_AUX_DIM),
296304
"sigma_obs": (OBS_STATE_DIM,),
297305
"sigma_state": (OBS_STATE_DIM,),
306+
"initial_level": (OBS_STATE_DIM,),
307+
"initial_trend": (OBS_STATE_DIM,),
308+
"initial_seasonal": (ETS_SEASONAL_DIM,),
298309
"seasonal_param": (ETS_SEASONAL_DIM,),
299310
}
300311

301312
if self.k_endog == 1:
302-
coord_map["sigma_state"] = ()
303-
coord_map["sigma_obs"] = ()
313+
coord_map["sigma_state"] = None
314+
coord_map["sigma_obs"] = None
315+
coord_map["initial_level"] = None
316+
coord_map["initial_trend"] = None
304317
if not self.measurement_error:
305318
del coord_map["sigma_obs"]
306319
if not self.seasonal:
@@ -317,15 +330,16 @@ def coords(self) -> dict[str, Sequence]:
317330
return coords
318331

319332
def make_symbolic_graph(self) -> None:
320-
x0 = self.make_and_register_variable("x0", shape=(self.k_states,), dtype=floatX)
321333
P0 = self.make_and_register_variable(
322334
"P0", shape=(self.k_states, self.k_states), dtype=floatX
323335
)
324-
325-
# x0, P0, Z, and R do not depend on the user config beyond the shape
326-
self.ssm["initial_state", :] = x0
327336
self.ssm["initial_state_cov"] = P0
328337

338+
initial_level = self.make_and_register_variable(
339+
"initial_level", shape=(self.k_endog,) if self.k_endog > 1 else (), dtype=floatX
340+
)
341+
self.ssm["initial_state", 1] = initial_level
342+
329343
# The shape of R can be pre-allocated, then filled with the required parameters
330344
R = pt.zeros((self.k_states, self.k_posdef))
331345
R = pt.set_subtensor(R[0, :], 1.0) # We will always have y_t = ... + e_t
@@ -337,6 +351,11 @@ def make_symbolic_graph(self) -> None:
337351
T_base = pt.as_tensor_variable(np.array([[0.0, 0.0], [0.0, 1.0]]))
338352

339353
if self.trend:
354+
initial_trend = self.make_and_register_variable(
355+
"initial_trend", shape=(self.k_endog,) if self.k_endog > 1 else (), dtype=floatX
356+
)
357+
self.ssm["initial_state", 2] = initial_trend
358+
340359
beta = self.make_and_register_variable("beta", shape=(), dtype=floatX)
341360
R = pt.set_subtensor(R[2, 0], beta)
342361

@@ -358,13 +377,19 @@ def make_symbolic_graph(self) -> None:
358377
T_components = [T_base]
359378

360379
if self.seasonal:
380+
initial_seasonal = self.make_and_register_variable(
381+
"initial_seasonal", shape=(self.seasonal_periods,), dtype=floatX
382+
)
383+
384+
self.ssm["initial_state", 2 + int(self.trend) :] = initial_seasonal
385+
361386
gamma = self.make_and_register_variable("gamma", shape=(), dtype=floatX)
362-
R = pt.set_subtensor(R[3, 0], gamma)
387+
R = pt.set_subtensor(R[2 + int(self.trend), 0], gamma)
363388

364389
# The seasonal component is always going to look like a TimeFrequency structural component, see that
365390
# docstring for more details
366391
T_seasonal = pt.eye(self.seasonal_periods, k=-1)
367-
T_seasonal = pt.set_subtensor(T_seasonal[0, :], -1)
392+
T_seasonal = pt.set_subtensor(T_seasonal[0, -1], 1.0)
368393
T_components += [T_seasonal]
369394

370395
self.ssm["selection"] = R
@@ -375,8 +400,6 @@ def make_symbolic_graph(self) -> None:
375400
Z = np.zeros((self.k_endog, self.k_states))
376401
Z[0, 0] = 1.0 # innovation
377402
Z[0, 1] = 1.0 # level
378-
if self.trend:
379-
Z[0, 2] = 1.0
380403
if self.seasonal:
381404
Z[0, 2 + int(self.trend)] = 1.0
382405
self.ssm["design"] = Z

tests/statespace/test_ETS.py

Lines changed: 117 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import numpy as np
22
import pytensor
33
import pytest
4+
import statsmodels.api as sm
45
from numpy.testing import assert_allclose
56
from pytensor.graph.basic import explicit_graph_inputs
67
from scipy import linalg
8+
from statespace.utils.constants import LONG_MATRIX_NAMES
79

810
from pymc_experimental.statespace.models.ETS import BayesianETS
9-
from tests.statespace.utilities.test_helpers import (
10-
load_nile_test_data,
11-
simulate_from_numpy_model,
12-
)
11+
from tests.statespace.utilities.shared_fixtures import rng
12+
from tests.statespace.utilities.test_helpers import load_nile_test_data
1313

1414

1515
@pytest.fixture(scope="session")
@@ -43,78 +43,66 @@ def tests_invalid_order_raises():
4343
BayesianETS(order=("A", "Ad", "A"))
4444

4545

46+
orders = (
47+
("A", "N", "N"),
48+
("A", "A", "N"),
49+
("A", "Ad", "N"),
50+
("A", "N", "A"),
51+
("A", "A", "A"),
52+
("A", "Ad", "A"),
53+
)
54+
order_names = (
55+
"Basic",
56+
"Trend",
57+
"Damped Trend",
58+
"Seasonal",
59+
"Trend and Seasonal",
60+
"Trend, Damped Trend, Seasonal",
61+
)
62+
63+
order_expected_flags = (
64+
{"trend": False, "damped_trend": False, "seasonal": False},
65+
{"trend": True, "damped_trend": False, "seasonal": False},
66+
{"trend": True, "damped_trend": True, "seasonal": False},
67+
{"trend": False, "damped_trend": False, "seasonal": True},
68+
{"trend": True, "damped_trend": False, "seasonal": True},
69+
{"trend": True, "damped_trend": True, "seasonal": True},
70+
)
71+
72+
order_params = (
73+
["alpha", "initial_level"],
74+
["alpha", "initial_level", "beta", "initial_trend"],
75+
["alpha", "initial_level", "beta", "initial_trend", "phi"],
76+
["alpha", "initial_level", "gamma", "initial_seasonal"],
77+
["alpha", "initial_level", "beta", "initial_trend", "gamma", "initial_seasonal"],
78+
["alpha", "initial_level", "beta", "initial_trend", "gamma", "initial_seasonal", "phi"],
79+
)
80+
81+
4682
@pytest.mark.parametrize(
47-
"order, expected_flags",
48-
[
49-
(("A", "N", "N"), {"trend": False, "damped_trend": False, "seasonal": False}),
50-
(("A", "A", "N"), {"trend": True, "damped_trend": False, "seasonal": False}),
51-
(("A", "Ad", "N"), {"trend": True, "damped_trend": True, "seasonal": False}),
52-
(("A", "N", "A"), {"trend": False, "damped_trend": False, "seasonal": True}),
53-
(("A", "A", "A"), {"trend": True, "damped_trend": False, "seasonal": True}),
54-
(("A", "Ad", "A"), {"trend": True, "damped_trend": True, "seasonal": True}),
55-
],
56-
ids=[
57-
"Basic",
58-
"Trend",
59-
"Damped Trend",
60-
"Seasonal",
61-
"Trend and Seasonal",
62-
"Trend, Damped Trend, Seasonal",
63-
],
83+
"order, expected_flags", zip(orders, order_expected_flags), ids=order_names
6484
)
6585
def test_order_flags(order, expected_flags):
6686
mod = BayesianETS(order=order, seasonal_periods=4)
6787
for key, value in expected_flags.items():
6888
assert getattr(mod, key) == value
6989

7090

71-
@pytest.mark.parametrize(
72-
"order, expected_params",
73-
[
74-
(("A", "N", "N"), ["alpha"]),
75-
(("A", "A", "N"), ["alpha", "beta"]),
76-
(("A", "Ad", "N"), ["alpha", "beta", "phi"]),
77-
(("A", "N", "A"), ["alpha", "gamma"]),
78-
(("A", "A", "A"), ["alpha", "beta", "gamma"]),
79-
(("A", "Ad", "A"), ["alpha", "beta", "gamma", "phi"]),
80-
],
81-
ids=[
82-
"Basic",
83-
"Trend",
84-
"Damped Trend",
85-
"Seasonal",
86-
"Trend and Seasonal",
87-
"Trend, Damped Trend, Seasonal",
88-
],
89-
)
91+
@pytest.mark.parametrize("order, expected_params", zip(orders, order_params), ids=order_names)
9092
def test_param_info(order: tuple[str, str, str], expected_params):
9193
mod = BayesianETS(order=order, seasonal_periods=4)
9294

93-
all_expected_params = [*expected_params, "sigma_state", "x0", "P0"]
95+
all_expected_params = [*expected_params, "sigma_state", "P0"]
9496
assert all(param in mod.param_names for param in all_expected_params)
9597
assert all(param in all_expected_params for param in mod.param_names)
96-
assert all(mod.param_info[param]["dims"] is None for param in expected_params)
98+
assert all(
99+
mod.param_info[param]["dims"] is None
100+
for param in expected_params
101+
if "seasonal" not in param
102+
)
97103

98104

99-
@pytest.mark.parametrize(
100-
"order, expected_params",
101-
[
102-
(("A", "N", "N"), ["alpha"]),
103-
(("A", "A", "N"), ["alpha", "beta"]),
104-
(("A", "Ad", "N"), ["alpha", "beta", "phi"]),
105-
(("A", "N", "A"), ["alpha", "gamma"]),
106-
(("A", "A", "A"), ["alpha", "beta", "gamma"]),
107-
(("A", "Ad", "A"), ["alpha", "beta", "gamma", "phi"]),
108-
],
109-
ids=[
110-
"Basic",
111-
"Trend",
112-
"Damped Trend",
113-
"Seasonal",
114-
"Trend and Seasonal",
115-
"Trend, Damped Trend, Seasonal",
116-
],
117-
)
105+
@pytest.mark.parametrize("order, expected_params", zip(orders, order_params), ids=order_names)
118106
def test_statespace_matrices(order: tuple[str, str, str], expected_params: list[str]):
119107
seasonal_periods = np.random.randint(3, 12)
120108
mod = BayesianETS(order=order, seasonal_periods=seasonal_periods, measurement_error=True)
@@ -127,7 +115,9 @@ def test_statespace_matrices(order: tuple[str, str, str], expected_params: list[
127115
"phi": 0.95,
128116
"sigma_state": 0.1,
129117
"sigma_obs": 0.1,
130-
"x0": np.zeros(expected_states),
118+
"initial_level": 3.0,
119+
"initial_trend": 1.0,
120+
"initial_seasonal": np.ones(seasonal_periods),
131121
"initial_state_cov": np.eye(expected_states),
132122
}
133123

@@ -161,42 +151,91 @@ def test_statespace_matrices(order: tuple[str, str, str], expected_params: list[
161151
Z_val[0, 0] = 1.0
162152
Z_val[0, 1] = 1.0
163153

154+
x0_val = np.zeros((expected_states,))
155+
x0_val[1] = test_values["initial_level"]
156+
164157
if order[1] == "N":
165158
T_val = np.array([[0.0, 0.0], [0.0, 1.0]])
166159
else:
160+
x0_val[2] = test_values["initial_trend"]
167161
R_val[2] = test_values["beta"]
168162
T_val = np.array([[0.0, 0.0, 0.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]])
169-
Z_val[0, 2] = 1.0
170163

171164
if order[1] == "Ad":
172165
T_val[1:, -1] *= test_values["phi"]
173166

174167
if order[2] == "A":
175-
R_val[3] = test_values["gamma"]
168+
x0_val[2 + int(order[1] != "N") :] = test_values["initial_seasonal"]
169+
R_val[2 + int(order[1] != "N")] = test_values["gamma"]
176170
S = np.eye(seasonal_periods, k=-1)
177-
S[0, :] = -1
171+
S[0, -1] = 1.0
178172
Z_val[0, 2 + int(order[1] != "N")] = 1.0
179173
else:
180174
S = np.eye(0)
181175

182176
T_val = linalg.block_diag(T_val, S)
183177

178+
assert_allclose(x0, x0_val)
184179
assert_allclose(T, T_val)
185180
assert_allclose(R, R_val)
186181
assert_allclose(Z, Z_val)
187182

188183

189-
def test_deterministic_simulation_matches_statsmodels():
190-
mod = BayesianETS(order=("A", "Ad", "A"), seasonal_periods=4, measurement_error=False)
184+
@pytest.mark.parametrize("order, params", zip(orders, order_params), ids=order_names)
185+
def test_statespace_matches_statsmodels(rng, order: tuple[str, str, str], params):
186+
seasonal_periods = rng.integers(3, 12)
187+
data = rng.normal(size=(100,))
188+
mod = BayesianETS(order=order, seasonal_periods=seasonal_periods, measurement_error=False)
189+
sm_mod = sm.tsa.statespace.ExponentialSmoothing(
190+
data,
191+
trend=mod.trend,
192+
damped_trend=mod.damped_trend,
193+
seasonal=seasonal_periods if mod.seasonal else None,
194+
)
195+
196+
simplex_params = ["alpha", "beta", "gamma"]
197+
test_values = dict(zip(simplex_params, rng.dirichlet(alpha=np.ones(3))))
198+
test_values["phi"] = rng.beta(1, 1)
199+
200+
test_values["initial_level"] = rng.normal()
201+
test_values["initial_trend"] = rng.normal()
202+
test_values["initial_seasonal"] = rng.normal(size=seasonal_periods)
203+
test_values["initial_state_cov"] = np.eye(mod.k_states)
204+
test_values["sigma_state"] = 1.0
205+
206+
sm_test_values = test_values.copy()
207+
sm_test_values["smoothing_level"] = test_values["alpha"]
208+
sm_test_values["smoothing_trend"] = test_values["beta"]
209+
sm_test_values["smoothing_seasonal"] = test_values["gamma"]
210+
sm_test_values["damping_trend"] = test_values["phi"]
211+
sm_test_values["initial_seasonal"] = test_values["initial_seasonal"][0]
212+
for i in range(1, seasonal_periods):
213+
sm_test_values[f"initial_seasonal.L{i}"] = test_values["initial_seasonal"][i]
214+
215+
x0 = np.r_[
216+
0, *[test_values[name] for name in ["initial_level", "initial_trend", "initial_seasonal"]]
217+
]
218+
mask = [True, True, order[1] != "N", *(order[2] != "N",) * seasonal_periods]
219+
220+
sm_mod.initialize_known(initial_state=x0[mask], initial_state_cov=np.eye(mod.k_states))
221+
sm_mod.fit_constrained({name: sm_test_values[name] for name in sm_mod.param_names})
222+
223+
matrices = mod._unpack_statespace_with_placeholders()
224+
inputs = list(explicit_graph_inputs(matrices))
225+
input_names = [x.name for x in inputs]
191226

192-
rng = np.random.default_rng()
193-
test_values = {
194-
"alpha": 0.7,
195-
"beta": 0.15,
196-
"gamma": 0.15,
197-
"phi": 0.95,
198-
"sigma_state": 0.0,
199-
"x0": rng.normal(size=(7,)),
200-
"initial_state_cov": np.eye(7),
201-
}
202-
hidden_states, observed = simulate_from_numpy_model(mod, rng, test_values)
227+
f_matrices = pytensor.function(inputs, matrices)
228+
test_values_subset = {name: test_values[name] for name in input_names}
229+
230+
matrices = f_matrices(**test_values_subset)
231+
sm_matrices = [sm_mod.ssm[name] for name in LONG_MATRIX_NAMES[2:]]
232+
233+
for matrix, sm_matrix, name in zip(matrices[2:], sm_matrices, LONG_MATRIX_NAMES[2:]):
234+
if name == "selection":
235+
# statsmodel selection matrix seems to be wrong? They set the first element of the selection matrix to
236+
# 1 - sum(alpha, beta, gamma), which doesn't match the equations presented in ffp3
237+
assert_allclose(matrix[1:], sm_matrix[1:], err_msg=f"{name} does not match")
238+
assert matrix[0] == 1.0
239+
assert sm_matrix[0] != 1.0
240+
else:
241+
assert_allclose(matrix, sm_matrix, err_msg=f"{name} does not match")

0 commit comments

Comments
 (0)