Skip to content

Commit 54aad19

Browse files
committed
[NullBackend] Add tests; fix issue with preallocate=0
Tests were copied from test_backend_numpy and the parts checking the `_samples` array removed. Fixed issue: - Reset default preallocation to 1_000, like with NumPyBackend: it is still used for the stats array, so it makes sense to use a reasonable default. - Preallocate = 0 no longer switches the allocation to object arrays, in contrast to NumPyBackend - IMO this is a bug in NumPyBackend: `grow_append` cannot know if ``preallocate = 0`` was used; it only looks at the `rigid` value to determine how to append. - Without this change, `grow_append` will always fail when we use `preallocate = 0` with multivariate statistics.
1 parent 5e6e90b commit 54aad19

File tree

2 files changed

+217
-4
lines changed

2 files changed

+217
-4
lines changed

mcbackend/backends/null.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(self, cmeta: ChainMeta, rmeta: RunMeta, *, preallocate: int) -> Non
5050
and grow the allocated memory by 10 % when needed.
5151
Exceptions are variables with non-rigid shapes (indicated by 0 in the shape tuple)
5252
where the correct amount of memory cannot be pre-allocated.
53-
In these cases, and when ``preallocate == 0`` object arrays are used.
53+
In these cases object arrays are used.
5454
"""
5555
self._stat_is_rigid: Dict[str, bool] = {}
5656
self._stats: Dict[str, numpy.ndarray] = {}
@@ -63,7 +63,7 @@ def __init__(self, cmeta: ChainMeta, rmeta: RunMeta, *, preallocate: int) -> Non
6363
for var in variables:
6464
rigid = is_rigid(var.shape) and not var.undefined_ndim and var.dtype != "str"
6565
rigid_dict[var.name] = rigid
66-
if preallocate > 0 and rigid:
66+
if rigid:
6767
reserve = (preallocate, *var.shape)
6868
target_dict[var.name] = numpy.empty(reserve, var.dtype)
6969
else:
@@ -101,7 +101,7 @@ def get_stats_at(self, idx: int, stat_names: Sequence[str]) -> Dict[str, numpy.n
101101
class NullRun(Run):
102102
"""An MCMC run where samples are immediately discarded."""
103103

104-
def __init__(self, meta: RunMeta, *, preallocate: int=0) -> None:
104+
def __init__(self, meta: RunMeta, *, preallocate: int) -> None:
105105
self._settings = {"preallocate": preallocate}
106106
self._chains: List[NullChain] = []
107107
super().__init__(meta)
@@ -119,7 +119,7 @@ def get_chains(self) -> Tuple[NullChain, ...]:
119119
class NullBackend(Backend):
120120
"""A backend which discards samples immediately."""
121121

122-
def __init__(self, preallocate: int=0) -> None:
122+
def __init__(self, preallocate: int = 1_000) -> None:
123123
self._settings = {"preallocate": preallocate}
124124
super().__init__()
125125

mcbackend/test_backend_null.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
import random
2+
3+
import hagelkorn
4+
import numpy
5+
import pytest
6+
7+
from mcbackend.backends.null import NullBackend, NullChain, NullRun
8+
from mcbackend.core import RunMeta, is_rigid
9+
from mcbackend.meta import Variable
10+
from mcbackend.test_utils import CheckBehavior, CheckPerformance, make_runmeta, make_draw
11+
12+
class CheckNullBehavior(CheckBehavior):
13+
"""
14+
Overrides tests which assert that data are recorded correctly
15+
We perform all the operations of the original test, but in the
16+
end we do the opposite: assert that an exception is raised
17+
when either `get_draws` or `get_draws_at` is called.
18+
Stats are still recorded, so that part of the tests is reproduced unchanged.
19+
"""
20+
21+
@pytest.mark.parametrize("with_stats", [False, True])
22+
def test__append_get_at(self, with_stats):
23+
rmeta = make_runmeta()
24+
run = self.backend.init_run(rmeta)
25+
chain = run.init_chain(7)
26+
27+
# Generate data
28+
draw = make_draw(rmeta.variables)
29+
stats = make_draw(rmeta.sample_stats) if with_stats else None
30+
31+
# Append to the chain
32+
assert len(chain) == 0
33+
chain.append(draw, stats)
34+
assert len(chain) == 1
35+
36+
# Retrieve by index - Raises exception
37+
with pytest.raises(RuntimeError):
38+
chain.get_draws_at(0, [v.name for v in rmeta.variables])
39+
40+
# NB: Stats are still recorded and can be retrieved as with other chains
41+
if with_stats:
42+
actual = chain.get_stats_at(0, [v.name for v in rmeta.sample_stats])
43+
assert isinstance(actual, dict)
44+
assert set(actual) == set(stats)
45+
for vn, act in actual.items():
46+
numpy.testing.assert_array_equal(act, stats[vn])
47+
pass
48+
49+
@pytest.mark.parametrize("with_stats", [False, True])
50+
def test__append_get_with_changelings(self, with_stats):
51+
rmeta = make_runmeta(flexibility=True)
52+
run = self.backend.init_run(rmeta)
53+
chain = run.init_chain(7)
54+
55+
# Generate draws and add them to the chain
56+
n = 10
57+
draws = [make_draw(rmeta.variables) for _ in range(n)]
58+
if with_stats:
59+
stats = [make_draw(rmeta.sample_stats) for _ in range(n)]
60+
else:
61+
stats = [None] * n
62+
63+
for d, s in zip(draws, stats):
64+
chain.append(d, s)
65+
66+
# Fetching variables raises exception
67+
for var in rmeta.variables:
68+
expected = [draw[var.name] for draw in draws]
69+
with pytest.raises(RuntimeError):
70+
chain.get_draws(var.name)
71+
72+
if with_stats:
73+
for var in rmeta.sample_stats:
74+
expected = [stat[var.name] for stat in stats]
75+
actual = chain.get_stats(var.name)
76+
assert isinstance(actual, numpy.ndarray)
77+
if var.dtype == "str":
78+
assert tuple(actual.shape) == tuple(numpy.shape(expected))
79+
# String dtypes have strange names
80+
assert "str" in actual.dtype.name
81+
elif is_rigid(var.shape):
82+
assert tuple(actual.shape) == tuple(numpy.shape(expected))
83+
assert actual.dtype.name == var.dtype
84+
numpy.testing.assert_array_equal(actual, expected)
85+
else:
86+
# Non-ridid variables are returned as object-arrays.
87+
assert actual.shape == (len(expected),)
88+
assert actual.dtype == object
89+
# Their values must be asserted elementwise to avoid shape problems.
90+
for act, exp in zip(actual, expected):
91+
numpy.testing.assert_array_equal(act, exp)
92+
pass
93+
94+
@pytest.mark.parametrize(
95+
"slc",
96+
[
97+
None,
98+
slice(None, None, None),
99+
slice(2, None, None),
100+
slice(2, 10, None),
101+
slice(2, 15, 3), # every 3rd
102+
slice(15, 2, -3), # backwards every 3rd
103+
slice(2, 15, -3), # empty
104+
slice(-8, None, None), # the last 8
105+
slice(-8, -2, 2),
106+
slice(-50, -2, 2),
107+
slice(15, 10), # empty
108+
slice(1, 1), # empty
109+
],
110+
)
111+
def test__get_slicing(self, slc: slice):
112+
# "A" are just numbers to make diagnosis easier.
113+
# "B" are dynamically shaped to cover the edge cases.
114+
rmeta = RunMeta(
115+
variables=[Variable("A", "uint8"), Variable("M", "str", [2, 3])],
116+
sample_stats=[Variable("B", "uint8", [2, -1])],
117+
data=[],
118+
)
119+
run = self.backend.init_run(rmeta)
120+
chain = run.init_chain(0)
121+
122+
# Generate draws and add them to the chain
123+
N = 20
124+
draws = [make_draw(rmeta.variables) for n in range(N)]
125+
stats = [make_draw(rmeta.sample_stats) for n in range(N)]
126+
for d, s in zip(draws, stats):
127+
chain.append(d, s)
128+
assert len(chain) == N
129+
130+
# slc=None in this test means "don't pass it".
131+
# The implementations should default to slc=slice(None, None, None).
132+
kwargs = dict(slc=slc) if slc is not None else {}
133+
with pytest.raises(RuntimeError):
134+
chain.get_draws("A", **kwargs)
135+
with pytest.raises(RuntimeError):
136+
chain.get_draws("M", **kwargs)
137+
act_stats = chain.get_stats("B", **kwargs)
138+
expected_stats = [s["B"] for s in stats][slc or slice(None, None, None)]
139+
140+
# Stat "B" is dynamically shaped, which means we're dealing with
141+
# dtype=object arrays. These must be checked elementwise.
142+
assert len(act_stats) == len(expected_stats)
143+
assert act_stats.dtype == object
144+
for a, e in zip(act_stats, expected_stats):
145+
numpy.testing.assert_array_equal(a, e)
146+
pass
147+
148+
def test__to_inferencedata(self):
149+
"""
150+
NullBackend doesn’t support `to_inferencedata`, so there isn’t
151+
anything to test here.
152+
"""
153+
pass
154+
155+
class TestNullBackend(CheckNullBehavior, CheckPerformance):
156+
cls_backend = NullBackend
157+
cls_run = NullRun
158+
cls_chain = NullChain
159+
160+
# `test_targets` and `test_growing` are copied over from TestNumPyBackend.
161+
# The lines testing sample storage removed, since neither `_samples`
162+
# nor `_var_is_rigid` are not supported by NullBackend.
163+
# However if one were to add tests for `_stats` and `_stat_is_rigid`
164+
# to the NumPy suite, we could port those here.
165+
166+
def test_targets(self):
167+
imb = NullBackend(preallocate=123)
168+
rm = RunMeta(
169+
rid=hagelkorn.random(),
170+
variables=[
171+
Variable("tensor", "int8", (3, 4, 5)),
172+
Variable("scalar", "float64", ()),
173+
Variable("changeling", "uint16", (3, -1)),
174+
],
175+
)
176+
run = imb.init_run(rm)
177+
chain = run.init_chain(0)
178+
pass
179+
180+
@pytest.mark.parametrize("preallocate", [0, 75])
181+
def test_growing(self, preallocate):
182+
imb = NullBackend(preallocate=preallocate)
183+
rm = RunMeta(
184+
rid=hagelkorn.random(),
185+
variables=[
186+
Variable(
187+
"A",
188+
"float32",
189+
(2,),
190+
),
191+
Variable(
192+
"B",
193+
"float32",
194+
(-1,),
195+
),
196+
],
197+
)
198+
run = imb.init_run(rm)
199+
chain = run.init_chain(0)
200+
# TODO: Check dimensions of stats array ?
201+
for _ in range(130):
202+
draw = {
203+
"A": numpy.random.uniform(size=(2,)),
204+
"B": numpy.random.uniform(size=(random.randint(0, 10),)),
205+
}
206+
chain.append(draw)
207+
# TODO: Check dimensions of stats array ?
208+
pass
209+
210+
if __name__ == "__main__":
211+
tc = TestNullBackend()
212+
df = tc.run_all_benchmarks()
213+
print(df)

0 commit comments

Comments
 (0)