Skip to content

Commit 3530428

Browse files
authored
[OpenMP][NFC] Extract OffloadPolicy into a helper class (#74029)
OpenMP allows 3 different offload policies, handling of which we want to encapsulate.
1 parent 70187eb commit 3530428

File tree

4 files changed

+68
-41
lines changed

4 files changed

+68
-41
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
//===-- OffloadPolicy.h - Configuration of offload behavior -----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Configuration for offload behavior, e.g., if offload is disabled, can be
10+
// disabled, is mandatory, etc.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef OMPTARGET_OFFLOAD_POLICY_H
15+
#define OMPTARGET_OFFLOAD_POLICY_H
16+
17+
#include "PluginManager.h"
18+
19+
enum kmp_target_offload_kind_t {
20+
tgt_disabled = 0,
21+
tgt_default = 1,
22+
tgt_mandatory = 2
23+
};
24+
25+
extern "C" int __kmpc_get_target_offload(void) __attribute__((weak));
26+
27+
class OffloadPolicy {
28+
29+
OffloadPolicy(PluginManager &PM) {
30+
// TODO: Check for OpenMP.
31+
switch ((kmp_target_offload_kind_t)__kmpc_get_target_offload()) {
32+
case tgt_disabled:
33+
Kind = DISABLED;
34+
return;
35+
case tgt_mandatory:
36+
Kind = MANDATORY;
37+
return;
38+
default:
39+
if (PM.getNumDevices()) {
40+
DP("Default TARGET OFFLOAD policy is now mandatory "
41+
"(devices were found)\n");
42+
Kind = MANDATORY;
43+
} else {
44+
DP("Default TARGET OFFLOAD policy is now disabled "
45+
"(no devices were found)\n");
46+
Kind = DISABLED;
47+
}
48+
return;
49+
};
50+
}
51+
52+
public:
53+
static const OffloadPolicy &get(PluginManager &PM) {
54+
static OffloadPolicy OP(PM);
55+
return OP;
56+
}
57+
58+
enum OffloadPolicyKind { DISABLED, MANDATORY };
59+
60+
OffloadPolicyKind Kind = MANDATORY;
61+
};
62+
63+
#endif // OMPTARGET_OFFLOAD_POLICY_H

openmp/libomptarget/include/PluginManager.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,6 @@ struct PluginManager {
107107
HostPtrToTableMapTy HostPtrToTableMap;
108108
std::mutex TblMapMtx; ///< For HostPtrToTableMap
109109

110-
// Store target policy (disabled, mandatory, default)
111-
kmp_target_offload_kind_t TargetOffloadPolicy = tgt_default;
112-
std::mutex TargetOffloadMtx; ///< For TargetOffloadPolicy
113-
114110
// Work around for plugins that call dlopen on shared libraries that call
115111
// tgt_register_lib during their initialisation. Stash the pointers in a
116112
// vector until the plugins are all initialised and then register them.

openmp/libomptarget/include/device.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,6 @@ struct PluginAdaptorTy;
3333
struct __tgt_bin_desc;
3434
struct __tgt_target_table;
3535

36-
// enum for OMP_TARGET_OFFLOAD; keep in sync with kmp.h definition
37-
enum kmp_target_offload_kind {
38-
tgt_disabled = 0,
39-
tgt_default = 1,
40-
tgt_mandatory = 2
41-
};
42-
typedef enum kmp_target_offload_kind kmp_target_offload_kind_t;
43-
4436
///
4537
struct PendingCtorDtorListsTy {
4638
std::list<void *> PendingCtors;

openmp/libomptarget/src/omptarget.cpp

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "omptarget.h"
15+
#include "OffloadPolicy.h"
1516
#include "OpenMP/OMPT/Callback.h"
1617
#include "OpenMP/OMPT/Interface.h"
1718
#include "PluginManager.h"
@@ -281,17 +282,13 @@ static int initLibrary(DeviceTy &Device) {
281282
}
282283

283284
void handleTargetOutcome(bool Success, ident_t *Loc) {
284-
switch (PM->TargetOffloadPolicy) {
285-
case tgt_disabled:
285+
switch (OffloadPolicy::get(*PM).Kind) {
286+
case OffloadPolicy::DISABLED:
286287
if (Success) {
287288
FATAL_MESSAGE0(1, "expected no offloading while offloading is disabled");
288289
}
289290
break;
290-
case tgt_default:
291-
FATAL_MESSAGE0(1, "default offloading policy must be switched to "
292-
"mandatory or disabled");
293-
break;
294-
case tgt_mandatory:
291+
case OffloadPolicy::MANDATORY:
295292
if (!Success) {
296293
if (getInfoLevel() & OMP_INFOTYPE_DUMP_TABLE)
297294
for (auto &Device : PM->Devices)
@@ -329,27 +326,6 @@ void handleTargetOutcome(bool Success, ident_t *Loc) {
329326
}
330327
}
331328

332-
static void handleDefaultTargetOffload() {
333-
std::lock_guard<decltype(PM->TargetOffloadMtx)> LG(PM->TargetOffloadMtx);
334-
if (PM->TargetOffloadPolicy == tgt_default) {
335-
if (omp_get_num_devices() > 0) {
336-
DP("Default TARGET OFFLOAD policy is now mandatory "
337-
"(devices were found)\n");
338-
PM->TargetOffloadPolicy = tgt_mandatory;
339-
} else {
340-
DP("Default TARGET OFFLOAD policy is now disabled "
341-
"(no devices were found)\n");
342-
PM->TargetOffloadPolicy = tgt_disabled;
343-
}
344-
}
345-
}
346-
347-
static bool isOffloadDisabled() {
348-
if (PM->TargetOffloadPolicy == tgt_default)
349-
handleDefaultTargetOffload();
350-
return PM->TargetOffloadPolicy == tgt_disabled;
351-
}
352-
353329
// If offload is enabled, ensure that device DeviceID has been initialized,
354330
// global ctors have been executed, and global data has been mapped.
355331
//
@@ -363,7 +339,7 @@ static bool isOffloadDisabled() {
363339
// If DeviceID == OFFLOAD_DEVICE_DEFAULT, set DeviceID to the default device.
364340
// This step might be skipped if offload is disabled.
365341
bool checkDeviceAndCtors(int64_t &DeviceID, ident_t *Loc) {
366-
if (isOffloadDisabled()) {
342+
if (OffloadPolicy::get(*PM).Kind == OffloadPolicy::DISABLED) {
367343
DP("Offload is disabled\n");
368344
return true;
369345
}

0 commit comments

Comments
 (0)