Skip to content

Commit 1fca52a

Browse files
wsmosesUbuntu
and
Ubuntu
authored
PMPI recognition (rust-lang#548)
Co-authored-by: Ubuntu <[email protected]>
1 parent b5e84b9 commit 1fca52a

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5981,7 +5981,8 @@ class AdjointGenerator
59815981
return;
59825982
}
59835983

5984-
if (funcName == "MPI_Send" || funcName == "MPI_Ssend") {
5984+
if (funcName == "MPI_Send" || funcName == "MPI_Ssend" ||
5985+
funcName == "PMPI_Send" || funcName == "PMPI_Ssend") {
59855986
if (Mode == DerivativeMode::ReverseModeGradient ||
59865987
Mode == DerivativeMode::ReverseModeCombined ||
59875988
Mode == DerivativeMode::ForwardMode) {
@@ -6008,6 +6009,12 @@ class AdjointGenerator
60086009
statusArg--;
60096010
if (auto PT = dyn_cast<PointerType>(statusArg->getType()))
60106011
statusType = PT->getPointerElementType();
6012+
} else if (Function *recvfn =
6013+
called->getParent()->getFunction("PMPI_Recv")) {
6014+
auto statusArg = recvfn->arg_end();
6015+
statusArg--;
6016+
if (auto PT = dyn_cast<PointerType>(statusArg->getType()))
6017+
statusType = PT->getPointerElementType();
60116018
}
60126019
if (statusType == nullptr) {
60136020
statusType = ArrayType::get(Type::getInt8Ty(call.getContext()), 24);

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3627,7 +3627,9 @@ void TypeAnalyzer::visitCallInst(CallInst &call) {
36273627
}
36283628
if (funcName == "MPI_Send" || funcName == "MPI_Ssend" ||
36293629
funcName == "MPI_Bsend" || funcName == "MPI_Recv" ||
3630-
funcName == "MPI_Brecv") {
3630+
funcName == "MPI_Brecv" || funcName == "PMPI_Send" ||
3631+
funcName == "PMPI_Ssend" || funcName == "PMPI_Bsend" ||
3632+
funcName == "PMPI_Recv" || funcName == "PMPI_Brecv") {
36313633
TypeTree buf = TypeTree(BaseType::Pointer);
36323634

36333635
if (Constant *C = dyn_cast<Constant>(call.getOperand(2))) {

0 commit comments

Comments
 (0)