From 1dd729332e16037cd0ced6d8ec8128666fb1e1b3 Mon Sep 17 00:00:00 2001 From: Muhammad Asif Manzoor Date: Tue, 14 Jan 2025 02:10:28 +0000 Subject: [PATCH] reduce prod op --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 7 ++ include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 7 ++ include/ttmlir/Target/TTNN/program.fbs | 1 + .../StableHLOToTTIRPatterns.cpp | 27 +++++ lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 1 + lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp | 3 +- lib/Dialect/TTIR/IR/TTIROps.cpp | 33 ++++++ lib/Dialect/TTNN/IR/TTNNOps.cpp | 9 ++ .../Workarounds/TTNNWorkarounds.cpp | 6 +- lib/Target/TTNN/TTNNToFlatbuffer.cpp | 6 + runtime/include/tt/runtime/detail/ttnn.h | 1 + .../ttnn/operations/reduction/reduction.cpp | 32 +++++ .../{ => reduction}/reduce_add_op.mlir | 0 .../{ => reduction}/reduce_maximum_op.mlir | 0 .../reduction/reduce_prod_op.mlir | 82 +++++++++++++ .../reduce_ops/negative_reduce_prod_op.mlir | 25 ++++ .../TTNN/{ => reduction}/simple_max.mlir | 0 .../TTNN/{ => reduction}/simple_mean.mlir | 0 .../Dialect/TTNN/reduction/simple_prod.mlir | 80 +++++++++++++ .../TTNN/{ => reduction}/simple_sum.mlir | 0 .../{ => reduction}/reduce_add_op.mlir | 2 +- .../{ => reduction}/reduce_maximum_op.mlir | 2 +- .../StableHLO/reduction/reduce_prod_op.mlir | 111 ++++++++++++++++++ 23 files changed, 431 insertions(+), 4 deletions(-) rename test/ttmlir/Conversion/StableHLOToTTIR/{ => reduction}/reduce_add_op.mlir (100%) rename test/ttmlir/Conversion/StableHLOToTTIR/{ => reduction}/reduce_maximum_op.mlir (100%) create mode 100644 test/ttmlir/Conversion/StableHLOToTTIR/reduction/reduce_prod_op.mlir create mode 100644 test/ttmlir/Dialect/TTIR/reduce_ops/negative_reduce_prod_op.mlir rename test/ttmlir/Dialect/TTNN/{ => reduction}/simple_max.mlir (100%) rename test/ttmlir/Dialect/TTNN/{ => reduction}/simple_mean.mlir (100%) create mode 100644 test/ttmlir/Dialect/TTNN/reduction/simple_prod.mlir rename test/ttmlir/Dialect/TTNN/{ => reduction}/simple_sum.mlir (100%) rename test/ttmlir/Silicon/StableHLO/{ => reduction}/reduce_add_op.mlir (97%) rename test/ttmlir/Silicon/StableHLO/{ => reduction}/reduce_maximum_op.mlir (97%) create mode 100644 test/ttmlir/Silicon/StableHLO/reduction/reduce_prod_op.mlir diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 710f88cfe..ddea44e4f 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -704,6 +704,13 @@ def TTIR_MaxOp : TTIR_ReductionOp<"max"> { }]; } +def TTIR_ProdOp : TTIR_ReductionOp<"prod"> { + let summary = "Product reduction op."; + let description = [{ + Product reduction op. + }]; +} + def TTIR_EmbeddingOp : TTIR_DPSOp<"embedding"> { let summary = "Embedding op."; let description = [{ diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index ba2484ac5..6273771cd 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -609,6 +609,13 @@ def TTNN_MaxOp : TTNN_ReductionOp<"max"> { }]; } +def TTNN_ProdOp : TTNN_ReductionOp<"prod"> { + let summary = "Product reduction op."; + let description = [{ + Product reduction op. + }]; +} + def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> { let summary = "Embedding op."; let description = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 0838c629d..ceb6f437a 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -168,6 +168,7 @@ enum ReductionOpType: uint32 { Sum, Mean, Max, + Prod, } table ReductionOp { diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 40574741d..a35088a2d 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -84,6 +84,10 @@ class StableHLOToTTIRReduceOpConversionPattern return matchAndRewriteInternal(srcOp, adaptor, rewriter); } + if (mlir::isa(innerOp)) { + return matchAndRewriteInternal(srcOp, adaptor, + rewriter); + } return failure(); } @@ -104,6 +108,29 @@ class StableHLOToTTIRReduceOpConversionPattern "Expecting StableHLO Reduce OP to have a body operation defined."); } + mlir::Operation &innerOp = srcOp.getBody().front().front(); + if (mlir::isa(innerOp)) { + RankedTensorType inputType = mlir::cast( + getTypeConverter()->convertType(srcOp.getInputs()[0].getType())); + int64_t inputTensorRank = inputType.getRank(); + int64_t numDims = srcOp.getDimensions().size(); + mlir::Type elementType = inputType.getElementType(); + if (inputTensorRank > 4) { + return rewriter.notifyMatchFailure( + srcOp, "Input tensor rank is greater than 4 for reduce(product)."); + } + if ((numDims > 1) && (numDims != inputTensorRank)) { + return rewriter.notifyMatchFailure( + srcOp, "TTNN only supports reduce(prod) along one dimension or all " + "dimensions."); + } + if ((numDims == inputTensorRank) && !elementType.isBF16()) { + return rewriter.notifyMatchFailure( + srcOp, "TTNN only supports Reduce(prod) along all dimensions for " + "bfloat16 datatype."); + } + } + return success(); } diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 2e84eb347..d7b8c422f 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -1206,6 +1206,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, ReductionOpConversionPattern, ReductionOpConversionPattern, ReductionOpConversionPattern, + ReductionOpConversionPattern, ElementwiseUnaryWithFloatParameterOpConversionPattern, BroadcastOpConversionPattern, EmbeddingOpConversionPattern, diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index f92e730ba..a80642919 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -846,7 +846,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // patterns.add, DefaultOpConversionPattern, - DefaultOpConversionPattern>(typeConverter, ctx); + DefaultOpConversionPattern, + DefaultOpConversionPattern>(typeConverter, ctx); // Conv ops // diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 0b58ed860..31da6b7f8 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -2013,6 +2013,23 @@ verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType, return reduceOp->emitOpError("Reduce dimensions are not unique"); } + if (mlir::isa(reduceOp)) { + int64_t numReduceDims = uniqueReduceDims.size(); + mlir::Type elementType = inputType.getElementType(); + if (inputTensorRank > 4) { + return reduceOp->emitOpError( + "Input tensor rank is greater than 4 for reduce(product)."); + } + if ((numReduceDims > 1) && (numReduceDims != inputTensorRank)) { + return reduceOp->emitOpError("TTNN only supports reduce(prod) along one " + "dimension or all dimensions."); + } + if ((numReduceDims == inputTensorRank) && !elementType.isBF16()) { + return reduceOp->emitOpError("TTNN only supports Reduce(prod) along all " + "dimensions for bfloat16 datatype."); + } + } + // TODO(mrakita): Add a check that depending on inputShape, reduceDims and // keepDim computes the expected output shape and checks if it matches the // actual output shape. Tracked by: @@ -2068,3 +2085,19 @@ void mlir::tt::ttir::SumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::LogicalResult mlir::tt::ttir::SumOp::verify() { return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); } + +//===----------------------------------------------------------------------===// +// Reduce ProdOp +//===----------------------------------------------------------------------===// + +// ProdOp kernel builder. +void mlir::tt::ttir::ProdOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, + ::mlir::Block *block) { + // NOLINTNEXTLINE + createReduceOp(opBuilder, block, getLoc(), "prod"); +} + +// ProdOp verification. +::mlir::LogicalResult mlir::tt::ttir::ProdOp::verify() { + return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); +} diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index eccb1e9ba..908a53fb8 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -1343,4 +1343,13 @@ ::mlir::LogicalResult SumOp::verify() { return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); } +//===----------------------------------------------------------------------===// +// Reduce ProdOp +//===----------------------------------------------------------------------===// + +// ProdOp verification. +::mlir::LogicalResult ProdOp::verify() { + return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); +} + } // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp index a51166537..80b314ae9 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp @@ -425,12 +425,16 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase { ttnn::MaxOp>, workarounds::decomposition::ReduceOpsKeepDimRewritePattern< ttnn::MeanOp>, + workarounds::decomposition::ReduceOpsKeepDimRewritePattern< + ttnn::ProdOp>, workarounds::decomposition::ReduceOpsAllDimsRewritePattern< ttnn::SumOp>, workarounds::decomposition::ReduceOpsAllDimsRewritePattern< ttnn::MaxOp>, workarounds::decomposition::ReduceOpsAllDimsRewritePattern< - ttnn::MeanOp>>(&getContext()); + ttnn::MeanOp>, + workarounds::decomposition::ReduceOpsAllDimsRewritePattern< + ttnn::ProdOp>>(&getContext()); runRewritePatterns(std::move(patterns), GreedyRewriteConfig::kNoLimit /*maxIterations*/); diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 170944311..c114e970a 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -742,6 +742,8 @@ createReductionOp(FlatbufferObjectCache &cache, ReductionOp op) { type = ::tt::target::ttnn::ReductionOpType::Mean; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::ReductionOpType::Max; + } else if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::ReductionOpType::Prod; } else { llvm_unreachable("unhandled ReductionOp"); } @@ -1116,6 +1118,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, return createOperation(cache, createReductionOp(cache, maxOp), debugString, locInfo); } + if (auto prodOp = dyn_cast(op); prodOp) { + return createOperation(cache, createReductionOp(cache, prodOp), debugString, + locInfo); + } if (auto embeddingOp = dyn_cast(op); embeddingOp) { return createOperation(cache, createEmbeddingOp(cache, embeddingOp), debugString, locInfo); diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index 081ef02fe..c16b6711d 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -31,6 +31,7 @@ #include "ttnn/operations/normalization/softmax/softmax.hpp" #include "ttnn/operations/pool/generic/generic_pools.hpp" #include "ttnn/operations/reduction/generic/generic_reductions.hpp" +#include "ttnn/operations/reduction/prod/prod.hpp" #include "ttnn/tensor/host_buffer/functions.hpp" #include "ttnn/tensor/host_buffer/owned_buffer.hpp" #include "ttnn/tensor/shape/shape.hpp" diff --git a/runtime/lib/ttnn/operations/reduction/reduction.cpp b/runtime/lib/ttnn/operations/reduction/reduction.cpp index 631df3f51..fc37ded5b 100644 --- a/runtime/lib/ttnn/operations/reduction/reduction.cpp +++ b/runtime/lib/ttnn/operations/reduction/reduction.cpp @@ -35,6 +35,34 @@ static void runReductionOp( tensorPool.insert_or_assign(op->out()->global_id(), out); } +static void runReductionProdOp(::tt::target::ttnn::ReductionOp const *op, + ProgramTensorPool &tensorPool) { + + ::tt::tt_metal::MemoryConfig outputMemoryConfig = + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); + const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id()); + DEBUG_ASSERT(in.is_allocated()); + + const auto *fbDimArg = op->dim_arg(); + int dim = 0; + bool all_dimensions = false; + if (fbDimArg) { + ::ttnn::SmallVector dimArg = + ::ttnn::SmallVector(fbDimArg->begin(), fbDimArg->end()); + dim = dimArg[0]; + if (dimArg.size() == 4) { + all_dimensions = true; + } + } else { + all_dimensions = true; + } + + ::ttnn::Tensor out = ::ttnn::prod(in, all_dimensions, dim, op->keep_dim(), + outputMemoryConfig /* memory_config_arg */); + + tensorPool.insert_or_assign(op->out()->global_id(), out); +} + void run(const ::tt::target::ttnn::ReductionOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); switch (op->type()) { @@ -50,6 +78,10 @@ void run(const ::tt::target::ttnn::ReductionOp *op, ProgramContext &context) { runReductionOp(op, tensorPool, ::ttnn::max); break; } + case ::tt::target::ttnn::ReductionOpType::Prod: { + runReductionProdOp(op, tensorPool); + break; + } } } } // namespace tt::runtime::ttnn::operations::reduction diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/reduce_add_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/reduction/reduce_add_op.mlir similarity index 100% rename from test/ttmlir/Conversion/StableHLOToTTIR/reduce_add_op.mlir rename to test/ttmlir/Conversion/StableHLOToTTIR/reduction/reduce_add_op.mlir diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/reduce_maximum_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/reduction/reduce_maximum_op.mlir similarity index 100% rename from test/ttmlir/Conversion/StableHLOToTTIR/reduce_maximum_op.mlir rename to test/ttmlir/Conversion/StableHLOToTTIR/reduction/reduce_maximum_op.mlir diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/reduction/reduce_prod_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/reduction/reduce_prod_op.mlir new file mode 100644 index 000000000..95a2fa1cf --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/reduction/reduce_prod_op.mlir @@ -0,0 +1,82 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +module @jit_reduce_prod attributes {} { + func.func public @test_reduce_prod_4to3dim(%arg0: tensor<128x10x32x4xf32>, %cst_0: tensor) -> tensor<128x32x4xf32> { + // CHECK-LABEL: func.func public @test_reduce_prod_4to3dim + // CHECK: tensor.empty + // CHECK: "ttir.prod" + // CHECK-SAME: dim_arg = [1 : i32] + // CHECK-SAME: keep_dim = false + // CHECK-SAME: tensor<128x10x32x4xf32> + // CHECK-SAME: -> tensor<128x32x4xf32> + %0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [1] : (tensor<128x10x32x4xf32>, tensor) -> tensor<128x32x4xf32> + return %0 : tensor<128x32x4xf32> + } + + func.func public @test_reduce_prod_4to0dim(%arg0: tensor<128x10x32x4xbf16>, %cst_0: tensor) -> tensor { + // CHECK-LABEL: func.func public @test_reduce_prod_4to0dim + // CHECK: tensor.empty + // CHECK: "ttir.prod" + // CHECK-SAME: dim_arg = [0 : i32, 1 : i32, 2 : i32, 3 : i32] + // CHECK-SAME: keep_dim = false + // CHECK-SAME: tensor<128x10x32x4xbf16> + // CHECK-SAME: -> tensor<1xbf16> + %0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [0, 1, 2, 3] : (tensor<128x10x32x4xbf16>, tensor) -> tensor + return %0 : tensor + } + + func.func public @test_reduce_prod_3to2dim(%arg0: tensor<128x10x4xf32>, %cst_0: tensor) -> tensor<128x4xf32> { + // CHECK: tensor.empty + // CHECK: "ttir.prod" + // CHECK-SAME: dim_arg = [1 : i32] + // CHECK-SAME: keep_dim = false + // CHECK-SAME: tensor<128x10x4xf32> + // CHECK-SAME: -> tensor<128x4xf32> + %0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [1] : (tensor<128x10x4xf32>, tensor) -> tensor<128x4xf32> + return %0 : tensor<128x4xf32> + } + + func.func public @test_reduce_prod_3to0dim(%arg0: tensor<128x10x4xbf16>, %cst_0: tensor) -> tensor { + // CHECK: tensor.empty + // CHECK: "ttir.prod" + // CHECK-SAME: dim_arg = [0 : i32, 1 : i32, 2 : i32] + // CHECK-SAME: keep_dim = false + // CHECK-SAME: tensor<128x10x4xbf16> + // CHECK-SAME: -> tensor<1xbf16> + %0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [0, 1, 2] : (tensor<128x10x4xbf16>, tensor) -> tensor + return %0 : tensor + } + + func.func public @test_reduce_prod_2to1dim(%arg0: tensor<128x10xf32>, %cst_0: tensor) -> tensor<128xf32> { + // CHECK: tensor.empty + // CHECK: "ttir.prod" + // CHECK-SAME: dim_arg = [1 : i32] + // CHECK-SAME: keep_dim = false + // CHECK-SAME: tensor<128x10xf32> + // CHECK-SAME: -> tensor<128xf32> + %0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [1] : (tensor<128x10xf32>, tensor) -> tensor<128xf32> + return %0 : tensor<128xf32> + } + + func.func public @test_reduce_prod_2to0dim(%arg0: tensor<128x10xbf16>, %cst_0: tensor) -> tensor { + // CHECK: tensor.empty + // CHECK: "ttir.prod" + // CHECK-SAME: dim_arg = [0 : i32, 1 : i32] + // CHECK-SAME: keep_dim = false + // CHECK-SAME: tensor<128x10xbf16> + // CHECK-SAME: -> tensor<1xbf16> + %0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [0, 1] : (tensor<128x10xbf16>, tensor) -> tensor + return %0 : tensor + } + + func.func public @test_reduce_prod_1to0dim(%arg0: tensor<128xbf16>, %cst_0: tensor) -> tensor { + // CHECK: tensor.empty + // CHECK: "ttir.prod" + // CHECK-SAME: dim_arg = [0 : i32] + // CHECK-SAME: keep_dim = false + // CHECK-SAME: tensor<128xbf16> + // CHECK-SAME: -> tensor<1xbf16> + %0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [0] : (tensor<128xbf16>, tensor) -> tensor + return %0 : tensor + } +} diff --git a/test/ttmlir/Dialect/TTIR/reduce_ops/negative_reduce_prod_op.mlir b/test/ttmlir/Dialect/TTIR/reduce_ops/negative_reduce_prod_op.mlir new file mode 100644 index 000000000..0ff9e07b2 --- /dev/null +++ b/test/ttmlir/Dialect/TTIR/reduce_ops/negative_reduce_prod_op.mlir @@ -0,0 +1,25 @@ +// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s +// Negative tests for reduce(prod) op + +// CHECK: error: 'ttir.prod' op TTNN only supports reduce(prod) along one dimension or all dimensions. +func.func public @test_reduce_prod_multiple_dims(%arg0: tensor<128x10x32x4xf32>) -> tensor<128x32xf32> { + %0 = tensor.empty() : tensor<128x32xf32> + %1 = "ttir.prod"(%arg0, %0) <{dim_arg = [1 : i32, 3 : i32], keep_dim = false}> : (tensor<128x10x32x4xf32>, tensor<128x32xf32>) -> tensor<128x32xf32> + return %1 : tensor<128x32xf32> +} + +// ----- +// CHECK: error: 'ttir.prod' op TTNN only supports Reduce(prod) along all dimensions for bfloat16 datatype. +func.func public @test_reduce_prod_all_dims_f32(%arg0: tensor<128x10x32x4xf32>) -> tensor<1xf32> { + %0 = tensor.empty() : tensor<1xf32> + %1 = "ttir.prod"(%arg0, %0) <{dim_arg = [0: i32, 1 : i32, 2: i32, 3 : i32], keep_dim = false}> : (tensor<128x10x32x4xf32>, tensor<1xf32>) -> tensor<1xf32> + return %1 : tensor<1xf32> +} + +// ----- +// error: 'ttir.prod' op Input tensor rank is greater than 4 for reduce(product). +func.func public @test_reduce_prod_higher_rank(%arg0: tensor<128x10x32x4x1xf32>) -> tensor<10x32x4x1xf32> { + %0 = tensor.empty() : tensor<10x32x4x1xf32> + %1 = "ttir.prod"(%arg0, %0) <{dim_arg = [0 : i32], keep_dim = false}> : (tensor<128x10x32x4x1xf32>, tensor<10x32x4x1xf32>) -> tensor<10x32x4x1xf32> + return %1 : tensor<10x32x4x1xf32> +} diff --git a/test/ttmlir/Dialect/TTNN/simple_max.mlir b/test/ttmlir/Dialect/TTNN/reduction/simple_max.mlir similarity index 100% rename from test/ttmlir/Dialect/TTNN/simple_max.mlir rename to test/ttmlir/Dialect/TTNN/reduction/simple_max.mlir diff --git a/test/ttmlir/Dialect/TTNN/simple_mean.mlir b/test/ttmlir/Dialect/TTNN/reduction/simple_mean.mlir similarity index 100% rename from test/ttmlir/Dialect/TTNN/simple_mean.mlir rename to test/ttmlir/Dialect/TTNN/reduction/simple_mean.mlir diff --git a/test/ttmlir/Dialect/TTNN/reduction/simple_prod.mlir b/test/ttmlir/Dialect/TTNN/reduction/simple_prod.mlir new file mode 100644 index 000000000..acc743853 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/reduction/simple_prod.mlir @@ -0,0 +1,80 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s + +module attributes {} { + func.func public @test_reduce_prod_4to3dim(%arg0: tensor<128x10x32x4xf32>) -> tensor<128x32x4xf32> { + // CHECK-LABEL: func.func public @test_reduce_prod_4to3dim + %0 = tensor.empty() : tensor<128x32x4xf32> + // CHECK: %[[PROD:[0-9]+]] = "ttnn.prod" + // CHECK-SAME: dim_arg = [1 : i32] + // CHECK-SAME: keep_dim = true + // CHECK-SAME: (tensor<128x10x32x4xf32, + // CHECK-SAME: -> tensor<128x1x32x4xf32, + // CHECK: "ttnn.reshape"(%[[PROD]]) + // CHECK-SAME: shape = [128 : i32, 32 : i32, 4 : i32] + // CHECK-SAME: tensor<128x1x32x4xf32, + // CHECK-SAME: -> tensor<128x32x4xf32 + %1 = "ttir.prod"(%arg0, %0) <{dim_arg = [1: i32], keep_dim = false}> : (tensor<128x10x32x4xf32>, tensor<128x32x4xf32>) -> tensor<128x32x4xf32> + return %1 : tensor<128x32x4xf32> + } + + func.func public @test_reduce_prod_4to0dim(%arg0: tensor<128x10x32x4xbf16>) -> tensor<1xbf16> { + // CHECK-LABEL: func.func public @test_reduce_prod_4to0dim + %0 = tensor.empty() : tensor<1xbf16> + // CHECK-NOT: dim_arg = [1 : i32] + // CHECK: %[[PROD:[0-9]+]] = "ttnn.prod" + // CHECK-SAME: keep_dim = true + // CHECK-SAME: (tensor<128x10x32x4xbf16, + // CHECK-SAME: -> tensor<1x1x1x1xbf16, + // CHECK: "ttnn.reshape"(%[[PROD]]) + // CHECK-SAME: shape = [1 : i32] + // CHECK-SAME: tensor<1x1x1x1xbf16, + // CHECK-SAME: -> tensor<1xbf16 + %1 = "ttir.prod"(%arg0, %0) <{dim_arg = [0 : i32, 1 : i32, 2 : i32, 3 : i32], keep_dim = false}> : (tensor<128x10x32x4xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } + + func.func public @test_reduce_prod_3to2dim(%arg0: tensor<128x10x4xf32>) -> tensor<128x4xf32> { + // CHECK-LABEL: func.func public @test_reduce_prod_3to2dim + %0 = tensor.empty() : tensor<128x4xf32> + // CHECK: %[[PROD:[0-9]+]] = "ttnn.prod" + // CHECK-SAME: dim_arg = [1 : i32] + // CHECK-SAME: keep_dim = true + // CHECK-SAME: (tensor<128x10x4xf32, + // CHECK-SAME: -> tensor<128x1x4xf32, + // CHECK: "ttnn.reshape"(%[[PROD]]) + // CHECK-SAME: shape = [128 : i32, 4 : i32] + // CHECK-SAME: tensor<128x1x4xf32, + // CHECK-SAME: -> tensor<128x4xf32 + %1 = "ttir.prod"(%arg0, %0) <{dim_arg = [1: i32], keep_dim = false}> : (tensor<128x10x4xf32>, tensor<128x4xf32>) -> tensor<128x4xf32> + return %1 : tensor<128x4xf32> + } + + func.func public @test_reduce_prod_3to0dim(%arg0: tensor<128x10x4xbf16>) -> tensor<1xbf16> { + // CHECK-LABEL: func.func public @test_reduce_prod_3to0dim + %0 = tensor.empty() : tensor<1xbf16> + // CHECK-NOT: dim_arg = [1 : i32] + // CHECK: %[[PROD:[0-9]+]] = "ttnn.prod" + // CHECK-SAME: keep_dim = true + // CHECK-SAME: (tensor<128x10x4xbf16, + // CHECK-SAME: -> tensor<1x1x1xbf16, + // CHECK: "ttnn.reshape"(%[[PROD]]) + // CHECK-SAME: shape = [1 : i32] + // CHECK-SAME: tensor<1x1x1xbf16, + // CHECK-SAME: -> tensor<1xbf16 + %1 = "ttir.prod"(%arg0, %0) <{dim_arg = [0 : i32, 1 : i32, 2 : i32], keep_dim = false}> : (tensor<128x10x4xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } + + func.func public @test_reduce_prod_1to0dim(%arg0: tensor<128xbf16>) -> tensor<1xbf16> { + // CHECK-LABEL: func.func public @test_reduce_prod_1to0dim + %0 = tensor.empty() : tensor<1xbf16> + // CHECK-NOT: dim_arg = [0 : i32] + // CHECK-NOT: ttnn.reshape + // CHECK: %[[PROD:[0-9]+]] = "ttnn.prod" + // CHECK-SAME: keep_dim = true + // CHECK-SAME: (tensor<128xbf16, + // CHECK-SAME: -> tensor<1xbf16, + %1 = "ttir.prod"(%arg0, %0) <{dim_arg = [0 : i32], keep_dim = false}> : (tensor<128xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + return %1 : tensor<1xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/simple_sum.mlir b/test/ttmlir/Dialect/TTNN/reduction/simple_sum.mlir similarity index 100% rename from test/ttmlir/Dialect/TTNN/simple_sum.mlir rename to test/ttmlir/Dialect/TTNN/reduction/simple_sum.mlir diff --git a/test/ttmlir/Silicon/StableHLO/reduce_add_op.mlir b/test/ttmlir/Silicon/StableHLO/reduction/reduce_add_op.mlir similarity index 97% rename from test/ttmlir/Silicon/StableHLO/reduce_add_op.mlir rename to test/ttmlir/Silicon/StableHLO/reduction/reduce_add_op.mlir index 89f51123e..8e30abf28 100644 --- a/test/ttmlir/Silicon/StableHLO/reduce_add_op.mlir +++ b/test/ttmlir/Silicon/StableHLO/reduction/reduce_add_op.mlir @@ -1,7 +1,7 @@ // REQUIRES: stablehlo // RUN: rm -rf %t.ttnn // RUN: rm -rf %t.mlir -// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-opt %s --stablehlo-to-ttir-pipeline --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir // RUN: FileCheck --input-file=%t.mlir %s // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn // UNSUPPORTED: true diff --git a/test/ttmlir/Silicon/StableHLO/reduce_maximum_op.mlir b/test/ttmlir/Silicon/StableHLO/reduction/reduce_maximum_op.mlir similarity index 97% rename from test/ttmlir/Silicon/StableHLO/reduce_maximum_op.mlir rename to test/ttmlir/Silicon/StableHLO/reduction/reduce_maximum_op.mlir index 8ee57fd52..2f56153b4 100644 --- a/test/ttmlir/Silicon/StableHLO/reduce_maximum_op.mlir +++ b/test/ttmlir/Silicon/StableHLO/reduction/reduce_maximum_op.mlir @@ -1,7 +1,7 @@ // REQUIRES: stablehlo // RUN: rm -rf %t.ttnn // RUN: rm -rf %t.mlir -// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-opt %s --stablehlo-to-ttir-pipeline --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir // RUN: FileCheck --input-file=%t.mlir %s // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn // UNSUPPORTED: true diff --git a/test/ttmlir/Silicon/StableHLO/reduction/reduce_prod_op.mlir b/test/ttmlir/Silicon/StableHLO/reduction/reduce_prod_op.mlir new file mode 100644 index 000000000..b1bc410bc --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/reduction/reduce_prod_op.mlir @@ -0,0 +1,111 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s + +module @jit_reduce_prod attributes {} { + func.func public @test_reduce_prod_4to3dim(%arg0: tensor<128x10x32x4xf32>, %cst_0: tensor) -> tensor<128x32x4xf32> { + // CHECK-LABEL: func.func public @test_reduce_prod_4to3dim + // CHECK: %[[PROD:[0-9]+]] = "ttnn.prod" + // CHECK-SAME: dim_arg = [1 : i32] + // CHECK-SAME: keep_dim = true + // CHECK-SAME: (tensor<128x10x32x4xf32, + // CHECK-SAME: -> tensor<128x1x32x4xf32, + // CHECK: "ttnn.reshape"(%[[PROD]]) + // CHECK-SAME: shape = [128 : i32, 32 : i32, 4 : i32] + // CHECK-SAME: tensor<128x1x32x4xf32, + // CHECK-SAME: -> tensor<128x32x4xf32 + %0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [1] : (tensor<128x10x32x4xf32>, tensor) -> tensor<128x32x4xf32> + return %0 : tensor<128x32x4xf32> + } + + func.func public @test_reduce_prod_4to0dim(%arg0: tensor<128x10x32x4xbf16>, %cst_0: tensor) -> tensor { + // CHECK-LABEL: func.func public @test_reduce_prod_4to0dim + // CHECK-NOT: dim_arg + // CHECK: %[[PROD:[0-9]+]] = "ttnn.prod" + // CHECK-SAME: keep_dim = true + // CHECK-SAME: (tensor<128x10x32x4xbf16, + // CHECK-SAME: -> tensor<1x1x1x1xbf16, + // CHECK: "ttnn.reshape"(%[[PROD]]) + // CHECK-SAME: shape = [1 : i32] + // CHECK-SAME: tensor<1x1x1x1xbf16, + // CHECK-SAME: -> tensor<1xbf16 + %0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [0, 1, 2, 3] : (tensor<128x10x32x4xbf16>, tensor) -> tensor + return %0 : tensor + } + + func.func public @test_reduce_prod_3to2dim(%arg0: tensor<128x10x4xf32>, %cst_0: tensor) -> tensor<128x10xf32> { + // CHECK-LABEL: func.func public @test_reduce_prod_3to2dim + // CHECK: %[[PROD:[0-9]+]] = "ttnn.prod" + // CHECK-SAME: dim_arg = [2 : i32] + // CHECK-SAME: keep_dim = true + // CHECK-SAME: (tensor<128x10x4xf32, + // CHECK-SAME: -> tensor<128x10x1xf32, + // CHECK: "ttnn.reshape"(%[[PROD]]) + // CHECK-SAME: shape = [128 : i32, 10 : i32] + // CHECK-SAME: tensor<128x10x1xf32, + // CHECK-SAME: -> tensor<128x10xf32 + %0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [2] : (tensor<128x10x4xf32>, tensor) -> tensor<128x10xf32> + return %0 : tensor<128x10xf32> + } + + func.func public @test_reduce_prod_3to0dim(%arg0: tensor<128x10x4xbf16>, %cst_0: tensor) -> tensor { + // CHECK-LABEL: func.func public @test_reduce_prod_3to0dim + // CHECK-NOT: dim_arg + // CHECK: %[[PROD:[0-9]+]] = "ttnn.prod" + // CHECK-SAME: keep_dim = true + // CHECK-SAME: (tensor<128x10x4xbf16, + // CHECK-SAME: -> tensor<1x1x1xbf16, + // CHECK: "ttnn.reshape"(%[[PROD]]) + // CHECK-SAME: shape = [1 : i32] + // CHECK-SAME: tensor<1x1x1xbf16, + // CHECK-SAME: -> tensor<1xbf16 + %0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [0, 1, 2] : (tensor<128x10x4xbf16>, tensor) -> tensor + return %0 : tensor + } + + func.func public @test_reduce_prod_2to1dim(%arg0: tensor<128x10xf32>, %cst_0: tensor) -> tensor<128xf32> { + // CHECK-LABEL: func.func public @test_reduce_prod_2to1dim + // CHECK: %[[PROD:[0-9]+]] = "ttnn.prod" + // CHECK-SAME: dim_arg = [1 : i32] + // CHECK-SAME: keep_dim = true + // CHECK-SAME: (tensor<128x10xf32, + // CHECK-SAME: -> tensor<128x1xf32, + // CHECK: "ttnn.reshape"(%[[PROD]]) + // CHECK-SAME: shape = [128 : i32] + // CHECK-SAME: tensor<128x1xf32, + // CHECK-SAME: -> tensor<128xf32 + %0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [1] : (tensor<128x10xf32>, tensor) -> tensor<128xf32> + return %0 : tensor<128xf32> + } + + func.func public @test_reduce_prod_2to0dim(%arg0: tensor<128x10xbf16>, %cst_0: tensor) -> tensor { + // CHECK-LABEL: func.func public @test_reduce_prod_2to0dim + // CHECK-NOT: dim_arg + // CHECK: %[[PROD:[0-9]+]] = "ttnn.prod" + // CHECK-SAME: keep_dim = true + // CHECK-SAME: (tensor<128x10xbf16, + // CHECK-SAME: -> tensor<1x1xbf16, + // CHECK: "ttnn.reshape"(%[[PROD]]) + // CHECK-SAME: shape = [1 : i32] + // CHECK-SAME: tensor<1x1xbf16, + // CHECK-SAME: -> tensor<1xbf16 + %0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [0, 1] : (tensor<128x10xbf16>, tensor) -> tensor + return %0 : tensor + } + + func.func public @test_reduce_prod_1to0dim(%arg0: tensor<128xbf16>, %cst_0: tensor) -> tensor { + // CHECK-LABEL: func.func public @test_reduce_prod_1to0dim + // CHECK-NOT: dim_arg + // CHECK-NOT: ttnn.reshape + // CHECK: %[[PROD:[0-9]+]] = "ttnn.prod" + // CHECK-SAME: keep_dim = true + // CHECK-SAME: (tensor<128xbf16, + // CHECK-SAME: -> tensor<1xbf16, + %0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [0] : (tensor<128xbf16>, tensor) -> tensor + return %0 : tensor + } +}