Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use tilized dram-interleaved as default input-output layout #1744

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,20 @@ def TTNN_FullOp : TTNN_Op<"full"> {

let arguments = (ins TT_Device:$device, F32Attr:$fillValue);
let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
mlir::tt::ttnn::FullOp op = mlir::cast<mlir::tt::ttnn::FullOp>(this->getOperation());
auto outputType = mlir::cast<RankedTensorType>(op.getType());
ttnn::TTNNLayoutAttr layoutAttr =
mlir::cast<ttnn::TTNNLayoutAttr>(outputType.getEncoding());
if (outputType.getRank() > 1 ||
!mlir::isa<TileType>(layoutAttr.getElementType())) {
return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds(this->getOperation());
}
return wa::TTNNOperandsWorkaroundsFactory::createFullOpOperandsWorkarounds();
}
}];
}

def TTNN_AllocOp : TTNN_Op<"alloc"> {
Expand Down
6 changes: 1 addition & 5 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,7 @@ def TTNN_MemoryConfigAttr : TTNN_Attr<"MemoryConfig", "memory_config"> {
let assemblyFormat = "`<` params `>`";

let extraClassDeclaration = [{
::llvm::ArrayRef<int64_t> getShardShapeArray() const
{
return this->getShardSpec().getShardShape().getShape();
}

llvm::ArrayRef<int64_t> getShardShape(bool convertTileToScalar = true) const;
MemoryConfigAttr withBufferType(::mlir::MLIRContext *context, BufferType bufferType);
MemoryConfigAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout);
}];
Expand Down
3 changes: 3 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ class TTNNOperandsWorkaroundsFactory {

// Create workarounds for embedding op operands.
static TTNNOperandsWorkarounds createEmbeddingOpOperandsWorkarounds();

// Create workarounds for full op operands.
static TTNNOperandsWorkarounds createFullOpOperandsWorkarounds();
};

} // namespace mlir::tt::ttnn::wa
Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ struct TTIRToTTNNBackendPipelineOptions

// Option to enable/disable the workaround pass.
//
Option<bool> layouotWorkaroundsEnabled{
Option<bool> layoutWorkaroundsEnabled{
*this, "enable-layout-workaround-pass",
llvm::cl::desc("Enable layout workaround pass."), llvm::cl::init(true)};

Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def TTNNWorkarounds : Pass<"ttnn-workaround", "::mlir::ModuleOp"> {
}];

let options = [
Option<"layouotWorkaroundsEnabled",
Option<"layoutWorkaroundsEnabled",
"ttnn-enable-layout-workaround-pass",
"bool", /*default=*/"true",
"TTNN Layout Workarounds Pass">,
Expand Down
93 changes: 4 additions & 89 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Conversion/TTIRToTTNN/TTIRToTTNN.h"

#include "ttmlir/Conversion/TTIRToTTNN/Utils.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"
Expand All @@ -13,6 +12,7 @@
#include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
Expand Down Expand Up @@ -170,6 +170,9 @@ class ToLayoutOpConversionPattern
rewriter.eraseOp(emptyOp);
}

assert(mlir::isa<mlir::RankedTensorType>(adaptor.getInput().getType()) &&
"Expected RankedTensorType for ToLayoutOp input");

auto outputLayoutAttr = mlir::cast<ttnn::TTNNLayoutAttr>(
op.getResult().getType().getEncoding());

Expand All @@ -186,32 +189,6 @@ class ToLayoutOpConversionPattern
bool isOutputOnHost = (outputBufferType == ttnn::BufferType::SystemMemory);

RankedTensorType result = mlir::cast<RankedTensorType>(op.getType());
if (!isOutputOnHost) {
// TODO(bug #665):
// Binary ops fail with row major layout in ttnn, defaulting to and
// assuming tile layout for all device tensors...
// Note: mlir doesn't know about this, so tensors may still appear as row
// major in the generated mlir
// TODO(bug #875):
// Remove the following code block once constraints modelling is
// implemented on dialect level
//
// Default to Tile layout unless op supports only RowMajor layout
//
ttnn::Layout newOutputLayoutEnum =
shouldForceRowMajor(op) ? ttnn::Layout::RowMajor : ttnn::Layout::Tile;

// If the layout of the output tensor changed as a result of forcing the
// layout update the tensor type
if (outputLayoutEnum != newOutputLayoutEnum) {
result =
getLayoutForcedResultTensor(rewriter, result, newOutputLayoutEnum);
op.getResult().setType(result);
outputLayoutAttr =
mlir::cast<ttnn::TTNNLayoutAttr>(result.getEncoding());
outputLayoutEnum = newOutputLayoutEnum;
}
}

