Skip to content

Commit

Permalink
new FoldBitWidth pass to fix cornell-zhang/allo#121
Browse files Browse the repository at this point in the history
  • Loading branch information
matth2k committed Jan 23, 2024
1 parent 337cf82 commit 23983cd
Show file tree
Hide file tree
Showing 5 changed files with 338 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/hcl/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createAnyWidthIntegerPass();
std::unique_ptr<OperationPass<ModuleOp>> createMoveReturnToInputPass();
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeCastPass();
std::unique_ptr<OperationPass<ModuleOp>> createRemoveStrideMapPass();
std::unique_ptr<OperationPass<ModuleOp>> createFoldBitWidthPass();
std::unique_ptr<OperationPass<ModuleOp>> createMemRefDCEPass();
std::unique_ptr<OperationPass<ModuleOp>> createDataPlacementPass();
std::unique_ptr<OperationPass<ModuleOp>> createTransformInterpreterPass();
Expand Down
5 changes: 5 additions & 0 deletions include/hcl/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def RemoveStrideMap : Pass<"remove-stride-map", "ModuleOp"> {
let constructor = "mlir::hcl::createRemoveStrideMapPass()";
}

def FoldBitWidth : Pass<"fold-bit-width", "ModuleOp"> {
let summary = "Remove ext and trunc operations surrounding wrap-around ops";
let constructor = "mlir::hcl::createFoldBitWidthPass()";
}

