Skip to content

Commit c4cfbcb

Browse files
committed
Refactor array_api namespace, relying more directly on jax.numpy
1 parent a949ce7 commit c4cfbcb

23 files changed

+192
-1167
lines changed

jax/_src/dtypes.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
435435
}
436436

437437

438-
def isdtype(dtype: DTypeLike, kind: str | DType | tuple[str | DType, ...]) -> bool:
438+
def isdtype(dtype: DTypeLike, kind: str | DTypeLike | tuple[DTypeLike, ...]) -> bool:
439439
"""Returns a boolean indicating whether a provided dtype is of a specified kind.
440440
441441
Args:
@@ -458,18 +458,22 @@ def isdtype(dtype: DTypeLike, kind: str | DType | tuple[str | DType, ...]) -> bo
458458
True or False
459459
"""
460460
the_dtype = np.dtype(dtype)
461-
kind_tuple: tuple[DType | str, ...] = kind if isinstance(kind, tuple) else (kind,)
461+
kind_tuple: tuple[DTypeLike, ...] = kind if isinstance(kind, tuple) else (kind,)
462462
options: set[DType] = set()
463463
for kind in kind_tuple:
464-
if isinstance(kind, str):
465-
if kind not in _dtype_kinds:
466-
raise ValueError(f"Unrecognized {kind=} expected one of {list(_dtype_kinds.keys())}")
464+
if isinstance(kind, str) and kind in _dtype_kinds:
467465
options.update(_dtype_kinds[kind])
468-
elif isinstance(kind, np.dtype):
469-
options.add(kind)
470-
else:
471-
# TODO(jakevdp): should we handle scalar types or ScalarMeta here?
472-
raise TypeError(f"Expected kind to be a dtype, string, or tuple; got {kind=}")
466+
continue
467+
try:
468+
_dtype = np.dtype(kind)
469+
except TypeError as e:
470+
if isinstance(kind, str):
471+
raise ValueError(
472+
f"Unrecognized {kind=} expected one of {list(_dtype_kinds.keys())}, "
473+
"or a compatible input for jnp.dtype()")
474+
raise TypeError(f"Expected kind to be a dtype, string, or tuple; got {kind=}") from e
475+
options.add(_dtype)
476+
continue
473477
return the_dtype in options
474478

475479

jax/experimental/array_api/__init__.py

Lines changed: 93 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
>>> from jax.experimental import array_api as xp
2222
2323
>>> xp.__array_api_version__
24-
'2022.12'
24+
'2023.12'
2525
2626
>>> arr = xp.arange(1000)
2727
@@ -38,68 +38,19 @@
3838

3939
from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__
4040

41-
from jax.experimental.array_api import (
42-
fft as fft,
43-
linalg as linalg,
44-
)
45-
46-
from jax.experimental.array_api._constants import (
47-
e as e,
48-
inf as inf,
49-
nan as nan,
50-
newaxis as newaxis,
51-
pi as pi,
52-
)
53-
54-
from jax.experimental.array_api._creation_functions import (
55-
arange as arange,
56-
asarray as asarray,
57-
empty as empty,
58-
empty_like as empty_like,
59-
eye as eye,
60-
from_dlpack as from_dlpack,
61-
full as full,
62-
full_like as full_like,
63-
linspace as linspace,
64-
meshgrid as meshgrid,
65-
ones as ones,
66-
ones_like as ones_like,
67-
tril as tril,
68-
triu as triu,
69-
zeros as zeros,
70-
zeros_like as zeros_like,
71-
)
72-
73-
from jax.experimental.array_api._data_type_functions import (
74-
astype as astype,
75-
can_cast as can_cast,
76-
finfo as finfo,
77-
iinfo as iinfo,
78-
isdtype as isdtype,
79-
result_type as result_type,
80-
)
41+
from jax.experimental.array_api import fft as fft
42+
from jax.experimental.array_api import linalg as linalg
8143

