From bcc9fbd610a859cfaa3d62c1e8f2b1f2afff09d5 Mon Sep 17 00:00:00 2001 From: Muhammad Asif Manzoor Date: Tue, 14 Jan 2025 02:10:28 +0000 Subject: [PATCH] reduce prod op --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 7 +++ include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 7 +++ include/ttmlir/Target/TTNN/program.fbs | 1 + .../StableHLOToTTIRPatterns.cpp | 17 +++++-- lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 7 ++- lib/Dialect/TTIR/IR/TTIROps.cpp | 16 ++++++ lib/Dialect/TTNN/IR/TTNNOps.cpp | 9 ++++ .../Workarounds/TTNNWorkarounds.cpp | 10 +++- lib/Target/TTNN/TTNNToFlatbuffer.cpp | 21 ++++++-- runtime/include/tt/runtime/detail/ttnn.h | 1 + .../ttnn/operations/reduction/reduction.cpp | 51 +++++++++++++++++++ 11 files changed, 135 insertions(+), 12 deletions(-) diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 839bd81d9..1fc37bd0b 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -696,6 +696,13 @@ def TTIR_MaxOp : TTIR_ReductionOp<"max"> { }]; } +def TTIR_ProdOp : TTIR_ReductionOp<"prod"> { + let summary = "Min reduction op."; + let description = [{ + Min reduction op. + }]; +} + def TTIR_EmbeddingOp : TTIR_DPSOp<"embedding"> { let summary = "Embedding op."; let description = [{ diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index ba2484ac5..a3d37fffd 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -609,6 +609,13 @@ def TTNN_MaxOp : TTNN_ReductionOp<"max"> { }]; } +def TTNN_ProdOp : TTNN_ReductionOp<"prod"> { + let summary = "Min reduction op."; + let description = [{ + Min reduction op. + }]; +} + def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> { let summary = "Embedding op."; let description = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 0838c629d..ceb6f437a 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -168,6 +168,7 @@ enum ReductionOpType: uint32 { Sum, Mean, Max, + Prod, } table ReductionOp { diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 40574741d..a4fed2def 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -69,12 +69,14 @@ class StableHLOToTTIRReduceOpConversionPattern matchAndRewrite(mlir::stablehlo::ReduceOp srcOp, mlir::stablehlo::ReduceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + llvm::errs() << "Reduction conversion\n"; LogicalResult legalityResult = checkBasicLegality(srcOp, adaptor, rewriter); if (!legalityResult.succeeded()) { return legalityResult; } const mlir::Operation &innerOp = srcOp.getBody().front().front(); + llvm::errs() << "isa(innerOp) << '\n'; if (mlir::isa(innerOp)) { return matchAndRewriteInternal(srcOp, adaptor, @@ -84,10 +86,13 @@ class StableHLOToTTIRReduceOpConversionPattern return matchAndRewriteInternal(srcOp, adaptor, rewriter); } + if (mlir::isa(innerOp)) { + return matchAndRewriteInternal(srcOp, adaptor, + rewriter); + } return failure(); } - private: LogicalResult checkBasicLegality(mlir::stablehlo::ReduceOp &srcOp, mlir::stablehlo::ReduceOp::Adaptor adaptor, @@ -112,19 +117,21 @@ class StableHLOToTTIRReduceOpConversionPattern matchAndRewriteInternal(mlir::stablehlo::ReduceOp &srcOp, mlir::stablehlo::ReduceOp::Adaptor &adaptor, ConversionPatternRewriter &rewriter) const { + llvm::errs() << "rewrite internal\n"; auto outputType = mlir::cast( getTypeConverter()->convertType(srcOp.getResultTypes().front())); + llvm::errs() << "outputType: "; outputType.dump(); tensor::EmptyOp outputTensor = rewriter.create( srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); - + llvm::errs() << "emptyOp: "; outputTensor->dump(); // Can't reuse the original dimensions attribute because it uses i64 type. mlir::ArrayAttr dimArg = rewriter.getI32ArrayAttr( llvm::SmallVector(srcOp.getDimensions())); - - rewriter.replaceOpWithNewOp( + llvm::errs() << "dimArg: "; dimArg.dump(); + auto temp = rewriter.replaceOpWithNewOp( srcOp, outputType, adaptor.getInputs().front(), outputTensor, false /* keep_dim */, dimArg); - + llvm::errs() << "converted reduce op: "; temp.dump(); return success(); } }; diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 2e84eb347..0210b4677 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -329,10 +329,13 @@ class ReductionOpConversionPattern : public OpConversionPattern { LogicalResult matchAndRewrite(TTIROpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( + llvm::errs() << "TTIR->TTNN conversion\n"; + auto temp = rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), adaptor.getInput(), adaptor.getKeepDim(), adaptor.getDimArg().value_or(nullptr)); + temp.dump(); + llvm::errs() << "Successfull\n"; return success(); } }; @@ -1206,7 +1209,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, ReductionOpConversionPattern, ReductionOpConversionPattern, ReductionOpConversionPattern, - ElementwiseUnaryWithFloatParameterOpConversionPattern, + ReductionOpConversionPattern, BroadcastOpConversionPattern, EmbeddingOpConversionPattern, EmbeddingBackwardOpConversionPattern, diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 73daad713..c4212e81b 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -1857,3 +1857,19 @@ void mlir::tt::ttir::SumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::LogicalResult mlir::tt::ttir::SumOp::verify() { return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); } + +//===----------------------------------------------------------------------===// +// Reduce ProdOp +//===----------------------------------------------------------------------===// + +// MaxOp kernel builder. +void mlir::tt::ttir::ProdOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, + ::mlir::Block *block) { + // NOLINTNEXTLINE + createReduceOp(opBuilder, block, getLoc(), "prod"); +} + +// MaxOp verification. +::mlir::LogicalResult mlir::tt::ttir::ProdOp::verify() { + return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); +} diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index eccb1e9ba..778f51069 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -1343,4 +1343,13 @@ ::mlir::LogicalResult SumOp::verify() { return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); } +//===----------------------------------------------------------------------===// +// Reduce MinOp +//===----------------------------------------------------------------------===// + +// SumOp verification. +::mlir::LogicalResult ProdOp::verify() { + return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); +} + } // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp index a51166537..afc8aaa16 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp @@ -425,12 +425,20 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase { ttnn::MaxOp>, workarounds::decomposition::ReduceOpsKeepDimRewritePattern< ttnn::MeanOp>, + workarounds::decomposition::ReduceOpsKeepDimRewritePattern< + ttnn::MinOp>, + workarounds::decomposition::ReduceOpsKeepDimRewritePattern< + ttnn::ProdOp>, workarounds::decomposition::ReduceOpsAllDimsRewritePattern< ttnn::SumOp>, workarounds::decomposition::ReduceOpsAllDimsRewritePattern< ttnn::MaxOp>, workarounds::decomposition::ReduceOpsAllDimsRewritePattern< - ttnn::MeanOp>>(&getContext()); + ttnn::MeanOp>, + workarounds::decomposition::ReduceOpsAllDimsRewritePattern< + ttnn::MinOp>, + workarounds::decomposition::ReduceOpsAllDimsRewritePattern< + ttnn::ProdOp>>(&getContext()); runRewritePatterns(std::move(patterns), GreedyRewriteConfig::kNoLimit /*maxIterations*/); diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 170944311..c8e4e8b39 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -731,10 +731,11 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { getOperandThroughDPSOps(op.getOutputs().front())), paramsType, params); } - +#include template ::flatbuffers::Offset<::tt::target::ttnn::ReductionOp> createReductionOp(FlatbufferObjectCache &cache, ReductionOp op) { + std::cerr << "Flatbuffer reduction\n"; ::tt::target::ttnn::ReductionOpType type; if constexpr (std::is_same_v) { type = ::tt::target::ttnn::ReductionOpType::Sum; @@ -742,19 +743,27 @@ createReductionOp(FlatbufferObjectCache &cache, ReductionOp op) { type = ::tt::target::ttnn::ReductionOpType::Mean; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::ReductionOpType::Max; + } else if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::ReductionOpType::Prod; } else { llvm_unreachable("unhandled ReductionOp"); } - + std::cerr << "check-1\n"; auto in = cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); + std::cerr << "check-2\n"; auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, kHostAllocatedAddress, kHostAllocatedSize); + std::cerr << "check-3\n"; + //auto temp = op.getDimArg(); + std::cerr << "check-4\n";// << (*temp).size() << '\n'; auto dim_arg = arrayAttrToFlatbuffer(cache, op.getDimArg()); - - return ::tt::target::ttnn::CreateReductionOp(*cache.fbb, type, in, output, +std::cerr << "check-5\n"; + auto temp3 = ::tt::target::ttnn::CreateReductionOp(*cache.fbb, type, in, output, dim_arg, op.getKeepDim()); + std::cerr << "check=6\n"; + return temp3; } ::flatbuffers::Offset<::tt::target::ttnn::TransposeOp> @@ -1116,6 +1125,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, return createOperation(cache, createReductionOp(cache, maxOp), debugString, locInfo); } + if (auto prodOp = dyn_cast(op); prodOp) { + return createOperation(cache, createReductionOp(cache, prodOp), debugString, + locInfo); + } if (auto embeddingOp = dyn_cast(op); embeddingOp) { return createOperation(cache, createEmbeddingOp(cache, embeddingOp), debugString, locInfo); diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index b1007d405..42d71104e 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -31,6 +31,7 @@ #include "ttnn/operations/normalization/softmax/softmax.hpp" #include "ttnn/operations/pool/generic/generic_pools.hpp" #include "ttnn/operations/reduction/generic/generic_reductions.hpp" +#include "ttnn/operations/reduction/prod/prod.hpp" #include "ttnn/tensor/host_buffer/functions.hpp" #include "ttnn/tensor/host_buffer/owned_buffer.hpp" #include "ttnn/tensor/shape/shape.hpp" diff --git a/runtime/lib/ttnn/operations/reduction/reduction.cpp b/runtime/lib/ttnn/operations/reduction/reduction.cpp index 631df3f51..a699987d7 100644 --- a/runtime/lib/ttnn/operations/reduction/reduction.cpp +++ b/runtime/lib/ttnn/operations/reduction/reduction.cpp @@ -28,12 +28,57 @@ static void runReductionOp( fbDimArg->end())) : std::nullopt; + if (fbDimArg) { + std::cerr << "reduce: fbDimArg found\n"; + std::cerr << "size: " << fbDimArg->size() << '\n'; + std::cerr << "[0]: " << *(fbDimArg->begin()) << '\n'; + } + else { + std::cerr << "reduce: fbDimArg not found\n"; + } ::ttnn::Tensor out = ttnnOp( in, dimArg, op->keep_dim(), outputMemoryConfig /* memory_config_arg */, std::nullopt /* compute_kernel_config */, 1.0f /* scalar */); tensorPool.insert_or_assign(op->out()->global_id(), out); } +#include +static void runReductionProdOp( + ::tt::target::ttnn::ReductionOp const *op, ProgramTensorPool &tensorPool) { + ::tt::tt_metal::MemoryConfig outputMemoryConfig = + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); + const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id()); + DEBUG_ASSERT(in.is_allocated()); + + const auto *fbDimArg = op->dim_arg(); + int dim = 0; + if (fbDimArg) { + std::cerr << "prod: fbDimArg found\n"; + ::ttnn::SmallVector dimArg = + ::ttnn::SmallVector(fbDimArg->begin(), fbDimArg->end()); + + std::cerr << "dimArg.size: " << dimArg.size() << '\n'; + std::cerr << "dimArg[0]: " << dimArg[0] << '\n'; + dim = dimArg[0]; + } + else { + std::cerr << "prod: fbDimArg not found\n"; + } + /*std::optional<::ttnn::SmallVector> dimArg = + fbDimArg ? std::make_optional(::ttnn::SmallVector(fbDimArg->begin(), + fbDimArg->end())) + : std::nullopt; + std::cerr << "fbDimArg: " << fbDimArg->size() << '\n'; + bool all_dimensions = fbDimArg->size() == 4 ? true: false; + int dim = *(fbDimArg->begin()); + std::cerr << "all_dimensions: " << all_dimensions << '\n';*/ + std::cerr << "dim: " << dim << '\n'; + ::ttnn::Tensor out = ::ttnn::prod( + in, false, dim, op->keep_dim(), outputMemoryConfig /* memory_config_arg */); + + tensorPool.insert_or_assign(op->out()->global_id(), out); +} + void run(const ::tt::target::ttnn::ReductionOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); @@ -50,6 +95,12 @@ void run(const ::tt::target::ttnn::ReductionOp *op, ProgramContext &context) { runReductionOp(op, tensorPool, ::ttnn::max); break; } + case ::tt::target::ttnn::ReductionOpType::Prod: { + std::cerr << "run ttnn::min\n"; + runReductionProdOp(op, tensorPool); + std::cerr << "execution done\n"; + break; + } } } } // namespace tt::runtime::ttnn::operations::reduction