Skip to content

Commit

Permalink
reduce prod op
Browse files Browse the repository at this point in the history
  • Loading branch information
mmanzoorTT committed Jan 16, 2025
1 parent 5d8c602 commit 1dd7293
Show file tree
Hide file tree
Showing 23 changed files with 431 additions and 4 deletions.
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,13 @@ def TTIR_MaxOp : TTIR_ReductionOp<"max"> {
}];
}

def TTIR_ProdOp : TTIR_ReductionOp<"prod"> {
let summary = "Product reduction op.";
let description = [{
Product reduction op.
}];
}

def TTIR_EmbeddingOp : TTIR_DPSOp<"embedding"> {
let summary = "Embedding op.";
let description = [{
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,13 @@ def TTNN_MaxOp : TTNN_ReductionOp<"max"> {
}];
}

def TTNN_ProdOp : TTNN_ReductionOp<"prod"> {
let summary = "Product reduction op.";
let description = [{
Product reduction op.
}];
}

def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> {
let summary = "Embedding op.";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ enum ReductionOpType: uint32 {
Sum,
Mean,
Max,
Prod,
}

table ReductionOp {
Expand Down
27 changes: 27 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ class StableHLOToTTIRReduceOpConversionPattern
return matchAndRewriteInternal<mlir::tt::ttir::MaxOp>(srcOp, adaptor,
rewriter);
}
if (mlir::isa<mlir::stablehlo::MulOp>(innerOp)) {
return matchAndRewriteInternal<mlir::tt::ttir::ProdOp>(srcOp, adaptor,
rewriter);
}

return failure();
}
Expand All @@ -104,6 +108,29 @@ class StableHLOToTTIRReduceOpConversionPattern
"Expecting StableHLO Reduce OP to have a body operation defined.");
}

mlir::Operation &innerOp = srcOp.getBody().front().front();
if (mlir::isa<mlir::stablehlo::MulOp>(innerOp)) {
RankedTensorType inputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getInputs()[0].getType()));
int64_t inputTensorRank = inputType.getRank();
int64_t numDims = srcOp.getDimensions().size();
mlir::Type elementType = inputType.getElementType();
if (inputTensorRank > 4) {
return rewriter.notifyMatchFailure(
srcOp, "Input tensor rank is greater than 4 for reduce(product).");
}
if ((numDims > 1) && (numDims != inputTensorRank)) {
return rewriter.notifyMatchFailure(
srcOp, "TTNN only supports reduce(prod) along one dimension or all "
"dimensions.");
}
if ((numDims == inputTensorRank) && !elementType.isBF16()) {
return rewriter.notifyMatchFailure(
srcOp, "TTNN only supports Reduce(prod) along all dimensions for "
"bfloat16 datatype.");
}
}

return success();
}

Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ReductionOpConversionPattern<ttir::SumOp, ttnn::SumOp>,
ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
ReductionOpConversionPattern<ttir::MaxOp, ttnn::MaxOp>,
ReductionOpConversionPattern<ttir::ProdOp, ttnn::ProdOp>,
ElementwiseUnaryWithFloatParameterOpConversionPattern<ttir::LeakyReluOp, ttnn::LeakyReluOp>,
BroadcastOpConversionPattern,
EmbeddingOpConversionPattern,
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
//
patterns.add<DefaultOpConversionPattern<ttnn::SumOp>,
DefaultOpConversionPattern<ttnn::MeanOp>,
DefaultOpConversionPattern<ttnn::MaxOp>>(typeConverter, ctx);
DefaultOpConversionPattern<ttnn::MaxOp>,
DefaultOpConversionPattern<ttnn::ProdOp>>(typeConverter, ctx);

// Conv ops
//
Expand Down
33 changes: 33 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2013,6 +2013,23 @@ verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType,
return reduceOp->emitOpError("Reduce dimensions are not unique");
}

