Skip to content

Commit d38d811

Browse files
authored
Functioning FBLAS and TBAA prop (rust-lang#528)
1 parent c88100b commit d38d811

20 files changed

+533
-289
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 427 additions & 199 deletions
Large diffs are not rendered by default.

enzyme/Enzyme/CacheUtility.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,11 @@ void CacheUtility::replaceAWithB(Value *A, Value *B, bool storeInCache) {
9393
scopeInstructions.erase(stfound);
9494
for (auto st : tmpInstructions)
9595
cast<StoreInst>(&*st)->eraseFromParent();
96+
MDNode *TBAA = nullptr;
97+
if (auto I = dyn_cast<Instruction>(A))
98+
TBAA = I->getMetadata(LLVMContext::MD_tbaa);
9699
storeInstructionInCache(found->second.second, cast<Instruction>(B),
97-
cache);
100+
cache, TBAA);
98101
}
99102
}
100103

@@ -1347,7 +1350,7 @@ CacheUtility::SubLimitType CacheUtility::getSubLimits(bool inForwardPass,
13471350
/// in the cache at the location defined in the given builder
13481351
void CacheUtility::storeInstructionInCache(LimitContext ctx,
13491352
IRBuilder<> &BuilderM, Value *val,
1350-
AllocaInst *cache) {
1353+
AllocaInst *cache, MDNode *TBAA) {
13511354
assert(BuilderM.GetInsertBlock()->getParent() == newFunc);
13521355
if (auto inst = dyn_cast<Instruction>(val))
13531356
assert(inst->getParent()->getParent() == newFunc);
@@ -1447,6 +1450,7 @@ void CacheUtility::storeInstructionInCache(LimitContext ctx,
14471450
.getTypeAllocSizeInBits(val->getType()) /
14481451
8);
14491452
unsigned align = getCacheAlignment((unsigned)byteSizeOfType->getZExtValue());
1453+
storeinst->setMetadata(LLVMContext::MD_tbaa, TBAA);
14501454
#if LLVM_VERSION_MAJOR >= 10
14511455
storeinst->setAlignment(Align(align));
14521456
#else
@@ -1459,7 +1463,8 @@ void CacheUtility::storeInstructionInCache(LimitContext ctx,
14591463
/// in the cache right after the instruction is executed
14601464
void CacheUtility::storeInstructionInCache(LimitContext ctx,
14611465
llvm::Instruction *inst,
1462-
llvm::AllocaInst *cache) {
1466+
llvm::AllocaInst *cache,
1467+
llvm::MDNode *TBAA) {
14631468
assert(ctx.Block);
14641469
assert(inst);
14651470
assert(cache);
@@ -1477,7 +1482,7 @@ void CacheUtility::storeInstructionInCache(LimitContext ctx,
14771482
v.SetInsertPoint(putafter);
14781483
}
14791484
v.setFastMathFlags(getFast());
1480-
storeInstructionInCache(ctx, v, inst, cache);
1485+
storeInstructionInCache(ctx, v, inst, cache, TBAA);
14811486
}
14821487

14831488
/// Given an allocation specified by the LimitContext ctx and cache, compute a

enzyme/Enzyme/CacheUtility.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,12 +363,14 @@ class CacheUtility {
363363
/// Given an allocation defined at a particular ctx, store the value val
364364
/// in the cache at the location defined in the given builder
365365
void storeInstructionInCache(LimitContext ctx, llvm::IRBuilder<> &BuilderM,
366-
llvm::Value *val, llvm::AllocaInst *cache);
366+
llvm::Value *val, llvm::AllocaInst *cache,
367+
llvm::MDNode *TBAA = nullptr);
367368

368369
/// Given an allocation defined at a particular ctx, store the instruction
369370
/// in the cache right after the instruction is executed
370371
void storeInstructionInCache(LimitContext ctx, llvm::Instruction *inst,
371-
llvm::AllocaInst *cache);
372+
llvm::AllocaInst *cache,
373+
llvm::MDNode *TBAA = nullptr);
372374

373375
/// Given an allocation specified by the LimitContext ctx and cache, compute a
374376
/// pointer that can hold the underlying type being cached. This value should

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3177,7 +3177,6 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
31773177
llvm::raw_string_ostream ss(s);
31783178
ss << "No reverse pass found for " + key.todiff->getName() << "\n";
31793179
ss << *key.todiff << "\n";
3180-
CustomErrorHandler(ss.str().c_str());
31813180
if (CustomErrorHandler) {
31823181
CustomErrorHandler(ss.str().c_str());
31833182
} else {

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ llvm::cl::opt<bool>
104104
extern void (*CustomErrorHandler)(const char *);
105105
}
106106

107+
unsigned int MD_ToCopy[5] = {LLVMContext::MD_dbg, LLVMContext::MD_tbaa,
108+
LLVMContext::MD_tbaa_struct, LLVMContext::MD_range,
109+
LLVMContext::MD_nonnull};
110+
107111
Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
108112
const ValueToValueMapTy &available,
109113
UnwrapMode unwrapMode, BasicBlock *scope,
@@ -686,6 +690,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
686690
#else
687691
auto toreturn = BuilderM.CreateLoad(pidx, load->getName() + "_unwrap");
688692
#endif
693+
toreturn->copyMetadata(*load, MD_ToCopy);
689694
toreturn->copyIRFlags(load);
690695
unwrappedLoads[toreturn] = load;
691696
if (toreturn->getParent()->getParent() != load->getParent()->getParent())
@@ -2217,8 +2222,11 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc,
22172222
return malloc;
22182223
}
22192224

2220-
ensureLookupCached(cast<Instruction>(malloc),
2221-
/*shouldFree=*/reverseBlocks.size() > 0);
2225+
ensureLookupCached(
2226+
cast<Instruction>(malloc),
2227+
/*shouldFree=*/reverseBlocks.size() > 0,
2228+
/*scope*/ nullptr,
2229+
cast<Instruction>(malloc)->getMetadata(LLVMContext::MD_tbaa));
22222230
auto found2 = scopeMap.find(malloc);
22232231
assert(found2 != scopeMap.end());
22242232
assert(found2->second.first);
@@ -2427,6 +2435,7 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
24272435
available),
24282436
lookupM(getNewFromOriginal(SI->getPointerOperand()), NB,
24292437
available));
2438+
ts->copyMetadata(*SI, MD_ToCopy);
24302439
#if LLVM_VERSION_MAJOR >= 10
24312440
ts->setAlignment(SI->getAlign());
24322441
#else
@@ -2516,6 +2525,7 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
25162525
auto Defs = getInvertedBundles(CI, BundleTypes, NB,
25172526
/*lookup*/ true, available);
25182527
auto cal = NB.CreateCall(CI->getCalledFunction(), args, Defs);
2528+
cal->copyMetadata(*CI, MD_ToCopy);
25192529
cal->setName("remat_" + CI->getName());
25202530
cal->setAttributes(CI->getAttributes());
25212531
cal->setCallingConv(CI->getCallingConv());
@@ -2622,6 +2632,7 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB,
26222632
/*lookup*/ true, available);
26232633
auto cal =
26242634
NB.CreateCall(MS->getCalledFunction(), args, Defs);
2635+
cal->copyMetadata(*MS, MD_ToCopy);
26252636
cal->setAttributes(MS->getAttributes());
26262637
cal->setCallingConv(MS->getCallingConv());
26272638
cal->setTailCallKind(MS->getTailCallKind());
@@ -4235,6 +4246,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
42354246
#else
42364247
auto li = bb.CreateLoad(ip, arg->getName() + "'ipl");
42374248
#endif
4249+
li->copyMetadata(*arg, MD_ToCopy);
42384250
li->copyIRFlags(arg);
42394251
#if LLVM_VERSION_MAJOR >= 10
42404252
li->setAlignment(arg->getAlign());
@@ -5591,14 +5603,18 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
55915603
}
55925604
}
55935605

