Skip to content

Commit

Permalink
Fix toMemoryConfig Op hacks and hardcodings, create memory ops as nec…
Browse files Browse the repository at this point in the history
…essary and obey constraints, properly model toLayout op.
  • Loading branch information
jnie-TT committed Oct 21, 2024
1 parent ede24ed commit 376d9ff
Show file tree
Hide file tree
Showing 28 changed files with 1,268 additions and 478 deletions.
43 changes: 32 additions & 11 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,17 @@ def TTNN_GetDeviceOp : TTNN_Op<"get_device"> {
def TTNN_ToMemoryConfigOp : TTNN_Op<"to_memory_config"> {
let summary = "ToMemoryConfig op.";
let description = [{
This op converts the memory config of the input tensor based on the given memory config.
It handles:
- Dram to L1
- L1 to Dram
- Interleaved to sharded
- Sharded to interleaved
- Sharded to sharded (reshard)
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device);
TTNN_MemoryConfigAttr:$memory_config);
let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
Expand All @@ -42,28 +49,49 @@ def TTNN_ToMemoryConfigOp : TTNN_Op<"to_memory_config"> {
def TTNN_ToLayoutOp : TTNN_Op<"to_layout"> {
let summary = "ToLayout op.";
let description = [{
This op converts the layout of the input tensor based on the given layout.
It handles:
- ROW_MAJOR to TILE
- TILE to ROW_MAJOR
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
TTNN_LayoutAttr:$layout);
TTNN_LayoutAttr:$layout,
OptionalAttr<TT_DataTypeAttr>:$dtype,
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config,
Optional<TT_Device>:$device);
let results = (outs AnyRankedTensor:$result);
}

def TTNN_TypecastOp : TTNN_Op<"typecast"> {
let summary = "Typecast op.";
let description = [{
This op converts the data type of the input tensor based on the given data type.
It handles:
- conversions of data types.
}];

let arguments = (ins AnyRankedTensor:$input,
TT_DataTypeAttr:$dtype);
let results = (outs AnyRankedTensor:$result);
}

def TTNN_ToDeviceOp : TTNN_Op<"to_device"> {
let summary = "ToDevice op.";
let description = [{
This op sends the input tensor to the given device with the given memory config.
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
TTNN_MemoryConfigAttr:$memory_config);
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config);
let results = (outs AnyRankedTensor:$result);
}

def TTNN_FromDeviceOp : TTNN_Op<"from_device"> {
let summary = "FromDevice op.";
let description = [{
This op retrieves the input tensor from the given device.
}];

let arguments = (ins AnyRankedTensor:$input);
Expand Down Expand Up @@ -174,13 +202,6 @@ def TTNN_SigmoidOp : TTNN_ElementwiseUnaryOp<"sigmoid"> {
}];
}

def TTNN_TypecastOp : TTNN_ElementwiseUnaryOp<"typecast"> {
let summary = "Eltwise typecast.";
let description = [{
Eltwise typecast operation.
}];
}

def TTNN_ExpOp : TTNN_ElementwiseUnaryOp<"exp"> {
let summary = "Eltwise exponential.";
let description = [{
Expand Down
12 changes: 7 additions & 5 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@ def TTNN_Layout : I32EnumAttr<"Layout", "TTNN Layout",
let cppNamespace = "::mlir::tt::ttnn";
}

def TTNN_TensorMemoryLayout_Interleaved : I32EnumAttrCase<"Interleaved", 0, "interleaved">;
def TTNN_TensorMemoryLayout_SingleBank : I32EnumAttrCase<"SingleBank", 1, "single_bank">;
def TTNN_TensorMemoryLayout_HeightSharded : I32EnumAttrCase<"HeightSharded", 2, "height_sharded">;
def TTNN_TensorMemoryLayout_WidthSharded : I32EnumAttrCase<"WidthSharded", 3, "width_sharded">;
def TTNN_TensorMemoryLayout_BlockSharded : I32EnumAttrCase<"BlockSharded", 4, "block_sharded">;
def TTNN_TensorMemoryLayout_None : I32EnumAttrCase<"None", 0, "none">;
def TTNN_TensorMemoryLayout_Interleaved : I32EnumAttrCase<"Interleaved", 1, "interleaved">;
def TTNN_TensorMemoryLayout_SingleBank : I32EnumAttrCase<"SingleBank", 2, "single_bank">;
def TTNN_TensorMemoryLayout_HeightSharded : I32EnumAttrCase<"HeightSharded", 3, "height_sharded">;
def TTNN_TensorMemoryLayout_WidthSharded : I32EnumAttrCase<"WidthSharded", 4, "width_sharded">;
def TTNN_TensorMemoryLayout_BlockSharded : I32EnumAttrCase<"BlockSharded", 5, "block_sharded">;

def TTNN_TensorMemoryLayout : I32EnumAttr<"TensorMemoryLayout", "TTNN Tensor Memory Layout",
[
TTNN_TensorMemoryLayout_None,
TTNN_TensorMemoryLayout_Interleaved,
TTNN_TensorMemoryLayout_SingleBank,
TTNN_TensorMemoryLayout_HeightSharded,
Expand Down
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTNN/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ toTTNNBufferType(const mlir::tt::MemorySpace memorySpace);
ttnn::TensorMemoryLayout
toTTNNTensorMemoryLayout(const tt::TensorMemoryLayout ttTensorMemoryLayout);

DataType getDataTypeFromMemRef(mlir::MemRefType memref);

} // namespace mlir::tt::ttnn::utils

#endif // TTMLIR_DIALECT_TTNN_UTILS_UTILS_H
25 changes: 13 additions & 12 deletions include/ttmlir/Target/Common/types.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@ enum Arch: uint {
}

enum DataType: uint16 {
Float32 = 0,
Float16 = 1,
BFloat16 = 2,
BFP_Float8 = 3,
BFP_BFloat8 = 4,
BFP_Float4 = 5,
BFP_BFloat4 = 6,
BFP_Float2 = 7,
BFP_BFloat2 = 8,
UInt32 = 9,
UInt16 = 10,
UInt8 = 11,
None = 0,
Float32 = 1,
Float16 = 2,
BFloat16 = 3,
BFP_Float8 = 4,
BFP_BFloat8 = 5,
BFP_Float4 = 6,
BFP_BFloat4 = 7,
BFP_Float2 = 8,
BFP_BFloat2 = 9,
UInt32 = 10,
UInt16 = 11,
UInt8 = 12,
}

enum OOBVal: ushort {
Expand Down
11 changes: 10 additions & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,25 @@ table GetDeviceOp {

table ToMemoryConfigOp {
in0: tt.target.TensorRef;
device: tt.target.DeviceRef;
memcfg: MemoryConfigDesc;
out: tt.target.TensorRef;
}

table ToLayoutOp {
in: tt.target.TensorRef;
layout: tt.target.TensorLayout;
dtype: tt.target.DataType;
memcfg: tt.target.MemoryConfigDesc;
device: tt.target.DeviceRef;
out: tt.target.TensorRef;
}

table TypecastOp {
in: tt.target.TensorRef;
dtype: tt.target.DataType;
out: tt.target.TensorRef;
}

table ToDeviceOp {
in: tt.target.TensorRef;
device: tt.target.DeviceRef;
Expand Down Expand Up @@ -191,6 +199,7 @@ union OpType {
GetDeviceOp,
ToMemoryConfigOp,
ToLayoutOp,
TypecastOp,
ToDeviceOp,
FromDeviceOp,
EmptyOp,
Expand Down
2 changes: 2 additions & 0 deletions include/ttmlir/Target/TTNN/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ ::tt::target::TensorMemoryLayout toTargetTensorMemoryLayout(
return ::tt::target::TensorMemoryLayout::WidthSharded;
case ::mlir::tt::ttnn::TensorMemoryLayout::BlockSharded:
return ::tt::target::TensorMemoryLayout::BlockSharded;
case ::mlir::tt::ttnn::TensorMemoryLayout::None:
return ::tt::target::TensorMemoryLayout::None;
}

llvm_unreachable("Unsupported TensorMemoryLayout");
Expand Down
Loading

0 comments on commit 376d9ff

Please sign in to comment.