if (mlir::isa<mlir::tt::ttir::ProdOp>(reduceOp)) {
int64_t numReduceDims = uniqueReduceDims.size();
mlir::Type elementType = inputType.getElementType();
if (inputTensorRank > 4) {
return reduceOp->emitOpError(
"Input tensor rank is greater than 4 for reduce(product).");
}
if ((numReduceDims > 1) && (numReduceDims != inputTensorRank)) {
return reduceOp->emitOpError("TTNN only supports reduce(prod) along one "
"dimension or all dimensions.");
}
if ((numReduceDims == inputTensorRank) && !elementType.isBF16()) {
return reduceOp->emitOpError("TTNN only supports Reduce(prod) along all "
"dimensions for bfloat16 datatype.");
}
}

// TODO(mrakita): Add a check that depending on inputShape, reduceDims and
// keepDim computes the expected output shape and checks if it matches the
// actual output shape. Tracked by:
Expand Down Expand Up @@ -2068,3 +2085,19 @@ void mlir::tt::ttir::SumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::LogicalResult mlir::tt::ttir::SumOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
// Reduce ProdOp
//===----------------------------------------------------------------------===//

// ProdOp kernel builder.
void mlir::tt::ttir::ProdOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
// NOLINTNEXTLINE
createReduceOp(opBuilder, block, getLoc(), "prod");
}

// ProdOp verification.
::mlir::LogicalResult mlir::tt::ttir::ProdOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}
9 changes: 9 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1343,4 +1343,13 @@ ::mlir::LogicalResult SumOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
// Reduce ProdOp
//===----------------------------------------------------------------------===//

