Skip to content

Commit

Permalink
[mlir][sparse] partially support lowering sparse coiteration loops to…
Browse files Browse the repository at this point in the history
… scf.while/for. (#105565)
  • Loading branch information
Peiming Liu authored Aug 23, 2024
1 parent ebc4a66 commit f607102
Show file tree
Hide file tree
Showing 9 changed files with 549 additions and 91 deletions.
22 changes: 15 additions & 7 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,24 +96,32 @@ class I64BitSet {
return *this;
}

bool isSubSetOf(const I64BitSet p) const {
I64BitSet tmp = *this;
tmp |= p;
return tmp == p;
}

// Needed by `llvm::const_set_bits_iterator_impl`.
int find_first() const { return min(); }
int find_next(unsigned prev) const {
if (prev >= max())
if (prev >= max() - 1)
return -1;

uint64_t b = storage >> (prev + 1);
if (b == 0)
return -1;
uint64_t b = storage >> (prev + static_cast<int64_t>(1));
assert(b != 0);

return llvm::countr_zero(b) + prev + 1;
return llvm::countr_zero(b) + prev + static_cast<int64_t>(1);
}

bool operator[](unsigned i) const {
assert(i < 64);
return (storage & (1 << i)) != 0;
return (storage & (static_cast<int64_t>(1) << i)) != 0;
}
unsigned min() const {
unsigned m = llvm::countr_zero(storage);
return m == 64 ? -1 : m;
}
unsigned min() const { return llvm::countr_zero(storage); }
unsigned max() const { return 64 - llvm::countl_zero(storage); }
unsigned count() const { return llvm::popcount(storage); }
bool empty() const { return storage == 0; }
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1787,6 +1787,10 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
.take_back(getRegionDefinedSpace(regionIdx).count());
}
ValueRange getYieldedValues(unsigned regionIdx);

// Returns a vector of regions that are the `sub-cases` of the given case region.
// E.g., `case %it1, _, %it3` is a subcase of `case %it1, %it2, %it3`.
SmallVector<Region *> getSubCasesOf(unsigned regionIdx);
}];

let hasVerifier = 1;
Expand Down
10 changes: 10 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2745,6 +2745,16 @@ LogicalResult CoIterateOp::verifyRegions() {
return success();
}

SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
SmallVector<Region *> ret;
I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
for (Region &r : getCaseRegions())
if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
ret.push_back(&r);

return ret;
}

//===----------------------------------------------------------------------===//
// Sparse Tensor Dialect Setups.
//===----------------------------------------------------------------------===//
Expand Down
291 changes: 290 additions & 1 deletion mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

#include "Utils/CodegenUtils.h"
#include "Utils/LoopEmitter.h"
#include "Utils/SparseTensorIterator.h"

#include "mlir/Dialect/MemRef/IR/MemRef.h"
Expand Down Expand Up @@ -49,6 +50,144 @@ convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
return success();
}

static ValueRange
genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
Value loopCrd,
ArrayRef<std::unique_ptr<SparseIterator>> iters,
ArrayRef<Region *> subCases, ArrayRef<Value> userReduc) {
if (subCases.empty())
return userReduc;

// The current branch that we are handling.
Region *b = subCases.front();
Value casePred = constantI1(rewriter, loc, true);
I64BitSet caseBits = op.getRegionDefinedSpace(b->getRegionNumber());
for (unsigned i : caseBits.bits()) {
SparseIterator *it = iters[i].get();
Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
it->getCrd(), loopCrd);
casePred = rewriter.create<arith::AndIOp>(loc, casePred, pred);
}
scf::IfOp ifOp = rewriter.create<scf::IfOp>(
loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());

// Erase the empty block.
rewriter.eraseBlock(&ifOp.getThenRegion().front());
// Set up block arguments: user-provided values -> loop coord -> iterators.
SmallVector<Value> blockArgs(userReduc);
blockArgs.push_back(loopCrd);
for (unsigned idx : caseBits.bits())
llvm::append_range(blockArgs, iters[idx]->getCursor());

IRMapping mapping;
for (auto [from, to] :
llvm::zip_equal(b->front().getArguments(), blockArgs)) {
mapping.map(from, to);
}

// Clone the region, we can not erase the region now because the same region
// might be a subcase for multiple lattice point.
rewriter.cloneRegionBefore(*b, ifOp.getThenRegion(),
ifOp.getThenRegion().begin(), mapping);

