From 676c714e8caaa966d3f816f4fcda9a835e77a9e9 Mon Sep 17 00:00:00 2001 From: Jackson Nie Date: Thu, 19 Dec 2024 22:09:19 +0000 Subject: [PATCH] Use tilized dram-interleaved as default input-output layout --- .../ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td | 6 +- .../Dialect/TTNN/Pipelines/TTNNPipelines.h | 2 +- .../ttmlir/Dialect/TTNN/Transforms/Passes.td | 2 +- lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 93 +---- lib/Dialect/TT/IR/TTOpsTypes.cpp | 4 +- .../TTNN/Analysis/BFInterleavedPolicy.cpp | 19 + lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp | 2 +- lib/Dialect/TTNN/Transforms/Optimizer.cpp | 16 + .../TTNN/Transforms/TTNNDecomposeLayouts.cpp | 54 +-- lib/Dialect/TTNN/Transforms/TTNNLayout.cpp | 60 +++- .../Workarounds/TTNNWorkarounds.cpp | 55 ++- .../include/tt/runtime/detail/workarounds.h | 22 +- runtime/lib/common/workarounds.cpp | 5 +- .../ttnn/include/tt/runtime/ttnn/types.cpp | 9 +- .../ttnn/include/tt/runtime/ttnn/utils.cpp | 5 - .../lib/ttnn/include/tt/runtime/ttnn/utils.h | 10 - .../lib/ttnn/operations/creation/empty.cpp | 4 +- runtime/lib/ttnn/operations/creation/full.cpp | 5 +- runtime/lib/ttnn/operations/creation/ones.cpp | 4 +- .../tt/runtime/ttnn/operations/utils.h | 10 + .../lib/ttnn/operations/pool/maxpool2d.cpp | 6 +- runtime/lib/ttnn/runtime.cpp | 4 + runtime/test/python/ttnn/test_runtime_api.py | 46 +-- runtime/tools/python/ttrt/common/run.py | 13 +- .../eltwise/unary/expm1/simple_expm1.mlir | 2 +- .../eltwise/unary/log1p/simple_log1p.mlir | 2 +- .../TTNN/eltwise/unary/relu/simple_relu.mlir | 13 +- .../TTNN/eltwise/unary/sign/simple_sign.mlir | 2 +- .../TTNN/eltwise/unary/tan/simple_tan.mlir | 2 +- .../TTNN/eltwise/unary/tanh/simple_tanh.mlir | 2 +- .../bf_interleaved_policy/fork_join_01.mlir | 12 +- .../bf_interleaved_policy/fork_join_02.mlir | 16 +- .../fork_join.mlir | 10 +- .../optimizer/input_layout_loc_override.mlir | 5 +- .../TTNN/optimizer/multiple_add_with_loc.mlir | 2 +- .../optimizer/output_layout_override.mlir | 2 + .../optimizer/sharding_matmul_override_0.mlir | 6 +- .../sharding_matmul_override_32.mlir | 4 +- test/ttmlir/Dialect/TTNN/simple_clamp.mlir | 5 +- .../TTNN/simple_get_dimension_size.mlir | 1 - test/ttmlir/Dialect/TTNN/simple_scatter.mlir | 2 +- test/ttmlir/Dialect/TTNN/simple_where.mlir | 6 +- .../eltwise_binary_op_chain.mlir | 49 --- .../Silicon/StableHLO/scalar_add_op.mlir | 1 - test/ttmlir/Silicon/TTNN/deallocate.mlir | 9 +- .../Silicon/TTNN/eltwise/binary/add/add.mlir | 11 + .../TTNN/eltwise/binary/add/add_int32.mlir | 11 + .../TTNN/eltwise/binary/concat/concat.mlir | 11 + .../Silicon/TTNN/eltwise/binary/div/div.mlir | 11 + .../Silicon/TTNN/eltwise/binary/ge/ge.mlir | 11 + .../TTNN/eltwise/binary/maximum/maximum.mlir | 11 + .../TTNN/eltwise/binary/minimum/minimum.mlir | 16 + .../eltwise/binary/multiply/multiply.mlir | 11 + .../eltwise/binary/remainder/remainder.mlir | 12 + .../TTNN/eltwise/binary/scatter/scatter.mlir | 14 + .../eltwise/binary/subtract/subtract.mlir | 11 + .../TTNN/eltwise/ternary/where/where.mlir | 14 + .../Silicon/TTNN/eltwise/unary/cbrt/cbrt.mlir | 11 + .../Silicon/TTNN/eltwise/unary/ceil/ceil.mlir | 11 + .../TTNN/eltwise/unary/clamp/clamp.mlir | 11 + .../TTNN/eltwise/unary/cosine/cosine.mlir | 11 + .../TTNN/eltwise/unary/expm1/expm1.mlir | 12 + .../TTNN/eltwise/unary/floor/floor.mlir | 15 + .../Silicon/TTNN/eltwise/unary/gelu/gelu.mlir | 15 + .../get_dimension_size.mlir | 9 + .../eltwise/unary/is_finite/is_finite.mlir | 15 + .../eltwise/unary/leaky_relu/leaky_relu.mlir | 11 + .../Silicon/TTNN/eltwise/unary/log/log.mlir | 11 + .../TTNN/eltwise/unary/log1p/log1p.mlir | 12 + .../TTNN/eltwise/unary/negate/negate.mlir | 10 + .../eltwise/unary/recipricol/recipricol.mlir | 11 + .../Silicon/TTNN/eltwise/unary/relu/relu.mlir | 11 + .../TTNN/eltwise/unary/rsqrt/rsqrt.mlir | 11 + .../TTNN/eltwise/unary/sigmoid/sigmoid.mlir | 11 + .../Silicon/TTNN/eltwise/unary/sign/sign.mlir | 12 + .../Silicon/TTNN/eltwise/unary/sine/sine.mlir | 11 + .../Silicon/TTNN/eltwise/unary/sqrt/sqrt.mlir | 11 + .../Silicon/TTNN/eltwise/unary/tan/tan.mlir | 11 + .../Silicon/TTNN/eltwise/unary/tanh/tanh.mlir | 11 + test/ttmlir/Silicon/TTNN/ones.mlir | 8 +- .../TTNN/perf_unit/test_perf_ceil.mlir | 2 +- .../TTNN/perf_unit/test_perf_clamp.mlir | 5 +- .../TTNN/perf_unit/test_perf_cosine.mlir | 2 +- .../TTNN/perf_unit/test_perf_expm1.mlir | 2 +- .../Silicon/TTNN/perf_unit/test_perf_log.mlir | 2 +- .../TTNN/perf_unit/test_perf_log1p.mlir | 2 +- .../TTNN/perf_unit/test_perf_sign.mlir | 2 +- .../TTNN/perf_unit/test_perf_sine.mlir | 2 +- .../Silicon/TTNN/perf_unit/test_perf_tan.mlir | 2 +- .../TTNN/perf_unit/test_perf_tanh.mlir | 2 +- .../TTNN/perf_unit/test_perf_where.mlir | 4 +- test/ttmlir/Silicon/TTNN/reshape/reshape.mlir | 10 + .../eltwise_binary_op_chain.mlir | 29 ++ test/ttmlir/Silicon/TTNN/simple_eltwise.mlir | 334 ------------------ test/ttmlir/Silicon/TTNN/simple_repeat.mlir | 6 +- test/ttmlir/Silicon/TTNN/softmax/softmax.mlir | 15 + test/ttmlir/Silicon/TTNN/squeeze/squeeze.mlir | 10 + .../TTNN/{ => typecast}/simple_typecast.mlir | 0 .../Silicon/TTNN/typecast/typecast.mlir | 12 + 99 files changed, 806 insertions(+), 652 deletions(-) delete mode 100644 test/ttmlir/Runtime/TTNN/runtime_stitching/eltwise_binary_op_chain.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/binary/add/add.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/binary/add/add_int32.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/binary/concat/concat.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/binary/div/div.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/binary/ge/ge.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/binary/maximum/maximum.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/binary/minimum/minimum.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/binary/multiply/multiply.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/binary/remainder/remainder.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/binary/scatter/scatter.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/binary/subtract/subtract.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/ternary/where/where.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/cbrt/cbrt.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/ceil/ceil.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/clamp/clamp.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/cosine/cosine.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/expm1/expm1.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/floor/floor.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/gelu/gelu.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/get_dimension_size/get_dimension_size.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/is_finite/is_finite.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/leaky_relu/leaky_relu.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/log/log.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/log1p/log1p.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/negate/negate.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/recipricol/recipricol.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/relu/relu.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/rsqrt/rsqrt.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/sigmoid/sigmoid.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/sign/sign.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/sine/sine.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/sqrt/sqrt.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/tan/tan.mlir create mode 100644 test/ttmlir/Silicon/TTNN/eltwise/unary/tanh/tanh.mlir create mode 100644 test/ttmlir/Silicon/TTNN/reshape/reshape.mlir create mode 100644 test/ttmlir/Silicon/TTNN/runtime_stitching/eltwise_binary_op_chain.mlir delete mode 100644 test/ttmlir/Silicon/TTNN/simple_eltwise.mlir create mode 100644 test/ttmlir/Silicon/TTNN/softmax/softmax.mlir create mode 100644 test/ttmlir/Silicon/TTNN/squeeze/squeeze.mlir rename test/ttmlir/Silicon/TTNN/{ => typecast}/simple_typecast.mlir (100%) create mode 100644 test/ttmlir/Silicon/TTNN/typecast/typecast.mlir diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td index 8d20a2bcc5..cc8f090c2c 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td @@ -88,11 +88,7 @@ def TTNN_MemoryConfigAttr : TTNN_Attr<"MemoryConfig", "memory_config"> { let assemblyFormat = "`<` params `>`"; let extraClassDeclaration = [{ - ::llvm::ArrayRef getShardShapeArray() const - { - return this->getShardSpec().getShardShape().getShape(); - } - + llvm::ArrayRef getShardShape(bool convertTileToScalar = true) const; MemoryConfigAttr withBufferType(::mlir::MLIRContext *context, BufferType bufferType); MemoryConfigAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout); }]; diff --git a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h index a65f95c6b2..3e8e71de83 100644 --- a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h +++ b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h @@ -128,7 +128,7 @@ struct TTIRToTTNNBackendPipelineOptions // Option to enable/disable the workaround pass. // - Option layouotWorkaroundsEnabled{ + Option layoutWorkaroundsEnabled{ *this, "enable-layout-workaround-pass", llvm::cl::desc("Enable layout workaround pass."), llvm::cl::init(true)}; diff --git a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td index a5f83290d8..8476964d60 100644 --- a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td @@ -36,7 +36,7 @@ def TTNNWorkarounds : Pass<"ttnn-workaround", "::mlir::ModuleOp"> { }]; let options = [ - Option<"layouotWorkaroundsEnabled", + Option<"layoutWorkaroundsEnabled", "ttnn-enable-layout-workaround-pass", "bool", /*default=*/"true", "TTNN Layout Workarounds Pass">, diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 2e84eb3471..03cbe9d331 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttmlir/Conversion/TTIRToTTNN/TTIRToTTNN.h" - #include "ttmlir/Conversion/TTIRToTTNN/Utils.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTIR/IR/TTIROps.h" @@ -13,6 +12,7 @@ #include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h" #include "ttmlir/Dialect/TTNN/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" @@ -170,6 +170,9 @@ class ToLayoutOpConversionPattern rewriter.eraseOp(emptyOp); } + assert(mlir::isa(adaptor.getInput().getType()) && + "Expected RankedTensorType for ToLayoutOp input"); + auto outputLayoutAttr = mlir::cast( op.getResult().getType().getEncoding()); @@ -186,32 +189,6 @@ class ToLayoutOpConversionPattern bool isOutputOnHost = (outputBufferType == ttnn::BufferType::SystemMemory); RankedTensorType result = mlir::cast(op.getType()); - if (!isOutputOnHost) { - // TODO(bug #665): - // Binary ops fail with row major layout in ttnn, defaulting to and - // assuming tile layout for all device tensors... - // Note: mlir doesn't know about this, so tensors may still appear as row - // major in the generated mlir - // TODO(bug #875): - // Remove the following code block once constraints modelling is - // implemented on dialect level - // - // Default to Tile layout unless op supports only RowMajor layout - // - ttnn::Layout newOutputLayoutEnum = - shouldForceRowMajor(op) ? ttnn::Layout::RowMajor : ttnn::Layout::Tile; - - // If the layout of the output tensor changed as a result of forcing the - // layout update the tensor type - if (outputLayoutEnum != newOutputLayoutEnum) { - result = - getLayoutForcedResultTensor(rewriter, result, newOutputLayoutEnum); - op.getResult().setType(result); - outputLayoutAttr = - mlir::cast(result.getEncoding()); - outputLayoutEnum = newOutputLayoutEnum; - } - } ttnn::LayoutAttr outputLayout = ttnn::LayoutAttr::get(rewriter.getContext(), outputLayoutEnum); @@ -235,68 +212,6 @@ class ToLayoutOpConversionPattern return success(); } - -private: - bool shouldForceRowMajor(ttir::ToLayoutOp op) const { - // Check if the output tensor is used by an op that only supports row major. - // - // EmbeddingBackwardOp supports row major layout for the first and second - // operands. - for (mlir::Operation *user : op.getResult().getUsers()) { - if (isa(user) || isa(user) || - (isa(user) && - (user->getOperand(0) == op || user->getOperand(1) == op))) { - return true; - } - } - - return false; - } - - RankedTensorType - getLayoutForcedResultTensor(ConversionPatternRewriter &rewriter, - RankedTensorType oldOutput, - ttnn::Layout newOutputLayoutEnum) const { - auto oldOutputLayoutAttr = - mlir::cast(oldOutput.getEncoding()); - DataType outputDtype = oldOutputLayoutAttr.getDataType(); - SmallVector oldShardShape = - oldOutputLayoutAttr.getShardShape(); - size_t shardShapeSize = oldShardShape.size(); - assert(shardShapeSize >= 2 && "expected at least 2D shape"); - - if (newOutputLayoutEnum == ttnn::Layout::RowMajor) { - // Set shard shape to match convention of row major layout - auto tileType = - mlir::cast(oldOutputLayoutAttr.getElementType()); - llvm::SmallVector newShardShape(oldShardShape.begin(), - oldShardShape.end()); - newShardShape[shardShapeSize - 2] = - oldShardShape[shardShapeSize - 2] * tileType.getHeight(); - newShardShape[shardShapeSize - 1] = - oldShardShape[shardShapeSize - 1] * tileType.getWidth(); - Type newElementType = ttnn::utils::createRowMajorTypeFromDtype( - rewriter.getContext(), outputDtype); - RankedTensorType result = RankedTensorType::get( - oldOutput.getShape(), oldOutput.getElementType(), - oldOutputLayoutAttr - .withElementType(rewriter.getContext(), newElementType) - .withShardShape(rewriter.getContext(), newShardShape)); - return result; - } - - if (newOutputLayoutEnum == ttnn::Layout::Tile) { - TileType tileType = - TileType::get(rewriter.getContext(), - {ttnn::TILE_HEIGHT, ttnn::TILE_WIDTH}, outputDtype); - RankedTensorType result = RankedTensorType::get( - oldOutput.getShape(), oldOutput.getElementType(), - oldOutputLayoutAttr.withElementType(rewriter.getContext(), tileType)); - return result; - } - - llvm_unreachable("Unreachable code path. Unexpected output layout enum"); - } }; template (std::malloc(size), std::free); diff --git a/lib/Dialect/TTNN/Analysis/BFInterleavedPolicy.cpp b/lib/Dialect/TTNN/Analysis/BFInterleavedPolicy.cpp index 4a6f26b5e4..046f342040 100644 --- a/lib/Dialect/TTNN/Analysis/BFInterleavedPolicy.cpp +++ b/lib/Dialect/TTNN/Analysis/BFInterleavedPolicy.cpp @@ -117,6 +117,25 @@ void BFInterleavedPolicy::run() { scheduler.scheduleOp(nextOpForScheduling); } + // TODO (#0000): This is a temporary solution + // Currently ReturnOps are not considered when calculating L1 usage + llvm::SmallVector eraseableL1UsageOps; + for (auto &[op, usage] : currentL1UsagePerOp) { + for (Operation *user : op->getUsers()) { + if (isa(user)) { + usage.numOfUnscheduledUsers -= 1; + } + } + if (usage.numOfUnscheduledUsers == 0) { + eraseableL1UsageOps.push_back(op); + } + } + + for (Operation *op : eraseableL1UsageOps) { + currentL1Usage -= currentL1UsagePerOp[op].l1MemUsagePerUser; + currentL1UsagePerOp.erase(op); + } + assert(currentL1Usage == 0); assert(currentL1UsagePerOp.size() == 0); diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index f1ec29999e..0551ca42f7 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -67,7 +67,7 @@ void createTTNNPipelineLoweringPasses( void createTTNNPipelineWorkaroundPass( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { TTNNWorkaroundsOptions workaroundOptions{ - options.layouotWorkaroundsEnabled, + options.layoutWorkaroundsEnabled, options.decompositionWorkaroundsEnabled}; pm.addPass(createTTNNWorkarounds(workaroundOptions)); pm.addPass(mlir::createCanonicalizerPass()); diff --git a/lib/Dialect/TTNN/Transforms/Optimizer.cpp b/lib/Dialect/TTNN/Transforms/Optimizer.cpp index 9ada2dbb5d..1ce957abb9 100644 --- a/lib/Dialect/TTNN/Transforms/Optimizer.cpp +++ b/lib/Dialect/TTNN/Transforms/Optimizer.cpp @@ -224,6 +224,22 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { // If schedule is set, apply order of operations to func. // if (opSchedule[func].size() > 1) { + // TODO (#0000): This is a temporary solution - when defaulting to dram + // tile input/output layout, GetDeviceOp can randomly appear as the last + // op in the graph instead of the first. This workaround ensures + // getDeviceOp is always in the beginning of the schedule. + // To reproduce, remove this workaround and run + // Silicon/TTNN/optimizer/mnist_sharding.mlir multiple times (as it is + // non-deterministic). + Operation **it = + std::find_if(opSchedule[func].begin(), opSchedule[func].end(), + [](Operation *op) { return isa(op); }); + if (it != opSchedule[func].end()) { + GetDeviceOp deviceOp = mlir::cast(*it); + opSchedule[func].erase(it); + opSchedule[func].insert(opSchedule[func].begin(), deviceOp); + } + for (size_t i = 0; i < opSchedule[func].size() - 1; i++) { Operation *op = opSchedule[func][i]; diff --git a/lib/Dialect/TTNN/Transforms/TTNNDecomposeLayouts.cpp b/lib/Dialect/TTNN/Transforms/TTNNDecomposeLayouts.cpp index 95cfed6f4e..a5dcaf74d3 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNDecomposeLayouts.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNDecomposeLayouts.cpp @@ -30,8 +30,11 @@ class TTNNDecomposeLayouts }); }); for (Operation *op : opsToReplace) { - this->createLayoutConversionOps(mlir::cast(op), - rewriter); + if (failed(createLayoutConversionOps(mlir::cast(op), + rewriter))) { + signalPassFailure(); + return; + } rewriter.eraseOp(op); } } @@ -42,6 +45,7 @@ class TTNNDecomposeLayouts ttnn::Layout layoutEnum; DataType dataType; ttnn::TensorMemoryLayoutAttr tensorMemoryLayout; + GridAttr shardGrid; llvm::SmallVector shardShape; ttnn::MemoryConfigAttr createMemoryConfigAttr(MLIRContext *context) const { @@ -51,7 +55,9 @@ class TTNNDecomposeLayouts ttnn::ShapeAttr::get(context, shardShape)), tensorMemoryLayout); } - + bool isL1Sharded() const { + return isShardedMemoryLayout(tensorMemoryLayout.getValue()); + } bool isOnHost() const { return bufferType == ttnn::BufferType::SystemMemory; } @@ -115,6 +121,9 @@ class TTNNDecomposeLayouts auto inputLayoutAttr = mlir::cast(op.getInput().getType().getEncoding()); + auto outputLayoutAttr = + mlir::cast(op.getResult().getType().getEncoding()); + assert(op.getMemoryConfig().has_value()); MemoryConfigAttr outputMemoryConfig = op.getMemoryConfig().value(); @@ -131,9 +140,12 @@ class TTNNDecomposeLayouts input.tensorMemoryLayout = inputLayoutAttr.getMemLayout(); output.tensorMemoryLayout = outputMemoryConfig.getTensorMemoryLayout(); + input.shardGrid = inputLayoutAttr.getGrid(); + output.shardGrid = outputLayoutAttr.getGrid(); + input.shardShape = inputLayoutAttr.getShardShape(); - output.shardShape = - llvm::SmallVector{outputMemoryConfig.getShardShapeArray()}; + output.shardShape = outputLayoutAttr.getShardShape(); + return {input, output}; } @@ -148,14 +160,6 @@ class TTNNDecomposeLayouts opsToCreate.createTypecastOp = input.dataType != output.dataType; opsToCreate.createToLayoutOp = input.layoutEnum != output.layoutEnum; - // TODO(bug #665): - // Insert a ToLayoutOp manually if we're moving from device to host to - // untilize. Since we're hardcoding tile layout, the tensor may be row - // major in mlir, and therefore it would appear as if we don't need to - // untilize - opsToCreate.createToLayoutOp |= - (opsToCreate.createFromDeviceOp and - output.layoutEnum == ttnn::Layout::RowMajor); // ToDeviceOp can handle the creation of the memory config of the initial // device tensor @@ -168,8 +172,10 @@ class TTNNDecomposeLayouts output.bufferType == ttnn::BufferType::L1) or (input.bufferType == ttnn::BufferType::L1 and output.bufferType == ttnn::BufferType::DRAM); + // If shard grids don't match we need to reshard opsToCreate.createToMemoryConfigOp |= - (input.shardShape != output.shardShape); + (input.isL1Sharded() and output.isL1Sharded() and + input.shardGrid != output.shardGrid); } return opsToCreate; } @@ -764,24 +770,30 @@ class TTNNDecomposeLayouts * sizeof(uint32_t). For now, we will always untilize on host. We rarely * need device to device untilize, so the perf hit should be acceptable. */ - void createLayoutConversionOps(ttnn::ToLayoutOp op, - IRRewriter &rewriter) const { + mlir::LogicalResult createLayoutConversionOps(ttnn::ToLayoutOp op, + IRRewriter &rewriter) const { auto [input, output] = getInputOutputLayouts(op); OpsToCreate opsToCreate = determineRequiredOps(input, output); - assert(isCreationValid(op, input, output, opsToCreate) && - "Invalid layout conversion"); + if (not isCreationValid(op, input, output, opsToCreate)) { + return failure(); + } + auto device = op.getDevice(); - assert((device || output.isOnHost()) && - "Op device must be set for output tensors on device"); + if (not device and not output.isOnHost()) { + op->emitError("Device not specified for device tensor"); + return failure(); + } + OpCreationInfo info(device, input, output, opsToCreate); Value currentInput = op.getInput(); if (input.isOnHost()) { handleHostInputLayoutConversion(op, rewriter, currentInput, info); - return; + return success(); } handleDeviceInputLayoutConversion(op, rewriter, currentInput, info); + return success(); } }; } // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp index e148b575fb..5aa853aa98 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp @@ -20,9 +20,6 @@ namespace mlir::tt::ttnn { static const std::array, 1> g_defaultCollapseDims = {{{0, -1}}}; -// Default memory space for tensors on host -static const BufferType g_defaultMemorySpaceHost = BufferType::SystemMemory; - // Default memory space for tesnors on device static const BufferType g_defaultMemorySpaceDevice = BufferType::DRAM; @@ -55,7 +52,7 @@ inline Location appendInputSuffix(Location loc, int64_t operandIndex) { // // Example: tensor<15x10x32xf32> -> tensor<15x10x32xf32, ttnn_layout<...>> // where ttnn_layout<...> is constructed with default values -// SystemMemory, MemoryLayout::None, Grid<1x1> +// Dram, MemoryLayout::Interleaved, Grid<1x1> class TTNNLayoutTensorTypeConverter : public TypeConverter { public: TTNNLayoutTensorTypeConverter(MLIRContext *ctx, GridAttr deviceGrid) { @@ -74,9 +71,12 @@ class TTNNLayoutTensorTypeConverter : public TypeConverter { llvm::ArrayRef> collapseDimsRef( g_defaultCollapseDims); + // Force TileType for tensors + auto elementType = TileType::get(ctx, type.getElementType()); TTNNLayoutAttr newLayout = TTNNLayoutAttr::get( - ctx, type.getShape(), type.getElementType(), g_defaultMemorySpaceHost, - tensorGrid, nullptr /* memLayoutAttr */, collapseDimsRef); + ctx, type.getShape(), elementType, g_defaultMemorySpaceDevice, + tensorGrid, TensorMemoryLayoutAttr::get(ctx, g_defaultMemoryLayout), + collapseDimsRef); return RankedTensorType::get(type.getShape(), type.getElementType(), newLayout); }); @@ -167,7 +167,13 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, BufferType currBufferType = ttnnLayoutAttr.getBufferType(); // Get the current element type (i.e bf16/TileType etc) + // If the defining op is arange, then we need to assume ROW_MAJOR (scalar) + // element type. Type currElementType = ttnnLayoutAttr.getElementType(); + ttir::ArangeOp existingArange = input.getDefiningOp(); + if (existingArange) { + currElementType = ttnnLayoutAttr.getScalarElementType(); + } // Get mem layout. If the tensor is on host layout is null TensorMemoryLayoutAttr currMemLayout = ttnnLayoutAttr.getMemLayout(); @@ -220,7 +226,6 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, // it is ROW_MAJOR - and to make it tile layout we still must insert // ToLayoutOp on its output. We can do this by setting the element type to // ty.getElementType() in case desiredElementType is a TileType. - ttir::ArangeOp existingArange = input.getDefiningOp(); if (existingArange) { TTNNLayoutAttr arangeLayout = rewriter.getAttr( ty.getShape(), ty.getElementType(), desiredBufferType, @@ -306,12 +311,15 @@ class TTNNLayoutDPSOperandsRewriter Location newLoc = appendInputSuffix(op.getLoc(), operand.getOperandNumber()); + + bool isTiled = shouldTilize(op, operand.getOperandNumber()); + // Given the operand constraint, create the desired layout for the operand std::optional desiredLayout = createToLayoutOp( rewriter, newLoc, operand.get(), g_defaultMemorySpaceDevice, TensorMemoryLayoutAttr::get(rewriter.getContext(), g_defaultMemoryLayout), - true /* isTiled */); + isTiled); // If layout changed update the operand if (desiredLayout) { @@ -328,6 +336,33 @@ class TTNNLayoutDPSOperandsRewriter return modified ? success() : failure(); } + +private: + bool shouldTilize(DestinationStyleOpInterface dpsOp, + int64_t operandNumber) const { + + Operation *operation = dpsOp.getOperation(); + + // TTNN Reshape does not support implicit tilization/untilization + // Therefore input output layouts should be the same + if (mlir::isa(operation) && operandNumber == 1) { + Value input = dpsOp->getOperand(0); + RankedTensorType inputType = + mlir::cast(input.getType()); + TTNNLayoutAttr inputLayout = + mlir::cast(inputType.getEncoding()); + return mlir::isa(inputLayout.getElementType()); + } + + // These ops constrain to ROW_MAJOR on their operands + if (mlir::isa(operation) || + mlir::isa(operation) || + (mlir::isa(operation) && + operandNumber < 2)) { + return false; + } + return true; + } }; // Updates the layout of the operands of a func::ReturnOp. @@ -342,11 +377,14 @@ class TTNNLayoutFuncReturnRewriter PatternRewriter &rewriter) const final { bool modified = false; for (OpOperand &operand : op->getOpOperands()) { + bool isTiled = true; Location newLoc = appendInputSuffix(op.getLoc(), operand.getOperandNumber()); std::optional layout = createToLayoutOp( - rewriter, newLoc, operand.get(), BufferType::SystemMemory, - nullptr /* tensorMemoryLayoutAttr */, false /* tiled */); + rewriter, newLoc, operand.get(), g_defaultMemorySpaceDevice, + TensorMemoryLayoutAttr::get(rewriter.getContext(), + g_defaultMemoryLayout), + isTiled); if (layout.has_value()) { rewriter.modifyOpInPlace( op, [&]() { op.setOperand(operand.getOperandNumber(), *layout); }); @@ -355,8 +393,6 @@ class TTNNLayoutFuncReturnRewriter } return modified ? success() : failure(); } - -private: }; class TTNNLayout : public impl::TTNNLayoutBase { diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp index 74d527c424..16f0f197bd 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp @@ -410,6 +410,57 @@ class TTNNAllReduceWorkarounds : public OpRewritePattern { } }; +// ttnn::FullOp does not support 1D tilized tensors +// If the output of full is a 1D tensor and is tiled +// we need to convert it to row major layout then tilize separately +class TTNNFullOpWorkaround : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttnn::FullOp op, + PatternRewriter &rewriter) const override { + Value device = op.getDevice(); + ::mlir::FloatAttr fillValueAttr = op.getFillValueAttr(); + auto outputType = mlir::cast(op.getType()); + ttnn::TTNNLayoutAttr layoutAttr = + mlir::cast(outputType.getEncoding()); + // If the output is not a 1D tensor or the element type is not tilized + // we don't need to apply the workaround + if (outputType.getRank() > 1 || + !mlir::isa(layoutAttr.getElementType())) { + return failure(); + } + + // Can't use withElementType because the shard shape would be wrong + ttnn::TTNNLayoutAttr rowMajorLayoutAttr = ttnn::TTNNLayoutAttr::get( + rewriter.getContext(), outputType.getShape(), + layoutAttr.getScalarElementType(), layoutAttr.getBufferType(), + layoutAttr.getGrid(), layoutAttr.getMemLayout()); + + auto fullOpOutputType = RankedTensorType::get( + outputType.getShape(), outputType.getElementType(), rowMajorLayoutAttr); + auto fullOp = rewriter.create(op.getLoc(), fullOpOutputType, + device, fillValueAttr); + + // Tilize the fullOp output separately + ttnn::MemoryConfigAttr memConfigAttr = + rewriter.getAttr( + rewriter.getAttr(layoutAttr.getBufferType()), + rewriter.getAttr( + rewriter.getAttr(layoutAttr.getShardShape())), + layoutAttr.getMemLayout()); + + bool isOutputOnHost = + (layoutAttr.getBufferType() == ttnn::BufferType::SystemMemory); + + rewriter.replaceOpWithNewOp( + op, op.getType(), fullOp, ttnn::Layout::Tile, + DataTypeAttr::get(rewriter.getContext(), layoutAttr.getDataType()), + memConfigAttr, isOutputOnHost ? nullptr : device); + return success(); + } +}; + // Pass to apply workarounds to the operands of TTNN operations. class TTNNWorkarounds : public impl::TTNNWorkaroundsBase { public: @@ -418,7 +469,7 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase { void runOnOperation() final { if (decompositionWorkaroundsEnabled) { RewritePatternSet patterns(&getContext()); - patterns.add, workarounds::decomposition::ReduceOpsKeepDimRewritePattern< @@ -435,7 +486,7 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase { runRewritePatterns(std::move(patterns), GreedyRewriteConfig::kNoLimit /*maxIterations*/); } - if (layouotWorkaroundsEnabled) { + if (layoutWorkaroundsEnabled) { RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); diff --git a/runtime/include/tt/runtime/detail/workarounds.h b/runtime/include/tt/runtime/detail/workarounds.h index cb05a7d98c..c109a3db1e 100644 --- a/runtime/include/tt/runtime/detail/workarounds.h +++ b/runtime/include/tt/runtime/detail/workarounds.h @@ -17,12 +17,12 @@ struct Env { #endif get(bool maxpool2dPreshard = true, bool swapBinaryOperands = true, bool readUpdateIndexFromDeviceForKVCache = true, - bool toDtypeOnHost = true) + bool toDtypeOnHost = true, bool toLayoutAPIAssumeSingleChip = true) #if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1 ; #else { - return Env(true, true, true, true); + return Env(true, true, true, true, true); } #endif // TODO(bug #855): Ideally we should have an op that preshards for maxpool2d @@ -45,14 +45,25 @@ struct Env { // to handle this, we should remove this workaround. bool toDtypeOnHost; + // TODO(bug #1778): We currently don't have device grid information (mesh + // shape, offset) in the flatbuffer TensorDesc nor in the mlir LayoutAttr. We + // need to add this information to the tensorDesc so that the runtime toLayout + // API can determine the correct devices. Enabling this workaround will assume + // that a device tensor will reside in the L1/Dram of the first device (device + // id 0) of the device grid. This should be removed once we add the device + // grid information to the tensorDesc. + bool toLayoutAPIAssumeSingleChip; + private: constexpr Env(bool maxpool2dPreshard, bool swapBinaryOperands, - bool readUpdateIndexFromDeviceForKVCache, bool toDtypeOnHost) + bool readUpdateIndexFromDeviceForKVCache, bool toDtypeOnHost, + bool toLayoutAPIAssumeSingleChip) : maxpool2dPreshard(maxpool2dPreshard), swapBinaryOperands(swapBinaryOperands), readUpdateIndexFromDeviceForKVCache( readUpdateIndexFromDeviceForKVCache), - toDtypeOnHost(toDtypeOnHost) {} + toDtypeOnHost(toDtypeOnHost), + toLayoutAPIAssumeSingleChip(toLayoutAPIAssumeSingleChip) {} }; inline std::ostream &operator<<(std::ostream &os, const Env &env) { @@ -66,6 +77,9 @@ inline std::ostream &operator<<(std::ostream &os, const Env &env) { << env.readUpdateIndexFromDeviceForKVCache << "\n"; os << "\t" << "toDtypeOnHost: " << env.toDtypeOnHost << "\n"; + os << "\t" + << "toLayoutAPIAssumeSingleChip: " << env.toLayoutAPIAssumeSingleChip + << "\n"; os << "}"; return os; } diff --git a/runtime/lib/common/workarounds.cpp b/runtime/lib/common/workarounds.cpp index 9dc45d964e..3d69c9c285 100644 --- a/runtime/lib/common/workarounds.cpp +++ b/runtime/lib/common/workarounds.cpp @@ -8,9 +8,10 @@ namespace tt::runtime::workaround { #if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1 const Env &Env::get(bool maxpool2dPreshard, bool swapBinaryOperands, bool readUpdateIndexFromDeviceForKVCache, - bool toDtypeOnHost) { + bool toDtypeOnHost, bool toLayoutAPIAssumeSingleChip) { static const Env config(maxpool2dPreshard, swapBinaryOperands, - readUpdateIndexFromDeviceForKVCache, toDtypeOnHost); + readUpdateIndexFromDeviceForKVCache, toDtypeOnHost, + toLayoutAPIAssumeSingleChip); return config; } #endif diff --git a/runtime/lib/ttnn/include/tt/runtime/ttnn/types.cpp b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.cpp index ae956f6ca4..07afbfd3f3 100644 --- a/runtime/lib/ttnn/include/tt/runtime/ttnn/types.cpp +++ b/runtime/lib/ttnn/include/tt/runtime/ttnn/types.cpp @@ -48,10 +48,13 @@ ::ttnn::Tensor LayoutConverter::toLayoutIfNeeded(const ::ttnn::Tensor &input) { } ::ttnn::Tensor LayoutConverter::typecastIfNeeded(const ::ttnn::Tensor &input) { - if (shouldTypecast) { - return ::ttnn::typecast(input, outputDesc.dataType); + if (not shouldTypecast) { + return input; } - return input; + if (utils::isOnHost(input.storage_type())) { + return ::ttnn::to_dtype(input, outputDesc.dataType); + } + return ::ttnn::typecast(input, outputDesc.dataType); } ::ttnn::Tensor diff --git a/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.cpp b/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.cpp index fa8aa82ed2..095118ce63 100644 --- a/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.cpp +++ b/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.cpp @@ -131,11 +131,6 @@ ::ttnn::BufferType toTTNNBufferType(::tt::target::BufferType bufferType) { } }; -std::vector -toShapeFromFBShape(const flatbuffers::Vector &vec) { - return std::vector(vec.begin(), vec.end()); -} - ::ttnn::Layout inferLayoutFromTileShape(const ::tt::target::TensorRef *tensorRef) { const ::tt::target::Dim2d *tileShape = diff --git a/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h b/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h index a322bf009d..3aed25eb8a 100644 --- a/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h +++ b/runtime/lib/ttnn/include/tt/runtime/ttnn/utils.h @@ -36,9 +36,6 @@ toTTNNBufferType(::tt::target::MemorySpace memorySpace); // ::ttnn::BufferType toTTNNBufferType(::tt::target::BufferType bufferType); -std::vector -toShapeFromFBShape(const flatbuffers::Vector &vec); - ::ttnn::Layout inferLayoutFromTileShape(const ::tt::target::TensorRef *tensorRef); @@ -51,13 +48,6 @@ createMemoryConfig(const ::tt::target::TensorRef *tensorRef); Tensor createRuntimeTensorFromTTNN(const ::ttnn::Tensor &tensor); -// TODO: (#1435): Fix int types across shapes -// -inline std::vector -toShapeFromFBShape(const flatbuffers::Vector &vec) { - return std::vector(vec.begin(), vec.end()); -} - } // namespace tt::runtime::ttnn::utils #endif diff --git a/runtime/lib/ttnn/operations/creation/empty.cpp b/runtime/lib/ttnn/operations/creation/empty.cpp index 530bf8cd33..a4514e212b 100644 --- a/runtime/lib/ttnn/operations/creation/empty.cpp +++ b/runtime/lib/ttnn/operations/creation/empty.cpp @@ -19,11 +19,11 @@ struct EmptyTensorConfig { std::optional<::ttnn::MemoryConfig> memoryConfig = std::nullopt; EmptyTensorConfig(const ::tt::target::ttnn::EmptyOp *op) - : shape(::tt::runtime::ttnn::utils::toShapeFromFBShape( + : shape(::tt::runtime::ttnn::operations::utils::toTTNNShape( *op->out()->desc()->shape())), dtype(::tt::runtime::ttnn::operations::utils::getDataType(op->out())), + layout(::tt::runtime::ttnn::utils::toTTNNLayout(op->layout())), numShards(op->num_shards()), strategy(op->strategy()) { - layout = ::tt::runtime::ttnn::utils::toTTNNLayout(op->layout()); if (op->device()) { LOG_ASSERT(op->memcfg(), "Memory config must be provided when device is provided"); diff --git a/runtime/lib/ttnn/operations/creation/full.cpp b/runtime/lib/ttnn/operations/creation/full.cpp index 200dc6969f..2e8b4669c0 100644 --- a/runtime/lib/ttnn/operations/creation/full.cpp +++ b/runtime/lib/ttnn/operations/creation/full.cpp @@ -20,14 +20,13 @@ struct FullTensorConfig { std::optional<::ttnn::MemoryConfig> memoryConfig = std::nullopt; FullTensorConfig(const ::tt::target::ttnn::FullOp *op) - : shape(::tt::runtime::ttnn::utils::toShapeFromFBShape( + : shape(::tt::runtime::ttnn::operations::utils::toTTNNShape( *op->out()->desc()->shape())), dtype(::tt::runtime::ttnn::operations::utils::getDataType(op->out())), + layout(::tt::runtime::ttnn::utils::inferLayoutFromTileShape(op->out())), fillValue(op->fill_value()), numShards(op->num_shards()), strategy(op->strategy()) { - layout = ::tt::runtime::ttnn::utils::inferLayoutFromTileShape(op->out()); - if (!utils::inSystemMemory(op->out())) { memoryConfig = ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); } diff --git a/runtime/lib/ttnn/operations/creation/ones.cpp b/runtime/lib/ttnn/operations/creation/ones.cpp index 6172037363..2af9a66046 100644 --- a/runtime/lib/ttnn/operations/creation/ones.cpp +++ b/runtime/lib/ttnn/operations/creation/ones.cpp @@ -18,8 +18,8 @@ namespace tt::runtime::ttnn::operations::creation { void run(const ::tt::target::ttnn::OnesOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); - const ::ttnn::Shape shape = ::ttnn::Shape(::tt::tt_metal::LegacyShape( - ::tt::runtime::ttnn::utils::toShapeFromFBShape(*op->shape()))); + const ::ttnn::Shape shape = + ::tt::runtime::ttnn::operations::utils::toTTNNShape(*op->shape()); std::optional<::ttnn::DataType> dtype = std::optional<::ttnn::DataType>(); std::optional<::ttnn::Layout> layout = std::optional<::ttnn::Layout>(); diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h index 269e0328f9..a1ad4b5e9c 100644 --- a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/utils.h @@ -9,6 +9,7 @@ #include "tt/runtime/ttnn/types.h" #include "ttmlir/Target/TTNN/program_generated.h" #include "types_generated.h" +#include #include namespace tt::runtime::ttnn::operations::utils { @@ -29,5 +30,14 @@ createMemoryConfig(const ::tt::target::MemoryConfigDesc *memcfg, ::tt::tt_metal::DistributedTensorConfig distributedTensorConfigFromFlatbuffer( const ::tt::target::DistributionStrategy *strategy); +template +inline ::ttnn::Shape toTTNNShape(const flatbuffers::Vector &vec) { + std::vector rawShape; + rawShape.reserve(vec.size()); + std::transform( + vec.begin(), vec.end(), std::back_inserter(rawShape), + [](const T &x) -> uint32_t { return static_cast(x); }); + return ::ttnn::Shape(rawShape); +} } // namespace tt::runtime::ttnn::operations::utils #endif diff --git a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp index 51e58e0ebf..b440c4024a 100644 --- a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp +++ b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp @@ -19,9 +19,9 @@ template static ::ttnn::Tensor preshardForMaxPool2d(const ::tt::target::ttnn::MaxPool2dOp *op, DeviceType &device, const ::ttnn::Tensor &input) { - const ::ttnn::Shape inputShape = ::ttnn::Shape(::tt::tt_metal::LegacyShape( - ::tt::runtime::ttnn::utils::toShapeFromFBShape( - *op->in()->desc()->shape()))); + const ::ttnn::Shape inputShape = + ::tt::runtime::ttnn::operations::utils::toTTNNShape( + *op->in()->desc()->shape()); uint32_t output_height = 1 + (op->input_height() + 2 * op->padding_height() - op->dilation_height() * (op->kernel_height() - 1) - 1) / diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index c527a94d7e..b0266e4c66 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -4,6 +4,7 @@ #include "tt/runtime/detail/debug.h" #include "tt/runtime/detail/logger.h" #include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/detail/workarounds.h" #include "tt/runtime/ttnn/types.h" #include "tt/runtime/ttnn/utils.h" #include "tt/runtime/utils.h" @@ -269,6 +270,9 @@ Tensor toLayout(Tensor tensor, Device device, Layout layout) { ::ttnn::MeshDevice &meshDevice = device.as<::ttnn::MeshDevice>(DeviceRuntime::TTNN); DeviceVariant targetDevice = getTargetDevice(meshDevice); + if (workaround::Env::get().toLayoutAPIAssumeSingleChip) { + targetDevice = std::ref(*(meshDevice.get_device_index(0))); + } LayoutConverter converter(inputLayoutDesc, outputLayoutDesc); std::shared_ptr<::ttnn::Tensor> out = std::make_shared<::ttnn::Tensor>( converter.convertTensorLayout(ttnnTensor, targetDevice)); diff --git a/runtime/test/python/ttnn/test_runtime_api.py b/runtime/test/python/ttnn/test_runtime_api.py index 5454cbcd9a..b7ee8c38a2 100644 --- a/runtime/test/python/ttnn/test_runtime_api.py +++ b/runtime/test/python/ttnn/test_runtime_api.py @@ -34,7 +34,7 @@ def test_to_layout(helper: Helper, shape, dtype, request): ) device_layout = ttrt.runtime.testing.get_dram_interleaved_tile_layout(runtime_dtype) host_layout = ttrt.runtime.testing.get_host_row_major_layout(runtime_dtype) - with DeviceContext([helper.query.device_ids[0]]) as device: + with DeviceContext(helper.query.device_ids) as device: device_tensor = ttrt.runtime.to_layout( runtime_input_tensor, device, device_layout ) @@ -133,19 +133,20 @@ def test_create_tensor_memcpy(helper: Helper, shape, dtype, request): def test_runtime_stitching_eltwise_binary_op_chain(helper: Helper, request): - binary_path = f"{TT_MLIR_HOME}/build/test/ttmlir/Runtime/TTNN/runtime_stitching/Output/eltwise_binary_op_chain.mlir.tmp.ttnn" + binary_path = f"{TT_MLIR_HOME}/build/test/ttmlir/Silicon/TTNN/runtime_stitching/Output/eltwise_binary_op_chain.mlir.tmp.ttnn" helper.initialize(request.node.name, binary_path) helper.check_constraints() + first_program: Binary.Program = helper.binary.get_program(0) assert first_program.num_inputs() == 2 inputs_torch = [] inputs_runtime = [] input_layouts = [] - for i in first_program.program["inputs"]: + for i, program_input in enumerate(first_program.program["inputs"]): torch_tensor = torch.randn( - i["desc"]["shape"], + program_input["desc"]["shape"], dtype=Binary.Program.from_data_type( - i["desc"]["layout"]["memory_desc"]["data_type"] + program_input["desc"]["layout"]["memory_desc"]["data_type"] ), ) runtime_dtype = Binary.Program.to_data_type(torch_tensor.dtype) @@ -159,15 +160,27 @@ def test_runtime_stitching_eltwise_binary_op_chain(helper: Helper, request): ) inputs_runtime.append(runtime_tensor) input_layouts.append( - ttrt.runtime.testing.get_dram_interleaved_row_major_layout(runtime_dtype) + ttrt.runtime.get_layout( + executable=helper.binary.fbb, program_index=0, input_index=i + ) ) + program_indices = list(range(helper.binary.get_num_programs())) + last_program: Binary.Program = helper.binary.get_program(program_indices[-1]) + torch_result_tensor = torch.randn( + last_program.program["outputs"][0]["desc"]["shape"], + dtype=Binary.Program.from_data_type( + last_program.program["outputs"][0]["desc"]["layout"]["memory_desc"][ + "data_type" + ] + ), + ) + activations, weights = inputs_runtime activations_layout, weights_layout = input_layouts - with DeviceContext([helper.query.device_ids[0]]) as device: + with DeviceContext(helper.query.device_ids) as device: activations = ttrt.runtime.to_layout(activations, device, activations_layout) weights = ttrt.runtime.to_layout(weights, device, weights_layout) - program_indices = list(range(helper.binary.get_num_programs())) for program_index in program_indices: program = helper.binary.get_program(program_index) assert program.num_inputs() == 2 and program.num_outputs() == 1 @@ -175,20 +188,13 @@ def test_runtime_stitching_eltwise_binary_op_chain(helper: Helper, request): device, helper.binary.fbb, program_index, [activations, weights] ) activations = ttrt.runtime.to_layout(outputs[0], device, activations_layout) - ttrt.runtime.deallocate_tensor(outputs[0], force=True) - activations = ttrt.runtime.to_host(activations, untilize=True) + ttrt.runtime.deallocate_tensor(outputs[0]) + final_result = ttrt.runtime.to_host(activations, untilize=True) + ttrt.runtime.memcpy(torch_result_tensor.data_ptr(), final_result) + ttrt.runtime.deallocate_tensor(activations, force=True) ttrt.runtime.deallocate_tensor(weights, force=True) + ttrt.runtime.deallocate_tensor(final_result, force=True) - last_program: Binary.Program = helper.binary.get_program(program_indices[-1]) - torch_result_tensor = torch.randn( - last_program.program["outputs"][0]["desc"]["shape"], - dtype=Binary.Program.from_data_type( - last_program.program["outputs"][0]["desc"]["layout"]["memory_desc"][ - "data_type" - ] - ), - ) - ttrt.runtime.memcpy(torch_result_tensor.data_ptr(), activations) golden = ( (inputs_torch[0] + inputs_torch[1]).mul(inputs_torch[1]).sub(inputs_torch[1]) ) diff --git a/runtime/tools/python/ttrt/common/run.py b/runtime/tools/python/ttrt/common/run.py index b83c5d390f..d639481a5a 100644 --- a/runtime/tools/python/ttrt/common/run.py +++ b/runtime/tools/python/ttrt/common/run.py @@ -147,6 +147,13 @@ def initialize_api(): choices=[True, False], help="disable to_dtype on host workaround", ) + Run.register_arg( + name="--disable-to-layout-api-assume-single-chip", + type=bool, + default=False, + choices=[True, False], + help="disable runtime to_layout api assume single chip workaround", + ) Run.register_arg( name="--result-file", type=str, @@ -387,6 +394,7 @@ def _execute(binaries): not self["--disable-swap-binary-operands"], not self["--disable-read-update-index-for-kv-cache"], not self["--disable-to-dtype-on-host"], + not self["--disable-to-layout-api-assume-single-chip"], ) self.logging.debug(f"setting tt runtime workaround env={workaround_env}") self.logging.debug(f"setting torch manual seed={self['--seed']}") @@ -529,9 +537,12 @@ def _execute(binaries): for i, runtime_output_tensor in enumerate( runtime_outputs ): + output_host = ttrt.runtime.to_host( + runtime_output_tensor, untilize=True + ) ttrt.runtime.memcpy( total_outputs[loop][i], - runtime_output_tensor, + output_host, ) ttrt.runtime.deallocate_tensor( runtime_output_tensor, force=True diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/expm1/simple_expm1.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/expm1/simple_expm1.mlir index bbcbf5dd6f..a8228fe9c0 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/expm1/simple_expm1.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/expm1/simple_expm1.mlir @@ -4,7 +4,7 @@ module attributes {} { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> %1 = "ttir.expm1"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %{{[0-9]+}} = "ttnn.expm1"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> + // CHECK: %{{[0-9]+}} = "ttnn.expm1"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/log1p/simple_log1p.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/log1p/simple_log1p.mlir index 4258e639cb..7d6ca51f3a 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/log1p/simple_log1p.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/log1p/simple_log1p.mlir @@ -4,7 +4,7 @@ module attributes {} { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> %1 = "ttir.log1p"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %{{[0-9]+}} = "ttnn.log1p"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> + // CHECK: %{{[0-9]+}} = "ttnn.log1p"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/relu/simple_relu.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/relu/simple_relu.mlir index d6b46aae65..b45c3c4087 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/relu/simple_relu.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/relu/simple_relu.mlir @@ -1,15 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#l1 = #ttnn.buffer_type -#system = #ttnn.buffer_type -#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xf32, #system>> -#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8>, memref<8x16xf32, #system>> -#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8>, memref<8x16xf32, #l1>, > module attributes {} { - func.func @forward(%arg0: tensor<64x128xf32, #ttnn_layout>) -> tensor<64x128xf32, #ttnn_layout1> { + func.func @relu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<64x128xf32, #ttnn_layout1> + %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.relu"[[C:.*]] - %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32, #ttnn_layout>, tensor<64x128xf32, #ttnn_layout1>) -> tensor<64x128xf32, #ttnn_layout1> - return %1 : tensor<64x128xf32, #ttnn_layout1> + %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/sign/simple_sign.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/sign/simple_sign.mlir index 170eb1b53c..ccc3b82a84 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/sign/simple_sign.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/sign/simple_sign.mlir @@ -4,7 +4,7 @@ module attributes {} { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> %1 = "ttir.sign"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %{{[0-9]+}} = "ttnn.sign"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> + // CHECK: %{{[0-9]+}} = "ttnn.sign"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/tan/simple_tan.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/tan/simple_tan.mlir index 72d8e14161..987d459aba 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/tan/simple_tan.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/tan/simple_tan.mlir @@ -4,7 +4,7 @@ module attributes {} { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> %1 = "ttir.tan"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %{{[0-9]+}} = "ttnn.tan"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> + // CHECK: %{{[0-9]+}} = "ttnn.tan"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/tanh/simple_tanh.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/tanh/simple_tanh.mlir index 530b4c79bd..62618ae829 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/tanh/simple_tanh.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/tanh/simple_tanh.mlir @@ -4,7 +4,7 @@ module attributes {} { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> %1 = "ttir.tanh"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %{{[0-9]+}} = "ttnn.tanh"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> + // CHECK: %{{[0-9]+}} = "ttnn.tanh"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/fork_join_01.mlir b/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/fork_join_01.mlir index 73001109ac..1566154286 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/fork_join_01.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/fork_join_01.mlir @@ -16,19 +16,19 @@ module attributes {} { func.func @forward(%arg0: tensor<4096x5120xbf16>, %arg1: tensor<5120x1024xbf16>, %arg2: tensor<5120x1024xbf16>) -> tensor<4096x1024xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x320x!tt.tile<32x32, bf16>, #l1_>, > - // CHECK: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x64x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x64x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x320x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<4096x5120xbf16> - // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_5]]> + // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_3]]> %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<4096x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16> %2 = tensor.empty() : tensor<4096x1024xbf16> - // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<4096x1024xbf16, #[[LAYOUT_6]]> + // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<4096x1024xbf16, #[[LAYOUT_2]]> %3 = "ttir.matmul"(%1, %arg1, %2) : (tensor<4096x5120xbf16>, tensor<5120x1024xbf16>, tensor<4096x1024xbf16>) -> tensor<4096x1024xbf16> %4 = tensor.empty() : tensor<4096x1024xbf16> - // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<4096x1024xbf16, #[[LAYOUT_6]]> + // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<4096x1024xbf16, #[[LAYOUT_2]]> %5 = "ttir.matmul"(%1, %arg2, %4) : (tensor<4096x5120xbf16>, tensor<5120x1024xbf16>, tensor<4096x1024xbf16>) -> tensor<4096x1024xbf16> %6 = tensor.empty() : tensor<4096x1024xbf16> - // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<4096x1024xbf16, #[[LAYOUT_6]]> + // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<4096x1024xbf16, #[[LAYOUT_2]]> %7 = "ttir.add"(%3, %5, %6) <{operandSegmentSizes = array}> : (tensor<4096x1024xbf16>, tensor<4096x1024xbf16>, tensor<4096x1024xbf16>) -> tensor<4096x1024xbf16> return %7 : tensor<4096x1024xbf16> } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/fork_join_02.mlir b/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/fork_join_02.mlir index 40d98d135d..85fdce7c14 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/fork_join_02.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/bf_interleaved_policy/fork_join_02.mlir @@ -19,23 +19,23 @@ module attributes {} { func.func @forward(%arg0: tensor<4096x5120xbf16>, %arg1: tensor<5120x9216xbf16>, %arg2: tensor<9216x1024xbf16>, %arg3: tensor<5120x1024xbf16>) -> tensor<4096x1024xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK: #[[LAYOUT_9:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x320x!tt.tile<32x32, bf16>, #l1_>, > - // CHECK: #[[LAYOUT_10:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x36x!tt.tile<32x32, bf16>, #dram>, > - // CHECK: #[[LAYOUT_11:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x64x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK: #[[LAYOUT_4:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x64x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x320x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x36x!tt.tile<32x32, bf16>, #dram>, > %0 = tensor.empty() : tensor<4096x5120xbf16> - // CHECK-DAG: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_9]]> + // CHECK-DAG: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_5]]> %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<4096x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16> %2 = tensor.empty() : tensor<4096x9216xbf16> - // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<4096x9216xbf16, #[[LAYOUT_10]]> + // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<4096x9216xbf16, #[[LAYOUT_6]]> %3 = "ttir.matmul"(%1, %arg1, %2) : (tensor<4096x5120xbf16>, tensor<5120x9216xbf16>, tensor<4096x9216xbf16>) -> tensor<4096x9216xbf16> %4 = tensor.empty() : tensor<4096x1024xbf16> - // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<4096x1024xbf16, #[[LAYOUT_11]]> + // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<4096x1024xbf16, #[[LAYOUT_4]]> %5 = "ttir.matmul"(%3, %arg2, %4) : (tensor<4096x9216xbf16>, tensor<9216x1024xbf16>, tensor<4096x1024xbf16>) -> tensor<4096x1024xbf16> %6 = tensor.empty() : tensor<4096x1024xbf16> - // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<4096x1024xbf16, #[[LAYOUT_11]]> + // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<4096x1024xbf16, #[[LAYOUT_4]]> %7 = "ttir.matmul"(%1, %arg3, %6) : (tensor<4096x5120xbf16>, tensor<5120x1024xbf16>, tensor<4096x1024xbf16>) -> tensor<4096x1024xbf16> %8 = tensor.empty() : tensor<4096x1024xbf16> - // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<4096x1024xbf16, #[[LAYOUT_11]]> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<4096x1024xbf16, #[[LAYOUT_4]]> %9 = "ttir.add"(%5, %7, %8) <{operandSegmentSizes = array}> : (tensor<4096x1024xbf16>, tensor<4096x1024xbf16>, tensor<4096x1024xbf16>) -> tensor<4096x1024xbf16> return %9 : tensor<4096x1024xbf16> } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/fork_join.mlir b/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/fork_join.mlir index 657da93390..043f1f97d4 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/fork_join.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/greedy_l1_interleaved_policy/fork_join.mlir @@ -21,8 +21,8 @@ module attributes {} { func.func @forward(%arg0: tensor<64x64xbf16>, %arg1: tensor<64x32xbf16>) -> tensor<64x32xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type + // CHECK: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x1x!tt.tile<32x32, bf16>, #l1_>, > // CHECK: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x1x!tt.tile<32x32, bf16>, #dram>, > - // CHECK: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x1x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<64x64xbf16> // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x64xbf16, #[[LAYOUT_3]]> %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> @@ -33,10 +33,10 @@ module attributes {} { %6 = tensor.empty() : tensor<64x32xbf16> %7 = "ttir.relu"(%5, %6) <{operandSegmentSizes = array}> : (tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> %8 = tensor.empty() : tensor<64x32xbf16> - // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_5]]> - // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_5]]> - // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x64xbf16, #[[LAYOUT_5]]> - // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_5]]> + // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_2]]> + // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_2]]> + // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x64xbf16, #[[LAYOUT_2]]> + // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_2]]> %9 = "ttir.matmul"(%3, %7, %8) : (tensor<64x64xbf16>, tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> return %9 : tensor<64x32xbf16> } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir b/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir index 16b0eb1b53..2d15f9ad46 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir @@ -4,12 +4,11 @@ // CHECK-DAG: #[[LOC_MATMUL_IN1:.*]] = loc("matmul_1_in_1_layout"(#loc3)) // CHECK-DAG: #[[LOC_MATMUL:.*]] = loc("matmul_1"(#loc3)) // CHECK-DAG: #[[IN_1_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x12x!tt.tile<32x32, bf16>, #l1_>, > - +// XFAIL: * +// TODO: Layout override on the optimizer needs update after default input/output tile layout. module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> { %0 = tensor.empty() : tensor<64x96xbf16> loc(#loc2) - // CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} loc(#[[LOC_MATMUL_IN0]]) - // CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} <{memory_config = #ttnn.memory_config<#l1_, <<1x12>>, >}> : {{.*}} -> tensor<128x96xbf16, #[[IN_1_LAYOUT]]> loc(#[[LOC_MATMUL_IN1]]) // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} loc(#[[LOC_MATMUL]]) %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> loc(#loc2) return %1 : tensor<64x96xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc.mlir b/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc.mlir index 66e1ec0836..f2e4668f67 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc.mlir @@ -12,7 +12,7 @@ module attributes {} { %4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7) // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT]]> %5 = "ttir.add"(%arg2, %arg1, %4) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7) - // CHECK: return %[[R0:.*]], %[[R1:.*]] : tensor<1x32x32xf32, #ttnn_layout>, tensor<1x32x32xf32, #ttnn_layout> + // CHECK: return %[[R0:.*]], %[[R1:.*]] : tensor<1x32x32xf32, #ttnn_layout1>, tensor<1x32x32xf32, #ttnn_layout1> return %3, %5 : tensor<1x32x32xf32>, tensor<1x32x32xf32> loc(#loc4) } loc(#loc) } loc(#loc) diff --git a/test/ttmlir/Dialect/TTNN/optimizer/output_layout_override.mlir b/test/ttmlir/Dialect/TTNN/optimizer/output_layout_override.mlir index 91f38d446a..a924411791 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/output_layout_override.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/output_layout_override.mlir @@ -1,4 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true override-output-layout=add_1_0=4x4:dram:interleaved:row_major:bf16,add_2_0=4x4:l1:interleaved:tile:f32" %s | FileCheck %s +// XFAIL: * +// TODO: Layout override on the optimizer needs update after default input/output tile layout. #loc = loc("test_ops.py:17_0_0":0:0) module attributes {} { func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) { diff --git a/test/ttmlir/Dialect/TTNN/optimizer/sharding_matmul_override_0.mlir b/test/ttmlir/Dialect/TTNN/optimizer/sharding_matmul_override_0.mlir index e893e5d2c7..4f58f62c03 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/sharding_matmul_override_0.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/sharding_matmul_override_0.mlir @@ -1,12 +1,12 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true max-legal-layouts=0" %s | FileCheck %s module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>, %arg2: tensor<96x64xbf16>) -> tensor<64x64xbf16> { - // CHECK: #[[LAYOUT_7:ttnn_layout7]] = #ttnn.ttnn_layout<{{.*}}, memref<{{.*}}, #dram>, {{.*}}> + // CHECK: #[[LAYOUT_3:ttnn_layout3]] = #ttnn.ttnn_layout<{{.*}}, memref<{{.*}}, #dram>, {{.*}}> %0 = tensor.empty() : tensor<64x96xbf16> - // CHECK: {{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_7]]> + // CHECK: {{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_3]]> %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> %2 = tensor.empty() : tensor<64x64xbf16> - // CHECK: {{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x64xbf16, #[[LAYOUT_7]]> + // CHECK: {{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x64xbf16, #[[LAYOUT_3]]> %3 = "ttir.matmul"(%1, %arg2, %2) : (tensor<64x96xbf16>, tensor<96x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %3 : tensor<64x64xbf16> } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/sharding_matmul_override_32.mlir b/test/ttmlir/Dialect/TTNN/optimizer/sharding_matmul_override_32.mlir index aa4616360b..5d89adb8e1 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/sharding_matmul_override_32.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/sharding_matmul_override_32.mlir @@ -2,9 +2,9 @@ module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>, %arg2: tensor<96x64xbf16>) -> tensor<64x64xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK: #[[LAYOUT_7:ttnn_layout7]] = #ttnn.ttnn_layout<{{.*}}, memref<{{.*}}, #l1_>, {{.*}}> + // CHECK: #[[LAYOUT_4:ttnn_layout4]] = #ttnn.ttnn_layout<{{.*}}, memref<{{.*}}, #l1_>, {{.*}}> %0 = tensor.empty() : tensor<64x96xbf16> - // CHECK: {{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_7]]> + // CHECK: {{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_4]]> %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> %2 = tensor.empty() : tensor<64x64xbf16> %3 = "ttir.matmul"(%1, %arg2, %2) : (tensor<64x96xbf16>, tensor<96x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> diff --git a/test/ttmlir/Dialect/TTNN/simple_clamp.mlir b/test/ttmlir/Dialect/TTNN/simple_clamp.mlir index 272e07175b..e533e18384 100644 --- a/test/ttmlir/Dialect/TTNN/simple_clamp.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_clamp.mlir @@ -2,11 +2,8 @@ module attributes {} { func.func @clamp(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { %0 = tensor.empty() : tensor<64x128xbf16> - // CHECK: %[[DEVICE:.*]] = "ttnn.to_device"(%arg0, - // CHECK: %[[LAYOUT:.*]] = "ttnn.to_layout"(%[[DEVICE]]) - // CHECK: = "ttnn.clamp"(%[[LAYOUT]]) + // CHECK: "ttnn.clamp"(%arg0) // CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32} - // CHECK-SAME: [[TENSOR:tensor<64x128xbf16]], #ttnn_layout{{[0-9]+}}>) -> [[TENSOR]] %1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } diff --git a/test/ttmlir/Dialect/TTNN/simple_get_dimension_size.mlir b/test/ttmlir/Dialect/TTNN/simple_get_dimension_size.mlir index f3bd6dab00..76f7f7b1c0 100644 --- a/test/ttmlir/Dialect/TTNN/simple_get_dimension_size.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_get_dimension_size.mlir @@ -4,6 +4,5 @@ module attributes {} { %0 = "ttir.get_dimension_size"(%arg0) <{dimension = 1 : i32}> : (tensor<13x21x3xf32>) -> tensor<1xi32> // CHECK: [[VAL:%[0-9]+]] = "ttnn.full"(%{{[0-9]+}}) <{fillValue = 2.100000e+01 : f32}> : (!tt.device<#device>) -> tensor<1xi32, {{.*}}> return %0 : tensor<1xi32> - // CHECK: return [[VAL]] : tensor<1xi32, {{.*}}> } } diff --git a/test/ttmlir/Dialect/TTNN/simple_scatter.mlir b/test/ttmlir/Dialect/TTNN/simple_scatter.mlir index 22ad5c2d03..43c87b89a2 100644 --- a/test/ttmlir/Dialect/TTNN/simple_scatter.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_scatter.mlir @@ -8,7 +8,7 @@ module attributes {} { ^bb0(%arg3: tensor<1xf32>, %arg4: tensor<1xf32>): "ttir.yield"(%arg4) : (tensor<1xf32>) -> () }) : (tensor<1x3x320x320xf32>, tensor<1x1xi32>, tensor<1x3x32x32xf32>, tensor<1x3x320x320xf32>) -> tensor<1x3x320x320xf32> - // CHECK: {{[0-9]+}} = "ttnn.scatter"(%4, %2, %5) <{operandSegmentSizes = array}> : (tensor<1x3x32x32xf32, {{.*}}>, tensor<[[TENSOR_SHAPE1]], {{.*}}>, tensor<[[TENSOR_SHAPE1]], {{.*}}>) -> tensor<[[TENSOR_SHAPE1]], {{.*}}> + // CHECK: %{{[0-9]+}} = "ttnn.scatter"(%arg1, %arg0, %1) <{operandSegmentSizes = array}> : (tensor<1x3x32x32xf32, {{.*}}>, tensor<[[TENSOR_SHAPE1]], {{.*}}>, tensor<[[TENSOR_SHAPE1]], {{.*}}>) -> tensor<[[TENSOR_SHAPE1]], {{.*}}> return %2 : tensor<1x3x320x320xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE1]], {{.*}}> } diff --git a/test/ttmlir/Dialect/TTNN/simple_where.mlir b/test/ttmlir/Dialect/TTNN/simple_where.mlir index c75c7f817d..d8d469943c 100644 --- a/test/ttmlir/Dialect/TTNN/simple_where.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_where.mlir @@ -5,9 +5,9 @@ module @jit_eltwise_where { %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> %2 = tensor.empty() : tensor<13x37xf32> %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> - // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} - // CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]]) - // CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) + // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} + // CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%arg0, %arg1, %[[EMPTY]]) + // CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %arg0, %arg1, %{{[0-9]+}}) return %3 : tensor<13x37xf32> } } diff --git a/test/ttmlir/Runtime/TTNN/runtime_stitching/eltwise_binary_op_chain.mlir b/test/ttmlir/Runtime/TTNN/runtime_stitching/eltwise_binary_op_chain.mlir deleted file mode 100644 index 35b4d90634..0000000000 --- a/test/ttmlir/Runtime/TTNN/runtime_stitching/eltwise_binary_op_chain.mlir +++ /dev/null @@ -1,49 +0,0 @@ -// RUN: ttmlir-opt --ttir-load-system-desc="path=%system_desc_path%" %s > %t.mlir -// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn - -// TODO: this is a workaround for compiler assuming input tensors are always on host. The ideal is to directly compile ttir graphs. -#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> -#system_memory = #ttnn.buffer_type -#dram = #ttnn.buffer_type -#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xbf16, #system_memory>> -#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, > -#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xbf16, #dram>, > - -module attributes {tt.device = #device} { - func.func @add(%arg0: tensor<64x128xbf16, #ttnn_layout1>, %arg1: tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout> { - %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> - %1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> - %2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> - %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<64x128>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2> - %4 = "ttnn.add"(%1, %2, %3) <{operandSegmentSizes = array}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2> - %5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout> - %6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout> - return %6 : tensor<64x128xbf16, #ttnn_layout> - } -} - -module attributes {tt.device = #device} { - func.func @multiply(%arg0: tensor<64x128xbf16, #ttnn_layout1>, %arg1: tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout> { - %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> - %1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> - %2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> - %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<64x128>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2> - %4 = "ttnn.multiply"(%1, %2, %3) <{operandSegmentSizes = array}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2> - %5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout> - %6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout> - return %6 : tensor<64x128xbf16, #ttnn_layout> - } -} - -module attributes {tt.device = #device} { - func.func @subtract(%arg0: tensor<64x128xbf16, #ttnn_layout1>, %arg1: tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout> { - %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> - %1 = "ttnn.to_layout"(%arg0) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> - %2 = "ttnn.to_layout"(%arg1) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout1>) -> tensor<64x128xbf16, #ttnn_layout1> - %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<64x128>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xbf16, #ttnn_layout2> - %4 = "ttnn.subtract"(%1, %2, %3) <{operandSegmentSizes = array}> : (tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout1>, tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout2> - %5 = "ttnn.from_device"(%4) : (tensor<64x128xbf16, #ttnn_layout2>) -> tensor<64x128xbf16, #ttnn_layout> - %6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout}> : (tensor<64x128xbf16, #ttnn_layout>) -> tensor<64x128xbf16, #ttnn_layout> - return %6 : tensor<64x128xbf16, #ttnn_layout> - } -} diff --git a/test/ttmlir/Silicon/StableHLO/scalar_add_op.mlir b/test/ttmlir/Silicon/StableHLO/scalar_add_op.mlir index bae9b2b4ac..4e05a5040b 100644 --- a/test/ttmlir/Silicon/StableHLO/scalar_add_op.mlir +++ b/test/ttmlir/Silicon/StableHLO/scalar_add_op.mlir @@ -5,7 +5,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn // RUN: FileCheck --input-file=%t.mlir %s - module @jit_eltwise_scalar_add attributes {} { func.func public @test_scalar_add(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-LABEL: func.func public @test_scalar_add diff --git a/test/ttmlir/Silicon/TTNN/deallocate.mlir b/test/ttmlir/Silicon/TTNN/deallocate.mlir index cdba160160..aed0a83549 100644 --- a/test/ttmlir/Silicon/TTNN/deallocate.mlir +++ b/test/ttmlir/Silicon/TTNN/deallocate.mlir @@ -7,12 +7,9 @@ module @"dealloc_test" attributes {} { %0 = tensor.empty() : tensor<1x256xf32> loc(#loc8) %1 = "ttir.matmul"(%arg0, %arg4, %0) : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) // CHECK: %{{.+}} = "ttnn.matmul"([[I1:%.+]], [[I2:%.+]], [[O1:%.+]]) {{.+}} -> tensor<1x256xf32, {{.+}}> - // CHECK: "ttnn.deallocate"([[I2]]) {{.+}} : (tensor<784x256xf32, {{.+}}) -> () - // CHECK: "ttnn.deallocate"([[I1]]) {{.+}} : (tensor<1x784xf32, {{.+}}>) -> () %2 = tensor.empty() : tensor<1x256xf32> loc(#loc9) %3 = "ttir.add"(%1, %arg3, %2) <{operandSegmentSizes = array}> : (tensor<1x256xf32>, tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc9) // CHECK: %{{.+}} = "ttnn.add"([[I1:%.+]], [[I2:%.+]], [[O2:%.+]]) {{.+}} -> tensor<1x256xf32, {{.+}}> - // CHECK: "ttnn.deallocate"([[I2]]) {{.+}} : (tensor<1x256xf32, {{.+}}>) -> () // CHECK: "ttnn.deallocate"([[O1]]) {{.+}} : (tensor<1x256xf32, {{.+}}>) -> () %4 = tensor.empty() : tensor<1x256xf32> loc(#loc10) %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array}> : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc10) @@ -21,15 +18,15 @@ module @"dealloc_test" attributes {} { %6 = tensor.empty() : tensor<1x10xf32> loc(#loc11) %7 = "ttir.matmul"(%5, %arg2, %6) : (tensor<1x256xf32>, tensor<256x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc11) // CHECK: %{{.+}} = "ttnn.matmul"([[I1:%.+]], [[I2:%.+]], [[O4:%.+]]) {{.+}} -> tensor<1x10xf32, {{.+}}> - // CHECK: "ttnn.deallocate"([[I2]]) {{.+}} : (tensor<256x10xf32, {{.+}}>) -> () - // CHECK: "ttnn.deallocate"([[O3]]) {{.+}} : (tensor<1x256xf32,{{.+}}>) -> () + // CHECK: "ttnn.deallocate"([[O3]]) {{.+}} : (tensor<1x256xf32, {{.+}}>) -> () %8 = tensor.empty() : tensor<1x10xf32> loc(#loc12) %9 = "ttir.add"(%7, %arg1, %8) <{operandSegmentSizes = array}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc12) // CHECK: %{{.+}} = "ttnn.add"([[I1:%.+]], [[I2:%.+]], [[O5:%.+]]) {{.+}} -> tensor<1x10xf32,{{.+}}> - // CHECK: "ttnn.deallocate"([[I2]]) {{.+}} : (tensor<1x10xf32, {{.+}}>) -> () // CHECK: "ttnn.deallocate"([[O4]]) {{.+}} : (tensor<1x10xf32, {{.+}}>) -> () %10 = tensor.empty() : tensor<1x10xf32> loc(#loc13) %11 = "ttir.softmax"(%9, %10) <{dimension = 1 : si32}> : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc13) + // CHECK: %{{.+}} = "ttnn.softmax"([[I1:%.+]]) {{.+}} -> tensor<1x10xf32, {{.+}}> + // CHECK: "ttnn.deallocate"([[O5]]) {{.+}} : (tensor<1x10xf32, {{.+}}>) -> () return %11 : tensor<1x10xf32> loc(#loc7) } loc(#loc) } loc(#loc) diff --git a/test/ttmlir/Silicon/TTNN/eltwise/binary/add/add.mlir b/test/ttmlir/Silicon/TTNN/eltwise/binary/add/add.mlir new file mode 100644 index 0000000000..0774d60ded --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/binary/add/add.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @add(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/binary/add/add_int32.mlir b/test/ttmlir/Silicon/TTNN/eltwise/binary/add/add_int32.mlir new file mode 100644 index 0000000000..49d028086b --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/binary/add/add_int32.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @addint32(%arg0: tensor<64x128xi32>, %arg1: tensor<64x128xi32>) -> tensor<64x128xi32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xi32> + // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xi32>, tensor<64x128xi32>, tensor<64x128xi32>) -> tensor<64x128xi32> + return %1 : tensor<64x128xi32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/binary/concat/concat.mlir b/test/ttmlir/Silicon/TTNN/eltwise/binary/concat/concat.mlir new file mode 100644 index 0000000000..ac73de3739 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/binary/concat/concat.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @concat(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<32x96xf32> + // CHECK: %[[C:.*]] = "ttnn.concat"[[C:.*]] + %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> + return %1 : tensor<32x96xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/binary/div/div.mlir b/test/ttmlir/Silicon/TTNN/eltwise/binary/div/div.mlir new file mode 100644 index 0000000000..7ba2da3211 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/binary/div/div.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @div(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.div"[[C:.*]] + %1 = "ttir.div"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/binary/ge/ge.mlir b/test/ttmlir/Silicon/TTNN/eltwise/binary/ge/ge.mlir new file mode 100644 index 0000000000..3449d485bf --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/binary/ge/ge.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @ge(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.ge"[[C:.*]] + %1 = "ttir.ge"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/binary/maximum/maximum.mlir b/test/ttmlir/Silicon/TTNN/eltwise/binary/maximum/maximum.mlir new file mode 100644 index 0000000000..2659cccdce --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/binary/maximum/maximum.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @maximum(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.maximum"[[C:.*]] + %1 = "ttir.maximum"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/binary/minimum/minimum.mlir b/test/ttmlir/Silicon/TTNN/eltwise/binary/minimum/minimum.mlir new file mode 100644 index 0000000000..8272db6e6a --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/binary/minimum/minimum.mlir @@ -0,0 +1,16 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @minimum(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty" + // CHECK-SAME: [[TENSOR:tensor<64x128xf32,]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.minimum" + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: -> [[TENSOR]] + %1 = "ttir.minimum"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/binary/multiply/multiply.mlir b/test/ttmlir/Silicon/TTNN/eltwise/binary/multiply/multiply.mlir new file mode 100644 index 0000000000..2edabb7470 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/binary/multiply/multiply.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @multiply(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/binary/remainder/remainder.mlir b/test/ttmlir/Silicon/TTNN/eltwise/binary/remainder/remainder.mlir new file mode 100644 index 0000000000..dbcff9d786 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/binary/remainder/remainder.mlir @@ -0,0 +1,12 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @remainder(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> { + %0 = tensor.empty() : tensor<32x32xf32> + // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} -> tensor<32x32xf32, {{.*}} + %1 = "ttir.remainder"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + // CHECK: %[[REM:[0-9]+]] = "ttnn.remainder"({{.*}}, {{.*}}, %[[EMPTY]]){{.*}} -> tensor<32x32xf32, {{.*}} + return %1 : tensor<32x32xf32> + // CHECK: return {{.*}} : tensor<32x32xf32, {{.*}} +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/binary/scatter/scatter.mlir b/test/ttmlir/Silicon/TTNN/eltwise/binary/scatter/scatter.mlir new file mode 100644 index 0000000000..ddd89b0787 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/binary/scatter/scatter.mlir @@ -0,0 +1,14 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @scatter(%arg0: tensor<1x3x320x320xf32>, %arg1: tensor<1x3x32x32xf32>) -> tensor<1x3x320x320xf32> { + %0 = tensor.empty() : tensor<1x3x320x320xf32> + %1 = tensor.empty() : tensor<1x1xi32> + %2 = "ttir.scatter"(%arg0, %1, %arg1, %0) <{index_vector_dim = 1 : i32, indices_are_sorted = false, input_batching_dims = array, inserted_window_dims = array, scatter_dims_to_operand_dims = array, scatter_indices_batching_dims = array, unique_indices = false, update_window_dims = array}> ({ + ^bb0(%arg3: tensor<1xf32>, %arg4: tensor<1xf32>): + "ttir.yield"(%arg4) : (tensor<1xf32>) -> () + }) : (tensor<1x3x320x320xf32>, tensor<1x1xi32>, tensor<1x3x32x32xf32>, tensor<1x3x320x320xf32>) -> tensor<1x3x320x320xf32> + // CHECK: "ttnn.scatter" + return %2 : tensor<1x3x320x320xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/binary/subtract/subtract.mlir b/test/ttmlir/Silicon/TTNN/eltwise/binary/subtract/subtract.mlir new file mode 100644 index 0000000000..d7a84fd159 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/binary/subtract/subtract.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @subtract(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.subtract"[[C:.*]] + %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/ternary/where/where.mlir b/test/ttmlir/Silicon/TTNN/eltwise/ternary/where/where.mlir new file mode 100644 index 0000000000..ced442b91f --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/ternary/where/where.mlir @@ -0,0 +1,14 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @test_where(%arg0: tensor<13x37xbf16>, %arg1: tensor<13x37xbf16>) -> tensor<13x37xbf16> { + %0 = tensor.empty() : tensor<13x37xbf16> + %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> + %2 = tensor.empty() : tensor<13x37xbf16> + %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> + // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} + // CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%arg0, %arg1, %[[EMPTY]]) + // CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %arg0, %arg1, %{{[0-9]+}}) + return %3 : tensor<13x37xbf16> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/cbrt/cbrt.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/cbrt/cbrt.mlir new file mode 100644 index 0000000000..a2833f7053 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/cbrt/cbrt.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @cbrt(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.cbrt"[[C:.*]] + %1 = "ttir.cbrt"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/ceil/ceil.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/ceil/ceil.mlir new file mode 100644 index 0000000000..934599eb74 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/ceil/ceil.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @ceil(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { + %0 = tensor.empty() : tensor<32x32xf32> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) + // CHECK: %{{[0-9]+}} = "ttnn.ceil"(%arg0, [[VAL0]]) + %1 = "ttir.ceil"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + return %1 : tensor<32x32xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/clamp/clamp.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/clamp/clamp.mlir new file mode 100644 index 0000000000..63c17062ac --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/clamp/clamp.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @clamp(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: "ttnn.clamp"(%arg0) + // CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32} + %1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/cosine/cosine.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/cosine/cosine.mlir new file mode 100644 index 0000000000..2391f68249 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/cosine/cosine.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @cosine(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { + %0 = tensor.empty() : tensor<32x32xf32> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) + // CHECK: %{{[0-9]+}} = "ttnn.cos"(%arg0, [[VAL0]]) + %1 = "ttir.cos"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + return %1 : tensor<32x32xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/expm1/expm1.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/expm1/expm1.mlir new file mode 100644 index 0000000000..fc95c1ae07 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/expm1/expm1.mlir @@ -0,0 +1,12 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @expm1(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> + %1 = "ttir.expm1"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + // CHECK: %{{[0-9]+}} = "ttnn.expm1"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> + return %1 : tensor<64x128xf32> + // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/floor/floor.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/floor/floor.mlir new file mode 100644 index 0000000000..7af577c8e8 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/floor/floor.mlir @@ -0,0 +1,15 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @floor(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %{{[0-9]+}} = "ttnn.empty" + // CHECK-SAME: [[TENSOR:tensor<64x128xf32,]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %{{[0-9]+}} = "ttnn.floor" + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: -> [[TENSOR]] + %1 = "ttir.floor"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/gelu/gelu.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/gelu/gelu.mlir new file mode 100644 index 0000000000..7e9767e1fd --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/gelu/gelu.mlir @@ -0,0 +1,15 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @gelu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x128xf32, + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: "ttnn.gelu" + // CHECK-SAME: tensor<64x128xf32, + // CHECK-SAME: tensor<64x128xf32, + // CHECK-SAME: tensor<64x128xf32, + %1 = "ttir.gelu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/get_dimension_size/get_dimension_size.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/get_dimension_size/get_dimension_size.mlir new file mode 100644 index 0000000000..afe7f178f6 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/get_dimension_size/get_dimension_size.mlir @@ -0,0 +1,9 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @get_dimension_size(%arg0: tensor<13x21x1x3xf32>) -> tensor<1xi32> { + %0 = "ttir.get_dimension_size"(%arg0) <{dimension = 1 : i32}> : (tensor<13x21x1x3xf32>) -> tensor<1xi32> + // CHECK: [[VAL:%[0-9]+]] = "ttnn.full"(%{{[0-9]+}}) <{fillValue = 2.100000e+01 : f32}> : (!tt.device<#device>) -> tensor<1xi32, {{.*}}> + return %0 : tensor<1xi32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/is_finite/is_finite.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/is_finite/is_finite.mlir new file mode 100644 index 0000000000..b8dc64fb72 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/is_finite/is_finite.mlir @@ -0,0 +1,15 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @is_finite(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + // CHECK: %[[C:.*]] = "ttnn.empty" + // CHECK-SAME: [[TENSOR:tensor<64x128xbf16,]] + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: %[[C:.*]] = "ttnn.isfinite" + // CHECK-SAME: tensor<64x128xbf16, + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: -> [[TENSOR]] + %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/leaky_relu/leaky_relu.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/leaky_relu/leaky_relu.mlir new file mode 100644 index 0000000000..018fa352ee --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/leaky_relu/leaky_relu.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @leaky_relu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty" + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.leaky_relu" + %1 = "ttir.leaky_relu"(%arg0, %0) <{parameter = 0.01 : f32, operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/log/log.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/log/log.mlir new file mode 100644 index 0000000000..fa21eb24cc --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/log/log.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @log(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.log"[[C:.*]] + %1 = "ttir.log"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/log1p/log1p.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/log1p/log1p.mlir new file mode 100644 index 0000000000..efdd6b8fe0 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/log1p/log1p.mlir @@ -0,0 +1,12 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @log1p(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> + %1 = "ttir.log1p"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + // CHECK: %{{[0-9]+}} = "ttnn.log1p"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> + return %1 : tensor<64x128xf32> + // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/negate/negate.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/negate/negate.mlir new file mode 100644 index 0000000000..5173e6d920 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/negate/negate.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @negate(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { + %0 = tensor.empty() : tensor<32x32xf32> + // CHECK: %[[C:.*]] = "ttnn.neg"[[C:.*]] + %1 = "ttir.neg"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + return %1 : tensor<32x32xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/recipricol/recipricol.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/recipricol/recipricol.mlir new file mode 100644 index 0000000000..a05f62da9f --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/recipricol/recipricol.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @reciprocal(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.reciprocal"[[C:.*]] + %1 = "ttir.reciprocal"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/relu/relu.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/relu/relu.mlir new file mode 100644 index 0000000000..3a75ad988a --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/relu/relu.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @relu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.relu"[[C:.*]] + %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/rsqrt/rsqrt.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/rsqrt/rsqrt.mlir new file mode 100644 index 0000000000..e3ced09427 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/rsqrt/rsqrt.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @rsqrt(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.rsqrt"[[C:.*]] + %1 = "ttir.rsqrt"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/sigmoid/sigmoid.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/sigmoid/sigmoid.mlir new file mode 100644 index 0000000000..b6ef7a5a44 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/sigmoid/sigmoid.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @sigmoid(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.sigmoid"[[C:.*]] + %1 = "ttir.sigmoid"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/sign/sign.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/sign/sign.mlir new file mode 100644 index 0000000000..368d7f26af --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/sign/sign.mlir @@ -0,0 +1,12 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @sign(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> + %1 = "ttir.sign"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + // CHECK: %{{[0-9]+}} = "ttnn.sign"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> + return %1 : tensor<64x128xf32> + // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/sine/sine.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/sine/sine.mlir new file mode 100644 index 0000000000..ca61f435af --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/sine/sine.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @sine(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { + %0 = tensor.empty() : tensor<32x32xf32> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) + // CHECK: %{{[0-9]+}} = "ttnn.sin"(%arg0, [[VAL0]]) + %1 = "ttir.sin"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + return %1 : tensor<32x32xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/sqrt/sqrt.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/sqrt/sqrt.mlir new file mode 100644 index 0000000000..7c948b7ac1 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/sqrt/sqrt.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @sqrt(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.sqrt"[[C:.*]] + %1 = "ttir.sqrt"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/tan/tan.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/tan/tan.mlir new file mode 100644 index 0000000000..c1c319dec2 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/tan/tan.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @tan(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) + // CHECK: %{{[0-9]+}} = "ttnn.tan"(%arg0, [[VAL0]]) + %1 = "ttir.tan"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> +} diff --git a/test/ttmlir/Silicon/TTNN/eltwise/unary/tanh/tanh.mlir b/test/ttmlir/Silicon/TTNN/eltwise/unary/tanh/tanh.mlir new file mode 100644 index 0000000000..bee21d19a9 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/eltwise/unary/tanh/tanh.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @tanh(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) + // CHECK: %{{[0-9]+}} = "ttnn.tanh"(%arg0, [[VAL0]]) + %1 = "ttir.tanh"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> +} diff --git a/test/ttmlir/Silicon/TTNN/ones.mlir b/test/ttmlir/Silicon/TTNN/ones.mlir index 660de36ae1..d0a0b1b8b8 100644 --- a/test/ttmlir/Silicon/TTNN/ones.mlir +++ b/test/ttmlir/Silicon/TTNN/ones.mlir @@ -4,25 +4,25 @@ module { func.func @ones_2d() -> tensor<32x128xbf16> { - // CHECK: {{.*}} = "ttnn.ones"() {{.*}} + // CHECK: {{.*}} = "ttnn.ones"({{.*}}) {{.*}} %0 = "ttir.ones"() <{shape = array}> : () -> tensor<32x128xbf16> return %0 : tensor<32x128xbf16> } func.func @ones_3d() -> tensor<32x64x128xbf16> { - // CHECK: {{.*}} = "ttnn.ones"() {{.*}} + // CHECK: {{.*}} = "ttnn.ones"({{.*}}) {{.*}} %0 = "ttir.ones"() <{shape = array}> : () -> tensor<32x64x128xbf16> return %0 : tensor<32x64x128xbf16> } func.func @ones_4d_irregular_shapes() -> tensor<13x24x56x42xbf16> { - // CHECK: {{.*}} = "ttnn.ones"() {{.*}} -> tensor<13x24x56x42xbf16{{.*}}> + // CHECK: {{.*}} = "ttnn.ones"({{.*}}) {{.*}} -> tensor<13x24x56x42xbf16{{.*}}> %0 = "ttir.ones"() <{shape = array}> : () -> tensor<13x24x56x42xbf16> return %0 : tensor<13x24x56x42xbf16> } func.func @ones_f32() -> tensor<32x64x128xf32> { - // CHECK: {{.*}} = "ttnn.ones"() {{.*}} -> tensor<32x64x128xf32{{.*}}> + // CHECK: {{.*}} = "ttnn.ones"({{.*}}) {{.*}} -> tensor<32x64x128xf32{{.*}}> %0 = "ttir.ones"() <{shape = array}> : () -> tensor<32x64x128xf32> return %0 : tensor<32x64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ceil.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ceil.mlir index d554baf2e3..114275c51f 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ceil.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ceil.mlir @@ -4,7 +4,7 @@ func.func @ceil(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) - // CHECK: %{{[0-9]+}} = "ttnn.ceil"(%{{[0-9]+}}, [[VAL0]]) + // CHECK: %{{[0-9]+}} = "ttnn.ceil"(%arg0, [[VAL0]]) %1 = "ttir.ceil"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_clamp.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_clamp.mlir index 44806c22df..63c17062ac 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_clamp.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_clamp.mlir @@ -4,11 +4,8 @@ func.func @clamp(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { %0 = tensor.empty() : tensor<64x128xbf16> - // CHECK: %[[DEVICE:.*]] = "ttnn.to_device"(%arg0, - // CHECK: %[[LAYOUT:.*]] = "ttnn.to_layout"(%[[DEVICE]]) - // CHECK: = "ttnn.clamp"(%[[LAYOUT]]) + // CHECK: "ttnn.clamp"(%arg0) // CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32} - // CHECK-SAME: [[TENSOR:tensor<64x128xbf16]], #ttnn_layout{{[0-9]+}}>) -> [[TENSOR]] %1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_cosine.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_cosine.mlir index 2596e4a132..1598de319a 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_cosine.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_cosine.mlir @@ -4,7 +4,7 @@ func.func @cosine(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) - // CHECK: %{{[0-9]+}} = "ttnn.cos"(%{{[0-9]+}}, [[VAL0]]) + // CHECK: %{{[0-9]+}} = "ttnn.cos"(%arg0, [[VAL0]]) %1 = "ttir.cos"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_expm1.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_expm1.mlir index 7d035174c0..a499c20ce2 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_expm1.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_expm1.mlir @@ -5,7 +5,7 @@ func.func @expm1(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> %1 = "ttir.expm1"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %{{[0-9]+}} = "ttnn.expm1"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> + // CHECK: %{{[0-9]+}} = "ttnn.expm1"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log.mlir index d4a7ed331b..ef5244fdae 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log.mlir @@ -4,7 +4,7 @@ func.func @log(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) - // CHECK: %{{[0-9]+}} = "ttnn.log"(%{{[0-9]+}}, [[VAL0]]) + // CHECK: %{{[0-9]+}} = "ttnn.log"(%arg0, [[VAL0]]) %1 = "ttir.log"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log1p.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log1p.mlir index 3d50d3e88f..7e21972a81 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log1p.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log1p.mlir @@ -6,7 +6,7 @@ func.func @log1p(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> %1 = "ttir.log1p"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %{{[0-9]+}} = "ttnn.log1p"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> + // CHECK: %{{[0-9]+}} = "ttnn.log1p"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sign.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sign.mlir index 26fe2b2d0e..8a05b1ae6d 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sign.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sign.mlir @@ -5,7 +5,7 @@ func.func @sign(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> %1 = "ttir.sign"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %{{[0-9]+}} = "ttnn.sign"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> + // CHECK: %{{[0-9]+}} = "ttnn.sign"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sine.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sine.mlir index 61fe517ead..60dc574693 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sine.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sine.mlir @@ -4,7 +4,7 @@ func.func @sine(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) - // CHECK: %{{[0-9]+}} = "ttnn.sin"(%{{[0-9]+}}, [[VAL0]]) + // CHECK: %{{[0-9]+}} = "ttnn.sin"(%arg0, [[VAL0]]) %1 = "ttir.sin"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tan.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tan.mlir index 47957677b3..d870cee186 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tan.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tan.mlir @@ -5,7 +5,7 @@ func.func @tan(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) - // CHECK: %{{[0-9]+}} = "ttnn.tan"(%{{[0-9]+}}, [[VAL0]]) + // CHECK: %{{[0-9]+}} = "ttnn.tan"(%arg0, [[VAL0]]) %1 = "ttir.tan"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tanh.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tanh.mlir index 4844bd3084..cf12dcf9ac 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tanh.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tanh.mlir @@ -5,7 +5,7 @@ func.func @tanh(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) - // CHECK: %{{[0-9]+}} = "ttnn.tanh"(%{{[0-9]+}}, [[VAL0]]) + // CHECK: %{{[0-9]+}} = "ttnn.tanh"(%arg0, [[VAL0]]) %1 = "ttir.tanh"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir index 9076f24f4b..ced442b91f 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir @@ -8,7 +8,7 @@ func.func @test_where(%arg0: tensor<13x37xbf16>, %arg1: tensor<13x37xbf16>) -> t %2 = tensor.empty() : tensor<13x37xbf16> %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} - // CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]]) - // CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) + // CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%arg0, %arg1, %[[EMPTY]]) + // CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %arg0, %arg1, %{{[0-9]+}}) return %3 : tensor<13x37xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/reshape/reshape.mlir b/test/ttmlir/Silicon/TTNN/reshape/reshape.mlir new file mode 100644 index 0000000000..7156e648be --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/reshape/reshape.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @reshape(%arg0: tensor<4x2x32x32xbf16>) -> tensor<2x4x32x32xbf16> { + %0 = tensor.empty() : tensor<2x4x32x32xbf16> + // CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]] + %1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 32: i32, 32: i32]}> : (tensor<4x2x32x32xbf16>, tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> + return %1 : tensor<2x4x32x32xbf16> +} diff --git a/test/ttmlir/Silicon/TTNN/runtime_stitching/eltwise_binary_op_chain.mlir b/test/ttmlir/Silicon/TTNN/runtime_stitching/eltwise_binary_op_chain.mlir new file mode 100644 index 0000000000..bd03fbe155 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/runtime_stitching/eltwise_binary_op_chain.mlir @@ -0,0 +1,29 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +module attributes {} { +func.func @add(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} + +func.func @multiply(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} + +func.func @subtract(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.subtract"[[C:.*]] + %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} +} diff --git a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir deleted file mode 100644 index a0452f01f8..0000000000 --- a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir +++ /dev/null @@ -1,334 +0,0 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir -// RUN: FileCheck %s --input-file=%t.mlir -// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -func.func @add(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> -} - -func.func @ceil(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { - %0 = tensor.empty() : tensor<32x32xf32> - // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) - // CHECK: %{{[0-9]+}} = "ttnn.ceil"(%{{[0-9]+}}, [[VAL0]]) - %1 = "ttir.ceil"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> - return %1 : tensor<32x32xf32> -} - -func.func @clamp(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { - %0 = tensor.empty() : tensor<64x128xbf16> - // CHECK: %[[DEVICE:.*]] = "ttnn.to_device"(%arg0, - // CHECK: %[[LAYOUT:.*]] = "ttnn.to_layout"(%[[DEVICE]]) - // CHECK: = "ttnn.clamp"(%[[LAYOUT]]) - // CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32} - // CHECK-SAME: [[TENSOR:tensor<64x128xbf16]], #ttnn_layout{{[0-9]+}}>) -> [[TENSOR]] - %1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> - return %1 : tensor<64x128xbf16> -} - -func.func @concat(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<32x96xf32> - // CHECK: %[[C:.*]] = "ttnn.concat"[[C:.*]] - %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> - return %1 : tensor<32x96xf32> -} - -func.func @cosine(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { - %0 = tensor.empty() : tensor<32x32xf32> - // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) - // CHECK: %{{[0-9]+}} = "ttnn.cos"(%{{[0-9]+}}, [[VAL0]]) - %1 = "ttir.cos"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> - return %1 : tensor<32x32xf32> -} - -func.func @div(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.div"[[C:.*]] - %1 = "ttir.div"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> -} - -func.func @floor(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %{{[0-9]+}} = "ttnn.empty" - // CHECK-SAME: [[TENSOR:tensor<64x128xf32,]] - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %{{[0-9]+}} = "ttnn.floor" - // CHECK-SAME: [[TENSOR]] - // CHECK-SAME: [[TENSOR]] - // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.floor"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> -} - -func.func @is_finite(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { - // CHECK: %[[C:.*]] = "ttnn.empty" - // CHECK-SAME: [[TENSOR:tensor<64x128xbf16,]] - %0 = tensor.empty() : tensor<64x128xbf16> - // CHECK: %[[C:.*]] = "ttnn.isfinite" - // CHECK-SAME: tensor<64x128xbf16, - // CHECK-SAME: [[TENSOR]] - // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> - return %1 : tensor<64x128xbf16> -} - -func.func @minimum(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty" - // CHECK-SAME: [[TENSOR:tensor<64x128xf32,]] - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.minimum" - // CHECK-SAME: [[TENSOR]] - // CHECK-SAME: [[TENSOR]] - // CHECK-SAME: [[TENSOR]] - // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.minimum"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> -} - -func.func @ge(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.ge"[[C:.*]] - %1 = "ttir.ge"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> -} - -func.func @maximum(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.maximum"[[C:.*]] - %1 = "ttir.maximum"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> -} - -func.func @multiply(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> -} - -func.func @negate(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { - %0 = tensor.empty() : tensor<32x32xf32> - // CHECK: %[[C:.*]] = "ttnn.neg"[[C:.*]] - %1 = "ttir.neg"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> - return %1 : tensor<32x32xf32> -} - -func.func @reciprocal(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.reciprocal"[[C:.*]] - %1 = "ttir.reciprocal"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> -} - -func.func @relu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.relu"[[C:.*]] - %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> -} - -func.func @leaky_relu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty" - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.leaky_relu" - %1 = "ttir.leaky_relu"(%arg0, %0) <{parameter = 0.01 : f32, operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> -} - -func.func @reshape(%arg0: tensor<4x2x32x32xbf16>) -> tensor<2x4x32x32xbf16> { - %0 = tensor.empty() : tensor<2x4x32x32xbf16> - // CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]] - %1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 32: i32, 32: i32]}> : (tensor<4x2x32x32xbf16>, tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> - return %1 : tensor<2x4x32x32xbf16> -} - -func.func @squeeze(%arg0: tensor<1x2x1x32x32xbf16>) -> tensor<1x2x32x32xbf16> { - %0 = tensor.empty() : tensor<1x2x32x32xbf16> - // CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]] - %1 = "ttir.squeeze"(%arg0, %0) <{dim = 2 : si32}> : (tensor<1x2x1x32x32xbf16>, tensor<1x2x32x32xbf16>) -> tensor<1x2x32x32xbf16> - return %1 : tensor<1x2x32x32xbf16> -} - -func.func @subtract(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.subtract"[[C:.*]] - %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> -} - -func.func @rsqrt(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.rsqrt"[[C:.*]] - %1 = "ttir.rsqrt"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> -} - -func.func @sigmoid(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.sigmoid"[[C:.*]] - %1 = "ttir.sigmoid"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> -} - -func.func @sqrt(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.sqrt"[[C:.*]] - %1 = "ttir.sqrt"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> -} - -func.func @sine(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { - %0 = tensor.empty() : tensor<32x32xf32> - // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) - // CHECK: %{{[0-9]+}} = "ttnn.sin"(%{{[0-9]+}}, [[VAL0]]) - %1 = "ttir.sin"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> - return %1 : tensor<32x32xf32> -} - -func.func @softmax(%arg0: tensor<512x1024xbf16>) -> tensor<512x1024xbf16> { - %0 = tensor.empty() : tensor<512x1024xbf16> - // CHECK: %[[C:.*]] = "ttnn.softmax"[[C:.*]] - // Check for positive dimension attribute - %1 = "ttir.softmax"(%arg0, %0) <{dimension = 1 : si32}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> - %2 = tensor.empty() : tensor<512x1024xbf16> - // CHECK: %[[C:.*]] = "ttnn.softmax"[[C:.*]] - // Check for negative dimension attribute - %3 = "ttir.softmax"(%1, %2) <{dimension = -1 : si32}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> - return %3 : tensor<512x1024xbf16> -} - -func.func @cbrt(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.cbrt"[[C:.*]] - %1 = "ttir.cbrt"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> -} - -func.func @typecast(%arg0: tensor<64x128xf32>) -> tensor<64x128xbf16> { - %0 = tensor.empty() : tensor<64x128xbf16> - // CHECK: %[[C:.*]] = "ttnn.typecast" - // CHECK-SAME: tensor<64x128xf32, - // CHECK-SAME: tensor<64x128xbf16, - %1 = "ttir.typecast"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16> - return %1 : tensor<64x128xbf16> -} - -func.func @log(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.log"[[C:.*]] - %1 = "ttir.log"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> -} - -func.func @log1p(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> - %1 = "ttir.log1p"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %{{[0-9]+}} = "ttnn.log1p"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> - return %1 : tensor<64x128xf32> - // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> -} - -func.func @expm1(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> - %1 = "ttir.expm1"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %{{[0-9]+}} = "ttnn.expm1"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> - return %1 : tensor<64x128xf32> - // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> -} - -func.func @sign(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> - %1 = "ttir.sign"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %{{[0-9]+}} = "ttnn.sign"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> - return %1 : tensor<64x128xf32> - // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> -} - -func.func @remainder(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> { - %0 = tensor.empty() : tensor<32x32xf32> - // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} -> tensor<32x32xf32, {{.*}} - %1 = "ttir.remainder"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> - // CHECK: %[[REM:[0-9]+]] = "ttnn.remainder"({{.*}}, {{.*}}, %[[EMPTY]]){{.*}} -> tensor<32x32xf32, {{.*}} - return %1 : tensor<32x32xf32> - // CHECK: return {{.*}} : tensor<32x32xf32, {{.*}} -} - -func.func @get_dimension_size(%arg0: tensor<13x21x3xf32>) -> tensor<1xi32> { - %0 = "ttir.get_dimension_size"(%arg0) <{dimension = 1 : i32}> : (tensor<13x21x3xf32>) -> tensor<1xi32> - // CHECK: [[VAL:%[0-9]+]] = "ttnn.full"(%{{[0-9]+}}) <{fillValue = 2.100000e+01 : f32}> : (!tt.device<#device>) -> tensor<1xi32, {{.*}}> - return %0 : tensor<1xi32> - // CHECK: return [[VAL]] : tensor<1xi32, {{.*}}> -} - -func.func @test_where(%arg0: tensor<13x37xbf16>, %arg1: tensor<13x37xbf16>) -> tensor<13x37xbf16> { - %0 = tensor.empty() : tensor<13x37xbf16> - %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> - %2 = tensor.empty() : tensor<13x37xbf16> - %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> - // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} - // CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]]) - // CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) - return %3 : tensor<13x37xbf16> -} - -func.func @gelu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: "ttnn.empty" - // CHECK-SAME: tensor<64x128xf32, - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: "ttnn.gelu" - // CHECK-SAME: tensor<64x128xf32, - // CHECK-SAME: tensor<64x128xf32, - // CHECK-SAME: tensor<64x128xf32, - %1 = "ttir.gelu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - return %1 : tensor<64x128xf32> -} - -func.func @tan(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { - %0 = tensor.empty() : tensor<64x128xbf16> - // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) - // CHECK: %{{[0-9]+}} = "ttnn.tan"(%{{[0-9]+}}, [[VAL0]]) - %1 = "ttir.tan"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> - return %1 : tensor<64x128xbf16> -} - -func.func @tanh(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { - %0 = tensor.empty() : tensor<64x128xbf16> - // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) - // CHECK: %{{[0-9]+}} = "ttnn.tanh"(%{{[0-9]+}}, [[VAL0]]) - %1 = "ttir.tanh"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> - return %1 : tensor<64x128xbf16> -} - -func.func @addint32(%arg0: tensor<64x128xi32>, %arg1: tensor<64x128xi32>) -> tensor<64x128xi32> { - %0 = tensor.empty() : tensor<64x128xi32> - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xi32>, tensor<64x128xi32>, tensor<64x128xi32>) -> tensor<64x128xi32> - return %1 : tensor<64x128xi32> -} - -func.func @scatter(%arg0: tensor<1x3x320x320xf32>, %arg1: tensor<1x3x32x32xf32>) -> tensor<1x3x320x320xf32> { - %0 = tensor.empty() : tensor<1x3x320x320xf32> - %1 = tensor.empty() : tensor<1x1xi32> - %2 = "ttir.scatter"(%arg0, %1, %arg1, %0) <{index_vector_dim = 1 : i32, indices_are_sorted = false, input_batching_dims = array, inserted_window_dims = array, scatter_dims_to_operand_dims = array, scatter_indices_batching_dims = array, unique_indices = false, update_window_dims = array}> ({ - ^bb0(%arg3: tensor<1xf32>, %arg4: tensor<1xf32>): - "ttir.yield"(%arg4) : (tensor<1xf32>) -> () - }) : (tensor<1x3x320x320xf32>, tensor<1x1xi32>, tensor<1x3x32x32xf32>, tensor<1x3x320x320xf32>) -> tensor<1x3x320x320xf32> - return %2 : tensor<1x3x320x320xf32> -} diff --git a/test/ttmlir/Silicon/TTNN/simple_repeat.mlir b/test/ttmlir/Silicon/TTNN/simple_repeat.mlir index ab91af2ee6..cbf0bc34db 100644 --- a/test/ttmlir/Silicon/TTNN/simple_repeat.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_repeat.mlir @@ -3,7 +3,7 @@ // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn module { func.func @main(%arg0: tensor<1x16x32xf32>, %arg1: tensor<1x1x32xf32>) -> tensor<1x16x32xf32> { - // CHECK: %{{[0-9]+}} = "ttnn.repeat"(%{{[0-9]+}}) + // CHECK: %{{[0-9]+}} = "ttnn.repeat"(%arg1) %0 = tensor.empty() : tensor<1x16x32xf32> %1 = "ttir.broadcast"(%arg1, %0) <{broadcast_dimensions = array}> : (tensor<1x1x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32> %2 = tensor.empty() : tensor<1x16x32xf32> @@ -14,8 +14,8 @@ module { module { func.func public @main(%arg0: tensor<1xf32>, %arg1: tensor<512x512xf32>) -> (tensor<512x512xf32>) { - // CHECK: [[VAL0:%[0-9]+]] = "ttnn.reshape"(%{{[0-9]+}}) - // CHECK: %{{[0-9]+}} = "ttnn.repeat"([[VAL0]]) + // CHECK: %{{[0-9]+}} = "ttnn.reshape"(%arg0) + // CHECK: %{{[0-9]+}} = "ttnn.repeat"(%{{[0-9]+}}) %0 = tensor.empty() : tensor<1x1xf32> %1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 1 : i32]}> : (tensor<1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32> %2 = tensor.empty() : tensor<512x512xf32> diff --git a/test/ttmlir/Silicon/TTNN/softmax/softmax.mlir b/test/ttmlir/Silicon/TTNN/softmax/softmax.mlir new file mode 100644 index 0000000000..4322841f0c --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/softmax/softmax.mlir @@ -0,0 +1,15 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @softmax(%arg0: tensor<512x1024xbf16>) -> tensor<512x1024xbf16> { + %0 = tensor.empty() : tensor<512x1024xbf16> + // CHECK: %[[C:.*]] = "ttnn.softmax"[[C:.*]] + // Check for positive dimension attribute + %1 = "ttir.softmax"(%arg0, %0) <{dimension = 1 : si32}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> + %2 = tensor.empty() : tensor<512x1024xbf16> + // CHECK: %[[C:.*]] = "ttnn.softmax"[[C:.*]] + // Check for negative dimension attribute + %3 = "ttir.softmax"(%1, %2) <{dimension = -1 : si32}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> + return %3 : tensor<512x1024xbf16> +} diff --git a/test/ttmlir/Silicon/TTNN/squeeze/squeeze.mlir b/test/ttmlir/Silicon/TTNN/squeeze/squeeze.mlir new file mode 100644 index 0000000000..6ae1a77e0a --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/squeeze/squeeze.mlir @@ -0,0 +1,10 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @squeeze(%arg0: tensor<1x2x1x32x32xbf16>) -> tensor<1x2x32x32xbf16> { + %0 = tensor.empty() : tensor<1x2x32x32xbf16> + // CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]] + %1 = "ttir.squeeze"(%arg0, %0) <{dim = 2 : si32}> : (tensor<1x2x1x32x32xbf16>, tensor<1x2x32x32xbf16>) -> tensor<1x2x32x32xbf16> + return %1 : tensor<1x2x32x32xbf16> +} diff --git a/test/ttmlir/Silicon/TTNN/simple_typecast.mlir b/test/ttmlir/Silicon/TTNN/typecast/simple_typecast.mlir similarity index 100% rename from test/ttmlir/Silicon/TTNN/simple_typecast.mlir rename to test/ttmlir/Silicon/TTNN/typecast/simple_typecast.mlir diff --git a/test/ttmlir/Silicon/TTNN/typecast/typecast.mlir b/test/ttmlir/Silicon/TTNN/typecast/typecast.mlir new file mode 100644 index 0000000000..9f7ae032fb --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/typecast/typecast.mlir @@ -0,0 +1,12 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +func.func @typecast(%arg0: tensor<64x128xf32>) -> tensor<64x128xbf16> { + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: %[[C:.*]] = "ttnn.typecast" + // CHECK-SAME: tensor<64x128xf32, + // CHECK-SAME: tensor<64x128xbf16, + %1 = "ttir.typecast"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> +}