Skip to content

Commit 7594769

Browse files
committed
Adding functions from array.h and device.h
- is_sparse - is_locked_array - modified eval to use af_eval_multiple - set_manual_eval_flag - get_manual_eval_flag - Added necessary tests
1 parent 2f8503d commit 7594769

File tree

4 files changed

+119
-10
lines changed

4 files changed

+119
-10
lines changed

arrayfire/array.py

+8
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,14 @@ def is_vector(self):
667667
safe_call(backend.get().af_is_vector(ct.pointer(res), self.arr))
668668
return res.value
669669

670+
def is_sparse(self):
671+
"""
672+
Check if the array is a sparse matrix.
673+
"""
674+
res = ct.c_bool(False)
675+
safe_call(backend.get().af_is_sparse(ct.pointer(res), self.arr))
676+
return res.value
677+
670678
def is_complex(self):
671679
"""
672680
Check if the array is of complex type.

arrayfire/device.py

+90-10
Original file line numberDiff line numberDiff line change
@@ -163,24 +163,87 @@ def sync(device=None):
163163
safe_call(backend.get().af_sync(dev))
164164

165165
def __eval(*args):
166-
for A in args:
167-
if isinstance(A, tuple):
168-
__eval(*A)
169-
if isinstance(A, list):
170-
__eval(*A)
171-
if isinstance(A, Array):
172-
safe_call(backend.get().af_eval(A.arr))
166+
nargs = len(args)
167+
if (nargs == 1):
168+
safe_call(backend.get().af_eval(args[0].arr))
169+
else:
170+
c_void_p_n = ct.c_void_p * nargs
171+
arrs = c_void_p_n()
172+
for n in range(nargs):
173+
arrs[n] = args[n].arr
174+
safe_call(backend.get().af_eval_multiple(ct.c_int(nargs), ct.pointer(arrs)))
175+
return
173176

174177
def eval(*args):
175178
"""
176-
Evaluate the input
179+
Evaluate one or more inputs together
177180
178181
Parameters
179182
-----------
180183
args : arguments to be evaluated
184+
185+
Note
186+
-----
187+
188+
All the input arrays to this function should be of the same size.
189+
190+
Examples
191+
--------
192+
193+
>>> a = af.constant(1, 3, 3)
194+
>>> b = af.constant(2, 3, 3)
195+
>>> c = a + b
196+
>>> d = a - b
197+
>>> af.eval(c, d) # A single kernel is launched here
198+
>>> c
199+
arrayfire.Array()
200+
Type: float
201+
[3 3 1 1]
202+
3.0000 3.0000 3.0000
203+
3.0000 3.0000 3.0000
204+
3.0000 3.0000 3.0000
205+
206+
>>> d
207+
arrayfire.Array()
208+
Type: float
209+
[3 3 1 1]
210+
-1.0000 -1.0000 -1.0000
211+
-1.0000 -1.0000 -1.0000
212+
-1.0000 -1.0000 -1.0000
213+
"""
214+
for arg in args:
215+
if not isinstance(arg, Array):
216+
raise RuntimeError("All inputs to eval must be of type arrayfire.Array")
217+
218+
__eval(*args)
219+
220+
def set_manual_eval_flag(flag):
221+
"""
222+
Tells the backend JIT engine to disable heuristics for determining when to evaluate a JIT tree.
223+
224+
Parameters
225+
----------
226+
227+
flag : optional: bool.
228+
- Specifies if the heuristic evaluation of the JIT tree needs to be disabled.
229+
230+
Note
231+
----
232+
This does not affect the evaluation that occurs when a non JIT function forces the evaluation.
181233
"""
234+
safe_call(backend.get().af_set_manual_eval_flag(flag))
182235

183-
__eval(args)
236+
def get_manual_eval_flag():
237+
"""
238+
Query the backend JIT engine to see if the user disabled heuristic evaluation of the JIT tree.
239+
240+
Note
241+
----
242+
This does not affect the evaluation that occurs when a non JIT function forces the evaluation.
243+
"""
244+
res = ct.c_bool(False)
245+
safe_call(backend.get().af_get_manual_eval_flag(ct.pointer(res)))
246+
return res.value
184247

185248
def device_mem_info():
186249
"""
@@ -258,10 +321,27 @@ def lock_array(a):
258321
259322
Note
260323
-----
261-
- The device pointer of `a` is not freed by memory manager until `unlock_device_ptr()` is called.
324+
- The device pointer of `a` is not freed by memory manager until `unlock_array()` is called.
262325
"""
263326
safe_call(backend.get().af_lock_array(a.arr))
264327

328+
def is_locked_array(a):
329+
"""
330+
Check if the input array is locked by the user.
331+
332+
Parameters
333+
----------
334+
a: af.Array
335+
- A multi dimensional arrayfire array.
336+
337+
Returns
338+
-----------
339+
A bool specifying if the input array is locked.
340+
"""
341+
res = ct.c_bool(False)
342+
safe_call(backend.get().af_is_locked_array(ct.pointer(res), a.arr))
343+
return res.value
344+
265345
def unlock_device_ptr(a):
266346
"""
267347
This functions is deprecated. Please use unlock_array instead.

arrayfire/tests/simple/array_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,6 @@ def simple_array(verbose=False):
6060
print_func(arr)
6161
print_func(lst)
6262

63+
print_func(a.is_sparse())
64+
6365
_util.tests['array'] = simple_array

arrayfire/tests/simple/device.py

+19
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,23 @@ def simple_device(verbose=False):
5151
af.lock_array(c)
5252
af.unlock_array(c)
5353

54+
a = af.constant(1, 3, 3)
55+
b = af.constant(2, 3, 3)
56+
af.eval(a)
57+
af.eval(b)
58+
print_func(a)
59+
print_func(b)
60+
c = a + b
61+
d = a - b
62+
af.eval(c, d)
63+
print_func(c)
64+
print_func(d)
65+
66+
print_func(af.set_manual_eval_flag(True))
67+
assert(af.get_manual_eval_flag() == True)
68+
print_func(af.set_manual_eval_flag(False))
69+
assert(af.get_manual_eval_flag() == False)
70+
71+
display_func(af.is_locked_array(a))
72+
5473
_util.tests['device'] = simple_device

0 commit comments

Comments
 (0)