diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index dca5ab38f..36df6f363 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -1942,6 +1942,9 @@ struct IfOpEnzymeOpsRemover removalBlockExplore(falseBlock, falseMapping, builder, gradients, pushedCaches); + if (gradients.size() == 0 || pushedCaches.size() == 0) + return success(); + Operation *trueTerm = trueBlock->getTerminator(); Operation *falseTerm = falseBlock->getTerminator(); @@ -2010,6 +2013,8 @@ struct IfOpEnzymeOpsRemover idx++; } + ifOp->replaceAllUsesWith( + newIf->getResults().slice(0, ifOp->getNumResults())); ifOp->erase(); return success();