Skip to content

Commit 249121f

Browse files
authored
Masked load/store support (rust-lang#344)
* Masked store support * Handle masked loads * Add tests * Fix format * Test fwd masked * fixup
1 parent a171f7e commit 249121f

File tree

9 files changed

+660
-310
lines changed

9 files changed

+660
-310
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 158 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -376,15 +376,18 @@ class AdjointGenerator
376376

377377
#if LLVM_VERSION_MAJOR >= 10
378378
void visitLoadLike(llvm::Instruction &I, MaybeAlign alignment,
379-
bool constantval, bool can_modref,
380-
Value *OrigOffset = nullptr)
379+
bool constantval, Value *OrigOffset = nullptr,
381380
#else
382381
void visitLoadLike(llvm::Instruction &I, unsigned alignment, bool constantval,
383-
bool can_modref, Value *OrigOffset = nullptr)
382+
Value *OrigOffset = nullptr,
384383
#endif
385-
{
384+
Value *mask = nullptr, Value *orig_maskInit = nullptr) {
386385
auto &DL = gutils->newFunc->getParent()->getDataLayout();
387386

387+
assert(gutils->can_modref_map);
388+
assert(gutils->can_modref_map->find(&I) != gutils->can_modref_map->end());
389+
bool can_modref = gutils->can_modref_map->find(&I)->second;
390+
388391
constantval |= gutils->isConstantValue(&I);
389392

390393
BasicBlock *parent = I.getParent();
@@ -536,7 +539,8 @@ class AdjointGenerator
536539
// the instruction if the value is a potential pointer. This may not be
537540
// caught by type analysis is the result does not have a known type.
538541
if (!gutils->isConstantInstruction(&I)) {
539-
bool isfloat = type->isFPOrFPVectorTy();
542+
Type *isfloat =
543+
type->isFPOrFPVectorTy() ? type->getScalarType() : nullptr;
540544
if (!isfloat && type->isIntOrIntVectorTy()) {
541545
auto LoadSize = DL.getTypeSizeInBits(type) / 8;
542546
ConcreteType vd = BaseType::Unknown;
@@ -560,8 +564,34 @@ class AdjointGenerator
560564
getForwardBuilder(Builder2);
561565

562566
if (!gutils->isConstantValue(&I)) {
563-
auto diff = Builder2.CreateLoad(
564-
gutils->invertPointerM(I.getOperand(0), Builder2));
567+
Value *diff;
568+
if (!mask) {
569+
auto LI = Builder2.CreateLoad(
570+
gutils->invertPointerM(I.getOperand(0), Builder2));
571+
if (alignment)
572+
#if LLVM_VERSION_MAJOR >= 10
573+
LI->setAlignment(*alignment);
574+
#else
575+
LI->setAlignment(alignment);
576+
#endif
577+
diff = LI;
578+
} else {
579+
Type *tys[] = {I.getType(), I.getOperand(0)->getType()};
580+
auto F = Intrinsic::getDeclaration(gutils->oldFunc->getParent(),
581+
Intrinsic::masked_load, tys);
582+
#if LLVM_VERSION_MAJOR >= 10
583+
Value *alignv =
584+
ConstantInt::get(Type::getInt32Ty(mask->getContext()),
585+
alignment ? alignment->value() : 0);
586+
#else
587+
Value *alignv = ConstantInt::get(
588+
Type::getInt32Ty(mask->getContext()), alignment);
589+
#endif
590+
Value *args[] = {
591+
gutils->invertPointerM(I.getOperand(0), Builder2), alignv,
592+
mask, diffe(orig_maskInit, Builder2)};
593+
diff = Builder2.CreateCall(F, args);
594+
}
565595
setDiffe(&I, diff, Builder2);
566596
}
567597
break;
@@ -576,8 +606,13 @@ class AdjointGenerator
576606

577607
if (!gutils->isConstantValue(I.getOperand(0))) {
578608
((DiffeGradientUtils *)gutils)
579-
->addToInvertedPtrDiffe(I.getOperand(0), prediff, Builder2,
580-
alignment, OrigOffset);
609+
->addToInvertedPtrDiffe(
610+
I.getOperand(0), prediff, Builder2, alignment, OrigOffset,
611+
mask ? lookup(mask, Builder2) : nullptr);
612+
}
613+
if (mask && !gutils->isConstantValue(orig_maskInit)) {
614+
addToDiffe(orig_maskInit, prediff, Builder2, isfloat,
615+
Builder2.CreateNot(mask));
581616
}
582617
break;
583618
}
@@ -614,10 +649,7 @@ class AdjointGenerator
614649
auto &DL = gutils->newFunc->getParent()->getDataLayout();
615650

616651
bool constantval = parseTBAA(LI, DL).Inner0().isIntegral();
617-
assert(gutils->can_modref_map);
618-
assert(gutils->can_modref_map->find(&LI) != gutils->can_modref_map->end());
619-
bool can_modref = gutils->can_modref_map->find(&LI)->second;
620-
visitLoadLike(LI, alignment, constantval, can_modref);
652+
visitLoadLike(LI, alignment, constantval);
621653
eraseIfUnused(LI);
622654
}
623655

@@ -636,15 +668,9 @@ class AdjointGenerator
636668
}
637669

638670
void visitStoreInst(llvm::StoreInst &SI) {
639-
Value *orig_ptr = SI.getPointerOperand();
640-
Value *orig_val = SI.getValueOperand();
641-
Value *val = gutils->getNewFromOriginal(orig_val);
642-
Type *valType = orig_val->getType();
643-
644-
auto &DL = gutils->newFunc->getParent()->getDataLayout();
645671
// If a store of an omp init argument, don't delete in reverse
646672
// and don't do any adjoint propagation (assumed integral)
647-
for (auto U : orig_ptr->users()) {
673+
for (auto U : SI.getPointerOperand()->users()) {
648674
if (auto CI = dyn_cast<CallInst>(U)) {
649675
if (auto F = CI->getCalledFunction()) {
650676
if (F->getName() == "__kmpc_for_static_init_4" ||
@@ -656,24 +682,47 @@ class AdjointGenerator
656682
}
657683
}
658684
}
685+
#if LLVM_VERSION_MAJOR >= 10
686+
auto align = SI.getAlign();
687+
#else
688+
auto align = SI.getAlignment();
689+
#endif
690+
visitCommonStore(SI, SI.getPointerOperand(), SI.getValueOperand(), align,
691+
SI.isVolatile(), SI.getOrdering(), SI.getSyncScopeID(),
692+
/*mask=*/nullptr);
693+
eraseIfUnused(SI);
694+
}
695+
696+
#if LLVM_VERSION_MAJOR >= 10
697+
void visitCommonStore(llvm::Instruction &I, Value *orig_ptr, Value *orig_val,
698+
MaybeAlign align, bool isVolatile,
699+
AtomicOrdering ordering, SyncScope::ID syncScope,
700+
Value *mask = nullptr)
701+
#else
702+
void visitCommonStore(llvm::Instruction &I, Value *orig_ptr, Value *orig_val,
703+
unsigned align, bool isVolatile,
704+
AtomicOrdering ordering, SyncScope::ID syncScope,
705+
Value *mask = nullptr)
706+
#endif
707+
{
708+
Value *val = gutils->getNewFromOriginal(orig_val);
709+
Type *valType = orig_val->getType();
710+
711+
auto &DL = gutils->newFunc->getParent()->getDataLayout();
659712

660-
if (unnecessaryStores.count(&SI)) {
661-
eraseIfUnused(SI);
713+
if (unnecessaryStores.count(&I)) {
662714
return;
663715
}
664716

665717
if (gutils->isConstantValue(orig_ptr)) {
666-
eraseIfUnused(SI);
667718
return;
668719
}
669720

670721
bool constantval = gutils->isConstantValue(orig_val) ||
671-
parseTBAA(SI, DL).Inner0().isIntegral();
722+
parseTBAA(I, DL).Inner0().isIntegral();
672723

673724
// TODO allow recognition of other types that could contain pointers [e.g.
674725
// {void*, void*} or <2 x i64> ]
675-
StoreInst *ts = nullptr;
676-
677726
auto storeSize = DL.getTypeSizeInBits(valType) / 8;
678727

679728
//! Storing a floating point value
@@ -688,12 +737,12 @@ class AdjointGenerator
688737
FT = fp.isFloat();
689738
} else if (isa<ConstantInt>(orig_val) ||
690739
valType->isIntOrIntVectorTy()) {
691-
llvm::errs() << "assuming type as integral for store: " << SI << "\n";
740+
llvm::errs() << "assuming type as integral for store: " << I << "\n";
692741
FT = nullptr;
693742
} else {
694743
TR.firstPointer(storeSize, orig_ptr, /*errifnotfound*/ true,
695744
/*pointerIntSame*/ true);
696-
llvm::errs() << "cannot deduce type of store " << SI << "\n";
745+
llvm::errs() << "cannot deduce type of store " << I << "\n";
697746
assert(0 && "cannot deduce");
698747
}
699748
} else {
@@ -710,35 +759,61 @@ class AdjointGenerator
710759
break;
711760
case DerivativeMode::ReverseModeGradient:
712761
case DerivativeMode::ReverseModeCombined: {
713-
IRBuilder<> Builder2(SI.getParent());
762+
IRBuilder<> Builder2(I.getParent());
714763
getReverseBuilder(Builder2);
715764

716765
if (constantval) {
717-
ts = setPtrDiffe(orig_ptr, Constant::getNullValue(valType), Builder2);
766+
gutils->setPtrDiffe(orig_ptr, Constant::getNullValue(valType),
767+
Builder2, align, isVolatile, ordering, syncScope,
768+
mask);
718769
} else {
719-
auto dif1 = Builder2.CreateLoad(
720-
lookup(gutils->invertPointerM(orig_ptr, Builder2), Builder2));
770+
Value *diff;
771+
if (!mask) {
772+
auto dif1 = Builder2.CreateLoad(
773+
lookup(gutils->invertPointerM(orig_ptr, Builder2), Builder2),
774+
isVolatile);
775+
if (align)
776+
#if LLVM_VERSION_MAJOR >= 10
777+
dif1->setAlignment(*align);
778+
#else
779+
dif1->setAlignment(align);
780+
#endif
781+
dif1->setOrdering(ordering);
782+
dif1->setSyncScopeID(syncScope);
783+
diff = dif1;
784+
} else {
785+
mask = lookup(mask, Builder2);
786+
Type *tys[] = {valType, orig_ptr->getType()};
787+
auto F = Intrinsic::getDeclaration(gutils->oldFunc->getParent(),
788+
Intrinsic::masked_load, tys);
721789
#if LLVM_VERSION_MAJOR >= 10
722-
dif1->setAlignment(SI.getAlign());
790+
Value *alignv =
791+
ConstantInt::get(Type::getInt32Ty(mask->getContext()),
792+
align ? align->value() : 0);
723793
#else
724-
dif1->setAlignment(SI.getAlignment());
794+
Value *alignv =
795+
ConstantInt::get(Type::getInt32Ty(mask->getContext()), align);
725796
#endif
726-
ts = setPtrDiffe(orig_ptr, Constant::getNullValue(valType), Builder2);
727-
addToDiffe(orig_val, dif1, Builder2, FT);
797+
Value *args[] = {
798+
lookup(gutils->invertPointerM(orig_ptr, Builder2), Builder2),
799+
alignv, mask, Constant::getNullValue(valType)};
800+
diff = Builder2.CreateCall(F, args);
801+
}
802+
gutils->setPtrDiffe(orig_ptr, Constant::getNullValue(valType),
803+
Builder2, align, isVolatile, ordering, syncScope,
804+
mask);
805+
addToDiffe(orig_val, diff, Builder2, FT, mask);
728806
}
729807
break;
730808
}
731809
case DerivativeMode::ForwardMode: {
732-
IRBuilder<> Builder2(&SI);
810+
IRBuilder<> Builder2(&I);
733811
getForwardBuilder(Builder2);
734812

735-
if (constantval) {
736-
ts = setPtrDiffe(orig_ptr, Constant::getNullValue(valType), Builder2);
737-
} else {
738-
auto diff = diffe(orig_val, Builder2);
739-
740-
ts = setPtrDiffe(orig_ptr, diff, Builder2);
741-
}
813+
Value *diff = constantval ? Constant::getNullValue(valType)
814+
: diffe(orig_val, Builder2);
815+
gutils->setPtrDiffe(orig_ptr, diff, Builder2, align, isVolatile,
816+
ordering, syncScope, mask);
742817
break;
743818
}
744819
}
@@ -749,7 +824,7 @@ class AdjointGenerator
749824
if (Mode == DerivativeMode::ReverseModePrimal ||
750825
Mode == DerivativeMode::ReverseModeCombined ||
751826
Mode == DerivativeMode::ForwardMode) {
752-
IRBuilder<> storeBuilder(gutils->getNewFromOriginal(&SI));
827+
IRBuilder<> storeBuilder(gutils->getNewFromOriginal(&I));
753828

754829
Value *valueop = nullptr;
755830

@@ -758,21 +833,10 @@ class AdjointGenerator
758833
} else {
759834
valueop = gutils->invertPointerM(orig_val, storeBuilder);
760835
}
761-
ts = setPtrDiffe(orig_ptr, valueop, storeBuilder);
836+
gutils->setPtrDiffe(orig_ptr, valueop, storeBuilder, align, isVolatile,
837+
ordering, syncScope, mask);
762838
}
763839
}
764-
765-
if (ts) {
766-
#if LLVM_VERSION_MAJOR >= 10
767-
ts->setAlignment(SI.getAlign());
768-
#else
769-
ts->setAlignment(SI.getAlignment());
770-
#endif
771-
ts->setVolatile(SI.isVolatile());
772-
ts->setOrdering(SI.getOrdering());
773-
ts->setSyncScopeID(SI.getSyncScopeID());
774-
}
775-
eraseIfUnused(SI);
776840
}
777841

