Skip to content

Commit

Permalink
Merge branch 'main' into vwells/linalg_to_llvm_conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
vwellsTT authored Dec 19, 2024
2 parents bf6bf17 + 82295b0 commit 277aa4e
Show file tree
Hide file tree
Showing 79 changed files with 673 additions and 494 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/on-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ on:
pull_request:
branches: [ "main" ]

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
pre-commit:
uses: ./.github/workflows/pre-commit.yml
Expand Down Expand Up @@ -32,7 +36,7 @@ jobs:
gh workflow run ${{ env.WORKFLOW_NAME }} \
--repo ${{ env.TARGET_REPO }} --ref main \
--field test_mark=push \
--field mlir_override=${{ github.sha }}
--field mlir_override=${{ github.event.pull_request.head.sha }}
gh run list --workflow=${{ env.WORKFLOW_NAME }} --repo ${{ env.TARGET_REPO }} --limit 1
echo "Triggered ${{ env.TARGET_REPO }}"
echo "### Triggered [${{ env.TARGET_REPO }}](https://github.com/${{ env.TARGET_REPO }}/actions/workflows/${{ env.WORKFLOW_NAME }}) :rocket:" >> $GITHUB_STEP_SUMMARY
34 changes: 34 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,40 @@ def TTIR_OnesOp : TTIR_Op<"ones"> {
let results = (outs AnyRankedTensor:$result);
}

def TTIR_ReverseOp : TTIR_DPSOp<"reverse", [AllShapesMatch<["input", "result"]>]> {
let summary = "Reverse operation.";

let description = [{
Reverses the order of elements in the `operand` along the specified
`dimensions` and produces a `result` tensor.

Examples:
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "ttir.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "ttir.reverse"(%operand) {
dimensions = array<i64: 1, 0>
} : (tensor<3x2xi64>) -> tensor<3x2xi64>
// %result: [[6, 5], [4, 3], [2, 1]]
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
DenseI64ArrayAttr:$dimensions);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike,
AllShapesMatch<["value", "result"]>]> {
let summary = "Constant op.";
Expand Down
1 change: 0 additions & 1 deletion include/ttmlir/Dialect/TTNN/Analysis/BFInterleavedPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h"
#include <cstdint>

namespace mlir::tt::ttnn {

Expand Down
26 changes: 5 additions & 21 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,6 @@ def TTNN_AbsOp : TTNN_ElementwiseUnaryOp<"abs"> {
let description = [{
Eltwise absolute operation.
}];

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
wa::TTNNOperandWorkarounds tileLayoutWorkaround = wa::TTNNOperandWorkarounds(Layout::Tile);
return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds()
.addInputOperandWorkaround(tileLayoutWorkaround)
.addInputOperandWorkaround(tileLayoutWorkaround)
.addOutputOperandWorkaround(tileLayoutWorkaround);
}
}];
}

def TTNN_CbrtOp : TTNN_ElementwiseUnaryOp<"cbrt"> {
Expand Down Expand Up @@ -567,8 +556,8 @@ def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> {
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
AnyRankedTensor:$weight);
AnyRankedTensor:$weight,
AnyRankedTensor:$output);

let results = (outs AnyRankedTensor:$result);

Expand Down Expand Up @@ -817,6 +806,9 @@ def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> {

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
return wa::TTNNOperandsWorkaroundsFactory::createMaxPool2DOpOperandsWorkarounds();
}
}];

let hasVerifier = 1;
Expand Down Expand Up @@ -858,14 +850,6 @@ def TTNN_EmptyOp : TTNN_Op<"empty"> {
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config);
let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
wa::TTNNOperandWorkarounds rowMajorLayoutWorkaround = wa::TTNNOperandWorkarounds(Layout::RowMajor);
return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds()
.addOutputOperandWorkaround(rowMajorLayoutWorkaround);
}
}];

let hasVerifier = 1;
}

Expand Down
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> {
DataType getDataType() const;
uint64_t getElementSizeBytes() const;
int64_t getTensorSizeInBytes(ArrayRef<int64_t> tensorShape, ::mlir::tt::DeviceAttr device) const;
static llvm::SmallVector<int64_t> calculateLogicalShardShapeForSharding(ArrayRef<int64_t> tensorShape, mlir::AffineMap linear, GridAttr grid);
static llvm::SmallVector<int64_t> calculateLogicalShardShapeForL1Interleaved(ArrayRef<int64_t> tensorShape, Type elementType, mlir::AffineMap linear, GridAttr grid);
llvm::SmallVector<int64_t> getStride(ArrayRef<int64_t> logicalShape) const;
llvm::SmallVector<int64_t> getShardShape() const;
llvm::SmallVector<int64_t> getScalarShardShape() const;
Expand Down
54 changes: 40 additions & 14 deletions include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,33 +83,52 @@ struct TTNNOperandWorkarounds {
}
};

