Skip to content

Commit

Permalink
#1090: Remove ttnn::CompositeToLayoutOp, use ttnn::ToLayoutOp instead
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT committed Oct 30, 2024
1 parent 22a06f2 commit 6b1897c
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 96 deletions.
24 changes: 4 additions & 20 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,6 @@ def TTNN_GetDeviceOp : TTNN_Op<"get_device"> {
let results = (outs TT_Device:$device);
}

def TTNN_CompositeToLayoutOp : TTNN_Op<"composite_to_layout"> {
let summary = "Composite toLayout op.";
let description = [{
This op wraps all layout information gathered from ttir.toLayout. It is used/updated by the optimizer
to perform optimizations, and later broken down into specific memory/layout operations (toDevice, toMemoryConfig, toLayout etc.).
}];

let arguments = (ins AnyRankedTensor:$input,
TTNN_LayoutAttr:$layout,
TT_DataTypeAttr:$dtype,
TTNN_MemoryConfigAttr:$memory_config,
Optional<TT_Device>:$device);

let results = (outs AnyRankedTensor:$result);
}

def TTNN_ToMemoryConfigOp : TTNN_Op<"to_memory_config"> {
let summary = "ToMemoryConfig op.";
let description = [{
Expand All @@ -65,10 +49,10 @@ def TTNN_ToMemoryConfigOp : TTNN_Op<"to_memory_config"> {
def TTNN_ToLayoutOp : TTNN_Op<"to_layout"> {
let summary = "ToLayout op.";
let description = [{
This op converts the layout of the input tensor based on the given layout.
It handles:
- ROW_MAJOR to TILE
- TILE to ROW_MAJOR
This op wraps all layout information gathered from ttir.toLayout. It is used/updated by the optimizer
to perform optimizations, and later broken down into specific memory/layout operations (toDevice, toMemoryConfig etc.).
Currently in the TTNN backend, we use this op solely for tilize/untilize, therefore marking all other attrs as optional.
Once ttnn::to_layout supports other attrs, we can remove the optional tag.
}];

let arguments = (ins AnyRankedTensor:$input,
Expand Down
6 changes: 3 additions & 3 deletions include/ttmlir/Dialect/TTNN/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ def TTNNDeallocate: Pass<"ttnn-deallocate", "::mlir::ModuleOp"> {
}];
}

def TTNNDecomposeCompositeLayouts: Pass<"ttnn-decompose-composite-layouts", "::mlir::ModuleOp"> {
let summary = "Decompose composite layout ops to according memory ops.";
def TTNNDecomposeLayouts: Pass<"ttnn-decompose-layouts", "::mlir::ModuleOp"> {
let summary = "Decompose ToLayoutOps to more granular memory ops.";
let description = [{
This pass decomposes composite layouts to memory ops (e.g. toDevice, toMemoryConfig etc.).
This pass decomposes ToLayoutOps to memory ops (e.g. toDevice, toMemoryConfig etc.).
}];
}

Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ class ToLayoutOpConversionPattern
op.getContext(), ttnn::ShapeAttr::get(rewriter.getContext(),
outputMemref.getShape())));

rewriter.replaceOpWithNewOp<ttnn::CompositeToLayoutOp>(
rewriter.replaceOpWithNewOp<ttnn::ToLayoutOp>(
op, this->getTypeConverter()->convertType(result), adaptor.getInput(),
outputLayout, outputDataType, outputMemConfigAttr,
isOutputOnHost ? nullptr : getOrInsertDevice(rewriter, op));
Expand Down
8 changes: 2 additions & 6 deletions lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@

namespace mlir::tt::ttnn {

bool isCompositeToLayoutOp(mlir::Operation *op) {
return isa<ttnn::CompositeToLayoutOp>(op);
}

void DFShardingPolicy::run(
const std::unordered_set<Edge> &overrideReshardEdges) {
rootOp->walk([&](func::FuncOp func) {
Expand Down Expand Up @@ -39,7 +35,7 @@ void DFShardingPolicy::run(
//
if (l1ChainConfigs->back().isEmpty()) {
for (auto *op : scheduleableOps) {
if (isCompositeToLayoutOp(op)) {
if (isa<ttnn::ToLayoutOp>(op)) {
currentOp = op;
break;
}
Expand All @@ -57,7 +53,7 @@ void DFShardingPolicy::run(
// Skip starting sharding chain if currentOp is a memory management op.
//
if (l1ChainConfigs->back().isEmpty() &&
isCompositeToLayoutOp(currentOp)) {
isa<ttnn::ToLayoutOp>(currentOp)) {
currentOp = nullptr;
continue;
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ bool cantChangeOutputLayout(Operation *op) {
return true;
}

if (llvm::isa<CompositeToLayoutOp>(op)) {
if (llvm::isa<ToLayoutOp>(op)) {
return true;
}

Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void createTTNNPipelineLoweringPasses(

void createTTNNPipelineLayoutDecompositionPass(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
pm.addPass(createTTNNDecomposeCompositeLayouts());
pm.addPass(createTTNNDecomposeLayouts());
}

void createTTNNPipelineDeallocPass(
Expand Down
45 changes: 20 additions & 25 deletions lib/Dialect/TTNN/Transforms/Optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,15 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
// TTNN will need to be special handled as well. Depends on ttnn
// layout attr refactor and lowering.
//
else if (isa<ttnn::CompositeToLayoutOp>(op)) {
else if (isa<ttnn::ToLayoutOp>(op)) {
BufferType bufferType =
utils::toTTNNBufferType(ttLayoutAttr.getMemorySpace());
TensorMemoryLayout tensorMemoryLayout =
utils::toTTNNTensorMemoryLayout(ttLayoutAttr.getMemLayout());
// Update the device op with the new tensor type.
//
ttnn::CompositeToLayoutOp compositeLayoutOp =
llvm::cast<ttnn::CompositeToLayoutOp>(op);
compositeLayoutOp.setMemoryConfigAttr(ttnn::MemoryConfigAttr::get(
ttnn::ToLayoutOp toLayoutOp = llvm::cast<ttnn::ToLayoutOp>(op);
toLayoutOp.setMemoryConfigAttr(ttnn::MemoryConfigAttr::get(
op->getContext(),
ttnn::TensorMemoryLayoutAttr::get(op->getContext(),
tensorMemoryLayout),
Expand Down Expand Up @@ -243,45 +242,41 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
Operation *producerOp = edge.producerOp;
Operation *consumerOp = edge.consumerOp;

// If producerOp is a compositeToLayoutOp, adjust its output layout(update
// If producerOp is a toLayoutOp, adjust its output layout(update
// inplace) to reflect consumerOp's output layout. If producerOp is not a
// compositeToLayoutOp, insert a compositeToLayoutOp in between producerOp
// toLayoutOp, insert a toLayoutOp in between producerOp
// and consumerOp.
//
if (isa<ttnn::CompositeToLayoutOp>(producerOp)) {
ttnn::CompositeToLayoutOp compositeLayoutOp =
llvm::cast<ttnn::CompositeToLayoutOp>(producerOp);
if (isa<ttnn::ToLayoutOp>(producerOp)) {
ttnn::ToLayoutOp toLayoutOp = llvm::cast<ttnn::ToLayoutOp>(producerOp);
tt::LayoutAttr consumerOpOutputLayout = mlir::cast<tt::LayoutAttr>(
mlir::cast<RankedTensorType>(consumerOp->getResult(0).getType())
.getEncoding());

RankedTensorType compositeLayoutOpTensorType =
mlir::cast<RankedTensorType>(
compositeLayoutOp.getResult().getType());
llvm::ArrayRef<int64_t> compositeLayoutOpTensorShape =
compositeLayoutOpTensorType.getShape();
tt::LayoutAttr compositeLayoutOpLayout = mlir::cast<tt::LayoutAttr>(
compositeLayoutOpTensorType.getEncoding());
RankedTensorType toLayoutOpTensorType =
mlir::cast<RankedTensorType>(toLayoutOp.getResult().getType());
llvm::ArrayRef<int64_t> toLayoutOpTensorShape =
toLayoutOpTensorType.getShape();
tt::LayoutAttr toLayoutOpLayout =
mlir::cast<tt::LayoutAttr>(toLayoutOpTensorType.getEncoding());

// TODO(nobradovic): Match memory space and layout of consumer op. This
// actually needs to be properly resolved based on op type, output
// layout and other inputs.
//
RankedTensorType newTensorType = RankedTensorType::get(
compositeLayoutOpTensorShape,
compositeLayoutOpTensorType.getElementType(),
compositeLayoutOpLayout
.withElementType(compositeLayoutOp->getContext(),
toLayoutOpTensorShape, toLayoutOpTensorType.getElementType(),
toLayoutOpLayout
.withElementType(toLayoutOp->getContext(),
consumerOpOutputLayout.getElementType())
.withMemorySpace(compositeLayoutOp.getContext(),
.withMemorySpace(toLayoutOp.getContext(),
consumerOpOutputLayout.getMemorySpace())
.withMemoryLayout(compositeLayoutOp.getContext(),
.withMemoryLayout(toLayoutOp.getContext(),
consumerOpOutputLayout.getMemLayout())
.withGrid(compositeLayoutOp.getContext(),
compositeLayoutOpTensorType,
.withGrid(toLayoutOp.getContext(), toLayoutOpTensorType,
consumerOpOutputLayout.getGrid()));

compositeLayoutOp.getResult().setType(newTensorType);
toLayoutOp.getResult().setType(newTensorType);
}
// TODO (nobradovic): Memory layout reconfig needs to be reimplemented for
// TTNN dialect.
Expand Down
Loading

0 comments on commit 6b1897c

Please sign in to comment.