diff --git a/src/array_api_extra/testing.py b/src/array_api_extra/testing.py index e124ed74..cc3f01f8 100644 --- a/src/array_api_extra/testing.py +++ b/src/array_api_extra/testing.py @@ -14,7 +14,7 @@ import pytest -from array_api_extra._lib._utils._compat import is_dask_namespace, is_jax_namespace +from ._lib._utils._compat import is_dask_namespace, is_jax_namespace __all__ = ["lazy_xp_function", "patch_lazy_xp_functions"] diff --git a/vendor_tests/test_vendor.py b/vendor_tests/test_vendor.py index 914a0a1d..9402217b 100644 --- a/vendor_tests/test_vendor.py +++ b/vendor_tests/test_vendor.py @@ -39,6 +39,15 @@ def test_vendor_extra(): assert_array_equal(y, x) +def test_vendor_extra_testing(): + from .array_api_extra.testing import lazy_xp_function + + def f(x): + return x + + lazy_xp_function(f) + + def test_vendor_extra_uses_vendor_compat(): from ._array_api_compat_vendor import array_namespace as n1 from .array_api_extra._lib._utils._compat import array_namespace as n2