From 6b1897c90889a50c12a7f13204799485a3360780 Mon Sep 17 00:00:00 2001 From: Jackson Nie Date: Tue, 29 Oct 2024 18:48:16 +0000 Subject: [PATCH] #1090: Remove ttnn::CompositeToLayoutOp, use ttnn::ToLayoutOp instead --- include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 24 +----- .../ttmlir/Dialect/TTNN/Transforms/Passes.td | 6 +- lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 2 +- .../TTNN/Analysis/DFShardingPolicy.cpp | 8 +- .../TTNN/Analysis/LegalGridAnalysis.cpp | 2 +- lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp | 2 +- lib/Dialect/TTNN/Transforms/Optimizer.cpp | 45 +++++------ lib/Dialect/TTNN/Transforms/Passes.cpp | 76 ++++++++++--------- test/ttmlir/Dialect/TTNN/test_grid_set.mlir | 6 +- 9 files changed, 75 insertions(+), 96 deletions(-) diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 42518c6e71..f878e7f21d 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -27,22 +27,6 @@ def TTNN_GetDeviceOp : TTNN_Op<"get_device"> { let results = (outs TT_Device:$device); } -def TTNN_CompositeToLayoutOp : TTNN_Op<"composite_to_layout"> { - let summary = "Composite toLayout op."; - let description = [{ - This op wraps all layout information gathered from ttir.toLayout. It is used/updated by the optimizer - to perform optimizations, and later broken down into specific memory/layout operations (toDevice, toMemoryConfig, toLayout etc.). - }]; - - let arguments = (ins AnyRankedTensor:$input, - TTNN_LayoutAttr:$layout, - TT_DataTypeAttr:$dtype, - TTNN_MemoryConfigAttr:$memory_config, - Optional:$device); - - let results = (outs AnyRankedTensor:$result); -} - def TTNN_ToMemoryConfigOp : TTNN_Op<"to_memory_config"> { let summary = "ToMemoryConfig op."; let description = [{ @@ -65,10 +49,10 @@ def TTNN_ToMemoryConfigOp : TTNN_Op<"to_memory_config"> { def TTNN_ToLayoutOp : TTNN_Op<"to_layout"> { let summary = "ToLayout op."; let description = [{ - This op converts the layout of the input tensor based on the given layout. - It handles: - - ROW_MAJOR to TILE - - TILE to ROW_MAJOR + This op wraps all layout information gathered from ttir.toLayout. It is used/updated by the optimizer + to perform optimizations, and later broken down into specific memory/layout operations (toDevice, toMemoryConfig etc.). + Currently in the TTNN backend, we use this op solely for tilize/untilize, therefore marking all other attrs as optional. + Once ttnn::to_layout supports other attrs, we can remove the optional tag. }]; let arguments = (ins AnyRankedTensor:$input, diff --git a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td index c3f8f8dd27..c29fa977b4 100644 --- a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td @@ -14,10 +14,10 @@ def TTNNDeallocate: Pass<"ttnn-deallocate", "::mlir::ModuleOp"> { }]; } -def TTNNDecomposeCompositeLayouts: Pass<"ttnn-decompose-composite-layouts", "::mlir::ModuleOp"> { - let summary = "Decompose composite layout ops to according memory ops."; +def TTNNDecomposeLayouts: Pass<"ttnn-decompose-layouts", "::mlir::ModuleOp"> { + let summary = "Decompose ToLayoutOps to more granular memory ops."; let description = [{ - This pass decomposes composite layouts to memory ops (e.g. toDevice, toMemoryConfig etc.). + This pass decomposes ToLayoutOps to memory ops (e.g. toDevice, toMemoryConfig etc.). }]; } diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 812d07b413..ed1cb9d578 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -200,7 +200,7 @@ class ToLayoutOpConversionPattern op.getContext(), ttnn::ShapeAttr::get(rewriter.getContext(), outputMemref.getShape()))); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(result), adaptor.getInput(), outputLayout, outputDataType, outputMemConfigAttr, isOutputOnHost ? nullptr : getOrInsertDevice(rewriter, op)); diff --git a/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp b/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp index 0af013c390..6136ce277c 100644 --- a/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp +++ b/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp @@ -8,10 +8,6 @@ namespace mlir::tt::ttnn { -bool isCompositeToLayoutOp(mlir::Operation *op) { - return isa(op); -} - void DFShardingPolicy::run( const std::unordered_set &overrideReshardEdges) { rootOp->walk([&](func::FuncOp func) { @@ -39,7 +35,7 @@ void DFShardingPolicy::run( // if (l1ChainConfigs->back().isEmpty()) { for (auto *op : scheduleableOps) { - if (isCompositeToLayoutOp(op)) { + if (isa(op)) { currentOp = op; break; } @@ -57,7 +53,7 @@ void DFShardingPolicy::run( // Skip starting sharding chain if currentOp is a memory management op. // if (l1ChainConfigs->back().isEmpty() && - isCompositeToLayoutOp(currentOp)) { + isa(currentOp)) { currentOp = nullptr; continue; } diff --git a/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp b/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp index c47f9a63cd..512006c387 100644 --- a/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp +++ b/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp @@ -52,7 +52,7 @@ bool cantChangeOutputLayout(Operation *op) { return true; } - if (llvm::isa(op)) { + if (llvm::isa(op)) { return true; } diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index e7f5c0f6b0..7ab62daefa 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -64,7 +64,7 @@ void createTTNNPipelineLoweringPasses( void createTTNNPipelineLayoutDecompositionPass( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { - pm.addPass(createTTNNDecomposeCompositeLayouts()); + pm.addPass(createTTNNDecomposeLayouts()); } void createTTNNPipelineDeallocPass( diff --git a/lib/Dialect/TTNN/Transforms/Optimizer.cpp b/lib/Dialect/TTNN/Transforms/Optimizer.cpp index 139c76c1c7..6c30a98eec 100644 --- a/lib/Dialect/TTNN/Transforms/Optimizer.cpp +++ b/lib/Dialect/TTNN/Transforms/Optimizer.cpp @@ -163,16 +163,15 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { // TTNN will need to be special handled as well. Depends on ttnn // layout attr refactor and lowering. // - else if (isa(op)) { + else if (isa(op)) { BufferType bufferType = utils::toTTNNBufferType(ttLayoutAttr.getMemorySpace()); TensorMemoryLayout tensorMemoryLayout = utils::toTTNNTensorMemoryLayout(ttLayoutAttr.getMemLayout()); // Update the device op with the new tensor type. // - ttnn::CompositeToLayoutOp compositeLayoutOp = - llvm::cast(op); - compositeLayoutOp.setMemoryConfigAttr(ttnn::MemoryConfigAttr::get( + ttnn::ToLayoutOp toLayoutOp = llvm::cast(op); + toLayoutOp.setMemoryConfigAttr(ttnn::MemoryConfigAttr::get( op->getContext(), ttnn::TensorMemoryLayoutAttr::get(op->getContext(), tensorMemoryLayout), @@ -243,45 +242,41 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { Operation *producerOp = edge.producerOp; Operation *consumerOp = edge.consumerOp; - // If producerOp is a compositeToLayoutOp, adjust its output layout(update + // If producerOp is a toLayoutOp, adjust its output layout(update // inplace) to reflect consumerOp's output layout. If producerOp is not a - // compositeToLayoutOp, insert a compositeToLayoutOp in between producerOp + // toLayoutOp, insert a toLayoutOp in between producerOp // and consumerOp. // - if (isa(producerOp)) { - ttnn::CompositeToLayoutOp compositeLayoutOp = - llvm::cast(producerOp); + if (isa(producerOp)) { + ttnn::ToLayoutOp toLayoutOp = llvm::cast(producerOp); tt::LayoutAttr consumerOpOutputLayout = mlir::cast( mlir::cast(consumerOp->getResult(0).getType()) .getEncoding()); - RankedTensorType compositeLayoutOpTensorType = - mlir::cast( - compositeLayoutOp.getResult().getType()); - llvm::ArrayRef compositeLayoutOpTensorShape = - compositeLayoutOpTensorType.getShape(); - tt::LayoutAttr compositeLayoutOpLayout = mlir::cast( - compositeLayoutOpTensorType.getEncoding()); + RankedTensorType toLayoutOpTensorType = + mlir::cast(toLayoutOp.getResult().getType()); + llvm::ArrayRef toLayoutOpTensorShape = + toLayoutOpTensorType.getShape(); + tt::LayoutAttr toLayoutOpLayout = + mlir::cast(toLayoutOpTensorType.getEncoding()); // TODO(nobradovic): Match memory space and layout of consumer op. This // actually needs to be properly resolved based on op type, output // layout and other inputs. // RankedTensorType newTensorType = RankedTensorType::get( - compositeLayoutOpTensorShape, - compositeLayoutOpTensorType.getElementType(), - compositeLayoutOpLayout - .withElementType(compositeLayoutOp->getContext(), + toLayoutOpTensorShape, toLayoutOpTensorType.getElementType(), + toLayoutOpLayout + .withElementType(toLayoutOp->getContext(), consumerOpOutputLayout.getElementType()) - .withMemorySpace(compositeLayoutOp.getContext(), + .withMemorySpace(toLayoutOp.getContext(), consumerOpOutputLayout.getMemorySpace()) - .withMemoryLayout(compositeLayoutOp.getContext(), + .withMemoryLayout(toLayoutOp.getContext(), consumerOpOutputLayout.getMemLayout()) - .withGrid(compositeLayoutOp.getContext(), - compositeLayoutOpTensorType, + .withGrid(toLayoutOp.getContext(), toLayoutOpTensorType, consumerOpOutputLayout.getGrid())); - compositeLayoutOp.getResult().setType(newTensorType); + toLayoutOp.getResult().setType(newTensorType); } // TODO (nobradovic): Memory layout reconfig needs to be reimplemented for // TTNN dialect. diff --git a/lib/Dialect/TTNN/Transforms/Passes.cpp b/lib/Dialect/TTNN/Transforms/Passes.cpp index dca2c8bc93..cb0d8c8869 100644 --- a/lib/Dialect/TTNN/Transforms/Passes.cpp +++ b/lib/Dialect/TTNN/Transforms/Passes.cpp @@ -11,7 +11,7 @@ namespace mlir::tt::ttnn { #define GEN_PASS_DEF_TTNNDEALLOCATE -#define GEN_PASS_DEF_TTNNDECOMPOSECOMPOSITELAYOUTS +#define GEN_PASS_DEF_TTNNDECOMPOSELAYOUTS #include "ttmlir/Dialect/TTNN/Transforms/Passes.h.inc" class TTNNDeallocate : public impl::TTNNDeallocateBase { @@ -98,13 +98,12 @@ class TTNNDeallocate : public impl::TTNNDeallocateBase { } }; -class TTNNDecomposeCompositeLayouts - : public impl::TTNNDecomposeCompositeLayoutsBase< - TTNNDecomposeCompositeLayouts> { +class TTNNDecomposeLayouts + : public impl::TTNNDecomposeLayoutsBase { public: - using impl::TTNNDecomposeCompositeLayoutsBase< - TTNNDecomposeCompositeLayouts>::TTNNDecomposeCompositeLayoutsBase; + using impl::TTNNDecomposeLayoutsBase< + TTNNDecomposeLayouts>::TTNNDecomposeLayoutsBase; void runOnOperation() final { ModuleOp module = getOperation(); @@ -113,14 +112,14 @@ class TTNNDecomposeCompositeLayouts module->walk([&](func::FuncOp func) { assert(func.getBody().hasOneBlock()); func->walk([&](Operation *op) { - if (!isa(op)) { + if (!isa(op)) { return; } opsToReplace.push_back(op); }); }); for (Operation *op : opsToReplace) { - this->createLayoutConversionOps(mlir::cast(op), + this->createLayoutConversionOps(mlir::cast(op), rewriter); rewriter.eraseOp(op); } @@ -164,11 +163,14 @@ class TTNNDecomposeCompositeLayouts void print() const { llvm::errs() << "OpsToCreate{ \n" - << "\t" << "CreateToDeviceOp: " << createToDeviceOp << "\n" - << "\t" << "CreateFromDeviceOp: " << createFromDeviceOp - << "\n" - << "\t" << "CreateToLayoutOp: " << createToLayoutOp << "\n" - << "\t" << "CreateTypecastOp: " << createTypecastOp << "\n" + << "\t" + << "CreateToDeviceOp: " << createToDeviceOp << "\n" + << "\t" + << "CreateFromDeviceOp: " << createFromDeviceOp << "\n" + << "\t" + << "CreateToLayoutOp: " << createToLayoutOp << "\n" + << "\t" + << "CreateTypecastOp: " << createTypecastOp << "\n" << "\t" << "CreateToMemoryConfigOp: " << createToMemoryConfigOp << "\n" @@ -208,13 +210,15 @@ class TTNNDecomposeCompositeLayouts } std::pair - getInputOutputLayouts(ttnn::CompositeToLayoutOp op) const { + getInputOutputLayouts(ttnn::ToLayoutOp op) const { LayoutInfo input, output; auto inputLayoutAttr = mlir::cast(op.getInput().getType().getEncoding()); auto inputMemref = inputLayoutAttr.getMemref(); - MemoryConfigAttr outputMemoryConfig = op.getMemoryConfig(); + + assert(op.getMemoryConfig().has_value()); + MemoryConfigAttr outputMemoryConfig = op.getMemoryConfig().value(); input.bufferType = ttnn::utils::toTTNNBufferType(inputLayoutAttr.getMemorySpace()); @@ -224,7 +228,8 @@ class TTNNDecomposeCompositeLayouts output.layoutEnum = op.getLayout(); input.dataType = ttnn::utils::getDataTypeFromMemRef(inputMemref); - output.dataType = op.getDtype(); + assert(op.getDtype().has_value()); + output.dataType = op.getDtype().value(); input.tensorMemoryLayout = ttnn::utils::toTTNNTensorMemoryLayout(inputLayoutAttr.getMemLayout()); @@ -273,13 +278,13 @@ class TTNNDecomposeCompositeLayouts return opsToCreate; } - bool isCreationValid(ttnn::CompositeToLayoutOp op, const LayoutInfo &input, + bool isCreationValid(ttnn::ToLayoutOp op, const LayoutInfo &input, const LayoutInfo &output, const OpsToCreate &opsToCreate) const { if (not opsToCreate.createSomeOp()) { op->emitError( - "Redundant ttnn::CompositeToLayoutOp - no ttnn layout ops " + "Redundant ttnn::ToLayoutOp - no ttnn layout ops " "needed, this may be due to the forcing of tile/row major layouts."); return false; } @@ -311,7 +316,7 @@ class TTNNDecomposeCompositeLayouts /* Helper functions to create ttnn layout ops */ template - mlir::Value createOp(ttnn::CompositeToLayoutOp op, IRRewriter &rewriter, + mlir::Value createOp(ttnn::ToLayoutOp op, IRRewriter &rewriter, mlir::Value currentInput, Args... args) const { rewriter.setInsertionPoint(op); @@ -319,7 +324,7 @@ class TTNNDecomposeCompositeLayouts args...); } - mlir::Value createToDeviceOpIfNeeded(ttnn::CompositeToLayoutOp op, + mlir::Value createToDeviceOpIfNeeded(ttnn::ToLayoutOp op, IRRewriter &rewriter, mlir::Value currentInput, const OpCreationInfo &info) const { @@ -333,7 +338,7 @@ class TTNNDecomposeCompositeLayouts } // FromDeviceOp - mlir::Value createFromDeviceOpIfNeeded(ttnn::CompositeToLayoutOp op, + mlir::Value createFromDeviceOpIfNeeded(ttnn::ToLayoutOp op, IRRewriter &rewriter, mlir::Value currentInput, const OpCreationInfo &info, @@ -344,7 +349,7 @@ class TTNNDecomposeCompositeLayouts return this->createOp(op, rewriter, currentInput); } - mlir::Value createToLayoutOpIfNeeded(ttnn::CompositeToLayoutOp op, + mlir::Value createToLayoutOpIfNeeded(ttnn::ToLayoutOp op, IRRewriter &rewriter, mlir::Value currentInput, const OpCreationInfo &info) const { @@ -358,7 +363,7 @@ class TTNNDecomposeCompositeLayouts /*memory_config*/ nullptr, /*device*/ nullptr); } - mlir::Value createTypecastOpIfNeeded(ttnn::CompositeToLayoutOp op, + mlir::Value createTypecastOpIfNeeded(ttnn::ToLayoutOp op, IRRewriter &rewriter, mlir::Value currentInput, const OpCreationInfo &info) const { @@ -371,7 +376,7 @@ class TTNNDecomposeCompositeLayouts dtypeAttr); } - mlir::Value createToMemoryConfigOpIfNeeded(ttnn::CompositeToLayoutOp op, + mlir::Value createToMemoryConfigOpIfNeeded(ttnn::ToLayoutOp op, IRRewriter &rewriter, mlir::Value currentInput, const OpCreationInfo &info) const { @@ -387,7 +392,7 @@ class TTNNDecomposeCompositeLayouts /* Functions that create ops based on the layouts of the input output tensors */ - void handleHostInputNoLayoutNoTypecast(ttnn::CompositeToLayoutOp op, + void handleHostInputNoLayoutNoTypecast(ttnn::ToLayoutOp op, IRRewriter &rewriter, mlir::Value currentInput, const OpCreationInfo &info) const { @@ -401,7 +406,7 @@ class TTNNDecomposeCompositeLayouts op.getResult().replaceAllUsesWith(currentInput); } - void handleHostInputLayoutNoTypecast(ttnn::CompositeToLayoutOp op, + void handleHostInputLayoutNoTypecast(ttnn::ToLayoutOp op, IRRewriter &rewriter, mlir::Value currentInput, const OpCreationInfo &info) const { @@ -450,7 +455,7 @@ class TTNNDecomposeCompositeLayouts llvm_unreachable("Unreachable code path"); } - void handleHostInputNoLayoutTypecast(ttnn::CompositeToLayoutOp op, + void handleHostInputNoLayoutTypecast(ttnn::ToLayoutOp op, IRRewriter &rewriter, mlir::Value currentInput, const OpCreationInfo &info) const { @@ -486,8 +491,7 @@ class TTNNDecomposeCompositeLayouts llvm_unreachable("Unreachable code path"); } - void handleHostInputLayoutTypecast(ttnn::CompositeToLayoutOp op, - IRRewriter &rewriter, + void handleHostInputLayoutTypecast(ttnn::ToLayoutOp op, IRRewriter &rewriter, mlir::Value currentInput, const OpCreationInfo &info) const { const LayoutInfo &input = info.input; @@ -557,7 +561,7 @@ class TTNNDecomposeCompositeLayouts llvm_unreachable("Unreachable code path"); } - void handleHostInputLayoutConversion(ttnn::CompositeToLayoutOp op, + void handleHostInputLayoutConversion(ttnn::ToLayoutOp op, IRRewriter &rewriter, mlir::Value currentInput, const OpCreationInfo &info) const { @@ -578,7 +582,7 @@ class TTNNDecomposeCompositeLayouts llvm_unreachable("Unreachable code path"); } - void handleDeviceInputNoLayoutNoTypecast(ttnn::CompositeToLayoutOp op, + void handleDeviceInputNoLayoutNoTypecast(ttnn::ToLayoutOp op, IRRewriter &rewriter, mlir::Value currentInput, const OpCreationInfo &info) const { @@ -592,7 +596,7 @@ class TTNNDecomposeCompositeLayouts op.getResult().replaceAllUsesWith(currentInput); } - void handleDeviceInputLayoutNoTypecast(ttnn::CompositeToLayoutOp op, + void handleDeviceInputLayoutNoTypecast(ttnn::ToLayoutOp op, IRRewriter &rewriter, mlir::Value currentInput, const OpCreationInfo &info) const { @@ -682,7 +686,7 @@ class TTNNDecomposeCompositeLayouts llvm_unreachable("Unreachable code path"); } - void handleDeviceInputNoLayoutTypecast(ttnn::CompositeToLayoutOp op, + void handleDeviceInputNoLayoutTypecast(ttnn::ToLayoutOp op, IRRewriter &rewriter, mlir::Value currentInput, const OpCreationInfo &info) const { @@ -735,7 +739,7 @@ class TTNNDecomposeCompositeLayouts llvm_unreachable("Unreachable code path"); } - void handleDeviceInputLayoutTypecast(ttnn::CompositeToLayoutOp op, + void handleDeviceInputLayoutTypecast(ttnn::ToLayoutOp op, IRRewriter &rewriter, mlir::Value currentInput, const OpCreationInfo &info) const { @@ -831,7 +835,7 @@ class TTNNDecomposeCompositeLayouts llvm_unreachable("Unreachable code path"); } - void handleDeviceInputLayoutConversion(ttnn::CompositeToLayoutOp op, + void handleDeviceInputLayoutConversion(ttnn::ToLayoutOp op, IRRewriter &rewriter, mlir::Value currentInput, const OpCreationInfo &info) const { @@ -864,7 +868,7 @@ class TTNNDecomposeCompositeLayouts * sizeof(uint32_t). For now, we will always untilize on host. We rarely * need device to device untilize, so the perf hit should be acceptable. */ - void createLayoutConversionOps(ttnn::CompositeToLayoutOp op, + void createLayoutConversionOps(ttnn::ToLayoutOp op, IRRewriter &rewriter) const { auto [input, output] = getInputOutputLayouts(op); OpsToCreate opsToCreate = determineRequiredOps(input, output); diff --git a/test/ttmlir/Dialect/TTNN/test_grid_set.mlir b/test/ttmlir/Dialect/TTNN/test_grid_set.mlir index ee983a152f..a65b428875 100644 --- a/test/ttmlir/Dialect/TTNN/test_grid_set.mlir +++ b/test/ttmlir/Dialect/TTNN/test_grid_set.mlir @@ -9,13 +9,13 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} { func.func @forward(%arg0: tensor<64x128xf32, #layout>, %arg1: tensor<64x128xf32, #layout>) -> tensor<64x128xf32, #layout> { %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> - %1 = "ttnn.composite_to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<64x128>>>}> : (tensor<64x128xf32, #layout>, !tt.device<#device>) -> tensor<64x128xf32, #layout1> - %2 = "ttnn.composite_to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<64x128>>>}> : (tensor<64x128xf32, #layout>, !tt.device<#device>) -> tensor<64x128xf32, #layout1> + %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<64x128>>>}> : (tensor<64x128xf32, #layout>, !tt.device<#device>) -> tensor<64x128xf32, #layout1> + %2 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<64x128>>>}> : (tensor<64x128xf32, #layout>, !tt.device<#device>) -> tensor<64x128xf32, #layout1> %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<64x128>>>, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xf32, #layout2> %4 = "ttnn.multiply"(%1, %2, %3) <{operandSegmentSizes = array}> : (tensor<64x128xf32, #layout1>, tensor<64x128xf32, #layout1>, tensor<64x128xf32, #layout2>) -> tensor<64x128xf32, #layout2> // CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #dram>, interleaved> // CHECK: %{{.+}} = "ttnn.multiply"{{.+}} -> tensor<64x128xf32, #[[LAYOUT_2]]> - %5 = "ttnn.composite_to_layout"(%4) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<64x128>>>}> : (tensor<64x128xf32, #layout2>) -> tensor<64x128xf32, #layout> + %5 = "ttnn.to_layout"(%4) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, , <<64x128>>>}> : (tensor<64x128xf32, #layout2>) -> tensor<64x128xf32, #layout> return %5 : tensor<64x128xf32, #layout> } }