82-
from jax.experimental.array_api._dtypes import (
83-
bool as bool,
84-
int8 as int8,
85-
int16 as int16,
86-
int32 as int32,
87-
int64 as int64,
88-
uint8 as uint8,
89-
uint16 as uint16,
90-
uint32 as uint32,
91-
uint64 as uint64,
92-
float32 as float32,
93-
float64 as float64,
94-
complex64 as complex64,
95-
complex128 as complex128,
96-
)
97-
98-
from jax.experimental.array_api._elementwise_functions import (
44+
from jax.numpy import (
9945
abs as abs,
10046
acos as acos,
10147
acosh as acosh,
10248
add as add,
49+
all as all,
50+
any as any,
51+
argmax as argmax,
52+
argmin as argmin,
53+
argsort as argsort,
10354
asin as asin,
10455
asinh as asinh,
10556
atan as atan,
@@ -111,22 +62,43 @@
11162
bitwise_or as bitwise_or,
11263
bitwise_right_shift as bitwise_right_shift,
11364
bitwise_xor as bitwise_xor,
114-
ceil as ceil,
115-
clip as clip,
65+
bool as bool,
66+
broadcast_arrays as broadcast_arrays,
67+
broadcast_to as broadcast_to,
68+
can_cast as can_cast,
69+
complex128 as complex128,
70+
complex64 as complex64,
71+
concat as concat,
11672
conj as conj,
11773
copysign as copysign,
11874
cos as cos,
11975
cosh as cosh,
76+
cumulative_sum as cumulative_sum,
12077
divide as divide,
78+
e as e,
79+
empty as empty,
80+
empty_like as empty_like,
12181
equal as equal,
12282
exp as exp,
83+
expand_dims as expand_dims,
12384
expm1 as expm1,
124-
floor as floor,
85+
flip as flip,
86+
float32 as float32,
87+
float64 as float64,
12588
floor_divide as floor_divide,
89+
from_dlpack as from_dlpack,
90+
full as full,
91+
full_like as full_like,
12692
greater as greater,
12793
greater_equal as greater_equal,
128-
hypot as hypot,
94+
iinfo as iinfo,
12995
imag as imag,
96+
inf as inf,
97+
int16 as int16,
98+
int32 as int32,
99+
int64 as int64,
100+
int8 as int8,
101+
isdtype as isdtype,
130102
isfinite as isfinite,
131103
isinf as isinf,
132104
isnan as isnan,
@@ -137,95 +109,103 @@
137109
log1p as log1p,
138110
log2 as log2,
139111
logaddexp as logaddexp,
140-
logical_and as logical_and,
141112
logical_not as logical_not,
142-
logical_or as logical_or,
143-
logical_xor as logical_xor,
113+
matmul as matmul,
114+
matrix_transpose as matrix_transpose,
115+
max as max,
144116
maximum as maximum,
117+
mean as mean,
118+
meshgrid as meshgrid,
119+
min as min,
145120
minimum as minimum,
121+
moveaxis as moveaxis,
146122
multiply as multiply,
123+
nan as nan,
147124
negative as negative,
125+
newaxis as newaxis,
126+
nonzero as nonzero,
148127
not_equal as not_equal,
128+
ones as ones,
129+
ones_like as ones_like,
130+
permute_dims as permute_dims,
131+
pi as pi,
149132
positive as positive,
150133
pow as pow,
134+
prod as prod,
151135
real as real,
152136
remainder as remainder,
137+
repeat as repeat,
138+
result_type as result_type,
139+
roll as roll,
153140
round as round,
141+
searchsorted as searchsorted,
154142
sign as sign,
155143
signbit as signbit,
156144
sin as sin,
157145
sinh as sinh,
146+
sort as sort,
158147
sqrt as sqrt,
159148
square as square,
149+
squeeze as squeeze,
150+
stack as stack,
160151
subtract as subtract,
152+
sum as sum,
153+
take as take,
161154
tan as tan,
162155
tanh as tanh,
163-
trunc as trunc,
164-
)
165-
166-
from jax.experimental.array_api._indexing_functions import (
167-
take as take,
156+
tensordot as tensordot,
157+
tile as tile,
158+
tril as tril,
159+
triu as triu,
160+
uint16 as uint16,
161+
uint32 as uint32,
162+
uint64 as uint64,
163+
uint8 as uint8,
164+
unique_all as unique_all,
165+
unique_counts as unique_counts,
166+
unique_inverse as unique_inverse,
167+
unique_values as unique_values,
168+
unstack as unstack,
169+
vecdot as vecdot,
170+
where as where,
171+
zeros as zeros,
172+
zeros_like as zeros_like,
168173
)
169174

170175
from jax.experimental.array_api._manipulation_functions import (
171-
broadcast_arrays as broadcast_arrays,
172-
broadcast_to as broadcast_to,
173-
concat as concat,
174-
expand_dims as expand_dims,
175-
flip as flip,
176-
moveaxis as moveaxis,
177-
permute_dims as permute_dims,
178-
repeat as repeat,
179176
reshape as reshape,
180-
roll as roll,
181-
squeeze as squeeze,
182-
stack as stack,
183-
tile as tile,
184-
unstack as unstack,
185177
)
186178

187-
from jax.experimental.array_api._searching_functions import (
188-
argmax as argmax,
189-
argmin as argmin,
190-
nonzero as nonzero,
191-
searchsorted as searchsorted,
192-
where as where,
179+
from jax.experimental.array_api._creation_functions import (
180+
arange as arange,
181+
asarray as asarray,
182+
eye as eye,
183+
linspace as linspace,
193184
)
194185

195-
from jax.experimental.array_api._set_functions import (
196-
unique_all as unique_all,
197-
unique_counts as unique_counts,
198-
unique_inverse as unique_inverse,
199-
unique_values as unique_values,
186+
from jax.experimental.array_api._data_type_functions import (
187+
astype as astype,
188+
finfo as finfo,
200189
)
201190

202-
from jax.experimental.array_api._sorting_functions import (
203-
argsort as argsort,
204-
sort as sort,
191+
from jax.experimental.array_api._elementwise_functions import (
192+
ceil as ceil,
193+
clip as clip,
194+
floor as floor,
195+
hypot as hypot,
196+
logical_and as logical_and,
197+
logical_or as logical_or,
198+
logical_xor as logical_xor,
199+
trunc as trunc,
205200
)
206201

207202
from jax.experimental.array_api._statistical_functions import (
208-
cumulative_sum as cumulative_sum,
209-
max as max,
210-
mean as mean,
211-
min as min,
212-
prod as prod,
213203
std as std,
214-
sum as sum,
215-
var as var
204+
var as var,
216205
)
217206

218207
from jax.experimental.array_api._utility_functions import (
219208
__array_namespace_info__ as __array_namespace_info__,
220-
all as all,
221-
any as any,
222-
)
223-
224-
from jax.experimental.array_api._linear_algebra_functions import (
225-
matmul as matmul,
226-
matrix_transpose as matrix_transpose,
227-
tensordot as tensordot,
228-
vecdot as vecdot,
229209
)
230210

231211
from jax.experimental.array_api import _array_methods

jax/experimental/array_api/_array_methods.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Any, Callable
17+
from typing import Any
1818

1919
import jax
2020
from jax._src.array import ArrayImpl

jax/experimental/array_api/_constants.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

0 commit comments

Comments
 (0)