Skip to content

Commit 43c0a9c

Browse files
authored
Introduce enzyme_iter for preserving approximation levels through differentiation. (rust-lang#229)
* Enzyme iter * Temp * Fix indexing * Fix recursive struct accumulation
1 parent 9be5033 commit 43c0a9c

File tree

8 files changed

+183
-94
lines changed

8 files changed

+183
-94
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -922,8 +922,16 @@ class AdjointGenerator
922922
if (!gutils->isConstantValue(orig_vec)) {
923923
SmallVector<Value *, 4> sv;
924924
sv.push_back(gutils->getNewFromOriginal(EEI.getIndexOperand()));
925+
926+
size_t size = 1;
927+
if (EEI.getType()->isSized())
928+
size = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
929+
EEI.getType()) +
930+
7) /
931+
8;
925932
((DiffeGradientUtils *)gutils)
926-
->addToDiffeIndexed(orig_vec, diffe(&EEI, Builder2), sv, Builder2);
933+
->addToDiffe(orig_vec, diffe(&EEI, Builder2), Builder2,
934+
TR.addingType(size, &EEI), sv);
927935
}
928936
setDiffe(&EEI, Constant::getNullValue(EEI.getType()), Builder2);
929937
}
@@ -1000,11 +1008,20 @@ class AdjointGenerator
10001008
auto opidx = (idx < l1) ? idx : (idx - l1);
10011009
SmallVector<Value *, 4> sv;
10021010
sv.push_back(ConstantInt::get(Type::getInt32Ty(SVI.getContext()), opidx));
1003-
if (!gutils->isConstantValue(SVI.getOperand(opnum)))
1011+
if (!gutils->isConstantValue(SVI.getOperand(opnum))) {
1012+
size_t size = 1;
1013+
if (SVI.getOperand(opnum)->getType()->isSized())
1014+
size =
1015+
(gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1016+
SVI.getOperand(opnum)->getType()) +
1017+
7) /
1018+
8;
10041019
((DiffeGradientUtils *)gutils)
1005-
->addToDiffeIndexed(SVI.getOperand(opnum),
1006-
Builder2.CreateExtractElement(loaded, instidx),
1007-
sv, Builder2);
1020+
->addToDiffe(SVI.getOperand(opnum),
1021+
Builder2.CreateExtractElement(loaded, instidx),
1022+
Builder2, TR.addingType(size, SVI.getOperand(opnum)),
1023+
sv);
1024+
}
10081025
++instidx;
10091026
}
10101027
setDiffe(&SVI, Constant::getNullValue(SVI.getType()), Builder2);
@@ -1032,8 +1049,15 @@ class AdjointGenerator
10321049
SmallVector<Value *, 4> sv;
10331050
for (auto i : EVI.getIndices())
10341051
sv.push_back(ConstantInt::get(Type::getInt32Ty(EVI.getContext()), i));
1052+
size_t size = 1;
1053+
if (EVI.getType()->isSized())
1054+
size = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
1055+
EVI.getType()) +
1056+
7) /
1057+
8;
10351058
((DiffeGradientUtils *)gutils)
1036-
->addToDiffeIndexed(orig_op0, prediff, sv, Builder2);
1059+
->addToDiffe(orig_op0, prediff, Builder2, TR.addingType(size, &EVI),
1060+
sv);
10371061
}
10381062

10391063
setDiffe(&EVI, Constant::getNullValue(EVI.getType()), Builder2);

enzyme/Enzyme/Enzyme.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,7 @@ class Enzyme : public ModulePass {
749749

750750
std::map<CallInst *, DerivativeMode> toLower;
751751
std::set<CallInst *> InactiveCalls;
752+
std::set<CallInst *> IterCalls;
752753
retry:;
753754
for (BasicBlock &BB : F) {
754755
for (Instruction &I : BB) {
@@ -808,6 +809,10 @@ class Enzyme : public ModulePass {
808809
}
809810
}
810811
}
812+
if (Fn->getName() == "__enzyme_iter") {
813+
Fn->addFnAttr(Attribute::ReadNone);
814+
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
815+
}
811816
if (Fn->getName().contains("__enzyme_call_inactive")) {
812817
InactiveCalls.insert(CI);
813818
}
@@ -1243,13 +1248,18 @@ class Enzyme : public ModulePass {
12431248
F->getName() == "__enzyme_pointer") {
12441249
toErase.push_back(CI);
12451250
}
1251+
if (F->getName() == "__enzyme_iter") {
1252+
CI->replaceAllUsesWith(CI->getArgOperand(0));
1253+
toErase.push_back(CI);
1254+
}
12461255
}
12471256
}
12481257
}
12491258
}
12501259
}
12511260
for (auto I : toErase) {
12521261
I->eraseFromParent();
1262+
changed = true;
12531263
}
12541264

