Skip to content

Commit 1843339

Browse files
authored
Trace interface improvements (rust-lang#990)
* simplify trace interface * move trace interface into separate header * replace strings with constexpr * move sampe_func detection into TraceInterface
1 parent 248dbbb commit 1843339

File tree

4 files changed

+337
-336
lines changed

4 files changed

+337
-336
lines changed

enzyme/Enzyme/Enzyme.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
#include "ActivityAnalysis.h"
6363
#include "EnzymeLogic.h"
6464
#include "GradientUtils.h"
65+
#include "TraceInterface.h"
6566
#include "TraceUtils.h"
6667
#include "Utils.h"
6768

@@ -1842,17 +1843,17 @@ class EnzymeBase {
18421843
}
18431844

18441845
// Interface
1845-
1846-
Function *sample = nullptr;
1847-
for (auto &&interface_func : F->getParent()->functions()) {
1848-
if (interface_func.getName().contains("__enzyme_sample")) {
1849-
assert(interface_func.getFunctionType()->getNumParams() >= 3);
1850-
sample = &interface_func;
1851-
}
1846+
bool has_dynamic_interface = dynamic_interface != nullptr;
1847+
std::unique_ptr<TraceInterface> interface;
1848+
if (has_dynamic_interface) {
1849+
interface =
1850+
std::unique_ptr<DynamicTraceInterface>(new DynamicTraceInterface(
1851+
dynamic_interface, CI->getParent()->getParent()));
1852+
} else {
1853+
interface = std::unique_ptr<StaticTraceInterface>(
1854+
new StaticTraceInterface(F->getParent()));
18521855
}
18531856

1854-
assert(sample);
1855-
18561857
if (dynamic_interface)
18571858
args.push_back(dynamic_interface);
18581859

@@ -1862,8 +1863,8 @@ class EnzymeBase {
18621863
// Determine generative functions
18631864
SmallPtrSet<Function *, 4> generativeFunctions;
18641865
SetVector<Function *, std::deque<Function *>> workList;
1865-
workList.insert(sample);
1866-
generativeFunctions.insert(sample);
1866+
workList.insert(interface->getSampleFunction());
1867+
generativeFunctions.insert(interface->getSampleFunction());
18671868

18681869
while (!workList.empty()) {
18691870
auto todo = *workList.begin();
@@ -1889,9 +1890,8 @@ class EnzymeBase {
18891890
}
18901891
#endif
18911892
}
1892-
1893-
auto newFunc = Logic.CreateTrace(F, generativeFunctions, mode,
1894-
dynamic_interface != nullptr);
1893+
auto newFunc =
1894+
Logic.CreateTrace(F, generativeFunctions, mode, has_dynamic_interface);
18951895

18961896
Value *trace =
18971897
Builder.CreateCall(newFunc->getFunctionType(), newFunc, args);
@@ -2588,8 +2588,8 @@ class EnzymeBase {
25882588
for (auto &&Inst : BB) {
25892589
if (auto CI = dyn_cast<CallInst>(&Inst)) {
25902590
Function *enzyme_sample = CI->getCalledFunction();
2591-
if (enzyme_sample &&
2592-
enzyme_sample->getName().contains("__enzyme_sample")) {
2591+
if (enzyme_sample && enzyme_sample->getName().contains(
2592+
TraceInterface::sampleFunctionName)) {
25932593
if (CI->getNumOperands() < 3) {
25942594
EmitFailure(
25952595
"IllegalNumberOfArguments", CI->getDebugLoc(), CI,

enzyme/Enzyme/TraceGenerator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
1717

1818
#include "FunctionUtils.h"
19+
#include "TraceInterface.h"
20+
#include "TraceUtils.h"
1921
#include "Utils.h"
2022

2123
using namespace llvm;

enzyme/Enzyme/TraceInterface.h

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
#ifndef TraceInterface_h
2+
#define TraceInterface_h
3+
4+
#include "llvm/IR/IRBuilder.h"
5+
#include "llvm/IR/Instructions.h"
6+
#include "llvm/IR/Type.h"
7+
#include "llvm/IR/Value.h"
8+
9+
using namespace llvm;
10+
11+
class TraceInterface {
12+
private:
13+
LLVMContext &C;
14+
15+
public:
16+
TraceInterface(LLVMContext &C) : C(C) {}
17+
18+
virtual ~TraceInterface() = default;
19+
20+
public:
21+
// implemented by enzyme
22+
virtual Function *getSampleFunction() = 0;
23+
static constexpr const char sampleFunctionName[] = "__enzyme_sample";
24+
25+
// user implemented
26+
virtual Value *getTrace() = 0;
27+
virtual Value *getChoice() = 0;
28+
virtual Value *insertCall() = 0;
29+
virtual Value *insertChoice() = 0;
30+
virtual Value *newTrace() = 0;
31+
virtual Value *freeTrace() = 0;
32+
virtual Value *hasCall() = 0;
33+
virtual Value *hasChoice() = 0;
34+
35+
public:
36+
static IntegerType *sizeType(LLVMContext &C) {
37+
return IntegerType::getInt64Ty(C);
38+
}
39+
static Type *stringType(LLVMContext &C) {
40+
return IntegerType::getInt8PtrTy(C);
41+
}
42+
43+
public:
44+
FunctionType *getTraceTy() { return getTraceTy(C); }
45+
FunctionType *getChoiceTy() { return getChoiceTy(C); }
46+
FunctionType *insertCallTy() { return insertCallTy(C); }
47+
FunctionType *insertChoiceTy() { return insertChoiceTy(C); }
48+
FunctionType *newTraceTy() { return newTraceTy(C); }
49+
FunctionType *freeTraceTy() { return freeTraceTy(C); }
50+
FunctionType *hasCallTy() { return hasCallTy(C); }
51+
FunctionType *hasChoiceTy() { return hasChoiceTy(C); }
52+
53+
static FunctionType *getTraceTy(LLVMContext &C) {
54+
return FunctionType::get(PointerType::getInt8PtrTy(C),
55+
{PointerType::getInt8PtrTy(C), stringType(C)},
56+
false);
57+
}
58+
59+
static FunctionType *getChoiceTy(LLVMContext &C) {
60+
return FunctionType::get(sizeType(C),
61+
{PointerType::getInt8PtrTy(C), stringType(C),
62+
PointerType::getInt8PtrTy(C), sizeType(C)},
63+
false);
64+
}
65+
66+
static FunctionType *insertCallTy(LLVMContext &C) {
67+
return FunctionType::get(Type::getVoidTy(C),
68+
{PointerType::getInt8PtrTy(C), stringType(C),
69+
PointerType::getInt8PtrTy(C)},
70+
false);
71+
}
72+
73+
static FunctionType *insertChoiceTy(LLVMContext &C) {
74+
return FunctionType::get(Type::getVoidTy(C),
75+
{PointerType::getInt8PtrTy(C), stringType(C),
76+
Type::getDoubleTy(C),
77+
PointerType::getInt8PtrTy(C), sizeType(C)},
78+
false);
79+
}
80+
81+
static FunctionType *newTraceTy(LLVMContext &C) {
82+
return FunctionType::get(PointerType::getInt8PtrTy(C), {}, false);
83+
}
84+
85+
static FunctionType *freeTraceTy(LLVMContext &C) {
86+
return FunctionType::get(Type::getVoidTy(C), {PointerType::getInt8PtrTy(C)},
87+
false);
88+
}
89+
90+
static FunctionType *hasCallTy(LLVMContext &C) {
91+
return FunctionType::get(Type::getInt1Ty(C),
92+
{PointerType::getInt8PtrTy(C), stringType(C)},
93+
false);
94+
}
95+
96+
static FunctionType *hasChoiceTy(LLVMContext &C) {
97+
return FunctionType::get(Type::getInt1Ty(C),
98+
{PointerType::getInt8PtrTy(C), stringType(C)},
99+
false);
100+
}
101+
};
102+
103+
class StaticTraceInterface final : public TraceInterface {
104+
private:
105+
Function *sampleFunction = nullptr;
106+
// user implemented
107+
Function *getTraceFunction = nullptr;
108+
Function *getChoiceFunction = nullptr;
109+
Function *insertCallFunction = nullptr;
110+
Function *insertChoiceFunction = nullptr;
111+
Function *newTraceFunction = nullptr;
112+
Function *freeTraceFunction = nullptr;
113+
Function *hasCallFunction = nullptr;
114+
Function *hasChoiceFunction = nullptr;
115+
116+
public:
117+
StaticTraceInterface(Module *M) : TraceInterface(M->getContext()) {
118+
for (auto &&F : M->functions()) {
119+
if (F.getName().contains("__enzyme_newtrace")) {
120+
assert(F.getFunctionType() == newTraceTy());
121+
newTraceFunction = &F;
122+
} else if (F.getName().contains("__enzyme_freetrace")) {
123+
assert(F.getFunctionType() == freeTraceTy());
124+
freeTraceFunction = &F;
125+
} else if (F.getName().contains("__enzyme_get_trace")) {
126+
assert(F.getFunctionType() == getTraceTy());
127+
getTraceFunction = &F;
128+
} else if (F.getName().contains("__enzyme_get_choice")) {
129+
assert(F.getFunctionType() == getChoiceTy());
130+
getChoiceFunction = &F;
131+
} else if (F.getName().contains("__enzyme_insert_call")) {
132+
assert(F.getFunctionType() == insertCallTy());
133+
insertCallFunction = &F;
134+
} else if (F.getName().contains("__enzyme_insert_choice")) {
135+
assert(F.getFunctionType() == insertChoiceTy());
136+
insertChoiceFunction = &F;
137+
} else if (F.getName().contains("__enzyme_has_call")) {
138+
assert(F.getFunctionType() == hasCallTy());
139+
hasCallFunction = &F;
140+
} else if (F.getName().contains("__enzyme_has_choice")) {
141+
assert(F.getFunctionType() == hasChoiceTy());
142+
hasChoiceFunction = &F;
143+
} else if (F.getName().contains(sampleFunctionName)) {
144+
assert(F.getFunctionType()->getNumParams() >= 3);
145+
sampleFunction = &F;
146+
}
147+
}
148+
149+
assert(newTraceFunction != nullptr && freeTraceFunction != nullptr &&
150+
getTraceFunction != nullptr && getChoiceFunction != nullptr &&
151+
insertCallFunction != nullptr && insertChoiceFunction != nullptr &&
152+
hasCallFunction != nullptr && hasChoiceFunction != nullptr &&
153+
sampleFunction != nullptr);
154+
}
155+
156+
~StaticTraceInterface() = default;
157+
158+
public:
159+
// implemented by enzyme
160+
Function *getSampleFunction() { return sampleFunction; }
161+
162+
// user implemented
163+
Value *getTrace() { return getTraceFunction; }
164+
Value *getChoice() { return getChoiceFunction; }
165+
Value *insertCall() { return insertCallFunction; }
166+
Value *insertChoice() { return insertChoiceFunction; }
167+
Value *newTrace() { return newTraceFunction; }
168+
Value *freeTrace() { return freeTraceFunction; }
169+
Value *hasCall() { return hasCallFunction; }
170+
Value *hasChoice() { return hasChoiceFunction; }
171+
};
172+
173+
class DynamicTraceInterface final : public TraceInterface {
174+
private:
175+
Function *sampleFunction = nullptr;
176+
Value *dynamicInterface;
177+
Function *F;
178+
179+
private:
180+
Value *getTraceFunction = nullptr;
181+
Value *getChoiceFunction = nullptr;
182+
Value *insertCallFunction = nullptr;
183+
Value *insertChoiceFunction = nullptr;
184+
Value *newTraceFunction = nullptr;
185+
Value *freeTraceFunction = nullptr;
186+
Value *hasCallFunction = nullptr;
187+
Value *hasChoiceFunction = nullptr;
188+
189+
public:
190+
DynamicTraceInterface(Value *dynamicInterface, Function *F)
191+
: TraceInterface(F->getContext()), dynamicInterface(dynamicInterface),
192+
F(F) {
193+
194+
for (auto &&interface_func : F->getParent()->functions()) {
195+
if (interface_func.getName().contains(
196+
TraceInterface::sampleFunctionName)) {
197+
assert(interface_func.getFunctionType()->getNumParams() >= 3);
198+
sampleFunction = &interface_func;
199+
}
200+
}
201+
202+
assert(sampleFunction);
203+
}
204+
205+
~DynamicTraceInterface() = default;
206+
207+
public:
208+
// implemented by enzyme
209+
Function *getSampleFunction() { return sampleFunction; }
210+
211+
// user implemented
212+
Value *getTrace() {
213+
if (getTraceFunction)
214+
return getTraceFunction;
215+
216+
IRBuilder<> Builder(F->getEntryBlock().getFirstNonPHIOrDbgOrLifetime());
217+
218+
auto ptr = Builder.CreateInBoundsGEP(Builder.getInt8PtrTy(),
219+
dynamicInterface, Builder.getInt32(0));
220+
auto load = Builder.CreateLoad(Builder.getInt8PtrTy(), ptr);
221+
return getTraceFunction = Builder.CreatePointerCast(
222+
load, PointerType::getUnqual(getTraceTy()), "get_trace");
223+
}
224+
225+
Value *getChoice() {
226+
if (getChoiceFunction)
227+
return getChoiceFunction;
228+
229+
IRBuilder<> Builder(F->getEntryBlock().getFirstNonPHIOrDbgOrLifetime());
230+
auto ptr = Builder.CreateInBoundsGEP(Builder.getInt8PtrTy(),
231+
dynamicInterface, Builder.getInt32(1));
232+
auto load = Builder.CreateLoad(Builder.getInt8PtrTy(), ptr);
233+
return getChoiceFunction = Builder.CreatePointerCast(
234+
load, PointerType::getUnqual(getChoiceTy()), "get_choice");
235+
}
236+
237+
Value *insertCall() {
238+
if (insertCallFunction)
239+
return insertCallFunction;
240+
241+
IRBuilder<> Builder(F->getEntryBlock().getFirstNonPHIOrDbgOrLifetime());
242+
auto ptr = Builder.CreateInBoundsGEP(Builder.getInt8PtrTy(),
243+
dynamicInterface, Builder.getInt32(2));
244+
auto load = Builder.CreateLoad(Builder.getInt8PtrTy(), ptr);
245+
return insertCallFunction = Builder.CreatePointerCast(
246+
load, PointerType::getUnqual(insertCallTy()), "insert_call");
247+
}
248+
249+
Value *insertChoice() {
250+
if (insertChoiceFunction)
251+
return insertChoiceFunction;
252+
253+
IRBuilder<> Builder(F->getEntryBlock().getFirstNonPHIOrDbgOrLifetime());
254+
auto ptr = Builder.CreateInBoundsGEP(Builder.getInt8PtrTy(),
255+
dynamicInterface, Builder.getInt32(3));
256+
auto load = Builder.CreateLoad(Builder.getInt8PtrTy(), ptr);
257+
return insertChoiceFunction = Builder.CreatePointerCast(
258+
load, PointerType::getUnqual(insertChoiceTy()), "insert_choice");
259+
}
260+
261+
Value *newTrace() {
262+
if (newTraceFunction)
263+
return newTraceFunction;
264+
265+
IRBuilder<> Builder(F->getEntryBlock().getFirstNonPHIOrDbgOrLifetime());
266+
auto ptr = Builder.CreateInBoundsGEP(Builder.getInt8PtrTy(),
267+
dynamicInterface, Builder.getInt32(4));
268+
auto load = Builder.CreateLoad(Builder.getInt8PtrTy(), ptr);
269+
return newTraceFunction = Builder.CreatePointerCast(
270+
load, PointerType::getUnqual(newTraceTy()), "new_trace");
271+
}
272+
273+
Value *freeTrace() {
274+
if (freeTraceFunction)
275+
return freeTraceFunction;
276+
277+
IRBuilder<> Builder(F->getEntryBlock().getFirstNonPHIOrDbgOrLifetime());
278+
auto ptr = Builder.CreateInBoundsGEP(Builder.getInt8PtrTy(),
279+
dynamicInterface, Builder.getInt32(5));
280+
auto load = Builder.CreateLoad(Builder.getInt8PtrTy(), ptr);
281+
return freeTraceFunction = Builder.CreatePointerCast(
282+
load, PointerType::getUnqual(freeTraceTy()), "free_trace");
283+
}
284+
285+
Value *hasCall() {
286+
if (hasCallFunction)
287+
return hasCallFunction;
288+
289+
IRBuilder<> Builder(F->getEntryBlock().getFirstNonPHIOrDbgOrLifetime());
290+
auto ptr = Builder.CreateInBoundsGEP(Builder.getInt8PtrTy(),
291+
dynamicInterface, Builder.getInt32(6));
292+
auto load = Builder.CreateLoad(Builder.getInt8PtrTy(), ptr);
293+
return hasCallFunction = Builder.CreatePointerCast(
294+
load, PointerType::getUnqual(hasCallTy()), "has_call");
295+
}
296+
297+
Value *hasChoice() {
298+
if (hasChoiceFunction)
299+
return hasChoiceFunction;
300+
301+
IRBuilder<> Builder(F->getEntryBlock().getFirstNonPHIOrDbgOrLifetime());
302+
auto ptr = Builder.CreateInBoundsGEP(Builder.getInt8PtrTy(),
303+
dynamicInterface, Builder.getInt32(7));
304+
auto load = Builder.CreateLoad(Builder.getInt8PtrTy(), ptr);
305+
return hasChoiceFunction = Builder.CreatePointerCast(
306+
load, PointerType::getUnqual(hasChoiceTy()), "has_choice");
307+
}
308+
};
309+
310+
#endif

0 commit comments

Comments
 (0)