File tree 2 files changed +11
-2
lines changed
2 files changed +11
-2
lines changed Original file line number Diff line number Diff line change @@ -5981,7 +5981,8 @@ class AdjointGenerator
5981
5981
return ;
5982
5982
}
5983
5983
5984
- if (funcName == " MPI_Send" || funcName == " MPI_Ssend" ) {
5984
+ if (funcName == " MPI_Send" || funcName == " MPI_Ssend" ||
5985
+ funcName == " PMPI_Send" || funcName == " PMPI_Ssend" ) {
5985
5986
if (Mode == DerivativeMode::ReverseModeGradient ||
5986
5987
Mode == DerivativeMode::ReverseModeCombined ||
5987
5988
Mode == DerivativeMode::ForwardMode) {
@@ -6008,6 +6009,12 @@ class AdjointGenerator
6008
6009
statusArg--;
6009
6010
if (auto PT = dyn_cast<PointerType>(statusArg->getType ()))
6010
6011
statusType = PT->getPointerElementType ();
6012
+ } else if (Function *recvfn =
6013
+ called->getParent ()->getFunction (" PMPI_Recv" )) {
6014
+ auto statusArg = recvfn->arg_end ();
6015
+ statusArg--;
6016
+ if (auto PT = dyn_cast<PointerType>(statusArg->getType ()))
6017
+ statusType = PT->getPointerElementType ();
6011
6018
}
6012
6019
if (statusType == nullptr ) {
6013
6020
statusType = ArrayType::get (Type::getInt8Ty (call.getContext ()), 24 );
Original file line number Diff line number Diff line change @@ -3627,7 +3627,9 @@ void TypeAnalyzer::visitCallInst(CallInst &call) {
3627
3627
}
3628
3628
if (funcName == " MPI_Send" || funcName == " MPI_Ssend" ||
3629
3629
funcName == " MPI_Bsend" || funcName == " MPI_Recv" ||
3630
- funcName == " MPI_Brecv" ) {
3630
+ funcName == " MPI_Brecv" || funcName == " PMPI_Send" ||
3631
+ funcName == " PMPI_Ssend" || funcName == " PMPI_Bsend" ||
3632
+ funcName == " PMPI_Recv" || funcName == " PMPI_Brecv" ) {
3631
3633
TypeTree buf = TypeTree (BaseType::Pointer);
3632
3634
3633
3635
if (Constant *C = dyn_cast<Constant>(call.getOperand (2 ))) {
You can’t perform that action at this time.
0 commit comments