Skip to content

Commit

Permalink
Add Reduce ops workaround for keepDim=false (#1625)
Browse files Browse the repository at this point in the history
This PR adds TTNN workarounds for these Metal issues:
- tenstorrent/tt-metal#13361 - By decomposing
`reduce(keepDim=false)` into `reduce(keepDim=true) + reshape`
- tenstorrent/tt-metal#16118 - By annulling
dimensions argument when all dims are being reduced

As part of this work I've also:
- Enabled conversion of `stablehlo.reduce` op with multiple reduce
dimensions
- Added reduce ops verifiers in TTIR
- Added a separate function in TTNNWorkarounds to run rewrite patterns
for decomposition and layout workarounds
- Added lots of unit tests for reduce ops to cover conversions and
verifiers
- Added lots of silicon tests for reduce ops

Opened issue #1624 on
myself to revert these workarounds once Metal issues are fixed.

Closes #805, #848

After implementing these workarounds and running tests, I've encountered
[another Metal
issue](tenstorrent/tt-metal#16104), this time
in `reshape` op. I've debugged it and I have a local fix, I will send a
PR to fix it in Metal repo, confirmed with reshape op owners. I've
opened myself an issue
#1640 to enable Reduce ops
silicon tests after this fix is uplifted.

Another issue that I've encountered while working on this is that after
the workaround pass decompositions, if we are changing the shapes of the
ops tensors, that means that their layout needs to be changed too, but
layout pass is done before the workaround pass. I've managed to solve it
by reusing the layout of the input tensor, but I am not sure if that is
a good solution and maybe we need to repeat some of the layout logic
again after workaround decompositions. FYI @sdjordjevicTT

Here is the example TTNN IR before the workarounds:
```
%3 = "ttnn.sum"(%2) <{dim_arg = [0: i32, 1 : i32, 2: i32], keep_dim = false}> : (tensor<128x32x4xf32, #ttnn_layout2>) -> tensor<1xf32, #ttnn_layout2>
```

and after the workarounds:
```
%3 = "ttnn.sum"(%2) <{keep_dim = true}> : (tensor<128x32x4xf32, #ttnn_layout2>) -> tensor<1x1x1xf32, #ttnn_layout2>
%4 = "ttnn.reshape"(%3) <{shape = [1 : i32]}> : (tensor<1x1x1xf32, #ttnn_layout2>) -> tensor<1xf32, #ttnn_layout3>
```
  • Loading branch information
mrakitaTT authored Dec 20, 2024
1 parent 590a287 commit cb3e406
Show file tree
Hide file tree
Showing 24 changed files with 954 additions and 82 deletions.
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,8 @@ class TTIR_ReductionOp<string mnemonic, list<Trait> traits = []> :
return {builder.getAffineMapArrayAttr(indexingMaps),
builder.getArrayAttr(iteratorTypes)};}
}];

let hasVerifier = 1;
}

def TTIR_SumOp : TTIR_ReductionOp<"sum"> {
Expand Down
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,8 @@ class TTNN_ReductionOp<string mnemonic, list<Trait> traits = []> : TTNN_Op<mnemo
OptionalAttr<I32ArrayAttr>:$dim_arg);

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
}

def TTNN_SumOp : TTNN_ReductionOp<"sum"> {
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> {
TTNNLayoutAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayoutAttr memLayoutAttr);
TTNNLayoutAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout);
TTNNLayoutAttr withShardShape(::mlir::MLIRContext *context, llvm::SmallVector<int64_t> shardShape);
TTNNLayoutAttr withTensorShape(::mlir::MLIRContext *context, ArrayRef<int64_t> tensorShape);

bool isSystemBufferType() const { return ::mlir::tt::ttnn::isSystemBufferType(getBufferType()); }
bool isDeviceBufferType() const { return ::mlir::tt::ttnn::isDeviceBufferType(getBufferType()); }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_TRANSFORMS_WORKAROUNDS_DECOMPOSITION_REDUCEOPSREWRITEPATTERN_H
#define TTMLIR_DIALECT_TTNN_TRANSFORMS_WORKAROUNDS_DECOMPOSITION_REDUCEOPSREWRITEPATTERN_H

