Skip to content

Commit

Permalink
reduce prod op
Browse files Browse the repository at this point in the history
  • Loading branch information
mmanzoorTT committed Jan 15, 2025
1 parent 431fdaa commit bcc9fbd
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 12 deletions.
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,13 @@ def TTIR_MaxOp : TTIR_ReductionOp<"max"> {
}];
}

def TTIR_ProdOp : TTIR_ReductionOp<"prod"> {
let summary = "Min reduction op.";
let description = [{
Min reduction op.
}];
}

def TTIR_EmbeddingOp : TTIR_DPSOp<"embedding"> {
let summary = "Embedding op.";
let description = [{
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,13 @@ def TTNN_MaxOp : TTNN_ReductionOp<"max"> {
}];
}

def TTNN_ProdOp : TTNN_ReductionOp<"prod"> {
let summary = "Min reduction op.";
let description = [{
Min reduction op.
}];
}

def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> {
let summary = "Embedding op.";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ enum ReductionOpType: uint32 {
Sum,
Mean,
Max,
Prod,
}

table ReductionOp {
Expand Down
17 changes: 12 additions & 5 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,14 @@ class StableHLOToTTIRReduceOpConversionPattern
matchAndRewrite(mlir::stablehlo::ReduceOp srcOp,
mlir::stablehlo::ReduceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
llvm::errs() << "Reduction conversion\n";
LogicalResult legalityResult = checkBasicLegality(srcOp, adaptor, rewriter);
if (!legalityResult.succeeded()) {
return legalityResult;
}

const mlir::Operation &innerOp = srcOp.getBody().front().front();
llvm::errs() << "isa<mulop: " << isa<mlir::stablehlo::MulOp>(innerOp) << '\n';

if (mlir::isa<mlir::stablehlo::AddOp>(innerOp)) {
return matchAndRewriteInternal<mlir::tt::ttir::SumOp>(srcOp, adaptor,
Expand All @@ -84,10 +86,13 @@ class StableHLOToTTIRReduceOpConversionPattern
return matchAndRewriteInternal<mlir::tt::ttir::MaxOp>(srcOp, adaptor,
rewriter);
}
if (mlir::isa<mlir::stablehlo::MulOp>(innerOp)) {
return matchAndRewriteInternal<mlir::tt::ttir::ProdOp>(srcOp, adaptor,
rewriter);
}

return failure();
}

private:
LogicalResult checkBasicLegality(mlir::stablehlo::ReduceOp &srcOp,
mlir::stablehlo::ReduceOp::Adaptor adaptor,
Expand All @@ -112,19 +117,21 @@ class StableHLOToTTIRReduceOpConversionPattern
matchAndRewriteInternal(mlir::stablehlo::ReduceOp &srcOp,
mlir::stablehlo::ReduceOp::Adaptor &adaptor,
ConversionPatternRewriter &rewriter) const {
llvm::errs() << "rewrite internal\n";
auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResultTypes().front()));
llvm::errs() << "outputType: "; outputType.dump();
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

llvm::errs() << "emptyOp: "; outputTensor->dump();
// Can't reuse the original dimensions attribute because it uses i64 type.
mlir::ArrayAttr dimArg = rewriter.getI32ArrayAttr(
llvm::SmallVector<int32_t>(srcOp.getDimensions()));

rewriter.replaceOpWithNewOp<DestOp>(
llvm::errs() << "dimArg: "; dimArg.dump();
auto temp = rewriter.replaceOpWithNewOp<DestOp>(
srcOp, outputType, adaptor.getInputs().front(), outputTensor,
false /* keep_dim */, dimArg);

llvm::errs() << "converted reduce op: "; temp.dump();
return success();
}
};
Expand Down
7 changes: 5 additions & 2 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,13 @@ class ReductionOpConversionPattern : public OpConversionPattern<TTIROpTy> {
LogicalResult
matchAndRewrite(TTIROpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TTNNOpTy>(
llvm::errs() << "TTIR->TTNN conversion\n";
auto temp = rewriter.replaceOpWithNewOp<TTNNOpTy>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getKeepDim(),
adaptor.getDimArg().value_or(nullptr));
temp.dump();
llvm::errs() << "Successfull\n";
return success();
}
};
Expand Down Expand Up @@ -1206,7 +1209,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ReductionOpConversionPattern<ttir::SumOp, ttnn::SumOp>,
ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
ReductionOpConversionPattern<ttir::MaxOp, ttnn::MaxOp>,
ElementwiseUnaryWithFloatParameterOpConversionPattern<ttir::LeakyReluOp, ttnn::LeakyReluOp>,
ReductionOpConversionPattern<ttir::ProdOp, ttnn::ProdOp>,
BroadcastOpConversionPattern,
EmbeddingOpConversionPattern,
EmbeddingBackwardOpConversionPattern,
Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1857,3 +1857,19 @@ void mlir::tt::ttir::SumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::LogicalResult mlir::tt::ttir::SumOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
// Reduce ProdOp
//===----------------------------------------------------------------------===//

// MaxOp kernel builder.
void mlir::tt::ttir::ProdOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
// NOLINTNEXTLINE
createReduceOp(opBuilder, block, getLoc(), "prod");
}

// MaxOp verification.
::mlir::LogicalResult mlir::tt::ttir::ProdOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}
9 changes: 9 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1343,4 +1343,13 @@ ::mlir::LogicalResult SumOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
// Reduce MinOp
//===----------------------------------------------------------------------===//

// SumOp verification.
::mlir::LogicalResult ProdOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

} // namespace mlir::tt::ttnn
10 changes: 9 additions & 1 deletion lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,12 +425,20 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase<TTNNWorkarounds> {
ttnn::MaxOp>,
workarounds::decomposition::ReduceOpsKeepDimRewritePattern<
ttnn::MeanOp>,
workarounds::decomposition::ReduceOpsKeepDimRewritePattern<
ttnn::MinOp>,
workarounds::decomposition::ReduceOpsKeepDimRewritePattern<
ttnn::ProdOp>,
workarounds::decomposition::ReduceOpsAllDimsRewritePattern<
ttnn::SumOp>,
workarounds::decomposition::ReduceOpsAllDimsRewritePattern<
ttnn::MaxOp>,
workarounds::decomposition::ReduceOpsAllDimsRewritePattern<
ttnn::MeanOp>>(&getContext());
ttnn::MeanOp>,
workarounds::decomposition::ReduceOpsAllDimsRewritePattern<
ttnn::MinOp>,
workarounds::decomposition::ReduceOpsAllDimsRewritePattern<
ttnn::ProdOp>>(&getContext());

runRewritePatterns(std::move(patterns),
GreedyRewriteConfig::kNoLimit /*maxIterations*/);
Expand Down
21 changes: 17 additions & 4 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -731,30 +731,39 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
getOperandThroughDPSOps(op.getOutputs().front())),
paramsType, params);
}

