Skip to content

Commit 146643f

Browse files
authored
Handle non-strict select / phi (rust-lang#763)
* Handle non-strict select * Fix non-strict phi
1 parent a959be9 commit 146643f

File tree

4 files changed

+210
-13
lines changed

4 files changed

+210
-13
lines changed

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1491,7 +1491,8 @@ void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) {
14911491
#else
14921492
APInt ai(DL.getPointerSize(gep.getPointerAddressSpace()) * 8, 0);
14931493
#endif
1494-
g2->accumulateConstantOffset(DL, ai);
1494+
bool valid = g2->accumulateConstantOffset(DL, ai);
1495+
assert(valid);
14951496
// Using destructor rather than eraseFromParent
14961497
// as g2 has no parent
14971498
delete g2;
@@ -1539,14 +1540,31 @@ void TypeAnalyzer::visitPHINode(PHINode &phi) {
15391540
TypeTree upVal = getAnalysis(&phi);
15401541
// only propagate anything's up if there is one
15411542
// incoming value
1542-
if (phi.getNumIncomingValues() >= 2) {
1543+
Value *seen = phi.getIncomingValue(0);
1544+
for (size_t i = 0, end = phi.getNumIncomingValues(); i < end; ++i) {
1545+
if (seen != phi.getIncomingValue(i)) {
1546+
seen = nullptr;
1547+
break;
1548+
}
1549+
}
1550+
1551+
if (!seen) {
15431552
upVal = upVal.PurgeAnything();
15441553
}
1545-
auto L = LI.getLoopFor(phi.getParent());
1546-
bool isHeader = L && L->getHeader() == phi.getParent();
1547-
for (size_t i = 0, end = phi.getNumIncomingValues(); i < end; ++i) {
1548-
if (!isHeader || !L->contains(phi.getIncomingBlock(i))) {
1549-
updateAnalysis(phi.getIncomingValue(i), upVal, &phi);
1554+
1555+
if (EnzymeStrictAliasing || seen) {
1556+
auto L = LI.getLoopFor(phi.getParent());
1557+
bool isHeader = L && L->getHeader() == phi.getParent();
1558+
for (size_t i = 0, end = phi.getNumIncomingValues(); i < end; ++i) {
1559+
if (!isHeader || !L->contains(phi.getIncomingBlock(i))) {
1560+
updateAnalysis(phi.getIncomingValue(i), upVal, &phi);
1561+
}
1562+
}
1563+
} else {
1564+
if (EnzymePrintType) {
1565+
for (size_t i = 0, end = phi.getNumIncomingValues(); i < end; ++i)
1566+
llvm::errs() << " skipping update into " << *phi.getIncomingValue(i)
1567+
<< " of " << upVal.str() << " from " << phi << "\n";
15501568
}
15511569
}
15521570
}
@@ -1840,11 +1858,20 @@ void TypeAnalyzer::visitBitCastInst(BitCastInst &I) {
18401858
}
18411859

18421860
void TypeAnalyzer::visitSelectInst(SelectInst &I) {
1843-
if (direction & UP)
1844-
updateAnalysis(I.getTrueValue(), getAnalysis(&I).PurgeAnything(), &I);
1845-
if (direction & UP)
1846-
updateAnalysis(I.getFalseValue(), getAnalysis(&I).PurgeAnything(), &I);
1847-
1861+
if (direction & UP) {
1862+
auto Data = getAnalysis(&I).PurgeAnything();
1863+
if (EnzymeStrictAliasing || (I.getTrueValue() == I.getFalseValue())) {
1864+
updateAnalysis(I.getTrueValue(), Data, &I);
1865+
updateAnalysis(I.getFalseValue(), Data, &I);
1866+
} else {
1867+
if (EnzymePrintType) {
1868+
llvm::errs() << " skipping update into " << *I.getTrueValue() << " of "
1869+
<< Data.str() << " from " << I << "\n";
1870+
llvm::errs() << " skipping update into " << *I.getFalseValue() << " of "
1871+
<< Data.str() << " from " << I << "\n";
1872+
}
1873+
}
1874+
}
18481875
if (direction & DOWN) {
18491876
// special case for min/max result is still that operand [even if something
18501877
// is 0]

enzyme/test/TypeAnalysis/strictalphi.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ bb153: ; preds = %bb216
2727

2828
; CHECK: f - {} |
2929
; CHECK-NEXT: e
30-
; CHECK-NEXT: %i78 = call noalias nonnull i8* @_Znwm(i64 8): {[-1]:Pointer, [-1,0]:Integer}
30+
; CHECK-NEXT: %i78 = call noalias nonnull i8* @_Znwm(i64 8): {[-1]:Pointer}
3131
; CHECK-NEXT: br label %bb155: {}
3232
; CHECK-NEXT: bb155
3333
; CHECK-NEXT: %i159 = phi i8* [ %i78, %e ], [ %i220, %bb216 ]: {[-1]:Pointer, [-1,0]:Integer}

enzyme/test/TypeAnalysis/strictphi.ll

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
; RUN: %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=f -enzyme-strict-aliasing=0 -o /dev/null | FileCheck %s
2+
3+
source_filename = "<source>"
4+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
5+
target triple = "x86_64-unknown-linux-gnu"
6+
7+
%class.Testing = type { %struct.Header, %struct.Header }
8+
%struct.Header = type { %struct.Base, i32 }
9+
%struct.Base = type { %struct.Base*, %struct.Base* }
10+
11+
define dso_local void @f(%class.Testing* %arg) {
12+
bb:
13+
%i = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 0, i32 0
14+
%i1 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 0, i32 0, i32 0
15+
%i13 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 1, i32 0
16+
%i14 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 1, i32 0, i32 0
17+
br label %bb2
18+
19+
bb2: ; preds = %bb2, %bb
20+
%i3 = phi %struct.Base** [ %i1, %bb ], [ %i7, %bb2 ]
21+
%i4 = phi %struct.Base* [ %i, %bb ], [ %i5, %bb2 ]
22+
%i5 = load %struct.Base*, %struct.Base** %i3, align 8, !tbaa !3
23+
%i6 = icmp eq %struct.Base* %i5, null
24+
%i7 = getelementptr inbounds %struct.Base, %struct.Base* %i5, i64 0, i32 1
25+
br i1 %i6, label %bb8, label %bb2, !llvm.loop !7
26+
27+
bb8: ; preds = %bb2
28+
%i9 = getelementptr inbounds %struct.Base, %struct.Base* %i4, i64 1, i32 1
29+
%i10 = bitcast %struct.Base** %i9 to double*
30+
%i11 = load double, double* %i10, align 8, !tbaa !9
31+
br label %bb15
32+
33+
bb15: ; preds = %bb15, %bb8
34+
%i16 = phi %struct.Base** [ %i14, %bb8 ], [ %i20, %bb15 ]
35+
%i17 = phi %struct.Base* [ %i13, %bb8 ], [ %i18, %bb15 ]
36+
%i18 = load %struct.Base*, %struct.Base** %i16, align 8, !tbaa !3
37+
%i19 = icmp eq %struct.Base* %i18, null
38+
%i20 = getelementptr inbounds %struct.Base, %struct.Base* %i18, i64 0, i32 1
39+
br i1 %i19, label %bb21, label %bb15, !llvm.loop !7
40+
41+
bb21: ; preds = %bb15
42+
%i22 = getelementptr inbounds %struct.Base, %struct.Base* %i17, i64 1, i32 1
43+
%i23 = bitcast %struct.Base** %i22 to double*
44+
%i24 = load double, double* %i23, align 8, !tbaa !9
45+
tail call void @_Z5printdd(double %i11, double %i24)
46+
ret void
47+
}
48+
49+
declare void @_Z5printdd(double, double)
50+
51+
!llvm.module.flags = !{!0, !1}
52+
!llvm.ident = !{!2}
53+
54+
!0 = !{i32 7, !"Dwarf Version", i32 4}
55+
!1 = !{i32 1, !"wchar_size", i32 4}
56+
!2 = !{!"clang version 12.0.1 (https://github.com/llvm/llvm-project.git fed41342a82f5a3a9201819a82bf7a48313e296b)"}
57+
!3 = !{!4, !4, i64 0}
58+
!4 = !{!"any pointer", !5, i64 0}
59+
!5 = !{!"omnipotent char", !6, i64 0}
60+
!6 = !{!"Simple C++ TBAA"}
61+
!7 = distinct !{!7, !8}
62+
!8 = !{!"llvm.loop.mustprogress"}
63+
!9 = !{!10, !10, i64 0}
64+
!10 = !{!"double", !5, i64 0}
65+
66+
; CHECK: %class.Testing* %arg: {[-1]:Pointer}
67+
; CHECK-NEXT: bb
68+
; CHECK-NEXT: %i = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 0, i32 0: {[-1]:Pointer}
69+
; CHECK-NEXT: %i1 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 0, i32 0, i32 0: {[-1]:Pointer}
70+
; CHECK-NEXT: %i13 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 1, i32 0: {[-1]:Pointer}
71+
; CHECK-NEXT: %i14 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 1, i32 0, i32 0: {[-1]:Pointer}
72+
; CHECK-NEXT: br label %bb2: {}
73+
; CHECK-NEXT: bb2
74+
; CHECK-NEXT: %i3 = phi %struct.Base** [ %i1, %bb ], [ %i7, %bb2 ]: {[-1]:Pointer, [-1,0]:Pointer}
75+
; CHECK-NEXT: %i4 = phi %struct.Base* [ %i, %bb ], [ %i5, %bb2 ]: {[-1]:Pointer, [-1,24]:Float@double}
76+
; CHECK-NEXT: %i5 = load %struct.Base*, %struct.Base** %i3, align 8, !tbaa !3: {[-1]:Pointer}
77+
; CHECK-NEXT: %i6 = icmp eq %struct.Base* %i5, null: {[-1]:Integer}
78+
; CHECK-NEXT: %i7 = getelementptr inbounds %struct.Base, %struct.Base* %i5, i64 0, i32 1: {[-1]:Pointer}
79+
; CHECK-NEXT: br i1 %i6, label %bb8, label %bb2, !llvm.loop !7: {}
80+
; CHECK-NEXT: bb8
81+
; CHECK-NEXT: %i9 = getelementptr inbounds %struct.Base, %struct.Base* %i4, i64 1, i32 1: {[-1]:Pointer, [-1,0]:Float@double}
82+
; CHECK-NEXT: %i10 = bitcast %struct.Base** %i9 to double*: {[-1]:Pointer, [-1,0]:Float@double}
83+
; CHECK-NEXT: %i11 = load double, double* %i10, align 8, !tbaa !9: {[-1]:Float@double}
84+
; CHECK-NEXT: br label %bb15: {}
85+
; CHECK-NEXT: bb15
86+
; CHECK-NEXT: %i16 = phi %struct.Base** [ %i14, %bb8 ], [ %i20, %bb15 ]: {[-1]:Pointer, [-1,0]:Pointer}
87+
; CHECK-NEXT: %i17 = phi %struct.Base* [ %i13, %bb8 ], [ %i18, %bb15 ]: {[-1]:Pointer, [-1,24]:Float@double}
88+
; CHECK-NEXT: %i18 = load %struct.Base*, %struct.Base** %i16, align 8, !tbaa !3: {[-1]:Pointer}
89+
; CHECK-NEXT: %i19 = icmp eq %struct.Base* %i18, null: {[-1]:Integer}
90+
; CHECK-NEXT: %i20 = getelementptr inbounds %struct.Base, %struct.Base* %i18, i64 0, i32 1: {[-1]:Pointer}
91+
; CHECK-NEXT: br i1 %i19, label %bb21, label %bb15, !llvm.loop !7: {}
92+
; CHECK-NEXT: bb21
93+
; CHECK-NEXT: %i22 = getelementptr inbounds %struct.Base, %struct.Base* %i17, i64 1, i32 1: {[-1]:Pointer, [-1,0]:Float@double}
94+
; CHECK-NEXT: %i23 = bitcast %struct.Base** %i22 to double*: {[-1]:Pointer, [-1,0]:Float@double}
95+
; CHECK-NEXT: %i24 = load double, double* %i23, align 8, !tbaa !9: {[-1]:Float@double}
96+
; CHECK-NEXT: tail call void @_Z5printdd(double %i11, double %i24): {}
97+
; CHECK-NEXT: ret void: {}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
; RUN: %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=f -enzyme-strict-aliasing=0 -o /dev/null | FileCheck %s
2+
3+
source_filename = "<source>"
4+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
5+
target triple = "x86_64-unknown-linux-gnu"
6+
7+
%class.Testing = type { %struct.Header, %struct.Header }
8+
%struct.Header = type { %struct.Base, i32 }
9+
%struct.Base = type { %struct.Base* }
10+
11+
define dso_local void @f(%class.Testing* nocapture nonnull readonly %arg) {
12+
bb:
13+
%i = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 0, i32 0
14+
%i1 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 0, i32 0, i32 0
15+
%i2 = load %struct.Base*, %struct.Base** %i1, align 8, !tbaa !3
16+
%i3 = icmp eq %struct.Base* %i2, null
17+
%i4 = select i1 %i3, %struct.Base* %i, %struct.Base* %i2
18+
%i5 = getelementptr inbounds %struct.Base, %struct.Base* %i4, i64 2
19+
%i6 = bitcast %struct.Base* %i5 to double*
20+
%i7 = load double, double* %i6, align 8, !tbaa !10
21+
%i8 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 1
22+
%i9 = getelementptr inbounds %struct.Header, %struct.Header* %i8, i64 0, i32 0
23+
%i10 = getelementptr inbounds %struct.Header, %struct.Header* %i8, i64 0, i32 0, i32 0
24+
%i11 = load %struct.Base*, %struct.Base** %i10, align 8, !tbaa !3
25+
%i12 = icmp eq %struct.Base* %i11, null
26+
%i13 = select i1 %i12, %struct.Base* %i9, %struct.Base* %i11
27+
%i14 = getelementptr inbounds %struct.Base, %struct.Base* %i13, i64 2
28+
%i15 = bitcast %struct.Base* %i14 to double*
29+
%i16 = load double, double* %i15, align 8, !tbaa !10
30+
tail call void (...) @_Z6printfPKcz(double %i7, double %i16)
31+
ret void
32+
}
33+
34+
declare void @_Z6printfPKcz(...)
35+
36+
!llvm.module.flags = !{!0, !1}
37+
!llvm.ident = !{!2}
38+
39+
!0 = !{i32 7, !"Dwarf Version", i32 4}
40+
!1 = !{i32 1, !"wchar_size", i32 4}
41+
!2 = !{!"clang version 12.0.1 (https://github.com/llvm/llvm-project.git fed41342a82f5a3a9201819a82bf7a48313e296b)"}
42+
!3 = !{!4, !6, i64 0}
43+
!4 = !{!"_ZTS6Header", !5, i64 0, !9, i64 8}
44+
!5 = !{!"_ZTS4Base", !6, i64 0}
45+
!6 = !{!"any pointer", !7, i64 0}
46+
!7 = !{!"omnipotent char", !8, i64 0}
47+
!8 = !{!"Simple C++ TBAA"}
48+
!9 = !{!"int", !7, i64 0}
49+
!10 = !{!11, !11, i64 0}
50+
!11 = !{!"double", !7, i64 0}
51+
52+
; CHECK: f - {} |{[-1]:Pointer}:{}
53+
; CHECK-NEXT: %class.Testing* %arg: {[-1]:Pointer, [-1,0]:Pointer, [-1,16]:Pointer}
54+
; CHECK-NEXT: bb
55+
; CHECK-NEXT: %i = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 0, i32 0: {[-1]:Pointer, [-1,0]:Pointer}
56+
; CHECK-NEXT: %i1 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 0, i32 0, i32 0: {[-1]:Pointer, [-1,0]:Pointer}
57+
; CHECK-NEXT: %i2 = load %struct.Base*, %struct.Base** %i1, align 8, !tbaa !3: {[-1]:Pointer}
58+
; CHECK-NEXT: %i3 = icmp eq %struct.Base* %i2, null: {[-1]:Integer}
59+
; CHECK-NEXT: %i4 = select i1 %i3, %struct.Base* %i, %struct.Base* %i2: {[-1]:Pointer, [-1,16]:Float@double}
60+
; CHECK-NEXT: %i5 = getelementptr inbounds %struct.Base, %struct.Base* %i4, i64 2: {[-1]:Pointer, [-1,0]:Float@double}
61+
; CHECK-NEXT: %i6 = bitcast %struct.Base* %i5 to double*: {[-1]:Pointer, [-1,0]:Float@double}
62+
; CHECK-NEXT: %i7 = load double, double* %i6, align 8, !tbaa !10: {[-1]:Float@double}
63+
; CHECK-NEXT: %i8 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 1: {[-1]:Pointer, [-1,0]:Pointer}
64+
; CHECK-NEXT: %i9 = getelementptr inbounds %struct.Header, %struct.Header* %i8, i64 0, i32 0: {[-1]:Pointer, [-1,0]:Pointer}
65+
; CHECK-NEXT: %i10 = getelementptr inbounds %struct.Header, %struct.Header* %i8, i64 0, i32 0, i32 0: {[-1]:Pointer, [-1,0]:Pointer}
66+
; CHECK-NEXT: %i11 = load %struct.Base*, %struct.Base** %i10, align 8, !tbaa !3: {[-1]:Pointer}
67+
; CHECK-NEXT: %i12 = icmp eq %struct.Base* %i11, null: {[-1]:Integer}
68+
; CHECK-NEXT: %i13 = select i1 %i12, %struct.Base* %i9, %struct.Base* %i11: {[-1]:Pointer, [-1,16]:Float@double}
69+
; CHECK-NEXT: %i14 = getelementptr inbounds %struct.Base, %struct.Base* %i13, i64 2: {[-1]:Pointer, [-1,0]:Float@double}
70+
; CHECK-NEXT: %i15 = bitcast %struct.Base* %i14 to double*: {[-1]:Pointer, [-1,0]:Float@double}
71+
; CHECK-NEXT: %i16 = load double, double* %i15, align 8, !tbaa !10: {[-1]:Float@double}
72+
; CHECK-NEXT: tail call void (...) @_Z6printfPKcz(double %i7, double %i16): {}
73+
; CHECK-NEXT: ret void: {}

0 commit comments

Comments
 (0)