Skip to content

Commit 4f187a8

Browse files
Merge pull request #557 from IntelPython/feature/DPCTLQueue_Memcpy_async
DPCTLQueue_Memcpy, _Prefetch, _Memadvise become asynchronous
2 parents 6b8e847 + 583fbf7 commit 4f187a8

7 files changed

+184
-51
lines changed

dpctl-capi/include/dpctl_sycl_queue_interface.h

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ DPCTLQueue_SubmitNDRange(__dpctl_keep const DPCTLSyclKernelRef KRef,
266266
size_t NDepEvents);
267267

268268
/*!
269-
* @brief Calls the ``sycl::queue.submit`` function to do a blocking wait on
269+
* @brief Calls the ``sycl::queue::submit`` function to do a blocking wait on
270270
* all enqueued tasks in the queue.
271271
*
272272
* @param QRef Opaque pointer to a ``sycl::queue``.
@@ -276,52 +276,55 @@ DPCTL_API
276276
void DPCTLQueue_Wait(__dpctl_keep const DPCTLSyclQueueRef QRef);
277277

278278
/*!
279-
* @brief C-API wrapper for ``sycl::queue::memcpy``, the function waits on an
280-
* event till the memcpy operation completes.
279+
* @brief C-API wrapper for ``sycl::queue::memcpy``.
281280
*
282281
* @param QRef An opaque pointer to the ``sycl::queue``.
283282
* @param Dest An USM pointer to the destination memory.
284283
* @param Src An USM pointer to the source memory.
285284
* @param Count A number of bytes to copy.
285+
* @return An opaque pointer to the ``sycl::event`` returned by the
286+
* ``sycl::queue::memcpy`` function.
286287
* @ingroup QueueInterface
287288
*/
288289
DPCTL_API
289-
void DPCTLQueue_Memcpy(__dpctl_keep const DPCTLSyclQueueRef QRef,
290-
void *Dest,
291-
const void *Src,
292-
size_t Count);
290+
DPCTLSyclEventRef DPCTLQueue_Memcpy(__dpctl_keep const DPCTLSyclQueueRef QRef,
291+
void *Dest,
292+
const void *Src,
293+
size_t Count);
293294

294295
/*!
295-
* @brief C-API wrapper for ``sycl::queue::prefetch``, the function waits on an
296-
* event till the prefetch operation completes.
296+
* @brief C-API wrapper for ``sycl::queue::prefetch``.
297297
*
298298
* @param QRef An opaque pointer to the ``sycl::queue``.
299299
* @param Ptr An USM pointer to memory.
300300
* @param Count A number of bytes to prefetch.
301+
* @return An opaque pointer to the ``sycl::event`` returned by the
302+
* ``sycl::queue::prefetch`` function.
301303
* @ingroup QueueInterface
302304
*/
303305
DPCTL_API
304-
void DPCTLQueue_Prefetch(__dpctl_keep DPCTLSyclQueueRef QRef,
305-
const void *Ptr,
306-
size_t Count);
306+
DPCTLSyclEventRef DPCTLQueue_Prefetch(__dpctl_keep DPCTLSyclQueueRef QRef,
307+
const void *Ptr,
308+
size_t Count);
307309

308310
/*!
309-
* @brief C-API wrapper for sycl::queue::mem_advise, the function waits on an
310-
* event till the operation completes.
311+
* @brief C-API wrapper for ``sycl::queue::mem_advise``.
311312
*
312313
* @param QRef An opaque pointer to the ``sycl::queue``.
313314
* @param Ptr An USM pointer to memory.
314315
* @param Count A number of bytes to prefetch.
315316
* @param Advice Device-defined advice for the specified allocation.
316317
* A value of 0 reverts the advice for Ptr to the
317318
* default behavior.
319+
* @return An opaque pointer to the ``sycl::event`` returned by the
320+
* ``sycl::queue::mem_advise`` function.
318321
* @ingroup QueueInterface
319322
*/
320323
DPCTL_API
321-
void DPCTLQueue_MemAdvise(__dpctl_keep DPCTLSyclQueueRef QRef,
322-
const void *Ptr,
323-
size_t Count,
324-
int Advice);
324+
DPCTLSyclEventRef DPCTLQueue_MemAdvise(__dpctl_keep DPCTLSyclQueueRef QRef,
325+
const void *Ptr,
326+
size_t Count,
327+
int Advice);
325328