#include<iostream>
template <typename ReductionOp>
::flatbuffers::Offset<::tt::target::ttnn::ReductionOp>
createReductionOp(FlatbufferObjectCache &cache, ReductionOp op) {
std::cerr << "Flatbuffer reduction\n";
::tt::target::ttnn::ReductionOpType type;
if constexpr (std::is_same_v<ReductionOp, SumOp>) {
type = ::tt::target::ttnn::ReductionOpType::Sum;
} else if constexpr (std::is_same_v<ReductionOp, MeanOp>) {
type = ::tt::target::ttnn::ReductionOpType::Mean;
} else if constexpr (std::is_same_v<ReductionOp, MaxOp>) {
type = ::tt::target::ttnn::ReductionOpType::Max;
} else if constexpr (std::is_same_v<ReductionOp, ProdOp>) {
type = ::tt::target::ttnn::ReductionOpType::Prod;
} else {
llvm_unreachable("unhandled ReductionOp");
}

std::cerr << "check-1\n";
auto in =
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
std::cerr << "check-2\n";
auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
kHostAllocatedAddress, kHostAllocatedSize);
std::cerr << "check-3\n";
//auto temp = op.getDimArg();
std::cerr << "check-4\n";// << (*temp).size() << '\n';
auto dim_arg =
arrayAttrToFlatbuffer<mlir::IntegerAttr, int>(cache, op.getDimArg());

return ::tt::target::ttnn::CreateReductionOp(*cache.fbb, type, in, output,
std::cerr << "check-5\n";
auto temp3 = ::tt::target::ttnn::CreateReductionOp(*cache.fbb, type, in, output,
dim_arg, op.getKeepDim());
std::cerr << "check=6\n";
return temp3;
}