// ProdOp verification.
::mlir::LogicalResult ProdOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

} // namespace mlir::tt::ttnn
6 changes: 5 additions & 1 deletion lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,12 +425,16 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase<TTNNWorkarounds> {
ttnn::MaxOp>,
workarounds::decomposition::ReduceOpsKeepDimRewritePattern<
ttnn::MeanOp>,
workarounds::decomposition::ReduceOpsKeepDimRewritePattern<
ttnn::ProdOp>,
workarounds::decomposition::ReduceOpsAllDimsRewritePattern<
ttnn::SumOp>,
workarounds::decomposition::ReduceOpsAllDimsRewritePattern<
ttnn::MaxOp>,
workarounds::decomposition::ReduceOpsAllDimsRewritePattern<
ttnn::MeanOp>>(&getContext());
ttnn::MeanOp>,
workarounds::decomposition::ReduceOpsAllDimsRewritePattern<
ttnn::ProdOp>>(&getContext());

runRewritePatterns(std::move(patterns),
GreedyRewriteConfig::kNoLimit /*maxIterations*/);
Expand Down
6 changes: 6 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,8 @@ createReductionOp(FlatbufferObjectCache &cache, ReductionOp op) {
type = ::tt::target::ttnn::ReductionOpType::Mean;
} else if constexpr (std::is_same_v<ReductionOp, MaxOp>) {
type = ::tt::target::ttnn::ReductionOpType::Max;
} else if constexpr (std::is_same_v<ReductionOp, ProdOp>) {
type = ::tt::target::ttnn::ReductionOpType::Prod;
} else {
llvm_unreachable("unhandled ReductionOp");
}
Expand Down Expand Up @@ -1116,6 +1118,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createReductionOp(cache, maxOp), debugString,
locInfo);
}
if (auto prodOp = dyn_cast<ProdOp>(op); prodOp) {
return createOperation(cache, createReductionOp(cache, prodOp), debugString,
locInfo);
}
if (auto embeddingOp = dyn_cast<EmbeddingOp>(op); embeddingOp) {
return createOperation(cache, createEmbeddingOp(cache, embeddingOp),
debugString, locInfo);
Expand Down
1 change: 1 addition & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "ttnn/operations/normalization/softmax/softmax.hpp"
#include "ttnn/operations/pool/generic/generic_pools.hpp"
#include "ttnn/operations/reduction/generic/generic_reductions.hpp"
#include "ttnn/operations/reduction/prod/prod.hpp"
#include "ttnn/tensor/host_buffer/functions.hpp"
#include "ttnn/tensor/host_buffer/owned_buffer.hpp"
#include "ttnn/tensor/shape/shape.hpp"
Expand Down
32 changes: 32 additions & 0 deletions runtime/lib/ttnn/operations/reduction/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,34 @@ static void runReductionOp(
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

static void runReductionProdOp(::tt::target::ttnn::ReductionOp const *op,
ProgramTensorPool &tensorPool) {

::tt::tt_metal::MemoryConfig outputMemoryConfig =
::tt::runtime::ttnn::utils::createMemoryConfig(op->out());
const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id());
DEBUG_ASSERT(in.is_allocated());

const auto *fbDimArg = op->dim_arg();
int dim = 0;
bool all_dimensions = false;
if (fbDimArg) {
::ttnn::SmallVector<int> dimArg =
::ttnn::SmallVector<int>(fbDimArg->begin(), fbDimArg->end());
dim = dimArg[0];
if (dimArg.size() == 4) {
all_dimensions = true;
}
} else {
all_dimensions = true;
}

::ttnn::Tensor out = ::ttnn::prod(in, all_dimensions, dim, op->keep_dim(),
outputMemoryConfig /* memory_config_arg */);

tensorPool.insert_or_assign(op->out()->global_id(), out);
}

void run(const ::tt::target::ttnn::ReductionOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
switch (op->type()) {
Expand All @@ -50,6 +78,10 @@ void run(const ::tt::target::ttnn::ReductionOp *op, ProgramContext &context) {
runReductionOp(op, tensorPool, ::ttnn::max);
break;
}
case ::tt::target::ttnn::ReductionOpType::Prod: {
runReductionProdOp(op, tensorPool);
break;
}
}
}
} // namespace tt::runtime::ttnn::operations::reduction
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_reduce_prod attributes {} {
func.func public @test_reduce_prod_4to3dim(%arg0: tensor<128x10x32x4xf32>, %cst_0: tensor<f32>) -> tensor<128x32x4xf32> {
// CHECK-LABEL: func.func public @test_reduce_prod_4to3dim
// CHECK: tensor.empty
// CHECK: "ttir.prod"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xf32>
// CHECK-SAME: -> tensor<128x32x4xf32>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [1] : (tensor<128x10x32x4xf32>, tensor<f32>) -> tensor<128x32x4xf32>
return %0 : tensor<128x32x4xf32>
}

func.func public @test_reduce_prod_4to0dim(%arg0: tensor<128x10x32x4xbf16>, %cst_0: tensor<bf16>) -> tensor<bf16> {
// CHECK-LABEL: func.func public @test_reduce_prod_4to0dim
// CHECK: tensor.empty
// CHECK: "ttir.prod"
// CHECK-SAME: dim_arg = [0 : i32, 1 : i32, 2 : i32, 3 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xbf16>
// CHECK-SAME: -> tensor<1xbf16>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [0, 1, 2, 3] : (tensor<128x10x32x4xbf16>, tensor<bf16>) -> tensor<bf16>
return %0 : tensor<bf16>
}

func.func public @test_reduce_prod_3to2dim(%arg0: tensor<128x10x4xf32>, %cst_0: tensor<f32>) -> tensor<128x4xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.prod"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x4xf32>
// CHECK-SAME: -> tensor<128x4xf32>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [1] : (tensor<128x10x4xf32>, tensor<f32>) -> tensor<128x4xf32>
return %0 : tensor<128x4xf32>
}

