diff --git a/docs/api-reference.md b/docs/api-reference.md index b43c960f..279c84c4 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -12,6 +12,7 @@ create_diagonal expand_dims kron + nunique setdiff1d sinc ``` diff --git a/pixi.lock b/pixi.lock index 2790b207..ed262f2f 100644 --- a/pixi.lock +++ b/pixi.lock @@ -2469,7 +2469,7 @@ packages: - pypi: . name: array-api-extra version: 0.5.1.dev0 - sha256: 8b4533cc75534abb69425a1e5c9f6a4ab96949562d2e90d41ea0e22187a02c1b + sha256: 09d6a4b1405fd64596379826065a09bc3787a4fc4e1535dc369f74a3b96f86e3 requires_dist: - array-api-compat>=1.10.0,<2 - furo>=2023.8.17 ; extra == 'docs' diff --git a/pyproject.toml b/pyproject.toml index 4f5ddac0..a5594541 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -299,6 +299,7 @@ messages_control.disable = [ "line-too-long", "missing-module-docstring", "missing-function-docstring", + "too-many-lines", "wrong-import-position", ] diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index a4f6815f..3f973307 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -7,6 +7,7 @@ create_diagonal, expand_dims, kron, + nunique, pad, setdiff1d, sinc, @@ -23,6 +24,7 @@ "create_diagonal", "expand_dims", "kron", + "nunique", "pad", "setdiff1d", "sinc", diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 7502561a..017c7297 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -3,6 +3,7 @@ # https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 from __future__ import annotations +import math import operator import warnings from collections.abc import Callable @@ -25,6 +26,7 @@ "create_diagonal", "expand_dims", "kron", + "nunique", "pad", "setdiff1d", "sinc", @@ -638,6 +640,42 @@ def pad( return at(padded, tuple(slices)).set(x) +def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array: + """ + Count the number of unique elements in an array. + + Compatible with JAX and Dask, whose laziness would be otherwise + problematic. + + Parameters + ---------- + x : Array + Input array. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + + Returns + ------- + array: 0-dimensional integer array + The number of unique elements in `x`. It can be lazy. + """ + if xp is None: + xp = array_namespace(x) + + if is_jax_array(x): + # size= is JAX-specific + # https://github.com/data-apis/array-api/issues/883 + _, counts = xp.unique_counts(x, size=_compat.size(x)) + return xp.astype(counts, xp.bool).sum() + + _, counts = xp.unique_counts(x) + n = _compat.size(counts) + # FIXME https://github.com/data-apis/array-api-compat/pull/231 + if n is None or math.isnan(n): # e.g. Dask, ndonnx + return xp.astype(counts, xp.bool).sum() + return xp.asarray(n, device=_compat.device(x)) + + class _AtOp(Enum): """Operations for use in `xpx.at`.""" diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 5f18ef61..201295da 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -11,6 +11,7 @@ create_diagonal, expand_dims, kron, + nunique, pad, setdiff1d, sinc, @@ -448,3 +449,21 @@ def test_list_of_tuples_width(self, xp: ModuleType): padded = pad(a, [(1, 0), (0, 0)]) assert padded.shape == (4, 4) + + +class TestNUnique: + def test_simple(self, xp: ModuleType): + a = xp.asarray([[1, 1], [0, 2], [2, 2]]) + xp_assert_equal(nunique(a), xp.asarray(3)) + + def test_empty(self, xp: ModuleType): + a = xp.asarray([]) + xp_assert_equal(nunique(a), xp.asarray(0)) + + def test_device(self, xp: ModuleType, device: Device): + a = xp.asarray(0.0, device=device) + assert get_device(nunique(a)) == device + + def test_xp(self, xp: ModuleType): + a = xp.asarray([[1, 1], [0, 2], [2, 2]]) + xp_assert_equal(nunique(a, xp=xp), xp.asarray(3))