Skip to content

Commit

Permalink
Add ttnn::ones() op (#1476)
Browse files Browse the repository at this point in the history
  • Loading branch information
svuckovicTT authored Dec 11, 2024
1 parent 7d53af2 commit 31e5518
Show file tree
Hide file tree
Showing 20 changed files with 688 additions and 207 deletions.
68 changes: 68 additions & 0 deletions include/ttmlir/Conversion/TTNNToEmitC/Utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_CONVERSION_TTNNTOEMITC_UTILS_H
#define TTMLIR_CONVERSION_TTNNTOEMITC_UTILS_H

#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"

#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir::tt::ttnn_to_emitc::utils {

// Create emitc::OpaqueAttr for ttnn::Shape
//
emitc::OpaqueAttr convertShape(Builder &builder, ttnn::ShapeAttr attr);

// Create emitc::OpaqueAttr for ttnn::TensorMemoryLayout
//
emitc::OpaqueAttr convertTensorMemoryLayout(Builder &builder,
ttnn::TensorMemoryLayoutAttr attr);

// Create emitc::OpaqueAttr for ttnn::BufferType
//
emitc::OpaqueAttr convertBufferType(Builder &builder,
ttnn::BufferTypeAttr attr);

// Create emitc::OpaqueAttr for ttnn::Layout
//
emitc::OpaqueAttr convertLayoutAttr(Builder &builder, ttnn::LayoutAttr attr);

// Create emitc::OpaqueAttr for BoolAttr
//
emitc::OpaqueAttr convertBoolAttr(Builder &builder, BoolAttr attr);

// Create emitc::OpaqueAttr for ttnn::DataType
//
emitc::OpaqueAttr convertDType(Builder &builder, tt::DataTypeAttr attr);

// Create emitc::OpaqueAttr for std::nullopt
//
emitc::OpaqueAttr createStdNullopt(Builder &builder);

// Create ttnn::Shape and return emitc::ExpressionOp
//
// ttnn:Shape has a couple constructors, but they are explicit and require
// specific datatypes on input. However, one of the constructors takes in a
// tt_metal::Shape - given that it's much easier to construct a
// tt_metal::Shape, we opted to do that here. The call looks like this:
// ttnn::Shape(tt::tt_metal::LegacyShape{dim0, dim1, dim2, ...});
//
// To make it easier on the eyes, these two calls are packed into one, using
// EmitC's ExpressionOp.
//
emitc::ExpressionOp createShapeOp(ConversionPatternRewriter &rewriter,
ttnn::ShapeAttr shapeAttr,
Block *containingBlock, Location loc);

// Create ttnn::MemoryConfig and return emitc::CallOpaqueOp
//
emitc::CallOpaqueOp createMemoryConfigOp(ConversionPatternRewriter &rewriter,
ttnn::MemoryConfigAttr memoryConfig,
Location loc);

} // namespace mlir::tt::ttnn_to_emitc::utils

#endif // TTMLIR_CONVERSION_TTNNTOEMITC_UTILS_H
17 changes: 17 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,23 @@ def TTIR_ArangeOp : TTIR_Op<"arange"> {
let hasVerifier = 1;
}

def TTIR_OnesOp : TTIR_Op<"ones"> {
let summary = "Creates a tensor filled with ones.";
let description = [{
Tensor operation to create a tensor filled with ones.

Given a `shape`, produces a tensor with the shape, filled with ones.

Example:
%0 = "ttir.ones"() <{shape = array<i32:64, 28, 28>}> : () -> tensor<64x28x28xbf16>
// %0: [[[1, 1, 1, ..., 1], [1, 1, 1, ..., 1], ..., [1, 1, 1, ..., 1]]]
}];

let arguments = (ins DenseI32ArrayAttr:$shape);

let results = (outs AnyRankedTensor:$result);
}

