Skip to content

Commit

Permalink
Use tilized dram-interleaved as default input-output layout
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT committed Jan 14, 2025
1 parent 5b4122a commit d4c5383
Show file tree
Hide file tree
Showing 95 changed files with 796 additions and 652 deletions.
6 changes: 1 addition & 5 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,7 @@ def TTNN_MemoryConfigAttr : TTNN_Attr<"MemoryConfig", "memory_config"> {
let assemblyFormat = "`<` params `>`";

let extraClassDeclaration = [{
::llvm::ArrayRef<int64_t> getShardShapeArray() const
{
return this->getShardSpec().getShardShape().getShape();
}

llvm::ArrayRef<int64_t> getShardShape(bool convertTileToScalar = true) const;
MemoryConfigAttr withBufferType(::mlir::MLIRContext *context, BufferType bufferType);
MemoryConfigAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout);
}];
Expand Down
144 changes: 50 additions & 94 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Conversion/TTIRToTTNN/TTIRToTTNN.h"

#include "ttmlir/Conversion/TTIRToTTNN/Utils.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"
Expand All @@ -13,6 +12,7 @@
#include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
Expand Down Expand Up @@ -170,6 +170,9 @@ class ToLayoutOpConversionPattern
rewriter.eraseOp(emptyOp);
}

assert(mlir::isa<mlir::RankedTensorType>(adaptor.getInput().getType()) &&
"Expected RankedTensorType for ToLayoutOp input");

auto outputLayoutAttr = mlir::cast<ttnn::TTNNLayoutAttr>(
op.getResult().getType().getEncoding());

Expand All @@ -186,32 +189,6 @@ class ToLayoutOpConversionPattern
bool isOutputOnHost = (outputBufferType == ttnn::BufferType::SystemMemory);

RankedTensorType result = mlir::cast<RankedTensorType>(op.getType());
if (!isOutputOnHost) {
// TODO(bug #665):
// Binary ops fail with row major layout in ttnn, defaulting to and
// assuming tile layout for all device tensors...
// Note: mlir doesn't know about this, so tensors may still appear as row
// major in the generated mlir
// TODO(bug #875):
// Remove the following code block once constraints modelling is
// implemented on dialect level
//
// Default to Tile layout unless op supports only RowMajor layout
//
ttnn::Layout newOutputLayoutEnum =
shouldForceRowMajor(op) ? ttnn::Layout::RowMajor : ttnn::Layout::Tile;

// If the layout of the output tensor changed as a result of forcing the
// layout update the tensor type
if (outputLayoutEnum != newOutputLayoutEnum) {
result =
getLayoutForcedResultTensor(rewriter, result, newOutputLayoutEnum);
op.getResult().setType(result);
outputLayoutAttr =
mlir::cast<ttnn::TTNNLayoutAttr>(result.getEncoding());
outputLayoutEnum = newOutputLayoutEnum;
}
}

ttnn::LayoutAttr outputLayout =
ttnn::LayoutAttr::get(rewriter.getContext(), outputLayoutEnum);
Expand All @@ -235,68 +212,6 @@ class ToLayoutOpConversionPattern

return success();
}

private:
bool shouldForceRowMajor(ttir::ToLayoutOp op) const {
// Check if the output tensor is used by an op that only supports row major.
//
// EmbeddingBackwardOp supports row major layout for the first and second
// operands.
for (mlir::Operation *user : op.getResult().getUsers()) {
if (isa<ttir::Conv2dOp>(user) || isa<ttir::SliceOp>(user) ||
(isa<ttir::EmbeddingBackwardOp>(user) &&
(user->getOperand(0) == op || user->getOperand(1) == op))) {
return true;
}
}

return false;
}

