diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index cd8d939f..ebe9bc5d 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -2,6 +2,7 @@ Internal helpers """ +import importlib from collections.abc import Callable from functools import wraps from inspect import signature @@ -52,8 +53,25 @@ def wrapped_f(*args: object, **kwargs: object) -> object: return inner -__all__ = ["get_xp"] +def clone_module(mod_name: str, globals_: dict[str, object]) -> list[str]: + """Import everything from module, updating globals(). + Returns __all__. + """ + mod = importlib.import_module(mod_name) + # Neither of these two methods is sufficient by itself, + # depending on various idiosyncrasies of the libraries we're wrapping. + objs = {} + exec(f"from {mod.__name__} import *", objs) + + for n in dir(mod): + if not n.startswith("_") and hasattr(mod, n): + objs[n] = getattr(mod, n) + + globals_.update(objs) + return list(objs) + +__all__ = ["get_xp", "clone_module"] def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8ea9162a..44ef6834 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -720,8 +720,6 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: "finfo", "iinfo", ] -_all_ignore = ["inspect", "array_namespace", "NamedTuple"] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 77175d0d..453a38f2 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -1052,7 +1052,5 @@ def is_lazy_array(x: object) -> bool: "to_device", ] -_all_ignore = ['lru_cache', 'sys', 'math', 'inspect', 'warnings'] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 7ad87a1b..6fea96f0 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -225,8 +225,6 @@ def trace( 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal', 'trace'] -_all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype'] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index 9a30f95d..af003c5a 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -1,3 +1,4 @@ +from typing import Final from cupy import * # noqa: F403 # from cupy import * doesn't overwrite these builtin names @@ -5,9 +6,19 @@ # These imports may overwrite names from the import * above. from ._aliases import * # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') -__array_api_version__ = '2024.12' +__array_api_version__: Final = '2024.12' + +__all__ = sorted( + {name for name in globals() if not name.startswith("__")} + - {"Final", "_aliases", "_info", "_typing"} + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 90b48f05..daa1aa95 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -7,7 +7,6 @@ from ..common import _aliases, _helpers from ..common._typing import NestedSequence, SupportsBufferProtocol from .._internal import get_xp -from ._info import __array_namespace_info__ from ._typing import Array, Device, DType bool = cp.bool_ @@ -146,11 +145,12 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): else: unstack = get_xp(cp)(_aliases.unstack) -__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype', +__all__ = _aliases.__all__ + ['asarray', 'astype', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', 'bool', 'concat', 'count_nonzero', 'pow', 'sign', 'take_along_axis'] -_all_ignore = ['cp', 'get_xp'] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index d8e49ca7..e5c202dc 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -1,7 +1,6 @@ from __future__ import annotations __all__ = ["Array", "DType", "Device"] -_all_ignore = ["cp"] from typing import TYPE_CHECKING diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py index 307e0f72..06af566b 100644 --- a/array_api_compat/cupy/fft.py +++ b/array_api_compat/cupy/fft.py @@ -30,7 +30,6 @@ __all__ = fft_all + _fft.__all__ -del get_xp -del cp -del fft_all -del _fft +def __dir__() -> list[str]: + return __all__ + diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index 7fcdd498..cd94be84 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -43,7 +43,5 @@ __all__ = linalg_all + _linalg.__all__ -del get_xp -del cp -del linalg_all -del _linalg +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index 1e47b960..fb1e0b94 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -1,12 +1,26 @@ from typing import Final -from dask.array import * # noqa: F403 +from ..._internal import clone_module + +__all__ = clone_module("dask.array", globals()) # These imports may overwrite names from the import * above. +from . import _aliases from ._aliases import * # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 __array_api_version__: Final = "2024.12" +del Final # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') + +__all__ = sorted( + set(__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index d43881ab..2442fe4b 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -41,7 +41,6 @@ NestedSequence, SupportsBufferProtocol, ) -from ._info import __array_namespace_info__ isdtype = get_xp(np)(_aliases.isdtype) unstack = get_xp(da)(_aliases.unstack) @@ -355,7 +354,6 @@ def count_nonzero( __all__ = [ - "__array_namespace_info__", "count_nonzero", "bool", "int8", "int16", "int32", "int64", @@ -369,8 +367,6 @@ def count_nonzero( "bitwise_left_shift", "bitwise_right_shift", "bitwise_invert", ] # fmt: skip __all__ += _aliases.__all__ -_all_ignore = ["array_namespace", "get_xp", "da", "np"] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py index 3f40dffe..44b68e73 100644 --- a/array_api_compat/dask/array/fft.py +++ b/array_api_compat/dask/array/fft.py @@ -1,13 +1,6 @@ -from dask.array.fft import * # noqa: F403 -# dask.array.fft doesn't have __all__. If it is added, replace this with -# -# from dask.array.fft import __all__ as linalg_all -_n = {} -exec('from dask.array.fft import *', _n) -for k in ("__builtins__", "Sequence", "annotations", "warnings"): - _n.pop(k, None) -fft_all = list(_n) -del _n, k +from ..._internal import clone_module + +__all__ = clone_module("dask.array.fft", globals()) from ...common import _fft from ..._internal import get_xp @@ -17,5 +10,7 @@ fftfreq = get_xp(da)(_fft.fftfreq) rfftfreq = get_xp(da)(_fft.rfftfreq) -__all__ = fft_all + ["fftfreq", "rfftfreq"] -_all_ignore = ["da", "fft_all", "get_xp", "warnings"] +__all__ += ["fftfreq", "rfftfreq"] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 0825386e..7c80620c 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -8,22 +8,13 @@ from dask.array import matmul, outer, tensordot # Exports -from dask.array.linalg import * # noqa: F403 - -from ..._internal import get_xp +from ..._internal import clone_module, get_xp from ...common import _linalg from ...common._typing import Array as _Array -from ._aliases import matrix_transpose, vecdot -# dask.array.linalg doesn't have __all__. If it is added, replace this with -# -# from dask.array.linalg import __all__ as linalg_all -_n = {} -exec('from dask.array.linalg import *', _n) -for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'): - _n.pop(k, None) -linalg_all = list(_n) -del _n, k +__all__ = clone_module("dask.array.linalg", globals()) + +from ._aliases import matrix_transpose, vecdot EighResult = _linalg.EighResult QRResult = _linalg.QRResult @@ -63,10 +54,11 @@ def svdvals(x: _Array) -> _Array: vector_norm = get_xp(da)(_linalg.vector_norm) diagonal = get_xp(da)(_linalg.diagonal) -__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot", - "matrix_transpose", "vecdot", "EighResult", - "QRResult", "SlogdetResult", "SVDResult", "qr", - "cholesky", "matrix_rank", "matrix_norm", "svdvals", - "vector_norm", "diagonal"] +__all__ += ["trace", "outer", "matmul", "tensordot", + "matrix_transpose", "vecdot", "EighResult", + "QRResult", "SlogdetResult", "SVDResult", "qr", + "cholesky", "matrix_rank", "matrix_norm", "svdvals", + "vector_norm", "diagonal"] -_all_ignore = ['get_xp', 'da', 'linalg_all', 'warnings'] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 3e138f53..cc9842b4 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -1,16 +1,17 @@ # ruff: noqa: PLC0414 from typing import Final -from numpy import * # noqa: F403 # pyright: ignore[reportWildcardImportFromLibrary] +from .._internal import clone_module -# from numpy import * doesn't overwrite these builtin names -from numpy import abs as abs -from numpy import max as max -from numpy import min as min -from numpy import round as round +# This needs to be loaded explicitly before cloning +import numpy.typing # noqa: F401 + +__all__ = clone_module("numpy", globals()) # These imports may overwrite names from the import * above. +from . import _aliases from ._aliases import * # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # Don't know why, but we have to do an absolute import to import linalg. If we # instead do @@ -26,3 +27,12 @@ from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401 __array_api_version__: Final = "2024.12" + +__all__ = sorted( + set(__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index a1aee5c0..e5fcceeb 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -9,7 +9,6 @@ from .._internal import get_xp from ..common import _aliases, _helpers from ..common._typing import NestedSequence, SupportsBufferProtocol -from ._info import __array_namespace_info__ from ._typing import Array, Device, DType if TYPE_CHECKING: @@ -162,8 +161,7 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): else: unstack = get_xp(np)(_aliases.unstack) -__all__ = [ - "__array_namespace_info__", +__all__ = _aliases.__all__ + [ "asarray", "astype", "acos", @@ -182,8 +180,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): "pow", "take_along_axis" ] -__all__ += _aliases.__all__ -_all_ignore = ["np", "get_xp"] def __dir__() -> list[str]: diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index e771c788..b5fa188c 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -23,7 +23,6 @@ Array: TypeAlias = np.ndarray __all__ = ["Array", "DType", "Device"] -_all_ignore = ["np"] def __dir__() -> list[str]: diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py index 06875f00..a492feb8 100644 --- a/array_api_compat/numpy/fft.py +++ b/array_api_compat/numpy/fft.py @@ -1,6 +1,8 @@ import numpy as np -from numpy.fft import __all__ as fft_all -from numpy.fft import fft2, ifft2, irfft2, rfft2 + +from .._internal import clone_module + +__all__ = clone_module("numpy.fft", globals()) from .._internal import get_xp from ..common import _fft @@ -21,15 +23,8 @@ ifftshift = get_xp(np)(_fft.ifftshift) -__all__ = ["rfft2", "irfft2", "fft2", "ifft2"] -__all__ += _fft.__all__ - +__all__ = sorted(set(__all__) | set(_fft.__all__)) def __dir__() -> list[str]: return __all__ - -del get_xp -del np -del fft_all -del _fft diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 2d3e731d..ca540880 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -7,26 +7,11 @@ import numpy as np -# intersection of `np.linalg.__all__` on numpy 1.22 and 2.2, minus `_linalg.__all__` -from numpy.linalg import ( - LinAlgError, - cond, - det, - eig, - eigvals, - eigvalsh, - inv, - lstsq, - matrix_power, - multi_dot, - norm, - tensorinv, - tensorsolve, -) - -from .._internal import get_xp +from .._internal import clone_module, get_xp from ..common import _linalg +__all__ = clone_module("numpy.linalg", globals()) + # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 from ._typing import Array @@ -120,7 +105,7 @@ def solve(x1: Array, x2: Array, /) -> Array: vector_norm = get_xp(np)(_linalg.vector_norm) -__all__ = [ +_all = [ "LinAlgError", "cond", "det", @@ -132,12 +117,12 @@ def solve(x1: Array, x2: Array, /) -> Array: "matrix_power", "multi_dot", "norm", + "solve", "tensorinv", "tensorsolve", + "vector_norm", ] -__all__ += _linalg.__all__ -__all__ += ["solve", "vector_norm"] - +__all__ = sorted(set(__all__) | set(_linalg.__all__) | set(_all)) def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index 69fd19ce..6cbb6ec2 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -1,22 +1,25 @@ -from torch import * # noqa: F403 +from typing import Final -# Several names are not included in the above import * -import torch -for n in dir(torch): - if (n.startswith('_') - or n.endswith('_') - or 'cuda' in n - or 'cpu' in n - or 'backward' in n): - continue - exec(f"{n} = torch.{n}") -del n +from .._internal import clone_module + +__all__ = clone_module("torch", globals()) # These imports may overwrite names from the import * above. +from . import _aliases from ._aliases import * # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') -__array_api_version__ = '2024.12' +__array_api_version__: Final = '2024.12' + +__all__ = sorted( + set(__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index de5d1a5d..be31c4d5 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -9,7 +9,6 @@ from .._internal import get_xp from ..common import _aliases from ..common._typing import NestedSequence, SupportsBufferProtocol -from ._info import __array_namespace_info__ from ._typing import Array, Device, DType _int_dtypes = { @@ -834,7 +833,7 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> List[Array return list(torch.meshgrid(*arrays, indexing='xy')) -__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast', +__all__ = ['asarray', 'result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero', @@ -851,5 +850,3 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> List[Array 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat', 'meshgrid'] - -_all_ignore = ['torch', 'get_xp'] diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index 50e6a0d0..242c92db 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -4,9 +4,11 @@ import torch import torch.fft -from torch.fft import * # noqa: F403 from ._typing import Array +from .._internal import clone_module + +__all__ = clone_module("torch.fft", globals()) # Several torch fft functions do not map axes to dim @@ -73,13 +75,7 @@ def ifftshift( return torch.fft.ifftshift(x, dim=axes, **kwargs) -__all__ = torch.fft.__all__ + [ - "fftn", - "ifftn", - "rfftn", - "irfftn", - "fftshift", - "ifftshift", -] +__all__ += ["fftn", "ifftn", "rfftn", "irfftn", "fftshift", "ifftshift"] -_all_ignore = ['torch'] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 70d72405..df94c351 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1,14 +1,12 @@ from __future__ import annotations import torch +import torch.linalg from typing import Optional, Union, Tuple -from torch.linalg import * # noqa: F403 +from .._internal import clone_module -# torch.linalg doesn't define __all__ -# from torch.linalg import __all__ as linalg_all -from torch import linalg as torch_linalg -linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] +__all__ = clone_module("torch.linalg", globals()) # outer is implemented in torch but aren't in the linalg namespace from torch import outer @@ -30,7 +28,7 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: if not (x1.shape[axis] == x2.shape[axis] == 3): raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}") x1, x2 = torch.broadcast_tensors(x1, x2) - return torch_linalg.cross(x1, x2, dim=axis) + return torch.linalg.cross(x1, x2, dim=axis) def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array: from ._aliases import isdtype @@ -110,12 +108,8 @@ def vector_norm( return out return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs) -__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot', - 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] - -_all_ignore = ['torch_linalg', 'sum'] - -del linalg_all +__all__ += ['outer', 'matmul', 'matrix_transpose', 'tensordot', + 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] def __dir__() -> list[str]: return __all__ diff --git a/tests/test_all.py b/tests/test_all.py index 271cd189..c36aef67 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,63 +1,311 @@ -""" -Test that files that define __all__ aren't missing any exports. +"""Test exported names""" -You can add names that shouldn't be exported to _all_ignore, like +import builtins -_all_ignore = ['sys'] +import numpy as np +import pytest -This is preferable to del-ing the names as this will break any name that is -used inside of a function. Note that names starting with an underscore are automatically ignored. -""" +from array_api_compat._internal import clone_module +from ._helpers import wrapped_libraries -import sys +NAMES = { + "": [ + # Inspection + "__array_api_version__", + "__array_namespace_info__", + # Submodules + "fft", + "linalg", + # Constants + "e", + "inf", + "nan", + "newaxis", + "pi", + # Creation Functions + "arange", + "asarray", + "empty", + "empty_like", + "eye", + "from_dlpack", + "full", + "full_like", + "linspace", + "meshgrid", + "ones", + "ones_like", + "tril", + "triu", + "zeros", + "zeros_like", + # Data Type Functions + "astype", + "can_cast", + "finfo", + "iinfo", + "isdtype", + "result_type", + # Data Types + "bool", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "complex64", + "complex128", + # Elementwise Functions + "abs", + "acos", + "acosh", + "add", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "ceil", + "clip", + "conj", + "copysign", + "cos", + "cosh", + "divide", + "equal", + "exp", + "expm1", + "floor", + "floor_divide", + "greater", + "greater_equal", + "hypot", + "imag", + "isfinite", + "isinf", + "isnan", + "less", + "less_equal", + "log", + "log1p", + "log2", + "log10", + "logaddexp", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "maximum", + "minimum", + "multiply", + "negative", + "nextafter", + "not_equal", + "positive", + "pow", + "real", + "reciprocal", + "remainder", + "round", + "sign", + "signbit", + "sin", + "sinh", + "square", + "sqrt", + "subtract", + "tan", + "tanh", + "trunc", + # Indexing Functions + "take", + "take_along_axis", + # Linear Algebra Functions + "matmul", + "matrix_transpose", + "tensordot", + "vecdot", + # Manipulation Functions + "broadcast_arrays", + "broadcast_to", + "concat", + "expand_dims", + "flip", + "moveaxis", + "permute_dims", + "repeat", + "reshape", + "roll", + "squeeze", + "stack", + "tile", + "unstack", + # Searching Functions + "argmax", + "argmin", + "count_nonzero", + "nonzero", + "searchsorted", + "where", + # Set Functions + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + # Sorting Functions + "argsort", + "sort", + # Statistical Functions + "cumulative_prod", + "cumulative_sum", + "max", + "mean", + "min", + "prod", + "std", + "sum", + "var", + # Utility Functions + "all", + "any", + "diff", + ], + "fft": [ + "fft", + "ifft", + "fftn", + "ifftn", + "rfft", + "irfft", + "rfftn", + "irfftn", + "hfft", + "ihfft", + "fftfreq", + "rfftfreq", + "fftshift", + "ifftshift", + ], + "linalg": [ + "cholesky", + "cross", + "det", + "diagonal", + "eigh", + "eigvalsh", + "inv", + "matmul", + "matrix_norm", + "matrix_power", + "matrix_rank", + "matrix_transpose", + "outer", + "pinv", + "qr", + "slogdet", + "solve", + "svd", + "svdvals", + "tensordot", + "trace", + "vecdot", + "vector_norm", + ], +} -from ._helpers import import_, wrapped_libraries +XFAILS = { + ("numpy", ""): ["from_dlpack"] if np.__version__ < "1.23" else [], + ("dask.array", ""): ["from_dlpack", "take_along_axis"], + ("dask.array", "linalg"): [ + "cross", + "det", + "eigh", + "eigvalsh", + "matrix_power", + "pinv", + "slogdet", + ], +} -import pytest -import typing - -TYPING_NAMES = frozenset(( - "Array", - "Device", - "DType", - "Namespace", - "NestedSequence", - "SupportsBufferProtocol", -)) - -@pytest.mark.parametrize("library", ["common"] + wrapped_libraries) -def test_all(library): - if library == "common": - import array_api_compat.common # noqa: F401 - else: - import_(library, wrapper=True) - - # NB: iterate over a copy to avoid a "dictionary size changed" error - for mod_name in sys.modules.copy(): - if not mod_name.startswith('array_api_compat.' + library): - continue - - module = sys.modules[mod_name] - - # TODO: We should define __all__ in the __init__.py files and test it - # there too. - if not hasattr(module, '__all__'): - continue - - dir_names = [n for n in dir(module) if not n.startswith('_')] - if '__array_namespace_info__' in dir(module): - dir_names.append('__array_namespace_info__') - ignore_all_names = set(getattr(module, '_all_ignore', ())) - ignore_all_names |= set(dir(typing)) - ignore_all_names |= {"annotations"} - if not module.__name__.endswith("._typing"): - ignore_all_names |= TYPING_NAMES - dir_names = set(dir_names) - set(ignore_all_names) - all_names = module.__all__ - - if set(dir_names) != set(all_names): - extra_dir = set(dir_names) - set(all_names) - extra_all = set(all_names) - set(dir_names) - assert not extra_dir, f"Some dir() names not included in __all__ for {mod_name}: {extra_dir}" - assert not extra_all, f"Some __all__ names not in dir() for {mod_name}: {extra_all}" + +def all_names(mod): + """Return all names available in a module.""" + objs = {} + clone_module(mod.__name__, objs) + return set(objs) + + +def get_mod(library, module, *, compat): + if compat: + library = f"array_api_compat.{library}" + xp = pytest.importorskip(library) + return getattr(xp, module) if module else xp + + +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_array_api_names(library, module): + """Test that __all__ isn't missing any exports + dictated by the Standard. + """ + mod = get_mod(library, module, compat=True) + missing = set(NAMES[module]) - all_names(mod) + xfail = set(XFAILS.get((library, module), [])) + xpass = xfail - missing + fails = missing - xfail + assert not xpass, f"Names in XFAILS are defined: {xpass}" + assert not fails, f"Missing exports: {fails}" + + +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_compat_doesnt_hide_names(library, module): + """The base namespace can have more names than the ones explicitly exported + by array-api-compat. Test that we're not suppressing them. + """ + bare_mod = get_mod(library, module, compat=False) + compat_mod = get_mod(library, module, compat=True) + + missing = all_names(bare_mod) - all_names(compat_mod) + missing = {name for name in missing if not name.startswith("_")} + assert not missing, f"Non-Array API names have been hidden: {missing}" + + +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_compat_doesnt_add_names(library, module): + """Test that array-api-compat isn't adding names to the namespace + besides those defined by the Array API Standard. + """ + bare_mod = get_mod(library, module, compat=False) + compat_mod = get_mod(library, module, compat=True) + + aapi_names = set(NAMES[module]) + spurious = all_names(compat_mod) - all_names(bare_mod) - aapi_names + # Quietly ignore *Result dataclasses + spurious = {name for name in spurious if not name.endswith("Result")} + assert not spurious, ( + f"array-api-compat is adding non-Array API names: {spurious}" + ) + + +@pytest.mark.parametrize( + "name", [name for name in NAMES[""] if hasattr(builtins, name)] +) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_builtins_collision(library, name): + """Test that xp.bool is not accidentally builtins.bool, etc.""" + xp = pytest.importorskip(f"array_api_compat.{library}") + assert getattr(xp, name) is not getattr(builtins, name)