Skip to content

Commit edc8a10

Browse files
betatimthomasjpfan
authored andcommitted
ENH Add Array API compatibility to MinMaxScaler (scikit-learn#26243)
Co-authored-by: Thomas J. Fan <[email protected]>
1 parent 74daec7 commit edc8a10

File tree

7 files changed

+174
-18
lines changed

7 files changed

+174
-18
lines changed

doc/modules/array_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ Estimators
9696
- :class:`decomposition.PCA` (with `svd_solver="full"`,
9797
`svd_solver="randomized"` and `power_iteration_normalizer="QR"`)
9898
- :class:`discriminant_analysis.LinearDiscriminantAnalysis` (with `solver="svd"`)
99+
- :class:`preprocessing.MinMaxScaler`
99100

100101
Tools
101102
-----

doc/whats_new/v1.3.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ Changelog
398398
- |Feature| Compute a custom out-of-bag score by passing a callable to
399399
:class:`ensemble.RandomForestClassifier`, :class:`ensemble.RandomForestRegressor`,
400400
:class:`ensemble.ExtraTreesClassifier` and :class:`ensemble.ExtraTreesRegressor`.
401-
:pr:`25177` by :user:`Tim Head <betatim>`.
401+
:pr:`25177` by `Tim Head`_.
402402

403403
- |Feature| :class:`ensemble.GradientBoostingClassifier` now exposes
404404
out-of-bag scores via the `oob_scores_` or `oob_score_` attributes.

doc/whats_new/v1.4.rst

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ Changelog
174174
is enabled and should be passed via the `params` parameter. :pr:`26896` by
175175
`Adrin Jalali`_.
176176

177+
- |Enhancement| :func:`sklearn.model_selection.train_test_split` now supports
178+
Array API compatible inputs. :pr:`26855` by `Tim Head`_.
179+
177180
:mod:`sklearn.neighbors`
178181
........................
179182

@@ -197,8 +200,11 @@ Changelog
197200
when `sparse_output=True` and the output is configured to be pandas.
198201
:pr:`26931` by `Thomas Fan`_.
199202

200-
- |Enhancement| :func:`sklearn.model_selection.train_test_split` now supports
201-
Array API compatible inputs. :pr:`26855` by `Tim Head`_.
203+
- |MajorFeature| :class:`preprocessing.MinMaxScaler` now
204+
supports the `Array API <https://data-apis.org/array-api/latest/>`_. Array API
205+
support is considered experimental and might evolve without being subject to
206+
our usual rolling deprecation cycle policy. See
207+
:ref:`array_api` for more details. :pr:`26243` by `Tim Head`_.
202208

203209
:mod:`sklearn.tree`
204210
...................

sklearn/preprocessing/_data.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
TransformerMixin,
2323
_fit_context,
2424
)
25-
from ..utils import check_array
25+
from ..utils import _array_api, check_array
26+
from ..utils._array_api import get_namespace
2627
from ..utils._param_validation import Interval, Options, StrOptions, validate_params
2728
from ..utils.extmath import _incremental_mean_and_var, row_norms
2829
from ..utils.sparsefuncs import (
@@ -103,16 +104,18 @@ def _handle_zeros_in_scale(scale, copy=True, constant_mask=None):
103104
if scale == 0.0:
104105
scale = 1.0
105106
return scale
106-
elif isinstance(scale, np.ndarray):
107+
# scale is an array
108+
else:
109+
xp, _ = get_namespace(scale)
107110
if constant_mask is None:
108111
# Detect near constant values to avoid dividing by a very small
109112
# value that could lead to surprising results and numerical
110113
# stability issues.
111-
constant_mask = scale < 10 * np.finfo(scale.dtype).eps
114+
constant_mask = scale < 10 * xp.finfo(scale.dtype).eps
112115

113116
if copy:
114117
# New array to avoid side-effects
115-
scale = scale.copy()
118+
scale = xp.asarray(scale, copy=True)
116119
scale[constant_mask] = 1.0
117120
return scale
118121

@@ -468,22 +471,24 @@ def partial_fit(self, X, y=None):
468471
"Consider using MaxAbsScaler instead."
469472
)
470473