RankedTensorType
getLayoutForcedResultTensor(ConversionPatternRewriter &rewriter,
RankedTensorType oldOutput,
ttnn::Layout newOutputLayoutEnum) const {
auto oldOutputLayoutAttr =
mlir::cast<ttnn::TTNNLayoutAttr>(oldOutput.getEncoding());
DataType outputDtype = oldOutputLayoutAttr.getDataType();
SmallVector<std::int64_t> oldShardShape =
oldOutputLayoutAttr.getShardShape();
size_t shardShapeSize = oldShardShape.size();
assert(shardShapeSize >= 2 && "expected at least 2D shape");

if (newOutputLayoutEnum == ttnn::Layout::RowMajor) {
// Set shard shape to match convention of row major layout
auto tileType =
mlir::cast<TileType>(oldOutputLayoutAttr.getElementType());
llvm::SmallVector<int64_t> newShardShape(oldShardShape.begin(),
oldShardShape.end());
newShardShape[shardShapeSize - 2] =
oldShardShape[shardShapeSize - 2] * tileType.getHeight();
newShardShape[shardShapeSize - 1] =
oldShardShape[shardShapeSize - 1] * tileType.getWidth();
Type newElementType = ttnn::utils::createRowMajorTypeFromDtype(
rewriter.getContext(), outputDtype);
RankedTensorType result = RankedTensorType::get(
oldOutput.getShape(), oldOutput.getElementType(),
oldOutputLayoutAttr
.withElementType(rewriter.getContext(), newElementType)
.withShardShape(rewriter.getContext(), newShardShape));
return result;
}

if (newOutputLayoutEnum == ttnn::Layout::Tile) {
TileType tileType =
TileType::get(rewriter.getContext(),
{ttnn::TILE_HEIGHT, ttnn::TILE_WIDTH}, outputDtype);
RankedTensorType result = RankedTensorType::get(
oldOutput.getShape(), oldOutput.getElementType(),
oldOutputLayoutAttr.withElementType(rewriter.getContext(), tileType));
return result;
}

llvm_unreachable("Unreachable code path. Unexpected output layout enum");
}
};

