-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Reduce ops workaround for keepDim=false (#1625)
This PR adds TTNN workarounds for these Metal issues: - tenstorrent/tt-metal#13361 - By decomposing `reduce(keepDim=false)` into `reduce(keepDim=true) + reshape` - tenstorrent/tt-metal#16118 - By annulling dimensions argument when all dims are being reduced As part of this work I've also: - Enabled conversion of `stablehlo.reduce` op with multiple reduce dimensions - Added reduce ops verifiers in TTIR - Added a separate function in TTNNWorkarounds to run rewrite patterns for decomposition and layout workarounds - Added lots of unit tests for reduce ops to cover conversions and verifiers - Added lots of silicon tests for reduce ops Opened issue #1624 on myself to revert these workarounds once Metal issues are fixed. Closes #805, #848 After implementing these workarounds and running tests, I've encountered [another Metal issue](tenstorrent/tt-metal#16104), this time in `reshape` op. I've debugged it and I have a local fix, I will send a PR to fix it in Metal repo, confirmed with reshape op owners. I've opened myself an issue #1640 to enable Reduce ops silicon tests after this fix is uplifted. Another issue that I've encountered while working on this is that after the workaround pass decompositions, if we are changing the shapes of the ops tensors, that means that their layout needs to be changed too, but layout pass is done before the workaround pass. I've managed to solve it by reusing the layout of the input tensor, but I am not sure if that is a good solution and maybe we need to repeat some of the layout logic again after workaround decompositions. FYI @sdjordjevicTT Here is the example TTNN IR before the workarounds: ``` %3 = "ttnn.sum"(%2) <{dim_arg = [0: i32, 1 : i32, 2: i32], keep_dim = false}> : (tensor<128x32x4xf32, #ttnn_layout2>) -> tensor<1xf32, #ttnn_layout2> ``` and after the workarounds: ``` %3 = "ttnn.sum"(%2) <{keep_dim = true}> : (tensor<128x32x4xf32, #ttnn_layout2>) -> tensor<1x1x1xf32, #ttnn_layout2> %4 = "ttnn.reshape"(%3) <{shape = [1 : i32]}> : (tensor<1x1x1xf32, #ttnn_layout2>) -> tensor<1xf32, #ttnn_layout3> ```
- Loading branch information
Showing
24 changed files
with
954 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
140 changes: 140 additions & 0 deletions
140
include/ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTMLIR_DIALECT_TTNN_TRANSFORMS_WORKAROUNDS_DECOMPOSITION_REDUCEOPSREWRITEPATTERN_H | ||
#define TTMLIR_DIALECT_TTNN_TRANSFORMS_WORKAROUNDS_DECOMPOSITION_REDUCEOPSREWRITEPATTERN_H | ||
|
||
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" | ||
|
||
#include "mlir/IR/BuiltinAttributes.h" | ||
#include "mlir/IR/BuiltinTypes.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
#include "llvm/ADT/SmallSet.h" | ||
#include "llvm/ADT/SmallVector.h" | ||
|
||
namespace mlir::tt::ttnn::workarounds::decomposition { | ||
|
||
// Extracts reduce dimensions' values from the dimArg attribute. In case when | ||
// dimArg is not specified, returns empty vector. | ||
llvm::SmallVector<int64_t> | ||
getReduceDims(const std::optional<mlir::ArrayAttr> &dimArg); | ||
|
||
// Calculates the shape of the new Reduce op created in the workaround, based | ||
// on the input shape and reducing dimensions. | ||
llvm::SmallVector<int64_t> | ||
calculateNewReduceShape(RankedTensorType inputType, | ||
const std::optional<mlir::ArrayAttr> &dimArg); | ||
|
||
// This workaround addresses the next Metal issue: | ||
// https://github.com/tenstorrent/tt-metal/issues/13361 | ||
// | ||
// TODO(mrakita): Remove this workaround once these Metal issues are fixed | ||
// (tracked by https://github.com/tenstorrent/tt-mlir/issues/1624). | ||
// | ||
template <typename ReduceOp> | ||
class ReduceOpsKeepDimRewritePattern : public OpRewritePattern<ReduceOp> { | ||
public: | ||
using OpRewritePattern<ReduceOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(ReduceOp srcOp, | ||
PatternRewriter &rewriter) const override { | ||
if (srcOp.getKeepDim()) { | ||
return failure(); | ||
} | ||
|
||
RankedTensorType inputType = srcOp.getInput().getType(); | ||
RankedTensorType outputType = srcOp.getResult().getType(); | ||
|
||
ReduceOp newReduceOp = | ||
createReduceOpWithKeepDim(srcOp, rewriter, inputType, outputType); | ||
|
||
// Metal TTNN implementation of Reduce ops doesn't yet support | ||
// keepDim=false. As a workaround, we convert Reduce ops to combination of | ||
// Reduce op with keepDim=true + Reshape op to remove the reduce dims so | ||
// that the rest of the graph is not affected. In case when this is not | ||
// needed (for example because type converters already promoted rank of the | ||
// op result) then we avoid adding unnecessary Reshape op. | ||
if (outputType.getShape().size() < inputType.getShape().size()) { | ||
replaceOpWithReshapeOp(srcOp, newReduceOp, rewriter, outputType); | ||
} else { | ||
rewriter.replaceOp(srcOp, newReduceOp); | ||
} | ||
|
||
return success(); | ||
} | ||
|
||
private: | ||
ReduceOp createReduceOpWithKeepDim(ReduceOp srcOp, PatternRewriter &rewriter, | ||
RankedTensorType inputType, | ||
RankedTensorType outputType) const { | ||
llvm::SmallVector<int64_t> outputShapeVec = | ||
calculateNewReduceShape(inputType, srcOp.getDimArg()); | ||
|
||
TTNNLayoutAttr newOutputLayoutAttr = | ||
mlir::cast<TTNNLayoutAttr>(outputType.getEncoding()) | ||
.withTensorShape(rewriter.getContext(), outputShapeVec); | ||
|
||
RankedTensorType newOutputType = RankedTensorType::get( | ||
outputShapeVec, outputType.getElementType(), newOutputLayoutAttr); | ||
|
||
return rewriter.create<ReduceOp>(srcOp.getLoc(), newOutputType, | ||
srcOp.getInput(), true /*keep_dim*/, | ||
srcOp.getDimArg().value_or(nullptr)); | ||
} | ||
|
||
void replaceOpWithReshapeOp(ReduceOp srcOp, ReduceOp newReduceOp, | ||
PatternRewriter &rewriter, | ||
RankedTensorType outputType) const { | ||
mlir::ArrayAttr shapeAttr = rewriter.getI32ArrayAttr( | ||
llvm::SmallVector<int32_t>(outputType.getShape())); | ||
|
||
rewriter.replaceOpWithNewOp<mlir::tt::ttnn::ReshapeOp>( | ||
srcOp, outputType, newReduceOp, shapeAttr); | ||
} | ||
}; | ||
|
||
// This workaround addresses the next Metal issue: | ||
// https://github.com/tenstorrent/tt-metal/issues/16118 | ||
// | ||
// TODO(mrakita): Remove this workaround once these Metal issues are fixed | ||
// (tracked by https://github.com/tenstorrent/tt-mlir/issues/1624). | ||
// | ||
template <typename ReduceOp> | ||
class ReduceOpsAllDimsRewritePattern : public OpRewritePattern<ReduceOp> { | ||
public: | ||
using OpRewritePattern<ReduceOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(ReduceOp srcOp, | ||
PatternRewriter &rewriter) const override { | ||
if (!srcOp.getDimArg() || srcOp.getDimArg()->empty()) { | ||
return failure(); | ||
} | ||
|
||
llvm::SmallVector<int64_t> reduceDims = getReduceDims(srcOp.getDimArg()); | ||
llvm::SmallSet<int64_t, 4> uniqueReduceDims(reduceDims.begin(), | ||
reduceDims.end()); | ||
|
||
// Check if reduce is done over all dimensions of the input tensor. | ||
if (uniqueReduceDims.size() != | ||
srcOp.getInput().getType().getShape().size()) { | ||
return failure(); | ||
} | ||
|
||
// In case when reduce is done over all dimensions of the input we need to | ||
// unset the dimensions attribute, because Metal supports reduce over all | ||
// dimensions for any tensor rank when reduce dimensions are not specified, | ||
// but it doesn't support reduce for tensors with rank larger than 2 when | ||
// reduce dimensions are specified. | ||
rewriter.replaceOpWithNewOp<ReduceOp>(srcOp, srcOp.getResult().getType(), | ||
srcOp.getInput(), srcOp.getKeepDim(), | ||
nullptr); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace mlir::tt::ttnn::workarounds::decomposition | ||
|
||
#endif // TTMLIR_DIALECT_TTNN_TRANSFORMS_WORKAROUNDS_DECOMPOSITION_REDUCEOPSREWRITEPATTERN_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.