5594-
ensureLookupCached(inst, /*shouldFree*/ true, scope);
5606+
ensureLookupCached(inst, /*shouldFree*/ true, scope,
5607+
inst->getMetadata(LLVMContext::MD_tbaa));
55955608
bool isi1 = inst->getType()->isIntegerTy() &&
55965609
cast<IntegerType>(inst->getType())->getBitWidth() == 1;
55975610
assert(!isOriginalBlock(*BuilderM.GetInsertBlock()));
55985611
auto found = findInMap(scopeMap, (Value *)inst);
55995612
Value *result =
56005613
lookupValueFromCache(/*isForwardPass*/ false, BuilderM, found->second,
56015614
found->first, isi1, available);
5615+
if (auto LI2 = dyn_cast<LoadInst>(result))
5616+
if (auto LI1 = dyn_cast<LoadInst>(inst))
5617+
LI2->copyMetadata(*LI1, MD_ToCopy);
56025618
if (result->getType() != inst->getType()) {
56035619
llvm::errs() << "newFunc: " << *newFunc << "\n";
56045620
llvm::errs() << "result: " << *result << "\n";

enzyme/Enzyme/GradientUtils.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ extern llvm::cl::opt<bool> EnzymeInactiveDynamic;
9999
extern llvm::cl::opt<bool> EnzymeFreeInternalAllocations;
100100
extern llvm::cl::opt<bool> EnzymeRematerialize;
101101
}
102+
extern unsigned int MD_ToCopy[5];
102103