#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"

namespace mlir::tt::ttnn::workarounds::decomposition {

// Extracts reduce dimensions' values from the dimArg attribute. In case when
// dimArg is not specified, returns empty vector.
llvm::SmallVector<int64_t>
getReduceDims(const std::optional<mlir::ArrayAttr> &dimArg);

// Calculates the shape of the new Reduce op created in the workaround, based
// on the input shape and reducing dimensions.
llvm::SmallVector<int64_t>
calculateNewReduceShape(RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &dimArg);

// This workaround addresses the next Metal issue:
// https://github.com/tenstorrent/tt-metal/issues/13361
//
// TODO(mrakita): Remove this workaround once these Metal issues are fixed
// (tracked by https://github.com/tenstorrent/tt-mlir/issues/1624).
//
template <typename ReduceOp>
class ReduceOpsKeepDimRewritePattern : public OpRewritePattern<ReduceOp> {
public:
using OpRewritePattern<ReduceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ReduceOp srcOp,
PatternRewriter &rewriter) const override {
if (srcOp.getKeepDim()) {
return failure();
}

RankedTensorType inputType = srcOp.getInput().getType();
RankedTensorType outputType = srcOp.getResult().getType();

ReduceOp newReduceOp =
createReduceOpWithKeepDim(srcOp, rewriter, inputType, outputType);

// Metal TTNN implementation of Reduce ops doesn't yet support
// keepDim=false. As a workaround, we convert Reduce ops to combination of
// Reduce op with keepDim=true + Reshape op to remove the reduce dims so
// that the rest of the graph is not affected. In case when this is not
// needed (for example because type converters already promoted rank of the
// op result) then we avoid adding unnecessary Reshape op.
if (outputType.getShape().size() < inputType.getShape().size()) {
replaceOpWithReshapeOp(srcOp, newReduceOp, rewriter, outputType);
} else {
rewriter.replaceOp(srcOp, newReduceOp);
}

return success();
}

private:
ReduceOp createReduceOpWithKeepDim(ReduceOp srcOp, PatternRewriter &rewriter,
RankedTensorType inputType,
RankedTensorType outputType) const {
llvm::SmallVector<int64_t> outputShapeVec =
calculateNewReduceShape(inputType, srcOp.getDimArg());

TTNNLayoutAttr newOutputLayoutAttr =
mlir::cast<TTNNLayoutAttr>(outputType.getEncoding())
.withTensorShape(rewriter.getContext(), outputShapeVec);

RankedTensorType newOutputType = RankedTensorType::get(
outputShapeVec, outputType.getElementType(), newOutputLayoutAttr);

return rewriter.create<ReduceOp>(srcOp.getLoc(), newOutputType,
srcOp.getInput(), true /*keep_dim*/,
srcOp.getDimArg().value_or(nullptr));
}

void replaceOpWithReshapeOp(ReduceOp srcOp, ReduceOp newReduceOp,
PatternRewriter &rewriter,
RankedTensorType outputType) const {
mlir::ArrayAttr shapeAttr = rewriter.getI32ArrayAttr(
llvm::SmallVector<int32_t>(outputType.getShape()));

rewriter.replaceOpWithNewOp<mlir::tt::ttnn::ReshapeOp>(
srcOp, outputType, newReduceOp, shapeAttr);
}
};