474+
xp, _ = get_namespace(X)
475+
471476
first_pass = not hasattr(self, "n_samples_seen_")
472477
X = self._validate_data(
473478
X,
474479
reset=first_pass,
475-
dtype=FLOAT_DTYPES,
480+
dtype=_array_api.supported_float_dtypes(xp),
476481
force_all_finite="allow-nan",
477482
)
478483

479-
data_min = np.nanmin(X, axis=0)
480-
data_max = np.nanmax(X, axis=0)
484+
data_min = _array_api._nanmin(X, axis=0)
485+
data_max = _array_api._nanmax(X, axis=0)
481486

482487
if first_pass:
483488
self.n_samples_seen_ = X.shape[0]
484489
else:
485-
data_min = np.minimum(self.data_min_, data_min)
486-
data_max = np.maximum(self.data_max_, data_max)
490+
data_min = xp.minimum(self.data_min_, data_min)
491+
data_max = xp.maximum(self.data_max_, data_max)
487492
self.n_samples_seen_ += X.shape[0]
488493

489494
data_range = data_max - data_min
@@ -511,18 +516,20 @@ def transform(self, X):
511516
"""
512517
check_is_fitted(self)
513518

519+
xp, _ = get_namespace(X)
520+
514521
X = self._validate_data(
515522
X,
516523
copy=self.copy,
517-
dtype=FLOAT_DTYPES,
524+
dtype=_array_api.supported_float_dtypes(xp),
518525
force_all_finite="allow-nan",
519526
reset=False,
520527
)
521528

522529
X *= self.scale_
523530
X += self.min_
524531
if self.clip:
525-
np.clip(X, self.feature_range[0], self.feature_range[1], out=X)
532+
xp.clip(X, self.feature_range[0], self.feature_range[1], out=X)
526533
return X
527534

528535
def inverse_transform(self, X):
@@ -540,8 +547,13 @@ def inverse_transform(self, X):
540547
"""
541548
check_is_fitted(self)
542549

550+
xp, _ = get_namespace(X)
551+
543552
X = check_array(
544-
X, copy=self.copy, dtype=FLOAT_DTYPES, force_all_finite="allow-nan"
553+
X,
554+
copy=self.copy,
555+
dtype=_array_api.supported_float_dtypes(xp),
556+
force_all_finite="allow-nan",
545557
)
546558

547559
X -= self.min_

sklearn/preprocessing/tests/test_data.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
from sklearn.preprocessing._data import BOUNDS_THRESHOLD, _handle_zeros_in_scale
4242
from sklearn.svm import SVR
4343
from sklearn.utils import gen_batches, shuffle
44+
from sklearn.utils._array_api import (
45+
yield_namespace_device_dtype_combinations,
46+
)
4447
from sklearn.utils._testing import (
4548
_convert_container,
4649
assert_allclose,
@@ -51,6 +54,10 @@
5154
assert_array_less,
5255
skip_if_32bit,
5356
)
57+
from sklearn.utils.estimator_checks import (
58+
_get_check_estimator_ids,
59+
check_array_api_input_and_values,
60+
)
5461
from sklearn.utils.sparsefuncs import mean_variance_axis
5562

5663
iris = datasets.load_iris()
@@ -684,6 +691,26 @@ def test_standard_check_array_of_inverse_transform():
684691
scaler.inverse_transform(x)
685692

686693