103104
struct InvertedPointerConfig : ValueMapConfig<const llvm::Value *> {
104105
typedef GradientUtils *ExtraData;
@@ -1444,7 +1445,8 @@ class GradientUtils : public CacheUtility {
14441445
bool permitCache = true) override final;
14451446

14461447
void ensureLookupCached(Instruction *inst, bool shouldFree = true,
1447-
BasicBlock *scope = nullptr) {
1448+
BasicBlock *scope = nullptr,
1449+
llvm::MDNode *TBAA = nullptr) {
14481450
assert(inst);
14491451
if (scopeMap.find(inst) != scopeMap.end())
14501452
return;
@@ -1463,7 +1465,7 @@ class GradientUtils : public CacheUtility {
14631465
insert_or_assign(
14641466
scopeMap, Val,
14651467
std::pair<AssertingVH<AllocaInst>, LimitContext>(cache, lctx));
1466-
storeInstructionInCache(lctx, inst, cache);
1468+
storeInstructionInCache(lctx, inst, cache, TBAA);
14671469
}
14681470

14691471
std::map<Instruction *, ValueMap<BasicBlock *, WeakTrackingVH>> lcssaFixes;

enzyme/Enzyme/Utils.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,16 @@ Function *getOrInsertDifferentialFloatMemcpy(Module &M, Type *elementType,
185185
return F;
186186
}
187187

188-
Function *getOrInsertMemcpyStrided(Module &M, PointerType *T, unsigned dstalign,
189-
unsigned srcalign) {
188+
Function *getOrInsertMemcpyStrided(Module &M, PointerType *T, Type *IT,
189+
unsigned dstalign, unsigned srcalign) {
190190
Type *elementType = T->getPointerElementType();
191191
assert(elementType->isFloatingPointTy());
192-
std::string name = "__enzyme_memcpy_" + tofltstr(elementType) + "da" +
193-
std::to_string(dstalign) + "sa" +
192+
std::string name = "__enzyme_memcpy_" + tofltstr(elementType) + "_" +
193+
std::to_string(cast<IntegerType>(IT)->getBitWidth()) +
194+
"_da" + std::to_string(dstalign) + "sa" +
194195
std::to_string(srcalign) + "stride";
195-
FunctionType *FT = FunctionType::get(Type::getVoidTy(M.getContext()),
196-
{T, T, Type::getInt32Ty(M.getContext()),
197-
Type::getInt32Ty(M.getContext())},
198-
false);
196+
FunctionType *FT =
197+
FunctionType::get(Type::getVoidTy(M.getContext()), {T, T, IT, IT}, false);
199198

200199
#if LLVM_VERSION_MAJOR >= 9
201200
Function *F = cast<Function>(M.getOrInsertFunction(name, FT).getCallee());

enzyme/Enzyme/Utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,8 @@ getOrInsertDifferentialFloatMemcpy(llvm::Module &M, llvm::Type *T,
573573

574574
/// Create function for type that performs memcpy with a stride
575575
llvm::Function *getOrInsertMemcpyStrided(llvm::Module &M, llvm::PointerType *T,
576-
unsigned dstalign, unsigned srcalign);
576+
llvm::Type *IT, unsigned dstalign,
577+
unsigned srcalign);
577578

578579
/// Create function for type that performs the derivative memmove on floating
579580
/// point memory

enzyme/test/Enzyme/ReverseMode/blas/cblas_ddot.ll

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -115,20 +115,18 @@ entry:
115115

116116
; CHECK: define internal { double*, double* } @[[augMod]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, double* %"n'", i32 %incn)
117117
; CHECK-NEXT: entry:
118-
; CHECK-NEXT: %0 = zext i32 %len to i64
119-
; CHECK-NEXT: %mallocsize = mul i64 %0, ptrtoint (double* getelementptr (double, double* null, i32 1) to i64)
120-
; CHECK-NEXT: %malloccall = tail call i8* @malloc(i64 %mallocsize)
121-
; CHECK-NEXT: %1 = bitcast i8* %malloccall to double*
122-
; CHECK-NEXT: call void @__enzyme_memcpy_doubleda0sa0stride(double* %1, double* %m, i32 %len, i32 %incm)
123-
; CHECK-NEXT: %2 = zext i32 %len to i64
124-
; CHECK-NEXT: %mallocsize1 = mul i64 %2, ptrtoint (double* getelementptr (double, double* null, i32 1) to i64)
125-
; CHECK-NEXT: %malloccall2 = tail call i8* @malloc(i64 %mallocsize1)
126-
; CHECK-NEXT: %3 = bitcast i8* %malloccall2 to double*
127-
; CHECK-NEXT: call void @__enzyme_memcpy_doubleda0sa0stride(double* %3, double* %n, i32 %len, i32 %incn)
128-
; CHECK-NEXT: %4 = insertvalue { double*, double* } undef, double* %1, 0
129-
; CHECK-NEXT: %5 = insertvalue { double*, double* } %4, double* %3, 1
118+
; CHECK-NEXT: %mallocsize = mul i32 %len, 8
119+
; CHECK-NEXT: %malloccall = tail call i8* @malloc(i32 %mallocsize)
120+
; CHECK-NEXT: %0 = bitcast i8* %malloccall to double*
121+
; CHECK-NEXT: call void @__enzyme_memcpy_double_32_da0sa0stride(double* %0, double* %m, i32 %len, i32 %incm)
122+
; CHECK-NEXT: %mallocsize1 = mul i32 %len, 8
123+
; CHECK-NEXT: %malloccall2 = tail call i8* @malloc(i32 %mallocsize1)
124+
; CHECK-NEXT: %1 = bitcast i8* %malloccall2 to double*
125+
; CHECK-NEXT: call void @__enzyme_memcpy_double_32_da0sa0stride(double* %1, double* %n, i32 %len, i32 %incn)
126+
; CHECK-NEXT: %2 = insertvalue { double*, double* } undef, double* %0, 0
127+
; CHECK-NEXT: %3 = insertvalue { double*, double* } %2, double* %1, 1
130128
; CHECK-NEXT: %call = call double @cblas_ddot(i32 %len, double* nocapture readonly %m, i32 %incm, double* nocapture readonly %n, i32 %incn)
131-
; CHECK-NEXT: ret { double*, double* } %5
129+
; CHECK-NEXT: ret { double*, double* } %3
132130
; CHECK-NEXT: }
133131

134132
; CHECK: define internal void @[[revMod]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, double* %"n'", i32 %incn, double %differeturn, { double*, double* }
@@ -153,13 +151,12 @@ entry:
153151

154152
; CHECK: define internal double* @[[augModFirst]](i32 %len, double* noalias %m, i32 %incm, double* noalias %n, double* %"n'", i32 %incn)
155153
; CHECK-NEXT: entry:
156-
; CHECK-NEXT: %0 = zext i32 %len to i64
157-
; CHECK-NEXT: %mallocsize = mul i64 %0, ptrtoint (double* getelementptr (double, double* null, i32 1) to i64)
158-
; CHECK-NEXT: %malloccall = tail call i8* @malloc(i64 %mallocsize)
159-
; CHECK-NEXT: %1 = bitcast i8* %malloccall to double*
160-
; CHECK-NEXT: call void @__enzyme_memcpy_doubleda0sa0stride(double* %1, double* %m, i32 %len, i32 %incm)
154+
; CHECK-NEXT: %mallocsize = mul i32 %len, 8
155+
; CHECK-NEXT: %malloccall = tail call i8* @malloc(i32 %mallocsize)
156+
; CHECK-NEXT: %0 = bitcast i8* %malloccall to double*
157+
; CHECK-NEXT: call void @__enzyme_memcpy_double_32_da0sa0stride(double* %0, double* %m, i32 %len, i32 %incm)
161158
; CHECK-NEXT: %call = call double @cblas_ddot(i32 %len, double* nocapture readonly %m, i32 %incm, double* nocapture readonly %n, i32 %incn)
162-
; CHECK-NEXT: ret double* %1
159+
; CHECK-NEXT: ret double* %0
163160
; CHECK-NEXT: }
164161

165162
; CHECK: define internal void @[[revModFirst]](i32 %len, double* noalias %m, i32 %incm, double* noalias %n, double* %"n'", i32 %incn, double %differeturn, double*
@@ -179,13 +176,12 @@ entry:
179176

180177
; CHECK: define internal double* @[[augModSecond]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, i32 %incn)
181178
; CHECK-NEXT: entry:
182-
; CHECK-NEXT: %0 = zext i32 %len to i64
183-
; CHECK-NEXT: %mallocsize = mul i64 %0, ptrtoint (double* getelementptr (double, double* null, i32 1) to i64)
184-
; CHECK-NEXT: %malloccall = tail call i8* @malloc(i64 %mallocsize)
185-
; CHECK-NEXT: %1 = bitcast i8* %malloccall to double*
186-
; CHECK-NEXT: call void @__enzyme_memcpy_doubleda0sa0stride(double* %1, double* %n, i32 %len, i32 %incn)
179+
; CHECK-NEXT: %mallocsize = mul i32 %len, 8
180+
; CHECK-NEXT: %malloccall = tail call i8* @malloc(i32 %mallocsize)
181+
; CHECK-NEXT: %0 = bitcast i8* %malloccall to double*
182+
; CHECK-NEXT: call void @__enzyme_memcpy_double_32_da0sa0stride(double* %0, double* %n, i32 %len, i32 %incn)
187183
; CHECK-NEXT: %call = call double @cblas_ddot(i32 %len, double* nocapture readonly %m, i32 %incm, double* nocapture readonly %n, i32 %incn)
188-
; CHECK-NEXT: ret double* %1
184+
; CHECK-NEXT: ret double* %0
189185
; CHECK-NEXT: }
190186

191187
; CHECK: define internal void @[[revModSecond]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, i32 %incn, double %differeturn, double*

0 commit comments

Comments
 (0)