@@ -2883,21 +2883,27 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
2883
2883
if (augmenteddata->tapeType &&
2884
2884
augmenteddata->tapeType != additionalValue->getType ()) {
2885
2885
IRBuilder<> BuilderZ (gutils->inversionAllocs );
2886
- // assert(PointerType::getUnqual(augmenteddata->tapeType) ==
2887
- // additionalValue->getType()); auto tapep = additionalValue;
2888
- auto tapep = BuilderZ.CreatePointerCast (
2889
- additionalValue, PointerType::getUnqual (augmenteddata->tapeType ));
2890
- LoadInst *truetape = BuilderZ.CreateLoad (tapep, " truetape" );
2891
- truetape->setMetadata (" enzyme_mustcache" ,
2892
- MDNode::get (truetape->getContext (), {}));
2893
-
2894
- if (!omp) {
2895
- CallInst *ci = cast<CallInst>(CallInst::CreateFree (
2896
- additionalValue, truetape)); // &*BuilderZ.GetInsertPoint()));
2897
- ci->moveAfter (truetape);
2886
+ if (!augmenteddata->tapeType ->isEmptyTy ()) {
2887
+ auto tapep = BuilderZ.CreatePointerCast (
2888
+ additionalValue, PointerType::getUnqual (augmenteddata->tapeType ));
2889
+ LoadInst *truetape = BuilderZ.CreateLoad (tapep, " truetape" );
2890
+ truetape->setMetadata (" enzyme_mustcache" ,
2891
+ MDNode::get (truetape->getContext (), {}));
2892
+
2893
+ if (!omp) {
2894
+ CallInst *ci = cast<CallInst>(CallInst::CreateFree (
2895
+ additionalValue, truetape)); // &*BuilderZ.GetInsertPoint()));
2896
+ ci->moveAfter (truetape);
2897
+ ci->addAttribute (AttributeList::FirstArgIndex, Attribute::NonNull);
2898
+ }
2899
+ additionalValue = truetape;
2900
+ } else {
2901
+ CallInst *ci = cast<CallInst>(
2902
+ CallInst::CreateFree (additionalValue, gutils->inversionAllocs ));
2898
2903
ci->addAttribute (AttributeList::FirstArgIndex, Attribute::NonNull);
2904
+ BuilderZ.Insert (ci);
2905
+ additionalValue = UndefValue::get (augmenteddata->tapeType );
2899
2906
}
2900
- additionalValue = truetape;
2901
2907
}
2902
2908
2903
2909
// TODO here finish up making recursive structs simply pass in i8*
0 commit comments