Skip to content

Commit 187b2ac

Browse files
author
jax authors
committed
Merge pull request #21013 from Micky774:array-api-trim
PiperOrigin-RevId: 630146636
2 parents d5983e1 + b88e2e8 commit 187b2ac

23 files changed

+189
-1179
lines changed

jax/_src/dtypes.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
435435
}
436436

437437

438-
def isdtype(dtype: DTypeLike, kind: str | DType | tuple[str | DType, ...]) -> bool:
438+
def isdtype(dtype: DTypeLike, kind: str | DTypeLike | tuple[str | DTypeLike, ...]) -> bool:
439439
"""Returns a boolean indicating whether a provided dtype is of a specified kind.
440440
441441
Args:
@@ -458,18 +458,25 @@ def isdtype(dtype: DTypeLike, kind: str | DType | tuple[str | DType, ...]) -> bo
458458
True or False
459459
"""
460460
the_dtype = np.dtype(dtype)
461-
kind_tuple: tuple[DType | str, ...] = kind if isinstance(kind, tuple) else (kind,)
461+
kind_tuple: tuple[str | DTypeLike, ...] = (
462+
kind if isinstance(kind, tuple) else (kind,)
463+
)
462464
options: set[DType] = set()
463465
for kind in kind_tuple:
464-
if isinstance(kind, str):
465-
if kind not in _dtype_kinds:
466-
raise ValueError(f"Unrecognized {kind=} expected one of {list(_dtype_kinds.keys())}")
466+
if isinstance(kind, str) and kind in _dtype_kinds:
467467
options.update(_dtype_kinds[kind])
468-
elif isinstance(kind, np.dtype):
469-
options.add(kind)
470-
else:
471-
# TODO(jakevdp): should we handle scalar types or ScalarMeta here?
472-
raise TypeError(f"Expected kind to be a dtype, string, or tuple; got {kind=}")
468+
continue
469+
try:
470+
_dtype = np.dtype(kind)
471+
except TypeError as e:
472+
if isinstance(kind, str):
473+
raise ValueError(
474+
f"Unrecognized {kind=} expected one of {list(_dtype_kinds.keys())}, "
475+
"or a compatible input for jnp.dtype()")
476+
raise TypeError(
477+
f"Expected kind to be a dtype, string, or tuple; got {kind=}"
478+
) from e
479+
options.add(_dtype)
473480
return the_dtype in options
474481

475482

jax/experimental/array_api/__init__.py

Lines changed: 90 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
>>> from jax.experimental import array_api as xp
2222
2323
>>> xp.__array_api_version__
24-
'2022.12'
24+
'2023.12'
2525
2626
>>> arr = xp.arange(1000)
2727
@@ -38,68 +38,19 @@
3838

3939
from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__
4040

41-
from jax.experimental.array_api import (
42-
fft as fft,
43-
linalg as linalg,
44-
)
45-
46-
from jax.experimental.array_api._constants import (
47-
e as e,
48-
inf as inf,
49-
nan as nan,
50-
newaxis as newaxis,
51-
pi as pi,
52-
)
53-
54-
from jax.experimental.array_api._creation_functions import (
55-
arange as arange,
56-
asarray as asarray,
57-
empty as empty,
58-
empty_like as empty_like,
59-
eye as eye,
60-
from_dlpack as from_dlpack,
61-
full as full,
62-
full_like as full_like,
63-
linspace as linspace,
64-
meshgrid as meshgrid,
65-
ones as ones,
66-
ones_like as ones_like,
67-
tril as tril,
68-
triu as triu,
69-
zeros as zeros,
70-
zeros_like as zeros_like,
71-
)
72-
73-
from jax.experimental.array_api._data_type_functions import (
74-
astype as astype,
75-
can_cast as can_cast,
76-
finfo as finfo,
77-
iinfo as iinfo,
78-
isdtype as isdtype,
79-
result_type as result_type,
80-
)
81-
82-
from jax.experimental.array_api._dtypes import (
83-
bool as bool,
84-
int8 as int8,
85-
int16 as int16,
86-
int32 as int32,
87-
int64 as int64,
88-
uint8 as uint8,
89-
uint16 as uint16,
90-
uint32 as uint32,
91-
uint64 as uint64,
92-
float32 as float32,
93-
float64 as float64,
94-
complex64 as complex64,
95-
complex128 as complex128,
96-
)
41+
from jax.experimental.array_api import fft as fft
42+
from jax.experimental.array_api import linalg as linalg
9743

