Skip to content

Commit b5abc8a

Browse files
authored
[SYCL] Add hasSpecialCaptures() constexpr function (#18386)
Detecting which kernel functions have special captures (i.e., anything that isn't a standard layout class or pointer) will enable us to introduce a fast path for those kernels. Note that supporting this functionality requires all of the kernel descriptor functions to be constexpr. This was already the case for the functions generated for the integration header, but was not true for some of the new builtins and placeholder functions in kernel_desc.hpp. --- As a note to reviewers: I do have some prototype functionality ready that uses this, but I thought that splitting things into two separate pull requests would make things easier to review. --------- Signed-off-by: John Pennycook <[email protected]>
1 parent cda5a47 commit b5abc8a

File tree

2 files changed

+56
-6
lines changed

2 files changed

+56
-6
lines changed

sycl/include/sycl/detail/kernel_desc.hpp

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ template <auto &SpecName> const char *get_spec_constant_symbolic_ID();
9595
#ifndef __SYCL_UNNAMED_LAMBDA__
9696
template <class KernelNameType> struct KernelInfo {
9797
static constexpr unsigned getNumParams() { return 0; }
98-
static const kernel_param_desc_t &getParamDesc(int) {
99-
static kernel_param_desc_t Dummy;
98+
static constexpr const kernel_param_desc_t &getParamDesc(int) {
10099
return Dummy;
101100
}
102101
static constexpr const char *getName() { return ""; }
@@ -106,12 +105,14 @@ template <class KernelNameType> struct KernelInfo {
106105
static constexpr unsigned getLineNumber() { return 0; }
107106
static constexpr unsigned getColumnNumber() { return 0; }
108107
static constexpr int64_t getKernelSize() { return 0; }
108+
109+
private:
110+
static constexpr kernel_param_desc_t Dummy{};
109111
};
110112
#else
111113
template <char...> struct KernelInfoData {
112114
static constexpr unsigned getNumParams() { return 0; }
113-
static const kernel_param_desc_t &getParamDesc(int) {
114-
static kernel_param_desc_t Dummy;
115+
static constexpr const kernel_param_desc_t &getParamDesc(int) {
115116
return Dummy;
116117
}
117118
static constexpr const char *getName() { return ""; }
@@ -121,6 +122,9 @@ template <char...> struct KernelInfoData {
121122
static constexpr unsigned getLineNumber() { return 0; }
122123
static constexpr unsigned getColumnNumber() { return 0; }
123124
static constexpr int64_t getKernelSize() { return 0; }
125+
126+
private:
127+
static constexpr kernel_param_desc_t Dummy{};
124128
};
125129

126130
// C++14 like index_sequence and make_index_sequence
@@ -154,7 +158,7 @@ template <class KernelNameType> struct KernelInfo {
154158
static constexpr unsigned getNumParams() {
155159
return SubKernelInfo::getNumParams();
156160
}
157-
static const kernel_param_desc_t &getParamDesc(int Idx) {
161+
static constexpr const kernel_param_desc_t &getParamDesc(int Idx) {
158162
return SubKernelInfo::getParamDesc(Idx);
159163
}
160164
static constexpr const char *getName() { return SubKernelInfo::getName(); }
@@ -186,7 +190,7 @@ template <typename KernelNameType> constexpr unsigned getKernelNumParams() {
186190
}
187191

188192
template <typename KernelNameType>
189-
kernel_param_desc_t getKernelParamDesc(int Idx) {
193+
constexpr kernel_param_desc_t getKernelParamDesc(int Idx) {
190194
#ifndef __INTEL_SYCL_USE_INTEGRATION_HEADERS
191195
kernel_param_desc_t ParamDesc;
192196
ParamDesc.kind =
@@ -255,6 +259,19 @@ template <typename KernelNameType> constexpr int64_t getKernelSize() {
255259
// cases with external host compiler, which use integration headers.
256260
return KernelInfo<KernelNameType>::getKernelSize();
257261
}
262+
263+
template <typename KernelNameType> constexpr bool hasSpecialCaptures() {
264+
bool FoundSpecialCapture = false;
265+
for (int I = 0; I < getKernelNumParams<KernelNameType>(); ++I) {
266+
auto ParamDesc = getKernelParamDesc<KernelNameType>(I);
267+
bool IsSpecialCapture =
268+
(ParamDesc.kind != kernel_param_kind_t::kind_std_layout &&
269+
ParamDesc.kind != kernel_param_kind_t::kind_pointer);
270+
FoundSpecialCapture |= IsSpecialCapture;
271+
}
272+
return FoundSpecialCapture;
273+
}
274+
258275
} // namespace detail
259276
} // namespace _V1
260277
} // namespace sycl
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: %{build} -fsyntax-only -o %t.out
2+
3+
#include <sycl/detail/kernel_desc.hpp>
4+
#include <sycl/queue.hpp>
5+
6+
using namespace sycl;
7+
8+
class A;
9+
class B;
10+
11+
int main() {
12+
13+
queue Queue;
14+
15+
// No special captures; only values and pointers.
16+
int Value;
17+
int *Pointer;
18+
Queue.parallel_for<A>(nd_range<1>{1, 1},
19+
[=](nd_item<1> Item) { *Pointer += Value; });
20+
#ifndef __SYCL_DEVICE_ONLY__
21+
static_assert(!detail::hasSpecialCaptures<A>());
22+
#endif
23+
24+
// An accessor is a special capture.
25+
accessor<int> Accessor;
26+
Queue.parallel_for<B>(nd_range<1>{1, 1}, [=](nd_item<1> Item) {
27+
*Pointer += Value;
28+
Accessor[0] += Value;
29+
});
30+
#ifndef __SYCL_DEVICE_ONLY__
31+
static_assert(detail::hasSpecialCaptures<B>());
32+
#endif
33+
}

0 commit comments

Comments
 (0)