Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve TTNN Dialect to better model TTNN lib #624

Merged
merged 9 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@ add_mlir_dialect(TTNNOps ttnn)
add_mlir_doc(TTNNBase TTNNDialect src/autogen/md/Dialect/ -gen-dialect-doc)
add_mlir_doc(TTNNOps TTNNOp src/autogen/md/Dialect/ -gen-op-doc)

set(LLVM_TARGET_DEFINITIONS TTNNOpsTypes.td)
mlir_tablegen(TTNNOpsAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(TTNNOpsAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRTTNNOpsAttributesIncGen)
add_dependencies(mlir-headers MLIRTTNNOpsAttributesIncGen)

set(LLVM_TARGET_DEFINITIONS TTNNOpsEnums.td)
mlir_tablegen(TTNNOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(TTNNOpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRTTNNOpsEnumsIncGen)
add_dependencies(mlir-headers MLIRTTNNOpsEnumsIncGen)

set(LLVM_TARGET_DEFINITIONS TTNNOpsAttrs.td)
mlir_tablegen(TTNNOpsAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(TTNNOpsAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRTTNNOpsAttrsIncGen)
add_dependencies(mlir-headers MLIRTTNNOpsAttrsIncGen)

set(LLVM_TARGET_DEFINITIONS TTNNOpsTypes.td)
mlir_tablegen(TTNNOpsTypeDefs.h.inc -gen-typedef-decls)
mlir_tablegen(TTNNOpsTypeDefs.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(MLIRTTNNOpsTypesIncGen)
add_dependencies(mlir-headers MLIRTTNNOpsTypesIncGen)
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"

#define GET_OP_CLASSES
Expand Down
24 changes: 24 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
#define TTMLIR_TTMLIR_DIALECT_TTNN_TTNNOPS_TD

include "ttmlir/Dialect/TT/IR/TTOpsTypes.td"
include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td"
include "ttmlir/Dialect/TTNN/IR/TTNNBase.td"
include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.td"
include "ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
Expand Down Expand Up @@ -35,6 +37,28 @@ def TTNN_ToMemoryConfigOp : TTNN_Op<"to_memory_config"> {
let hasVerifier = 1;
}

def TTNN_ToLayoutOp : TTNN_Op<"to_layout"> {
let summary = "ToLayout op.";
let description = [{
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
TTNN_LayoutAttr:$layout);
let results = (outs AnyRankedTensor:$result);
}

def TTNN_ToDeviceOp : TTNN_Op<"to_device"> {
let summary = "ToDevice op.";
let description = [{
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
TTNN_MemoryConfigAttr:$memory_config);
let results = (outs AnyRankedTensor:$result);
}

class TTNN_NamedDPSOp<string mnemonic, list<Trait> traits = []> :
TTNN_Op<mnemonic, !listconcat(traits, [DestinationStyleOpInterface])> {
let extraClassDeclaration = [{
Expand Down
17 changes: 17 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_IR_TTNNOPSATTRS_H
#define TTMLIR_DIALECT_TTNN_IR_TTNNOPSATTRS_H

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"

#include "ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.h.inc"

#define GET_ATTRDEF_CLASSES
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrDefs.h.inc"

#endif // TTMLIR_DIALECT_TTNN_IR_TTNNOPSATTRS_H
70 changes: 70 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_TTMLIR_DIALECT_TTNN_TTNNOPSATTRS_TD
#define TTMLIR_TTMLIR_DIALECT_TTNN_TTNNOPSATTRS_TD

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/CommonTypeConstraints.td"
include "ttmlir/Dialect/TTNN/IR/TTNNBase.td"
include "ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.td"

//===----------------------------------------------------------------------===//
// TTNN attr definitions
//===----------------------------------------------------------------------===//

class TTNN_Attr<string name, string attrMnemonic, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
: AttrDef<TTNN_Dialect, name, traits, baseCppClass> {
let mnemonic = attrMnemonic;
let attrName = "ttnn." # attrMnemonic;
}

def TTNN_CoreRangeAttr : TTNN_Attr<"CoreRange", "core_range"> {
let summary = "TTNN grid attribute";
let description = [{
TTNN grid attribute
}];

let parameters = (ins ArrayRefParameter<"int64_t">:$offset,
ArrayRefParameter<"int64_t">:$size);
let assemblyFormat = "`<` custom<DimensionList>($offset) `,` custom<DimensionList>($size) `>`";

let extraClassDeclaration = [{
static CoreRangeAttr get(::mlir::MLIRContext *context, ::mlir::tt::GridAttr grid, SmallVector<int64_t> offset = {0, 0})
{
assert(grid.getShape().size() == 2 && "Grid shape must be 2D for now");
return CoreRangeAttr::get(context, {0, 0}, grid.getShape());
}
}];
}

def TTNN_CoreRangeArrayAttr : TypedArrayAttrBase<TTNN_CoreRangeAttr, "">;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this map to ttnn::CoreRangeSet? We might need a proper def for this because they are kind of non-trivial to construct and depends on TensorMemoryLayout. Happy to tackle this in a follow on change though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually not sure! All the CoreRange* attrs here show as additions because I separated Types and Attrs, and these were previously in TTNNOpsTypes.td file.


def TTNN_LayoutAttr : EnumAttr<TTNN_Dialect, TTNN_Layout, "layout"> {
let assemblyFormat = "`<` $value `>`";
}

def TTNN_TensorMemoryLayoutAttr : EnumAttr<TTNN_Dialect, TTNN_TensorMemoryLayout, "tensor_memory_layout"> {
let assemblyFormat = "`<` $value `>`";
}

def TTNN_BufferTypeAttr : EnumAttr<TTNN_Dialect, TTNN_BufferType, "buffer_type"> {
let assemblyFormat = "`<` $value `>`";
}

def TTNN_MemoryConfigAttr : TTNN_Attr<"MemoryConfig", "memory_config"> {
let summary = "TTNN MemoryConfig attribute";
let description = [{
TTNN memory config attribute
}];

let parameters = (ins AttrParameter<"TensorMemoryLayoutAttr", "">:$tensorMemoryLayout,
AttrParameter<"BufferTypeAttr", "">:$bufferType);
let assemblyFormat = "`<` params `>`";
}

#endif // TTMLIR_TTMLIR_DIALECT_TTNN_TTNNOPSATTRS_TD
50 changes: 50 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,54 @@

include "mlir/IR/EnumAttr.td"

def TTNN_Layout_RowMajor : I32EnumAttrCase<"RowMajor", 0, "row_major">;
def TTNN_Layout_Tile : I32EnumAttrCase<"Tile", 1, "tile">;
def TTNN_Layout_Invalid : I32EnumAttrCase<"Invalid", 2, "invalid">;

def TTNN_Layout : I32EnumAttr<"Layout", "TTNN Layout",
[
TTNN_Layout_RowMajor,
TTNN_Layout_Tile,
TTNN_Layout_Invalid,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tt::ttnn";
}

def TTNN_TensorMemoryLayout_Interleaved : I32EnumAttrCase<"Interleaved", 0, "interleaved">;
def TTNN_TensorMemoryLayout_SingleBank : I32EnumAttrCase<"SingleBank", 1, "single_bank">;
def TTNN_TensorMemoryLayout_HeightSharded : I32EnumAttrCase<"HeightSharded", 2, "height_sharded">;
def TTNN_TensorMemoryLayout_WidthSharded : I32EnumAttrCase<"WidthSharded", 3, "width_sharded">;
def TTNN_TensorMemoryLayout_BlockSharded : I32EnumAttrCase<"BlockSharded", 4, "block_sharded">;

def TTNN_TensorMemoryLayout : I32EnumAttr<"TensorMemoryLayout", "TTNN Tensor Memory Layout",
[
TTNN_TensorMemoryLayout_Interleaved,
TTNN_TensorMemoryLayout_SingleBank,
TTNN_TensorMemoryLayout_HeightSharded,
TTNN_TensorMemoryLayout_WidthSharded,
TTNN_TensorMemoryLayout_BlockSharded,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tt::ttnn";
}

def TTNN_BufferType_DRAM : I32EnumAttrCase<"DRAM", 0, "dram">;
def TTNN_BufferType_L1 : I32EnumAttrCase<"L1", 1, "l1">;
def TTNN_BufferType_SystemMemory : I32EnumAttrCase<"SystemMemory", 2, "system_memory">;
def TTNN_BufferType_L1Small : I32EnumAttrCase<"L1Small", 3, "l1_small">;
def TTNN_BufferType_Trace : I32EnumAttrCase<"Trace", 4, "trace">;

def TTNN_BufferType : I32EnumAttr<"BufferType", "TTNN Buffer Type",
[
TTNN_BufferType_DRAM,
TTNN_BufferType_L1,
TTNN_BufferType_SystemMemory,
TTNN_BufferType_L1Small,
TTNN_BufferType_Trace,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tt::ttnn";
}

#endif
11 changes: 4 additions & 7 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,13 @@
#ifndef TTMLIR_DIALECT_TTNN_IR_TTNNOPSTYPES_H
#define TTMLIR_DIALECT_TTNN_IR_TTNNOPSTYPES_H

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"

#include "ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.h.inc"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"

#define GET_TYPEDEF_CLASSES
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h.inc"

#define GET_ATTRDEF_CLASSES
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrDefs.h.inc"

#endif
#endif // TTMLIR_DIALECT_TTNN_IR_TTNNOPSTYPES_H
36 changes: 1 addition & 35 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#define TTMLIR_TTMLIR_DIALECT_TTNN_TTNNOPSTYPES_TD

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/CommonTypeConstraints.td"
include "ttmlir/Dialect/TTNN/IR/TTNNBase.td"
Expand All @@ -21,37 +20,4 @@ class TTNN_Type<string name, string typeMnemonic, list<Trait> traits = []>
let mnemonic = typeMnemonic;
}

//===----------------------------------------------------------------------===//
// TTNN attr definitions
//===----------------------------------------------------------------------===//
// Should Attr be a separate file?

class TTNN_Attr<string name, string attrMnemonic, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
: AttrDef<TTNN_Dialect, name, traits, baseCppClass> {
let mnemonic = attrMnemonic;
let attrName = "ttnn." # attrMnemonic;
}

def TTNN_CoreRangeAttr : TTNN_Attr<"CoreRange", "core_range"> {
let summary = "TTNN grid attribute";
let description = [{
TTNN grid attribute
}];

let parameters = (ins ArrayRefParameter<"int64_t">:$offset,
ArrayRefParameter<"int64_t">:$size);
let assemblyFormat = "`<` custom<DimensionList>($offset) `,` custom<DimensionList>($size) `>`";

let extraClassDeclaration = [{
static CoreRangeAttr get(::mlir::MLIRContext *context, ::mlir::tt::GridAttr grid, SmallVector<int64_t> offset = {0, 0})
{
assert(grid.getShape().size() == 2 && "Grid shape must be 2D for now");
return CoreRangeAttr::get(context, {0, 0}, grid.getShape());
}
}];
}

def TTNN_CoreRangeArrayAttr : TypedArrayAttrBase<TTNN_CoreRangeAttr, "">;

#endif
#endif // TTMLIR_TTMLIR_DIALECT_TTNN_TTNNOPSTYPES_TD
19 changes: 19 additions & 0 deletions include/ttmlir/Target/Common/types.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,25 @@ enum TensorMemoryLayout: ushort {
BlockSharded = 5,
}

enum TensorLayout: ushort {
RowMajor = 0,
Tile = 1,
Invalid = 2,
}

enum BufferType: ushort {
DRAM = 0,
L1 = 1,
SystemMemory = 2,
L1Small = 3,
Trace = 4,
}

table MemoryConfigDesc {
tensor_memory_layout: TensorMemoryLayout;
buffer_type: BufferType;
}

table MemoryDesc {
shape: [int];
tile_shape: Dim2d;
Expand Down
16 changes: 16 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,20 @@ table ToMemoryConfigOp {
out: tt.target.TensorRef;
}

table ToLayoutOp {
in: tt.target.TensorRef;
layout: tt.target.TensorLayout;
device: tt.target.DeviceRef;
out: tt.target.TensorRef;
}

table ToDeviceOp {
in: tt.target.TensorRef;
device: tt.target.DeviceRef;
memcfg: tt.target.MemoryConfigDesc;
out: tt.target.TensorRef;
}

table EmptyOp {
device: tt.target.DeviceRef;
out: tt.target.TensorRef;
Expand Down Expand Up @@ -148,6 +162,8 @@ table DeallocOp {
union OpType {
GetDeviceOp,
ToMemoryConfigOp,
ToLayoutOp,
ToDeviceOp,
EmptyOp,
FullOp,
EltwiseOp,
Expand Down
Loading
Loading