Skip to content

Commit 8abdb9f

Browse files
authored
Cabs calling convention (rust-lang#749)
* Handle array types in TypeAnalysis * Handle array calling convention of cabs * Add tests
1 parent 1d9d047 commit 8abdb9f

File tree

11 files changed

+424
-13
lines changed

11 files changed

+424
-13
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9945,26 +9945,65 @@ class AdjointGenerator
99459945
Value *d = Builder2.CreateCall(called, args);
99469946

99479947
if (args.size() == 2) {
9948-
Value *op0 = diffe(orig->getArgOperand(0), Builder2);
9949-
9950-
Value *op1 = diffe(orig->getArgOperand(1), Builder2);
9948+
Value *op0 = gutils->isConstantValue(orig->getArgOperand(0))
9949+
? nullptr
9950+
: diffe(orig->getArgOperand(0), Builder2);
9951+
Value *op1 = gutils->isConstantValue(orig->getArgOperand(1))
9952+
? nullptr
9953+
: diffe(orig->getArgOperand(1), Builder2);
9954+
9955+
auto rule1 = [&](Value *op) {
9956+
return Builder2.CreateFMul(args[0], Builder2.CreateFDiv(op, d));
9957+
};
99519958

9952-
auto rule = [&](Value *op0, Value *op1) {
9959+
auto rule2 = [&](Value *op0, Value *op1) {
99539960
Value *dif1 =
99549961
Builder2.CreateFMul(args[0], Builder2.CreateFDiv(op0, d));
99559962
Value *dif2 =
99569963
Builder2.CreateFMul(args[1], Builder2.CreateFDiv(op1, d));
99579964
return Builder2.CreateFAdd(dif1, dif2);
99589965
};
99599966

9960-
Value *dif =
9961-
applyChainRule(call.getType(), Builder2, rule, op0, op1);
9967+
Value *dif;
9968+
if (op0 && op1)
9969+
dif = applyChainRule(call.getType(), Builder2, rule2, op0, op1);
9970+
else if (op0)
9971+
dif = applyChainRule(call.getType(), Builder2, rule1, op0);
9972+
else if (op1)
9973+
dif = applyChainRule(call.getType(), Builder2, rule1, op1);
9974+
else
9975+
llvm_unreachable(
9976+
"trying to differentiate a constant instruction");
9977+
99629978
setDiffe(orig, dif, Builder2);
99639979
return;
9964-
} else {
9965-
llvm::errs() << *orig << "\n";
9966-
llvm_unreachable("unknown calling convention found for cabs");
9980+
} else if (args.size() == 1) {
9981+
if (auto AT = dyn_cast<ArrayType>(args[0]->getType())) {
9982+
if (AT->getNumElements() == 2) {
9983+
Value *op = diffe(orig->getArgOperand(0), Builder2);
9984+
Value *args0 = Builder2.CreateExtractValue(args[0], 0);
9985+
Value *args1 = Builder2.CreateExtractValue(args[0], 1);
9986+
9987+
auto rule = [&](Value *op) {
9988+
Value *op0 = Builder2.CreateExtractValue(op, 0);
9989+
Value *op1 = Builder2.CreateExtractValue(op, 1);
9990+
9991+
Value *dif1 =
9992+
Builder2.CreateFMul(args0, Builder2.CreateFDiv(op0, d));
9993+
Value *dif2 =
9994+
Builder2.CreateFMul(args1, Builder2.CreateFDiv(op1, d));
9995+
return Builder2.CreateFAdd(dif1, dif2);
9996+
};
9997+
9998+
Value *dif =
9999+
applyChainRule(call.getType(), Builder2, rule, op);
10000+
setDiffe(orig, dif, Builder2);
10001+
return;
10002+
}
10003+
}
996710004
}
10005+
llvm::errs() << *orig << "\n";
10006+
llvm_unreachable("unknown calling convention found for cabs");
996810007
}
996910008
case DerivativeMode::ReverseModeGradient:
997010009
case DerivativeMode::ReverseModeCombined: {
@@ -9998,10 +10037,31 @@ class AdjointGenerator
999810037
Builder2.CreateFMul(args[i], div), Builder2,
999910038
orig->getType());
1000010039
return;
10001-
} else {
10002-
llvm::errs() << *orig << "\n";
10003-
llvm_unreachable("unknown calling convention found for cabs");
10040+
} else if (args.size() == 1) {
10041+
if (auto AT = dyn_cast<ArrayType>(args[0]->getType())) {
10042+
if (AT->getNumElements() == 2) {
10043+
if (!gutils->isConstantValue(orig->getArgOperand(0))) {
10044+
Value *agg = UndefValue::get(args[0]->getType());
10045+
agg = Builder2.CreateInsertValue(
10046+
agg,
10047+
Builder2.CreateFMul(
10048+
Builder2.CreateExtractValue(args[0], 0), div),
10049+
0);
10050+
agg = Builder2.CreateInsertValue(
10051+
agg,
10052+
Builder2.CreateFMul(
10053+
Builder2.CreateExtractValue(args[0], 1), div),
10054+
1);
10055+
10056+
addToDiffe(orig->getArgOperand(0), agg, Builder2,
10057+
orig->getType());
10058+
return;
10059+
}
10060+
}
10061+
}
1000410062
}
10063+
llvm::errs() << *orig << "\n";
10064+
llvm_unreachable("unknown calling convention found for cabs");
1000510065
}
1000610066
case DerivativeMode::ReverseModePrimal: {
1000710067
return;

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4240,7 +4240,21 @@ void TypeAnalyzer::visitCallInst(CallInst &call) {
42404240
llvm::errs() << *T << " - " << call << "\n";
42414241
llvm_unreachable("Unknown type for libm");
42424242
}
4243-
4243+
} else if (auto AT = dyn_cast<ArrayType>(T)) {
4244+
assert(AT->getNumElements() >= 1);
4245+
if (AT->getElementType()->isFloatingPointTy())
4246+
updateAnalysis(
4247+
call.getArgOperand(i),
4248+
TypeTree(ConcreteType(AT->getElementType()->getScalarType()))
4249+
.Only(-1),
4250+
&call);
4251+
else if (AT->getElementType()->isIntegerTy()) {
4252+
updateAnalysis(call.getArgOperand(i),
4253+
TypeTree(BaseType::Integer).Only(-1), &call);
4254+
} else {
4255+
llvm::errs() << *T << " - " << call << "\n";
4256+
llvm_unreachable("Unknown type for libm");
4257+
}
42444258
} else {
42454259
llvm::errs() << *T << " - " << call << "\n";
42464260
llvm_unreachable("Unknown type for libm");
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s
2+
3+
; Function Attrs: nounwind readnone uwtable
4+
define double @tester(double %x, double %y) {
5+
entry:
6+
%call = call double @cabs(double %x, double %y)
7+
ret double %call
8+
}
9+
10+
define double @test_derivative(double %x, double %y) {
11+
entry:
12+
%0 = tail call double (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, metadata !"enzyme_const", double %x, double %y, double 1.0)
13+
ret double %0
14+
}
15+
16+
declare double @cabs(double, double)
17+
18+
; Function Attrs: nounwind
19+
declare double @__enzyme_fwddiff(double (double, double)*, ...)
20+
21+
22+
; CHECK: define internal double @fwddiffetester(double %x, double %y, double %"y'")
23+
; CHECK-NEXT: entry:
24+
; CHECK-NEXT: %0 = call fast double @cabs(double %x, double %y)
25+
; CHECK-NEXT: %1 = fdiv fast double %"y'", %0
26+
; CHECK-NEXT: %2 = fmul fast double %x, %1
27+
; CHECK-NEXT: ret double %2
28+
; CHECK-NEXT:}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s
2+
3+
; Function Attrs: nounwind readnone uwtable
4+
define double @tester(double %x, double %y) {
5+
entry:
6+
%call = call double @cabs(double %x, double %y)
7+
ret double %call
8+
}
9+
10+
define double @test_derivative(double %x, double %y) {
11+
entry:
12+
%0 = tail call double (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, double %x, double 1.0, double %y, double 1.0)
13+
ret double %0
14+
}
15+
16+
declare double @cabs(double, double)
17+
18+
; Function Attrs: nounwind
19+
declare double @__enzyme_fwddiff(double (double, double)*, ...)
20+
21+
22+
; CHECK: define internal double @fwddiffetester(double %x, double %"x'", double %y, double %"y'")
23+
; CHECK-NEXT: entry:
24+
; CHECK-NEXT: %0 = call fast double @cabs(double %x, double %y)
25+
; CHECK-NEXT: %1 = fdiv fast double %"x'", %0
26+
; CHECK-NEXT: %2 = fmul fast double %x, %1
27+
; CHECK-NEXT: %3 = fdiv fast double %"y'", %0
28+
; CHECK-NEXT: %4 = fmul fast double %y, %3
29+
; CHECK-NEXT: %5 = fadd fast double %2, %4
30+
; CHECK-NEXT: ret double %5
31+
; CHECK-NEXT: }
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s
2+
3+
; Function Attrs: nounwind readnone willreturn
4+
declare double @cabs([2 x double])
5+
6+
; Function Attrs: nounwind readnone uwtable
7+
define double @tester(double %x, double %y) {
8+
entry:
9+
%agg0 = insertvalue [2 x double] undef, double %x, 0
10+
%agg1 = insertvalue [2 x double] %agg0, double %y, 1
11+
%call = call double @cabs([2 x double] %agg1)
12+
ret double %call
13+
}
14+
15+
define double @test_derivative(double %x, double %y) {
16+
entry:
17+
%0 = tail call double (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, metadata !"enzyme_const", double %x, double %y, double 1.0)
18+
ret double %0
19+
}
20+
21+
; Function Attrs: nounwind
22+
declare double @__enzyme_fwddiff(double (double, double)*, ...)
23+
24+
25+
; CHECK: define internal double @fwddiffetester(double %x, double %y, double %"y'")
26+
; CHECK-NEXT: entry:
27+
; CHECK-NEXT: %agg0 = insertvalue [2 x double] undef, double %x, 0
28+
; CHECK-NEXT: %"agg1'ipiv" = insertvalue [2 x double] zeroinitializer, double %"y'", 1
29+
; CHECK-NEXT: %agg1 = insertvalue [2 x double] %agg0, double %y, 1
30+
; CHECK-NEXT: %0 = call fast double @cabs([2 x double] %agg1)
31+
; CHECK-NEXT: %1 = extractvalue [2 x double] %"agg1'ipiv", 0
32+
; CHECK-NEXT: %2 = fdiv fast double %1, %0
33+
; CHECK-NEXT: %3 = fmul fast double %x, %2
34+
; CHECK-NEXT: %4 = fdiv fast double %"y'", %0
35+
; CHECK-NEXT: %5 = fmul fast double %y, %4
36+
; CHECK-NEXT: %6 = fadd fast double %3, %5
37+
; CHECK-NEXT: ret double %6
38+
; CHECK-NEXT: }
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s
2+
3+
; Function Attrs: nounwind readnone willreturn
4+
declare double @cabs([2 x double])
5+
6+
; Function Attrs: nounwind readnone uwtable
7+
define double @tester(double %x, double %y) {
8+
entry:
9+
%agg0 = insertvalue [2 x double] undef, double %x, 0
10+
%agg1 = insertvalue [2 x double] %agg0, double %y, 1
11+
%call = call double @cabs([2 x double] %agg1)
12+
ret double %call
13+
}
14+
15+
define double @test_derivative(double %x, double %y) {
16+
entry:
17+
%0 = tail call double (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, double %x, double 1.0, double %y, double 1.0)
18+
ret double %0
19+
}
20+
21+
; Function Attrs: nounwind
22+
declare double @__enzyme_fwddiff(double (double, double)*, ...)
23+
24+
25+
; CHECK: define internal double @fwddiffetester(double %x, double %"x'", double %y, double %"y'")
26+
; CHECK-NEXT: entry:
27+
; CHECK-NEXT: %agg0 = insertvalue [2 x double] undef, double %x, 0
28+
; CHECK-NEXT: %agg1 = insertvalue [2 x double] %agg0, double %y, 1
29+
; CHECK-NEXT: %0 = call fast double @cabs([2 x double] %agg1)
30+
; CHECK-NEXT: %1 = fdiv fast double %"x'", %0
31+
; CHECK-NEXT: %2 = fmul fast double %x, %1
32+
; CHECK-NEXT: %3 = fdiv fast double %"y'", %0
33+
; CHECK-NEXT: %4 = fmul fast double %y, %3
34+
; CHECK-NEXT: %5 = fadd fast double %2, %4
35+
; CHECK-NEXT: ret double %5
36+
; CHECK-NEXT: }
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s
2+
3+
; Function Attrs: nounwind readnone uwtable
4+
define double @tester(double %x, double %y) {
5+
entry:
6+
%call = call double @cabs(double %x, double %y)
7+
ret double %call
8+
}
9+
10+
define [3 x double] @test_derivative(double %x, double %y) {
11+
entry:
12+
%0 = tail call [3 x double] (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, metadata !"enzyme_width", i64 3, double %x, double 1.0, double 1.3, double 2.0, double %y, double 1.0, double 0.0, double 2.0)
13+
ret [3 x double] %0
14+
}
15+
16+
declare double @cabs(double, double)
17+
18+
; Function Attrs: nounwind
19+
declare [3 x double] @__enzyme_fwddiff(double (double, double)*, ...)
20+
21+
22+
; CHECK: define internal [3 x double] @fwddiffe3tester(double %x, [3 x double] %"x'", double %y, [3 x double] %"y'")
23+
; CHECK-NEXT: entry:
24+
; CHECK-NEXT: %0 = call fast double @cabs(double %x, double %y)
25+
; CHECK-NEXT: %1 = extractvalue [3 x double] %"x'", 0
26+
; CHECK-NEXT: %2 = extractvalue [3 x double] %"y'", 0
27+
; CHECK-NEXT: %3 = fdiv fast double %1, %0
28+
; CHECK-NEXT: %4 = fmul fast double %x, %3
29+
; CHECK-NEXT: %5 = fdiv fast double %2, %0
30+
; CHECK-NEXT: %6 = fmul fast double %y, %5
31+
; CHECK-NEXT: %7 = fadd fast double %4, %6
32+
; CHECK-NEXT: %8 = insertvalue [3 x double] undef, double %7, 0
33+
; CHECK-NEXT: %9 = extractvalue [3 x double] %"x'", 1
34+
; CHECK-NEXT: %10 = extractvalue [3 x double] %"y'", 1
35+
; CHECK-NEXT: %11 = fdiv fast double %9, %0
36+
; CHECK-NEXT: %12 = fmul fast double %x, %11
37+
; CHECK-NEXT: %13 = fdiv fast double %10, %0
38+
; CHECK-NEXT: %14 = fmul fast double %y, %13
39+
; CHECK-NEXT: %15 = fadd fast double %12, %14
40+
; CHECK-NEXT: %16 = insertvalue [3 x double] %8, double %15, 1
41+
; CHECK-NEXT: %17 = extractvalue [3 x double] %"x'", 2
42+
; CHECK-NEXT: %18 = extractvalue [3 x double] %"y'", 2
43+
; CHECK-NEXT: %19 = fdiv fast double %17, %0
44+
; CHECK-NEXT: %20 = fmul fast double %x, %19
45+
; CHECK-NEXT: %21 = fdiv fast double %18, %0
46+
; CHECK-NEXT: %22 = fmul fast double %y, %21
47+
; CHECK-NEXT: %23 = fadd fast double %20, %22
48+
; CHECK-NEXT: %24 = insertvalue [3 x double] %16, double %23, 2
49+
; CHECK-NEXT: ret [3 x double] %24
50+
; CHECK-NEXT: }
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -adce -S | FileCheck %s
2+
3+
; Function Attrs: nounwind readnone willreturn
4+
declare double @cabs([2 x double]) #7
5+
6+
; Function Attrs: nounwind readnone uwtable
7+
define double @tester(double %x, double %y) {
8+
entry:
9+
%agg0 = insertvalue [2 x double] undef, double %x, 0
10+
%agg1 = insertvalue [2 x double] %agg0, double %y, 1
11+
%call = call double @cabs([2 x double] %agg1)
12+
ret double %call
13+
}
14+
15+
define [3 x double] @test_derivative(double %x, double %y) {
16+
entry:
17+
%0 = tail call [3 x double] (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, metadata !"enzyme_width", i64 3, double %x, double 1.0, double 1.3, double 2.0, double %y, double 1.0, double 0.0, double 2.0)
18+
ret [3 x double] %0
19+
}
20+
21+
; Function Attrs: nounwind
22+
declare [3 x double] @__enzyme_fwddiff(double (double, double)*, ...)
23+
24+
25+
; CHECK: define internal [3 x double] @fwddiffe3tester(double %x, [3 x double] %"x'", double %y, [3 x double] %"y'")
26+
; CHECK-NEXT: entry:
27+
; CHECK-NEXT: %0 = extractvalue [3 x double] %"x'", 0
28+
; CHECK-NEXT: %1 = extractvalue [3 x double] %"x'", 1
29+
; CHECK-NEXT: %2 = extractvalue [3 x double] %"x'", 2
30+
; CHECK-NEXT: %agg0 = insertvalue [2 x double] undef, double %x, 0
31+
; CHECK-NEXT: %3 = extractvalue [3 x double] %"y'", 0
32+
; CHECK-NEXT: %4 = extractvalue [3 x double] %"y'", 1
33+
; CHECK-NEXT: %5 = extractvalue [3 x double] %"y'", 2
34+
; CHECK-NEXT: %agg1 = insertvalue [2 x double] %agg0, double %y, 1
35+
; CHECK-NEXT: %6 = call fast double @cabs([2 x double] %agg1)
36+
; CHECK-NEXT: %7 = fdiv fast double %0, %6
37+
; CHECK-NEXT: %8 = fmul fast double %x, %7
38+
; CHECK-NEXT: %9 = fdiv fast double %3, %6
39+
; CHECK-NEXT: %10 = fmul fast double %y, %9
40+
; CHECK-NEXT: %11 = fadd fast double %8, %10
41+
; CHECK-NEXT: %12 = insertvalue [3 x double] undef, double %11, 0
42+
; CHECK-NEXT: %13 = fdiv fast double %1, %6
43+
; CHECK-NEXT: %14 = fmul fast double %x, %13
44+
; CHECK-NEXT: %15 = fdiv fast double %4, %6
45+
; CHECK-NEXT: %16 = fmul fast double %y, %15
46+
; CHECK-NEXT: %17 = fadd fast double %14, %16
47+
; CHECK-NEXT: %18 = insertvalue [3 x double] %12, double %17, 1
48+
; CHECK-NEXT: %19 = fdiv fast double %2, %6
49+
; CHECK-NEXT: %20 = fmul fast double %x, %19
50+
; CHECK-NEXT: %21 = fdiv fast double %5, %6
51+
; CHECK-NEXT: %22 = fmul fast double %y, %21
52+
; CHECK-NEXT: %23 = fadd fast double %20, %22
53+
; CHECK-NEXT: %24 = insertvalue [3 x double] %18, double %23, 2
54+
; CHECK-NEXT: ret [3 x double] %24
55+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)