Skip to content

Commit e40a859

Browse files
authored
Fix OMP fork arg handling (rust-lang#485)
1 parent 9e48a51 commit e40a859

File tree

3 files changed

+228
-44
lines changed

3 files changed

+228
-44
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3878,7 +3878,7 @@ class AdjointGenerator
38783878
}
38793879

38803880
assert(uncacheable_args_map.find(&call) != uncacheable_args_map.end());
3881-
const std::map<Argument *, bool> &uncacheable_argsAbove =
3881+
const std::map<Argument *, bool> &uncacheable_args =
38823882
uncacheable_args_map.find(&call)->second;
38833883

38843884
IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&call));
@@ -3902,38 +3902,6 @@ class AdjointGenerator
39023902
"could not derive underlying task contents from omp call");
39033903
}
39043904

3905-
std::map<Argument *, bool> uncacheable_args;
3906-
{
3907-
auto in_arg = call.getCalledFunction()->arg_begin();
3908-
auto pp_arg = task->arg_begin();
3909-
3910-
// Global.tid is cacheable
3911-
uncacheable_args[pp_arg] = false;
3912-
++pp_arg;
3913-
// Bound.tid is cacheable
3914-
uncacheable_args[pp_arg] = false;
3915-
++pp_arg;
3916-
3917-
// Ignore the first three args of init call
3918-
++in_arg;
3919-
++in_arg;
3920-
++in_arg;
3921-
3922-
for (; pp_arg != task->arg_end();) {
3923-
// If var-args then we may still have args even though outermost
3924-
// has no more
3925-
if (in_arg == call.getCalledFunction()->arg_end()) {
3926-
uncacheable_args[pp_arg] = true;
3927-
} else {
3928-
assert(uncacheable_argsAbove.find(in_arg) !=
3929-
uncacheable_argsAbove.end());
3930-
uncacheable_args[pp_arg] = uncacheable_argsAbove.find(in_arg)->second;
3931-
++in_arg;
3932-
}
3933-
++pp_arg;
3934-
}
3935-
}
3936-
39373905
auto called = task;
39383906
// bool modifyPrimal = true;
39393907

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -537,14 +537,15 @@ struct CacheAnalysis {
537537
return {};
538538
}
539539

540-
if (Fn->getName().startswith("MPI_") ||
541-
Fn->getName().startswith("enzyme_wrapmpi$$"))
540+
StringRef funcName = Fn->getName();
541+
542+
if (funcName.startswith("MPI_") || funcName.startswith("enzyme_wrapmpi$$"))
542543
return {};
543544