func.func public @test_reduce_prod_3to0dim(%arg0: tensor<128x10x4xbf16>, %cst_0: tensor<bf16>) -> tensor<bf16> {
// CHECK: tensor.empty
// CHECK: "ttir.prod"
// CHECK-SAME: dim_arg = [0 : i32, 1 : i32, 2 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x4xbf16>
// CHECK-SAME: -> tensor<1xbf16>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [0, 1, 2] : (tensor<128x10x4xbf16>, tensor<bf16>) -> tensor<bf16>
return %0 : tensor<bf16>
}

func.func public @test_reduce_prod_2to1dim(%arg0: tensor<128x10xf32>, %cst_0: tensor<f32>) -> tensor<128xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.prod"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10xf32>
// CHECK-SAME: -> tensor<128xf32>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [1] : (tensor<128x10xf32>, tensor<f32>) -> tensor<128xf32>
return %0 : tensor<128xf32>
}

func.func public @test_reduce_prod_2to0dim(%arg0: tensor<128x10xbf16>, %cst_0: tensor<bf16>) -> tensor<bf16> {
// CHECK: tensor.empty
// CHECK: "ttir.prod"
// CHECK-SAME: dim_arg = [0 : i32, 1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10xbf16>
// CHECK-SAME: -> tensor<1xbf16>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [0, 1] : (tensor<128x10xbf16>, tensor<bf16>) -> tensor<bf16>
return %0 : tensor<bf16>
}

func.func public @test_reduce_prod_1to0dim(%arg0: tensor<128xbf16>, %cst_0: tensor<bf16>) -> tensor<bf16> {
// CHECK: tensor.empty
// CHECK: "ttir.prod"
// CHECK-SAME: dim_arg = [0 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128xbf16>
// CHECK-SAME: -> tensor<1xbf16>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.multiply across dimensions = [0] : (tensor<128xbf16>, tensor<bf16>) -> tensor<bf16>
return %0 : tensor<bf16>
}
}
25 changes: 25 additions & 0 deletions test/ttmlir/Dialect/TTIR/reduce_ops/negative_reduce_prod_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s
// Negative tests for reduce(prod) op

// CHECK: error: 'ttir.prod' op TTNN only supports reduce(prod) along one dimension or all dimensions.
func.func public @test_reduce_prod_multiple_dims(%arg0: tensor<128x10x32x4xf32>) -> tensor<128x32xf32> {
%0 = tensor.empty() : tensor<128x32xf32>
%1 = "ttir.prod"(%arg0, %0) <{dim_arg = [1 : i32, 3 : i32], keep_dim = false}> : (tensor<128x10x32x4xf32>, tensor<128x32xf32>) -> tensor<128x32xf32>
return %1 : tensor<128x32xf32>
}

// -----
// CHECK: error: 'ttir.prod' op TTNN only supports Reduce(prod) along all dimensions for bfloat16 datatype.
func.func public @test_reduce_prod_all_dims_f32(%arg0: tensor<128x10x32x4xf32>) -> tensor<1xf32> {
%0 = tensor.empty() : tensor<1xf32>
%1 = "ttir.prod"(%arg0, %0) <{dim_arg = [0: i32, 1 : i32, 2: i32, 3 : i32], keep_dim = false}> : (tensor<128x10x32x4xf32>, tensor<1xf32>) -> tensor<1xf32>
return %1 : tensor<1xf32>
}

// -----
// error: 'ttir.prod' op Input tensor rank is greater than 4 for reduce(product).
func.func public @test_reduce_prod_higher_rank(%arg0: tensor<128x10x32x4x1xf32>) -> tensor<10x32x4x1xf32> {
%0 = tensor.empty() : tensor<10x32x4x1xf32>
%1 = "ttir.prod"(%arg0, %0) <{dim_arg = [0 : i32], keep_dim = false}> : (tensor<128x10x32x4x1xf32>, tensor<10x32x4x1xf32>) -> tensor<10x32x4x1xf32>
return %1 : tensor<10x32x4x1xf32>
}
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 1dd7293

Please sign in to comment.