diff --git a/.github/workflows/on-pr.yml b/.github/workflows/on-pr.yml index f3e30ac80b..478a9263fd 100644 --- a/.github/workflows/on-pr.yml +++ b/.github/workflows/on-pr.yml @@ -5,6 +5,10 @@ on: pull_request: branches: [ "main" ] +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: pre-commit: uses: ./.github/workflows/pre-commit.yml @@ -32,7 +36,7 @@ jobs: gh workflow run ${{ env.WORKFLOW_NAME }} \ --repo ${{ env.TARGET_REPO }} --ref main \ --field test_mark=push \ - --field mlir_override=${{ github.sha }} + --field mlir_override=${{ github.event.pull_request.head.sha }} gh run list --workflow=${{ env.WORKFLOW_NAME }} --repo ${{ env.TARGET_REPO }} --limit 1 echo "Triggered ${{ env.TARGET_REPO }}" echo "### Triggered [${{ env.TARGET_REPO }}](https://github.com/${{ env.TARGET_REPO }}/actions/workflows/${{ env.WORKFLOW_NAME }}) :rocket:" >> $GITHUB_STEP_SUMMARY diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index e685172c24..76a972df10 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -1129,6 +1129,40 @@ def TTIR_OnesOp : TTIR_Op<"ones"> { let results = (outs AnyRankedTensor:$result); } +def TTIR_ReverseOp : TTIR_DPSOp<"reverse", [AllShapesMatch<["input", "result"]>]> { + let summary = "Reverse operation."; + + let description = [{ + Reverses the order of elements in the `operand` along the specified + `dimensions` and produces a `result` tensor. + + Examples: + // %operand = [[1, 2], [3, 4], [5, 6]] + %result = "ttir.reverse"(%operand) { + dimensions = array + } : (tensor<3x2xi32>) -> tensor<3x2xi32> + // %result: [[2, 1], [4, 3], [6, 5]] + + // %operand = [[1, 2], [3, 4], [5, 6]] + %result = "ttir.reverse"(%operand) { + dimensions = array + } : (tensor<3x2xi64>) -> tensor<3x2xi64> + // %result: [[6, 5], [4, 3], [2, 1]] + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$output, + DenseI64ArrayAttr:$dimensions); + + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike, AllShapesMatch<["value", "result"]>]> { let summary = "Constant op."; diff --git a/include/ttmlir/Dialect/TTNN/Analysis/BFInterleavedPolicy.h b/include/ttmlir/Dialect/TTNN/Analysis/BFInterleavedPolicy.h index 1744a1d415..77d7c131f1 100644 --- a/include/ttmlir/Dialect/TTNN/Analysis/BFInterleavedPolicy.h +++ b/include/ttmlir/Dialect/TTNN/Analysis/BFInterleavedPolicy.h @@ -7,7 +7,6 @@ #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h" -#include namespace mlir::tt::ttnn { diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 758fb41d7e..4e434db56c 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -176,17 +176,6 @@ def TTNN_AbsOp : TTNN_ElementwiseUnaryOp<"abs"> { let description = [{ Eltwise absolute operation. }]; - - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } - wa::TTNNOperandsWorkarounds getOperandsWorkarounds() { - wa::TTNNOperandWorkarounds tileLayoutWorkaround = wa::TTNNOperandWorkarounds(Layout::Tile); - return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds() - .addInputOperandWorkaround(tileLayoutWorkaround) - .addInputOperandWorkaround(tileLayoutWorkaround) - .addOutputOperandWorkaround(tileLayoutWorkaround); - } - }]; } def TTNN_CbrtOp : TTNN_ElementwiseUnaryOp<"cbrt"> { @@ -567,8 +556,8 @@ def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> { }]; let arguments = (ins AnyRankedTensor:$input, - AnyRankedTensor:$output, - AnyRankedTensor:$weight); + AnyRankedTensor:$weight, + AnyRankedTensor:$output); let results = (outs AnyRankedTensor:$result); @@ -817,6 +806,9 @@ def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> { let extraClassDeclaration = [{ MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + wa::TTNNOperandsWorkarounds getOperandsWorkarounds() { + return wa::TTNNOperandsWorkaroundsFactory::createMaxPool2DOpOperandsWorkarounds(); + } }]; let hasVerifier = 1; @@ -858,14 +850,6 @@ def TTNN_EmptyOp : TTNN_Op<"empty"> { OptionalAttr:$memory_config); let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - wa::TTNNOperandsWorkarounds getOperandsWorkarounds() { - wa::TTNNOperandWorkarounds rowMajorLayoutWorkaround = wa::TTNNOperandWorkarounds(Layout::RowMajor); - return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds() - .addOutputOperandWorkaround(rowMajorLayoutWorkaround); - } - }]; - let hasVerifier = 1; } diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td index e483b07bf2..94d05eadcb 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td @@ -164,6 +164,8 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> { DataType getDataType() const; uint64_t getElementSizeBytes() const; int64_t getTensorSizeInBytes(ArrayRef tensorShape, ::mlir::tt::DeviceAttr device) const; + static llvm::SmallVector calculateLogicalShardShapeForSharding(ArrayRef tensorShape, mlir::AffineMap linear, GridAttr grid); + static llvm::SmallVector calculateLogicalShardShapeForL1Interleaved(ArrayRef tensorShape, Type elementType, mlir::AffineMap linear, GridAttr grid); llvm::SmallVector getStride(ArrayRef logicalShape) const; llvm::SmallVector getShardShape() const; llvm::SmallVector getScalarShardShape() const; diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h b/include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h index 4122b0ca03..9e07a0315e 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h @@ -83,33 +83,52 @@ struct TTNNOperandWorkarounds { } }; +// Workaround result struct that encapsulates the previous and target +// (workaround) value and a method indicating whether the workaround modifies +// the workaround value. +template +struct WorkaroundResult { + T previousValue; + T targetValue; + bool isModified() const { return previousValue != targetValue; } +}; + +// Layout workaround result struct. +struct LayoutWorkaroundResult : public WorkaroundResult {}; + +// Buffer type workaround result struct. +struct BufferTypeWorkaroundResult : public WorkaroundResult {}; + +// Memory layout workaround result struct. +struct MemoryLayoutWorkaroundResult + : public WorkaroundResult> {}; + // Struct that encapsulates the result of applying the workarounds. // It contains the target tensor layout, buffer type and tensor memory layout // results and a flag indicating whether the workarounds were applied. -struct WorkaroundResult { - // Target tensor layout. - std::pair targetTensorLayoutResult; +struct WorkaroundResults { + // Tensor layout workaround result. + LayoutWorkaroundResult tensorLayoutResult; - // Target tensor buffer type. - std::pair targetTensorBufferTypeResult; + // Tensor buffer type workaround result. + BufferTypeWorkaroundResult tensorBufferTypeResult; - // Target tensor memory layout. Can be nullopt for tensors on host. - std::pair, bool> - targetTensorMemoryLayoutResult; + // Tensor memory layout workaround result. + MemoryLayoutWorkaroundResult tensorMemoryLayoutResult; // Returns true if any of the workarounds were applied. - bool modified() const { - return targetTensorLayoutResult.second || - targetTensorBufferTypeResult.second || - targetTensorMemoryLayoutResult.second; + bool isModified() const { + return tensorLayoutResult.isModified() || + tensorBufferTypeResult.isModified() || + tensorMemoryLayoutResult.isModified(); } }; // Apply the operand workarounds to the layout attribute that contains // tensor layout, buffer type and tensor memory layout arguments. // Returns the result of applying the workarounds. -WorkaroundResult applyWorkarounds(const TTNNOperandWorkarounds &workaround, - const TTNNLayoutAttr &inputLayoutAttr); +WorkaroundResults applyWorkarounds(const TTNNOperandWorkarounds &workaround, + const TTNNLayoutAttr &inputLayoutAttr); // Class that encapsulates operands workarounds. // It contains input and output workarounds for operands. @@ -170,6 +189,13 @@ class TTNNOperandsWorkarounds { llvm::SmallVector outputOperandWorkarounds; }; +// Workaround factory class that creates workarounds for ops. +class TTNNOperandsWorkaroundsFactory { +public: + // Create workarounds for max_pool2d op operands. + static TTNNOperandsWorkarounds createMaxPool2DOpOperandsWorkarounds(); +}; + } // namespace mlir::tt::ttnn::wa #endif diff --git a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h index d27c488eda..a65f95c6b2 100644 --- a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h +++ b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h @@ -130,7 +130,7 @@ struct TTIRToTTNNBackendPipelineOptions // Option layouotWorkaroundsEnabled{ *this, "enable-layout-workaround-pass", - llvm::cl::desc("Enable layout workaround pass."), llvm::cl::init(false)}; + llvm::cl::desc("Enable layout workaround pass."), llvm::cl::init(true)}; Option decompositionWorkaroundsEnabled{ *this, "enable-decomposition-workaround-pass", diff --git a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td index 4597db87e1..9d87bd0cf9 100644 --- a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td @@ -38,7 +38,7 @@ def TTNNWorkarounds : Pass<"ttnn-workaround", "::mlir::ModuleOp"> { let options = [ Option<"layouotWorkaroundsEnabled", "ttnn-enable-layout-workaround-pass", - "bool", /*default=*/"false", + "bool", /*default=*/"true", "TTNN Layout Workarounds Pass">, Option<"decompositionWorkaroundsEnabled", "ttnn-enable-decomposition-workaround-pass", diff --git a/include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h b/include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h index 2dc83388d1..f491f2ed5e 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h +++ b/include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h @@ -5,13 +5,25 @@ #ifndef TTMLIR_DIALECT_TTNN_UTILS_TRANSFORMUTILS_H #define TTMLIR_DIALECT_TTNN_UTILS_TRANSFORMUTILS_H +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" + #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" namespace mlir::tt::ttnn::utils { // Get or insert device for the given operation. -mlir::Value getOrInsertDevice(mlir::PatternRewriter &rewriter, +GetDeviceOp getOrInsertDevice(mlir::PatternRewriter &rewriter, mlir::Operation *op); + +// Helper method to insert a ToLayoutOp to convert the input operand to the +// desired tensor layout, buffer type and memory layout. +ToLayoutOp +createToLayoutOp(mlir::Operation *op, + mlir::TypedValue inputValue, + PatternRewriter &rewriter, Layout targetTensorLayout, + BufferType targetTensorBufferType, + std::optional targetTensorMemoryLayout); } // namespace mlir::tt::ttnn::utils #endif diff --git a/include/ttmlir/Dialect/TTNN/Utils/Utils.h b/include/ttmlir/Dialect/TTNN/Utils/Utils.h index d3fb76bda9..71dc98b7f8 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/Utils.h +++ b/include/ttmlir/Dialect/TTNN/Utils/Utils.h @@ -6,6 +6,7 @@ #define TTMLIR_DIALECT_TTNN_UTILS_UTILS_H #include +#include #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" @@ -46,9 +47,15 @@ createRankedTensorTypeWithEncoding(RankedTensorType tensorType, // Return the L1 memory usage of the output tensor of the given op. // Used within L1 interleaved policies. // -uint64_t getOpOutputL1Usage(Operation *op, TTNNLayoutAttr opLayout, - DeviceAttr &deviceAttr); +uint64_t getOpOutputL1Usage(TTNNLayoutAttr opLayout); +// Helper method to get the tensor layout attribute from the tensor value. +TTNNLayoutAttr +getLayoutAttrFromTensor(mlir::TypedValue tensorValue); + +// Helper method to get the element type for the given tensor layout and data. +Type getElementType(MLIRContext *context, Layout tensorLayout, + DataType dataType); } // namespace mlir::tt::ttnn::utils #endif // TTMLIR_DIALECT_TTNN_UTILS_UTILS_H diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 4f2f82361f..61a3154a10 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -1709,6 +1709,33 @@ class StableHLOToTTIRReturnOpConversionPattern } }; +class StableHLOToTTIROpReverseOpConversionPattern + : public OpConversionPattern { + + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mlir::stablehlo::ReverseOp srcOp, + mlir::stablehlo::ReverseOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto outputType = mlir::cast( + getTypeConverter()->convertType(srcOp.getResult().getType())); + + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + + rewriter.replaceOpWithNewOp( + srcOp, + outputType, // result type + adaptor.getOperand(), // input + outputTensor, // output + adaptor.getDimensionsAttr() // dimensions + ); + return success(); + } +}; + void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { @@ -1910,6 +1937,12 @@ void addReturnOpConversionPatterns(MLIRContext *ctx, patterns.add(typeConverter, ctx); } +void addReverseOpConversionPattern(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, ctx); +} + } // namespace namespace mlir::tt { @@ -1938,6 +1971,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx, addIotaOpConversionPattern(ctx, patterns, typeConverter); addScatterOpConversionPatterns(ctx, patterns, typeConverter); addReturnOpConversionPatterns(ctx, patterns, typeConverter); + addReverseOpConversionPattern(ctx, patterns, typeConverter); } } // namespace mlir::tt diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index e60261bada..945492481d 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -135,7 +135,8 @@ class OnesOpConversionPattern : public OpConversionPattern { // Device only exists if memLayout is *not* null // auto device = - memLayout ? ::ttnn::utils::getOrInsertDevice(rewriter, op) : nullptr; + memLayout ? mlir::Value(::ttnn::utils::getOrInsertDevice(rewriter, op)) + : nullptr; // MemoryConfigAttr only exists if memLayout is *not* null // @@ -234,8 +235,9 @@ class ToLayoutOpConversionPattern rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(result), adaptor.getInput(), outputLayout, outputDataType, outputMemConfigAttr, - isOutputOnHost ? nullptr - : ::ttnn::utils::getOrInsertDevice(rewriter, op)); + isOutputOnHost + ? nullptr + : mlir::Value(::ttnn::utils::getOrInsertDevice(rewriter, op))); return success(); } @@ -247,8 +249,8 @@ class ToLayoutOpConversionPattern // EmbeddingBackwardOp supports row major layout for the first and second // operands. for (mlir::Operation *user : op.getResult().getUsers()) { - if (isa(user) || isa(user) || - isa(user) || isa(user) || + if (isa(user) || isa(user) || + isa(user) || (isa(user) && (user->getOperand(0) == op || user->getOperand(1) == op))) { return true; @@ -352,7 +354,7 @@ class EmbeddingOpConversionPattern ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), - adaptor.getInput(), adaptor.getOutput(), adaptor.getWeight()); + adaptor.getInput(), adaptor.getWeight(), adaptor.getOutput()); return success(); } diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 319d8b1e60..ef9fd29f4a 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -1545,6 +1545,40 @@ ::mlir::LogicalResult mlir::tt::ttir::FillCacheOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// ReverseOp +//===----------------------------------------------------------------------===// + +::mlir::LogicalResult mlir::tt::ttir::ReverseOp::verify() { + llvm::ArrayRef dimensions = getDimensions(); + + // Check that all given dimensions are unique/not repeating. + llvm::SmallDenseSet uniqueDims(dimensions.begin(), dimensions.end()); + + if (uniqueDims.size() != dimensions.size()) { + return emitOpError("dimensions should be unique. Got: ") << dimensions; + } + + ::mlir::RankedTensorType operandTy = getInput().getType(); + + // Check that each dimension is positive and within valid interval [0, + // operandRank). + for (int64_t dim : dimensions) { + if (dim < 0) { + return emitOpError( + "all dimensions should be non-negative. Got dimension: ") + << dim; + } + + if (dim >= operandTy.getRank()) { + return emitOpError("all dimensions should be in interval [0, ") + << operandTy.getRank() << "). Got dimension: " << dim; + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // GenericOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTNN/Analysis/BFInterleavedPolicy.cpp b/lib/Dialect/TTNN/Analysis/BFInterleavedPolicy.cpp index 4d58d7b5aa..4a6f26b5e4 100644 --- a/lib/Dialect/TTNN/Analysis/BFInterleavedPolicy.cpp +++ b/lib/Dialect/TTNN/Analysis/BFInterleavedPolicy.cpp @@ -14,7 +14,6 @@ void BFInterleavedPolicy::run() { for (Operation &funcOp : rootOp->getRegion(0).getOps()) { func::FuncOp func = dyn_cast(funcOp); mlir::tt::scheduler::Scheduler scheduler(&func); - deviceAttr = getCurrentScopeDevice(func); // Initialize the policy. // @@ -53,8 +52,7 @@ void BFInterleavedPolicy::run() { // if (hasL1BufferType(op)) { TTNNLayoutAttr layout = getL1InterleavedLayout(op); - uint64_t opOutputL1Usage = - utils::getOpOutputL1Usage(op, layout, deviceAttr); + uint64_t opOutputL1Usage = utils::getOpOutputL1Usage(layout); if (currentL1Usage + opOutputL1Usage <= getAvailableL1CacheSize()) { allocOfL1Mem = opOutputL1Usage; @@ -92,8 +90,7 @@ void BFInterleavedPolicy::run() { uint64_t numOfUsers = std::distance(nextOpForScheduling->user_begin(), nextOpForScheduling->user_end()); currentL1UsagePerOp[nextOpForScheduling].l1MemUsagePerUser = - utils::getOpOutputL1Usage(nextOpForScheduling, opL1MemSpec.layout, - deviceAttr); + utils::getOpOutputL1Usage(opL1MemSpec.layout); currentL1UsagePerOp[nextOpForScheduling].numOfUnscheduledUsers = numOfUsers; currentL1Usage += diff --git a/lib/Dialect/TTNN/Analysis/GreedyL1InterleavedPolicy.cpp b/lib/Dialect/TTNN/Analysis/GreedyL1InterleavedPolicy.cpp index 5606132906..cf1adbf595 100644 --- a/lib/Dialect/TTNN/Analysis/GreedyL1InterleavedPolicy.cpp +++ b/lib/Dialect/TTNN/Analysis/GreedyL1InterleavedPolicy.cpp @@ -130,7 +130,6 @@ GreedyL1InterleavedPolicy::OpConfig GreedyL1InterleavedPolicy::getGreedyConfig( void GreedyL1InterleavedPolicy::run() { for (Operation &funcOp : rootOp->getRegion(0).getOps()) { func::FuncOp func = dyn_cast(funcOp); - deviceAttr = getCurrentScopeDevice(func); // Start the policy. // @@ -166,8 +165,8 @@ void GreedyL1InterleavedPolicy::run() { if (op->hasOneUse() && hasL1BufferType(op)) { L1Usage l1Usage; - l1Usage.outputL1Usage = utils::getOpOutputL1Usage( - op, getL1InterleavedLayout(op), deviceAttr); + l1Usage.outputL1Usage = + utils::getOpOutputL1Usage(getL1InterleavedLayout(op)); l1Usage.requiredL1Usage = 0; opsL1Usage[op] = l1Usage; } @@ -192,8 +191,8 @@ void GreedyL1InterleavedPolicy::run() { // if (operandOpLayout.hasInterleavedL1TensorMemoryLayout()) { L1Usage l1Usage; - l1Usage.outputL1Usage = utils::getOpOutputL1Usage( - operandOp, operandOpLayout, deviceAttr); + l1Usage.outputL1Usage = + utils::getOpOutputL1Usage(operandOpLayout); l1Usage.requiredL1Usage = OpMemSpecMap[operandOp].requiredL1Usage; opsL1Usage[operandOp] = l1Usage; } @@ -252,15 +251,14 @@ void GreedyL1InterleavedPolicy::run() { std::max(intermediateRequiredL1Usage, intermediateL1Usage + OpMemSpecMap[operandOp].requiredL1Usage); - intermediateL1Usage += utils::getOpOutputL1Usage( - operandOp, OpMemSpecMap[operandOp].layout, deviceAttr); + intermediateL1Usage += + utils::getOpOutputL1Usage(OpMemSpecMap[operandOp].layout); } } OpMemSpecMap[op].requiredL1Usage = std::max(intermediateRequiredL1Usage, intermediateL1Usage + - utils::getOpOutputL1Usage( - op, OpMemSpecMap[op].layout, deviceAttr)); + utils::getOpOutputL1Usage(OpMemSpecMap[op].layout)); } } } diff --git a/lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp b/lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp index 3f4ef25ab2..799ef6c5c8 100644 --- a/lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp +++ b/lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp @@ -228,16 +228,21 @@ void LegalLayoutAnalysis::analysisImplementation() { TensorMemoryLayoutAttr::get(op->getContext(), TensorMemoryLayout::Interleaved))); - // L1 Interleaved (same as above). - analysisResult.push_back(TTNNLayoutAttr::get( - op->getContext(), tensorShape, elementType, BufferType::L1, - analysisInput.maxGrid, - TensorMemoryLayoutAttr::get(op->getContext(), - TensorMemoryLayout::Interleaved))); + // L1 Interleaved - It must be tiled. + // TODO(odjuricic): Check that this is always the case. + if (elementType == tileElementType) { + analysisResult.push_back(TTNNLayoutAttr::get( + op->getContext(), tensorShape, elementType, BufferType::L1, + analysisInput.maxGrid, + TensorMemoryLayoutAttr::get(op->getContext(), + TensorMemoryLayout::Interleaved))); + } // L1 Sharded TTNNLayoutAttr shardedBase = layout.withBufferType(op->getContext(), BufferType::L1) + .withMemoryLayout(op->getContext(), + TensorMemoryLayout::BlockSharded) .withElementType(op->getContext(), elementType); assert(analysisInput.maxGrid.getShape().size() == 2 && diff --git a/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp b/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp index 43c5984ed9..c7bf769ddc 100644 --- a/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp @@ -4,13 +4,10 @@ #include -#include "mlir/IR/Builders.h" -#include "mlir/IR/DialectImplementation.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" #include "ttmlir/Dialect/TTNN/Utils/Utils.h" #include "ttmlir/Utils.h" -#include "llvm/ADT/TypeSwitch.h" using namespace mlir::tt::ttnn; @@ -68,6 +65,67 @@ bool TTNNLayoutAttr::hasInterleavedDRAMTensorMemoryLayout() const { (getMemLayout().getValue() == TensorMemoryLayout::Interleaved); } +// Calculate the logical shape of the shard. +// +// Shard is defined as a piece of the tensor that is mapped to a single grid +// core. This function returns the shard shape for tensors with BLOCK SHARDED +// tensor memory layout. +// +// All examples assume that the tensor is mapped to a 8x8 grid. +// Example: tensor<32x32xbf16> -> {4, 4} +// Example: tensor<65x65xbf16> -> {9, 9} +// +// return The logical shard shape in case of block sharded tensor memory layout. +llvm::SmallVector +TTNNLayoutAttr::calculateLogicalShardShapeForSharding( + ArrayRef tensorShape, mlir::AffineMap linear, GridAttr grid) { + assert(linear.getNumResults() == grid.getShape().size()); + mlir::SmallVector physicalShape = + ttmlir::utils::evalShape(linear, tensorShape); + mlir::SmallVector shardShape(linear.getNumResults()); + for (size_t i = 0; i < linear.getNumResults(); ++i) { + shardShape[i] = + (physicalShape[i] + grid.getShape()[i] - 1) / grid.getShape()[i]; + } + return shardShape; +} + +// Calculate the logical shape of the shard. +// +// Shard is defined as a piece of the tensor that is mapped to a single grid +// core. This function returns the shard shape for tensors with INTERLEAVED +// tensor memory layout. +// +// All examples assume that the tensor is mapped to a 8x8 grid. +// Example: tensor<1x1024xbf16> ( -> 32 tiles ) -> {1, 1} +// Example: tensor<512x512xbf16> ( -> 256 tiles ) -> {1, 4} +// Example: tensor<32x2049xbf16> ( -> 65 tiles ) -> {1, 2} +// +// return The logical shard shape in case of interleaved tensor memory layout. +llvm::SmallVector +TTNNLayoutAttr::calculateLogicalShardShapeForL1Interleaved( + ArrayRef tensorShape, mlir::Type elementType, + mlir::AffineMap linear, mlir::tt::GridAttr grid) { + assert(linear.getNumResults() == grid.getShape().size()); + assert(mlir::isa(elementType)); + + mlir::SmallVector physicalShape = + ttmlir::utils::evalShape(linear, tensorShape); + mlir::SmallVector physicalTiledShape = + mlir::cast(elementType).getTiledShape(physicalShape); + uint64_t numOfTiles = + std::accumulate(physicalTiledShape.begin(), physicalTiledShape.end(), 1, + std::multiplies()); + uint64_t numOfGridUnits = + std::accumulate(grid.getShape().begin(), grid.getShape().end(), 1, + std::multiplies()); + + mlir::SmallVector shardShape; + shardShape.resize(grid.getShape().size() - 1, 1); + shardShape.push_back((numOfTiles + numOfGridUnits - 1) / numOfGridUnits); + return mlir::cast(elementType).getScalarShape(shardShape); +} + // Get stride given tensor logical shape llvm::SmallVector TTNNLayoutAttr::getStride(ArrayRef logicalShape) const { @@ -157,12 +215,12 @@ mlir::tt::DataType TTNNLayoutAttr::getDataType() const { return elementTypeToDataType(elementType); } -// Gets the size of shard in bytes +// Get the size of the element in bytes // -// This function returns the size of the shard in bytes. -// Size is calculated by multiplying shard shape with element size. +// This function returns the size of a single tensor element in bytes. +// Distinction is made between scalar types and TileType. // -// return The size of the shard in bytes. +// return The size of the element in bytes. uint64_t TTNNLayoutAttr::getElementSizeBytes() const { mlir::Type elementType = getElementType(); if (isTiled()) { @@ -177,7 +235,7 @@ uint64_t TTNNLayoutAttr::getElementSizeBytes() const { // Return the shape of the shard. // Example: memref<2x2x!tt.tile<32x32xf32>> -> { 2, 2 } // Example: memref<128x128xf32> -> { 128, 128 } -// Example: memref<2x3!tt.tile<32x32xf32>> -> { 2, 3 } +// Example: memref<2x3x!tt.tile<32x32xf32>> -> { 2, 3 } // // return The shape of the shard. llvm::SmallVector TTNNLayoutAttr::getShardShape() const { @@ -283,13 +341,13 @@ mlir::AffineMap TTNNLayoutAttr::replaceMemoryMapSymbolsWithShardShape( "shard rank"); SmallVector symReplacements; - for (unsigned i = 0; i < physicalMemoryMap.getNumSymbols(); ++i) { + for (size_t i = 0; i < physicalMemoryMap.getNumSymbols(); ++i) { symReplacements.push_back( getAffineConstantExpr(shardShape[i], getContext())); } SmallVector dimReplacements; - for (unsigned i = 0; i < physicalMemoryMap.getNumDims(); ++i) { + for (size_t i = 0; i < physicalMemoryMap.getNumDims(); ++i) { dimReplacements.push_back(getAffineDimExpr(i, getContext())); } @@ -453,14 +511,23 @@ TTNNLayoutAttr TTNNLayoutAttr::get( Type elementType, BufferType bufferType, GridAttr grid, TensorMemoryLayoutAttr memLayoutAttr, ArrayRef> collapseIntervals) { + // Construct a new affine map which will be used to map from logical - // space to physical space + // space to physical space. AffineMap linear = collapsedLinearAffineMap( context, tensorShape, grid.getShape(), collapseIntervals); - // Calculate shard shape by evaluating the linear map with last element - // of the tensor shape and dividing it by the grid shape - mlir::SmallVector shardShape = - calculateLogicalShardShape(tensorShape, linear, grid); + + // Calculate shard shape + mlir::SmallVector shardShape; + if (bufferType == BufferType::L1 && + memLayoutAttr.getValue() == TensorMemoryLayout::Interleaved) { + shardShape = TTNNLayoutAttr::calculateLogicalShardShapeForL1Interleaved( + tensorShape, elementType, linear, grid); + } else { + shardShape = TTNNLayoutAttr::calculateLogicalShardShapeForSharding( + tensorShape, linear, grid); + } + // Build memref type with the given parameters MemRefType memRefType = buildMemRef( context, shardShape, elementType, bufferType); diff --git a/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp b/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp index 0dd7eaaafd..848d80a3ee 100644 --- a/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp +++ b/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp @@ -31,33 +31,30 @@ TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds(int inputSize, // Method to apply tensor workarounds. If the workaround is present, it // applies the workaround, and returns both the target workaround argument and // a flag indicating whether the workaround was applied. -WorkaroundResult applyWorkarounds(const TTNNOperandWorkarounds &workaround, - const TTNNLayoutAttr &inputLayoutAttr) { - WorkaroundResult result; - result.targetTensorLayoutResult.first = +WorkaroundResults applyWorkarounds(const TTNNOperandWorkarounds &workaround, + const TTNNLayoutAttr &inputLayoutAttr) { + WorkaroundResults results; + results.tensorLayoutResult.targetValue = workaround.tensorLayoutWorkaround.value_or(inputLayoutAttr.getLayout()); - result.targetTensorLayoutResult.second = - result.targetTensorLayoutResult.first != inputLayoutAttr.getLayout(); + results.tensorLayoutResult.previousValue = inputLayoutAttr.getLayout(); - result.targetTensorBufferTypeResult.first = + results.tensorBufferTypeResult.targetValue = workaround.tensorBufferTypeWorkaround.value_or( inputLayoutAttr.getBufferType()); - result.targetTensorBufferTypeResult.second = - result.targetTensorBufferTypeResult.first != + results.tensorBufferTypeResult.previousValue = inputLayoutAttr.getBufferType(); // If the tensor memory layout workaround is present, apply it. // Otherwise, return the input tensor memory layout, which may be // nullopt if tensor is on host. - result.targetTensorMemoryLayoutResult.first = + results.tensorMemoryLayoutResult.targetValue = workaround.tensorMemoryLayoutWorkaround.has_value() ? workaround.tensorMemoryLayoutWorkaround : inputLayoutAttr.getMemLayoutOpt(); - result.targetTensorMemoryLayoutResult.second = - result.targetTensorMemoryLayoutResult.first != + results.tensorMemoryLayoutResult.previousValue = inputLayoutAttr.getMemLayoutOpt(); - return result; + return results; } // Operands workarounds factory method. @@ -71,4 +68,23 @@ TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds(Operation *op) { return TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds( tensorInputs, tensorResults); } + +/////////////////////////////////////////////////////////////////////////////// +// Factory methods to create a set of workarounds for specific operations +/////////////////////////////////////////////////////////////////////////////// + +// Factory method to create a set of workarounds for max pool 2d operation +// operands. The max pool 2d operation can accept input in both row-major and +// tile layout, but the output of the operation is strictly in row-major layout. +// In order to keep the output consistent with the input, the row-major +// workaround is applied to both the input and output operands. +TTNNOperandsWorkarounds +TTNNOperandsWorkaroundsFactory::createMaxPool2DOpOperandsWorkarounds() { + wa::TTNNOperandWorkarounds rowMajorLayoutWorkaround = + wa::TTNNOperandWorkarounds(Layout::RowMajor); + return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds() + .addInputOperandWorkaround(rowMajorLayoutWorkaround) + .addInputOperandWorkaround(rowMajorLayoutWorkaround) + .addOutputOperandWorkaround(rowMajorLayoutWorkaround); +} } // namespace mlir::tt::ttnn::wa diff --git a/lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp b/lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp index 528ffc2ebe..2c0c48dbcc 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp @@ -11,6 +11,7 @@ #include "ttmlir/Dialect/TTNN/Types/Types.h" #include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h" #include "ttmlir/Dialect/TTNN/Utils/Utils.h" +#include "ttmlir/Utils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" @@ -22,7 +23,6 @@ #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "ttmlir/Utils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" @@ -35,82 +35,35 @@ namespace mlir::tt::ttnn { #define GEN_PASS_DEF_TTNNWORKAROUNDS #include "ttmlir/Dialect/TTNN/Transforms/Passes.h.inc" -// Helper method to get the tensor layout attribute from the op operand. -static TTNNLayoutAttr getLayoutAttrFromOpOperand(OpOperand &opOperand) { - auto tensorType = mlir::cast(opOperand.get().getType()); - return mlir::cast(tensorType.getEncoding()); -} - -// Helper method to get the tensor layout attribute from the op result. -static TTNNLayoutAttr getLayoutAttrFromOpResult(OpResult &opResult) { - auto tensorType = mlir::cast(opResult.getType()); - return mlir::cast(tensorType.getEncoding()); -} - -// Helper method to get the element type for the given tensor layout and data. -static Type getElementType(MLIRContext *context, Layout tensorLayout, - DataType dataType) { - return tensorLayout == Layout::Tile - ? TileType::get(context, {ttnn::TILE_HEIGHT, ttnn::TILE_WIDTH}, - dataType) - : ttnn::utils::createRowMajorTypeFromDtype(context, dataType); -} - -// Helper method to insert a ToLayoutOp to convert the input operand to the -// desired tensor layout, buffer type and memory layout. -static mlir::Value -createToLayoutOp(wa::TTNNWorkaroundInterface &op, OpOperand &inputOperand, - PatternRewriter &rewriter, Layout targetTensorLayout, - BufferType targetTensorBufferType, - std::optional targetTensorMemoryLayout) { - TTNNLayoutAttr inputLayoutAttr = getLayoutAttrFromOpOperand(inputOperand); - - // Create element type based on tensor layout. - Type elementType = getElementType(rewriter.getContext(), targetTensorLayout, - inputLayoutAttr.getDataType()); - - // Create tensor memory layout attribute. - ttnn::TensorMemoryLayoutAttr outputMemLayoutAttr = - targetTensorMemoryLayout.has_value() - ? ttnn::TensorMemoryLayoutAttr::get(rewriter.getContext(), - targetTensorMemoryLayout.value()) - : nullptr; - - // Create the output memory config attribute. - ttnn::MemoryConfigAttr outputMemConfigAttr = ttnn::MemoryConfigAttr::get( - rewriter.getContext(), - ttnn::BufferTypeAttr::get(rewriter.getContext(), targetTensorBufferType), - ttnn::ShardSpecAttr::get( - op.getContext(), - ttnn::ShapeAttr::get(rewriter.getContext(), - inputLayoutAttr.getMemref().getShape())), - outputMemLayoutAttr); +// If the layout of the output result has changed as a result of applying a +// workaround, this method transforms the layout back to the previous state +// by inserting a ToLayoutOp after the op result output in order to maintain +// the workarounds changes locally. +// +static void revertOutputLayout(wa::TTNNWorkaroundInterface &op, + PatternRewriter &rewriter, + wa::WorkaroundResults &workaroundResults, + mlir::TypedValue newOpResult) { + // Check if the data type of the output result has changed. + if (!workaroundResults.isModified()) { + return; + } - // Get the input operand type. - RankedTensorType inputOperandType = - mlir::cast(inputOperand.get().getType()); - - // Create a ToLayoutOp to convert the input operand to the desired - // tensor layout, buffer type and memory layout. - return rewriter - .create( - op.getLoc(), - ttnn::utils::createRankedTensorTypeWithEncoding( - inputOperandType, - inputLayoutAttr - .withElementType(rewriter.getContext(), elementType) - .withBufferType(rewriter.getContext(), targetTensorBufferType) - .withMemoryLayout(rewriter.getContext(), - outputMemLayoutAttr)), - inputOperand.get(), - LayoutAttr::get(rewriter.getContext(), targetTensorLayout), - DataTypeAttr::get(rewriter.getContext(), - inputLayoutAttr.getDataType()), - outputMemConfigAttr, - (targetTensorBufferType == ttnn::BufferType::SystemMemory) - ? nullptr - : utils::getOrInsertDevice(rewriter, op)) - ->getResult(0); + // Insert the toLayoutOp after the op output. + rewriter.setInsertionPointAfter(op); + + // Cast the data type back to the previous data type by inserting ToLayoutOp. + mlir::Value castLayoutOp = utils::createToLayoutOp( + op.getOperation(), newOpResult, rewriter, + workaroundResults.tensorLayoutResult.previousValue, + workaroundResults.tensorBufferTypeResult.previousValue, + workaroundResults.tensorMemoryLayoutResult.previousValue); + + // Replace the new output result with the casted output result. + rewriter.replaceUsesWithIf( + newOpResult, castLayoutOp, [&](OpOperand &operand) { + return operand.getOwner() != castLayoutOp.getDefiningOp(); + }); } // Helper method to apply workarounds to an input operand. This method inserts a @@ -121,24 +74,26 @@ static bool workaroundInputOperand( PatternRewriter &rewriter, wa::TTNNWorkaroundInterface op) { // Get the current input tensor layout, buffer type and memory layout from the // input operand. - TTNNLayoutAttr inputLayoutAttr = getLayoutAttrFromOpOperand(inputOperand); + auto inputValue = + mlir::cast>(inputOperand.get()); + TTNNLayoutAttr inputLayoutAttr = utils::getLayoutAttrFromTensor(inputValue); // Apply the workarounds on the input operand workaround arguments - wa::WorkaroundResult inputWorkaroundResult = + wa::WorkaroundResults inputWorkaroundResults = applyWorkarounds(inputWorkaround, inputLayoutAttr); // If there were no modifications by workarounds, return false. - if (!inputWorkaroundResult.modified()) { + if (!inputWorkaroundResults.isModified()) { return false; } // Apply the workarounds on the input operand by inserting the ToLayoutOp with // the desired tensor layout, buffer type and memory layout. - mlir::Value insertedToLayoutOpValue = createToLayoutOp( - op, inputOperand, rewriter, - inputWorkaroundResult.targetTensorLayoutResult.first, - inputWorkaroundResult.targetTensorBufferTypeResult.first, - inputWorkaroundResult.targetTensorMemoryLayoutResult.first); + mlir::Value insertedToLayoutOpValue = utils::createToLayoutOp( + op.getOperation(), inputValue, rewriter, + inputWorkaroundResults.tensorLayoutResult.targetValue, + inputWorkaroundResults.tensorBufferTypeResult.targetValue, + inputWorkaroundResults.tensorMemoryLayoutResult.targetValue); // Insert to layout op between the current op and the input operand // to convert the input operand to the desired tensor layout, buffer type. @@ -159,34 +114,29 @@ static bool workaroundInputOperand( // the // output result and returns true if the workarounds were successfully // applied. -static bool workaroundOutputOperand( - OpResult &opResult, const wa::TTNNOperandWorkarounds &outputWorkaround, - PatternRewriter &rewriter, wa::TTNNWorkaroundInterface op) { +static bool +workaroundOutputOperand(mlir::TypedValue opResult, + const wa::TTNNOperandWorkarounds &outputWorkaround, + PatternRewriter &rewriter, + wa::TTNNWorkaroundInterface op) { // Get the current output tensor layout, buffer type and memory layout from // the input operand. - TTNNLayoutAttr opResultLayoutAttr = getLayoutAttrFromOpResult(opResult); + TTNNLayoutAttr opResultLayoutAttr = utils::getLayoutAttrFromTensor(opResult); // Apply the workarounds on the output result workaround arguments - wa::WorkaroundResult outputWorkaroundResult = + wa::WorkaroundResults outputWorkaroundResults = wa::applyWorkarounds(outputWorkaround, opResultLayoutAttr); - // At this point, the DPS result should already be propagated, hence we only - // need to verify that the output workaround is in sync with the current DPS - // result. - assert(!(outputWorkaroundResult.modified() && - mlir::isa(op.getOperation())) && - "Output operand workarounds not supported for DPS ops"); - // If there were no modifications by workarounds, return false. - if (!outputWorkaroundResult.modified()) { + if (!outputWorkaroundResults.isModified()) { return false; } // Create the data type attribute. - Type elementType = - getElementType(rewriter.getContext(), - outputWorkaroundResult.targetTensorLayoutResult.first, - opResultLayoutAttr.getDataType()); + Type elementType = utils::getElementType( + rewriter.getContext(), + outputWorkaroundResults.tensorLayoutResult.targetValue, + opResultLayoutAttr.getDataType()); // Get the input operand type. RankedTensorType opResultType = @@ -194,11 +144,10 @@ static bool workaroundOutputOperand( // Create tensor memory layout attribute. TensorMemoryLayoutAttr outputMemLayoutAttr = - outputWorkaroundResult.targetTensorMemoryLayoutResult.first.has_value() + outputWorkaroundResults.tensorMemoryLayoutResult.targetValue ? ttnn::TensorMemoryLayoutAttr::get( rewriter.getContext(), - outputWorkaroundResult.targetTensorMemoryLayoutResult.first - .value()) + *outputWorkaroundResults.tensorMemoryLayoutResult.targetValue) : nullptr; // Create the new output result type with the updated tensor layout, buffer @@ -209,7 +158,7 @@ static bool workaroundOutputOperand( opResultLayoutAttr.withElementType(rewriter.getContext(), elementType) .withBufferType( rewriter.getContext(), - outputWorkaroundResult.targetTensorBufferTypeResult.first) + outputWorkaroundResults.tensorBufferTypeResult.targetValue) .withMemoryLayout(rewriter.getContext(), outputMemLayoutAttr)); // Update the type of result with applied workarounds. @@ -219,15 +168,15 @@ static bool workaroundOutputOperand( // Some ops defines attributes with tensor layout, buffer type and memory // layout, hence we need to update the attributes as well. For example, // the empty op defines layout and memory_config attributes. - if (outputWorkaroundResult.targetTensorLayoutResult.second && + if (outputWorkaroundResults.tensorLayoutResult.isModified() && op->getAttrDictionary().get("layout")) { LayoutAttr updatedLayoutAttr = rewriter.getAttr( - outputWorkaroundResult.targetTensorLayoutResult.first); + outputWorkaroundResults.tensorLayoutResult.targetValue); op->setAttr("layout", updatedLayoutAttr); } - if ((outputWorkaroundResult.targetTensorBufferTypeResult.second || - outputWorkaroundResult.targetTensorMemoryLayoutResult.second) && + if ((outputWorkaroundResults.tensorBufferTypeResult.isModified() || + outputWorkaroundResults.tensorMemoryLayoutResult.isModified()) && op->getAttrDictionary().get("memory_config")) { MemoryConfigAttr currentMemoryConfig = @@ -235,17 +184,17 @@ static bool workaroundOutputOperand( // Create the output memory config attribute. // Check if the buffer type got updated. - if (outputWorkaroundResult.targetTensorBufferTypeResult.second) { + if (outputWorkaroundResults.tensorBufferTypeResult.isModified()) { currentMemoryConfig = currentMemoryConfig.withBufferType( rewriter.getContext(), - outputWorkaroundResult.targetTensorBufferTypeResult.first); + outputWorkaroundResults.tensorBufferTypeResult.targetValue); } // Check if the memory layout got updated. - if (outputWorkaroundResult.targetTensorMemoryLayoutResult.second) { + if (outputWorkaroundResults.tensorMemoryLayoutResult.isModified()) { currentMemoryConfig = currentMemoryConfig.withMemoryLayout( rewriter.getContext(), - outputWorkaroundResult.targetTensorMemoryLayoutResult.first + outputWorkaroundResults.tensorMemoryLayoutResult.targetValue .value()); } @@ -254,61 +203,35 @@ static bool workaroundOutputOperand( } }); - return true; -} - -// Propagate the workaround changes for DPS input operands if they are applied -// in above graph transforms, either in a pattern for a current op, or in a -// pattern matched for a previous ops. -static bool propagateDpsInitChangesToDpsResults(wa::TTNNWorkaroundInterface &op, - PatternRewriter &rewriter) { - // Check if the op is a DPS op. - if (!mlir::isa(op.getOperation())) { - return false; - } - - bool modified = false; - - auto dpsOp = mlir::cast(op.getOperation()); - mlir::OperandRange dpsInits = dpsOp.getDpsInits(); - - // Iterate through all dps destination operands and propagate the changes if - // any. - for (size_t dpsInitIndex = 0; dpsInitIndex < dpsInits.size(); - dpsInitIndex++) { - OpOperand *dpsInit = dpsOp.getDpsInitOperand(dpsInitIndex); - OpResult tiedDpsResult = dpsOp.getTiedOpResult(dpsInit); - - // If the DPS destination is changed, update the DPS result as well. - if (tiedDpsResult.getType() != dpsInit->get().getType()) { - modified = true; - rewriter.modifyOpInPlace( - op, [&]() { tiedDpsResult.setType(dpsInit->get().getType()); }); - } - } + revertOutputLayout(op, rewriter, outputWorkaroundResults, opResult); - return modified; + return true; } // TTNNWorkaroundInterface rewriter applies workarounds to the operands of TTNN // operations. TTNNWorkaroundInterface is an interface on TTNN_Op, so this -// pattern should match each op in the IR. +// pattern should match each op in the IR. Each op has a default implementation +// of the interface that returns a default TTNNOperandsWorkarounds object +// without workarounds. For each op that is required, we can override the +// default implementation to return the specific workarounds for the op. // -// The rewriter processes both input and output operands of TTNN operations: +// The main goal of the rewriter is to apply workaround changes to the input and +// output operands of TTNN operations. The idea is to insert a ToLayoutOp before +// input operands and after output results to apply the necessary workarounds in +// order to keep workaround changes consistent and local to the affected op. The +// rewriter processes both input and output operands of TTNN operations: // 1. **Input Operands**: The rewriter iterates through all input tensor // operands and applies the necessary workarounds. -// - Workarounds are applied by inserting ToLayoutOp with the desired tensor -// layout, buffer type, and memory layout. -// 2. **DPS result propagation**: The rewriter propagates changes to tied DPS -// destination operands to ensure consistency with previous graph -// transformations, either in the current op match or previous op matches. -// 3. **Output Operands**: Output workarounds are applied only if the operation -// is not a DPS op. -// - At this stage, all DPS result changes should be propagated. An assertion -// ensures that the output result workaround matches -// the corresponding DPS output result. +// - If the input workarounds makes any changes to the input operand layout, +// we are inserting a ToLayoutOp before the op to transform the layout to the +// desired tensor layout, buffer type, and memory layout. +// 2. **Output Operands**: The rewriter iterates through all output tensor +// results and applies the necessary workarounds. // - Workarounds are applied by updating the output result type with the new // tensor layout, buffer type, and memory layout. +// - If the output workaround makes any changes to the output layout, we +// are inserting a ToLayoutOp after the op to transform the layout back to +// the previous state in order to maintain the workarounds changes locally. // - For operations that define attributes with tensor layout, buffer type, // and memory layout, these attributes are also updated. // For example, the empty op defines layout and memory_config attributes. @@ -321,7 +244,8 @@ class TTNNOperandsWorkaroundsRewriter LogicalResult matchAndRewrite(wa::TTNNWorkaroundInterface op, PatternRewriter &rewriter) const final { - // To layout op is a special case, we don't want to rewrite it. + // To layout op is a special case, we don't want to rewrite it. We use it + // to apply workarounds to the operands and results of TTNN operations. if (mlir::isa(op.getOperation())) { return failure(); } @@ -348,11 +272,6 @@ class TTNNOperandsWorkaroundsRewriter std::get<1>(pair), rewriter, op); }); - // Propagate the workaround changes for DPS input operands to DPS results if - // they are applied in above graph transforms, either in a pattern for a - // current op, or in a pattern matched for a previous ops. - modified |= propagateDpsInitChangesToDpsResults(op, rewriter); - // Filter out all the output tensor results. auto outputTensorResults = llvm::make_filter_range(op->getOpResults(), [](OpResult v) { @@ -366,8 +285,10 @@ class TTNNOperandsWorkaroundsRewriter [&](std::tuple pair) { modified |= std::get<1>(pair).hasAnyWorkaround() && - workaroundOutputOperand(std::get<0>(pair), - std::get<1>(pair), rewriter, op); + workaroundOutputOperand( + mlir::cast>( + std::get<0>(pair)), + std::get<1>(pair), rewriter, op); }); // Return success if the transformations were applied. diff --git a/lib/Dialect/TTNN/Utils/TransformUtils.cpp b/lib/Dialect/TTNN/Utils/TransformUtils.cpp index bef022fda2..ed4b318ec5 100644 --- a/lib/Dialect/TTNN/Utils/TransformUtils.cpp +++ b/lib/Dialect/TTNN/Utils/TransformUtils.cpp @@ -6,15 +6,16 @@ #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" +#include "ttmlir/Dialect/TTNN/Utils/Utils.h" namespace mlir::tt::ttnn::utils { // Gets or inserts a GetDeviceOp at the top of the current block of the given // operation. -Value getOrInsertDevice(PatternRewriter &rewriter, Operation *op) { +GetDeviceOp getOrInsertDevice(PatternRewriter &rewriter, Operation *op) { Block *block = op->getBlock(); for (auto &op : block->getOperations()) { if (auto deviceOp = dyn_cast(op)) { - return deviceOp.getResult(); + return deviceOp; } } @@ -29,6 +30,65 @@ Value getOrInsertDevice(PatternRewriter &rewriter, Operation *op) { op->getLoc(), rewriter.getType(deviceAttr), ttnn::MeshShapeAttr::get(op->getContext(), meshShape[0], meshShape[1])); rewriter.restoreInsertionPoint(currentInsertionPoint); - return deviceOp.getResult(); + return deviceOp; +} + +// Helper method to insert a ToLayoutOp to convert the input operand to the +// desired tensor layout, buffer type and memory layout. +ToLayoutOp +createToLayoutOp(Operation *op, mlir::TypedValue inputValue, + PatternRewriter &rewriter, Layout targetTensorLayout, + BufferType targetTensorBufferType, + std::optional targetTensorMemoryLayout) { + TTNNLayoutAttr inputLayoutAttr = getLayoutAttrFromTensor(inputValue); + + // Create element type based on tensor layout. + Type elementType = getElementType(rewriter.getContext(), targetTensorLayout, + inputLayoutAttr.getDataType()); + + // Create tensor memory layout attribute. + ttnn::TensorMemoryLayoutAttr outputMemLayoutAttr = + targetTensorMemoryLayout.has_value() + ? ttnn::TensorMemoryLayoutAttr::get(rewriter.getContext(), + targetTensorMemoryLayout.value()) + : nullptr; + + // Get the input operand type. + RankedTensorType inputToLayoutOpType = + mlir::cast(inputValue.getType()); + + // Create the new encoding for the output tensor type. + TTNNLayoutAttr toLayoutOpResultEncoding = + inputLayoutAttr.withElementType(rewriter.getContext(), elementType) + .withBufferType(rewriter.getContext(), targetTensorBufferType) + .withMemoryLayout(rewriter.getContext(), outputMemLayoutAttr); + + // Create the output result type with the new encoding. + RankedTensorType toLayoutOpResultType = + ttnn::utils::createRankedTensorTypeWithEncoding(inputToLayoutOpType, + toLayoutOpResultEncoding); + + // Create the output memory config attribute. + ttnn::MemoryConfigAttr outputMemConfigAttr = ttnn::MemoryConfigAttr::get( + rewriter.getContext(), + ttnn::BufferTypeAttr::get(rewriter.getContext(), targetTensorBufferType), + ttnn::ShardSpecAttr::get( + op->getContext(), + ttnn::ShapeAttr::get(rewriter.getContext(), + toLayoutOpResultEncoding.getShardShape())), + outputMemLayoutAttr); + + // Get the device value if the tensor output is not on the host. + auto deviceValue = targetTensorBufferType == ttnn::BufferType::SystemMemory + ? nullptr + : Value(utils::getOrInsertDevice(rewriter, op)); + + // Create a ToLayoutOp to convert the input operand to the desired + // tensor layout, buffer type and memory layout. + return rewriter.create( + op->getLoc(), toLayoutOpResultType, inputValue, + LayoutAttr::get(rewriter.getContext(), targetTensorLayout), + DataTypeAttr::get(rewriter.getContext(), inputLayoutAttr.getDataType()), + outputMemConfigAttr, deviceValue); } } // namespace mlir::tt::ttnn::utils diff --git a/lib/Dialect/TTNN/Utils/Utils.cpp b/lib/Dialect/TTNN/Utils/Utils.cpp index 6976dd35f4..90091b1ff8 100644 --- a/lib/Dialect/TTNN/Utils/Utils.cpp +++ b/lib/Dialect/TTNN/Utils/Utils.cpp @@ -4,6 +4,9 @@ #include "ttmlir/Dialect/TTNN/Utils/Utils.h" +#include "ttmlir/Dialect/TTNN/Types/Types.h" +#include + namespace mlir::tt::ttnn::utils { // Map TT::MemorySpace to TTNN::BufferType // @@ -117,24 +120,29 @@ createRankedTensorTypeWithEncoding(RankedTensorType tensorType, tensorType.getElementType(), encoding); } -uint64_t getOpOutputL1Usage(Operation *op, TTNNLayoutAttr opLayout, - DeviceAttr &deviceAttr) { - assert(mlir::isa(op->getResult(0).getType()) && - "L1 memory usage of the ops without output tensors cannot be " - "calculated."); - +uint64_t getOpOutputL1Usage(TTNNLayoutAttr opLayout) { // In case the opLayout is not in L1 memory space, L1 memory usage is 0. // if (opLayout.hasDRAMBufferType()) { return 0; } - llvm::ArrayRef opOutputTensorShape = - mlir::cast(op->getResult(0).getType()).getShape(); + return opLayout.getShardSizeInBytes(); +} + +// Helper method to get the tensor layout attribute from the value. +TTNNLayoutAttr +getLayoutAttrFromTensor(mlir::TypedValue tensorValue) { + return mlir::cast(tensorValue.getType().getEncoding()); +} - uint64_t opL1OutputUsage = - opLayout.getTensorSizeInBytes(opOutputTensorShape, deviceAttr); - return opL1OutputUsage; +// Helper method to get the element type for the given tensor layout and data. +Type getElementType(MLIRContext *context, Layout tensorLayout, + DataType dataType) { + return tensorLayout == Layout::Tile + ? TileType::get(context, {ttnn::TILE_HEIGHT, ttnn::TILE_WIDTH}, + dataType) + : ttnn::utils::createRowMajorTypeFromDtype(context, dataType); } } // namespace mlir::tt::ttnn::utils diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 6e3745c91d..5fd09d5e4a 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -753,9 +753,9 @@ createEmbeddingOp(FlatbufferObjectCache &cache, EmbeddingOp op) { cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); auto in1 = cache.at<::tt::target::TensorRef>( getOperandThroughDPSOps(op.getWeight())); - auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, - kHostAllocatedAddress, kHostAllocatedSize); - return ::tt::target::ttnn::CreateEmbeddingOp(*cache.fbb, in0, in1, output); + auto out = cache.at<::tt::target::TensorRef>( + getOperandThroughDPSOps(op.getResult())); + return ::tt::target::ttnn::CreateEmbeddingOp(*cache.fbb, in0, in1, out); } template diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index 268959e8a2..eac3b0ebb2 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -118,19 +118,6 @@ Tensor getOpOutputTensor(OpContext opContextHandle, std::vector getTensorData(Tensor tensor); -namespace legacy { -/* Will be deprecated soon once FEs migrate to new API */ - -Event submit(Device deviceHandle, Binary executableHandle, - std::uint32_t programIndex, std::vector const &inputs, - std::vector const &outputs); - -void runProgram(::ttnn::MeshDevice &meshDevice, Binary &executableHandle, - std::uint32_t programIndex, - std::vector<::ttnn::Tensor *> const &inputs, - std::vector<::ttnn::Tensor *> const &outputs); -} // namespace legacy - std::vector submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputs); diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 2da673ad19..1b5b775b07 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -456,11 +456,8 @@ Event submit(Device deviceHandle, Binary executableHandle, std::vector const &outputHandles) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - LOG_WARNING("This submit API will soon be deprecated. Please switch to the " - "new API."); - return ::tt::runtime::ttnn::legacy::submit(deviceHandle, executableHandle, - programIndex, inputHandles, - outputHandles); + LOG_FATAL("This submit API is deprecated for TTNN. Please switch to the " + "new API."); } #endif diff --git a/runtime/lib/ttnn/CMakeLists.txt b/runtime/lib/ttnn/CMakeLists.txt index 6a68c4c7b9..a1af94b868 100644 --- a/runtime/lib/ttnn/CMakeLists.txt +++ b/runtime/lib/ttnn/CMakeLists.txt @@ -24,6 +24,9 @@ add_library(TTRuntimeTTNN ) # We have to set the C++ standard to 20 because tt-metal requires it set_property(TARGET TTRuntimeTTNN PROPERTY CXX_STANDARD 20) +target_include_directories(TTRuntimeTTNN PRIVATE + ${PROJECT_SOURCE_DIR}/runtime/lib/ttnn +) target_include_directories(TTRuntimeTTNN PUBLIC ${PROJECT_SOURCE_DIR}/runtime/include ${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common diff --git a/runtime/lib/ttnn/operations/CMakeLists.txt b/runtime/lib/ttnn/operations/CMakeLists.txt index cc84cc8dae..953310193f 100644 --- a/runtime/lib/ttnn/operations/CMakeLists.txt +++ b/runtime/lib/ttnn/operations/CMakeLists.txt @@ -44,6 +44,9 @@ add_library(TTRuntimeTTNNOps set_property(TARGET TTRuntimeTTNNOps PROPERTY CXX_STANDARD 20) target_compile_options(TTRuntimeTTNNOps PUBLIC -mavx -mavx2 -fsized-deallocation) +target_include_directories(TTRuntimeTTNNOps PRIVATE + ${PROJECT_SOURCE_DIR}/runtime/lib/ttnn +) target_include_directories(TTRuntimeTTNNOps PUBLIC ${PROJECT_SOURCE_DIR}/runtime/include ${PROJECT_SOURCE_DIR}/runtime/lib/ttnn/include diff --git a/runtime/lib/ttnn/operations/ccl/all_gather.cpp b/runtime/lib/ttnn/operations/ccl/all_gather.cpp index eee27e7bab..8c9e7e00c2 100644 --- a/runtime/lib/ttnn/operations/ccl/all_gather.cpp +++ b/runtime/lib/ttnn/operations/ccl/all_gather.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "all_gather.h" +#include "operations/ccl/all_gather.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" #include "tt/runtime/ttnn/utils.h" diff --git a/runtime/lib/ttnn/operations/context/get_device.cpp b/runtime/lib/ttnn/operations/context/get_device.cpp index 376b9ff744..3efc45cf50 100644 --- a/runtime/lib/ttnn/operations/context/get_device.cpp +++ b/runtime/lib/ttnn/operations/context/get_device.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "get_device.h" +#include "operations/context/get_device.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" diff --git a/runtime/lib/ttnn/operations/conv/conv2d.cpp b/runtime/lib/ttnn/operations/conv/conv2d.cpp index 29a866f816..26d71df1ac 100644 --- a/runtime/lib/ttnn/operations/conv/conv2d.cpp +++ b/runtime/lib/ttnn/operations/conv/conv2d.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "conv2d.h" +#include "operations/conv/conv2d.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" diff --git a/runtime/lib/ttnn/operations/creation/arange.cpp b/runtime/lib/ttnn/operations/creation/arange.cpp index f51937462a..dc2d56f3f3 100644 --- a/runtime/lib/ttnn/operations/creation/arange.cpp +++ b/runtime/lib/ttnn/operations/creation/arange.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "arange.h" +#include "operations/creation/arange.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/ttnn/operations/utils.h" #include "tt/runtime/ttnn/utils.h" diff --git a/runtime/lib/ttnn/operations/creation/empty.cpp b/runtime/lib/ttnn/operations/creation/empty.cpp index e92fbc893b..929ae05602 100644 --- a/runtime/lib/ttnn/operations/creation/empty.cpp +++ b/runtime/lib/ttnn/operations/creation/empty.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "empty.h" +#include "operations/creation/empty.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/detail/workarounds.h" diff --git a/runtime/lib/ttnn/operations/creation/full.cpp b/runtime/lib/ttnn/operations/creation/full.cpp index 7a3654296f..c5fff8342f 100644 --- a/runtime/lib/ttnn/operations/creation/full.cpp +++ b/runtime/lib/ttnn/operations/creation/full.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "full.h" +#include "operations/creation/full.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/detail/workarounds.h" diff --git a/runtime/lib/ttnn/operations/creation/ones.cpp b/runtime/lib/ttnn/operations/creation/ones.cpp index 36cf5d5af4..ab82d6d71b 100644 --- a/runtime/lib/ttnn/operations/creation/ones.cpp +++ b/runtime/lib/ttnn/operations/creation/ones.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ones.h" +#include "operations/creation/ones.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/ttnn/operations/utils.h" @@ -38,9 +38,9 @@ void run(const ::tt::target::ttnn::OnesOp *op, ProgramContext &context) { if (op->device()) { DeviceVariant targetDevice = context.getTargetDevice(op->device()->global_id()); - assert(std::holds_alternative>( - targetDevice) && - "ttnn::ones does not support MeshDevice."); + LOG_ASSERT(std::holds_alternative>( + targetDevice), + "ttnn::ones does not support MeshDevice."); device = std::get>(targetDevice); } diff --git a/runtime/lib/ttnn/operations/data_movement/concat.cpp b/runtime/lib/ttnn/operations/data_movement/concat.cpp index 189f0e6d60..fc4cbd70b2 100644 --- a/runtime/lib/ttnn/operations/data_movement/concat.cpp +++ b/runtime/lib/ttnn/operations/data_movement/concat.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "concat.h" +#include "operations/data_movement/concat.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" diff --git a/runtime/lib/ttnn/operations/data_movement/reshape.cpp b/runtime/lib/ttnn/operations/data_movement/reshape.cpp index b38f96872f..0be7114f00 100644 --- a/runtime/lib/ttnn/operations/data_movement/reshape.cpp +++ b/runtime/lib/ttnn/operations/data_movement/reshape.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "reshape.h" +#include "operations/data_movement/reshape.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" diff --git a/runtime/lib/ttnn/operations/data_movement/slice.cpp b/runtime/lib/ttnn/operations/data_movement/slice.cpp index 87ba89d800..03b82c94ae 100644 --- a/runtime/lib/ttnn/operations/data_movement/slice.cpp +++ b/runtime/lib/ttnn/operations/data_movement/slice.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "slice.h" +#include "operations/data_movement/slice.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "ttmlir/Target/TTNN/program_generated.h" diff --git a/runtime/lib/ttnn/operations/data_movement/transpose.cpp b/runtime/lib/ttnn/operations/data_movement/transpose.cpp index c86c0ee10a..634c93c3c5 100644 --- a/runtime/lib/ttnn/operations/data_movement/transpose.cpp +++ b/runtime/lib/ttnn/operations/data_movement/transpose.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "transpose.h" +#include "operations/data_movement/transpose.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" diff --git a/runtime/lib/ttnn/operations/deletion/deallocate.cpp b/runtime/lib/ttnn/operations/deletion/deallocate.cpp index e871a9ea64..2ddaa2de4f 100644 --- a/runtime/lib/ttnn/operations/deletion/deallocate.cpp +++ b/runtime/lib/ttnn/operations/deletion/deallocate.cpp @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 -#include "deallocate.h" +#include "operations/deletion/deallocate.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" diff --git a/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp b/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp index ff47bdcdd8..40f80e259e 100644 --- a/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp +++ b/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 -#include "binary.h" +#include "operations/eltwise/binary/binary.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/eltwise/binary/utils.h" diff --git a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp index 921b542ed2..08124841a0 100644 --- a/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp +++ b/runtime/lib/ttnn/operations/eltwise/binary/binary_composite.cpp @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 -#include "binary_composite.h" +#include "operations/eltwise/binary/binary_composite.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/eltwise/binary/utils.h" diff --git a/runtime/lib/ttnn/operations/eltwise/ternary/ternary.cpp b/runtime/lib/ttnn/operations/eltwise/ternary/ternary.cpp index 44f1413898..01c7694f08 100644 --- a/runtime/lib/ttnn/operations/eltwise/ternary/ternary.cpp +++ b/runtime/lib/ttnn/operations/eltwise/ternary/ternary.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ternary.h" +#include "operations/eltwise/ternary/ternary.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/eltwise/ternary/utils.h" diff --git a/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp b/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp index d24dc24f8d..3ba13a4fdb 100644 --- a/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp +++ b/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 -#include "unary.h" +#include "operations/eltwise/unary/unary.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/eltwise/unary/utils.h" diff --git a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp index 31514f0fe5..632377c2f4 100644 --- a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp +++ b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 -#include "unary_composite.h" +#include "operations/eltwise/unary/unary_composite.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/eltwise/unary/utils.h" diff --git a/runtime/lib/ttnn/operations/embedding/embedding.cpp b/runtime/lib/ttnn/operations/embedding/embedding.cpp index 511d8256de..e1f4df4587 100644 --- a/runtime/lib/ttnn/operations/embedding/embedding.cpp +++ b/runtime/lib/ttnn/operations/embedding/embedding.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "embedding.h" +#include "operations/embedding/embedding.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" diff --git a/runtime/lib/ttnn/operations/embedding/embedding_backward.cpp b/runtime/lib/ttnn/operations/embedding/embedding_backward.cpp index 8d340e5686..8131626578 100644 --- a/runtime/lib/ttnn/operations/embedding/embedding_backward.cpp +++ b/runtime/lib/ttnn/operations/embedding/embedding_backward.cpp @@ -2,14 +2,14 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "embedding_backward.h" +#include "operations/embedding/embedding_backward.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" #include "tt/runtime/ttnn/utils.h" #include "ttmlir/Target/TTNN/program_generated.h" -#include -#include +#include "ttnn/operations/embedding_backward/embedding_backward.hpp" +#include "ttnn/tensor/tensor.hpp" namespace tt::runtime::ttnn::operations::embedding_backward { void run(const ::tt::target::ttnn::EmbeddingBackwardOp *op, diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp index f97f71e403..662b16faec 100644 --- a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 -#include "utils.h" +#include "tt/runtime/ttnn/operations/eltwise/binary/utils.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/workarounds.h" diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp index b5f707f2cc..60f5efe3b3 100644 --- a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 -#include "utils.h" +#include "tt/runtime/ttnn/operations/eltwise/ternary/utils.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/workarounds.h" diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp index d8437666de..dc9aaa9669 100644 --- a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 -#include "utils.h" +#include "tt/runtime/ttnn/operations/eltwise/unary/utils.h" #include "tt/runtime/detail/logger.h" namespace tt::runtime::ttnn::operations::unary { diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp index 60ee2ddc2b..d9085b3df1 100644 --- a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.cpp @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 -#include "utils.h" +#include "tt/runtime/ttnn/operations/utils.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/ttnn/utils.h" diff --git a/runtime/lib/ttnn/operations/kv_cache/fill_cache.cpp b/runtime/lib/ttnn/operations/kv_cache/fill_cache.cpp index 89022f64a1..50648a2371 100644 --- a/runtime/lib/ttnn/operations/kv_cache/fill_cache.cpp +++ b/runtime/lib/ttnn/operations/kv_cache/fill_cache.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "fill_cache.h" +#include "operations/kv_cache/fill_cache.h" namespace tt::runtime::ttnn::operations::kv_cache { void run(const ::tt::target::ttnn::FillCacheOp *op, ProgramContext &context) { diff --git a/runtime/lib/ttnn/operations/kv_cache/update_cache.cpp b/runtime/lib/ttnn/operations/kv_cache/update_cache.cpp index fae1da40c6..10f4626e59 100644 --- a/runtime/lib/ttnn/operations/kv_cache/update_cache.cpp +++ b/runtime/lib/ttnn/operations/kv_cache/update_cache.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "update_cache.h" +#include "operations/kv_cache/update_cache.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/workarounds.h" diff --git a/runtime/lib/ttnn/operations/layout/from_device.cpp b/runtime/lib/ttnn/operations/layout/from_device.cpp index e26e3be2a3..b2a5830554 100644 --- a/runtime/lib/ttnn/operations/layout/from_device.cpp +++ b/runtime/lib/ttnn/operations/layout/from_device.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "to_device.h" +#include "operations/layout/from_device.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" diff --git a/runtime/lib/ttnn/operations/layout/to_device.cpp b/runtime/lib/ttnn/operations/layout/to_device.cpp index 414afc9f05..c885ea530a 100644 --- a/runtime/lib/ttnn/operations/layout/to_device.cpp +++ b/runtime/lib/ttnn/operations/layout/to_device.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "to_device.h" +#include "operations/layout/to_device.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" diff --git a/runtime/lib/ttnn/operations/layout/to_layout.cpp b/runtime/lib/ttnn/operations/layout/to_layout.cpp index bf80ef292e..baeb63db03 100644 --- a/runtime/lib/ttnn/operations/layout/to_layout.cpp +++ b/runtime/lib/ttnn/operations/layout/to_layout.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "to_layout.h" +#include "operations/layout/to_layout.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" diff --git a/runtime/lib/ttnn/operations/layout/to_memory_config.cpp b/runtime/lib/ttnn/operations/layout/to_memory_config.cpp index 61f09bdced..26ede90b0b 100644 --- a/runtime/lib/ttnn/operations/layout/to_memory_config.cpp +++ b/runtime/lib/ttnn/operations/layout/to_memory_config.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "to_memory_config.h" +#include "operations/layout/to_memory_config.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" diff --git a/runtime/lib/ttnn/operations/layout/typecast.cpp b/runtime/lib/ttnn/operations/layout/typecast.cpp index e59a64a401..63a9ba63da 100644 --- a/runtime/lib/ttnn/operations/layout/typecast.cpp +++ b/runtime/lib/ttnn/operations/layout/typecast.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "typecast.h" +#include "operations/layout/typecast.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" #include "tt/runtime/ttnn/utils.h" diff --git a/runtime/lib/ttnn/operations/matmul/matmul.cpp b/runtime/lib/ttnn/operations/matmul/matmul.cpp index 896797d59c..5cfcc6b8da 100644 --- a/runtime/lib/ttnn/operations/matmul/matmul.cpp +++ b/runtime/lib/ttnn/operations/matmul/matmul.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "matmul.h" +#include "operations/matmul/matmul.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" diff --git a/runtime/lib/ttnn/operations/normalization/softmax.cpp b/runtime/lib/ttnn/operations/normalization/softmax.cpp index 432f920956..86bbe1b3a1 100644 --- a/runtime/lib/ttnn/operations/normalization/softmax.cpp +++ b/runtime/lib/ttnn/operations/normalization/softmax.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "softmax.h" +#include "operations/normalization/softmax.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" diff --git a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp index a20bdc51b4..51e58e0ebf 100644 --- a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp +++ b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "maxpool2d.h" +#include "operations/pool/maxpool2d.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/detail/workarounds.h" diff --git a/runtime/lib/ttnn/operations/reduction/reduction.cpp b/runtime/lib/ttnn/operations/reduction/reduction.cpp index a74373ee9f..631df3f51a 100644 --- a/runtime/lib/ttnn/operations/reduction/reduction.cpp +++ b/runtime/lib/ttnn/operations/reduction/reduction.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "reduction.h" +#include "operations/reduction/reduction.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/ttnn/operations/utils.h" diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index b1f95a8e83..5d1bf6fd0d 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -45,7 +45,7 @@ namespace tt::runtime::ttnn { using LogType = ::tt::runtime::logger::LogType; -void tracyLogOpLocation(const ::tt::target::ttnn::Operation *op) { +static void tracyLogOpLocation(const ::tt::target::ttnn::Operation *op) { #ifdef TT_RUNTIME_ENABLE_PERF_TRACE TracyMessage(op->loc_info()->c_str(), op->loc_info()->size()); #endif @@ -235,76 +235,6 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { } } -namespace legacy { - -static bool handleNopProgram(::tt::target::ttnn::Program const *program, - std::vector<::ttnn::Tensor *> const &inputs, - std::vector<::ttnn::Tensor *> const &outputs) { - - bool isNop = program->inputs()->size() == 1 && - program->outputs()->size() == 1 && - program->inputs()->Get(0)->global_id() == - program->outputs()->Get(0)->global_id(); - - if (isNop) { - void *src = ::tt::tt_metal::get_raw_host_data_ptr(*inputs.at(0)); - void *dst = ::tt::tt_metal::get_raw_host_data_ptr(*outputs.at(0)); - std::uint32_t size = outputs[0]->volume() * outputs[0]->element_size(); - std::memcpy(dst, src, size); - } - return isNop; -} - -void runProgram(::ttnn::MeshDevice &meshDevice, Binary &executableHandle, - std::uint32_t programIndex, - std::vector<::ttnn::Tensor *> const &inputs, - std::vector<::ttnn::Tensor *> const &outputs) { - ::tt::target::ttnn::TTNNBinary const &fbb = *getBinary(executableHandle); - ::tt::target::ttnn::Program const *program = - fbb.programs()->Get(programIndex); - if (handleNopProgram(program, inputs, outputs)) { - return; - } - std::unordered_map liveTensors; - std::vector programInputs; - int inputIndex = 0; - LOG_ASSERT(program->inputs()->size() == inputs.size(), - "Program input size mismatch: ", program->inputs()->size(), - " != ", inputs.size()); - for (::tt::target::TensorRef const *input : *program->inputs()) { - auto [iter, inserted] = - liveTensors.try_emplace(input->global_id(), inputs[inputIndex++]); - LOG_ASSERT(inserted, "Duplicate input tensor"); - programInputs.push_back(input->global_id()); - } - - int outputIndex = 0; - std::vector programOutputs; - LOG_ASSERT(program->outputs()->size() == outputs.size()); - for (::tt::target::TensorRef const *output : *program->outputs()) { - auto [iter, inserted] = - liveTensors.try_emplace(output->global_id(), outputs[outputIndex++]); - LOG_ASSERT(inserted, "Duplicate output tensor"); - programOutputs.push_back(output->global_id()); - } - ProgramExecutor executor(executableHandle, liveTensors, programInputs, - programOutputs, &meshDevice); - executor.execute(program); - outputIndex = 0; - for (uint32_t outputId : programOutputs) { - const ::ttnn::Tensor &src = - executor.getContext().getTensorPool().at(outputId); - const ::ttnn::Tensor &dst = *(outputs[outputIndex++]); - size_t srcSize = src.volume() * src.element_size(); - size_t dstSize = dst.volume() * dst.element_size(); - LOG_ASSERT(srcSize == dstSize, "Output tensor size mismatch"); - const void *srcPtr = ::tt::tt_metal::get_raw_host_data_ptr(src); - void *dstPtr = ::tt::tt_metal::get_raw_host_data_ptr(dst); - std::memcpy(dstPtr, srcPtr, dstSize); - } -} -} // namespace legacy - std::vector runProgram(::ttnn::MeshDevice &meshDevice, Binary executableHandle, std::uint32_t programIndex, diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 18bb0d47c8..56a205546a 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -489,34 +489,6 @@ std::vector getTensorData(Tensor tensor) { static_cast(dataPtr) + nnTensor->volume()); } -namespace legacy { - -Event submit(Device deviceHandle, Binary executableHandle, - std::uint32_t programIndex, - std::vector const &inputHandles, - std::vector const &outputHandles) { - ::ttnn::MeshDevice &meshDevice = - deviceHandle.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN); - std::vector<::ttnn::Tensor *> inputs; - inputs.reserve(inputHandles.size()); - for (auto &input : inputHandles) { - LOG_ASSERT(input.matchesRuntime(DeviceRuntime::TTNN)); - inputs.push_back(static_cast<::ttnn::Tensor *>(input.handle.get())); - } - - std::vector<::ttnn::Tensor *> outputs; - outputs.reserve(outputHandles.size()); - for (auto &output : outputHandles) { - LOG_ASSERT(output.matchesRuntime(DeviceRuntime::TTNN)); - outputs.push_back(static_cast<::ttnn::Tensor *>(output.handle.get())); - } - - tt::runtime::ttnn::legacy::runProgram(meshDevice, executableHandle, - programIndex, inputs, outputs); - return Event(nullptr, DeviceRuntime::TTNN); -} -} // namespace legacy - std::vector submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputHandles) { diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/reverse_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/reverse_op.mlir new file mode 100644 index 0000000000..e223ffb4f5 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/reverse_op.mlir @@ -0,0 +1,12 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s + +module @jit_eltwise_reverse attributes {} { + func.func @reverse_op(%arg0: tensor<32x64xf32>) -> tensor<32x64xf32> { + %0 = "stablehlo.reverse"(%arg0) {dimensions = array} : (tensor<32x64xf32>) -> tensor<32x64xf32> + // CHECK: %[[EMPTY:[0-9]+]] = tensor.empty() : tensor<32x64xf32> + // CHECK: %[[REV:[0-9]+]] = "ttir.reverse"(%arg0, %0) <{dimensions = array}> : (tensor<32x64xf32>, tensor<32x64xf32>) -> tensor<32x64xf32> + return %0 : tensor<32x64xf32> + // CHECK: return %[[REV]] : tensor<32x64xf32> + } +} diff --git a/test/ttmlir/Dialect/TTIR/reverse/reverse_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/reverse/reverse_tests_negative.mlir new file mode 100644 index 0000000000..744996bc56 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/reverse/reverse_tests_negative.mlir @@ -0,0 +1,34 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for reverse operation + +// Verify that parsing fails if dimensions are not unique. +module attributes {} { + func.func @reverse_non_unique_dims(%arg0: tensor<32x64xf32>) -> tensor<32x64xf32> { + // CHECK: error: 'ttir.reverse' op dimensions should be unique. Got: 0, 0 + %0 = tensor.empty() : tensor<32x64xf32> + %1 = "ttir.reverse"(%arg0, %0) <{dimensions = array}> : (tensor<32x64xf32>, tensor<32x64xf32>) -> tensor<32x64xf32> + return %1 : tensor<32x64xf32> + } +} + +// Verify that parsing fails if any dimension is negative. +// ----- +module attributes {} { + func.func @reverse_negative_dim(%arg0: tensor<32x64xf32>) -> tensor<32x64xf32> { + // CHECK: error: 'ttir.reverse' op all dimensions should be non-negative. Got dimension: -1 + %0 = tensor.empty() : tensor<32x64xf32> + %1 = "ttir.reverse"(%arg0, %0) <{dimensions = array}> : (tensor<32x64xf32>, tensor<32x64xf32>) -> tensor<32x64xf32> + return %1 : tensor<32x64xf32> + } +} + +// Verify that parsing fails if any dimension is out of range [0, operandRank). +// ----- +module attributes {} { + func.func @reverse_out_of_bounds_dim(%arg0: tensor<32x64xf32>) -> tensor<32x64xf32> { + // CHECK: error: 'ttir.reverse' op all dimensions should be in interval [0, 2). Got dimension: 2 + %0 = tensor.empty() : tensor<32x64xf32> + %1 = "ttir.reverse"(%arg0, %0) <{dimensions = array}> : (tensor<32x64xf32>, tensor<32x64xf32>) -> tensor<32x64xf32> + return %1 : tensor<32x64xf32> + } +} diff --git a/test/ttmlir/Dialect/TTIR/reverse/reverse_tests_positive.mlir b/test/ttmlir/Dialect/TTIR/reverse/reverse_tests_positive.mlir new file mode 100644 index 0000000000..babd0c13ec --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/reverse/reverse_tests_positive.mlir @@ -0,0 +1,24 @@ +// RUN: ttmlir-opt %s | FileCheck %s + +module attributes {} { + func.func @reverse_first_dim(%arg0: tensor<32x64xf32>) -> tensor<32x64xf32> { + %0 = tensor.empty() : tensor<32x64xf32> + // CHECK: %[[C:.*]] = "ttir.reverse"[[C:.*]] + %1 = "ttir.reverse"(%arg0, %0) <{dimensions = array}> : (tensor<32x64xf32>, tensor<32x64xf32>) -> tensor<32x64xf32> + return %1 : tensor<32x64xf32> + } + + func.func @reverse_second_dim(%arg0: tensor<32x64xf32>) -> tensor<32x64xf32> { + %0 = tensor.empty() : tensor<32x64xf32> + // CHECK: %[[C:.*]] = "ttir.reverse"[[C:.*]] + %1 = "ttir.reverse"(%arg0, %0) <{dimensions = array}> : (tensor<32x64xf32>, tensor<32x64xf32>) -> tensor<32x64xf32> + return %1 : tensor<32x64xf32> + } + + func.func @reverse_both_dims(%arg0: tensor<32x64xf32>) -> tensor<32x64xf32> { + %0 = tensor.empty() : tensor<32x64xf32> + // CHECK: %[[C:.*]] = "ttir.reverse"[[C:.*]] + %1 = "ttir.reverse"(%arg0, %0) <{dimensions = array}> : (tensor<32x64xf32>, tensor<32x64xf32>) -> tensor<32x64xf32> + return %1 : tensor<32x64xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/max_pool2d_workaround.mlir b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/max_pool2d_workaround.mlir new file mode 100644 index 0000000000..f24ff11845 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/max_pool2d_workaround.mlir @@ -0,0 +1,39 @@ +// RUN: ttmlir-opt --ttnn-workaround --canonicalize %s | FileCheck %s +#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> +#dram = #ttnn.buffer_type +#system_desc = #tt.system_desc<[{role = host, target_triple = "x86_64-pc-linux-gnu"}], [{arch = , grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [, , , , , , , , , , , ], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]> +#system_memory = #ttnn.buffer_type +#ttnn_layout = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 128 + d2, d3), <1x1>, memref<16384x32xbf16, #system_memory>> +#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 64 + d2, d3), <1x1>, memref<4096x32xbf16, #system_memory>> +#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 128 + d2, d3), <1x1>, memref<512x1x!tt.tile<32x32, bf16>, #dram>, > +#ttnn_layout3 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 64 + d2, d3), <1x1>, memref<128x1x!tt.tile<32x32, bf16>, #dram>, > +module attributes {tt.device = #device, tt.system_desc = #system_desc} { + func.func @forward(%arg0: tensor<1x128x128x32xbf16, #ttnn_layout>) -> tensor<1x64x64x32xbf16, #ttnn_layout1> { + %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> + // CHECK: %[[DEVICE_OP:.*]] = "ttnn.get_device"[[C:.*]] + %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<512x1>>, >}> : (tensor<1x128x128x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<1x128x128x32xbf16, #ttnn_layout2> + %2 = "ttnn.reshape"(%1) <{shape = [1 : i32, 1 : i32, 16384 : i32, 32 : i32]}> : (tensor<1x128x128x32xbf16, #ttnn_layout2>) -> tensor<1x1x16384x32xbf16, #ttnn_layout2> + %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<128x1>>, >, shape = #ttnn.shape<1x1x4096x32>}> : (!tt.device<#device>) -> tensor<1x1x4096x32xbf16, #ttnn_layout3> + // CHECK: %[[EMPTY_OP:.*]] = "ttnn.empty"(%[[DEVICE_OP]]) + // Check that the input operand is transformed into the row major layout. + // CHECK-NEXT: %[[TO_LAYOUT_INPUT:.*]] = "ttnn.to_layout" + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<16384x32>>, > + // CHECK-SAME: -> tensor<1x1x16384x32xbf16, + // Check that the output operand is transformed into the row major layout. + // CHECK-NEXT: %[[TO_LAYOUT_OUTPUT_DPS:.*]] = "ttnn.to_layout"(%[[EMPTY_OP]], %[[DEVICE_OP]]) + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<4096x32>>, > + // CHECK-SAME: -> tensor<1x1x4096x32xbf16, + %4 = "ttnn.max_pool2d"(%2, %3, %0) <{batch_size = 1 : si32, ceil_mode = false, channels = 32 : si32, dilation_height = 1 : si32, dilation_width = 1 : si32, input_height = 128 : si32, input_width = 128 : si32, kernel_height = 2 : si32, kernel_width = 2 : si32, padding_height = 0 : si32, padding_width = 0 : si32, stride_height = 2 : si32, stride_width = 2 : si32}> : (tensor<1x1x16384x32xbf16, #ttnn_layout2>, tensor<1x1x4096x32xbf16, #ttnn_layout3>, !tt.device<#device>) -> tensor<1x1x4096x32xbf16, #ttnn_layout3> + // CHECK-NEXT: %[[MAX_POOL_2D_OP:.*]] = "ttnn.max_pool2d"(%[[TO_LAYOUT_INPUT]], %[[TO_LAYOUT_OUTPUT_DPS]], %[[DEVICE_OP]]) + // CHECK-NEXT: %[[TO_LAYOUT_OUTPUT:.*]] = "ttnn.to_layout"(%[[MAX_POOL_2D_OP]], %[[DEVICE_OP]]) + // CHECK-SAME: layout = #ttnn.layout + // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<128x1>>, > + // CHECK-SAME: -> tensor<1x1x4096x32xbf16 + %5 = "ttnn.reshape"(%4) <{shape = [1 : i32, 64 : i32, 64 : i32, 32 : i32]}> : (tensor<1x1x4096x32xbf16, #ttnn_layout3>) -> tensor<1x64x64x32xbf16, #ttnn_layout3> + // CHECK-NEXT: ttnn.reshape + %6 = "ttnn.to_layout"(%5) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#system_memory, <<4096x32>>>}> : (tensor<1x64x64x32xbf16, #ttnn_layout3>) -> tensor<1x64x64x32xbf16, #ttnn_layout1> + return %6 : tensor<1x64x64x32xbf16, #ttnn_layout1> + } +} diff --git a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/simple_workaround.mlir b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/simple_workaround.mlir deleted file mode 100644 index 41edc95bf8..0000000000 --- a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/simple_workaround.mlir +++ /dev/null @@ -1,31 +0,0 @@ -// RUN: ttmlir-opt --ttnn-workaround=ttnn-enable-layout-workaround-pass %s | FileCheck %s -#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> -#dram = #ttnn.buffer_type -#system_memory = #ttnn.buffer_type -#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xf32, #system_memory>> -#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xf32, #dram>, > -#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, > -module attributes {tt.device = #device} { - func.func @forward(%arg0: tensor<64x128xf32, #ttnn_layout>) -> tensor<64x128xf32, #ttnn_layout> { - %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> - // CHECK: %[[DEVICE_OP:.*]] = "ttnn.get_device"[[C:.*]] - %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<2x4>>, >}> : (tensor<64x128xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout1> - // CHECK-NEXT: %[[RM_DEVICE_LAYOUT_OP:.*]] = "ttnn.to_layout"(%arg0, %[[DEVICE_OP]]) - // CHECK-SAME: layout = #ttnn.layout - // CHECK-SAME: -> tensor<64x128xf32, #ttnn_layout1> - %2 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<64x128>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout2> - // CHECK-NEXT: %[[EMPTY_OP:.*]] = "ttnn.empty"(%[[DEVICE_OP]]) - // CHECK-SAME: layout = #ttnn.layout - // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<64x128>>, > - // CHECK-SAME: -> tensor<64x128xf32, #ttnn_layout1> - %3 = "ttnn.abs"(%1, %2) <{operandSegmentSizes = array}> : (tensor<64x128xf32, #ttnn_layout1>, tensor<64x128xf32, #ttnn_layout2>) -> tensor<64x128xf32, #ttnn_layout2> - // CHECK-NEXT: %[[TO_LAYOUT_LEFT:.*]] = "ttnn.to_layout"(%[[RM_DEVICE_LAYOUT_OP]], %[[DEVICE_OP]]) - // CHECK-SAME: layout = #ttnn.layout - // CHECK-SAME: -> tensor<64x128xf32, #ttnn_layout2> - // CHECK-NEXT: %[[TO_LAYOUT_RIGHT:.*]] = "ttnn.to_layout"(%[[EMPTY_OP]], %[[DEVICE_OP]]) - // CHECK-SAME: layout = #ttnn.layout - // CHECK-SAME: -> tensor<64x128xf32, #ttnn_layout2> - %4 = "ttnn.to_layout"(%3) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#system_memory, <<64x128>>>}> : (tensor<64x128xf32, #ttnn_layout2>) -> tensor<64x128xf32, #ttnn_layout> - return %4 : tensor<64x128xf32, #ttnn_layout> - } -} diff --git a/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/all_dram_operands_l1_op.mlir b/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/all_dram_operands_l1_op.mlir index ec809a60a7..d8da6e2ce9 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/all_dram_operands_l1_op.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/all_dram_operands_l1_op.mlir @@ -4,7 +4,7 @@ module attributes {} { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<32x20x!tt.tile<32x32, bf16>, #dram>, > // CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x32x!tt.tile<32x32, bf16>, #dram>, > - // CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x400x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<5120x8192xbf16> // CHECK-DAG: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<5120x8192xbf16, #[[LAYOUT_6]]> %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<5120x8192xbf16>, tensor<5120x8192xbf16>) -> tensor<5120x8192xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/all_l1_operands_dram_op.mlir b/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/all_l1_operands_dram_op.mlir index 0460f6ac47..7c9d90427b 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/all_l1_operands_dram_op.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/all_l1_operands_dram_op.mlir @@ -2,17 +2,16 @@ module attributes {} { func.func @forward(%arg0: tensor<6144x1024xbf16>, %arg1: tensor<1024x6144xbf16>) -> tensor<6144x6144xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<24x4x!tt.tile<32x32, bf16>, #l1_>, > - // CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<4x24x!tt.tile<32x32, bf16>, #l1_>, > - // CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<24x24x!tt.tile<32x32, bf16>, #dram>, > + // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x96x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<24x24x!tt.tile<32x32, bf16>, #dram>, > %0 = tensor.empty() : tensor<6144x1024xbf16> // CHECK-DAG: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<6144x1024xbf16, #[[LAYOUT_5]]> %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<6144x1024xbf16>, tensor<6144x1024xbf16>) -> tensor<6144x1024xbf16> %2 = tensor.empty() : tensor<1024x6144xbf16> - // CHECK-DAG: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<1024x6144xbf16, #[[LAYOUT_6]]> + // CHECK-DAG: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<1024x6144xbf16, #[[LAYOUT_5]]> %3 = "ttir.relu"(%arg1, %2) <{operandSegmentSizes = array}> : (tensor<1024x6144xbf16>, tensor<1024x6144xbf16>) -> tensor<1024x6144xbf16> %4 = tensor.empty() : tensor<6144x6144xbf16> - // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<6144x6144xbf16, #[[LAYOUT_7]]> + // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<6144x6144xbf16, #[[LAYOUT_6]]> %5 = "ttir.matmul"(%1, %3, %4) : (tensor<6144x1024xbf16>, tensor<1024x6144xbf16>, tensor<6144x6144xbf16>) -> tensor<6144x6144xbf16> return %5 : tensor<6144x6144xbf16> } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/fork_join_01.mlir b/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/fork_join_01.mlir index 5446082c75..73001109ac 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/fork_join_01.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/fork_join_01.mlir @@ -16,8 +16,8 @@ module attributes {} { func.func @forward(%arg0: tensor<4096x5120xbf16>, %arg1: tensor<5120x1024xbf16>, %arg2: tensor<5120x1024xbf16>) -> tensor<4096x1024xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x20x!tt.tile<32x32, bf16>, #l1_>, > - // CHECK: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x4x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x320x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x64x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<4096x5120xbf16> // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_5]]> %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<4096x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/fork_join_02.mlir b/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/fork_join_02.mlir index ee44b78c21..40d98d135d 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/fork_join_02.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/fork_join_02.mlir @@ -19,9 +19,9 @@ module attributes {} { func.func @forward(%arg0: tensor<4096x5120xbf16>, %arg1: tensor<5120x9216xbf16>, %arg2: tensor<9216x1024xbf16>, %arg3: tensor<5120x1024xbf16>) -> tensor<4096x1024xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK: #[[LAYOUT_9:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x20x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK: #[[LAYOUT_9:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x320x!tt.tile<32x32, bf16>, #l1_>, > // CHECK: #[[LAYOUT_10:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x36x!tt.tile<32x32, bf16>, #dram>, > - // CHECK: #[[LAYOUT_11:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x4x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK: #[[LAYOUT_11:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x64x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<4096x5120xbf16> // CHECK-DAG: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_9]]> %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<4096x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_AB_l1_C.mlir b/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_AB_l1_C.mlir index 8f018f9515..a1a1354ad1 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_AB_l1_C.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_AB_l1_C.mlir @@ -15,7 +15,7 @@ module attributes {} { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type // CHECK-DAG: #[[LAYOUT_4:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x16x!tt.tile<32x32, bf16>, #dram>, > // CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x20x!tt.tile<32x32, bf16>, #dram>, > - // CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x400x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<5120x4096xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<5120x4096xbf16, #[[LAYOUT_4]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<5120x4096xbf16>, tensor<5120x4096xbf16>, tensor<5120x4096xbf16>) -> tensor<5120x4096xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_AC_l1_B.mlir b/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_AC_l1_B.mlir index 0791c46295..bc697a2b37 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_AC_l1_B.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_AC_l1_B.mlir @@ -14,7 +14,7 @@ module attributes {} { func.func @forward(%arg0: tensor<4096x5120xbf16>, %arg1: tensor<4096x5120xbf16>, %arg2: tensor<5120x5120xbf16>, %arg3: tensor<5120x5120xbf16>) -> tensor<4096x5120xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x20x!tt.tile<32x32, bf16>, #dram>, > - // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x400x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<4096x5120xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_3]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<4096x5120xbf16>, tensor<4096x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_A_l1_BC.mlir b/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_A_l1_BC.mlir index 049f8f0b45..b0a3dfa56b 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_A_l1_BC.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_A_l1_BC.mlir @@ -14,7 +14,7 @@ module attributes {} { func.func @forward(%arg0: tensor<2048x2048xbf16>, %arg1: tensor<2048x2048xbf16>, %arg2: tensor<2048x8192xbf16>, %arg3: tensor<2048x8192xbf16>) -> tensor<2048x8192xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x8x!tt.tile<32x32, bf16>, #dram>, > - // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x32x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x256x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<2048x2048xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<2048x2048xbf16, #[[LAYOUT_3]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<2048x2048xbf16>, tensor<2048x2048xbf16>, tensor<2048x2048xbf16>) -> tensor<2048x2048xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_BC_l1_A.mlir b/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_BC_l1_A.mlir index 0a63866a63..47f6bb28df 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_BC_l1_A.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_BC_l1_A.mlir @@ -14,7 +14,7 @@ module attributes {} { func.func @forward(%arg0: tensor<5120x5120xbf16>, %arg1: tensor<5120x5120xbf16>, %arg2: tensor<5120x4096xbf16>, %arg3: tensor<5120x4096xbf16>) -> tensor<5120x4096xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<20x16x!tt.tile<32x32, bf16>, #dram>, > - // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<1x400x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<5120x5120xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<5120x5120xbf16, #[[LAYOUT_5]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<5120x5120xbf16>, tensor<5120x5120xbf16>, tensor<5120x5120xbf16>) -> tensor<5120x5120xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_B_l1_AC.mlir b/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_B_l1_AC.mlir index c75c2f39c7..f7be80d504 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_B_l1_AC.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_B_l1_AC.mlir @@ -14,7 +14,7 @@ module attributes {} { func.func @forward(%arg0: tensor<8192x2048xbf16>, %arg1: tensor<8192x2048xbf16>, %arg2: tensor<2048x2048xbf16>, %arg3: tensor<2048x2048xbf16>) -> tensor<8192x2048xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x8x!tt.tile<32x32, bf16>, #dram>, > - // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<32x8x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x256x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<8192x2048xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x2048xbf16, #[[LAYOUT_5]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<8192x2048xbf16>, tensor<8192x2048xbf16>, tensor<8192x2048xbf16>) -> tensor<8192x2048xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_C_l1_AB.mlir b/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_C_l1_AB.mlir index 635540ea61..19e5121ca9 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_C_l1_AB.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/simple_join_tests/dram_C_l1_AB.mlir @@ -13,17 +13,16 @@ module attributes {} { func.func @forward(%arg0: tensor<2048x8192xbf16>, %arg1: tensor<2048x8192xbf16>, %arg2: tensor<8192x2048xbf16>, %arg3: tensor<8192x2048xbf16>) -> tensor<2048x2048xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK-DAG: #[[LAYOUT_4:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x32x!tt.tile<32x32, bf16>, #l1_>, > - // CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<32x8x!tt.tile<32x32, bf16>, #l1_>, > - // CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x8x!tt.tile<32x32, bf16>, #dram>, > + // CHECK-DAG: #[[LAYOUT_4:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x256x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x8x!tt.tile<32x32, bf16>, #dram>, > %0 = tensor.empty() : tensor<2048x8192xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<2048x8192xbf16, #[[LAYOUT_4]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<2048x8192xbf16>, tensor<2048x8192xbf16>, tensor<2048x8192xbf16>) -> tensor<2048x8192xbf16> %2 = tensor.empty() : tensor<8192x2048xbf16> - // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x2048xbf16, #[[LAYOUT_6]]> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x2048xbf16, #[[LAYOUT_4]]> %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array}> : (tensor<8192x2048xbf16>, tensor<8192x2048xbf16>, tensor<8192x2048xbf16>) -> tensor<8192x2048xbf16> %4 = tensor.empty() : tensor<2048x2048xbf16> - // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<2048x2048xbf16, #[[LAYOUT_7]]> + // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<2048x2048xbf16, #[[LAYOUT_6]]> %5 = "ttir.matmul"(%1, %3, %4) : (tensor<2048x8192xbf16>, tensor<8192x2048xbf16>, tensor<2048x2048xbf16>) -> tensor<2048x2048xbf16> return %5 : tensor<2048x2048xbf16> } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir b/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir index 97892500aa..16b0eb1b53 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir @@ -3,13 +3,13 @@ // CHECK-DAG: #[[LOC_MATMUL_IN0:.*]] = loc("matmul_1_in_0_layout"(#loc3)) // CHECK-DAG: #[[LOC_MATMUL_IN1:.*]] = loc("matmul_1_in_1_layout"(#loc3)) // CHECK-DAG: #[[LOC_MATMUL:.*]] = loc("matmul_1"(#loc3)) -// CHECK-DAG: #[[IN_1_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<4x3x!tt.tile<32x32, bf16>, #l1_>, > +// CHECK-DAG: #[[IN_1_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x12x!tt.tile<32x32, bf16>, #l1_>, > module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> { %0 = tensor.empty() : tensor<64x96xbf16> loc(#loc2) // CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} loc(#[[LOC_MATMUL_IN0]]) - // CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} <{memory_config = #ttnn.memory_config<#l1_, <<4x3>>, >}> : {{.*}} -> tensor<128x96xbf16, #[[IN_1_LAYOUT]]> loc(#[[LOC_MATMUL_IN1]]) + // CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} <{memory_config = #ttnn.memory_config<#l1_, <<1x12>>, >}> : {{.*}} -> tensor<128x96xbf16, #[[IN_1_LAYOUT]]> loc(#[[LOC_MATMUL_IN1]]) // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} loc(#[[LOC_MATMUL]]) %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> loc(#loc2) return %1 : tensor<64x96xbf16> diff --git a/test/unittests/Optimizer/TestGreedyL1InterleavedPolicy.cpp b/test/unittests/Optimizer/TestGreedyL1InterleavedPolicy.cpp index 3bc0c54410..10980cbdab 100644 --- a/test/unittests/Optimizer/TestGreedyL1InterleavedPolicy.cpp +++ b/test/unittests/Optimizer/TestGreedyL1InterleavedPolicy.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTNN/IR/TTNN.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" @@ -88,14 +89,14 @@ class GreedyL1InterleavedPolicyBase : public ::testing::Test { TensorMemoryLayoutAttr::get(&context, tensorMemoryLayout); if (legalLayouts.find(op) == legalLayouts.end()) { legalLayouts[op] = std::vector{TTNNLayoutAttr::get( - &context, getTensorRankedType().getShape(), builder.getF32Type(), - memorySpace, mlir::tt::GridAttr::get(&context, {8, 8}), - tensorMemoryLayoutAttr)}; + &context, getTensorRankedType().getShape(), + mlir::tt::TileType::get(&context, builder.getF32Type()), memorySpace, + mlir::tt::GridAttr::get(&context, {8, 8}), tensorMemoryLayoutAttr)}; } else { legalLayouts[op].push_back(TTNNLayoutAttr::get( - &context, getTensorRankedType().getShape(), builder.getF32Type(), - memorySpace, mlir::tt::GridAttr::get(&context, {8, 8}), - tensorMemoryLayoutAttr)); + &context, getTensorRankedType().getShape(), + mlir::tt::TileType::get(&context, builder.getF32Type()), memorySpace, + mlir::tt::GridAttr::get(&context, {8, 8}), tensorMemoryLayoutAttr)); } }