template <typename TTIROpTy, typename TTNNOpTy,
Expand Down Expand Up @@ -703,24 +618,65 @@ class ConstantOpConversionPattern
matchAndRewrite(ttir::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
::mlir::ElementsAttr valueAttr = op.getValue();

LogicalResult legalityResult = checkBasicLegality(op, valueAttr, rewriter);
if (!legalityResult.succeeded()) {
return legalityResult;
}

if (valueAttr.isSplat()) {
Value device = ::ttnn::utils::getOrInsertDevice(rewriter, op);
auto outputType = mlir::cast<RankedTensorType>(op.getType());
ttnn::TTNNLayoutAttr layoutAttr =
mlir::cast<ttnn::TTNNLayoutAttr>(outputType.getEncoding());
float fillValue =
valueAttr.getElementType().isInteger()
? getIntegerValue(valueAttr)
: valueAttr.getSplatValue<mlir::APFloat>().convertToFloat();

::mlir::FloatAttr fillValueAttr = rewriter.getF32FloatAttr(fillValue);
rewriter.replaceOpWithNewOp<ttnn::FullOp>(
op, this->getTypeConverter()->convertType(op.getType()), device,
fillValueAttr);

if (outputType.getRank() > 1 ||
!mlir::isa<TileType>(layoutAttr.getElementType())) {
rewriter.replaceOpWithNewOp<ttnn::FullOp>(
op, this->getTypeConverter()->convertType(outputType), device,
fillValueAttr);
} else {
// ttnn::FullOp does not support 1D tilized tensors
// If the output of full is a 1D tensor and is tiled
// we need to convert it to row major layout then tilize separately

// Can't use withElementType because the shard shape would be wrong
ttnn::TTNNLayoutAttr rowMajorLayoutAttr = ttnn::TTNNLayoutAttr::get(
rewriter.getContext(), outputType.getShape(),
layoutAttr.getScalarElementType(), layoutAttr.getBufferType(),
layoutAttr.getGrid(), layoutAttr.getMemLayout());

auto fullOpOutputType = RankedTensorType::get(
outputType.getShape(), outputType.getElementType(),
rowMajorLayoutAttr);
auto fullOp = rewriter.create<ttnn::FullOp>(
op.getLoc(),
this->getTypeConverter()->convertType(fullOpOutputType), device,
fillValueAttr);

// Tilize the fullOp output separately
ttnn::MemoryConfigAttr memConfigAttr =
rewriter.getAttr<ttnn::MemoryConfigAttr>(
rewriter.getAttr<ttnn::BufferTypeAttr>(
layoutAttr.getBufferType()),
rewriter.getAttr<ttnn::ShardSpecAttr>(
rewriter.getAttr<ttnn::ShapeAttr>(
layoutAttr.getShardShape())),
layoutAttr.getMemLayout());

bool isOutputOnHost =
(layoutAttr.getBufferType() == ttnn::BufferType::SystemMemory);

rewriter.replaceOpWithNewOp<ttnn::ToLayoutOp>(
op, this->getTypeConverter()->convertType(op.getType()), fullOp,
ttnn::Layout::Tile,
DataTypeAttr::get(rewriter.getContext(), layoutAttr.getDataType()),
memConfigAttr, isOutputOnHost ? nullptr : device);
}
} else {
return rewriter.notifyMatchFailure(
op, "TTNN doesn't currently support tensor creation from multiple "
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/TT/IR/TTOpsTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ mlir::tt::SystemDescAttr::getDefault(MLIRContext *context) {
mlir::tt::SystemDescAttr
mlir::tt::SystemDescAttr::getFromPath(MLIRContext *context, std::string &path) {
// Check if file exists
assert(!path.empty() && "cluster desc path must not be empty!");
assert(!path.empty() && "system desc path must not be empty!");
std::ifstream fbb(path, std::ios::binary | std::ios::ate);
assert(fbb.good() && "cluster desc does not exist!");
assert(fbb.good() && "system desc does not exist!");
std::streampos size = fbb.tellg();
fbb.seekg(0, std::ios::beg);
auto buffer = std::shared_ptr<void>(std::malloc(size), std::free);
Expand Down
19 changes: 19 additions & 0 deletions lib/Dialect/TTNN/Analysis/BFInterleavedPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,25 @@ void BFInterleavedPolicy::run() {
scheduler.scheduleOp(nextOpForScheduling);
}

// TODO (#0000): This is a temporary solution
// Currently ReturnOps are not considered when calculating L1 usage
llvm::SmallVector<mlir::Operation *> eraseableL1UsageOps;
for (auto &[op, usage] : currentL1UsagePerOp) {
for (Operation *user : op->getUsers()) {
if (isa<mlir::func::ReturnOp>(user)) {
usage.numOfUnscheduledUsers -= 1;
}
}
if (usage.numOfUnscheduledUsers == 0) {
eraseableL1UsageOps.push_back(op);
}
}

for (Operation *op : eraseableL1UsageOps) {
currentL1Usage -= currentL1UsagePerOp[op].l1MemUsagePerUser;
currentL1UsagePerOp.erase(op);
}

assert(currentL1Usage == 0);
assert(currentL1UsagePerOp.size() == 0);

Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/TTNN/Transforms/Optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,22 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
// If schedule is set, apply order of operations to func.
//
if (opSchedule[func].size() > 1) {
// TODO (#0000): This is a temporary solution - when defaulting to dram
// tile input/output layout, GetDeviceOp can randomly appear as the last
// op in the graph instead of the first. This workaround ensures
// getDeviceOp is always in the beginning of the schedule.
// To reproduce, remove this workaround and run
// Silicon/TTNN/optimizer/mnist_sharding.mlir multiple times (as it is
// non-deterministic).
Operation **it =
std::find_if(opSchedule[func].begin(), opSchedule[func].end(),
[](Operation *op) { return isa<GetDeviceOp>(op); });
if (it != opSchedule[func].end()) {
GetDeviceOp deviceOp = mlir::cast<GetDeviceOp>(*it);
opSchedule[func].erase(it);
opSchedule[func].insert(opSchedule[func].begin(), deviceOp);
}

for (size_t i = 0; i < opSchedule[func].size() - 1; i++) {
Operation *op = opSchedule[func][i];

Expand Down
54 changes: 33 additions & 21 deletions lib/Dialect/TTNN/Transforms/TTNNDecomposeLayouts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ class TTNNDecomposeLayouts
});
});
for (Operation *op : opsToReplace) {
this->createLayoutConversionOps(mlir::cast<ttnn::ToLayoutOp>(op),
rewriter);
if (failed(createLayoutConversionOps(mlir::cast<ttnn::ToLayoutOp>(op),
rewriter))) {
signalPassFailure();
return;
}
rewriter.eraseOp(op);
}
}
Expand All @@ -42,6 +45,7 @@ class TTNNDecomposeLayouts
ttnn::Layout layoutEnum;
DataType dataType;
ttnn::TensorMemoryLayoutAttr tensorMemoryLayout;
GridAttr shardGrid;
llvm::SmallVector<int64_t> shardShape;

ttnn::MemoryConfigAttr createMemoryConfigAttr(MLIRContext *context) const {
Expand All @@ -51,7 +55,9 @@ class TTNNDecomposeLayouts
ttnn::ShapeAttr::get(context, shardShape)),
tensorMemoryLayout);
}