544-
if (Fn->getName() == "__kmpc_for_static_init_4" ||
545-
Fn->getName() == "__kmpc_for_static_init_4u" ||
546-
Fn->getName() == "__kmpc_for_static_init_8" ||
547-
Fn->getName() == "__kmpc_for_static_init_8u") {
545+
if (funcName == "__kmpc_for_static_init_4" ||
546+
funcName == "__kmpc_for_static_init_4u" ||
547+
funcName == "__kmpc_for_static_init_8" ||
548+
funcName == "__kmpc_for_static_init_8u") {
548549
return {};
549550
}
550551

@@ -644,12 +645,48 @@ struct CacheAnalysis {
644645

645646
std::map<Argument *, bool> uncacheable_args;
646647

647-
auto arg = Fn->arg_begin();
648-
for (unsigned i = 0; i < args.size(); ++i) {
649-
uncacheable_args[arg] = !args_safe[i];
648+
if (funcName == "__kmpc_fork_call") {
649+
Value *op = callsite_op->getArgOperand(2);
650+
Function *task = nullptr;
651+
while (!(task = dyn_cast<Function>(op))) {
652+
if (auto castinst = dyn_cast<ConstantExpr>(op))
653+
if (castinst->isCast()) {
654+
op = castinst->getOperand(0);
655+
continue;
656+
}
657+
if (auto CI = dyn_cast<CastInst>(op)) {
658+
op = CI->getOperand(0);
659+
continue;
660+
}
661+
llvm::errs() << "op: " << *op << "\n";
662+
assert(0 && "unknown fork call arg");
663+
}
664+
665+
auto arg = task->arg_begin();
666+
667+
// Global.tid is cacheable
668+
uncacheable_args[arg] = false;
650669
++arg;
651-
if (arg == Fn->arg_end()) {
652-
break;
670+
// Bound.tid is cacheable
671+
uncacheable_args[arg] = false;
672+
++arg;
673+
674+
// Ignore first three arguments of fork call
675+
for (unsigned i = 3; i < args.size(); ++i) {
676+
uncacheable_args[arg] = !args_safe[i];
677+
++arg;
678+
if (arg == Fn->arg_end()) {
679+
break;
680+
}
681+
}
682+
} else {
683+
auto arg = Fn->arg_begin();
684+
for (unsigned i = 0; i < args.size(); ++i) {
685+
uncacheable_args[arg] = !args_safe[i];
686+
++arg;
687+
if (arg == Fn->arg_end()) {
688+
break;
689+
}
653690
}
654691
}
655692

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
; RUN: if [ %llvmver -ge 9 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -adce -simplifycfg -S | FileCheck %s; fi
2+
3+
source_filename = "lulesh.cc"
4+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
5+
target triple = "x86_64-unknown-linux-gnu"
6+
7+
%struct.ident_t = type { i32, i32, i32, i32, i8* }
8+
9+
@0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
10+
@1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 514, i32 0, i32 0, i8* getelementptr inbounds ([23 x i8], [23 x i8]* @0, i32 0, i32 0) }, align 8
11+
@2 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 0, i8* getelementptr inbounds ([23 x i8], [23 x i8]* @0, i32 0, i32 0) }, align 8
12+
13+
; Function Attrs: norecurse nounwind uwtable mustprogress
14+
define void @caller(double* %out, double* %dout, double* %in, double* %din) {
15+
entry:
16+
call void @_Z17__enzyme_autodiffPvS_S_m(i8* bitcast (void (double*, double*, i64)* @_ZL16LagrangeLeapFrogPdm to i8*), double* %out, double* %dout, double* %in, double* %din, i64 100)
17+
ret void
18+
}
19+
20+
declare dso_local void @_Z17__enzyme_autodiffPvS_S_m(i8*, double*, double*, double*, double*, i64)
21+
22+
; Function Attrs: inlinehint nounwind uwtable mustprogress
23+
define internal void @_ZL16LagrangeLeapFrogPdm(double* noalias %out, double* noalias %in, i64 %length) #3 {
24+
entry:
25+
tail call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* nonnull @2, i32 2, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i64, double*, double*)* @.omp_outlined. to void (i32*, i32*, ...)*), i64 %length, double* %out, double* %in)
26+
ret void
27+
}
28+
29+
; Function Attrs: norecurse nounwind uwtable
30+
define internal void @.omp_outlined.(i32* noalias nocapture readonly %.global_tid., i32* noalias nocapture readnone %.bound_tid., i64 %length, double* nocapture noalias %out, double* nocapture noalias %tmp) #4 {
31+
entry:
32+
%.omp.lb = alloca i64, align 8
33+
%.omp.ub = alloca i64, align 8
34+
%.omp.stride = alloca i64, align 8
35+
%.omp.is_last = alloca i32, align 4
36+
%sub4 = add i64 %length, -1
37+
%cmp.not = icmp eq i64 %length, 0
38+
br i1 %cmp.not, label %omp.precond.end, label %omp.precond.then
39+
40+
omp.precond.then: ; preds = %entry
41+
%0 = bitcast i64* %.omp.lb to i8*
42+
store i64 0, i64* %.omp.lb, align 8, !tbaa !3
43+
%1 = bitcast i64* %.omp.ub to i8*
44+
store i64 %sub4, i64* %.omp.ub, align 8, !tbaa !3
45+
%2 = bitcast i64* %.omp.stride to i8*
46+
store i64 1, i64* %.omp.stride, align 8, !tbaa !3
47+
%3 = bitcast i32* %.omp.is_last to i8*
48+
store i32 0, i32* %.omp.is_last, align 4, !tbaa !7
49+
%4 = load i32, i32* %.global_tid., align 4, !tbaa !7
50+
call void @__kmpc_for_static_init_8u(%struct.ident_t* nonnull @1, i32 %4, i32 34, i32* nonnull %.omp.is_last, i64* nonnull %.omp.lb, i64* nonnull %.omp.ub, i64* nonnull %.omp.stride, i64 1, i64 1)
51+
%5 = load i64, i64* %.omp.ub, align 8, !tbaa !3
52+
%cmp6 = icmp ugt i64 %5, %sub4
53+
%cond = select i1 %cmp6, i64 %sub4, i64 %5
54+
store i64 %cond, i64* %.omp.ub, align 8, !tbaa !3
55+
%6 = load i64, i64* %.omp.lb, align 8, !tbaa !3
56+
%add29 = add i64 %cond, 1
57+
%cmp730 = icmp ult i64 %6, %add29
58+
br i1 %cmp730, label %omp.inner.for.body, label %omp.loop.exit
59+
60+
omp.inner.for.body: ; preds = %omp.precond.then, %omp.inner.for.body
61+
%.omp.iv.031 = phi i64 [ %add11, %omp.inner.for.body ], [ %6, %omp.precond.then ]
62+
%arrayidx = getelementptr inbounds double, double* %tmp, i64 %.omp.iv.031
63+
%7 = load double, double* %arrayidx, align 8, !tbaa !9
64+
%call = call double @sqrt(double %7) #5
65+
%outidx = getelementptr inbounds double, double* %out, i64 %.omp.iv.031
66+
store double %call, double* %outidx, align 8, !tbaa !9
67+
%add11 = add nuw i64 %.omp.iv.031, 1
68+
%8 = load i64, i64* %.omp.ub, align 8, !tbaa !3
69+
%add = add i64 %8, 1
70+
%cmp7 = icmp ult i64 %add11, %add
71+
br i1 %cmp7, label %omp.inner.for.body, label %omp.loop.exit
72+
73+
omp.loop.exit: ; preds = %omp.inner.for.body, %omp.precond.then
74+
call void @__kmpc_for_static_fini(%struct.ident_t* nonnull @1, i32 %4)
75+
br label %omp.precond.end
76+
77+
omp.precond.end: ; preds = %omp.loop.exit, %entry
78+
ret void
79+
}
80+
81+
; Function Attrs: nounwind
82+
declare dso_local void @__kmpc_for_static_init_8u(%struct.ident_t*, i32, i32, i32*, i64*, i64*, i64*, i64, i64) local_unnamed_addr #5
83+
84+
; Function Attrs: nofree nounwind willreturn mustprogress
85+
declare dso_local double @sqrt(double) local_unnamed_addr #6
86+
87+
; Function Attrs: nounwind
88+
declare void @__kmpc_for_static_fini(%struct.ident_t*, i32) local_unnamed_addr #5
89+
90+
; Function Attrs: nounwind
91+
declare !callback !11 void @__kmpc_fork_call(%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) local_unnamed_addr #5
92+
93+
attributes #0 = { norecurse nounwind uwtable }
94+
attributes #1 = { argmemonly }
95+
96+
!llvm.module.flags = !{!0, !1}
97+
!llvm.ident = !{!2}
98+
!nvvm.annotations = !{}
99+
100+
!0 = !{i32 1, !"wchar_size", i32 4}
101+
!1 = !{i32 7, !"uwtable", i32 1}
102+
!2 = !{!"clang version 13.0.0 ([email protected]:llvm/llvm-project 619bfe8bd23f76b22f0a53fedafbfc8c97a15f12)"}
103+
!3 = !{!4, !4, i64 0}
104+
!4 = !{!"long", !5, i64 0}
105+
!5 = !{!"omnipotent char", !6, i64 0}
106+
!6 = !{!"Simple C++ TBAA"}
107+
!7 = !{!8, !8, i64 0}
108+
!8 = !{!"int", !5, i64 0}
109+
!9 = !{!10, !10, i64 0}
110+
!10 = !{!"double", !5, i64 0}
111+
!11 = !{!12}
112+
!12 = !{i64 2, i64 -1, i64 -1, i1 true}
113+
114+
; This should not cache and instead reload from %tmp
115+
116+
; CHECK: define internal void @diffe.omp_outlined.(i32* noalias nocapture readonly %.global_tid., i32* noalias nocapture readnone %.bound_tid., i64 %length, double* noalias nocapture %out, double* nocapture %"out'", double* noalias nocapture %tmp, double* nocapture %"tmp'")
117+
; CHECK-NEXT: entry:
118+
; CHECK-NEXT: %.omp.lb_smpl = alloca i64
119+
; CHECK-NEXT: %.omp.ub_smpl = alloca i64
120+
; CHECK-NEXT: %.omp.stride_smpl = alloca i64
121+
; CHECK-NEXT: %.omp.is_last = alloca i32
122+
; CHECK-NEXT: %sub4 = add i64 %length, -1
123+
; CHECK-NEXT: %cmp.not = icmp eq i64 %length, 0
124+
; CHECK-NEXT: br i1 %cmp.not, label %invertentry, label %omp.precond.then
125+
126+
; CHECK: omp.precond.then: ; preds = %entry
127+
; CHECK-NEXT: store i32 0, i32* %.omp.is_last
128+
; CHECK-NEXT: %0 = load i32, i32* %.global_tid.
129+
; CHECK-NEXT: store i64 0, i64* %.omp.lb_smpl
130+
; CHECK-NEXT: store i64 %sub4, i64* %.omp.ub_smpl
131+
; CHECK-NEXT: store i64 1, i64* %.omp.stride_smpl
132+
; CHECK-NEXT: call void @__kmpc_for_static_init_8u(%struct.ident_t* nonnull @1, i32 %0, i32 34, i32* nonnull %.omp.is_last, i64* nocapture nonnull %.omp.lb_smpl, i64* nocapture nonnull %.omp.ub_smpl, i64* nocapture nonnull %.omp.stride_smpl, i64 1, i64 1) #0
133+
; CHECK-NEXT: %_unwrap8 = load i64, i64* %.omp.lb_smpl
134+
; CHECK-NEXT: %_unwrap9 = load i64, i64* %.omp.ub_smpl
135+
; CHECK-NEXT: %cmp6_unwrap10 = icmp ugt i64 %_unwrap9, %sub4
136+
; CHECK-NEXT: %cond_unwrap11 = select i1 %cmp6_unwrap10, i64 %sub4, i64 %_unwrap9
137+
; CHECK-NEXT: %add29_unwrap = add i64 %cond_unwrap11, 1
138+
; CHECK-NEXT: %cmp730_unwrap = icmp ult i64 %_unwrap8, %add29_unwrap
139+
; CHECK-NEXT: br i1 %cmp730_unwrap, label %invertomp.loop.exit.loopexit, label %invertomp.precond.then
140+
141+
; CHECK: invertentry: ; preds = %entry, %invertomp.precond.then
142+
; CHECK-NEXT: ret void
143+
144+
; CHECK: invertomp.precond.then: ; preds = %invertomp.inner.for.body, %omp.precond.then
145+
; CHECK-NEXT: %_unwrap = load i32, i32* %.global_tid., align 4, !tbaa !7, !invariant.group !13
146+
; CHECK-NEXT: call void @__kmpc_for_static_fini(%struct.ident_t* @1, i32 %_unwrap)
147+
; CHECK-NEXT: br label %invertentry
148+
149+
; CHECK: invertomp.inner.for.body: ; preds = %invertomp.loop.exit.loopexit, %incinvertomp.inner.for.body
150+
; CHECK-NEXT: %"iv'ac.0" = phi i64 [ %_unwrap7, %invertomp.loop.exit.loopexit ], [ %9, %incinvertomp.inner.for.body ]
151+
; CHECK-NEXT: %_unwrap2 = load i64, i64* %.omp.lb_smpl
152+
; CHECK-NEXT: %_unwrap3 = add i64 {{((%_unwrap2, %"iv'ac.0")|%"iv'ac.0", %_unwrap2)}}
153+
; CHECK-NEXT: %"outidx'ipg_unwrap" = getelementptr inbounds double, double* %"out'", i64 %_unwrap3
154+
; CHECK-NEXT: %1 = load double, double* %"outidx'ipg_unwrap", align 8
155+
; CHECK-NEXT: store double 0.000000e+00, double* %"outidx'ipg_unwrap", align 8
156+
; CHECK-NEXT: %arrayidx_unwrap = getelementptr inbounds double, double* %tmp, i64 %_unwrap3
157+
; CHECK-NEXT: %_unwrap4 = load double, double* %arrayidx_unwrap, align 8, !tbaa !9, !invariant.group !16
158+
; CHECK-NEXT: %2 = call fast double @llvm.sqrt.f64(double %_unwrap4)
159+
; CHECK-NEXT: %3 = fmul fast double 5.000000e-01, %1
160+
; CHECK-NEXT: %4 = fdiv fast double %3, %2
161+
; CHECK-NEXT: %5 = fcmp fast oeq double %_unwrap4, 0.000000e+00
162+
; CHECK-NEXT: %6 = select fast i1 %5, double 0.000000e+00, double %4
163+
; CHECK-NEXT: %"arrayidx'ipg_unwrap" = getelementptr inbounds double, double* %"tmp'", i64 %_unwrap3
164+
; CHECK-NEXT: %7 = atomicrmw fadd double* %"arrayidx'ipg_unwrap", double %6 monotonic
165+
; CHECK-NEXT: %8 = icmp eq i64 %"iv'ac.0", 0
166+
; CHECK-NEXT: br i1 %8, label %invertomp.precond.then, label %incinvertomp.inner.for.body
167+
168+
; CHECK: incinvertomp.inner.for.body: ; preds = %invertomp.inner.for.body
169+
; CHECK-NEXT: %9 = add nsw i64 %"iv'ac.0", -1
170+
; CHECK-NEXT: br label %invertomp.inner.for.body
171+
172+
; CHECK: invertomp.loop.exit.loopexit: ; preds = %omp.precond.then
173+
; CHECK-NEXT: %_unwrap5 = load i64, i64* %.omp.ub_smpl
174+
; CHECK-NEXT: %cmp6_unwrap = icmp ugt i64 %_unwrap5, %sub4
175+
; CHECK-NEXT: %cond_unwrap = select i1 %cmp6_unwrap, i64 %sub4, i64 %_unwrap5
176+
; CHECK-NEXT: %_unwrap6 = load i64, i64* %.omp.lb_smpl
177+
; CHECK-NEXT: %_unwrap7 = sub i64 %cond_unwrap, %_unwrap6
178+
; CHECK-NEXT: br label %invertomp.inner.for.body
179+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)