// This workaround addresses the next Metal issue:
// https://github.com/tenstorrent/tt-metal/issues/16118
//
// TODO(mrakita): Remove this workaround once these Metal issues are fixed
// (tracked by https://github.com/tenstorrent/tt-mlir/issues/1624).
//
template <typename ReduceOp>
class ReduceOpsAllDimsRewritePattern : public OpRewritePattern<ReduceOp> {
public:
using OpRewritePattern<ReduceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ReduceOp srcOp,
PatternRewriter &rewriter) const override {
if (!srcOp.getDimArg() || srcOp.getDimArg()->empty()) {
return failure();
}

llvm::SmallVector<int64_t> reduceDims = getReduceDims(srcOp.getDimArg());
llvm::SmallSet<int64_t, 4> uniqueReduceDims(reduceDims.begin(),
reduceDims.end());

// Check if reduce is done over all dimensions of the input tensor.
if (uniqueReduceDims.size() !=
srcOp.getInput().getType().getShape().size()) {
return failure();
}

// In case when reduce is done over all dimensions of the input we need to
// unset the dimensions attribute, because Metal supports reduce over all
// dimensions for any tensor rank when reduce dimensions are not specified,
// but it doesn't support reduce for tensors with rank larger than 2 when
// reduce dimensions are specified.
rewriter.replaceOpWithNewOp<ReduceOp>(srcOp, srcOp.getResult().getType(),
srcOp.getInput(), srcOp.getKeepDim(),
nullptr);

return success();
}
};

} // namespace mlir::tt::ttnn::workarounds::decomposition

#endif // TTMLIR_DIALECT_TTNN_TRANSFORMS_WORKAROUNDS_DECOMPOSITION_REDUCEOPSREWRITEPATTERN_H
7 changes: 3 additions & 4 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,9 @@ class StableHLOToTTIRReduceOpConversionPattern
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

mlir::ArrayAttr dimArg = rewriter.getArrayAttr(SmallVector<Attribute>(
1, rewriter.getI32IntegerAttr(adaptor.getDimensionsAttr().size() > 0
? adaptor.getDimensionsAttr()[0]
: 1)));
// Can't reuse the original dimensions attribute because it uses i64 type.
mlir::ArrayAttr dimArg = rewriter.getI32ArrayAttr(
llvm::SmallVector<int32_t>(srcOp.getDimensions()));

rewriter.replaceOpWithNewOp<DestOp>(
srcOp, outputType, adaptor.getInputs().front(), outputTensor,
Expand Down
87 changes: 73 additions & 14 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"

Expand Down Expand Up @@ -1672,32 +1673,32 @@ static void buildGenericEltwiseUnaryRegion(::mlir::Location loc,
opBuilder.create<mlir::tt::ttir::YieldOp>(loc, mlir::ValueRange({result}));
}

// AddOp generic region builder
// AddOp generic region builder.
void mlir::tt::ttir::AddOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
buildGenericEltwiseBinaryRegion<arith::AddFOp>(getLoc(), opBuilder, block);
}

// MultiplyOp generic region builder
// MultiplyOp generic region builder.
void mlir::tt::ttir::MultiplyOp::buildGenericRegion(
::mlir::OpBuilder &opBuilder, ::mlir::Block *block) {
buildGenericEltwiseBinaryRegion<arith::MulFOp>(getLoc(), opBuilder, block);
}

// ExpOp generic region builder
// ExpOp generic region builder.
void mlir::tt::ttir::ExpOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
buildGenericEltwiseUnaryRegion<math::ExpOp>(getLoc(), opBuilder, block);
}

// DivOp generic region builder
// DivOp generic region builder.
void mlir::tt::ttir::DivOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
return buildGenericEltwiseBinaryRegion<arith::DivFOp>(getLoc(), opBuilder,
block);
}

