Skip to content

Commit

Permalink
Fix toMemoryConfig Op hacks and hardcodings, create memory ops as nec…
Browse files Browse the repository at this point in the history
…essary and obey constraints in a separate pass, properly model toLayout op.
  • Loading branch information
jnie-TT committed Oct 28, 2024
1 parent d22d7eb commit 4088fc9
Showing 47 changed files with 1,758 additions and 644 deletions.
5 changes: 4 additions & 1 deletion include/ttmlir-c/TTNNAttrs.h
Original file line number Diff line number Diff line change
@@ -29,9 +29,12 @@ MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTNNTensorMemoryLayoutAttrGet(
MLIR_CAPI_EXPORTED MlirAttribute
ttmlirTTNNBufferTypeAttrGet(MlirContext ctx, uint32_t bufferType);

MLIR_CAPI_EXPORTED MlirAttribute
ttmlirTTNNShardSpecAttrGet(MlirContext ctx, MlirAttribute shardShapeAttr);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTNNMemoryConfigAttrGet(
MlirContext ctx, MlirAttribute tensorMemoryLayoutAttr,
MlirAttribute bufferTypeAttr);
MlirAttribute bufferTypeAttr, MlirAttribute shardSpecAttr);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTNNShapeAttrGet(MlirContext ctx,
int64_t *shape,
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
@@ -279,6 +279,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> {
LayoutAttr withElementType(::mlir::MLIRContext *context, Type elementType);
LayoutAttr withMemorySpace(::mlir::MLIRContext *context, MemorySpace memorySpace);
LayoutAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout);
LayoutAttr withShardShape(::mlir::MLIRContext *context, llvm::SmallVector<int64_t> shardShape);

uint64_t getMemrefSizeBytes() const;
MemorySpace getMemorySpace() const;
59 changes: 48 additions & 11 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
@@ -27,13 +27,36 @@ def TTNN_GetDeviceOp : TTNN_Op<"get_device"> {
let results = (outs TT_Device:$device);
}

def TTNN_CompositeToLayoutOp : TTNN_Op<"composite_to_layout"> {
let summary = "Composite toLayout op.";
let description = [{
This op wraps all layout information gathered from ttir.toLayout. It is used/updated by the optimizer
to perform optimizations, and later broken down into specific memory/layout operations (toDevice, toMemoryConfig, toLayout etc.).
}];

let arguments = (ins AnyRankedTensor:$input,
TTNN_LayoutAttr:$layout,
TT_DataTypeAttr:$dtype,
TTNN_MemoryConfigAttr:$memory_config,
Optional<TT_Device>:$device);

let results = (outs AnyRankedTensor:$result);
}

def TTNN_ToMemoryConfigOp : TTNN_Op<"to_memory_config"> {
let summary = "ToMemoryConfig op.";
let description = [{
This op converts the memory config of the input tensor based on the given memory config.
It handles:
- Dram to L1
- L1 to Dram
- Interleaved to sharded
- Sharded to interleaved
- Sharded to sharded (reshard)
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device);
TTNN_MemoryConfigAttr:$memory_config);
let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
@@ -42,28 +65,49 @@ def TTNN_ToMemoryConfigOp : TTNN_Op<"to_memory_config"> {
def TTNN_ToLayoutOp : TTNN_Op<"to_layout"> {
let summary = "ToLayout op.";
let description = [{
This op converts the layout of the input tensor based on the given layout.
It handles:
- ROW_MAJOR to TILE
- TILE to ROW_MAJOR
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
TTNN_LayoutAttr:$layout);
TTNN_LayoutAttr:$layout,
OptionalAttr<TT_DataTypeAttr>:$dtype,
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config,
Optional<TT_Device>:$device);
let results = (outs AnyRankedTensor:$result);
}

def TTNN_TypecastOp : TTNN_Op<"typecast"> {
let summary = "Typecast op.";
let description = [{
This op converts the data type of the input tensor based on the given data type.
It handles:
- conversions of data types.
}];

let arguments = (ins AnyRankedTensor:$input,
TT_DataTypeAttr:$dtype);
let results = (outs AnyRankedTensor:$result);
}

def TTNN_ToDeviceOp : TTNN_Op<"to_device"> {
let summary = "ToDevice op.";
let description = [{
This op sends the input tensor to the given device with the given memory config.
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
TTNN_MemoryConfigAttr:$memory_config);
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config);
let results = (outs AnyRankedTensor:$result);
}

def TTNN_FromDeviceOp : TTNN_Op<"from_device"> {
let summary = "FromDevice op.";
let description = [{
This op retrieves the input tensor from the given device.
}];

let arguments = (ins AnyRankedTensor:$input);
@@ -174,13 +218,6 @@ def TTNN_SigmoidOp : TTNN_ElementwiseUnaryOp<"sigmoid"> {
}];
}

def TTNN_TypecastOp : TTNN_ElementwiseUnaryOp<"typecast"> {
let summary = "Eltwise typecast.";
let description = [{
Eltwise typecast operation.
}];
}

def TTNN_ExpOp : TTNN_ElementwiseUnaryOp<"exp"> {
let summary = "Eltwise exponential.";
let description = [{
40 changes: 29 additions & 11 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
@@ -42,8 +42,6 @@ def TTNN_CoreRangeAttr : TTNN_Attr<"CoreRange", "core_range"> {
}];
}

def TTNN_CoreRangeArrayAttr : TypedArrayAttrBase<TTNN_CoreRangeAttr, "">;

def TTNN_LayoutAttr : EnumAttr<TTNN_Dialect, TTNN_Layout, "layout"> {
let assemblyFormat = "`<` $value `>`";
}
@@ -56,25 +54,45 @@ def TTNN_BufferTypeAttr : EnumAttr<TTNN_Dialect, TTNN_BufferType, "buffer_type">
let assemblyFormat = "`<` $value `>`";
}

def TTNN_ShapeAttr : TTNN_Attr<"Shape", "shape"> {
let summary = "TTNN Shape attribute";
let description = [{
TTNN shape attribute
}];

let parameters = (ins ArrayRefParameter<"int64_t">:$shape);
let assemblyFormat = "`<` custom<DimensionList>($shape) `>`";
}

def TTNN_ShardSpecAttr : TTNN_Attr<"ShardSpec", "shard_spec"> {
let summary = "TTNN ShardSpec attribute";
let description = [{
TTNN ShardSpec attribute
}];

// TODO (#620): Add other fields like core_ranges, shard orientation etc.
let parameters = (ins AttrParameter<"ShapeAttr", "">:$shardShape);
let assemblyFormat = "`<` params `>`";
}

def TTNN_MemoryConfigAttr : TTNN_Attr<"MemoryConfig", "memory_config"> {
let summary = "TTNN MemoryConfig attribute";
let description = [{
TTNN memory config attribute
}];

let parameters = (ins AttrParameter<"TensorMemoryLayoutAttr", "">:$tensorMemoryLayout,
AttrParameter<"BufferTypeAttr", "">:$bufferType);
AttrParameter<"BufferTypeAttr", "">:$bufferType,
AttrParameter<"ShardSpecAttr", "">:$shardSpec);

let assemblyFormat = "`<` params `>`";
}

def TTNN_ShapeAttr : TTNN_Attr<"Shape", "shape"> {
let summary = "TTNN Shape attribute";
let description = [{
TTNN shape attribute
let extraClassDeclaration = [{
::llvm::ArrayRef<int64_t> getShardShapeArray() const
{
return this->getShardSpec().getShardShape().getShape();
}
}];

let parameters = (ins ArrayRefParameter<"int64_t">:$shape);
let assemblyFormat = "`<` custom<DimensionList>($shape) `>`";
}

def TTNN_MeshShapeAttr : TTNN_Attr<"MeshShape", "mesh_shape"> {
12 changes: 7 additions & 5 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.td
Original file line number Diff line number Diff line change
@@ -21,14 +21,16 @@ def TTNN_Layout : I32EnumAttr<"Layout", "TTNN Layout",
let cppNamespace = "::mlir::tt::ttnn";
}

def TTNN_TensorMemoryLayout_Interleaved : I32EnumAttrCase<"Interleaved", 0, "interleaved">;
def TTNN_TensorMemoryLayout_SingleBank : I32EnumAttrCase<"SingleBank", 1, "single_bank">;
def TTNN_TensorMemoryLayout_HeightSharded : I32EnumAttrCase<"HeightSharded", 2, "height_sharded">;
def TTNN_TensorMemoryLayout_WidthSharded : I32EnumAttrCase<"WidthSharded", 3, "width_sharded">;
def TTNN_TensorMemoryLayout_BlockSharded : I32EnumAttrCase<"BlockSharded", 4, "block_sharded">;
def TTNN_TensorMemoryLayout_None : I32EnumAttrCase<"None", 0, "none">;
def TTNN_TensorMemoryLayout_Interleaved : I32EnumAttrCase<"Interleaved", 1, "interleaved">;
def TTNN_TensorMemoryLayout_SingleBank : I32EnumAttrCase<"SingleBank", 2, "single_bank">;
def TTNN_TensorMemoryLayout_HeightSharded : I32EnumAttrCase<"HeightSharded", 3, "height_sharded">;
def TTNN_TensorMemoryLayout_WidthSharded : I32EnumAttrCase<"WidthSharded", 4, "width_sharded">;
def TTNN_TensorMemoryLayout_BlockSharded : I32EnumAttrCase<"BlockSharded", 5, "block_sharded">;

def TTNN_TensorMemoryLayout : I32EnumAttr<"TensorMemoryLayout", "TTNN Tensor Memory Layout",
[
TTNN_TensorMemoryLayout_None,
TTNN_TensorMemoryLayout_Interleaved,
TTNN_TensorMemoryLayout_SingleBank,
TTNN_TensorMemoryLayout_HeightSharded,
12 changes: 12 additions & 0 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
@@ -186,6 +186,12 @@ void createTTNNPipelineAnalysisPasses(
void createTTNNPipelineLoweringPasses(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options);

void createTTNNPipelineLayoutDecompositionPass(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options);

void createTTNNPipelineDeallocPass(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options);

void createTTNNPipelineTTIRPassesFromString(OpPassManager &pm,
std::string options);

@@ -195,6 +201,12 @@ void createTTNNPipelineAnalysisPassesFromString(OpPassManager &pm,
void createTTNNPipelineLoweringPassesFromString(OpPassManager &pm,
std::string options);

void createTTNNPipelineLayoutDecompositionPassFromString(OpPassManager &pm,
std::string options);

void createTTNNPipelineDeallocPassFromString(OpPassManager &pm,
std::string options);

void createTTIRToTTNNBackendPipeline(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options);

7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTNN/Transforms/Passes.td
Original file line number Diff line number Diff line change
@@ -14,6 +14,13 @@ def TTNNDeallocate: Pass<"ttnn-deallocate", "::mlir::ModuleOp"> {
}];
}

def TTNNDecomposeCompositeLayouts: Pass<"ttnn-decompose-composite-layouts", "::mlir::ModuleOp"> {
let summary = "Decompose composite layout ops to according memory ops.";
let description = [{
This pass decomposes composite layouts to memory ops (e.g. toDevice, toMemoryConfig etc.).
}];
}

def TTNNOptimizer: Pass<"ttnn-optimizer", "::mlir::ModuleOp"> {
let summary = "Determine op configurations for maximum performance.";
let description = [{
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTNN/Utils/Utils.h
Original file line number Diff line number Diff line change
@@ -21,6 +21,13 @@ toTTNNBufferType(const mlir::tt::MemorySpace memorySpace);
ttnn::TensorMemoryLayout
toTTNNTensorMemoryLayout(const tt::TensorMemoryLayout ttTensorMemoryLayout);

DataType getDataTypeFromMemRef(mlir::MemRefType memref);

Layout getLayoutFromMemRef(mlir::MemRefType memref);

mlir::Type createRowMajorTypeFromDtype(::mlir::MLIRContext *context,
DataType dtype);

} // namespace mlir::tt::ttnn::utils

#endif // TTMLIR_DIALECT_TTNN_UTILS_UTILS_H
35 changes: 19 additions & 16 deletions include/ttmlir/Target/Common/types.fbs
Original file line number Diff line number Diff line change
@@ -17,18 +17,19 @@ enum Arch: uint {
}

enum DataType: uint16 {
Float32 = 0,
Float16 = 1,
BFloat16 = 2,
BFP_Float8 = 3,
BFP_BFloat8 = 4,
BFP_Float4 = 5,
BFP_BFloat4 = 6,
BFP_Float2 = 7,
BFP_BFloat2 = 8,
UInt32 = 9,
UInt16 = 10,
UInt8 = 11,
None = 0,
Float32 = 1,
Float16 = 2,
BFloat16 = 3,
BFP_Float8 = 4,
BFP_BFloat8 = 5,
BFP_Float4 = 6,
BFP_BFloat4 = 7,
BFP_Float2 = 8,
BFP_BFloat2 = 9,
UInt32 = 10,
UInt16 = 11,
UInt8 = 12,
}

enum OOBVal: ushort {
@@ -74,13 +75,15 @@ enum BufferType: ushort {
Trace = 4,
}

// TODO (#620): Add other fields like core_ranges, shard orientation etc.
table ShardSpec {
shard_shape: [int64];
}

table MemoryConfigDesc {
tensor_memory_layout: TensorMemoryLayout;
buffer_type: BufferType;
// TODO(bug #620):
// Add support for ShardSpec and remove shard_shape
//
shard_shape: [int64];
shard_spec: ShardSpec;
}

table MemoryDesc {
11 changes: 10 additions & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
@@ -11,17 +11,25 @@ table GetDeviceOp {

table ToMemoryConfigOp {
in0: tt.target.TensorRef;
device: tt.target.DeviceRef;
memcfg: MemoryConfigDesc;
out: tt.target.TensorRef;
}

table ToLayoutOp {
in: tt.target.TensorRef;
layout: tt.target.TensorLayout;
dtype: tt.target.DataType;
memcfg: tt.target.MemoryConfigDesc;
device: tt.target.DeviceRef;
out: tt.target.TensorRef;
}

table TypecastOp {
in: tt.target.TensorRef;
dtype: tt.target.DataType;
out: tt.target.TensorRef;
}

table ToDeviceOp {
in: tt.target.TensorRef;
device: tt.target.DeviceRef;
@@ -198,6 +206,7 @@ union OpType {
GetDeviceOp,
ToMemoryConfigOp,
ToLayoutOp,
TypecastOp,
ToDeviceOp,
FromDeviceOp,
EmptyOp,
2 changes: 2 additions & 0 deletions include/ttmlir/Target/TTNN/utils.h
Original file line number Diff line number Diff line change
@@ -26,6 +26,8 @@ ::tt::target::TensorMemoryLayout toTargetTensorMemoryLayout(
return ::tt::target::TensorMemoryLayout::WidthSharded;
case ::mlir::tt::ttnn::TensorMemoryLayout::BlockSharded:
return ::tt::target::TensorMemoryLayout::BlockSharded;
case ::mlir::tt::ttnn::TensorMemoryLayout::None:
return ::tt::target::TensorMemoryLayout::None;
}

llvm_unreachable("Unsupported TensorMemoryLayout");
16 changes: 11 additions & 5 deletions lib/CAPI/TTNNAttrs.cpp
Original file line number Diff line number Diff line change
@@ -43,14 +43,20 @@ MlirAttribute ttmlirTTNNBufferTypeAttrGet(MlirContext ctx,
BufferTypeAttr::get(unwrap(ctx), static_cast<BufferType>(bufferType)));
}

MlirAttribute
ttmlirTTNNMemoryConfigAttrGet(MlirContext ctx,
MlirAttribute tensorMemoryLayoutAttr,
MlirAttribute bufferTypeAttr) {
MlirAttribute ttmlirTTNNShardSpecAttrGet(MlirContext ctx,
MlirAttribute shardShapeAttr) {
return wrap(ShardSpecAttr::get(
unwrap(ctx), mlir::cast<ShapeAttr>(unwrap(shardShapeAttr))));
}

MlirAttribute ttmlirTTNNMemoryConfigAttrGet(
MlirContext ctx, MlirAttribute tensorMemoryLayoutAttr,
MlirAttribute bufferTypeAttr, MlirAttribute shardSpecAttr) {
return wrap(MemoryConfigAttr::get(
unwrap(ctx),
mlir::cast<TensorMemoryLayoutAttr>(unwrap(tensorMemoryLayoutAttr)),
mlir::cast<BufferTypeAttr>(unwrap(bufferTypeAttr))));
mlir::cast<BufferTypeAttr>(unwrap(bufferTypeAttr)),
mlir::cast<ShardSpecAttr>(unwrap(shardSpecAttr))));
}

MlirAttribute ttmlirTTNNShapeAttrGet(MlirContext ctx, int64_t *shape,
Loading

0 comments on commit 4088fc9

Please sign in to comment.