98-
from jax.experimental.array_api._elementwise_functions import (
44+
from jax.numpy import (
9945
abs as abs,
10046
acos as acos,
10147
acosh as acosh,
10248
add as add,
49+
all as all,
50+
any as any,
51+
argmax as argmax,
52+
argmin as argmin,
53+
argsort as argsort,
10354
asin as asin,
10455
asinh as asinh,
10556
atan as atan,
@@ -111,22 +62,43 @@
11162
bitwise_or as bitwise_or,
11263
bitwise_right_shift as bitwise_right_shift,
11364
bitwise_xor as bitwise_xor,
114-
ceil as ceil,
115-
clip as clip,
65+
bool as bool,
66+
broadcast_arrays as broadcast_arrays,
67+
broadcast_to as broadcast_to,
68+
can_cast as can_cast,
69+
complex128 as complex128,
70+
complex64 as complex64,
71+
concat as concat,
11672
conj as conj,
11773
copysign as copysign,
11874
cos as cos,
11975
cosh as cosh,
76+
cumulative_sum as cumulative_sum,
12077
divide as divide,
78+
e as e,
79+
empty as empty,
80+
empty_like as empty_like,
12181
equal as equal,
12282
exp as exp,
83+
expand_dims as expand_dims,
12384
expm1 as expm1,
124-
floor as floor,
85+
flip as flip,
86+
float32 as float32,
87+
float64 as float64,
12588
floor_divide as floor_divide,
89+
from_dlpack as from_dlpack,
90+
full as full,
91+
full_like as full_like,
12692
greater as greater,
12793
greater_equal as greater_equal,
128-
hypot as hypot,
94+
iinfo as iinfo,
12995
imag as imag,
96+
inf as inf,
97+
int16 as int16,
98+
int32 as int32,
99+
int64 as int64,
100+
int8 as int8,
101+
isdtype as isdtype,
130102
isfinite as isfinite,
131103
isinf as isinf,
132104
isnan as isnan,
@@ -141,91 +113,99 @@
141113
logical_not as logical_not,
142114
logical_or as logical_or,
143115
logical_xor as logical_xor,
116+
matmul as matmul,
117+
matrix_transpose as matrix_transpose,
118+
max as max,
144119
maximum as maximum,
120+
mean as mean,
121+
meshgrid as meshgrid,
122+
min as min,
145123
minimum as minimum,
124+
moveaxis as moveaxis,
146125
multiply as multiply,
126+
nan as nan,
147127
negative as negative,
128+
newaxis as newaxis,
129+
nonzero as nonzero,
148130
not_equal as not_equal,
131+
ones as ones,
132+
ones_like as ones_like,
133+
permute_dims as permute_dims,
134+
pi as pi,
149135
positive as positive,
150136
pow as pow,
137+
prod as prod,
151138
real as real,
152139
remainder as remainder,
140+
repeat as repeat,
141+
result_type as result_type,
142+
roll as roll,
153143
round as round,
144+
searchsorted as searchsorted,
154145
sign as sign,
155146
signbit as signbit,
156147
sin as sin,
157148
sinh as sinh,
149+
sort as sort,
158150
sqrt as sqrt,
159151
square as square,
152+
squeeze as squeeze,
153+
stack as stack,
160154
subtract as subtract,
155+
sum as sum,
156+
take as take,
161157
tan as tan,
162158
tanh as tanh,
163-
trunc as trunc,
164-
)
165-
166-
from jax.experimental.array_api._indexing_functions import (
167-
take as take,
159+
tensordot as tensordot,
160+
tile as tile,
161+
tril as tril,
162+
triu as triu,
163+
uint16 as uint16,
164+
uint32 as uint32,
165+
uint64 as uint64,
166+
uint8 as uint8,
167+
unique_all as unique_all,
168+
unique_counts as unique_counts,
169+
unique_inverse as unique_inverse,
170+
unique_values as unique_values,
171+
unstack as unstack,
172+
vecdot as vecdot,
173+
where as where,
174+
zeros as zeros,
175+
zeros_like as zeros_like,
168176
)
169177

170178
from jax.experimental.array_api._manipulation_functions import (
171-
broadcast_arrays as broadcast_arrays,
172-
broadcast_to as broadcast_to,
173-
concat as concat,
174-
expand_dims as expand_dims,
175-
flip as flip,
176-
moveaxis as moveaxis,
177-
permute_dims as permute_dims,
178-
repeat as repeat,
179179
reshape as reshape,
180-
roll as roll,
181-
squeeze as squeeze,
182-
stack as stack,
183-
tile as tile,
184-
unstack as unstack,
185180
)
186181

187-
from jax.experimental.array_api._searching_functions import (
188-
argmax as argmax,
189-
argmin as argmin,
190-
nonzero as nonzero,
191-
searchsorted as searchsorted,
192-
where as where,
182+
from jax.experimental.array_api._creation_functions import (
183+
arange as arange,
184+
asarray as asarray,
185+
eye as eye,
186+
linspace as linspace,
193187
)
194188

195-
from jax.experimental.array_api._set_functions import (
196-
unique_all as unique_all,
197-
unique_counts as unique_counts,
198-
unique_inverse as unique_inverse,
199-
unique_values as unique_values,
189+
from jax.experimental.array_api._data_type_functions import (
190+
astype as astype,
191+
finfo as finfo,
200192
)
201193

202-
from jax.experimental.array_api._sorting_functions import (
203-
argsort as argsort,
204-
sort as sort,
194+
from jax.experimental.array_api._elementwise_functions import (
195+
ceil as ceil,
196+
clip as clip,
197+
floor as floor,
198+
hypot as hypot,
199+
trunc as trunc,
205200
)
206201

207202
from jax.experimental.array_api._statistical_functions import (
208-
cumulative_sum as cumulative_sum,
209-
max as max,
210-
mean as mean,
211-
min as min,
212-
prod as prod,
213203
std as std,
214-
sum as sum,
215-
var as var
204+
var as var,
216205
)
217206

218207
from jax.experimental.array_api._utility_functions import (
219208
__array_namespace_info__ as __array_namespace_info__,
220-
all as all,
221-
any as any,
222-
)
223-
224-
from jax.experimental.array_api._linear_algebra_functions import (
225-
matmul as matmul,
226-
matrix_transpose as matrix_transpose,
227-
tensordot as tensordot,
228-
vecdot as vecdot,
229209
)
230210

231211
from jax.experimental.array_api import _array_methods

jax/experimental/array_api/_array_methods.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Any, Callable
17+
from typing import Any
1818

1919
import jax
2020
from jax._src.array import ArrayImpl

jax/experimental/array_api/_constants.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

0 commit comments

Comments
 (0)