// MaximumOp generic region builder
// MaximumOp generic region builder.
void mlir::tt::ttir::MaximumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
buildGenericEltwiseBinaryRegion<arith::MaximumFOp>(getLoc(), opBuilder,
Expand All @@ -1708,7 +1709,7 @@ void mlir::tt::ttir::MaximumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
// KernelOp
//===----------------------------------------------------------------------===//

// KernelOp builders
// KernelOp builders.
static mlir::tt::ttir::KernelOp
buildKernelOp(::mlir::OpBuilder &opBuilder, ::mlir::Location loc,
::mlir::StringRef kernelName, ::mlir::StringRef kernelKind,
Expand All @@ -1717,31 +1718,89 @@ buildKernelOp(::mlir::OpBuilder &opBuilder, ::mlir::Location loc,
loc, outputs.getTypes(), kernelName, kernelKind, inputs, outputs);
}

// Reduce op kernel builder
// Reduce op kernel builder.
static void createReduceOp(::mlir::OpBuilder &opBuilder, ::mlir::Block *block,
mlir::Location loc, ::mlir::StringRef kernelKind) {
auto kernelOp = buildKernelOp(opBuilder, loc, "reduce", kernelKind,
block->getArgument(0), block->getArgument(1));
opBuilder.create<mlir::tt::ttir::YieldOp>(loc, kernelOp->getResults());
}

// Sum op kernel builder
void mlir::tt::ttir::SumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
// Common verifier for all Reduce ops.
static mlir::LogicalResult
verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &reduceDims) {
if (!reduceDims) {
return mlir::success();
}

int64_t inputTensorRank = inputType.getRank();

llvm::SmallSet<int64_t, 4> uniqueReduceDims;
for (mlir::Attribute reduceDim : *reduceDims) {
int64_t reduceDimInt = mlir::cast<mlir::IntegerAttr>(reduceDim).getInt();
if (reduceDimInt < -inputTensorRank || reduceDimInt >= inputTensorRank) {
return reduceOp->emitOpError("Reduce dimensions are out of range");
}
uniqueReduceDims.insert(reduceDimInt);
}

if (uniqueReduceDims.size() != reduceDims->size()) {
return reduceOp->emitOpError("Reduce dimensions are not unique");
}

// TODO(mrakita): Add a check that depending on inputShape, reduceDims and
// keepDim computes the expected output shape and checks if it matches the
// actual output shape. Tracked by:
// https://github.com/tenstorrent/tt-mlir/issues/1639

return mlir::success();
}

//===----------------------------------------------------------------------===//
// MaxOp
//===----------------------------------------------------------------------===//

// MaxOp kernel builder.
void mlir::tt::ttir::MaxOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
// NOLINTNEXTLINE
createReduceOp(opBuilder, block, getLoc(), "sum");
createReduceOp(opBuilder, block, getLoc(), "max");
}

// MaxOp verification.
::mlir::LogicalResult mlir::tt::ttir::MaxOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

// Mean op kernel builder
//===----------------------------------------------------------------------===//
// MeanOp
//===----------------------------------------------------------------------===//

// MeanOp kernel builder.
void mlir::tt::ttir::MeanOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
// NOLINTNEXTLINE
createReduceOp(opBuilder, block, getLoc(), "mean");
}

// Max op kernel builder
void mlir::tt::ttir::MaxOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
// MeanOp verification.
::mlir::LogicalResult mlir::tt::ttir::MeanOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
// SumOp
//===----------------------------------------------------------------------===//

// SumOp kernel builder.
void mlir::tt::ttir::SumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
// NOLINTNEXTLINE
createReduceOp(opBuilder, block, getLoc(), "max");
createReduceOp(opBuilder, block, getLoc(), "sum");
}

// SumOp verification.
::mlir::LogicalResult mlir::tt::ttir::SumOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}
48 changes: 48 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1310,4 +1310,52 @@ ::mlir::LogicalResult mlir::tt::ttnn::PermuteOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// Reduction ops
//===----------------------------------------------------------------------===//

// Common verifier for all Reduction ops.
static mlir::LogicalResult
verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &reduceDims) {
int64_t inputTensorRank = inputType.getRank();

// TODO(mrakita): Only last two dimensions can be reduced, check for that
// too.
if (reduceDims && reduceDims->size() > 2 &&
static_cast<int64_t>(reduceDims->size()) != inputTensorRank) {
return reduceOp->emitOpError("Reduce on more than two dimensions is not "
"currently supported by TTNN");
}

return mlir::success();
}

//===----------------------------------------------------------------------===//
// MaxOp
//===----------------------------------------------------------------------===//

// MaxOp verification.
::mlir::LogicalResult MaxOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
// MeanOp
//===----------------------------------------------------------------------===//

// MeanOp verification.
::mlir::LogicalResult MeanOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
// SumOp
//===----------------------------------------------------------------------===//

// SumOp verification.
::mlir::LogicalResult SumOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

} // namespace mlir::tt::ttnn
Loading

0 comments on commit cb3e406

Please sign in to comment.