@@ -163,24 +163,87 @@ def sync(device=None):
163
163
safe_call (backend .get ().af_sync (dev ))
164
164
165
165
def __eval (* args ):
166
- for A in args :
167
- if isinstance (A , tuple ):
168
- __eval (* A )
169
- if isinstance (A , list ):
170
- __eval (* A )
171
- if isinstance (A , Array ):
172
- safe_call (backend .get ().af_eval (A .arr ))
166
+ nargs = len (args )
167
+ if (nargs == 1 ):
168
+ safe_call (backend .get ().af_eval (args [0 ].arr ))
169
+ else :
170
+ c_void_p_n = ct .c_void_p * nargs
171
+ arrs = c_void_p_n ()
172
+ for n in range (nargs ):
173
+ arrs [n ] = args [n ].arr
174
+ safe_call (backend .get ().af_eval_multiple (ct .c_int (nargs ), ct .pointer (arrs )))
175
+ return
173
176
174
177
def eval (* args ):
175
178
"""
176
- Evaluate the input
179
+ Evaluate one or more inputs together
177
180
178
181
Parameters
179
182
-----------
180
183
args : arguments to be evaluated
184
+
185
+ Note
186
+ -----
187
+
188
+ All the input arrays to this function should be of the same size.
189
+
190
+ Examples
191
+ --------
192
+
193
+ >>> a = af.constant(1, 3, 3)
194
+ >>> b = af.constant(2, 3, 3)
195
+ >>> c = a + b
196
+ >>> d = a - b
197
+ >>> af.eval(c, d) # A single kernel is launched here
198
+ >>> c
199
+ arrayfire.Array()
200
+ Type: float
201
+ [3 3 1 1]
202
+ 3.0000 3.0000 3.0000
203
+ 3.0000 3.0000 3.0000
204
+ 3.0000 3.0000 3.0000
205
+
206
+ >>> d
207
+ arrayfire.Array()
208
+ Type: float
209
+ [3 3 1 1]
210
+ -1.0000 -1.0000 -1.0000
211
+ -1.0000 -1.0000 -1.0000
212
+ -1.0000 -1.0000 -1.0000
213
+ """
214
+ for arg in args :
215
+ if not isinstance (arg , Array ):
216
+ raise RuntimeError ("All inputs to eval must be of type arrayfire.Array" )
217
+
218
+ __eval (* args )
219
+
220
+ def set_manual_eval_flag (flag ):
221
+ """
222
+ Tells the backend JIT engine to disable heuristics for determining when to evaluate a JIT tree.
223
+
224
+ Parameters
225
+ ----------
226
+
227
+ flag : optional: bool.
228
+ - Specifies if the heuristic evaluation of the JIT tree needs to be disabled.
229
+
230
+ Note
231
+ ----
232
+ This does not affect the evaluation that occurs when a non JIT function forces the evaluation.
181
233
"""
234
+ safe_call (backend .get ().af_set_manual_eval_flag (flag ))
182
235
183
- __eval (args )
236
+ def get_manual_eval_flag ():
237
+ """
238
+ Query the backend JIT engine to see if the user disabled heuristic evaluation of the JIT tree.
239
+
240
+ Note
241
+ ----
242
+ This does not affect the evaluation that occurs when a non JIT function forces the evaluation.
243
+ """
244
+ res = ct .c_bool (False )
245
+ safe_call (backend .get ().af_get_manual_eval_flag (ct .pointer (res )))
246
+ return res .value
184
247
185
248
def device_mem_info ():
186
249
"""
@@ -258,10 +321,27 @@ def lock_array(a):
258
321
259
322
Note
260
323
-----
261
- - The device pointer of `a` is not freed by memory manager until `unlock_device_ptr ()` is called.
324
+ - The device pointer of `a` is not freed by memory manager until `unlock_array ()` is called.
262
325
"""
263
326
safe_call (backend .get ().af_lock_array (a .arr ))
264
327
328
+ def is_locked_array (a ):
329
+ """
330
+ Check if the input array is locked by the user.
331
+
332
+ Parameters
333
+ ----------
334
+ a: af.Array
335
+ - A multi dimensional arrayfire array.
336
+
337
+ Returns
338
+ -----------
339
+ A bool specifying if the input array is locked.
340
+ """
341
+ res = ct .c_bool (False )
342
+ safe_call (backend .get ().af_is_locked_array (ct .pointer (res ), a .arr ))
343
+ return res .value
344
+
265
345
def unlock_device_ptr (a ):
266
346
"""
267
347
This functions is deprecated. Please use unlock_array instead.
0 commit comments