Skip to content

Commit 36e53d9

Browse files
committed
Refactor array_api namespace, relying more directly on jax.numpy
1 parent a949ce7 commit 36e53d9

23 files changed

+166
-1143
lines changed

.github/workflows/jax-array-api.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,4 @@ jobs:
4242
JAX_ENABLE_X64: 'true'
4343
run: |
4444
cd ${GITHUB_WORKSPACE}/array-api-tests
45-
pytest array_api_tests --max-examples=5 --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/jax/experimental/array_api/skips.txt
45+
pytest array_api_tests --max-examples=5 --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/jax/experimental/array_api/skips.txt -Wignore::UserWarning

jax/_src/dtypes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,9 @@ def check_valid_dtype(dtype: DType) -> None:
652652
raise TypeError(f"Dtype {dtype} is not a valid JAX array "
653653
"type. Only arrays of numeric types are supported by JAX.")
654654

655+
def is_valid_dtype(dtype: DType) -> bool:
656+
return dtype in _jax_dtype_set
657+
655658
def dtype(x: Any, *, canonicalize: bool = False) -> DType:
656659
"""Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
657660
if x is None:

jax/experimental/array_api/__init__.py

Lines changed: 74 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
>>> from jax.experimental import array_api as xp
2222
2323
>>> xp.__array_api_version__
24-
'2022.12'
24+
'2023.12'
2525
2626
>>> arr = xp.arange(1000)
2727
@@ -38,64 +38,17 @@
3838

3939
from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__
4040

41-
from jax.experimental.array_api import (
42-
fft as fft,
43-
linalg as linalg,
44-
)
41+
from jax.experimental.array_api import fft as fft
42+
from jax.experimental.array_api import linalg as linalg
4543

46-
from jax.experimental.array_api._constants import (
44+
from jax.numpy import (
4745
e as e,
4846
inf as inf,
4947
nan as nan,
5048
newaxis as newaxis,
5149
pi as pi,
52-
)
53-
54-
from jax.experimental.array_api._creation_functions import (
55-
arange as arange,
56-
asarray as asarray,
57-
empty as empty,
58-
empty_like as empty_like,
59-
eye as eye,
60-
from_dlpack as from_dlpack,
61-
full as full,
62-
full_like as full_like,
63-
linspace as linspace,
64-
meshgrid as meshgrid,
65-
ones as ones,
66-
ones_like as ones_like,
6750
tril as tril,
6851
triu as triu,
69-
zeros as zeros,
70-
zeros_like as zeros_like,
71-
)
72-
73-
from jax.experimental.array_api._data_type_functions import (
74-
astype as astype,
75-
can_cast as can_cast,
76-
finfo as finfo,
77-
iinfo as iinfo,
78-
isdtype as isdtype,
79-
result_type as result_type,
80-
)
81-
82-
from jax.experimental.array_api._dtypes import (
83-
bool as bool,
84-
int8 as int8,
85-
int16 as int16,
86-
int32 as int32,
87-
int64 as int64,
88-
uint8 as uint8,
89-
uint16 as uint16,
90-
uint32 as uint32,
91-
uint64 as uint64,
92-
float32 as float32,
93-
float64 as float64,
94-
complex64 as complex64,
95-
complex128 as complex128,
96-
)
97-
98-
from jax.experimental.array_api._elementwise_functions import (
9952
abs as abs,
10053
acos as acos,
10154
acosh as acosh,
@@ -111,8 +64,6 @@
11164
bitwise_or as bitwise_or,
11265
bitwise_right_shift as bitwise_right_shift,
11366
bitwise_xor as bitwise_xor,
114-
ceil as ceil,
115-
clip as clip,
11667
conj as conj,
11768
copysign as copysign,
11869
cos as cos,
@@ -121,11 +72,9 @@
12172
equal as equal,
12273
exp as exp,
12374
expm1 as expm1,
124-
floor as floor,
12575
floor_divide as floor_divide,
12676
greater as greater,
12777
greater_equal as greater_equal,
128-
hypot as hypot,
12978
imag as imag,
13079
isfinite as isfinite,
13180
isinf as isinf,
@@ -137,10 +86,7 @@
13786
log1p as log1p,
13887
log2 as log2,
13988
logaddexp as logaddexp,
140-
logical_and as logical_and,
14189
logical_not as logical_not,
142-
logical_or as logical_or,
143-
logical_xor as logical_xor,
14490
maximum as maximum,
14591
minimum as minimum,
14692
multiply as multiply,
@@ -151,23 +97,15 @@
15197
real as real,
15298
remainder as remainder,
15399
round as round,
154-
sign as sign,
155100
signbit as signbit,
156101
sin as sin,
157102
sinh as sinh,
158103
sqrt as sqrt,
159104
square as square,
160105
subtract as subtract,
161106
tan as tan,
162-
tanh as tanh,
163-
trunc as trunc,
164-
)
165-
166-
from jax.experimental.array_api._indexing_functions import (
167107
take as take,
168-
)
169-
170-
from jax.experimental.array_api._manipulation_functions import (
108+
tanh as tanh,
171109
broadcast_arrays as broadcast_arrays,
172110
broadcast_to as broadcast_to,
173111
concat as concat,
@@ -176,58 +114,104 @@
176114
moveaxis as moveaxis,
177115
permute_dims as permute_dims,
178116
repeat as repeat,
179-
reshape as reshape,
180117
roll as roll,
181118
squeeze as squeeze,
182119
stack as stack,
183120
tile as tile,
184121
unstack as unstack,
185-
)
186-
187-
from jax.experimental.array_api._searching_functions import (
188122
argmax as argmax,
189123
argmin as argmin,
190-
nonzero as nonzero,
191124
searchsorted as searchsorted,
192125
where as where,
193-
)
194-
195-
from jax.experimental.array_api._set_functions import (
196126
unique_all as unique_all,
197127
unique_counts as unique_counts,
198128
unique_inverse as unique_inverse,
199129
unique_values as unique_values,
200-
)
201-
202-
from jax.experimental.array_api._sorting_functions import (
203130
argsort as argsort,
204131
sort as sort,
205-
)
206-
207-
from jax.experimental.array_api._statistical_functions import (
208132
cumulative_sum as cumulative_sum,
209133
max as max,
210134
mean as mean,
211135
min as min,
212-
prod as prod,
213-
std as std,
214-
sum as sum,
215-
var as var
216-
)
217-
218-
from jax.experimental.array_api._utility_functions import (
219-
__array_namespace_info__ as __array_namespace_info__,
220136
all as all,
221137
any as any,
138+
from_dlpack as from_dlpack,
139+
meshgrid as meshgrid,
140+
empty as empty,
141+
empty_like as empty_like,
142+
full as full,
143+
full_like as full_like,
144+
ones as ones,
145+
ones_like as ones_like,
146+
zeros as zeros,
147+
zeros_like as zeros_like,
148+
can_cast as can_cast,
149+
isdtype as isdtype,
150+
result_type as result_type,
151+
iinfo as iinfo,
152+
sign as sign,
153+
nonzero as nonzero,
154+
prod as prod,
155+
sum as sum,
222156
)
223157

224-
from jax.experimental.array_api._linear_algebra_functions import (
158+
# TODO(mickey): Remove these imports once we have add them to jax.numpy namespace
159+
from jax.numpy.linalg import (
225160
matmul as matmul,
226161
matrix_transpose as matrix_transpose,
227162
tensordot as tensordot,
228163
vecdot as vecdot,
229164
)
230165

166+
from jax.experimental.array_api._manipulation_functions import (
167+
reshape as reshape,
168+
)
169+
170+
from jax.experimental.array_api._creation_functions import (
171+
arange as arange,
172+
asarray as asarray,
173+
eye as eye,
174+
linspace as linspace,
175+
)
176+
177+
from jax.experimental.array_api._data_type_functions import (
178+
bool as bool,
179+
int8 as int8,
180+
int16 as int16,
181+
int32 as int32,
182+
int64 as int64,
183+
uint8 as uint8,
184+
uint16 as uint16,
185+
uint32 as uint32,
186+
uint64 as uint64,
187+
float32 as float32,
188+
float64 as float64,
189+
complex64 as complex64,
190+
complex128 as complex128,
191+
astype as astype,
192+
finfo as finfo,
193+
)
194+
195+
from jax.experimental.array_api._elementwise_functions import (
196+
ceil as ceil,
197+
clip as clip,
198+
floor as floor,
199+
hypot as hypot,
200+
logical_and as logical_and,
201+
logical_or as logical_or,
202+
logical_xor as logical_xor,
203+
trunc as trunc,
204+
)
205+
206+
from jax.experimental.array_api._statistical_functions import (
207+
std as std,
208+
var as var,
209+
)
210+
211+
from jax.experimental.array_api._utility_functions import (
212+
__array_namespace_info__ as __array_namespace_info__,
213+
)
214+
231215
from jax.experimental.array_api import _array_methods
232216
_array_methods.add_array_object_methods()
233217
del _array_methods

jax/experimental/array_api/_array_methods.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Any, Callable
17+
from typing import Any
1818

1919
import jax
2020
from jax._src.array import ArrayImpl

jax/experimental/array_api/_constants.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

jax/experimental/array_api/_creation_functions.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,53 +16,16 @@
1616

1717
import jax
1818
import jax.numpy as jnp
19-
from jax._src.lib import xla_client as xc
20-
from jax._src.sharding import Sharding
2119

20+
# TODO(micky774): Deprecate after adding device argument to jax.numpy functions
2221
def arange(start, /, stop=None, step=1, *, dtype=None, device=None):
2322
return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device)
2423

2524
def asarray(obj, /, *, dtype=None, device=None, copy=None):
2625
return jax.device_put(jnp.array(obj, dtype=dtype, copy=copy), device=device)
2726

28-
def empty(shape, *, dtype=None, device=None):
29-
return jax.device_put(jnp.empty(shape, dtype=dtype), device=device)
30-
31-
def empty_like(x, /, *, dtype=None, device=None):
32-
return jax.device_put(jnp.empty_like(x, dtype=dtype), device=device)
33-
3427
def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None):
3528
return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device)
3629

37-
def from_dlpack(x, /, *, device: xc.Device | Sharding | None = None, copy: bool | None = None):
38-
return jnp.from_dlpack(x, device=device, copy=copy)
39-
40-
def full(shape, fill_value, *, dtype=None, device=None):
41-
return jax.device_put(jnp.full(shape, fill_value, dtype=dtype), device=device)
42-
43-
def full_like(x, /, fill_value, *, dtype=None, device=None):
44-
return jax.device_put(jnp.full_like(x, fill_value=fill_value, dtype=dtype), device=device)
45-
4630
def linspace(start, stop, /, num, *, dtype=None, device=None, endpoint=True):
4731
return jax.device_put(jnp.linspace(start, stop, num=num, dtype=dtype, endpoint=endpoint), device=device)
48-
49-
def meshgrid(*arrays, indexing='xy'):
50-
return jnp.meshgrid(*arrays, indexing=indexing)
51-
52-
def ones(shape, *, dtype=None, device=None):
53-
return jax.device_put(jnp.ones(shape, dtype=dtype), device=device)
54-
55-
def ones_like(x, /, *, dtype=None, device=None):
56-
return jax.device_put(jnp.ones_like(x, dtype=dtype), device=device)
57-
58-
def tril(x, /, *, k=0):
59-
return jnp.tril(x, k=k)
60-
61-
def triu(x, /, *, k=0):
62-
return jnp.triu(x, k=k)
63-
64-
def zeros(shape, *, dtype=None, device=None):
65-
return jax.device_put(jnp.zeros(shape, dtype=dtype), device=device)
66-
67-
def zeros_like(x, /, *, dtype=None, device=None):
68-
return jax.device_put(jnp.zeros_like(x, dtype=dtype), device=device)

0 commit comments

Comments
 (0)