|
21 | 21 | >>> from jax.experimental import array_api as xp
|
22 | 22 |
|
23 | 23 | >>> xp.__array_api_version__
|
24 |
| - '2022.12' |
| 24 | + '2023.12' |
25 | 25 |
|
26 | 26 | >>> arr = xp.arange(1000)
|
27 | 27 |
|
|
38 | 38 |
|
39 | 39 | from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__
|
40 | 40 |
|
41 |
| -from jax.experimental.array_api import ( |
42 |
| - fft as fft, |
43 |
| - linalg as linalg, |
44 |
| -) |
45 |
| - |
46 |
| -from jax.experimental.array_api._constants import ( |
47 |
| - e as e, |
48 |
| - inf as inf, |
49 |
| - nan as nan, |
50 |
| - newaxis as newaxis, |
51 |
| - pi as pi, |
52 |
| -) |
53 |
| - |
54 |
| -from jax.experimental.array_api._creation_functions import ( |
55 |
| - arange as arange, |
56 |
| - asarray as asarray, |
57 |
| - empty as empty, |
58 |
| - empty_like as empty_like, |
59 |
| - eye as eye, |
60 |
| - from_dlpack as from_dlpack, |
61 |
| - full as full, |
62 |
| - full_like as full_like, |
63 |
| - linspace as linspace, |
64 |
| - meshgrid as meshgrid, |
65 |
| - ones as ones, |
66 |
| - ones_like as ones_like, |
67 |
| - tril as tril, |
68 |
| - triu as triu, |
69 |
| - zeros as zeros, |
70 |
| - zeros_like as zeros_like, |
71 |
| -) |
72 |
| - |
73 |
| -from jax.experimental.array_api._data_type_functions import ( |
74 |
| - astype as astype, |
75 |
| - can_cast as can_cast, |
76 |
| - finfo as finfo, |
77 |
| - iinfo as iinfo, |
78 |
| - isdtype as isdtype, |
79 |
| - result_type as result_type, |
80 |
| -) |
81 |
| - |
82 |
| -from jax.experimental.array_api._dtypes import ( |
83 |
| - bool as bool, |
84 |
| - int8 as int8, |
85 |
| - int16 as int16, |
86 |
| - int32 as int32, |
87 |
| - int64 as int64, |
88 |
| - uint8 as uint8, |
89 |
| - uint16 as uint16, |
90 |
| - uint32 as uint32, |
91 |
| - uint64 as uint64, |
92 |
| - float32 as float32, |
93 |
| - float64 as float64, |
94 |
| - complex64 as complex64, |
95 |
| - complex128 as complex128, |
96 |
| -) |
| 41 | +from jax.experimental.array_api import fft as fft |
| 42 | +from jax.experimental.array_api import linalg as linalg |
97 | 43 |
|
98 |
| -from jax.experimental.array_api._elementwise_functions import ( |
| 44 | +from jax.numpy import ( |
99 | 45 | abs as abs,
|
100 | 46 | acos as acos,
|
101 | 47 | acosh as acosh,
|
102 | 48 | add as add,
|
| 49 | + all as all, |
| 50 | + any as any, |
| 51 | + argmax as argmax, |
| 52 | + argmin as argmin, |
| 53 | + argsort as argsort, |
103 | 54 | asin as asin,
|
104 | 55 | asinh as asinh,
|
105 | 56 | atan as atan,
|
|
111 | 62 | bitwise_or as bitwise_or,
|
112 | 63 | bitwise_right_shift as bitwise_right_shift,
|
113 | 64 | bitwise_xor as bitwise_xor,
|
114 |
| - ceil as ceil, |
115 |
| - clip as clip, |
| 65 | + bool as bool, |
| 66 | + broadcast_arrays as broadcast_arrays, |
| 67 | + broadcast_to as broadcast_to, |
| 68 | + can_cast as can_cast, |
| 69 | + complex128 as complex128, |
| 70 | + complex64 as complex64, |
| 71 | + concat as concat, |
116 | 72 | conj as conj,
|
117 | 73 | copysign as copysign,
|
118 | 74 | cos as cos,
|
119 | 75 | cosh as cosh,
|
| 76 | + cumulative_sum as cumulative_sum, |
120 | 77 | divide as divide,
|
| 78 | + e as e, |
| 79 | + empty as empty, |
| 80 | + empty_like as empty_like, |
121 | 81 | equal as equal,
|
122 | 82 | exp as exp,
|
| 83 | + expand_dims as expand_dims, |
123 | 84 | expm1 as expm1,
|
124 |
| - floor as floor, |
| 85 | + flip as flip, |
| 86 | + float32 as float32, |
| 87 | + float64 as float64, |
125 | 88 | floor_divide as floor_divide,
|
| 89 | + from_dlpack as from_dlpack, |
| 90 | + full as full, |
| 91 | + full_like as full_like, |
126 | 92 | greater as greater,
|
127 | 93 | greater_equal as greater_equal,
|
128 |
| - hypot as hypot, |
| 94 | + iinfo as iinfo, |
129 | 95 | imag as imag,
|
| 96 | + inf as inf, |
| 97 | + int16 as int16, |
| 98 | + int32 as int32, |
| 99 | + int64 as int64, |
| 100 | + int8 as int8, |
| 101 | + isdtype as isdtype, |
130 | 102 | isfinite as isfinite,
|
131 | 103 | isinf as isinf,
|
132 | 104 | isnan as isnan,
|
|
141 | 113 | logical_not as logical_not,
|
142 | 114 | logical_or as logical_or,
|
143 | 115 | logical_xor as logical_xor,
|
| 116 | + matmul as matmul, |
| 117 | + matrix_transpose as matrix_transpose, |
| 118 | + max as max, |
144 | 119 | maximum as maximum,
|
| 120 | + mean as mean, |
| 121 | + meshgrid as meshgrid, |
| 122 | + min as min, |
145 | 123 | minimum as minimum,
|
| 124 | + moveaxis as moveaxis, |
146 | 125 | multiply as multiply,
|
| 126 | + nan as nan, |
147 | 127 | negative as negative,
|
| 128 | + newaxis as newaxis, |
| 129 | + nonzero as nonzero, |
148 | 130 | not_equal as not_equal,
|
| 131 | + ones as ones, |
| 132 | + ones_like as ones_like, |
| 133 | + permute_dims as permute_dims, |
| 134 | + pi as pi, |
149 | 135 | positive as positive,
|
150 | 136 | pow as pow,
|
| 137 | + prod as prod, |
151 | 138 | real as real,
|
152 | 139 | remainder as remainder,
|
| 140 | + repeat as repeat, |
| 141 | + result_type as result_type, |
| 142 | + roll as roll, |
153 | 143 | round as round,
|
| 144 | + searchsorted as searchsorted, |
154 | 145 | sign as sign,
|
155 | 146 | signbit as signbit,
|
156 | 147 | sin as sin,
|
157 | 148 | sinh as sinh,
|
| 149 | + sort as sort, |
158 | 150 | sqrt as sqrt,
|
159 | 151 | square as square,
|
| 152 | + squeeze as squeeze, |
| 153 | + stack as stack, |
160 | 154 | subtract as subtract,
|
| 155 | + sum as sum, |
| 156 | + take as take, |
161 | 157 | tan as tan,
|
162 | 158 | tanh as tanh,
|
163 |
| - trunc as trunc, |
164 |
| -) |
165 |
| - |
166 |
| -from jax.experimental.array_api._indexing_functions import ( |
167 |
| - take as take, |
| 159 | + tensordot as tensordot, |
| 160 | + tile as tile, |
| 161 | + tril as tril, |
| 162 | + triu as triu, |
| 163 | + uint16 as uint16, |
| 164 | + uint32 as uint32, |
| 165 | + uint64 as uint64, |
| 166 | + uint8 as uint8, |
| 167 | + unique_all as unique_all, |
| 168 | + unique_counts as unique_counts, |
| 169 | + unique_inverse as unique_inverse, |
| 170 | + unique_values as unique_values, |
| 171 | + unstack as unstack, |
| 172 | + vecdot as vecdot, |
| 173 | + where as where, |
| 174 | + zeros as zeros, |
| 175 | + zeros_like as zeros_like, |
168 | 176 | )
|
169 | 177 |
|
170 | 178 | from jax.experimental.array_api._manipulation_functions import (
|
171 |
| - broadcast_arrays as broadcast_arrays, |
172 |
| - broadcast_to as broadcast_to, |
173 |
| - concat as concat, |
174 |
| - expand_dims as expand_dims, |
175 |
| - flip as flip, |
176 |
| - moveaxis as moveaxis, |
177 |
| - permute_dims as permute_dims, |
178 |
| - repeat as repeat, |
179 | 179 | reshape as reshape,
|
180 |
| - roll as roll, |
181 |
| - squeeze as squeeze, |
182 |
| - stack as stack, |
183 |
| - tile as tile, |
184 |
| - unstack as unstack, |
185 | 180 | )
|
186 | 181 |
|
187 |
| -from jax.experimental.array_api._searching_functions import ( |
188 |
| - argmax as argmax, |
189 |
| - argmin as argmin, |
190 |
| - nonzero as nonzero, |
191 |
| - searchsorted as searchsorted, |
192 |
| - where as where, |
| 182 | +from jax.experimental.array_api._creation_functions import ( |
| 183 | + arange as arange, |
| 184 | + asarray as asarray, |
| 185 | + eye as eye, |
| 186 | + linspace as linspace, |
193 | 187 | )
|
194 | 188 |
|
195 |
| -from jax.experimental.array_api._set_functions import ( |
196 |
| - unique_all as unique_all, |
197 |
| - unique_counts as unique_counts, |
198 |
| - unique_inverse as unique_inverse, |
199 |
| - unique_values as unique_values, |
| 189 | +from jax.experimental.array_api._data_type_functions import ( |
| 190 | + astype as astype, |
| 191 | + finfo as finfo, |
200 | 192 | )
|
201 | 193 |
|
202 |
| -from jax.experimental.array_api._sorting_functions import ( |
203 |
| - argsort as argsort, |
204 |
| - sort as sort, |
| 194 | +from jax.experimental.array_api._elementwise_functions import ( |
| 195 | + ceil as ceil, |
| 196 | + clip as clip, |
| 197 | + floor as floor, |
| 198 | + hypot as hypot, |
| 199 | + trunc as trunc, |
205 | 200 | )
|
206 | 201 |
|
207 | 202 | from jax.experimental.array_api._statistical_functions import (
|
208 |
| - cumulative_sum as cumulative_sum, |
209 |
| - max as max, |
210 |
| - mean as mean, |
211 |
| - min as min, |
212 |
| - prod as prod, |
213 | 203 | std as std,
|
214 |
| - sum as sum, |
215 |
| - var as var |
| 204 | + var as var, |
216 | 205 | )
|
217 | 206 |
|
218 | 207 | from jax.experimental.array_api._utility_functions import (
|
219 | 208 | __array_namespace_info__ as __array_namespace_info__,
|
220 |
| - all as all, |
221 |
| - any as any, |
222 |
| -) |
223 |
| - |
224 |
| -from jax.experimental.array_api._linear_algebra_functions import ( |
225 |
| - matmul as matmul, |
226 |
| - matrix_transpose as matrix_transpose, |
227 |
| - tensordot as tensordot, |
228 |
| - vecdot as vecdot, |
229 | 209 | )
|
230 | 210 |
|
231 | 211 | from jax.experimental.array_api import _array_methods
|
|
0 commit comments