diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index a99159c92574c..6157e2601ec78 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -164,6 +164,10 @@ static bool isParentOrSameContext(LoopContext & possibleChild, LoopContext & pos // Case 2: The correct exiting block terminator unconditionally branches a different block, change to a conditional branch depending on if we are the first iteration } else if (succ.size() == 1) { + lc.latchMerge->getTerminator()->eraseFromParent(); + mergeBuilder.SetInsertPoint(lc.latchMerge); + + assert(mergeBuilder.GetInsertBlock()->size() == 0 || !isa(mergeBuilder.GetInsertBlock()->back())); // If first iteration, branch to the exiting block, otherwise the backlatch mergeBuilder.CreateCondBr(firstiter, succ[0], reverseBlocks[backlatch]); @@ -187,6 +191,8 @@ static bool isParentOrSameContext(LoopContext & possibleChild, LoopContext & pos lc.latchMerge->getTerminator()->eraseFromParent(); mergeBuilder.SetInsertPoint(lc.latchMerge); + + assert(mergeBuilder.GetInsertBlock()->size() == 0 || !isa(mergeBuilder.GetInsertBlock()->back())); mergeBuilder.CreateCondBr(firstiter, splitBlock, reverseBlocks[backlatch]); } @@ -858,6 +864,7 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B if (targetToPreds.size() == 1) { if (replacePHIs == nullptr) { + assert(BuilderM.GetInsertBlock()->size() == 0 || !isa(BuilderM.GetInsertBlock()->back())); BuilderM.CreateBr( targetToPreds.begin()->first ); } else { for (auto pair : *replacePHIs) { @@ -962,6 +969,7 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B Value* phi = lookupValueFromCache(BuilderM, ctx, cache); if (replacePHIs == nullptr) { + assert(BuilderM.GetInsertBlock()->size() == 0 || !isa(BuilderM.GetInsertBlock()->back())); BuilderM.CreateCondBr(phi, *done[std::make_pair(block, branch->getSuccessor(0))].begin(), *done[std::make_pair(block, branch->getSuccessor(1))].begin()); } else { for (auto pair : *replacePHIs) { @@ -1076,6 +1084,7 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B if (replacePHIs == nullptr) { if (targetToPreds.size() == 2) { + assert(BuilderM.GetInsertBlock()->size() == 0 || !isa(BuilderM.GetInsertBlock()->back())); BuilderM.CreateCondBr(which, /*true*/targets[1], /*false*/targets[0]); } else { auto swit = BuilderM.CreateSwitch(which, targets.back(), targets.size()-1); diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index a44598a2d67c1..919fb73df8e48 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -966,12 +966,14 @@ class GradientUtils { llvm::errs() << "starting outermost ph at " << allocationPreheaders[i]->getName() << "|ctx=" << ctx->getName() <<"\n"; sublimits.push_back(std::make_pair(size, lims)); size = nullptr; + lims.clear(); } } if (size != nullptr) { llvm::errs() << "starting final outermost ph at " << allocationPreheaders[contexts.size()-1]->getName()<<"|ctx=" << ctx->getName() << "\n"; sublimits.push_back(std::make_pair(size, lims)); + lims.clear(); } return sublimits; } @@ -1118,6 +1120,7 @@ class GradientUtils { indices.push_back(idx.var); available[idx.var] = idx.var; } + llvm::errs() << "W sl idx=" << i << " " << *idx.var << " header=" << idx.header->getName() << "\n"; Value* lim = unwrapM(riter->second, BuilderM, available, /*lookupIfAble*/true); assert(lim); @@ -1129,9 +1132,11 @@ class GradientUtils { } if (indices.size() > 0) { + llvm::errs() << "sl idx=" << i << " " << *indices[0] << "\n"; Value* idx = indices[0]; - for(unsigned i=1; i