// replace sparse_tensor::YieldOp -> scf::YieldOp
auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
ValueRange yields = spY.getResults();
rewriter.eraseOp(spY);
rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front());
rewriter.create<scf::YieldOp>(loc, yields);

// Generates remaining case recursively.
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
subCases.drop_front(), userReduc);
if (!res.empty())
rewriter.create<scf::YieldOp>(loc, res);

rewriter.setInsertionPointAfter(ifOp);
return ifOp.getResults();
}

static ValueRange genLoopWithIterator(
PatternRewriter &rewriter, Location loc, SparseIterator *it,
ValueRange reduc, bool iterFirst,
function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc,
Region &loopBody, SparseIterator *it,
ValueRange reduc)>
bodyBuilder) {
if (it->iteratableByFor()) {
auto [lo, hi] = it->genForCond(rewriter, loc);
Value step = constantIndex(rewriter, loc, 1);
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, reduc);
{
OpBuilder::InsertionGuard guard(rewriter);
// Erase the implicit yield operation created by ForOp when there is no
// yielding values.
if (!forOp.getBody()->empty())
rewriter.eraseOp(&forOp.getBody()->front());
assert(forOp.getBody()->empty());

it->linkNewScope(forOp.getInductionVar());
rewriter.setInsertionPointToStart(forOp.getBody());
SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
it, forOp.getRegionIterArgs());

rewriter.setInsertionPointToEnd(forOp.getBody());
rewriter.create<scf::YieldOp>(loc, ret);
}
return forOp.getResults();
}
SmallVector<Value> ivs;
// TODO: always put iterator SSA values at the end of argument list to be
// consistent with coiterate operation.
if (!iterFirst)
llvm::append_range(ivs, it->getCursor());
// Appends the user-provided values.
llvm::append_range(ivs, reduc);
if (iterFirst)
llvm::append_range(ivs, it->getCursor());

TypeRange types = ValueRange(ivs).getTypes();
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
{
OpBuilder::InsertionGuard guard(rewriter);
// Generates loop conditions.
SmallVector<Location> l(types.size(), loc);
Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
rewriter.setInsertionPointToStart(before);
ValueRange bArgs = before->getArguments();
auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());

// Delegates loop body generation.
Region &dstRegion = whileOp.getAfter();
Block *after = rewriter.createBlock(&dstRegion, {}, types, l);
ValueRange aArgs = whileOp.getAfterArguments();
if (iterFirst) {
aArgs = it->linkNewScope(aArgs);
} else {
aArgs = aArgs.take_front(reduc.size());
it->linkNewScope(aArgs.drop_front(reduc.size()));
}

rewriter.setInsertionPointToStart(after);
SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs);
rewriter.setInsertionPointToEnd(after);

// Forward loops
SmallVector<Value> yields;
ValueRange nx = it->forward(rewriter, loc);
if (iterFirst)
llvm::append_range(yields, nx);
llvm::append_range(yields, ret);
if (!iterFirst)
llvm::append_range(yields, nx);
rewriter.create<scf::YieldOp>(loc, yields);
}
return whileOp.getResults().drop_front(it->getCursor().size());
}

namespace {

/// Sparse codegen rule for number of entries operator.
Expand Down Expand Up @@ -136,6 +275,8 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
rewriter.replaceOp(op, forOp.getResults(), resultMapping);
} else {
SmallVector<Value> ivs;
// TODO: put iterator at the end of argument list to be consistent with
// coiterate operation.
llvm::append_range(ivs, it->getCursor());
for (ValueRange inits : adaptor.getInitArgs())
llvm::append_range(ivs, inits);
Expand Down Expand Up @@ -189,6 +330,153 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
}
};

class SparseCoIterateOpConverter
: public OneToNOpConversionPattern<CoIterateOp> {
using OneToNOpConversionPattern::OneToNOpConversionPattern;

LogicalResult
matchAndRewrite(CoIterateOp op, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
assert(op.getSpaceDim() == 1 && "Not implemented");
Location loc = op.getLoc();

I64BitSet denseBits(0);
for (auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes()))
if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(), isDenseLT))
denseBits.set(idx);

// If there exists a case that only contains dense spaces. I.e., case
// bits is a subset of dense bits, or when there is a full empty case (due
// to complements), we need a universal pointer to forward the coiteration
// loop.
bool needUniv =
any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) {
// A case for complement.
if (caseBits.count() == 0)
return true;
// An all-dense case.
return caseBits.isSubSetOf(denseBits);
});
assert(!needUniv && "Not implemented");
(void)needUniv;

