Skip to content

Commit d6397f8

Browse files
committed
ENH: add a test that result_type does not depend on the order of arguments
1 parent 0b89c52 commit d6397f8

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
from hypothesis import assume, reject
1111
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
12-
integers, just, lists, none, one_of,
13-
sampled_from, shared, builds, nothing)
12+
integers, complex_numbers, just, lists, none, one_of,
13+
sampled_from, shared, builds, nothing, permutations)
1414

1515
from . import _array_module as xp, api_version
1616
from . import array_helpers as ah
@@ -148,6 +148,13 @@ def mutually_promotable_dtypes(
148148
return one_of(strats).map(tuple)
149149

150150

151+
@composite
152+
def pair_of_mutually_promotable_dtypes(draw, max_size=2, *, dtypes=dh.all_dtypes):
153+
sample = draw(mutually_promotable_dtypes( max_size, dtypes=dtypes))
154+
permuted = draw(permutations(sample))
155+
return sample, permuted
156+
157+
151158
class OnewayPromotableDtypes(NamedTuple):
152159
input_dtype: DataType
153160
result_dtype: DataType

array_api_tests/test_data_type_functions.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,17 @@ def test_isdtype(dtype, kind):
208208
assert out == expected, f"{out=}, but should be {expected} [isdtype()]"
209209

210210

211-
@given(hh.mutually_promotable_dtypes(None))
212-
def test_result_type(dtypes):
213-
out = xp.result_type(*dtypes)
214-
ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out")
211+
class TestResultType:
212+
@given(dtypes=hh.mutually_promotable_dtypes(None))
213+
def test_result_type(self, dtypes):
214+
out = xp.result_type(*dtypes)
215+
ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out")
216+
217+
@given(pair=hh.pair_of_mutually_promotable_dtypes(None))
218+
def test_shuffled(self, pair):
219+
"""Test that result_type is insensitive to the order of arguments."""
220+
s1, s2 = pair
221+
out1 = xp.result_type(*s1)
222+
out2 = xp.result_type(*s2)
223+
assert out1 == out2
224+

0 commit comments

Comments
 (0)