Skip to content

Commit ffe46d5

Browse files
authored
Handle Pointer Returns in Forward Mode (rust-lang#352)
1 parent 534beda commit ffe46d5

File tree

3 files changed

+86
-11
lines changed

3 files changed

+86
-11
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7757,10 +7757,19 @@ class AdjointGenerator
77577757
} else {
77587758
diffe = diffes;
77597759
}
7760-
gutils->replaceAWithB(newcall, diffe);
7761-
gutils->erase(newcall);
7762-
if (!gutils->isConstantValue(&call))
7763-
setDiffe(&call, diffe, Builder2);
7760+
7761+
auto ifound = gutils->invertedPointers.find(orig);
7762+
if (ifound != gutils->invertedPointers.end()) {
7763+
auto placeholder = cast<PHINode>(&*ifound->second);
7764+
gutils->replaceAWithB(placeholder, diffe);
7765+
gutils->erase(placeholder);
7766+
} else {
7767+
gutils->replaceAWithB(newcall, diffe);
7768+
gutils->erase(newcall);
7769+
if (!gutils->isConstantValue(&call)) {
7770+
setDiffe(&call, diffe, Builder2);
7771+
}
7772+
}
77647773
} else {
77657774
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
77667775
}

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2314,8 +2314,8 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
23142314
return AugmentedCachedFunctions.find(tup)->second;
23152315
}
23162316

2317-
void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB,
2318-
DIFFE_TYPE retType, ReturnType retVal) {
2317+
void createTerminator(TypeResults &TR, DiffeGradientUtils *gutils,
2318+
BasicBlock *oBB, DIFFE_TYPE retType, ReturnType retVal) {
23192319

23202320
BasicBlock *nBB = cast<BasicBlock>(gutils->getNewFromOriginal(oBB));
23212321
assert(nBB);
@@ -2341,8 +2341,14 @@ void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB,
23412341

23422342
toret =
23432343
nBuilder.CreateInsertValue(toret, gutils->getNewFromOriginal(ret), 0);
2344-
toret =
2345-
nBuilder.CreateInsertValue(toret, gutils->diffe(ret, nBuilder), 1);
2344+
2345+
if (TR.getReturnAnalysis().Inner0().isPossiblePointer()) {
2346+
toret = nBuilder.CreateInsertValue(
2347+
toret, gutils->invertPointerM(ret, nBuilder), 1);
2348+
} else {
2349+
toret =
2350+
nBuilder.CreateInsertValue(toret, gutils->diffe(ret, nBuilder), 1);
2351+
}
23462352
break;
23472353
}
23482354
case ReturnType::Void: {
@@ -3662,8 +3668,7 @@ Function *EnzymeLogic::CreateForwardDiff(
36623668
}
36633669

36643670
auto TRo = TA.analyzeFunction(oldTypeInfo);
3665-
bool retActive = TRo.getReturnAnalysis().Inner0().isPossibleFloat() &&
3666-
!todiff->getReturnType()->isVoidTy();
3671+
bool retActive = retType != DIFFE_TYPE::CONSTANT;
36673672

36683673
ReturnType retVal =
36693674
returnValue ? (retActive ? ReturnType::TwoReturns : ReturnType::Return)
@@ -3847,7 +3852,7 @@ Function *EnzymeLogic::CreateForwardDiff(
38473852
maker->visit(&*it);
38483853
}
38493854

3850-
createTerminator(gutils, &oBB, retType, retVal);
3855+
createTerminator(TR, gutils, &oBB, retType, retVal);
38513856
}
38523857

38533858
gutils->eraseFictiousPHIs();
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
define dso_local noalias nonnull double* @_Z6toHeapd(double %x) {
4+
entry:
5+
%call = call noalias nonnull dereferenceable(8) i8* @_Znwm(i64 8)
6+
%0 = bitcast i8* %call to double*
7+
store double %x, double* %0, align 8
8+
ret double* %0
9+
}
10+
11+
declare dso_local nonnull i8* @_Znwm(i64)
12+
13+
define dso_local double @_Z6squared(double %x) {
14+
entry:
15+
%call = call double* @_Z6toHeapd(double %x)
16+
%0 = load double, double* %call, align 8
17+
%mul = fmul double %0, %x
18+
ret double %mul
19+
}
20+
21+
define dso_local double @_Z7dsquared(double %x) {
22+
entry:
23+
%call = call double (...) @_Z16__enzyme_fwddiffz(i8* bitcast (double (double)* @_Z6squared to i8*), double %x, double 1.000000e+00)
24+
ret double %call
25+
}
26+
27+
declare dso_local double @_Z16__enzyme_fwddiffz(...)
28+
29+
30+
31+
; CHECK: define dso_local double @_Z7dsquared(double %x)
32+
; CHECK-NEXT: entry:
33+
; CHECK-NEXT: %0 = call fast double @fwddiffe_Z6squared(double %x, double 1.000000e+00)
34+
; CHECK-NEXT: ret double %0
35+
; CHECK-NEXT: }
36+
37+
; CHECK: define internal double @fwddiffe_Z6squared(double %x, double %"x'")
38+
; CHECK-NEXT: entry:
39+
; CHECK-NEXT: %call = call double* @_Z6toHeapd(double %x)
40+
; CHECK-NEXT: %0 = call { double*, double* } @fwddiffe_Z6toHeapd(double %x, double %"x'")
41+
; CHECK-NEXT: %1 = extractvalue { double*, double* } %0, 1
42+
; CHECK-NEXT: %2 = load double, double* %call, align 8
43+
; CHECK-NEXT: %3 = load double, double* %1, align 8
44+
; CHECK-NEXT: %4 = fmul fast double %3, %x
45+
; CHECK-NEXT: %5 = fmul fast double %"x'", %2
46+
; CHECK-NEXT: %6 = fadd fast double %4, %5
47+
; CHECK-NEXT: ret double %6
48+
; CHECK-NEXT: }
49+
50+
; CHECK: define internal { double*, double* } @fwddiffe_Z6toHeapd(double %x, double %"x'")
51+
; CHECK-NEXT: entry:
52+
; CHECK-NEXT: %call = call noalias nonnull dereferenceable(8) i8* @_Znwm(i64 8)
53+
; CHECK-NEXT: %0 = call noalias nonnull dereferenceable(8) i8* @_Znwm(i64 8)
54+
; CHECK-NEXT: %"'ipc" = bitcast i8* %0 to double*
55+
; CHECK-NEXT: %1 = bitcast i8* %call to double*
56+
; CHECK-NEXT: store double %x, double* %1, align 8
57+
; CHECK-NEXT: store double %"x'", double* %"'ipc", align 8
58+
; CHECK-NEXT: %2 = insertvalue { double*, double* } undef, double* %1, 0
59+
; CHECK-NEXT: %3 = insertvalue { double*, double* } %2, double* %"'ipc", 1
60+
; CHECK-NEXT: ret { double*, double* } %3
61+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)