ttnn::LayoutAttr outputLayout =
ttnn::LayoutAttr::get(rewriter.getContext(), outputLayoutEnum);
Expand All @@ -235,68 +212,6 @@ class ToLayoutOpConversionPattern

return success();
}

private:
bool shouldForceRowMajor(ttir::ToLayoutOp op) const {
// Check if the output tensor is used by an op that only supports row major.
//
// EmbeddingBackwardOp supports row major layout for the first and second
// operands.
for (mlir::Operation *user : op.getResult().getUsers()) {
if (isa<ttir::Conv2dOp>(user) || isa<ttir::SliceOp>(user) ||
(isa<ttir::EmbeddingBackwardOp>(user) &&
(user->getOperand(0) == op || user->getOperand(1) == op))) {
return true;
}
}

return false;
}

RankedTensorType
getLayoutForcedResultTensor(ConversionPatternRewriter &rewriter,
RankedTensorType oldOutput,
ttnn::Layout newOutputLayoutEnum) const {
auto oldOutputLayoutAttr =
mlir::cast<ttnn::TTNNLayoutAttr>(oldOutput.getEncoding());
DataType outputDtype = oldOutputLayoutAttr.getDataType();
SmallVector<std::int64_t> oldShardShape =
oldOutputLayoutAttr.getShardShape();
size_t shardShapeSize = oldShardShape.size();
assert(shardShapeSize >= 2 && "expected at least 2D shape");

if (newOutputLayoutEnum == ttnn::Layout::RowMajor) {
// Set shard shape to match convention of row major layout
auto tileType =
mlir::cast<TileType>(oldOutputLayoutAttr.getElementType());
llvm::SmallVector<int64_t> newShardShape(oldShardShape.begin(),
oldShardShape.end());
newShardShape[shardShapeSize - 2] =
oldShardShape[shardShapeSize - 2] * tileType.getHeight();
newShardShape[shardShapeSize - 1] =
oldShardShape[shardShapeSize - 1] * tileType.getWidth();
Type newElementType = ttnn::utils::createRowMajorTypeFromDtype(
rewriter.getContext(), outputDtype);
RankedTensorType result = RankedTensorType::get(
oldOutput.getShape(), oldOutput.getElementType(),
oldOutputLayoutAttr
.withElementType(rewriter.getContext(), newElementType)
.withShardShape(rewriter.getContext(), newShardShape));
return result;
}

if (newOutputLayoutEnum == ttnn::Layout::Tile) {
TileType tileType =
TileType::get(rewriter.getContext(),
{ttnn::TILE_HEIGHT, ttnn::TILE_WIDTH}, outputDtype);
RankedTensorType result = RankedTensorType::get(
oldOutput.getShape(), oldOutput.getElementType(),
oldOutputLayoutAttr.withElementType(rewriter.getContext(), tileType));
return result;
}

llvm_unreachable("Unreachable code path. Unexpected output layout enum");
}
};