bool isL1Sharded() const {
return isShardedMemoryLayout(tensorMemoryLayout.getValue());
}
bool isOnHost() const {
return bufferType == ttnn::BufferType::SystemMemory;
}
Expand Down Expand Up @@ -115,6 +121,9 @@ class TTNNDecomposeLayouts
auto inputLayoutAttr =
mlir::cast<TTNNLayoutAttr>(op.getInput().getType().getEncoding());

auto outputLayoutAttr =
mlir::cast<TTNNLayoutAttr>(op.getResult().getType().getEncoding());

assert(op.getMemoryConfig().has_value());
MemoryConfigAttr outputMemoryConfig = op.getMemoryConfig().value();

Expand All @@ -131,9 +140,12 @@ class TTNNDecomposeLayouts
input.tensorMemoryLayout = inputLayoutAttr.getMemLayout();
output.tensorMemoryLayout = outputMemoryConfig.getTensorMemoryLayout();

input.shardGrid = inputLayoutAttr.getGrid();
output.shardGrid = outputLayoutAttr.getGrid();

input.shardShape = inputLayoutAttr.getShardShape();
output.shardShape =
llvm::SmallVector<int64_t>{outputMemoryConfig.getShardShapeArray()};
output.shardShape = outputLayoutAttr.getShardShape();

return {input, output};
}

Expand All @@ -148,14 +160,6 @@ class TTNNDecomposeLayouts

opsToCreate.createTypecastOp = input.dataType != output.dataType;
opsToCreate.createToLayoutOp = input.layoutEnum != output.layoutEnum;
// TODO(bug #665):
// Insert a ToLayoutOp manually if we're moving from device to host to
// untilize. Since we're hardcoding tile layout, the tensor may be row
// major in mlir, and therefore it would appear as if we don't need to
// untilize
opsToCreate.createToLayoutOp |=
(opsToCreate.createFromDeviceOp and
output.layoutEnum == ttnn::Layout::RowMajor);

// ToDeviceOp can handle the creation of the memory config of the initial
// device tensor
Expand All @@ -168,8 +172,10 @@ class TTNNDecomposeLayouts
output.bufferType == ttnn::BufferType::L1) or
(input.bufferType == ttnn::BufferType::L1 and
output.bufferType == ttnn::BufferType::DRAM);
// If shard grids don't match we need to reshard
opsToCreate.createToMemoryConfigOp |=
(input.shardShape != output.shardShape);
(input.isL1Sharded() and output.isL1Sharded() and
input.shardGrid != output.shardGrid);
}
return opsToCreate;
}
Expand Down Expand Up @@ -764,24 +770,30 @@ class TTNNDecomposeLayouts
* sizeof(uint32_t). For now, we will always untilize on host. We rarely
* need device to device untilize, so the perf hit should be acceptable.
*/
void createLayoutConversionOps(ttnn::ToLayoutOp op,
IRRewriter &rewriter) const {
mlir::LogicalResult createLayoutConversionOps(ttnn::ToLayoutOp op,
IRRewriter &rewriter) const {
auto [input, output] = getInputOutputLayouts(op);
OpsToCreate opsToCreate = determineRequiredOps(input, output);
assert(isCreationValid(op, input, output, opsToCreate) &&
"Invalid layout conversion");
if (not isCreationValid(op, input, output, opsToCreate)) {
return failure();
}

auto device = op.getDevice();
assert((device || output.isOnHost()) &&
"Op device must be set for output tensors on device");
if (not device and not output.isOnHost()) {
op->emitError("Device not specified for device tensor");
return failure();
}

OpCreationInfo info(device, input, output, opsToCreate);

Value currentInput = op.getInput();

if (input.isOnHost()) {
handleHostInputLayoutConversion(op, rewriter, currentInput, info);
return;
return success();
}
handleDeviceInputLayoutConversion(op, rewriter, currentInput, info);
return success();
}
};
} // namespace mlir::tt::ttnn
Loading

0 comments on commit d4c5383

Please sign in to comment.