Skip to content

Commit

Permalink
Fix indexing and c tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 17, 2021
1 parent 13a79c4 commit 189a8ff
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
9 changes: 9 additions & 0 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ static bool isParentOrSameContext(LoopContext & possibleChild, LoopContext & pos

// Case 2: The correct exiting block terminator unconditionally branches a different block, change to a conditional branch depending on if we are the first iteration
} else if (succ.size() == 1) {
lc.latchMerge->getTerminator()->eraseFromParent();
mergeBuilder.SetInsertPoint(lc.latchMerge);

assert(mergeBuilder.GetInsertBlock()->size() == 0 || !isa<BranchInst>(mergeBuilder.GetInsertBlock()->back()));

// If first iteration, branch to the exiting block, otherwise the backlatch
mergeBuilder.CreateCondBr(firstiter, succ[0], reverseBlocks[backlatch]);
Expand All @@ -187,6 +191,8 @@ static bool isParentOrSameContext(LoopContext & possibleChild, LoopContext & pos

lc.latchMerge->getTerminator()->eraseFromParent();
mergeBuilder.SetInsertPoint(lc.latchMerge);

assert(mergeBuilder.GetInsertBlock()->size() == 0 || !isa<BranchInst>(mergeBuilder.GetInsertBlock()->back()));
mergeBuilder.CreateCondBr(firstiter, splitBlock, reverseBlocks[backlatch]);

}
Expand Down Expand Up @@ -858,6 +864,7 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B

if (targetToPreds.size() == 1) {
if (replacePHIs == nullptr) {
assert(BuilderM.GetInsertBlock()->size() == 0 || !isa<BranchInst>(BuilderM.GetInsertBlock()->back()));
BuilderM.CreateBr( targetToPreds.begin()->first );
} else {
for (auto pair : *replacePHIs) {
Expand Down Expand Up @@ -962,6 +969,7 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B
Value* phi = lookupValueFromCache(BuilderM, ctx, cache);

if (replacePHIs == nullptr) {
assert(BuilderM.GetInsertBlock()->size() == 0 || !isa<BranchInst>(BuilderM.GetInsertBlock()->back()));
BuilderM.CreateCondBr(phi, *done[std::make_pair(block, branch->getSuccessor(0))].begin(), *done[std::make_pair(block, branch->getSuccessor(1))].begin());
} else {
for (auto pair : *replacePHIs) {
Expand Down Expand Up @@ -1076,6 +1084,7 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B

if (replacePHIs == nullptr) {
if (targetToPreds.size() == 2) {
assert(BuilderM.GetInsertBlock()->size() == 0 || !isa<BranchInst>(BuilderM.GetInsertBlock()->back()));
BuilderM.CreateCondBr(which, /*true*/targets[1], /*false*/targets[0]);
} else {
auto swit = BuilderM.CreateSwitch(which, targets.back(), targets.size()-1);
Expand Down
9 changes: 7 additions & 2 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -966,12 +966,14 @@ class GradientUtils {
llvm::errs() << "starting outermost ph at " << allocationPreheaders[i]->getName() << "|ctx=" << ctx->getName() <<"\n";
sublimits.push_back(std::make_pair(size, lims));
size = nullptr;
lims.clear();
}
}

if (size != nullptr) {
llvm::errs() << "starting final outermost ph at " << allocationPreheaders[contexts.size()-1]->getName()<<"|ctx=" << ctx->getName() << "\n";
sublimits.push_back(std::make_pair(size, lims));
lims.clear();
}
return sublimits;
}
Expand Down Expand Up @@ -1118,6 +1120,7 @@ class GradientUtils {
indices.push_back(idx.var);
available[idx.var] = idx.var;
}
llvm::errs() << "W sl idx=" << i << " " << *idx.var << " header=" << idx.header->getName() << "\n";

Value* lim = unwrapM(riter->second, BuilderM, available, /*lookupIfAble*/true);
assert(lim);
Expand All @@ -1129,9 +1132,11 @@ class GradientUtils {
}

if (indices.size() > 0) {
llvm::errs() << "sl idx=" << i << " " << *indices[0] << "\n";
Value* idx = indices[0];
for(unsigned i=1; i<indices.size(); i++) {
idx = BuilderM.CreateNUWAdd(idx, BuilderM.CreateNUWMul(indices[i], limits[i-1]));
for(unsigned ind=1; ind<indices.size(); ind++) {
llvm::errs() << "sl idx=" << i << " " << *indices[ind] << "\n";
idx = BuilderM.CreateNUWAdd(idx, BuilderM.CreateNUWMul(indices[ind], limits[ind-1]));
}
next = BuilderM.CreateGEP(next, {idx});
}
Expand Down

0 comments on commit 189a8ff

Please sign in to comment.