Skip to content

Commit bc861ab

Browse files
MemoryUSM* classes have have alignment option
queue can no longer be specified via positional argument, only through a keyword to allow a user to specify alignment but not queue. SYCL spec says that aligned allocation may return null pointer when the requested alignment is not supported by the device. Non-positive alignments silently go unused (i.e. DPPLmalloc_* is instead of DPPL_aligned_alloc_*)
1 parent 4138b76 commit bc861ab

File tree

4 files changed

+61
-21
lines changed

4 files changed

+61
-21
lines changed

dpctl/_memory.pxd

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ cdef class Memory:
3232
cdef object refobj
3333

3434
cdef _cinit_empty(self)
35-
cdef _cinit_alloc(self, Py_ssize_t nbytes, bytes ptr_type, SyclQueue queue)
35+
cdef _cinit_alloc(self, Py_ssize_t alignment, Py_ssize_t nbytes,
36+
bytes ptr_type, SyclQueue queue)
3637
cdef _cinit_other(self, object other)
3738
cdef _getbuffer(self, Py_buffer *buffer, int flags)
3839

dpctl/_memory.pyx

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ cdef class Memory:
160160
self.queue = None
161161
self.refobj = None
162162

163-
cdef _cinit_alloc(self, Py_ssize_t nbytes, bytes ptr_type, SyclQueue queue):
163+
cdef _cinit_alloc(self, Py_ssize_t alignment, Py_ssize_t nbytes,
164+
bytes ptr_type, SyclQueue queue):
164165
cdef DPPLSyclUSMRef p
165166

166167
self._cinit_empty()
@@ -170,11 +171,23 @@ cdef class Memory:
170171
queue = get_current_queue()
171172

172173
if (ptr_type == b"shared"):
173-
p = DPPLmalloc_shared(nbytes, queue.get_queue_ref())
174+
if alignment > 0:
175+
p = DPPLaligned_alloc_shared(alignment, nbytes,
176+
queue.get_queue_ref())
177+
else:
178+
p = DPPLmalloc_shared(nbytes, queue.get_queue_ref())
174179
elif (ptr_type == b"host"):
175-
p = DPPLmalloc_host(nbytes, queue.get_queue_ref())
180+
if alignment > 0:
181+
p = DPPLaligned_alloc_host(alignment, nbytes,
182+
queue.get_queue_ref())
183+
else:
184+
p = DPPLmalloc_host(nbytes, queue.get_queue_ref())
176185
elif (ptr_type == b"device"):
177-
p = DPPLmalloc_device(nbytes, queue.get_queue_ref())
186+
if (alignment > 0):
187+
p = DPPLaligned_alloc_device(alignment, nbytes,
188+
queue.get_queue_ref())
189+
else:
190+
p = DPPLmalloc_device(nbytes, queue.get_queue_ref())
178191
else:
179192
raise RuntimeError("Pointer type is unknown: {}" \
180193
.format(ptr_type.decode("UTF-8")))
@@ -391,10 +404,19 @@ cdef class Memory:
391404

392405

393406
cdef class MemoryUSMShared(Memory):
407+
"""
408+
MemoryUSMShared(nbytes, alignment=0, queue=None) allocates nbytes of USM shared memory.
409+
410+
Non-positive alignments are not used (malloc_shared is used instead).
411+
The queue=None the current `dpctl.get_current_queue()` is used to allocate memory.
394412
395-
def __cinit__(self, other, SyclQueue queue=None):
396-
if isinstance(other, int):
397-
self._cinit_alloc(<Py_ssize_t>other, b"shared", queue)
413+
MemoryUSMShared(usm_obj) constructor create instance from `usm_obj` expected to
414+
implement `__sycl_usm_array_interface__` protocol and exposing a contiguous block of
415+
USM memory.
416+
"""
417+
def __cinit__(self, other, *, Py_ssize_t alignment=0, SyclQueue queue=None):
418+
if (isinstance(other, int)):
419+
self._cinit_alloc(alignment, <Py_ssize_t>other, b"shared", queue)
398420
else:
399421
self._cinit_other(other)
400422

@@ -404,9 +426,9 @@ cdef class MemoryUSMShared(Memory):
404426

405427
cdef class MemoryUSMHost(Memory):
406428

