Skip to content

Commit f2666dd

Browse files
Avoid code duplication
1 parent 1540224 commit f2666dd

File tree

4 files changed

+38
-38
lines changed

4 files changed

+38
-38
lines changed

mcbackend/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
A framework agnostic implementation for storage of MCMC draws.
33
"""
44

5-
from .backends.numpy import NumPyBackend
65
from .backends.null import NullBackend
6+
from .backends.numpy import NumPyBackend
77
from .core import Backend, Chain, Run
88
from .meta import ChainMeta, Coordinate, DataVariable, ExtendedValue, RunMeta, Variable
99

mcbackend/backends/null.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010

1111
import numpy
1212

13-
from ..core import Backend, Chain, Run, is_rigid
13+
from ..core import Backend, Chain, Run
1414
from ..meta import ChainMeta, RunMeta
15+
from .numpy import grow_append, prepare_storage
1516

16-
from .numpy import grow_append
1717

1818
class NullChain(Chain):
1919
"""A null storage: discards values immediately and allocates no memory.
@@ -52,26 +52,14 @@ def __init__(self, cmeta: ChainMeta, rmeta: RunMeta, *, preallocate: int) -> Non
5252
where the correct amount of memory cannot be pre-allocated.
5353
In these cases object arrays are used.
5454
"""
55-
self._stat_is_rigid: Dict[str, bool] = {}
56-
self._stats: Dict[str, numpy.ndarray] = {}
5755
self._draw_idx = 0
5856

59-
# Create storage ndarrays for each model variable and sampler stat.
60-
for target_dict, rigid_dict, variables in [
61-
(self._stats, self._stat_is_rigid, rmeta.sample_stats),
62-
]:
63-
for var in variables:
64-
rigid = is_rigid(var.shape) and not var.undefined_ndim and var.dtype != "str"
65-
rigid_dict[var.name] = rigid
66-
if rigid:
67-
reserve = (preallocate, *var.shape)
68-
target_dict[var.name] = numpy.empty(reserve, var.dtype)
69-
else:
70-
target_dict[var.name] = numpy.array([None] * preallocate, dtype=object)
57+
# Create storage ndarrays only for sampler stats.
58+
self._stats, self._stat_is_rigid = prepare_storage(rmeta.sample_stats, preallocate)
7159

7260
super().__init__(cmeta, rmeta)
7361

74-
def append(
62+
def append( # pylint: disable=duplicate-code
7563
self, draw: Mapping[str, numpy.ndarray], stats: Optional[Mapping[str, numpy.ndarray]] = None
7664
):
7765
if stats:
@@ -88,7 +76,9 @@ def get_draws(self, var_name: str, slc: slice = slice(None)) -> numpy.ndarray:
8876
def get_draws_at(self, idx: int, var_names: Sequence[str]) -> Dict[str, numpy.ndarray]:
8977
raise RuntimeError("NullChain does not save draws.")
9078

91-
def get_stats(self, stat_name: str, slc: slice = slice(None)) -> numpy.ndarray:
79+
def get_stats( # pylint: disable=duplicate-code
80+
self, stat_name: str, slc: slice = slice(None)
81+
) -> numpy.ndarray:
9282
data = self._stats[stat_name][: self._draw_idx][slc]
9383
if self.sample_stats[stat_name].dtype == "str":
9484
return numpy.array(data.tolist(), dtype=str)

mcbackend/backends/numpy.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
"""
44

55
import math
6-
from typing import Dict, List, Mapping, Optional, Sequence, Tuple
6+
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
77

88
import numpy
99

1010
from ..core import Backend, Chain, Run, is_rigid
11-
from ..meta import ChainMeta, RunMeta
11+
from ..meta import ChainMeta, RunMeta, Variable
1212

1313

1414
def grow_append(
@@ -34,6 +34,22 @@ def grow_append(
3434
return
3535

3636

37+
def prepare_storage(
38+
variables: Iterable[Variable], preallocate: int
39+
) -> Tuple[Dict[str, numpy.ndarray], Dict[str, bool]]:
40+
storage: Dict[str, numpy.ndarray] = {}
41+
rigid_dict: Dict[str, bool] = {}
42+
for var in variables:
43+
rigid = is_rigid(var.shape) and not var.undefined_ndim and var.dtype != "str"
44+
rigid_dict[var.name] = rigid
45+
if rigid:
46+
reserve = (preallocate, *var.shape)
47+
storage[var.name] = numpy.empty(reserve, var.dtype)
48+
else:
49+
storage[var.name] = numpy.array([None] * preallocate, dtype=object)
50+
return storage, rigid_dict
51+
52+
3753
class NumPyChain(Chain):
3854
"""Stores value draws in NumPy arrays and can pre-allocate memory."""
3955

@@ -54,25 +70,11 @@ def __init__(self, cmeta: ChainMeta, rmeta: RunMeta, *, preallocate: int) -> Non
5470
where the correct amount of memory cannot be pre-allocated.
5571
In these cases object arrays are used.
5672
"""
57-
self._var_is_rigid: Dict[str, bool] = {}
58-
self._samples: Dict[str, numpy.ndarray] = {}
59-
self._stat_is_rigid: Dict[str, bool] = {}
60-
self._stats: Dict[str, numpy.ndarray] = {}
6173
self._draw_idx = 0
6274

6375
# Create storage ndarrays for each model variable and sampler stat.
64-
for target_dict, rigid_dict, variables in [
65-
(self._samples, self._var_is_rigid, rmeta.variables),
66-
(self._stats, self._stat_is_rigid, rmeta.sample_stats),
67-
]:
68-
for var in variables:
69-
rigid = is_rigid(var.shape) and not var.undefined_ndim and var.dtype != "str"
70-
rigid_dict[var.name] = rigid
71-
if rigid:
72-
reserve = (preallocate, *var.shape)
73-
target_dict[var.name] = numpy.empty(reserve, var.dtype)
74-
else:
75-
target_dict[var.name] = numpy.array([None] * preallocate, dtype=object)
76+
self._samples, self._var_is_rigid = prepare_storage(rmeta.variables, preallocate)
77+
self._stats, self._stat_is_rigid = prepare_storage(rmeta.sample_stats, preallocate)
7678

7779
super().__init__(cmeta, rmeta)
7880

mcbackend/test_backend_null.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
from mcbackend.backends.null import NullBackend, NullChain, NullRun
88
from mcbackend.core import RunMeta, is_rigid
99
from mcbackend.meta import Variable
10-
from mcbackend.test_utils import CheckBehavior, CheckPerformance, make_runmeta, make_draw
10+
from mcbackend.test_utils import (
11+
CheckBehavior,
12+
CheckPerformance,
13+
make_draw,
14+
make_runmeta,
15+
)
16+
1117

1218
class CheckNullBehavior(CheckBehavior):
1319
"""
@@ -152,6 +158,7 @@ def test__to_inferencedata(self):
152158
"""
153159
pass
154160

161+
155162
class TestNullBackend(CheckNullBehavior, CheckPerformance):
156163
cls_backend = NullBackend
157164
cls_run = NullRun
@@ -207,6 +214,7 @@ def test_growing(self, preallocate):
207214
# TODO: Check dimensions of stats array ?
208215
pass
209216

217+
210218
if __name__ == "__main__":
211219
tc = TestNullBackend()
212220
df = tc.run_all_benchmarks()

0 commit comments

Comments
 (0)