694+
@pytest.mark.parametrize(
695+
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
696+
)
697+
@pytest.mark.parametrize(
698+
"check",
699+
[check_array_api_input_and_values],
700+
ids=_get_check_estimator_ids,
701+
)
702+
@pytest.mark.parametrize(
703+
"estimator",
704+
[MinMaxScaler()],
705+
ids=_get_check_estimator_ids,
706+
)
707+
def test_minmaxscaler_array_api_compliance(
708+
estimator, check, array_namespace, device, dtype
709+
):
710+
name = estimator.__class__.__name__
711+
check(name, estimator, array_namespace, device=device, dtype=dtype)
712+
713+
687714
def test_min_max_scaler_iris():
688715
X = iris.data
689716
scaler = MinMaxScaler()

sklearn/utils/_array_api.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def _isdtype_single(dtype, kind, *, xp):
146146
for k in ("signed integer", "unsigned integer")
147147
)
148148
elif kind == "real floating":
149-
return dtype in {xp.float32, xp.float64}
149+
return dtype in supported_float_dtypes(xp)
150150
elif kind == "complex floating":
151151
# Some name spaces do not have complex, such as cupy.array_api
152152
# and numpy.array_api
@@ -167,14 +167,29 @@ def _isdtype_single(dtype, kind, *, xp):
167167
return dtype == kind
168168

169169

170+
def supported_float_dtypes(xp):
171+
"""Supported floating point types for the namespace
172+
173+
Note: float16 is not officially part of the Array API spec at the
174+
time of writing but scikit-learn estimators and functions can choose
175+
to accept it when xp.float16 is defined.
176+
177+
https://data-apis.org/array-api/latest/API_specification/data_types.html
178+
"""
179+
if hasattr(xp, "float16"):
180+
return (xp.float64, xp.float32, xp.float16)
181+
else:
182+
return (xp.float64, xp.float32)
183+
184+
170185
class _ArrayAPIWrapper:
171186
"""sklearn specific Array API compatibility wrapper
172187
173188
This wrapper makes it possible for scikit-learn maintainers to
174189
deal with discrepancies between different implementations of the
175-
Python array API standard and its evolution over time.
190+
Python Array API standard and its evolution over time.
176191
177-
The Python array API standard specification:
192+
The Python Array API standard specification:
178193
https://data-apis.org/array-api/latest/
179194
180195
Documentation of the NumPy implementation:
@@ -269,6 +284,9 @@ class _NumPyAPIWrapper:
269284
"uint16",
270285
"uint32",
271286
"uint64",
287+
# XXX: float16 is not part of the Array API spec but exposed by
288+
# some namespaces.
289+
"float16",
272290
"float32",
273291
"float64",
274292
"complex64",
@@ -394,6 +412,8 @@ def get_namespace(*arrays):
394412

395413
namespace, is_array_api_compliant = array_api_compat.get_namespace(*arrays), True
396414

415+
# These namespaces need additional wrapping to smooth out small differences
416+
# between implementations
397417
if namespace.__name__ in {"numpy.array_api", "cupy.array_api"}:
398418
namespace = _ArrayAPIWrapper(namespace)
399419

@@ -466,6 +486,40 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None):
466486
return float(xp.sum(sample_score))
467487

468488

489+
def _nanmin(X, axis=None):
490+
# TODO: refactor once nan-aware reductions are standardized:
491+
# https://github.com/data-apis/array-api/issues/621
492+
xp, _ = get_namespace(X)
493+
if _is_numpy_namespace(xp):
494+
return xp.asarray(numpy.nanmin(X, axis=axis))
495+
496+
else:
497+
mask = xp.isnan(X)
498+
X = xp.min(xp.where(mask, xp.asarray(+xp.inf), X), axis=axis)
499+
# Replace Infs from all NaN slices with NaN again
500+
mask = xp.all(mask, axis=axis)
501+
if xp.any(mask):
502+
X = xp.where(mask, xp.asarray(xp.nan), X)
503+
return X
504+
505+
506+
def _nanmax(X, axis=None):
507+
# TODO: refactor once nan-aware reductions are standardized:
508+
# https://github.com/data-apis/array-api/issues/621
509+
xp, _ = get_namespace(X)
510+
if _is_numpy_namespace(xp):
511+
return xp.asarray(numpy.nanmax(X, axis=axis))
512+
513+
else:
514+
mask = xp.isnan(X)
515+
X = xp.max(xp.where(mask, xp.asarray(-xp.inf), X), axis=axis)
516+
# Replace Infs from all NaN slices with NaN again
517+
mask = xp.all(mask, axis=axis)
518+
if xp.any(mask):
519+
X = xp.where(mask, xp.asarray(xp.nan), X)
520+
return X
521+
522+
469523
def _asarray_with_order(array, dtype=None, order=None, copy=None, *, xp=None):
470524
"""Helper to support the order kwarg only for NumPy-backed arrays
471525