def MemRefDCE : Pass<"memref-dce", "ModuleOp"> {
let summary = "Remove MemRefs that are never loaded from";
let constructor = "mlir::hcl::createMemRefDCEPass()";
Expand Down
1 change: 1 addition & 0 deletions lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_library(MLIRHCLPasses
Passes.cpp
LegalizeCast.cpp
RemoveStrideMap.cpp
FoldBitWidth.cpp
MemRefDCE.cpp
DataPlacement.cpp
TransformInterpreter.cpp
Expand Down
321 changes: 321 additions & 0 deletions lib/Transforms/FoldBitWidth.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
/*
* Copyright HeteroCL authors. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

#include "PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/RegionUtils.h"
#include <algorithm>
#include <cassert>
#include <optional>
#include <utility>

#include "hcl/Transforms/Passes.h"

using namespace mlir;
using namespace hcl;

namespace mlir {
namespace hcl {
template <typename OpTy> struct FoldWidth : OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;

static std::optional<IntegerType> operandIfExtended(Value operand) {
auto *definingOp = operand.getDefiningOp();
if (!definingOp)
return std::nullopt;

if (!isa<IntegerType>(operand.getType()))
return std::nullopt;

if (auto extOp = dyn_cast<arith::ExtSIOp>(*definingOp))
return cast<IntegerType>(extOp->getOperand(0).getType());
if (auto extOp = dyn_cast<arith::ExtUIOp>(*definingOp))
return cast<IntegerType>(extOp->getOperand(0).getType());

return std::nullopt;
}

static std::optional<IntegerType>
valIfTruncated(TypedValue<IntegerType> val) {
if (!val.hasOneUse())
return std::nullopt;
auto *op = *val.getUsers().begin();
if (auto trunc = dyn_cast<arith::TruncIOp>(*op))
if (auto truncType = dyn_cast<IntegerType>(trunc.getType()))
return truncType;

return std::nullopt;
}

static bool opIsLegal(OpTy op) {
if (op->getNumResults() != 1)
return true;
if (op->getNumOperands() <= 0)
return true;
if (!isa<IntegerType>(op->getResultTypes().front()))
return true;

auto outType =
valIfTruncated(cast<TypedValue<IntegerType>>(op->getResult(0)));
if (!outType.has_value())
return true;

auto operandType = operandIfExtended(op->getOperand(0));
if (!operandType.has_value() || operandType != outType)
return true;

// Extension and trunc should be opt away
SmallVector<Value> operands;
for (auto operand : op->getOperands()) {
auto oW = operandIfExtended(operand);
if (oW != operandType)
return true;
}
return false;
}

LogicalResult
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (opIsLegal(op))
return failure();

auto outType =
valIfTruncated(cast<TypedValue<IntegerType>>(op->getResult(0)));

// Extension and trunc should be opt away
SmallVector<Value> operands;
for (auto operand : op->getOperands())
operands.push_back(operand.getDefiningOp()->getOperand(0));

SmallVector<Type> resultTypes = {*outType};
auto newOp = rewriter.create<OpTy>(op.getLoc(), resultTypes, operands);
auto trunc = *op->getUsers().begin();
trunc->getResult(0).replaceAllUsesWith(newOp->getResult(0));
rewriter.eraseOp(trunc);
rewriter.eraseOp(op);

return success();
}
};

template <typename OpTy> struct FoldLinalgWidth : OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;

static unsigned getIndex(mlir::Block::OpListType &opList, Operation *item) {

for (auto op : enumerate(opList))
if (&op.value() == item)
return op.index();
assert(false && "Op not in Op list");
}

static SmallVector<Operation *> getUsersSorted(Value memref) {
SmallVector<Operation *> users(memref.getUsers().begin(),
memref.getUsers().end());

std::sort(users.begin(), users.end(),
[&memref](Operation *a, Operation *b) {
return getIndex(memref.getParentBlock()->getOperations(), a) <
getIndex(memref.getParentBlock()->getOperations(), b);
});

return users;
}

static std::optional<std::pair<TypedValue<MemRefType>, linalg::GenericOp>>
operandIfExtended(TypedValue<MemRefType> memref) {
if (memref.getUsers().empty())
return std::nullopt;

auto users = getUsersSorted(memref);
// If a buffer is used for the sake of type-conversion it should only have 2
// uses.
if (users.size() != 2)
return std::nullopt;

// If this is an extended operand, the first use should be a GenericOp that
// extends
if (!isa<linalg::GenericOp>(users.front()))
return std::nullopt;

auto genericOp = cast<linalg::GenericOp>(users.front());

// Check that the Generic Op is used to extend with memref as an output
if (genericOp.getOutputs().front() != memref ||
genericOp.getBody()->getOperations().size() != 2 ||
genericOp.getInputs().size() != 1)
return std::nullopt;

auto &operation = genericOp.getBody()->front();
if (!isa<arith::ExtSIOp>(operation) && !isa<arith::ExtUIOp>(operation))
return std::nullopt;

// Return the memory buffer that is being extended and the GenericOp too
return std::pair(
cast<TypedValue<MemRefType>>(genericOp.getInputs().front()), genericOp);
}

static std::optional<std::pair<TypedValue<MemRefType>, linalg::GenericOp>>
valIfTruncated(TypedValue<MemRefType> memref) {
if (memref.getUsers().empty())
return std::nullopt;

auto users = getUsersSorted(memref);
// If a buffer is used for the sake of type-conversion it should only have 2
// uses.
if (users.size() != 2)
return std::nullopt;

// If this is an truncated operand, the last use should be a GenericOp that
// truncates
if (!isa<linalg::GenericOp>(users.back()))
return std::nullopt;

auto genericOp = cast<linalg::GenericOp>(users.back());

// Check that the Generic Op is used to truncate the memref input
if (genericOp.getInputs().front() != memref ||
genericOp.getBody()->getOperations().size() != 2 ||
genericOp.getOutputs().size() != 1)
return std::nullopt;

auto &operation = genericOp.getBody()->front();
if (!isa<arith::TruncIOp>(operation))
return std::nullopt;

// Return the memory buffer that is being truncated and the GenericOp too
return std::pair(
cast<TypedValue<MemRefType>>(genericOp.getOutputs().front()),
genericOp);
}

// Test if we should apply this pattern or not
static bool opIsLegal(OpTy op) {

// Should be a binary operation
if (op.getInputs().size() != 2)
return true;
if (op.getOutputs().size() != 1)
return true;

auto outType =
valIfTruncated(cast<TypedValue<MemRefType>>(op.getOutputs().front()));
if (!outType.has_value())
return true;

auto inputs = op.getInputs();
auto firstOperand =
operandIfExtended(cast<TypedValue<MemRefType>>(inputs[0]));
if (!firstOperand.has_value() ||
firstOperand->first.getType() != outType->first.getType())
return true;

auto secondOperand =
operandIfExtended(cast<TypedValue<MemRefType>>(inputs[1]));
if (!secondOperand.has_value() ||
firstOperand->first.getType() != secondOperand->first.getType())
return true;

// At this point, we know all memref types are equivalent so the pattern
// should be applied
return false;
}

LogicalResult
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (opIsLegal(op))
return failure();

auto outType =
valIfTruncated(cast<TypedValue<MemRefType>>(op.getOutputs().front()));
auto inputs = op.getInputs();
auto firstOperand =
operandIfExtended(cast<TypedValue<MemRefType>>(inputs[0]));
auto secondOperand =
operandIfExtended(cast<TypedValue<MemRefType>>(inputs[1]));

// Extension and trunc should be opt away
SmallVector<Value> operands({firstOperand->first, secondOperand->first});

SmallVector<Value> results({outType->first});

// Create the new linalg operation, and move the output memory buffer up in
// the instructions so that it dominates
auto newop = rewriter.create<OpTy>(op->getLoc(), operands, results);
newop.getOutputs().front().getDefiningOp()->moveBefore(newop);

// It is safe to delete these operations, because we force that each
// memory buffer only has 2 uses
rewriter.eraseOp(outType->second);
rewriter.eraseOp(firstOperand->second);
rewriter.eraseOp(secondOperand->second);
rewriter.eraseOp(op);
assert(opIsLegal(newop));

return success();
}
};
} // namespace hcl
} // namespace mlir

namespace {
struct HCLFoldBitWidthTransformation
: public FoldBitWidthBase<HCLFoldBitWidthTransformation> {
void runOnOperation() override {
auto *context = &getContext();
RewritePatternSet patterns(context);
ConversionTarget target(*context);

// Patterns for scalar wraparound operations
patterns.add<FoldWidth<arith::AddIOp>>(context);
patterns.add<FoldWidth<arith::SubIOp>>(context);
patterns.add<FoldWidth<arith::MulIOp>>(context);

// Targets for scalar wraparound operations
target.addDynamicallyLegalOp<arith::AddIOp>(
FoldWidth<arith::AddIOp>::opIsLegal);
target.addDynamicallyLegalOp<arith::SubIOp>(
FoldWidth<arith::SubIOp>::opIsLegal);
target.addDynamicallyLegalOp<arith::MulIOp>(
FoldWidth<arith::MulIOp>::opIsLegal);

// Patterns for linalg wraparound operations
patterns.add<FoldLinalgWidth<linalg::AddOp>>(context);
patterns.add<FoldLinalgWidth<linalg::SubOp>>(context);
patterns.add<FoldLinalgWidth<linalg::MulOp>>(context);

// Targets for linalg wraparound operations
target.addDynamicallyLegalOp<linalg::AddOp>(
FoldLinalgWidth<linalg::AddOp>::opIsLegal);
target.addDynamicallyLegalOp<linalg::SubOp>(
FoldLinalgWidth<linalg::SubOp>::opIsLegal);
target.addDynamicallyLegalOp<linalg::MulOp>(
FoldLinalgWidth<linalg::MulOp>::opIsLegal);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
OpBuilder builder(getOperation());
IRRewriter rewriter(builder);
(void)runRegionDCE(rewriter, getOperation()->getRegions());
}
};
} // namespace

namespace mlir {
namespace hcl {
std::unique_ptr<OperationPass<ModuleOp>> createFoldBitWidthPass() {
return std::make_unique<HCLFoldBitWidthTransformation>();
}
} // namespace hcl
} // namespace mlir
10 changes: 10 additions & 0 deletions tools/hcl-opt/hcl-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ static llvm::cl::opt<bool> removeStrideMap("remove-stride-map",
llvm::cl::desc("Remove stride map"),
llvm::cl::init(false));

static llvm::cl::opt<bool> foldBitWidth(
"fold-bit-width",
llvm::cl::desc(
"Remove ext and trunc operations surrounding wrap-around ops"),
llvm::cl::init(false));

static llvm::cl::opt<bool> lowerPrintOps("lower-print-ops",
llvm::cl::desc("Lower print ops"),
llvm::cl::init(false));
Expand Down Expand Up @@ -319,6 +325,10 @@ int main(int argc, char **argv) {
pm.addPass(mlir::hcl::createRemoveStrideMapPass());
}

if (foldBitWidth) {
pm.addPass(mlir::hcl::createFoldBitWidthPass());
}

if (bufferization) {
pm.addPass(mlir::bufferization::createOneShotBufferizePass());
}
Expand Down

0 comments on commit 23983cd

Please sign in to comment.