Skip to content

Commit 0121315

Browse files
Require all step methods to return stats
The reason for this change is the resulting simplification of code, including simpler branching and less type ambiguity. At the same time it allowed for fixing of a lot of type hints and method signatures on step methods. Closes #6270
1 parent 4acd98e commit 0121315

File tree

13 files changed

+97
-131
lines changed

13 files changed

+97
-131
lines changed

pymc/blocking.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,18 @@
2020
from __future__ import annotations
2121

2222
from functools import partial
23-
from typing import Callable, Dict, Generic, NamedTuple, TypeVar
23+
from typing import Any, Callable, Dict, Generic, List, NamedTuple, TypeVar
2424

2525
import numpy as np
2626

27+
from typing_extensions import TypeAlias
28+
2729
__all__ = ["DictToArrayBijection"]
2830

2931

3032
T = TypeVar("T")
31-
PointType = Dict[str, np.ndarray]
33+
PointType: TypeAlias = Dict[str, np.ndarray]
34+
StatsType: TypeAlias = List[Dict[str, Any]]
3235

3336
# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for
3437
# each of the raveled variables.

pymc/sampling/mcmc.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -912,14 +912,10 @@ def _iter_sample(
912912
step.iter_count = 0
913913
if i == tune:
914914
step.stop_tuning()
915-
if step.generates_stats:
916-
point, stats = step.step(point)
917-
strace.record(point, stats)
918-
log_warning_stats(stats)
919-
diverging = i > tune and stats and stats[0].get("diverging")
920-
else:
921-
point = step.step(point)
922-
strace.record(point, [])
915+
point, stats = step.step(point)
916+
strace.record(point, stats)
917+
log_warning_stats(stats)
918+
diverging = i > tune and stats and stats[0].get("diverging")
923919
if callback is not None:
924920
callback(
925921
trace=strace,

pymc/sampling/parallel.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def _start_loop(self):
173173

174174
if draw < self._draws + self._tune:
175175
try:
176-
point, stats = self._compute_point()
176+
point, stats = self._step_method.step(self._point)
177177
except SamplingError as e:
178178
e = ExceptionWithTraceback(e, e.__traceback__)
179179
self._msg_pipe.send(("error", e))
@@ -191,14 +191,6 @@ def _start_loop(self):
191191
else:
192192
raise ValueError("Unknown message " + msg[0])
193193

194-
def _compute_point(self):
195-
if self._step_method.generates_stats:
196-
point, stats = self._step_method.step(self._point)
197-
else:
198-
point = self._step_method.step(self._point)
199-
stats = None
200-
return point, stats
201-
202194

203195
def _run_process(*args):
204196
_Process(*args).run()

pymc/sampling/population.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import logging
1818

1919
from copy import copy
20-
from typing import Iterator, List, Sequence, Union
20+
from typing import Iterator, List, Sequence, Tuple, Union
2121

2222
import cloudpickle
2323
import numpy as np
@@ -31,7 +31,11 @@
3131
from pymc.model import modelcontext
3232
from pymc.stats.convergence import log_warning_stats
3333
from pymc.step_methods import CompoundStep
34-
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
34+
from pymc.step_methods.arraystep import (
35+
BlockedStep,
36+
PopulationArrayStepShared,
37+
StatsType,
38+
)
3539
from pymc.util import RandomSeed
3640

3741
__all__ = ()
@@ -224,7 +228,7 @@ def _run_secondary(c, stepper_dumps, secondary_end):
224228
_log.exception(f"ChainWalker{c}")
225229
return
226230

227-
def step(self, tune_stop: bool, population):
231+
def step(self, tune_stop: bool, population) -> List[Tuple[PointType, StatsType]]:
228232
"""Step the entire population of chains.
229233
230234
Parameters
@@ -239,18 +243,18 @@ def step(self, tune_stop: bool, population):
239243
update : list
240244
List of (Point, stats) tuples for all chains
241245
"""
242-
updates = [None] * self.nchains
246+
updates: List[Tuple[PointType, StatsType]] = []
243247
if self.is_parallelized:
244248
for c in range(self.nchains):
245249
self._primary_ends[c].send((tune_stop, population))
246250
# Blockingly get the step outcomes
247251
for c in range(self.nchains):
248-
updates[c] = self._primary_ends[c].recv()
252+
updates.append(self._primary_ends[c].recv())
249253
else:
250254
for c in range(self.nchains):
251255
if tune_stop:
252256
self._steppers[c].stop_tuning()
253-
updates[c] = self._steppers[c].step(population[c])
257+
updates.append(self._steppers[c].step(population[c]))
254258
return updates
255259

256260

@@ -378,13 +382,9 @@ def _iter_population(
378382

379383
# apply the update to the points and record to the traces
380384
for c, strace in enumerate(traces):
381-
if steppers[c].generates_stats:
382-
points[c], stats = updates[c]
383-
strace.record(points[c], stats)
384-
log_warning_stats(stats)
385-
else:
386-
points[c] = updates[c]
387-
strace.record(points[c])
385+
points[c], stats = updates[c]
386+
strace.record(points[c], stats)
387+
log_warning_stats(stats)
388388
# yield the state of all chains in parallel
389389
yield traces
390390
except KeyboardInterrupt:

pymc/step_methods/arraystep.py

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,20 @@
1414

1515
from abc import ABC, abstractmethod
1616
from enum import IntEnum, unique
17-
from typing import Dict, List, Tuple, TypeVar, Union
17+
from typing import Callable, Dict, List, Tuple, Union, cast
1818

1919
import numpy as np
2020

2121
from aesara.graph.basic import Variable
2222
from numpy.random import uniform
2323

24-
from pymc.blocking import DictToArrayBijection, PointType, RaveledVars
24+
from pymc.blocking import DictToArrayBijection, PointType, RaveledVars, StatsType
2525
from pymc.model import modelcontext
2626
from pymc.step_methods.compound import CompoundStep
2727
from pymc.util import get_var_name
2828

2929
__all__ = ["ArrayStep", "ArrayStepShared", "metrop_select", "Competence"]
3030

31-
StatsType = TypeVar("StatsType")
32-
3331

3432
@unique
3533
class Competence(IntEnum):
@@ -49,7 +47,6 @@ class Competence(IntEnum):
4947

5048
class BlockedStep(ABC):
5149

52-
generates_stats = False
5350
stats_dtypes: List[Dict[str, type]] = []
5451
vars: List[Variable] = []
5552

@@ -103,7 +100,7 @@ def __getnewargs_ex__(self):
103100
return self.__newargs
104101

105102
@abstractmethod
106-
def step(point: PointType, *args, **kwargs) -> Union[PointType, Tuple[PointType, StatsType]]:
103+
def step(self, point: PointType) -> Tuple[PointType, StatsType]:
107104
"""Perform a single step of the sampler."""
108105

109106
@staticmethod
@@ -146,35 +143,28 @@ def __init__(self, vars, fs, allvars=False, blocked=True):
146143
self.allvars = allvars
147144
self.blocked = blocked
148145

149-
def step(self, point: PointType):
146+
def step(self, point: PointType) -> Tuple[PointType, StatsType]:
150147

151-
partial_funcs_and_point = [DictToArrayBijection.mapf(x, start_point=point) for x in self.fs]
148+
partial_funcs_and_point: List[Union[Callable, PointType]] = [
149+
DictToArrayBijection.mapf(x, start_point=point) for x in self.fs
150+
]
152151
if self.allvars:
153152
partial_funcs_and_point.append(point)
154153

155-
apoint = DictToArrayBijection.map({v.name: point[v.name] for v in self.vars})
156-
step_res = self.astep(apoint, *partial_funcs_and_point)
157-
158-
if self.generates_stats:
159-
apoint_new, stats = step_res
160-
else:
161-
apoint_new = step_res
154+
var_dict = {cast(str, v.name): point[cast(str, v.name)] for v in self.vars}
155+
apoint = DictToArrayBijection.map(var_dict)
156+
apoint_new, stats = self.astep(apoint, *partial_funcs_and_point)
162157

163158
if not isinstance(apoint_new, RaveledVars):
164159
# We assume that the mapping has stayed the same
165160
apoint_new = RaveledVars(apoint_new, apoint.point_map_info)
166161

167162
point_new = DictToArrayBijection.rmap(apoint_new, start_point=point)
168163

169-
if self.generates_stats:
170-
return point_new, stats
171-
172-
return point_new
164+
return point_new, stats
173165

174166
@abstractmethod
175-
def astep(
176-
self, apoint: RaveledVars, point: PointType, *args
177-
) -> Union[RaveledVars, Tuple[RaveledVars, StatsType]]:
167+
def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
178168
"""Perform a single sample step in a raveled and concatenated parameter space."""
179169

180170

@@ -198,30 +188,27 @@ def __init__(self, vars, shared, blocked=True):
198188
self.shared = {get_var_name(var): shared for var, shared in shared.items()}
199189
self.blocked = blocked
200190

201-
def step(self, point):
191+
def step(self, point: PointType) -> Tuple[PointType, StatsType]:
202192

203193
for name, shared_var in self.shared.items():
204194
shared_var.set_value(point[name])
205195

206-
q = DictToArrayBijection.map({v.name: point[v.name] for v in self.vars})
196+
var_dict = {cast(str, v.name): point[cast(str, v.name)] for v in self.vars}
197+
q = DictToArrayBijection.map(var_dict)
207198

208-
step_res = self.astep(q)
209-
210-
if self.generates_stats:
211-
apoint, stats = step_res
212-
else:
213-
apoint = step_res
199+
apoint, stats = self.astep(q)
214200

215201
if not isinstance(apoint, RaveledVars):
216202
# We assume that the mapping has stayed the same
217203
apoint = RaveledVars(apoint, q.point_map_info)
218204

219205
new_point = DictToArrayBijection.rmap(apoint, start_point=point)
220206

221-
if self.generates_stats:
222-
return new_point, stats
207+
return new_point, stats
223208

224-
return new_point
209+
@abstractmethod
210+
def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
211+
"""Perform a single sample step in a raveled and concatenated parameter space."""
225212

226213

227214
class PopulationArrayStepShared(ArrayStepShared):
@@ -281,7 +268,7 @@ def __init__(
281268

282269
super().__init__(vars, func._extra_vars_shared, blocked)
283270

284-
def step(self, point):
271+
def step(self, point) -> Tuple[PointType, StatsType]:
285272
self._logp_dlogp_func._extra_are_set = True
286273
return super().step(point)
287274

pymc/step_methods/compound.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
1818
@author: johnsalvatier
1919
"""
20-
from collections import namedtuple
2120

22-
import numpy as np
21+
22+
from typing import Tuple
23+
24+
from pymc.blocking import PointType, StatsType
2325

2426

2527
class CompoundStep:
@@ -28,36 +30,23 @@ class CompoundStep:
2830

2931
def __init__(self, methods):
3032
self.methods = list(methods)
31-
self.generates_stats = any(method.generates_stats for method in self.methods)
3233
self.stats_dtypes = []
3334
for method in self.methods:
34-
if method.generates_stats:
35-
self.stats_dtypes.extend(method.stats_dtypes)
35+
self.stats_dtypes.extend(method.stats_dtypes)
3636
self.name = (
3737
f"Compound[{', '.join(getattr(m, 'name', 'UNNAMED_STEP') for m in self.methods)}]"
3838
)
3939

40-
def step(self, point):
41-
if self.generates_stats:
42-
states = []
43-
for method in self.methods:
44-
if method.generates_stats:
45-
point, state = method.step(point)
46-
states.extend(state)
47-
else:
48-
point = method.step(point)
49-
# Model logp can only be the logp of the _last_ state, if there is
50-
# one. Pop all others (if dict), or set to np.nan (if namedtuple).
51-
for state in states[:-1]:
52-
if isinstance(state, dict):
53-
state.pop("model_logp", None)
54-
elif isinstance(state, namedtuple):
55-
state = state._replace(logp=np.nan)
56-
return point, states
57-
else:
58-
for method in self.methods:
59-
point = method.step(point)
60-
return point
40+
def step(self, point) -> Tuple[PointType, StatsType]:
41+
stats = []
42+
for method in self.methods:
43+
point, sts = method.step(point)
44+
stats.extend(sts)
45+
# Model logp can only be the logp of the _last_ stats,
46+
# if there is one. Pop all others.
47+
for sts in stats[:-1]:
48+
sts.pop("model_logp", None)
49+
return point, stats
6150

6251
def stop_tuning(self):
6352
for method in self.methods:

pymc/step_methods/hmc/base_hmc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import numpy as np
2424

2525
from pymc.aesaraf import floatX
26-
from pymc.blocking import DictToArrayBijection, RaveledVars
26+
from pymc.blocking import DictToArrayBijection, RaveledVars, StatsType
2727
from pymc.exceptions import SamplingError
2828
from pymc.model import Point, modelcontext
2929
from pymc.stats.convergence import SamplerWarning, WarningType
@@ -157,7 +157,7 @@ def _hamiltonian_step(self, start, p0, step_size) -> HMCStepData:
157157
Subclasses must overwrite this abstract method and return an `HMCStepData` object.
158158
"""
159159

160-
def astep(self, q0: RaveledVars) -> tuple[RaveledVars, list[dict[str, Any]]]:
160+
def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
161161
"""Perform a single HMC iteration."""
162162
perf_start = time.perf_counter()
163163
process_start = time.process_time()

pymc/step_methods/hmc/hmc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ class HamiltonianMC(BaseHMC):
3939

4040
name = "hmc"
4141
default_blocked = True
42-
generates_stats = True
4342
stats_dtypes = [
4443
{
4544
"step_size": np.float64,

pymc/step_methods/hmc/nuts.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ class NUTS(BaseHMC):
9797
name = "nuts"
9898

9999
default_blocked = True
100-
generates_stats = True
101100
stats_dtypes = [
102101
{
103102
"depth": np.int64,

0 commit comments

Comments
 (0)