template <typename TTIROpTy, typename TTNNOpTy,
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/TT/IR/TTOpsTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ mlir::tt::SystemDescAttr::getDefault(MLIRContext *context) {
mlir::tt::SystemDescAttr
mlir::tt::SystemDescAttr::getFromPath(MLIRContext *context, std::string &path) {
// Check if file exists
assert(!path.empty() && "cluster desc path must not be empty!");
assert(!path.empty() && "system desc path must not be empty!");
std::ifstream fbb(path, std::ios::binary | std::ios::ate);
assert(fbb.good() && "cluster desc does not exist!");
assert(fbb.good() && "system desc does not exist!");
std::streampos size = fbb.tellg();
fbb.seekg(0, std::ios::beg);
auto buffer = std::shared_ptr<void>(std::malloc(size), std::free);
Expand Down
19 changes: 19 additions & 0 deletions lib/Dialect/TTNN/Analysis/BFInterleavedPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,25 @@ void BFInterleavedPolicy::run() {
scheduler.scheduleOp(nextOpForScheduling);
}

// TODO (#0000): This is a temporary solution
// Currently ReturnOps are not considered when calculating L1 usage
llvm::SmallVector<mlir::Operation *> eraseableL1UsageOps;
for (auto &[op, usage] : currentL1UsagePerOp) {
for (Operation *user : op->getUsers()) {
if (isa<mlir::func::ReturnOp>(user)) {
usage.numOfUnscheduledUsers -= 1;
}
}
if (usage.numOfUnscheduledUsers == 0) {
eraseableL1UsageOps.push_back(op);
}
}

for (Operation *op : eraseableL1UsageOps) {
currentL1Usage -= currentL1UsagePerOp[op].l1MemUsagePerUser;
currentL1UsagePerOp.erase(op);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fbajraktariTT, can you review this file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI @odjuricicTT, as @fbajraktariTT completed internship recently.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnie-TT I'm not sure that this extra logic is needed. Was a test failing without this temp fix? If so, can you provide more details?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@odjuricicTT there's an assert below that checks if the currentL1Usage is 0. This error only surfaces with my changes - it's fine without my changes because we always untilize (to_layout) before returning. However it's possible now with my changes that we will return directly without any intermediate ops between the current op and the return op, and this causes issues because we wouldn't have zeroed out currentL1Usage.

Since this function doesn't decrement l1 usage on return op, the assert will fire and say that the l1 usage is non 0. My change basically adds a check that if the consumer op is a return op, we decrement the l1 usage.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnie-TT Thanks! Your solution is fine for now. Just please file the followup issue and reference it in the comment.

assert(currentL1Usage == 0);
assert(currentL1UsagePerOp.size() == 0);

Expand Down
12 changes: 12 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,16 @@ TTNNOperandsWorkaroundsFactory::createEmbeddingOpOperandsWorkarounds() {
.addInputOperandWorkaround(weightWorkaround)
.addOutputOperandWorkaround(weightWorkaround);
}

// Factory method to create a set of workarounds for full op output operand.
// ttnn::FullOp does not support 1D tilized tensors
// If the output of full is a 1D tensor and is tiled
// we need to convert it to row major layout then tilize separately
TTNNOperandsWorkarounds
TTNNOperandsWorkaroundsFactory::createFullOpOperandsWorkarounds() {
wa::TTNNOperandWorkarounds rowMajorLayoutWorkaround;
rowMajorLayoutWorkaround.tensorLayoutWorkaround = Layout::RowMajor;
return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds()
.addOutputOperandWorkaround(rowMajorLayoutWorkaround);
}
} // namespace mlir::tt::ttnn::wa
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void createTTNNPipelineLoweringPasses(
void createTTNNPipelineWorkaroundPass(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
TTNNWorkaroundsOptions workaroundOptions{
options.layouotWorkaroundsEnabled,
options.layoutWorkaroundsEnabled,
options.decompositionWorkaroundsEnabled};
pm.addPass(createTTNNWorkarounds(workaroundOptions));
pm.addPass(mlir::createCanonicalizerPass());
Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/TTNN/Transforms/Optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,22 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
// If schedule is set, apply order of operations to func.
//
if (opSchedule[func].size() > 1) {
// TODO (#0000): This is a temporary solution - when defaulting to dram
// tile input/output layout, GetDeviceOp can randomly appear as the last
// op in the graph instead of the first. This workaround ensures
// getDeviceOp is always in the beginning of the schedule.
// To reproduce, remove this workaround and run
// Silicon/TTNN/optimizer/mnist_sharding.mlir multiple times (as it is
// non-deterministic).
Operation **it =
std::find_if(opSchedule[func].begin(), opSchedule[func].end(),
[](Operation *op) { return isa<GetDeviceOp>(op); });
if (it != opSchedule[func].end()) {
GetDeviceOp deviceOp = mlir::cast<GetDeviceOp>(*it);
opSchedule[func].erase(it);
opSchedule[func].insert(opSchedule[func].begin(), deviceOp);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@odjuricicTT, can you review this file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnie-TT The proper fix for this would be to add it here:
https://github.com/tenstorrent/tt-mlir/blob/main/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp#L37

Try changing the if to check for GetDeviceOp as well as ToLayoutOp.

for (size_t i = 0; i < opSchedule[func].size() - 1; i++) {
Operation *op = opSchedule[func][i];

Expand Down
54 changes: 33 additions & 21 deletions lib/Dialect/TTNN/Transforms/TTNNDecomposeLayouts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ class TTNNDecomposeLayouts
});
});
for (Operation *op : opsToReplace) {
this->createLayoutConversionOps(mlir::cast<ttnn::ToLayoutOp>(op),
rewriter);
if (failed(createLayoutConversionOps(mlir::cast<ttnn::ToLayoutOp>(op),
rewriter))) {
signalPassFailure();
return;
}
rewriter.eraseOp(op);
}
}
Expand All @@ -42,6 +45,7 @@ class TTNNDecomposeLayouts
ttnn::Layout layoutEnum;
DataType dataType;
ttnn::TensorMemoryLayoutAttr tensorMemoryLayout;
GridAttr shardGrid;
llvm::SmallVector<int64_t> shardShape;

ttnn::MemoryConfigAttr createMemoryConfigAttr(MLIRContext *context) const {
Expand All @@ -51,7 +55,9 @@ class TTNNDecomposeLayouts
ttnn::ShapeAttr::get(context, shardShape)),
tensorMemoryLayout);
}

