|
21 | 21 | >>> from jax.experimental import array_api as xp
|
22 | 22 |
|
23 | 23 | >>> xp.__array_api_version__
|
24 |
| - '2022.12' |
| 24 | + '2023.12' |
25 | 25 |
|
26 | 26 | >>> arr = xp.arange(1000)
|
27 | 27 |
|
|
38 | 38 |
|
39 | 39 | from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__
|
40 | 40 |
|
41 |
| -from jax.experimental.array_api import ( |
42 |
| - fft as fft, |
43 |
| - linalg as linalg, |
44 |
| -) |
| 41 | +from jax.experimental.array_api import fft as fft |
| 42 | +from jax.experimental.array_api import linalg as linalg |
45 | 43 |
|
46 |
| -from jax.experimental.array_api._constants import ( |
| 44 | +from jax.numpy import ( |
47 | 45 | e as e,
|
48 | 46 | inf as inf,
|
49 | 47 | nan as nan,
|
50 | 48 | newaxis as newaxis,
|
51 | 49 | pi as pi,
|
52 |
| -) |
53 |
| - |
54 |
| -from jax.experimental.array_api._creation_functions import ( |
55 |
| - arange as arange, |
56 |
| - asarray as asarray, |
57 |
| - empty as empty, |
58 |
| - empty_like as empty_like, |
59 |
| - eye as eye, |
60 |
| - from_dlpack as from_dlpack, |
61 |
| - full as full, |
62 |
| - full_like as full_like, |
63 |
| - linspace as linspace, |
64 |
| - meshgrid as meshgrid, |
65 |
| - ones as ones, |
66 |
| - ones_like as ones_like, |
67 | 50 | tril as tril,
|
68 | 51 | triu as triu,
|
69 |
| - zeros as zeros, |
70 |
| - zeros_like as zeros_like, |
71 |
| -) |
72 |
| - |
73 |
| -from jax.experimental.array_api._data_type_functions import ( |
74 |
| - astype as astype, |
75 |
| - can_cast as can_cast, |
76 |
| - finfo as finfo, |
77 |
| - iinfo as iinfo, |
78 |
| - isdtype as isdtype, |
79 |
| - result_type as result_type, |
80 |
| -) |
81 |
| - |
82 |
| -from jax.experimental.array_api._dtypes import ( |
83 |
| - bool as bool, |
84 |
| - int8 as int8, |
85 |
| - int16 as int16, |
86 |
| - int32 as int32, |
87 |
| - int64 as int64, |
88 |
| - uint8 as uint8, |
89 |
| - uint16 as uint16, |
90 |
| - uint32 as uint32, |
91 |
| - uint64 as uint64, |
92 |
| - float32 as float32, |
93 |
| - float64 as float64, |
94 |
| - complex64 as complex64, |
95 |
| - complex128 as complex128, |
96 |
| -) |
97 |
| - |
98 |
| -from jax.experimental.array_api._elementwise_functions import ( |
99 | 52 | abs as abs,
|
100 | 53 | acos as acos,
|
101 | 54 | acosh as acosh,
|
|
111 | 64 | bitwise_or as bitwise_or,
|
112 | 65 | bitwise_right_shift as bitwise_right_shift,
|
113 | 66 | bitwise_xor as bitwise_xor,
|
114 |
| - ceil as ceil, |
115 |
| - clip as clip, |
116 | 67 | conj as conj,
|
117 | 68 | copysign as copysign,
|
118 | 69 | cos as cos,
|
|
121 | 72 | equal as equal,
|
122 | 73 | exp as exp,
|
123 | 74 | expm1 as expm1,
|
124 |
| - floor as floor, |
125 | 75 | floor_divide as floor_divide,
|
126 | 76 | greater as greater,
|
127 | 77 | greater_equal as greater_equal,
|
128 |
| - hypot as hypot, |
129 | 78 | imag as imag,
|
130 | 79 | isfinite as isfinite,
|
131 | 80 | isinf as isinf,
|
|
137 | 86 | log1p as log1p,
|
138 | 87 | log2 as log2,
|
139 | 88 | logaddexp as logaddexp,
|
140 |
| - logical_and as logical_and, |
141 | 89 | logical_not as logical_not,
|
142 |
| - logical_or as logical_or, |
143 |
| - logical_xor as logical_xor, |
144 | 90 | maximum as maximum,
|
145 | 91 | minimum as minimum,
|
146 | 92 | multiply as multiply,
|
|
151 | 97 | real as real,
|
152 | 98 | remainder as remainder,
|
153 | 99 | round as round,
|
154 |
| - sign as sign, |
155 | 100 | signbit as signbit,
|
156 | 101 | sin as sin,
|
157 | 102 | sinh as sinh,
|
158 | 103 | sqrt as sqrt,
|
159 | 104 | square as square,
|
160 | 105 | subtract as subtract,
|
161 | 106 | tan as tan,
|
162 |
| - tanh as tanh, |
163 |
| - trunc as trunc, |
164 |
| -) |
165 |
| - |
166 |
| -from jax.experimental.array_api._indexing_functions import ( |
167 | 107 | take as take,
|
168 |
| -) |
169 |
| - |
170 |
| -from jax.experimental.array_api._manipulation_functions import ( |
| 108 | + tanh as tanh, |
171 | 109 | broadcast_arrays as broadcast_arrays,
|
172 | 110 | broadcast_to as broadcast_to,
|
173 | 111 | concat as concat,
|
|
176 | 114 | moveaxis as moveaxis,
|
177 | 115 | permute_dims as permute_dims,
|
178 | 116 | repeat as repeat,
|
179 |
| - reshape as reshape, |
180 | 117 | roll as roll,
|
181 | 118 | squeeze as squeeze,
|
182 | 119 | stack as stack,
|
183 | 120 | tile as tile,
|
184 | 121 | unstack as unstack,
|
185 |
| -) |
186 |
| - |
187 |
| -from jax.experimental.array_api._searching_functions import ( |
188 | 122 | argmax as argmax,
|
189 | 123 | argmin as argmin,
|
190 |
| - nonzero as nonzero, |
191 | 124 | searchsorted as searchsorted,
|
192 | 125 | where as where,
|
193 |
| -) |
194 |
| - |
195 |
| -from jax.experimental.array_api._set_functions import ( |
196 | 126 | unique_all as unique_all,
|
197 | 127 | unique_counts as unique_counts,
|
198 | 128 | unique_inverse as unique_inverse,
|
199 | 129 | unique_values as unique_values,
|
200 |
| -) |
201 |
| - |
202 |
| -from jax.experimental.array_api._sorting_functions import ( |
203 | 130 | argsort as argsort,
|
204 | 131 | sort as sort,
|
205 |
| -) |
206 |
| - |
207 |
| -from jax.experimental.array_api._statistical_functions import ( |
208 | 132 | cumulative_sum as cumulative_sum,
|
209 | 133 | max as max,
|
210 | 134 | mean as mean,
|
211 | 135 | min as min,
|
212 |
| - prod as prod, |
213 |
| - std as std, |
214 |
| - sum as sum, |
215 |
| - var as var |
216 |
| -) |
217 |
| - |
218 |
| -from jax.experimental.array_api._utility_functions import ( |
219 |
| - __array_namespace_info__ as __array_namespace_info__, |
220 | 136 | all as all,
|
221 | 137 | any as any,
|
| 138 | + from_dlpack as from_dlpack, |
| 139 | + meshgrid as meshgrid, |
| 140 | + empty as empty, |
| 141 | + empty_like as empty_like, |
| 142 | + full as full, |
| 143 | + full_like as full_like, |
| 144 | + ones as ones, |
| 145 | + ones_like as ones_like, |
| 146 | + zeros as zeros, |
| 147 | + zeros_like as zeros_like, |
| 148 | + can_cast as can_cast, |
| 149 | + isdtype as isdtype, |
| 150 | + result_type as result_type, |
| 151 | + iinfo as iinfo, |
| 152 | + sign as sign, |
| 153 | + nonzero as nonzero, |
| 154 | + prod as prod, |
| 155 | + sum as sum, |
222 | 156 | )
|
223 | 157 |
|
224 |
| -from jax.experimental.array_api._linear_algebra_functions import ( |
| 158 | +# TODO(mickey): Remove these imports once we have add them to jax.numpy namespace |
| 159 | +from jax.numpy.linalg import ( |
225 | 160 | matmul as matmul,
|
226 | 161 | matrix_transpose as matrix_transpose,
|
227 | 162 | tensordot as tensordot,
|
228 | 163 | vecdot as vecdot,
|
229 | 164 | )
|
230 | 165 |
|
| 166 | +from jax.experimental.array_api._manipulation_functions import ( |
| 167 | + reshape as reshape, |
| 168 | +) |
| 169 | + |
| 170 | +from jax.experimental.array_api._creation_functions import ( |
| 171 | + arange as arange, |
| 172 | + asarray as asarray, |
| 173 | + eye as eye, |
| 174 | + linspace as linspace, |
| 175 | +) |
| 176 | + |
| 177 | +from jax.experimental.array_api._data_type_functions import ( |
| 178 | + bool as bool, |
| 179 | + int8 as int8, |
| 180 | + int16 as int16, |
| 181 | + int32 as int32, |
| 182 | + int64 as int64, |
| 183 | + uint8 as uint8, |
| 184 | + uint16 as uint16, |
| 185 | + uint32 as uint32, |
| 186 | + uint64 as uint64, |
| 187 | + float32 as float32, |
| 188 | + float64 as float64, |
| 189 | + complex64 as complex64, |
| 190 | + complex128 as complex128, |
| 191 | + astype as astype, |
| 192 | + finfo as finfo, |
| 193 | +) |
| 194 | + |
| 195 | +from jax.experimental.array_api._elementwise_functions import ( |
| 196 | + ceil as ceil, |
| 197 | + clip as clip, |
| 198 | + floor as floor, |
| 199 | + hypot as hypot, |
| 200 | + logical_and as logical_and, |
| 201 | + logical_or as logical_or, |
| 202 | + logical_xor as logical_xor, |
| 203 | + trunc as trunc, |
| 204 | +) |
| 205 | + |
| 206 | +from jax.experimental.array_api._statistical_functions import ( |
| 207 | + std as std, |
| 208 | + var as var, |
| 209 | +) |
| 210 | + |
| 211 | +from jax.experimental.array_api._utility_functions import ( |
| 212 | + __array_namespace_info__ as __array_namespace_info__, |
| 213 | +) |
| 214 | + |
231 | 215 | from jax.experimental.array_api import _array_methods
|
232 | 216 | _array_methods.add_array_object_methods()
|
233 | 217 | del _array_methods
|
0 commit comments