::flatbuffers::Offset<::tt::target::ttnn::TransposeOp>
Expand Down Expand Up @@ -1116,6 +1125,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createReductionOp(cache, maxOp), debugString,
locInfo);
}
if (auto prodOp = dyn_cast<ProdOp>(op); prodOp) {
return createOperation(cache, createReductionOp(cache, prodOp), debugString,
locInfo);
}
if (auto embeddingOp = dyn_cast<EmbeddingOp>(op); embeddingOp) {
return createOperation(cache, createEmbeddingOp(cache, embeddingOp),
debugString, locInfo);
Expand Down
1 change: 1 addition & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "ttnn/operations/normalization/softmax/softmax.hpp"
#include "ttnn/operations/pool/generic/generic_pools.hpp"
#include "ttnn/operations/reduction/generic/generic_reductions.hpp"
#include "ttnn/operations/reduction/prod/prod.hpp"
#include "ttnn/tensor/host_buffer/functions.hpp"
#include "ttnn/tensor/host_buffer/owned_buffer.hpp"
#include "ttnn/tensor/shape/shape.hpp"
Expand Down
51 changes: 51 additions & 0 deletions runtime/lib/ttnn/operations/reduction/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,57 @@ static void runReductionOp(
fbDimArg->end()))
: std::nullopt;

if (fbDimArg) {
std::cerr << "reduce: fbDimArg found\n";
std::cerr << "size: " << fbDimArg->size() << '\n';
std::cerr << "[0]: " << *(fbDimArg->begin()) << '\n';
}
else {
std::cerr << "reduce: fbDimArg not found\n";
}
::ttnn::Tensor out = ttnnOp(
in, dimArg, op->keep_dim(), outputMemoryConfig /* memory_config_arg */,
std::nullopt /* compute_kernel_config */, 1.0f /* scalar */);

tensorPool.insert_or_assign(op->out()->global_id(), out);
}
#include <iostream>
static void runReductionProdOp(
::tt::target::ttnn::ReductionOp const *op, ProgramTensorPool &tensorPool) {
::tt::tt_metal::MemoryConfig outputMemoryConfig =
::tt::runtime::ttnn::utils::createMemoryConfig(op->out());
const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id());
DEBUG_ASSERT(in.is_allocated());

const auto *fbDimArg = op->dim_arg();
int dim = 0;
if (fbDimArg) {
std::cerr << "prod: fbDimArg found\n";
::ttnn::SmallVector<int> dimArg =
::ttnn::SmallVector<int>(fbDimArg->begin(), fbDimArg->end());

std::cerr << "dimArg.size: " << dimArg.size() << '\n';
std::cerr << "dimArg[0]: " << dimArg[0] << '\n';
dim = dimArg[0];
}
else {
std::cerr << "prod: fbDimArg not found\n";
}
/*std::optional<::ttnn::SmallVector<int>> dimArg =
fbDimArg ? std::make_optional(::ttnn::SmallVector<int>(fbDimArg->begin(),
fbDimArg->end()))
: std::nullopt;
std::cerr << "fbDimArg: " << fbDimArg->size() << '\n';
bool all_dimensions = fbDimArg->size() == 4 ? true: false;
int dim = *(fbDimArg->begin());
std::cerr << "all_dimensions: " << all_dimensions << '\n';*/
std::cerr << "dim: " << dim << '\n';
::ttnn::Tensor out = ::ttnn::prod(
in, false, dim, op->keep_dim(), outputMemoryConfig /* memory_config_arg */);

tensorPool.insert_or_assign(op->out()->global_id(), out);
}


void run(const ::tt::target::ttnn::ReductionOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
Expand All @@ -50,6 +95,12 @@ void run(const ::tt::target::ttnn::ReductionOp *op, ProgramContext &context) {
runReductionOp(op, tensorPool, ::ttnn::max);
break;
}
case ::tt::target::ttnn::ReductionOpType::Prod: {
std::cerr << "run ttnn::min\n";
runReductionProdOp(op, tensorPool);
std::cerr << "execution done\n";
break;
}
}
}
} // namespace tt::runtime::ttnn::operations::reduction

0 comments on commit bcc9fbd

Please sign in to comment.