778842
void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) {
@@ -1366,13 +1430,11 @@ class AdjointGenerator
13661430
((DiffeGradientUtils *)gutils)->setDiffe(val, dif, Builder);
13671431
}
13681432

1369-
StoreInst *setPtrDiffe(Value *val, Value *dif, IRBuilder<> &Builder) {
1370-
return gutils->setPtrDiffe(val, dif, Builder);
1371-
}
1372-
13731433
std::vector<SelectInst *> addToDiffe(Value *val, Value *dif,
1374-
IRBuilder<> &Builder, Type *T) {
1375-
return ((DiffeGradientUtils *)gutils)->addToDiffe(val, dif, Builder, T);
1434+
IRBuilder<> &Builder, Type *T,
1435+
Value *mask = nullptr) {
1436+
return ((DiffeGradientUtils *)gutils)
1437+
->addToDiffe(val, dif, Builder, T, /*idxs*/ {}, mask);
13761438
}
13771439

13781440
Value *lookup(Value *val, IRBuilder<> &Builder) {
@@ -2351,17 +2413,45 @@ class AdjointGenerator
23512413
auto CI = cast<ConstantInt>(I.getOperand(1));
23522414
#if LLVM_VERSION_MAJOR >= 10
23532415
visitLoadLike(I, /*Align*/ MaybeAlign(CI->getZExtValue()),
2354-
/*constantval*/ false,
2355-
/*can_modref*/ false);
2416+
/*constantval*/ false);
23562417
#else
2357-
visitLoadLike(I, /*Align*/ CI->getZExtValue(), /*constantval*/ false,
2358-
/*can_modref*/ false);
2418+
visitLoadLike(I, /*Align*/ CI->getZExtValue(), /*constantval*/ false);
23592419
#endif
23602420
return;
23612421
}
23622422
default:
23632423
break;
23642424
}
2425+
2426+
if (ID == Intrinsic::masked_store) {
2427+
auto align0 = cast<ConstantInt>(I.getOperand(2))->getZExtValue();
2428+
#if LLVM_VERSION_MAJOR >= 10
2429+
auto align = MaybeAlign(align0);
2430+
#else
2431+
auto align = align0;
2432+
#endif
2433+
visitCommonStore(I, /*orig_ptr*/ I.getOperand(1),
2434+
/*orig_val*/ I.getOperand(0), align,
2435+
/*isVolatile*/ false, llvm::AtomicOrdering::NotAtomic,
2436+
SyncScope::SingleThread,
2437+
/*mask*/ gutils->getNewFromOriginal(I.getOperand(3)));
2438+
return;
2439+
}
2440+
if (ID == Intrinsic::masked_load) {
2441+
auto align0 = cast<ConstantInt>(I.getOperand(1))->getZExtValue();
2442+
#if LLVM_VERSION_MAJOR >= 10
2443+
auto align = MaybeAlign(align0);
2444+
#else
2445+
auto align = align0;
2446+
#endif
2447+
auto &DL = gutils->newFunc->getParent()->getDataLayout();
2448+
bool constantval = parseTBAA(I, DL).Inner0().isIntegral();
2449+
visitLoadLike(I, align, constantval, /*OrigOffset*/ nullptr,
2450+
/*mask*/ gutils->getNewFromOriginal(I.getOperand(2)),
2451+
/*orig_maskInit*/ I.getOperand(3));
2452+
return;
2453+
}
2454+
23652455
switch (Mode) {
23662456
case DerivativeMode::ReverseModePrimal: {
23672457
switch (ID) {

0 commit comments

Comments
 (0)