bool isL1Sharded() const {
return isShardedMemoryLayout(tensorMemoryLayout.getValue());
}
bool isOnHost() const {
return bufferType == ttnn::BufferType::SystemMemory;
}
Expand Down Expand Up @@ -115,6 +121,9 @@ class TTNNDecomposeLayouts
auto inputLayoutAttr =
mlir::cast<TTNNLayoutAttr>(op.getInput().getType().getEncoding());

auto outputLayoutAttr =
mlir::cast<TTNNLayoutAttr>(op.getResult().getType().getEncoding());

assert(op.getMemoryConfig().has_value());
MemoryConfigAttr outputMemoryConfig = op.getMemoryConfig().value();

Expand All @@ -131,9 +140,12 @@ class TTNNDecomposeLayouts
input.tensorMemoryLayout = inputLayoutAttr.getMemLayout();
output.tensorMemoryLayout = outputMemoryConfig.getTensorMemoryLayout();

input.shardGrid = inputLayoutAttr.getGrid();
output.shardGrid = outputLayoutAttr.getGrid();

input.shardShape = inputLayoutAttr.getShardShape();
output.shardShape =
llvm::SmallVector<int64_t>{outputMemoryConfig.getShardShapeArray()};
output.shardShape = outputLayoutAttr.getShardShape();

return {input, output};
}

Expand All @@ -148,14 +160,6 @@ class TTNNDecomposeLayouts

opsToCreate.createTypecastOp = input.dataType != output.dataType;
opsToCreate.createToLayoutOp = input.layoutEnum != output.layoutEnum;
// TODO(bug #665):
// Insert a ToLayoutOp manually if we're moving from device to host to
// untilize. Since we're hardcoding tile layout, the tensor may be row
// major in mlir, and therefore it would appear as if we don't need to
// untilize
opsToCreate.createToLayoutOp |=
(opsToCreate.createFromDeviceOp and
output.layoutEnum == ttnn::Layout::RowMajor);

// ToDeviceOp can handle the creation of the memory config of the initial
// device tensor
Expand All @@ -168,8 +172,10 @@ class TTNNDecomposeLayouts
output.bufferType == ttnn::BufferType::L1) or
(input.bufferType == ttnn::BufferType::L1 and
output.bufferType == ttnn::BufferType::DRAM);
// If shard grids don't match we need to reshard
opsToCreate.createToMemoryConfigOp |=
(input.shardShape != output.shardShape);
(input.isL1Sharded() and output.isL1Sharded() and
input.shardGrid != output.shardGrid);
}
return opsToCreate;
}
Expand Down Expand Up @@ -764,24 +770,30 @@ class TTNNDecomposeLayouts
* sizeof(uint32_t). For now, we will always untilize on host. We rarely
* need device to device untilize, so the perf hit should be acceptable.
*/
void createLayoutConversionOps(ttnn::ToLayoutOp op,
IRRewriter &rewriter) const {
mlir::LogicalResult createLayoutConversionOps(ttnn::ToLayoutOp op,
IRRewriter &rewriter) const {
auto [input, output] = getInputOutputLayouts(op);
OpsToCreate opsToCreate = determineRequiredOps(input, output);
assert(isCreationValid(op, input, output, opsToCreate) &&
"Invalid layout conversion");
if (not isCreationValid(op, input, output, opsToCreate)) {
return failure();
}

auto device = op.getDevice();
assert((device || output.isOnHost()) &&
"Op device must be set for output tensors on device");
if (not device and not output.isOnHost()) {
op->emitError("Device not specified for device tensor");
return failure();
}

OpCreationInfo info(device, input, output, opsToCreate);

Value currentInput = op.getInput();

if (input.isOnHost()) {
handleHostInputLayoutConversion(op, rewriter, currentInput, info);
return;
return success();
}
handleDeviceInputLayoutConversion(op, rewriter, currentInput, info);
return success();
}
};
} // namespace mlir::tt::ttnn
Loading