Skip to content

Commit f853557

Browse files
authored
Add fromstack utility (rust-lang#776)
* Add fromstack utility * Add addrspace fromstack
1 parent 43b737e commit f853557

File tree

4 files changed

+178
-22
lines changed

4 files changed

+178
-22
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10379,8 +10379,7 @@ class AdjointGenerator
1037910379
Attribute::NonNull);
1038010380
#endif
1038110381

10382-
if (called->getName() == "malloc" ||
10383-
called->getName() == "_Znwm") {
10382+
if (funcName == "malloc" || funcName == "_Znwm") {
1038410383
if (auto ci = dyn_cast<ConstantInt>(args[0])) {
1038510384
unsigned derefBytes = ci->getLimitedValue();
1038610385
CallInst *cal =
@@ -10438,18 +10437,37 @@ class AdjointGenerator
1043810437
bb, anti, getIndex(orig, CacheType::Shadow));
1043910438
else {
1044010439
if (auto MD = hasMetadata(orig, "enzyme_fromstack")) {
10441-
AllocaInst *replacement = bb.CreateAlloca(
10442-
Type::getInt8Ty(orig->getContext()), args[0]);
10440+
Value *Size;
10441+
if (funcName == "malloc")
10442+
Size = args[0];
10443+
else if (funcName == "julia.gc_alloc_obj")
10444+
Size = args[1];
10445+
else
10446+
llvm_unreachable("Unknown allocation to upgrade");
10447+
Value *replacement = bb.CreateAlloca(
10448+
Type::getInt8Ty(orig->getContext()), Size);
1044310449
replacement->takeName(anti);
1044410450
auto Alignment = cast<ConstantInt>(cast<ConstantAsMetadata>(
1044510451
MD->getOperand(0))
1044610452
->getValue())
1044710453
->getLimitedValue();
1044810454
#if LLVM_VERSION_MAJOR >= 10
10449-
replacement->setAlignment(Align(Alignment));
10455+
cast<AllocaInst>(replacement)->setAlignment(Align(Alignment));
1045010456
#else
10451-
replacement->setAlignment(Alignment);
10457+
cast<AllocaInst>(replacement)->setAlignment(Alignment);
1045210458
#endif
10459+
if (!anti->getType()->getPointerElementType()->isIntegerTy(8))
10460+
replacement = bb.CreatePointerCast(
10461+
replacement,
10462+
PointerType::getUnqual(
10463+
anti->getType()->getPointerElementType()));
10464+
10465+
if (int AS =
10466+
cast<PointerType>(anti->getType())->getAddressSpace())
10467+
replacement = bb.CreateAddrSpaceCast(
10468+
replacement,
10469+
PointerType::get(
10470+
anti->getType()->getPointerElementType(), AS));
1045310471

1045410472
gutils->replaceAWithB(cast<Instruction>(anti), replacement);
1045510473
gutils->erase(cast<Instruction>(anti));
@@ -10583,23 +10601,43 @@ class AdjointGenerator
1058310601
// allocation where possible.
1058410602
if (auto MD = hasMetadata(orig, "enzyme_fromstack")) {
1058510603
IRBuilder<> B(newCall);
10586-
if (auto CI = dyn_cast<ConstantInt>(orig->getArgOperand(0))) {
10604+
10605+
Value *Size;
10606+
if (funcName == "malloc")
10607+
Size = orig->getArgOperand(0);
10608+
else if (funcName == "julia.gc_alloc_obj")
10609+
Size = orig->getArgOperand(1);
10610+
else
10611+
llvm_unreachable("Unknown allocation to upgrade");
10612+
Size = gutils->getNewFromOriginal(Size);
10613+
10614+
if (auto CI = dyn_cast<ConstantInt>(Size)) {
1058710615
B.SetInsertPoint(gutils->inversionAllocs);
1058810616
}
10589-
1059010617
auto rule = [&]() {
10591-
auto replacement = B.CreateAlloca(
10592-
Type::getInt8Ty(orig->getContext()),
10593-
gutils->getNewFromOriginal(orig->getArgOperand(0)));
10618+
Value *replacement =
10619+
B.CreateAlloca(Type::getInt8Ty(orig->getContext()), Size);
1059410620
auto Alignment =
1059510621
cast<ConstantInt>(
1059610622
cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
1059710623
->getLimitedValue();
1059810624
#if LLVM_VERSION_MAJOR >= 10
10599-
replacement->setAlignment(Align(Alignment));
10600-
#else
10601-
replacement->setAlignment(Alignment);
10602-
#endif
10625+
cast<AllocaInst>(replacement)->setAlignment(Align(Alignment));
10626+
#else
10627+
cast<AllocaInst>(replacement)->setAlignment(Alignment);
10628+
#endif
10629+
if (!orig->getType()->getPointerElementType()->isIntegerTy(8))
10630+
replacement = B.CreatePointerCast(
10631+
replacement,
10632+
PointerType::getUnqual(
10633+
orig->getType()->getPointerElementType()));
10634+
10635+
if (int AS =
10636+
cast<PointerType>(orig->getType())->getAddressSpace())
10637+
replacement = B.CreateAddrSpaceCast(
10638+
replacement,
10639+
PointerType::get(orig->getType()->getPointerElementType(),
10640+
AS));
1060310641
return replacement;
1060410642
};
1060510643

enzyme/Enzyme/CApi.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,11 @@ void EnzymeSetMustCache(LLVMValueRef inst1) {
600600
I1->setMetadata("enzyme_mustcache", MDNode::get(I1->getContext(), {}));
601601
}
602602

603+
uint8_t EnzymeHasFromStack(LLVMValueRef inst1) {
604+
Instruction *I1 = cast<Instruction>(unwrap(inst1));
605+
return hasMetadata(I1, "enzyme_fromstack") != 0;
606+
}
607+
603608
void EnzymeReplaceFunctionImplementation(LLVMModuleRef M) {
604609
ReplaceFunctionImplementation(*unwrap(M));
605610
}

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 117 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,116 @@ static inline bool OnlyUsedInOMP(AllocaInst *AI) {
265265
return true;
266266
}
267267

268+
void RecursivelyReplaceAddressSpace(Value *AI, Value *rep, bool legal) {
269+
SmallVector<std::tuple<Value *, Value *, Instruction *>, 1> Todo;
270+
for (auto U : AI->users()) {
271+
Todo.push_back(
272+
std::make_tuple((Value *)rep, (Value *)AI, cast<Instruction>(U)));
273+
}
274+
SmallVector<Instruction *, 1> toErase;
275+
if (auto I = dyn_cast<Instruction>(AI))
276+
toErase.push_back(I);
277+
while (Todo.size()) {
278+
auto cur = Todo.back();
279+
Todo.pop_back();
280+
Value *rep = std::get<0>(cur);
281+
Value *prev = std::get<1>(cur);
282+
Value *inst = std::get<2>(cur);
283+
if (auto ASC = dyn_cast<AddrSpaceCastInst>(inst)) {
284+
auto AS = cast<PointerType>(rep->getType())->getAddressSpace();
285+
if (AS == ASC->getDestAddressSpace()) {
286+
ASC->replaceAllUsesWith(rep);
287+
continue;
288+
}
289+
ASC->setOperand(0, rep);
290+
continue;
291+
}
292+
if (auto CI = dyn_cast<CastInst>(inst)) {
293+
if (!CI->getType()->isPointerTy()) {
294+
CI->setOperand(0, rep);
295+
continue;
296+
}
297+
IRBuilder<> B(CI);
298+
auto nCI = cast<CastInst>(B.CreateCast(
299+
CI->getOpcode(), rep,
300+
PointerType::get(
301+
CI->getType()->getPointerElementType(),
302+
cast<PointerType>(rep->getType())->getAddressSpace())));
303+
nCI->takeName(CI);
304+
for (auto U : CI->users()) {
305+
Todo.push_back(
306+
std::make_tuple((Value *)nCI, (Value *)CI, cast<Instruction>(U)));
307+
}
308+
continue;
309+
}
310+
if (auto GEP = dyn_cast<GetElementPtrInst>(inst)) {
311+
IRBuilder<> B(GEP);
312+
SmallVector<Value *, 1> ind(GEP->indices());
313+
#if LLVM_VERSION_MAJOR > 7
314+
auto nGEP = cast<GetElementPtrInst>(
315+
B.CreateGEP(GEP->getSourceElementType(), rep, ind));
316+
#else
317+
auto nGEP = cast<GetElementPtrInst>(B.CreateGEP(rep, ind));
318+
#endif
319+
nGEP->takeName(GEP);
320+
for (auto U : GEP->users()) {
321+
Todo.push_back(
322+
std::make_tuple((Value *)nGEP, (Value *)GEP, cast<Instruction>(U)));
323+
}
324+
toErase.push_back(GEP);
325+
continue;
326+
}
327+
if (auto LI = dyn_cast<LoadInst>(inst)) {
328+
LI->setOperand(0, rep);
329+
continue;
330+
}
331+
if (auto SI = dyn_cast<StoreInst>(inst)) {
332+
if (SI->getPointerOperand() == prev) {
333+
SI->setOperand(1, rep);
334+
continue;
335+
}
336+
}
337+
if (auto MS = dyn_cast<MemSetInst>(inst)) {
338+
IRBuilder<> B(MS);
339+
340+
Value *nargs[] = {rep, MS->getArgOperand(1), MS->getArgOperand(2),
341+
MS->getArgOperand(3)};
342+
343+
Type *tys[] = {nargs[0]->getType(), nargs[2]->getType()};
344+
345+
auto nMS = cast<CallInst>(B.CreateCall(
346+
Intrinsic::getDeclaration(MS->getParent()->getParent()->getParent(),
347+
Intrinsic::memset, tys),
348+
nargs));
349+
nMS->copyIRFlags(MS);
350+
toErase.push_back(MS);
351+
continue;
352+
}
353+
if (auto CI = dyn_cast<CallInst>(inst)) {
354+
if (auto F = CI->getCalledFunction()) {
355+
if (F->getName() == "julia.write_barrier" && legal) {
356+
toErase.push_back(CI);
357+
continue;
358+
}
359+
}
360+
}
361+
if (legal) {
362+
IRBuilder<> B(cast<Instruction>(rep)->getNextNode());
363+
rep = B.CreateAddrSpaceCast(
364+
rep, PointerType::get(
365+
rep->getType()->getPointerElementType(),
366+
cast<PointerType>(prev->getType())->getAddressSpace()));
367+
prev->replaceAllUsesWith(rep);
368+
continue;
369+
}
370+
llvm::errs() << " rep: " << *rep << " prev: " << *prev << " inst: " << *inst
371+
<< "\n";
372+
llvm_unreachable("Illegal address space propagation");
373+
}
374+
for (auto I : llvm::reverse(toErase))
375+
I->eraseFromParent();
376+
}
377+
268378
/// Convert necessary stack allocations into mallocs for use in the reverse
269379
/// pass. Specifically if we're not topLevel all allocations must be upgraded
270380
/// Even if topLevel any allocations that aren't in the entry block (and
@@ -312,13 +422,13 @@ static inline void UpgradeAllocasToMallocs(Function *NewF,
312422

313423
auto PT0 = cast<PointerType>(rep->getType());
314424
auto PT1 = cast<PointerType>(AI->getType());
315-
if (PT0->getAddressSpace() != PT1->getAddressSpace())
316-
rep = B.CreateAddrSpaceCast(rep,
317-
PointerType::get(PT0->getPointerElementType(),
318-
PT1->getAddressSpace()));
319-
assert(rep->getType() == AI->getType());
320-
AI->replaceAllUsesWith(rep);
321-
AI->eraseFromParent();
425+
if (PT0->getAddressSpace() != PT1->getAddressSpace()) {
426+
RecursivelyReplaceAddressSpace(AI, rep, /*legal*/ false);
427+
} else {
428+
assert(rep->getType() == AI->getType());
429+
AI->replaceAllUsesWith(rep);
430+
AI->eraseFromParent();
431+
}
322432
}
323433
}
324434

enzyme/Enzyme/FunctionUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,9 @@ static inline void calculateUnusedStores(
346346
}
347347
}
348348

349+
void RecursivelyReplaceAddressSpace(llvm::Value *AI, llvm::Value *rep,
350+
bool legal);
351+
349352
void ReplaceFunctionImplementation(llvm::Module &M);
350353

351354
/// Is the use of value val as an argument of call CI potentially captured

0 commit comments

Comments
 (0)