326329
/*!
327330
* @brief C-API wrapper for sycl::queue::is_in_order that indicates whether

dpctl-capi/source/dpctl_sycl_queue_interface.cpp

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -484,39 +484,74 @@ void DPCTLQueue_Wait(__dpctl_keep DPCTLSyclQueueRef QRef)
484484
}
485485
}
486486

487-
void DPCTLQueue_Memcpy(__dpctl_keep const DPCTLSyclQueueRef QRef,
488-
void *Dest,
489-
const void *Src,
490-
size_t Count)
487+
DPCTLSyclEventRef DPCTLQueue_Memcpy(__dpctl_keep const DPCTLSyclQueueRef QRef,
488+
void *Dest,
489+
const void *Src,
490+
size_t Count)
491491
{
492492
auto Q = unwrap(QRef);
493493
if (Q) {
494-
auto event = Q->memcpy(Dest, Src, Count);
495-
event.wait();
494+
sycl::event ev;
495+
try {
496+
ev = Q->memcpy(Dest, Src, Count);
497+
} catch (const sycl::runtime_error &re) {
498+
// todo: log error
499+
std::cerr << re.what() << '\n';
500+
return nullptr;
501+
}
502+
return wrap(new event(ev));
503+
}
504+
else {
505+
// todo: log error
506+
std::cerr << "QRef passed to memcpy was NULL" << '\n';
507+
return nullptr;
496508
}
497509
}
498510

499-
void DPCTLQueue_Prefetch(__dpctl_keep DPCTLSyclQueueRef QRef,
500-
const void *Ptr,
501-
size_t Count)
511+
DPCTLSyclEventRef DPCTLQueue_Prefetch(__dpctl_keep DPCTLSyclQueueRef QRef,
512+
const void *Ptr,
513+
size_t Count)
502514
{
503515
auto Q = unwrap(QRef);
504516
if (Q) {
505-
auto event = Q->prefetch(Ptr, Count);
506-
event.wait();
517+
sycl::event ev;
518+
try {
519+
ev = Q->prefetch(Ptr, Count);
520+
} catch (sycl::runtime_error &re) {
521+
// todo: log error
522+
std::cerr << re.what() << '\n';
523+
return nullptr;
524+
}
525+
return wrap(new event(ev));
526+
}
527+
else {
528+
// todo: log error
529+
std::cerr << "QRef passed to prefetch was NULL" << '\n';
530+
return nullptr;
507531
}
508532
}
509533

510-
void DPCTLQueue_MemAdvise(__dpctl_keep DPCTLSyclQueueRef QRef,
511-
const void *Ptr,
512-
size_t Count,
513-
int Advice)
534+
DPCTLSyclEventRef DPCTLQueue_MemAdvise(__dpctl_keep DPCTLSyclQueueRef QRef,
535+
const void *Ptr,
536+
size_t Count,
537+
int Advice)
514538
{
515539
auto Q = unwrap(QRef);
516540
if (Q) {
517-
auto event =
518-
Q->mem_advise(Ptr, Count, static_cast<pi_mem_advice>(Advice));
519-
event.wait();
541+
sycl::event ev;
542+
try {
543+
ev = Q->mem_advise(Ptr, Count, static_cast<pi_mem_advice>(Advice));
544+
} catch (const sycl::runtime_error &re) {
545+
// todo: log error
546+
std::cerr << re.what() << '\n';
547+
return nullptr;
548+
}
549+
return wrap(new event(ev));
550+
}
551+
else {
552+
// todo: log error
553+
std::cerr << "QRef passed to prefetch was NULL" << '\n';
554+
return nullptr;
520555
}
521556
}
522557

dpctl-capi/tests/test_sycl_queue_interface.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,24 @@ TEST(TestDPCTLSyclQueueInterface, CheckPropertyHandling)
327327
EXPECT_NO_FATAL_FAILURE(DPCTLDeviceSelector_Delete(DSRef));
328328
}
329329

330+
TEST(TestDPCTLSyclQueueInterface, CheckMemOpsZeroQRef)
331+
{
332+
DPCTLSyclQueueRef QRef = nullptr;
333+
void *p1 = nullptr;
334+
void *p2 = nullptr;
335+
size_t n_bytes = 0;
336+
DPCTLSyclEventRef ERef = nullptr;
337+
338+
ASSERT_NO_FATAL_FAILURE(ERef = DPCTLQueue_Memcpy(QRef, p1, p2, n_bytes));
339+
ASSERT_FALSE(bool(ERef));
340+
341+
ASSERT_NO_FATAL_FAILURE(ERef = DPCTLQueue_Prefetch(QRef, p1, n_bytes));
342+
ASSERT_FALSE(bool(ERef));
343+
344+
ASSERT_NO_FATAL_FAILURE(ERef = DPCTLQueue_MemAdvise(QRef, p1, n_bytes, 0));
345+
ASSERT_FALSE(bool(ERef));
346+
}
347+
330348
TEST_P(TestDPCTLQueueMemberFunctions, CheckGetBackend)
331349
{
332350
auto q = unwrap(QRef);
@@ -364,6 +382,31 @@ TEST_P(TestDPCTLQueueMemberFunctions, CheckGetDevice)
364382
EXPECT_NO_FATAL_FAILURE(DPCTLDevice_Delete(D));
365383
}
366384

385+
TEST_P(TestDPCTLQueueMemberFunctions, CheckMemOpsNullPtr)
386+
{
387+
void *p1 = nullptr;
388+
void *p2 = nullptr;
389+
size_t n_bytes = 256;
390+
DPCTLSyclEventRef ERef = nullptr;
391+
392+
ASSERT_NO_FATAL_FAILURE(ERef = DPCTLQueue_Memcpy(QRef, p1, p2, n_bytes));
393+
ASSERT_FALSE(bool(ERef));
394+
395+
ASSERT_NO_FATAL_FAILURE(ERef = DPCTLQueue_Prefetch(QRef, p1, n_bytes));
396+
if (ERef) {
397+
ASSERT_NO_FATAL_FAILURE(DPCTLEvent_Wait(ERef));
398+
ASSERT_NO_FATAL_FAILURE(DPCTLEvent_Delete(ERef));
399+
ERef = nullptr;
400+
}
401+
402+
ASSERT_NO_FATAL_FAILURE(ERef = DPCTLQueue_MemAdvise(QRef, p1, n_bytes, 0));
403+
if (ERef) {
404+
ASSERT_NO_FATAL_FAILURE(DPCTLEvent_Wait(ERef));
405+
ASSERT_NO_FATAL_FAILURE(DPCTLEvent_Delete(ERef));
406+
ERef = nullptr;
407+
}
408+
}
409+
367410
INSTANTIATE_TEST_SUITE_P(
368411
DPCTLQueueMemberFuncTests,
369412
TestDPCTLQueueMemberFunctions,

dpctl-capi/tests/test_sycl_usm_interface.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,23 @@ void common_test_body(size_t nbytes,
5656
auto QueueDev = DPCTLQueue_GetDevice(Q);
5757
EXPECT_TRUE(DPCTLDevice_AreEq(Dev, QueueDev));
5858

59-
EXPECT_NO_FATAL_FAILURE(DPCTLQueue_Prefetch(Q, Ptr, nbytes));
60-
EXPECT_NO_FATAL_FAILURE(DPCTLQueue_MemAdvise(Q, Ptr, nbytes, 0));
61-
59+
DPCTLSyclEventRef E1Ref = nullptr, E2Ref = nullptr, E3Ref = nullptr;
60+
EXPECT_NO_FATAL_FAILURE(E1Ref = DPCTLQueue_Prefetch(Q, Ptr, nbytes));
61+
EXPECT_TRUE(E1Ref != nullptr);
62+
EXPECT_NO_FATAL_FAILURE(E2Ref = DPCTLQueue_MemAdvise(Q, Ptr, nbytes, 0));
63+
EXPECT_TRUE(E2Ref != nullptr);
64+
65+
EXPECT_NO_FATAL_FAILURE(DPCTLEvent_Wait(E1Ref));
66+
DPCTLEvent_Delete(E1Ref);
67+
EXPECT_NO_FATAL_FAILURE(DPCTLEvent_Wait(E2Ref));
68+
DPCTLEvent_Delete(E2Ref);
6269
try {
6370
unsigned short *host_ptr = new unsigned short[nbytes];
64-
EXPECT_NO_FATAL_FAILURE(DPCTLQueue_Memcpy(Q, host_ptr, Ptr, nbytes));
71+
EXPECT_NO_FATAL_FAILURE(
72+
E3Ref = DPCTLQueue_Memcpy(Q, host_ptr, Ptr, nbytes));
73+
EXPECT_TRUE(E3Ref != nullptr);
74+
EXPECT_NO_FATAL_FAILURE(DPCTLEvent_Wait(E3Ref));
75+
DPCTLEvent_Delete(E3Ref);
6576
delete[] host_ptr;
6677
} catch (std::bad_alloc const &ba) {
6778
// pass

dpctl/_backend.pxd

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,16 +355,16 @@ cdef extern from "dpctl_sycl_queue_interface.h":
355355
const DPCTLSyclEventRef *DepEvents,
356356
size_t NDepEvents)
357357
cdef void DPCTLQueue_Wait(const DPCTLSyclQueueRef QRef)
358-
cdef void DPCTLQueue_Memcpy(
358+
cdef DPCTLSyclEventRef DPCTLQueue_Memcpy(
359359
const DPCTLSyclQueueRef Q,
360360
void *Dest,
361361
const void *Src,
362362
size_t Count)
363-
cdef void DPCTLQueue_Prefetch(
363+
cdef DPCTLSyclEventRef DPCTLQueue_Prefetch(
364364
const DPCTLSyclQueueRef Q,
365365
const void *Src,
366366
size_t Count)
367-
cdef void DPCTLQueue_MemAdvise(
367+
cdef DPCTLSyclEventRef DPCTLQueue_MemAdvise(
368368
const DPCTLSyclQueueRef Q,
369369
const void *Src,
370370
size_t Count,

dpctl/_sycl_queue.pyx

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ from ._backend cimport ( # noqa: E211
3030
DPCTLDevice_Delete,
3131
DPCTLDeviceMgr_GetCachedContext,
3232
DPCTLDeviceSelector_Delete,
33+
DPCTLEvent_Delete,
34+
DPCTLEvent_Wait,
3335
DPCTLFilterSelector_Create,
3436
DPCTLQueue_AreEq,
3537
DPCTLQueue_Copy,
@@ -812,6 +814,7 @@ cdef class SyclQueue(_SyclQueue):
812814
cpdef memcpy(self, dest, src, size_t count):
813815
cdef void *c_dest
814816
cdef void *c_src
817+
cdef DPCTLSyclEventRef ERef = NULL
815818

816819
if isinstance(dest, _Memory):
817820
c_dest = <void*>(<_Memory>dest).memory_ptr
@@ -823,10 +826,17 @@ cdef class SyclQueue(_SyclQueue):
823826
else:
824827
raise TypeError("Parameter `src` should have type _Memory.")
825828

826-
DPCTLQueue_Memcpy(self._queue_ref, c_dest, c_src, count)
829+
ERef = DPCTLQueue_Memcpy(self._queue_ref, c_dest, c_src, count)
830+
if (ERef is NULL):
831+
raise RuntimeError(
832+
"SyclQueue.memcpy operation encountered an error"
833+
)
834+
DPCTLEvent_Wait(ERef)
835+
DPCTLEvent_Delete(ERef)
827836

828837
cpdef prefetch(self, mem, size_t count=0):
829838
cdef void *ptr
839+
cdef DPCTLSyclEventRef ERef = NULL
830840

831841
if isinstance(mem, _Memory):
832842
ptr = <void*>(<_Memory>mem).memory_ptr
@@ -836,10 +846,17 @@ cdef class SyclQueue(_SyclQueue):
836846
if (count <=0 or count > self.nbytes):
837847
count = self.nbytes
838848

839-
DPCTLQueue_Prefetch(self._queue_ref, ptr, count)
849+
ERef = DPCTLQueue_Prefetch(self._queue_ref, ptr, count)
850+
if (ERef is NULL):
851+
raise RuntimeError(
852+
"SyclQueue.prefetch encountered an error"
853+
)
854+
DPCTLEvent_Wait(ERef)
855+
DPCTLEvent_Delete(ERef)
840856

841857
cpdef mem_advise(self, mem, size_t count, int advice):
842858
cdef void *ptr
859+
cdef DPCTLSyclEventRef ERef = NULL
843860

844861
if isinstance(mem, _Memory):
845862
ptr = <void*>(<_Memory>mem).memory_ptr
@@ -849,7 +866,13 @@ cdef class SyclQueue(_SyclQueue):
849866
if (count <=0 or count > self.nbytes):
850867
count = self.nbytes
851868

852-
DPCTLQueue_MemAdvise(self._queue_ref, ptr, count, advice)
869+
ERef = DPCTLQueue_MemAdvise(self._queue_ref, ptr, count, advice)
870+
if (ERef is NULL):
871+
raise RuntimeError(
872+
"SyclQueue.mem_advise operation encountered an error"
873+
)
874+
DPCTLEvent_Wait(ERef)
875+
DPCTLEvent_Delete(ERef)
853876

854877
@property
855878
def is_in_order(self):

0 commit comments

Comments
 (0)