12551265
Logic.clear();

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,39 @@ Function *PreProcessCache::preprocessForClone(Function *F,
787787
}
788788
}
789789

790+
{
791+
std::vector<CallInst *> ItersToErase;
792+
for (auto &BB : *NewF) {
793+
for (auto &I : BB) {
794+
795+
if (auto CI = dyn_cast<CallInst>(&I)) {
796+
797+
Function *called = CI->getCalledFunction();
798+
#if LLVM_VERSION_MAJOR >= 11
799+
if (auto castinst = dyn_cast<ConstantExpr>(CI->getCalledOperand()))
800+
#else
801+
if (auto castinst = dyn_cast<ConstantExpr>(CI->getCalledValue()))
802+
#endif
803+
{
804+
if (castinst->isCast()) {
805+
if (auto fn = dyn_cast<Function>(castinst->getOperand(0)))
806+
called = fn;
807+
}
808+
}
809+
810+
if (called && called->getName() == "__enzyme_iter") {
811+
ItersToErase.push_back(CI);
812+
}
813+
}
814+
}
815+
}
816+
for (auto Call : ItersToErase) {
817+
IRBuilder<> B(Call);
818+
Call->setArgOperand(
819+
0, B.CreateAdd(Call->getArgOperand(0), Call->getArgOperand(1)));
820+
}
821+
}
822+
790823
if (EnzymeLowerGlobals) {
791824
std::vector<CallInst *> Calls;
792825
std::vector<ReturnInst *> Returns;

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2557,6 +2557,9 @@ Value *GradientUtils::invertPointerM(Value *oval, IRBuilder<> &BuilderM) {
25572557
invertedPointers[arg] = li;
25582558
return lookupM(invertedPointers[arg], BuilderM);
25592559
} else if (auto arg = dyn_cast<BinaryOperator>(oval)) {
2560+
if (arg->getOpcode() == Instruction::FAdd)
2561+
return lookupM(getNewFromOriginal(arg), BuilderM);
2562+
25602563
if (!arg->getType()->isIntOrIntVectorTy()) {
25612564
llvm::errs() << *oval << "\n";
25622565
}

enzyme/Enzyme/GradientUtils.h

Lines changed: 34 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,8 +1321,9 @@ class DiffeGradientUtils : public GradientUtils {
13211321
}
13221322

13231323
// Returns created select instructions, if any
1324-
std::vector<SelectInst *>
1325-
addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, Type *addingType) {
1324+
std::vector<SelectInst *> addToDiffe(Value *val, Value *dif,
1325+
IRBuilder<> &BuilderM, Type *addingType,
1326+
ArrayRef<Value *> idxs = {}) {
13261327
if (auto arg = dyn_cast<Argument>(val))
13271328
assert(arg->getParent() == oldFunc);
13281329
if (auto inst = dyn_cast<Instruction>(val))
@@ -1407,19 +1408,27 @@ class DiffeGradientUtils : public GradientUtils {
14071408
}
14081409
assert(!val->getType()->isPointerTy());
14091410
assert(!isConstantValue(val));
1410-
if (val->getType() != dif->getType()) {
1411-
llvm::errs() << "val: " << *val << " dif: " << *dif << "\n";
1411+
1412+
Value *ptr = getDifferential(val);
1413+
1414+
if (idxs.size() != 0) {
1415+
SmallVector<Value *, 4> sv;
1416+
sv.push_back(ConstantInt::get(Type::getInt32Ty(val->getContext()), 0));
1417+
for (auto i : idxs)
1418+
sv.push_back(i);
1419+
ptr = BuilderM.CreateGEP(ptr, sv);
1420+
cast<GetElementPtrInst>(ptr)->setIsInBounds(true);
14121421
}
1413-
assert(val->getType() == dif->getType());
1414-
auto old = diffe(val, BuilderM);
1415-
assert(val->getType() == old->getType());
1422+
Value *old = BuilderM.CreateLoad(ptr);
1423+
1424+
assert(dif->getType() == old->getType());
14161425
Value *res = nullptr;
1417-
if (val->getType()->isIntOrIntVectorTy()) {
1426+
if (old->getType()->isIntOrIntVectorTy()) {
14181427
if (!addingType) {
14191428
llvm::errs() << "module: " << *oldFunc->getParent() << "\n";
14201429
llvm::errs() << "oldFunc: " << *oldFunc << "\n";
14211430
llvm::errs() << "newFunc: " << *newFunc << "\n";
1422-
llvm::errs() << "val: " << *val << "\n";
1431+
llvm::errs() << "val: " << *val << " old: " << old << "\n";
14231432
}
14241433
assert(addingType);
14251434
assert(addingType->isFPOrFPVectorTy());
@@ -1448,30 +1457,34 @@ class DiffeGradientUtils : public GradientUtils {
14481457
addedSelects.erase(addedSelects.end() - 1);
14491458
res = BuilderM.CreateSelect(
14501459
select->getCondition(),
1451-
BuilderM.CreateBitCast(select->getTrueValue(), val->getType()),
1452-
BuilderM.CreateBitCast(select->getFalseValue(), val->getType()));
1460+
BuilderM.CreateBitCast(select->getTrueValue(), old->getType()),
1461+
BuilderM.CreateBitCast(select->getFalseValue(), old->getType()));
14531462
assert(select->getNumUses() == 0);
14541463
} else {
1455-
res = BuilderM.CreateBitCast(res, val->getType());
1464+
res = BuilderM.CreateBitCast(res, old->getType());
14561465
}
1457-
BuilderM.CreateStore(res, getDifferential(val));
1466+
BuilderM.CreateStore(res, ptr);
14581467
// store->setAlignment(align);
14591468
return addedSelects;
1460-
} else if (val->getType()->isFPOrFPVectorTy()) {
1469+
} else if (old->getType()->isFPOrFPVectorTy()) {
14611470
// TODO consider adding type
14621471
res = faddForSelect(old, dif);
14631472

1464-
BuilderM.CreateStore(res, getDifferential(val));
1473+
BuilderM.CreateStore(res, ptr);
14651474
// store->setAlignment(align);
14661475
return addedSelects;
1467-
} else if (val->getType()->isStructTy()) {
1468-
auto st = cast<StructType>(val->getType());
1476+
} else if (auto st = dyn_cast<StructType>(old->getType())) {
14691477
for (unsigned i = 0; i < st->getNumElements(); ++i) {
1478+
// TODO pass in full type tree here and recurse into tree.
1479+
if (st->getElementType(i)->isPointerTy())
1480+
continue;
14701481
Value *v = ConstantInt::get(Type::getInt32Ty(st->getContext()), i);
1471-
SelectInst *addedSelect = addToDiffeIndexed(
1472-
val, BuilderM.CreateExtractValue(dif, {i}), {v}, BuilderM);
1473-
if (addedSelect) {
1474-
addedSelects.push_back(addedSelect);
1482+
SmallVector<Value *, 2> idx2(idxs.begin(), idxs.end());
1483+
idx2.push_back(v);
1484+
auto selects = addToDiffe(val, BuilderM.CreateExtractValue(dif, {i}),
1485+
BuilderM, nullptr, idx2);
1486+
for (auto select : selects) {
1487+
addedSelects.push_back(select);
14751488
}
14761489
}
14771490
return addedSelects;
@@ -1502,72 +1515,6 @@ class DiffeGradientUtils : public GradientUtils {
15021515
BuilderM.CreateStore(toset, tostore);
15031516
}
15041517

1505-
SelectInst *addToDiffeIndexed(Value *val, Value *dif, ArrayRef<Value *> idxs,
1506-
IRBuilder<> &BuilderM) {
1507-
if (auto arg = dyn_cast<Argument>(val))
1508-
assert(arg->getParent() == oldFunc);
1509-
if (auto inst = dyn_cast<Instruction>(val))
1510-
assert(inst->getParent()->getParent() == oldFunc);
1511-
assert(!isConstantValue(val));
1512-
SmallVector<Value *, 4> sv;
1513-
sv.push_back(ConstantInt::get(Type::getInt32Ty(val->getContext()), 0));
1514-
for (auto i : idxs)
1515-
sv.push_back(i);
1516-
Value *ptr = BuilderM.CreateGEP(getDifferential(val), sv);
1517-
cast<GetElementPtrInst>(ptr)->setIsInBounds(true);
1518-
Value *old = BuilderM.CreateLoad(ptr);
1519-
1520-
Value *res = nullptr;
1521-
1522-
if (old->getType()->isIntOrIntVectorTy()) {
1523-
res = BuilderM.CreateFAdd(
1524-
BuilderM.CreateBitCast(old, IntToFloatTy(old->getType())),
1525-
BuilderM.CreateBitCast(dif, IntToFloatTy(dif->getType())));
1526-
res = BuilderM.CreateBitCast(res, old->getType());
1527-
} else if (old->getType()->isFPOrFPVectorTy()) {
1528-
res = BuilderM.CreateFAdd(old, dif);
1529-
} else {
1530-
assert(old);
1531-
assert(dif);
1532-
llvm::errs() << *newFunc << "\n"
1533-
<< "cannot handle type " << *old << "\n"
1534-
<< *dif;
1535-
assert(0 && "cannot handle type");
1536-
report_fatal_error("cannot handle type");
1537-
}
1538-
1539-
SelectInst *addedSelect = nullptr;
1540-
1541-
//! optimize fadd of select to select of fadd
1542-
// TODO: Handle Selects of ints
1543-
if (SelectInst *select = dyn_cast<SelectInst>(dif)) {
1544-
if (ConstantFP *ci = dyn_cast<ConstantFP>(select->getTrueValue())) {
1545-
if (ci->isZero()) {
1546-
cast<Instruction>(res)->eraseFromParent();
1547-
res = BuilderM.CreateSelect(
1548-
select->getCondition(), old,
1549-
BuilderM.CreateFAdd(old, select->getFalseValue()));
1550-
addedSelect = cast<SelectInst>(res);
1551-
goto endselect;
1552-
}
1553-
}
1554-
if (ConstantFP *ci = dyn_cast<ConstantFP>(select->getFalseValue())) {
1555-
if (ci->isZero()) {
1556-
cast<Instruction>(res)->eraseFromParent();
1557-
res = BuilderM.CreateSelect(
1558-
select->getCondition(),
1559-
BuilderM.CreateFAdd(old, select->getTrueValue()), old);
1560-
addedSelect = cast<SelectInst>(res);
1561-
goto endselect;
1562-
}
1563-
}
1564-
}
1565-
endselect:;
1566-
1567-
BuilderM.CreateStore(res, ptr);
1568-
return addedSelect;
1569-
}
1570-
15711518
void freeCache(llvm::BasicBlock *forwardPreheader,
15721519
const SubLimitType &sublimits, int i, llvm::AllocaInst *alloc,
15731520
llvm::ConstantInt *byteSizeOfType, llvm::Value *storeInto,

enzyme/test/Integration/ReverseMode/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Run regression and unit tests
22
add_lit_testsuite(check-enzyme-integration-reverse "Running enzyme reverse mode integration tests"
33
${CMAKE_CURRENT_BINARY_DIR}
4-
DEPENDS ${ENZYME_TEST_DEPS}
4+
DEPENDS ${ENZYME_TEST_DEPS} ClangEnzyme-${LLVM_VERSION_MAJOR}
55
ARGS -v
66
)
77

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O0 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
2+
// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
3+
// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
4+
// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
5+
// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O0 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
6+
// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
7+
// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
8+
// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
9+
10+
#include <stdio.h>
11+
#include <stdint.h>
12+
#include <math.h>
13+
14+
#include "test_utils.h"
15+
16+
__attribute__((noinline))
17+
uint64_t factorial(uint64_t x) {
18+
if (x == 0) return 1;
19+
return x * factorial(x-1);
20+
}
21+
22+
double my_sin(double x) {
23+
double result = 0;
24+
uint64_t N = 12;
25+
for(uint64_t i=0; i<=N; i++) {
26+
if (i % 2 == 0) continue;
27+
result += pow(x, i) / factorial(i) * (i % 4 == 1 ? 1 : -1);
28+
}
29+
return result;
30+
31+
}
32+
33+
uint64_t __enzyme_iter(uint64_t, uint64_t);
34+
35+
double __enzyme_autodiff(void*, double);
36+
37+
double my_sin2(double x) {
38+
double result = 0;
39+
uint64_t N = __enzyme_iter(12, 1);
40+
for(uint64_t i=0; i<=N; i++) {
41+
if (i % 2 == 0) continue;
42+
result += pow(x, i) / factorial(i) * (i % 4 == 1 ? 1 : -1);
43+
}
44+
return result;
45+
}
46+
47+
double d_mysin2(double x) {
48+
return __enzyme_autodiff(my_sin2, x);
49+
}
50+
double dd_mysin2(double x) {
51+
return __enzyme_autodiff(d_mysin2, x);
52+
}
53+
double ddd_mysin2(double x) {
54+
return __enzyme_autodiff(dd_mysin2, x);
55+
}
56+
double dddd_mysin2(double x) {
57+
return __enzyme_autodiff(ddd_mysin2, x);
58+
}
59+
60+
int main() {
61+
double x = 1.23;
62+
printf("my_sin(x=%f)=%e\n", x, my_sin(x));
63+
printf("my_sin2(x=%f)=%e\n", x, my_sin2(x));
64+
APPROX_EQ(my_sin2(x), my_sin(x), 10e-10);
65+
printf("dd_my_sin2(x=%f)=%e\n", x, dd_mysin2(x));
66+
APPROX_EQ(dd_mysin2(x), -my_sin(x), 10e-10);
67+
printf("dddd_my_sin2(x=%f)=%e\n", x, dddd_mysin2(x));
68+
APPROX_EQ(dddd_mysin2(x), my_sin(x), 10e-10);
69+
}

0 commit comments

Comments
 (0)