// Workaround result struct that encapsulates the previous and target
// (workaround) value and a method indicating whether the workaround modifies
// the workaround value.
template <typename T>
struct WorkaroundResult {
T previousValue;
T targetValue;
bool isModified() const { return previousValue != targetValue; }
};

// Layout workaround result struct.
struct LayoutWorkaroundResult : public WorkaroundResult<Layout> {};

// Buffer type workaround result struct.
struct BufferTypeWorkaroundResult : public WorkaroundResult<BufferType> {};

// Memory layout workaround result struct.
struct MemoryLayoutWorkaroundResult
: public WorkaroundResult<std::optional<TensorMemoryLayout>> {};

// Struct that encapsulates the result of applying the workarounds.
// It contains the target tensor layout, buffer type and tensor memory layout
// results and a flag indicating whether the workarounds were applied.
struct WorkaroundResult {
// Target tensor layout.
std::pair<Layout, bool> targetTensorLayoutResult;
struct WorkaroundResults {
// Tensor layout workaround result.
LayoutWorkaroundResult tensorLayoutResult;

// Target tensor buffer type.
std::pair<BufferType, bool> targetTensorBufferTypeResult;
// Tensor buffer type workaround result.
BufferTypeWorkaroundResult tensorBufferTypeResult;

// Target tensor memory layout. Can be nullopt for tensors on host.
std::pair<std::optional<TensorMemoryLayout>, bool>
targetTensorMemoryLayoutResult;
// Tensor memory layout workaround result.
MemoryLayoutWorkaroundResult tensorMemoryLayoutResult;

// Returns true if any of the workarounds were applied.
bool modified() const {
return targetTensorLayoutResult.second ||
targetTensorBufferTypeResult.second ||
targetTensorMemoryLayoutResult.second;
bool isModified() const {
return tensorLayoutResult.isModified() ||
tensorBufferTypeResult.isModified() ||
tensorMemoryLayoutResult.isModified();
}
};

// Apply the operand workarounds to the layout attribute that contains
// tensor layout, buffer type and tensor memory layout arguments.
// Returns the result of applying the workarounds.
WorkaroundResult applyWorkarounds(const TTNNOperandWorkarounds &workaround,
const TTNNLayoutAttr &inputLayoutAttr);
WorkaroundResults applyWorkarounds(const TTNNOperandWorkarounds &workaround,
const TTNNLayoutAttr &inputLayoutAttr);

// Class that encapsulates operands workarounds.
// It contains input and output workarounds for operands.
Expand Down Expand Up @@ -170,6 +189,13 @@ class TTNNOperandsWorkarounds {
llvm::SmallVector<TTNNOperandWorkarounds> outputOperandWorkarounds;
};

// Workaround factory class that creates workarounds for ops.
class TTNNOperandsWorkaroundsFactory {
public:
// Create workarounds for max_pool2d op operands.
static TTNNOperandsWorkarounds createMaxPool2DOpOperandsWorkarounds();
};

} // namespace mlir::tt::ttnn::wa

#endif
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ struct TTIRToTTNNBackendPipelineOptions
//
Option<bool> layouotWorkaroundsEnabled{
*this, "enable-layout-workaround-pass",
llvm::cl::desc("Enable layout workaround pass."), llvm::cl::init(false)};
llvm::cl::desc("Enable layout workaround pass."), llvm::cl::init(true)};