for (Region &region : op.getCaseRegions()) {
// Do a one-shot type conversion on all region blocks, since the same
// region might be used multiple time.
Block *block = &region.getBlocks().front();
OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
blockTypeMapping)))
return rewriter.notifyMatchFailure(
op, "failed to convert coiterate region argurment types");

rewriter.applySignatureConversion(block, blockTypeMapping);
}

SmallVector<SparseIterationSpace> spaces;
SmallVector<std::unique_ptr<SparseIterator>> iters;
for (auto [spaceTp, spaceVals] : llvm::zip_equal(
op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) {
// TODO: do we really need tid?
spaces.push_back(SparseIterationSpace::fromValues(
cast<IterSpaceType>(spaceTp), spaceVals, /*tid=*/0));
// Extract the iterator.
iters.push_back(spaces.back().extractIterator(rewriter, loc));
}

auto getFilteredIters = [&iters](I64BitSet caseBits) {
// Retrives a vector of pointers to the iterators used in the case.
SmallVector<SparseIterator *> validIters;
for (auto idx : caseBits.bits())
validIters.push_back(iters[idx].get());
return validIters;
};

// Get a flattened user-provided loop reduction values.
SmallVector<Value> userReduc;
for (ValueRange r : adaptor.getInitArgs())
llvm::append_range(userReduc, r);

// TODO: we need to sort the cases such that they appears in lexical order.
// Although sparsification always generates cases in that order, it might
// not be the case for human-written code.

// Generates a loop sequence, one loop per case.
for (auto [r, caseBits] :
llvm::zip_equal(op.getCaseRegions(), op.getRegionDefinedSpaces())) {
assert(caseBits.count() > 0 && "Complement space not implemented");

// Retrives a vector of pointers to the iterators used in the case.
SmallVector<SparseIterator *> validIters = getFilteredIters(caseBits);

if (validIters.size() > 1) {
auto [loop, loopCrd] =
genCoIteration(rewriter, loc, validIters, userReduc,
/*uniIdx=*/nullptr, /*userReducFirst=*/true);

// 1st. find all the cases that is a strict subset of the current case
// condition, for which we generate one branch per case inside the loop.
// The subcases are never empty, it must contains at least the current
// region itself.
// TODO: these cases should be sorted.
SmallVector<Region *> subCases = op.getSubCasesOf(r.getRegionNumber());
assert(!subCases.empty());

ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd,
iters, subCases, userReduc);

SmallVector<Value> nextIterYields(res);
// 2nd. foward the loop.
for (SparseIterator *it : validIters) {
Value cmp = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd);
it->forwardIf(rewriter, loc, cmp);
llvm::append_range(nextIterYields, it->getCursor());
}
rewriter.create<scf::YieldOp>(loc, nextIterYields);

// Exit the loop, relink the iterator SSA value.
rewriter.setInsertionPointAfter(loop);
ValueRange iterVals = loop->getResults().drop_front(userReduc.size());
for (SparseIterator *it : validIters)
iterVals = it->linkNewScope(iterVals);
assert(iterVals.empty());

ValueRange curResult = loop->getResults().take_front(userReduc.size());
userReduc.assign(curResult.begin(), curResult.end());
} else {
// This is a simple iteration loop.
assert(caseBits.count() == 1);

Block *block = &r.getBlocks().front();
ValueRange curResult = genLoopWithIterator(
rewriter, loc, validIters.front(), userReduc, /*iterFirst=*/false,
/*bodyBuilder=*/
[block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
SparseIterator *it,
ValueRange reduc) -> SmallVector<Value> {
SmallVector<Value> blockArgs(reduc);
blockArgs.push_back(it->deref(rewriter, loc));
llvm::append_range(blockArgs, it->getCursor());

Block *dstBlock = &dstRegion.getBlocks().front();
rewriter.inlineBlockBefore(
block, dstBlock, rewriter.getInsertionPoint(), blockArgs);
auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
SmallVector<Value> result(yield.getResults());
rewriter.eraseOp(yield);
return result;
});

userReduc.assign(curResult.begin(), curResult.end());
}
}

rewriter.replaceOp(op, userReduc);
return success();
}
};

} // namespace

mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
Expand All @@ -210,5 +498,6 @@ void mlir::populateLowerSparseIterationToSCFPatterns(

IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
SparseIterateOpConverter>(converter, patterns.getContext());
SparseIterateOpConverter, SparseCoIterateOpConverter>(
converter, patterns.getContext());
}
Loading

0 comments on commit f607102

Please sign in to comment.