407-
def __cinit__(self, other, SyclQueue queue=None):
408-
if isinstance(other, int):
409-
self._cinit_alloc(<Py_ssize_t>other, b"host", queue)
429+
def __cinit__(self, other, *, Py_ssize_t alignment=0, SyclQueue queue=None):
430+
if (isinstance(other, int)):
431+
self._cinit_alloc(alignment, <Py_ssize_t>other, b"host", queue)
410432
else:
411433
self._cinit_other(other)
412434

@@ -416,8 +438,8 @@ cdef class MemoryUSMHost(Memory):
416438

417439
cdef class MemoryUSMDevice(Memory):
418440

419-
def __cinit__(self, other, SyclQueue queue=None):
420-
if isinstance(other, int):
421-
self._cinit_alloc(<Py_ssize_t>other, b"device", queue)
441+
def __cinit__(self, other, *, Py_ssize_t alignment=0, SyclQueue queue=None):
442+
if (isinstance(other, int)):
443+
self._cinit_alloc(alignment, <Py_ssize_t>other, b"device", queue)
422444
else:
423445
self._cinit_other(other)

dpctl/tests/test_sycl_queue_memcpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@
2323
##===----------------------------------------------------------------------===##
2424

2525
import dpctl
26+
import dpctl.memory
2627
import unittest
2728

2829

2930
class TestQueueMemcpy(unittest.TestCase):
3031
def _create_memory(self):
3132
nbytes = 1024
32-
queue = dpctl.get_current_queue()
33-
mobj = dpctl._memory.MemoryUSMShared(nbytes, queue)
33+
mobj = dpctl.memory.MemoryUSMShared(nbytes)
3434
return mobj
3535

3636
@unittest.skipUnless(

dpctl/tests/test_sycl_usm.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@ class TestMemory(unittest.TestCase):
3434
def test_memory_create(self):
3535
nbytes = 1024
3636
queue = dpctl.get_current_queue()
37-
mobj = MemoryUSMShared(nbytes, queue)
37+
mobj = MemoryUSMShared(nbytes, alignment=64, queue=queue)
3838
self.assertEqual(mobj.nbytes, nbytes)
3939
self.assertTrue(hasattr(mobj, "__sycl_usm_array_interface__"))
4040

4141
def _create_memory(self):
4242
nbytes = 1024
4343
queue = dpctl.get_current_queue()
44-
mobj = MemoryUSMShared(nbytes, queue)
44+
mobj = MemoryUSMShared(nbytes, alignment=64, queue=queue)
4545
return mobj
4646

4747
def _create_host_buf(self, nbytes):
@@ -156,16 +156,33 @@ class TestMemoryUSMBase:
156156
@unittest.skipUnless(
157157
dpctl.has_sycl_platforms(), "No SYCL devices except the default host device."
158158
)
159-
def test_create_with_queue(self):
159+
def test_create_with_size_and_alignment_and_queue(self):
160160
q = dpctl.get_current_queue()
161-
m = self.MemoryUSMClass(1024, q)
161+
m = self.MemoryUSMClass(1024, alignment=64, queue=q)
162162
self.assertEqual(m.nbytes, 1024)
163163
self.assertEqual(m.get_usm_type(), self.usm_type)
164164

165165
@unittest.skipUnless(
166166
dpctl.has_sycl_platforms(), "No SYCL devices except the default host device."
167167
)
168-
def test_create_without_queue(self):
168+
def test_create_with_size_and_queue(self):
169+
q = dpctl.get_current_queue()
170+
m = self.MemoryUSMClass(1024, queue=q)
171+
self.assertEqual(m.nbytes, 1024)
172+
self.assertEqual(m.get_usm_type(), self.usm_type)
173+
174+
@unittest.skipUnless(
175+
dpctl.has_sycl_platforms(), "No SYCL devices except the default host device."
176+
)
177+
def test_create_with_size_and_alignment(self):
178+
m = self.MemoryUSMClass(1024, alignment=64)
179+
self.assertEqual(m.nbytes, 1024)
180+
self.assertEqual(m.get_usm_type(), self.usm_type)
181+
182+
@unittest.skipUnless(
183+
dpctl.has_sycl_platforms(), "No SYCL devices except the default host device."
184+
)
185+
def test_create_with_only_size(self):
169186
m = self.MemoryUSMClass(1024)
170187
self.assertEqual(m.nbytes, 1024)
171188
self.assertEqual(m.get_usm_type(), self.usm_type)

0 commit comments

Comments
 (0)