Option<bool> decompositionWorkaroundsEnabled{
*this, "enable-decomposition-workaround-pass",
Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def TTNNWorkarounds : Pass<"ttnn-workaround", "::mlir::ModuleOp"> {
let options = [
Option<"layouotWorkaroundsEnabled",
"ttnn-enable-layout-workaround-pass",
"bool", /*default=*/"false",
"bool", /*default=*/"true",
"TTNN Layout Workarounds Pass">,
Option<"decompositionWorkaroundsEnabled",
"ttnn-enable-decomposition-workaround-pass",
Expand Down
14 changes: 13 additions & 1 deletion include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,25 @@
#ifndef TTMLIR_DIALECT_TTNN_UTILS_TRANSFORMUTILS_H
#define TTMLIR_DIALECT_TTNN_UTILS_TRANSFORMUTILS_H

#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"

#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"

namespace mlir::tt::ttnn::utils {
// Get or insert device for the given operation.
mlir::Value getOrInsertDevice(mlir::PatternRewriter &rewriter,
GetDeviceOp getOrInsertDevice(mlir::PatternRewriter &rewriter,
mlir::Operation *op);

// Helper method to insert a ToLayoutOp to convert the input operand to the
// desired tensor layout, buffer type and memory layout.
ToLayoutOp
createToLayoutOp(mlir::Operation *op,
mlir::TypedValue<RankedTensorType> inputValue,
PatternRewriter &rewriter, Layout targetTensorLayout,
BufferType targetTensorBufferType,
std::optional<TensorMemoryLayout> targetTensorMemoryLayout);
} // namespace mlir::tt::ttnn::utils

#endif
11 changes: 9 additions & 2 deletions include/ttmlir/Dialect/TTNN/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define TTMLIR_DIALECT_TTNN_UTILS_UTILS_H

#include <llvm/Support/CommandLine.h>
#include <mlir/IR/Value.h>

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
Expand Down Expand Up @@ -46,9 +47,15 @@ createRankedTensorTypeWithEncoding(RankedTensorType tensorType,
// Return the L1 memory usage of the output tensor of the given op.
// Used within L1 interleaved policies.
//
uint64_t getOpOutputL1Usage(Operation *op, TTNNLayoutAttr opLayout,
DeviceAttr &deviceAttr);
uint64_t getOpOutputL1Usage(TTNNLayoutAttr opLayout);

// Helper method to get the tensor layout attribute from the tensor value.
TTNNLayoutAttr
getLayoutAttrFromTensor(mlir::TypedValue<RankedTensorType> tensorValue);

// Helper method to get the element type for the given tensor layout and data.
Type getElementType(MLIRContext *context, Layout tensorLayout,
DataType dataType);
} // namespace mlir::tt::ttnn::utils

#endif // TTMLIR_DIALECT_TTNN_UTILS_UTILS_H
34 changes: 34 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1709,6 +1709,33 @@ class StableHLOToTTIRReturnOpConversionPattern
}
};

class StableHLOToTTIROpReverseOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::ReverseOp> {

using OpConversionPattern<mlir::stablehlo::ReverseOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::ReverseOp srcOp,
mlir::stablehlo::ReverseOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));

tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

rewriter.replaceOpWithNewOp<mlir::tt::ttir::ReverseOp>(
srcOp,
outputType, // result type
adaptor.getOperand(), // input
outputTensor, // output
adaptor.getDimensionsAttr() // dimensions
);
return success();
}
};

void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand Down Expand Up @@ -1910,6 +1937,12 @@ void addReturnOpConversionPatterns(MLIRContext *ctx,
patterns.add<StableHLOToTTIRReturnOpConversionPattern>(typeConverter, ctx);
}

void addReverseOpConversionPattern(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIROpReverseOpConversionPattern>(typeConverter, ctx);
}

} // namespace

namespace mlir::tt {
Expand Down Expand Up @@ -1938,6 +1971,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addIotaOpConversionPattern(ctx, patterns, typeConverter);
addScatterOpConversionPatterns(ctx, patterns, typeConverter);
addReturnOpConversionPatterns(ctx, patterns, typeConverter);
addReverseOpConversionPattern(ctx, patterns, typeConverter);
}

} // namespace mlir::tt
14 changes: 8 additions & 6 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ class OnesOpConversionPattern : public OpConversionPattern<ttir::OnesOp> {
// Device only exists if memLayout is *not* null
//
auto device =
memLayout ? ::ttnn::utils::getOrInsertDevice(rewriter, op) : nullptr;
memLayout ? mlir::Value(::ttnn::utils::getOrInsertDevice(rewriter, op))
: nullptr;

// MemoryConfigAttr only exists if memLayout is *not* null
//
Expand Down Expand Up @@ -234,8 +235,9 @@ class ToLayoutOpConversionPattern
rewriter.replaceOpWithNewOp<ttnn::ToLayoutOp>(
op, this->getTypeConverter()->convertType(result), adaptor.getInput(),
outputLayout, outputDataType, outputMemConfigAttr,
isOutputOnHost ? nullptr
: ::ttnn::utils::getOrInsertDevice(rewriter, op));
isOutputOnHost
? nullptr
: mlir::Value(::ttnn::utils::getOrInsertDevice(rewriter, op)));

return success();
}
Expand All @@ -247,8 +249,8 @@ class ToLayoutOpConversionPattern
// EmbeddingBackwardOp supports row major layout for the first and second
// operands.
for (mlir::Operation *user : op.getResult().getUsers()) {
if (isa<ttir::Conv2dOp>(user) || isa<ttir::MaxPool2dOp>(user) ||
isa<ttir::SliceOp>(user) || isa<ttir::EmbeddingOp>(user) ||
if (isa<ttir::Conv2dOp>(user) || isa<ttir::SliceOp>(user) ||
isa<ttir::EmbeddingOp>(user) ||
(isa<ttir::EmbeddingBackwardOp>(user) &&
(user->getOperand(0) == op || user->getOperand(1) == op))) {
return true;
Expand Down Expand Up @@ -352,7 +354,7 @@ class EmbeddingOpConversionPattern
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::EmbeddingOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getOutput(), adaptor.getWeight());
adaptor.getInput(), adaptor.getWeight(), adaptor.getOutput());

return success();
}
Expand Down
Loading

0 comments on commit 277aa4e

Please sign in to comment.