sklearn/utils/tests/test_array_api.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import numpy
24
import pytest
35
from numpy.testing import assert_allclose, assert_array_equal
@@ -9,8 +11,11 @@
911
_asarray_with_order,
1012
_convert_to_numpy,
1113
_estimator_with_converted_arrays,
14+
_nanmax,
15+
_nanmin,
1216
_NumPyAPIWrapper,
1317
get_namespace,
18+
supported_float_dtypes,
1419
)
1520
from sklearn.utils._testing import skip_if_array_api_compat_not_configured
1621

@@ -159,6 +164,54 @@ def test_asarray_with_order_ignored():
159164
assert not X_new_np.flags["F_CONTIGUOUS"]
160165

161166

167+
@skip_if_array_api_compat_not_configured
168+
@pytest.mark.parametrize(
169+
"library", ["numpy", "numpy.array_api", "cupy", "cupy.array_api", "torch"]
170+
)
171+
@pytest.mark.parametrize(
172+
"X,reduction,expected",
173+
[
174+
([1, 2, numpy.nan], _nanmin, 1),
175+
([1, -2, -numpy.nan], _nanmin, -2),
176+
([numpy.inf, numpy.inf], _nanmin, numpy.inf),
177+
(
178+
[[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
179+
partial(_nanmin, axis=0),
180+
[1.0, 2.0, 3.0],
181+
),
182+
(
183+
[[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
184+
partial(_nanmin, axis=1),
185+
[1.0, numpy.nan, 4.0],
186+
),
187+
([1, 2, numpy.nan], _nanmax, 2),
188+
([1, 2, numpy.nan], _nanmax, 2),
189+
([-numpy.inf, -numpy.inf], _nanmax, -numpy.inf),
190+
(
191+
[[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
192+
partial(_nanmax, axis=0),
193+
[4.0, 5.0, 6.0],
194+
),
195+
(
196+
[[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
197+
partial(_nanmax, axis=1),
198+
[3.0, numpy.nan, 6.0],
199+
),
200+
],
201+
)
202+
def test_nan_reductions(library, X, reduction, expected):
203+
"""Check NaN reductions like _nanmin and _nanmax"""
204+
xp = pytest.importorskip(library)
205+
206+
if isinstance(expected, list):
207+
expected = xp.asarray(expected)
208+
209+
with config_context(array_api_dispatch=True):
210+
result = reduction(xp.asarray(X))
211+
212+
assert_allclose(result, expected)
213+
214+
162215
@skip_if_array_api_compat_not_configured
163216
@pytest.mark.parametrize("library", ["cupy", "torch", "cupy.array_api"])
164217
def test_convert_to_numpy_gpu(library): # pragma: nocover
@@ -256,6 +309,9 @@ def test_get_namespace_array_api_isdtype(wrapper):
256309
assert xp.isdtype(xp.float64, "real floating")
257310
assert not xp.isdtype(xp.int32, "real floating")
258311

312+
for dtype in supported_float_dtypes(xp):
313+
assert xp.isdtype(dtype, "real floating")
314+
259315
assert xp.isdtype(xp.bool, "bool")
260316
assert not xp.isdtype(xp.float32, "bool")
261317

0 commit comments

Comments
 (0)