def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike,
AllShapesMatch<["value", "result"]>]> {
let summary = "Constant op.";
Expand Down
21 changes: 21 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,27 @@ def TTNN_ArangeOp : TTNN_Op<"arange"> {
let hasVerifier = 1;
}

def TTNN_OnesOp : TTNN_Op<"ones"> {
let summary = "Creates a tensor filled with ones.";
let description = [{
Tensor operation to create a tensor filled with ones.

Given a ShapeAttr `shape`, produces a tensor with the same shape, filled with ones.

Example:
%0 = "ttnn.ones"() <{shape = array<i32:64, 28, 28>}> : () -> tensor<64x28x28xbf16>
// %0: [[[1, 1, 1, ..., 1], [1, 1, 1, ..., 1], ..., [1, 1, 1, ..., 1]]]
}];

let arguments = (ins TTNN_ShapeAttr:$shape,
OptionalAttr<TT_DataTypeAttr>:$dtype,
OptionalAttr<TTNN_LayoutAttr>:$layout,
Optional<TT_Device>:$device,
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config);

let results = (outs AnyRankedTensor:$result);
}

def TTNN_FullOp : TTNN_Op<"full"> {
let summary = "Full op.";
let description = [{
Expand Down
4 changes: 0 additions & 4 deletions include/ttmlir/Dialect/TTNN/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ mlir::tt::TensorMemoryLayout toTTTensorMemoryLayout(
mlir::tt::MemorySpace
toTTMemorySpace(const mlir::tt::ttnn::BufferType bufferType);

// Get Layout from MemRefType
//
Layout getLayoutFromMemRef(mlir::MemRefType memref);

mlir::Type createRowMajorTypeFromDtype(::mlir::MLIRContext *context,
DataType dtype);

Expand Down
20 changes: 15 additions & 5 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,21 @@ table EmptyOp {
dtype: DataType;
layout: TensorLayout;
num_shards: uint32;
device: tt.target.DeviceRef; // optional
memcfg: tt.target.MemoryConfigDesc; // optional
device: tt.target.DeviceRef;
memcfg: tt.target.MemoryConfigDesc;
strategy: tt.target.DistributionStrategy;
out: tt.target.TensorRef;
}

table OnesOp {
shape: [int64];
dtype: DataType = null;
layout: TensorLayout = null;
device: tt.target.DeviceRef;
memcfg: tt.target.MemoryConfigDesc;
out: tt.target.TensorRef;
}

table FullOp {
device: tt.target.DeviceRef;
fill_value: float;
Expand All @@ -78,9 +87,9 @@ table ArangeOp {
start: float;
end: float;
step: float;
dtype: tt.target.DataType = null; // optional
device: tt.target.DeviceRef; // optional
memcfg: tt.target.MemoryConfigDesc; // optional
dtype: tt.target.DataType = null;
device: tt.target.DeviceRef;
memcfg: tt.target.MemoryConfigDesc;
out: tt.target.TensorRef;
}

Expand Down Expand Up @@ -299,6 +308,7 @@ union OpType {
ToDeviceOp,
FromDeviceOp,
EmptyOp,
OnesOp,
FullOp,
EltwiseOp,
LinearOp,
Expand Down
38 changes: 34 additions & 4 deletions include/ttmlir/Target/Utils/MLIRToFlatbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
#ifndef TTMLIR_TARGET_UTILS_MLIRTOFLATBUFFER_H
#define TTMLIR_TARGET_UTILS_MLIRTOFLATBUFFER_H

#include <numeric>
#include <type_traits>

#include "flatbuffers/flatbuffers.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Target/Common/Target.h"
#include "ttmlir/Target/Utils/FlatbufferObjectCache.h"
#include "ttmlir/Utils.h"

#include "flatbuffers/flatbuffers.h"

#include <numeric>
#include <type_traits>

namespace mlir::tt {

flatbuffers::Offset<::tt::target::LayoutDesc>
Expand Down Expand Up @@ -136,6 +137,35 @@ inline ::tt::target::DataType toFlatbuffer(FlatbufferObjectCache &,
}
}

inline ::flatbuffers::Optional<::tt::target::DataType>
toFlatbufferOptional(FlatbufferObjectCache &cache,
::std::optional<::mlir::tt::DataType> dataType) {
return dataType.has_value() ? ::flatbuffers::Optional<::tt::target::DataType>(
toFlatbuffer(cache, dataType.value()))
: ::flatbuffers::nullopt;
}

inline ::tt::target::TensorLayout toFlatbuffer(FlatbufferObjectCache &cache,
ttnn::Layout layout) {
switch (layout) {
case ttnn::Layout::RowMajor:
return ::tt::target::TensorLayout::RowMajor;
case ttnn::Layout::Tile:
return ::tt::target::TensorLayout::Tile;
case ttnn::Layout::Invalid:
return ::tt::target::TensorLayout::Invalid;
}
}

inline ::flatbuffers::Optional<::tt::target::TensorLayout>
toFlatbufferOptional(FlatbufferObjectCache &cache,
::std::optional<mlir::tt::ttnn::Layout> layout) {
return layout.has_value()
? ::flatbuffers::Optional<::tt::target::TensorLayout>(
toFlatbuffer(cache, layout.value()))
: ::flatbuffers::nullopt;
}

inline ::tt::target::MemorySpace toFlatbuffer(FlatbufferObjectCache &,
MemorySpace memspace) {
switch (memspace) {
Expand Down
64 changes: 64 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,69 @@ class TensorEmptyConversionPattern
}
};

class OnesOpConversionPattern : public OpConversionPattern<ttir::OnesOp> {
public:
using OpConversionPattern<ttir::OnesOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::OnesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Get ttnn::TTNNLayoutAttr of the result type
//
ttnn::TTNNLayoutAttr layoutAttr = mlir::cast<ttnn::TTNNLayoutAttr>(
op.getResult().getType().getEncoding());

// Get the shape of tensor
//
// TODO(svuckovic): (#1435) ShapeAttr accepts int64_t, when it should be
// uint32_t
//
ttnn::ShapeAttr shapeAttr = ttnn::ShapeAttr::get(
rewriter.getContext(), llvm::SmallVector<int64_t, 4>(
op.getShape().begin(), op.getShape().end()));

// Get memref
//
mlir::MemRefType memref = layoutAttr.getMemref();

// Get data type, tensor layout, device and memory config
//
DataTypeAttr dTypeAttr =
DataTypeAttr::get(rewriter.getContext(), layoutAttr.getDataType());
ttnn::BufferType bufferType = layoutAttr.getBufferType();
ttnn::Layout ttnnLayoutEnum = llvm::isa<TileType>(memref.getElementType())
? ttnn::Layout::Tile
: ttnn::Layout::RowMajor;
ttnn::LayoutAttr tensorLayoutAttr =
ttnn::LayoutAttr::get(op.getContext(), ttnnLayoutEnum);
ttnn::TensorMemoryLayoutAttr memLayout = layoutAttr.getMemLayout();

// Device only exists if memLayout is *not* null
//
auto device =
memLayout ? ::ttnn::utils::getOrInsertDevice(rewriter, op) : nullptr;

// MemoryConfigAttr only exists if memLayout is *not* null
//
ttnn::MemoryConfigAttr memoryConfigAttr =
memLayout
? ttnn::MemoryConfigAttr::get(
op.getContext(),
ttnn::BufferTypeAttr::get(op.getContext(), bufferType),
ttnn::ShardSpecAttr::get(
op.getContext(),
ttnn::ShapeAttr::get(op.getContext(), memref.getShape())),
memLayout)
: nullptr;

rewriter.replaceOpWithNewOp<ttnn::OnesOp>(
op, this->getTypeConverter()->convertType(op.getType()), shapeAttr,
dTypeAttr, tensorLayoutAttr, device, memoryConfigAttr);

return success();
}
};

class ToLayoutOpConversionPattern
: public OpConversionPattern<ttir::ToLayoutOp> {
public:
Expand Down Expand Up @@ -1018,6 +1081,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
// ANCHOR: op_rewriter_pattern_set
patterns
.add<TensorEmptyConversionPattern,
OnesOpConversionPattern,
ToLayoutOpConversionPattern,
ElementwiseOpConversionPattern<ttir::AbsOp, ttnn::AbsOp>,
ElementwiseOpConversionPattern<ttir::AddOp, ttnn::AddOp>,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTNNToEmitC/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_mlir_library(TTMLIRTTNNToEmitC
TTNNToEmitC.cpp
TTNNToEmitCPass.cpp
Utils.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/ttmlir/Conversion/TTNNToEmitC
Expand Down
Loading

0 comments on commit 31e5518

Please sign in to comment.