@@ -376,15 +376,18 @@ class AdjointGenerator
376
376
377
377
#if LLVM_VERSION_MAJOR >= 10
378
378
void visitLoadLike (llvm::Instruction &I, MaybeAlign alignment,
379
- bool constantval, bool can_modref,
380
- Value *OrigOffset = nullptr )
379
+ bool constantval, Value *OrigOffset = nullptr ,
381
380
#else
382
381
void visitLoadLike (llvm::Instruction &I, unsigned alignment, bool constantval,
383
- bool can_modref, Value *OrigOffset = nullptr )
382
+ Value *OrigOffset = nullptr ,
384
383
#endif
385
- {
384
+ Value *mask = nullptr , Value *orig_maskInit = nullptr ) {
386
385
auto &DL = gutils->newFunc ->getParent ()->getDataLayout ();
387
386
387
+ assert (gutils->can_modref_map );
388
+ assert (gutils->can_modref_map ->find (&I) != gutils->can_modref_map ->end ());
389
+ bool can_modref = gutils->can_modref_map ->find (&I)->second ;
390
+
388
391
constantval |= gutils->isConstantValue (&I);
389
392
390
393
BasicBlock *parent = I.getParent ();
@@ -536,7 +539,8 @@ class AdjointGenerator
536
539
// the instruction if the value is a potential pointer. This may not be
537
540
// caught by type analysis is the result does not have a known type.
538
541
if (!gutils->isConstantInstruction (&I)) {
539
- bool isfloat = type->isFPOrFPVectorTy ();
542
+ Type *isfloat =
543
+ type->isFPOrFPVectorTy () ? type->getScalarType () : nullptr ;
540
544
if (!isfloat && type->isIntOrIntVectorTy ()) {
541
545
auto LoadSize = DL.getTypeSizeInBits (type) / 8 ;
542
546
ConcreteType vd = BaseType::Unknown;
@@ -560,8 +564,34 @@ class AdjointGenerator
560
564
getForwardBuilder (Builder2);
561
565
562
566
if (!gutils->isConstantValue (&I)) {
563
- auto diff = Builder2.CreateLoad (
564
- gutils->invertPointerM (I.getOperand (0 ), Builder2));
567
+ Value *diff;
568
+ if (!mask) {
569
+ auto LI = Builder2.CreateLoad (
570
+ gutils->invertPointerM (I.getOperand (0 ), Builder2));
571
+ if (alignment)
572
+ #if LLVM_VERSION_MAJOR >= 10
573
+ LI->setAlignment (*alignment);
574
+ #else
575
+ LI->setAlignment (alignment);
576
+ #endif
577
+ diff = LI;
578
+ } else {
579
+ Type *tys[] = {I.getType (), I.getOperand (0 )->getType ()};
580
+ auto F = Intrinsic::getDeclaration (gutils->oldFunc ->getParent (),
581
+ Intrinsic::masked_load, tys);
582
+ #if LLVM_VERSION_MAJOR >= 10
583
+ Value *alignv =
584
+ ConstantInt::get (Type::getInt32Ty (mask->getContext ()),
585
+ alignment ? alignment->value () : 0 );
586
+ #else
587
+ Value *alignv = ConstantInt::get (
588
+ Type::getInt32Ty (mask->getContext ()), alignment);
589
+ #endif
590
+ Value *args[] = {
591
+ gutils->invertPointerM (I.getOperand (0 ), Builder2), alignv,
592
+ mask, diffe (orig_maskInit, Builder2)};
593
+ diff = Builder2.CreateCall (F, args);
594
+ }
565
595
setDiffe (&I, diff, Builder2);
566
596
}
567
597
break ;
@@ -576,8 +606,13 @@ class AdjointGenerator
576
606
577
607
if (!gutils->isConstantValue (I.getOperand (0 ))) {
578
608
((DiffeGradientUtils *)gutils)
579
- ->addToInvertedPtrDiffe (I.getOperand (0 ), prediff, Builder2,
580
- alignment, OrigOffset);
609
+ ->addToInvertedPtrDiffe (
610
+ I.getOperand (0 ), prediff, Builder2, alignment, OrigOffset,
611
+ mask ? lookup (mask, Builder2) : nullptr );
612
+ }
613
+ if (mask && !gutils->isConstantValue (orig_maskInit)) {
614
+ addToDiffe (orig_maskInit, prediff, Builder2, isfloat,
615
+ Builder2.CreateNot (mask));
581
616
}
582
617
break ;
583
618
}
@@ -614,10 +649,7 @@ class AdjointGenerator
614
649
auto &DL = gutils->newFunc ->getParent ()->getDataLayout ();
615
650
616
651
bool constantval = parseTBAA (LI, DL).Inner0 ().isIntegral ();
617
- assert (gutils->can_modref_map );
618
- assert (gutils->can_modref_map ->find (&LI) != gutils->can_modref_map ->end ());
619
- bool can_modref = gutils->can_modref_map ->find (&LI)->second ;
620
- visitLoadLike (LI, alignment, constantval, can_modref);
652
+ visitLoadLike (LI, alignment, constantval);
621
653
eraseIfUnused (LI);
622
654
}
623
655
@@ -636,15 +668,9 @@ class AdjointGenerator
636
668
}
637
669
638
670
void visitStoreInst (llvm::StoreInst &SI) {
639
- Value *orig_ptr = SI.getPointerOperand ();
640
- Value *orig_val = SI.getValueOperand ();
641
- Value *val = gutils->getNewFromOriginal (orig_val);
642
- Type *valType = orig_val->getType ();
643
-
644
- auto &DL = gutils->newFunc ->getParent ()->getDataLayout ();
645
671
// If a store of an omp init argument, don't delete in reverse
646
672
// and don't do any adjoint propagation (assumed integral)
647
- for (auto U : orig_ptr ->users ()) {
673
+ for (auto U : SI. getPointerOperand () ->users ()) {
648
674
if (auto CI = dyn_cast<CallInst>(U)) {
649
675
if (auto F = CI->getCalledFunction ()) {
650
676
if (F->getName () == " __kmpc_for_static_init_4" ||
@@ -656,24 +682,47 @@ class AdjointGenerator
656
682
}
657
683
}
658
684
}
685
+ #if LLVM_VERSION_MAJOR >= 10
686
+ auto align = SI.getAlign ();
687
+ #else
688
+ auto align = SI.getAlignment ();
689
+ #endif
690
+ visitCommonStore (SI, SI.getPointerOperand (), SI.getValueOperand (), align,
691
+ SI.isVolatile (), SI.getOrdering (), SI.getSyncScopeID (),
692
+ /* mask=*/ nullptr );
693
+ eraseIfUnused (SI);
694
+ }
695
+
696
+ #if LLVM_VERSION_MAJOR >= 10
697
+ void visitCommonStore (llvm::Instruction &I, Value *orig_ptr, Value *orig_val,
698
+ MaybeAlign align, bool isVolatile,
699
+ AtomicOrdering ordering, SyncScope::ID syncScope,
700
+ Value *mask = nullptr )
701
+ #else
702
+ void visitCommonStore (llvm::Instruction &I, Value *orig_ptr, Value *orig_val,
703
+ unsigned align, bool isVolatile,
704
+ AtomicOrdering ordering, SyncScope::ID syncScope,
705
+ Value *mask = nullptr )
706
+ #endif
707
+ {
708
+ Value *val = gutils->getNewFromOriginal (orig_val);
709
+ Type *valType = orig_val->getType ();
710
+
711
+ auto &DL = gutils->newFunc ->getParent ()->getDataLayout ();
659
712
660
- if (unnecessaryStores.count (&SI)) {
661
- eraseIfUnused (SI);
713
+ if (unnecessaryStores.count (&I)) {
662
714
return ;
663
715
}
664
716
665
717
if (gutils->isConstantValue (orig_ptr)) {
666
- eraseIfUnused (SI);
667
718
return ;
668
719
}
669
720
670
721
bool constantval = gutils->isConstantValue (orig_val) ||
671
- parseTBAA (SI , DL).Inner0 ().isIntegral ();
722
+ parseTBAA (I , DL).Inner0 ().isIntegral ();
672
723
673
724
// TODO allow recognition of other types that could contain pointers [e.g.
674
725
// {void*, void*} or <2 x i64> ]
675
- StoreInst *ts = nullptr ;
676
-
677
726
auto storeSize = DL.getTypeSizeInBits (valType) / 8 ;
678
727
679
728
// ! Storing a floating point value
@@ -688,12 +737,12 @@ class AdjointGenerator
688
737
FT = fp.isFloat ();
689
738
} else if (isa<ConstantInt>(orig_val) ||
690
739
valType->isIntOrIntVectorTy ()) {
691
- llvm::errs () << " assuming type as integral for store: " << SI << " \n " ;
740
+ llvm::errs () << " assuming type as integral for store: " << I << " \n " ;
692
741
FT = nullptr ;
693
742
} else {
694
743
TR.firstPointer (storeSize, orig_ptr, /* errifnotfound*/ true ,
695
744
/* pointerIntSame*/ true );
696
- llvm::errs () << " cannot deduce type of store " << SI << " \n " ;
745
+ llvm::errs () << " cannot deduce type of store " << I << " \n " ;
697
746
assert (0 && " cannot deduce" );
698
747
}
699
748
} else {
@@ -710,35 +759,61 @@ class AdjointGenerator
710
759
break ;
711
760
case DerivativeMode::ReverseModeGradient:
712
761
case DerivativeMode::ReverseModeCombined: {
713
- IRBuilder<> Builder2 (SI .getParent ());
762
+ IRBuilder<> Builder2 (I .getParent ());
714
763
getReverseBuilder (Builder2);
715
764
716
765
if (constantval) {
717
- ts = setPtrDiffe (orig_ptr, Constant::getNullValue (valType), Builder2);
766
+ gutils->setPtrDiffe (orig_ptr, Constant::getNullValue (valType),
767
+ Builder2, align, isVolatile, ordering, syncScope,
768
+ mask);
718
769
} else {
719
- auto dif1 = Builder2.CreateLoad (
720
- lookup (gutils->invertPointerM (orig_ptr, Builder2), Builder2));
770
+ Value *diff;
771
+ if (!mask) {
772
+ auto dif1 = Builder2.CreateLoad (
773
+ lookup (gutils->invertPointerM (orig_ptr, Builder2), Builder2),
774
+ isVolatile);
775
+ if (align)
776
+ #if LLVM_VERSION_MAJOR >= 10
777
+ dif1->setAlignment (*align);
778
+ #else
779
+ dif1->setAlignment (align);
780
+ #endif
781
+ dif1->setOrdering (ordering);
782
+ dif1->setSyncScopeID (syncScope);
783
+ diff = dif1;
784
+ } else {
785
+ mask = lookup (mask, Builder2);
786
+ Type *tys[] = {valType, orig_ptr->getType ()};
787
+ auto F = Intrinsic::getDeclaration (gutils->oldFunc ->getParent (),
788
+ Intrinsic::masked_load, tys);
721
789
#if LLVM_VERSION_MAJOR >= 10
722
- dif1->setAlignment (SI.getAlign ());
790
+ Value *alignv =
791
+ ConstantInt::get (Type::getInt32Ty (mask->getContext ()),
792
+ align ? align->value () : 0 );
723
793
#else
724
- dif1->setAlignment (SI.getAlignment ());
794
+ Value *alignv =
795
+ ConstantInt::get (Type::getInt32Ty (mask->getContext ()), align);
725
796
#endif
726
- ts = setPtrDiffe (orig_ptr, Constant::getNullValue (valType), Builder2);
727
- addToDiffe (orig_val, dif1, Builder2, FT);
797
+ Value *args[] = {
798
+ lookup (gutils->invertPointerM (orig_ptr, Builder2), Builder2),
799
+ alignv, mask, Constant::getNullValue (valType)};
800
+ diff = Builder2.CreateCall (F, args);
801
+ }
802
+ gutils->setPtrDiffe (orig_ptr, Constant::getNullValue (valType),
803
+ Builder2, align, isVolatile, ordering, syncScope,
804
+ mask);
805
+ addToDiffe (orig_val, diff, Builder2, FT, mask);
728
806
}
729
807
break ;
730
808
}
731
809
case DerivativeMode::ForwardMode: {
732
- IRBuilder<> Builder2 (&SI );
810
+ IRBuilder<> Builder2 (&I );
733
811
getForwardBuilder (Builder2);
734
812
735
- if (constantval) {
736
- ts = setPtrDiffe (orig_ptr, Constant::getNullValue (valType), Builder2);
737
- } else {
738
- auto diff = diffe (orig_val, Builder2);
739
-
740
- ts = setPtrDiffe (orig_ptr, diff, Builder2);
741
- }
813
+ Value *diff = constantval ? Constant::getNullValue (valType)
814
+ : diffe (orig_val, Builder2);
815
+ gutils->setPtrDiffe (orig_ptr, diff, Builder2, align, isVolatile,
816
+ ordering, syncScope, mask);
742
817
break ;
743
818
}
744
819
}
@@ -749,7 +824,7 @@ class AdjointGenerator
749
824
if (Mode == DerivativeMode::ReverseModePrimal ||
750
825
Mode == DerivativeMode::ReverseModeCombined ||
751
826
Mode == DerivativeMode::ForwardMode) {
752
- IRBuilder<> storeBuilder (gutils->getNewFromOriginal (&SI ));
827
+ IRBuilder<> storeBuilder (gutils->getNewFromOriginal (&I ));
753
828
754
829
Value *valueop = nullptr ;
755
830
@@ -758,21 +833,10 @@ class AdjointGenerator
758
833
} else {
759
834
valueop = gutils->invertPointerM (orig_val, storeBuilder);
760
835
}
761
- ts = setPtrDiffe (orig_ptr, valueop, storeBuilder);
836
+ gutils->setPtrDiffe (orig_ptr, valueop, storeBuilder, align, isVolatile,
837
+ ordering, syncScope, mask);
762
838
}
763
839
}
764
-
765
- if (ts) {
766
- #if LLVM_VERSION_MAJOR >= 10
767
- ts->setAlignment (SI.getAlign ());
768
- #else
769
- ts->setAlignment (SI.getAlignment ());
770
- #endif
771
- ts->setVolatile (SI.isVolatile ());
772
- ts->setOrdering (SI.getOrdering ());
773
- ts->setSyncScopeID (SI.getSyncScopeID ());
774
- }
775
- eraseIfUnused (SI);
776
840
}
777
841
778
842
void visitGetElementPtrInst (llvm::GetElementPtrInst &gep) {
@@ -1366,13 +1430,11 @@ class AdjointGenerator
1366
1430
((DiffeGradientUtils *)gutils)->setDiffe (val, dif, Builder);
1367
1431
}
1368
1432
1369
- StoreInst *setPtrDiffe (Value *val, Value *dif, IRBuilder<> &Builder) {
1370
- return gutils->setPtrDiffe (val, dif, Builder);
1371
- }
1372
-
1373
1433
std::vector<SelectInst *> addToDiffe (Value *val, Value *dif,
1374
- IRBuilder<> &Builder, Type *T) {
1375
- return ((DiffeGradientUtils *)gutils)->addToDiffe (val, dif, Builder, T);
1434
+ IRBuilder<> &Builder, Type *T,
1435
+ Value *mask = nullptr ) {
1436
+ return ((DiffeGradientUtils *)gutils)
1437
+ ->addToDiffe (val, dif, Builder, T, /* idxs*/ {}, mask);
1376
1438
}
1377
1439
1378
1440
Value *lookup (Value *val, IRBuilder<> &Builder) {
@@ -2351,17 +2413,45 @@ class AdjointGenerator
2351
2413
auto CI = cast<ConstantInt>(I.getOperand (1 ));
2352
2414
#if LLVM_VERSION_MAJOR >= 10
2353
2415
visitLoadLike (I, /* Align*/ MaybeAlign (CI->getZExtValue ()),
2354
- /* constantval*/ false ,
2355
- /* can_modref*/ false );
2416
+ /* constantval*/ false );
2356
2417
#else
2357
- visitLoadLike (I, /* Align*/ CI->getZExtValue (), /* constantval*/ false ,
2358
- /* can_modref*/ false );
2418
+ visitLoadLike (I, /* Align*/ CI->getZExtValue (), /* constantval*/ false );
2359
2419
#endif
2360
2420
return ;
2361
2421
}
2362
2422
default :
2363
2423
break ;
2364
2424
}
2425
+
2426
+ if (ID == Intrinsic::masked_store) {
2427
+ auto align0 = cast<ConstantInt>(I.getOperand (2 ))->getZExtValue ();
2428
+ #if LLVM_VERSION_MAJOR >= 10
2429
+ auto align = MaybeAlign (align0);
2430
+ #else
2431
+ auto align = align0;
2432
+ #endif
2433
+ visitCommonStore (I, /* orig_ptr*/ I.getOperand (1 ),
2434
+ /* orig_val*/ I.getOperand (0 ), align,
2435
+ /* isVolatile*/ false , llvm::AtomicOrdering::NotAtomic,
2436
+ SyncScope::SingleThread,
2437
+ /* mask*/ gutils->getNewFromOriginal (I.getOperand (3 )));
2438
+ return ;
2439
+ }
2440
+ if (ID == Intrinsic::masked_load) {
2441
+ auto align0 = cast<ConstantInt>(I.getOperand (1 ))->getZExtValue ();
2442
+ #if LLVM_VERSION_MAJOR >= 10
2443
+ auto align = MaybeAlign (align0);
2444
+ #else
2445
+ auto align = align0;
2446
+ #endif
2447
+ auto &DL = gutils->newFunc ->getParent ()->getDataLayout ();
2448
+ bool constantval = parseTBAA (I, DL).Inner0 ().isIntegral ();
2449
+ visitLoadLike (I, align, constantval, /* OrigOffset*/ nullptr ,
2450
+ /* mask*/ gutils->getNewFromOriginal (I.getOperand (2 )),
2451
+ /* orig_maskInit*/ I.getOperand (3 ));
2452
+ return ;
2453
+ }
2454
+
2365
2455
switch (Mode) {
2366
2456
case DerivativeMode::ReverseModePrimal: {
2367
2457
switch (ID) {
0 commit comments