Skip to content

Commit

Permalink
[TTNN] Adding support for data type workarounds and introducing Embed…
Browse files Browse the repository at this point in the history
…ding workarounds (#1583)

This PR introduces a solution for handling data type workarounds for
operation operands and results. To address input operand data type
workarounds, we insert a `toLayout` operation between the input operands
and the operation itself. This casts the input to the desired data type.
If the data type of the output result changes due to a workaround, we
will revert it to the previous data type by inserting a `ToLayoutOp`
after the operation's output.

Additionally, this PR provides necessary workarounds to ensure that the
embedding operation functions correctly. Specifically, it changes the
input to an RM layout and casts both the input weight and the output to
bf16. Other ops will be onboarded to this type of workaround in a
separate PR.

Example of IR today:
```mlir
module attributes {tt.device = #device, tt.system_desc = #system_desc} {
  func.func @forward(%arg0: tensor<32x32xf32, #ttnn_layout>, %arg1: tensor<512x128xf32, #ttnn_layout1>) -> tensor<32x32x128xf32, #ttnn_layout2> {
    %0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
    %1 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<32x4>>, <interleaved>>, shape = #ttnn.shape<32x32x128>}> : (!tt.device<#device>) -> tensor<32x32x128xf32, #ttnn_layout3>
    %2 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#dram, <<32x32>>, <interleaved>>}> : (tensor<32x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout4>
    %3 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<16x4>>, <interleaved>>}> : (tensor<512x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<512x128xf32, #ttnn_layout5>
    %4 = "ttnn.to_layout"(%1, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<32x4>>, <interleaved>>}> : (tensor<32x32x128xf32, #ttnn_layout3>, !tt.device<#device>) -> tensor<32x32x128xf32, #ttnn_layout6>
    %5 = "ttnn.embedding"(%2, %3, %4) : (tensor<32x32xf32, #ttnn_layout4>, tensor<512x128xf32, #ttnn_layout5>, tensor<32x32x128xf32, #ttnn_layout6>) -> tensor<32x32x128xf32, #ttnn_layout6>
    %6 = "ttnn.to_layout"(%5) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#system_memory, <<1024x128>>>}> : (tensor<32x32x128xf32, #ttnn_layout6>) -> tensor<32x32x128xf32, #ttnn_layout2>
    return %6 : tensor<32x32x128xf32, #ttnn_layout2>
  }
}
```

An example of IR with this change where embedding op has bf16 workaround
applied for weight operand:
```mlir
module attributes {tt.device = #device, tt.system_desc = #system_desc} {
  func.func @forward(%arg0: tensor<32x32xf32, #ttnn_layout>, %arg1: tensor<512x128xf32, #ttnn_layout1>) -> tensor<32x32x128xf32, #ttnn_layout2> {
    %0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
    %1 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<32x4>>, <interleaved>>, shape = #ttnn.shape<32x32x128>}> : (!tt.device<#device>) -> tensor<32x32x128xf32, #ttnn_layout3>
    %2 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#dram, <<32x32>>, <interleaved>>}> : (tensor<32x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout4>
    %3 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<16x4>>, <interleaved>>}> : (tensor<512x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<512x128xbf16, #ttnn_layout5>
    %4 = "ttnn.to_layout"(%1, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<32x4>>, <interleaved>>}> : (tensor<32x32x128xf32, #ttnn_layout3>, !tt.device<#device>) -> tensor<32x32x128xbf16, #ttnn_layout6>
    %5 = "ttnn.embedding"(%2, %3, %4) : (tensor<32x32xf32, #ttnn_layout4>, tensor<512x128xbf16, #ttnn_layout5>, tensor<32x32x128xbf16, #ttnn_layout6>) -> tensor<32x32x128xbf16, #ttnn_layout6>
    %6 = "ttnn.to_layout"(%5) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#system_memory, <<1024x128>>>}> : (tensor<32x32x128xbf16, #ttnn_layout6>) -> tensor<32x32x128xf32, #ttnn_layout2>
    return %6 : tensor<32x32x128xf32, #ttnn_layout2>
  }
}
```


- Closes #1433 
- Closes #1497 
- Closes #1215
  • Loading branch information
sdjordjevicTT authored Dec 23, 2024
1 parent 109d917 commit 9520cbb
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 29 deletions.
3 changes: 3 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,9 @@ def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> {

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
return wa::TTNNOperandsWorkaroundsFactory::createEmbeddingOpOperandsWorkarounds();
}
}];

let hasVerifier = 1;
Expand Down
40 changes: 32 additions & 8 deletions include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace mlir::tt::ttnn::wa {
using TensorLayoutWorkaround = std::optional<Layout>;
using TensorBufferTypeWorkaround = std::optional<BufferType>;
using TensorMemoryLayoutWorkaround = std::optional<TensorMemoryLayout>;
using TensorDataTypeWorkaround = std::optional<DataType>;

// Struct that encapsulates operand workarounds.
// It contains tensor layout, tensor buffer type and tensor memory layout
Expand All @@ -31,35 +32,47 @@ struct TTNNOperandWorkarounds {
// Tensor memory layout workaround.
TensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround;

// Tensor data format workaround.
TensorDataTypeWorkaround tensorDataTypeWorkaround;

// Default constructor.
TTNNOperandWorkarounds() = default;

// Constructor that takes tensor layout, tensor buffer type and tensor memory.
TTNNOperandWorkarounds(
TensorLayoutWorkaround tensorLayoutWorkaround,
TensorBufferTypeWorkaround tensorBufferTypeWorkaround,
TensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround)
TensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround,
TensorDataTypeWorkaround tensorDataTypeWorkaround)
: tensorLayoutWorkaround(tensorLayoutWorkaround),
tensorBufferTypeWorkaround(tensorBufferTypeWorkaround),
tensorMemoryLayoutWorkaround(tensorMemoryLayoutWorkaround) {}
tensorMemoryLayoutWorkaround(tensorMemoryLayoutWorkaround),
tensorDataTypeWorkaround(tensorDataTypeWorkaround) {}

// Constructor that takes tensor layout workaround and sets the other
// workarounds to nullopt.
TTNNOperandWorkarounds(TensorLayoutWorkaround tensorLayoutWorkaround)
: TTNNOperandWorkarounds(tensorLayoutWorkaround, std::nullopt,
std::nullopt) {}
std::nullopt, std::nullopt) {}

// Constructor that takes tensor buffer type workaround and sets the other
// workarounds to nullopt.
TTNNOperandWorkarounds(TensorBufferTypeWorkaround tensorBufferTypeWorkaround)
: TTNNOperandWorkarounds(std::nullopt, tensorBufferTypeWorkaround,
std::nullopt) {}
std::nullopt, std::nullopt) {}

// Constructor that takes tensor memory layout workaround and sets the other
// workarounds to nullopt.
TTNNOperandWorkarounds(
TensorMemoryLayoutWorkaround tensorMemoryLayoutWorkaround)
: TTNNOperandWorkarounds(std::nullopt, std::nullopt,
tensorMemoryLayoutWorkaround) {}
tensorMemoryLayoutWorkaround, std::nullopt) {}

// Constructor that takes tensor data type workaround and sets the other
// workarounds to nullopt.
TTNNOperandWorkarounds(TensorDataTypeWorkaround tensorDataTypeWorkaround)
: TTNNOperandWorkarounds(std::nullopt, std::nullopt, std::nullopt,
tensorDataTypeWorkaround) {}

// Operand workarounds factory methods.
static TTNNOperandWorkarounds createEmptyTTNNOperandWorkarounds();
Expand All @@ -68,7 +81,8 @@ struct TTNNOperandWorkarounds {
bool operator==(const TTNNOperandWorkarounds &rhs) const {
return tensorLayoutWorkaround == rhs.tensorLayoutWorkaround &&
tensorBufferTypeWorkaround == rhs.tensorBufferTypeWorkaround &&
tensorMemoryLayoutWorkaround == rhs.tensorMemoryLayoutWorkaround;
tensorMemoryLayoutWorkaround == rhs.tensorMemoryLayoutWorkaround &&
tensorDataTypeWorkaround == rhs.tensorDataTypeWorkaround;
}

// Inequality operator.
Expand All @@ -79,7 +93,7 @@ struct TTNNOperandWorkarounds {
// Returns true if any of the workarounds is set.
bool hasAnyWorkaround() const {
return tensorLayoutWorkaround || tensorBufferTypeWorkaround ||
tensorMemoryLayoutWorkaround;
tensorMemoryLayoutWorkaround || tensorDataTypeWorkaround;
}
};

Expand All @@ -103,6 +117,9 @@ struct BufferTypeWorkaroundResult : public WorkaroundResult<BufferType> {};
struct MemoryLayoutWorkaroundResult
: public WorkaroundResult<std::optional<TensorMemoryLayout>> {};

// Data type workaround result struct.
struct DataTypeWorkaroundResult : public WorkaroundResult<DataType> {};

// Struct that encapsulates the result of applying the workarounds.
// It contains the target tensor layout, buffer type and tensor memory layout
// results and a flag indicating whether the workarounds were applied.
Expand All @@ -116,11 +133,15 @@ struct WorkaroundResults {
// Tensor memory layout workaround result.
MemoryLayoutWorkaroundResult tensorMemoryLayoutResult;

// Tensor data type workaround result.
DataTypeWorkaroundResult tensorDataTypeResult;

// Returns true if any of the workarounds were applied.
bool isModified() const {
return tensorLayoutResult.isModified() ||
tensorBufferTypeResult.isModified() ||
tensorMemoryLayoutResult.isModified();
tensorMemoryLayoutResult.isModified() ||
tensorDataTypeResult.isModified();
}
};

Expand Down Expand Up @@ -194,6 +215,9 @@ class TTNNOperandsWorkaroundsFactory {
public:
// Create workarounds for max_pool2d op operands.
static TTNNOperandsWorkarounds createMaxPool2DOpOperandsWorkarounds();

// Create workarounds for embedding op operands.
static TTNNOperandsWorkarounds createEmbeddingOpOperandsWorkarounds();
};

} // namespace mlir::tt::ttnn::wa
Expand Down
3 changes: 2 additions & 1 deletion include/ttmlir/Dialect/TTNN/Utils/TransformUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ createToLayoutOp(mlir::Operation *op,
mlir::TypedValue<RankedTensorType> inputValue,
PatternRewriter &rewriter, Layout targetTensorLayout,
BufferType targetTensorBufferType,
std::optional<TensorMemoryLayout> targetTensorMemoryLayout);
std::optional<TensorMemoryLayout> targetTensorMemoryLayout,
DataType targetTensorDataType);
} // namespace mlir::tt::ttnn::utils

#endif
7 changes: 6 additions & 1 deletion include/ttmlir/Dialect/TTNN/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,16 @@ toTTMemorySpace(const mlir::tt::ttnn::BufferType bufferType);
mlir::Type createRowMajorTypeFromDtype(::mlir::MLIRContext *context,
DataType dtype);

// Helper method to create a RankedTensorType with the given encoding
// Helper method to create a RankedTensorType with the given encoding.
RankedTensorType
createRankedTensorTypeWithEncoding(RankedTensorType tensorType,
ttnn::TTNNLayoutAttr encoding);

// Helper method to create a RankedTensorType with the given element type.
RankedTensorType
createRankedTensorTypeWithElementType(RankedTensorType tensorType,
Type elementType);

// Return the L1 memory usage of the output tensor of the given op.
// Used within L1 interleaved policies.
//
Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ class ToLayoutOpConversionPattern
// operands.
for (mlir::Operation *user : op.getResult().getUsers()) {
if (isa<ttir::Conv2dOp>(user) || isa<ttir::SliceOp>(user) ||
isa<ttir::EmbeddingOp>(user) ||
(isa<ttir::EmbeddingBackwardOp>(user) &&
(user->getOperand(0) == op || user->getOperand(1) == op))) {
return true;
Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ WorkaroundResults applyWorkarounds(const TTNNOperandWorkarounds &workaround,
results.tensorMemoryLayoutResult.previousValue =
inputLayoutAttr.getMemLayoutOpt();

results.tensorDataTypeResult.targetValue =
workaround.tensorDataTypeWorkaround.value_or(
inputLayoutAttr.getDataType());
results.tensorDataTypeResult.previousValue = inputLayoutAttr.getDataType();

return results;
}

Expand Down Expand Up @@ -87,4 +92,29 @@ TTNNOperandsWorkaroundsFactory::createMaxPool2DOpOperandsWorkarounds() {
.addInputOperandWorkaround(rowMajorLayoutWorkaround)
.addOutputOperandWorkaround(rowMajorLayoutWorkaround);
}

// Factory method to create a set of workarounds for embedding operation
// operands. The embedding operation expects the input to be in row-major layout
// and the weight operand to use the bf16 data type. Since the output of the
// embedding operation follows the same format as the weight operand, the same
// workaround is applied to the output operand.
//
// Metal issue for input operand workaround:
// https://github.com/tenstorrent/tt-metal/issues/14915
//
// Metal issue weight operand workaround:
// TBD
TTNNOperandsWorkarounds
TTNNOperandsWorkaroundsFactory::createEmbeddingOpOperandsWorkarounds() {
// Create input and weight workarounds.
TTNNOperandWorkarounds inputWorkaround =
TTNNOperandWorkarounds(Layout::RowMajor);
TTNNOperandWorkarounds weightWorkaround =
TTNNOperandWorkarounds(DataType::BFloat16);
return TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds(0, 0)
.addInputOperandWorkaround(inputWorkaround)
.addInputOperandWorkaround(weightWorkaround)
.addInputOperandWorkaround(weightWorkaround)
.addOutputOperandWorkaround(weightWorkaround);
}
} // namespace mlir::tt::ttnn::wa
39 changes: 28 additions & 11 deletions lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"

#include <optional>
#include <tuple>
Expand Down Expand Up @@ -58,7 +58,8 @@ static void revertOutputLayout(wa::TTNNWorkaroundInterface &op,
op.getOperation(), newOpResult, rewriter,
workaroundResults.tensorLayoutResult.previousValue,
workaroundResults.tensorBufferTypeResult.previousValue,
workaroundResults.tensorMemoryLayoutResult.previousValue);
workaroundResults.tensorMemoryLayoutResult.previousValue,
workaroundResults.tensorDataTypeResult.previousValue);

// Replace the new output result with the casted output result.
rewriter.replaceUsesWithIf(
Expand Down Expand Up @@ -94,7 +95,8 @@ static bool workaroundInputOperand(
op.getOperation(), inputValue, rewriter,
inputWorkaroundResults.tensorLayoutResult.targetValue,
inputWorkaroundResults.tensorBufferTypeResult.targetValue,
inputWorkaroundResults.tensorMemoryLayoutResult.targetValue);
inputWorkaroundResults.tensorMemoryLayoutResult.targetValue,
inputWorkaroundResults.tensorDataTypeResult.targetValue);

// Insert to layout op between the current op and the input operand
// to convert the input operand to the desired tensor layout, buffer type.
Expand Down Expand Up @@ -137,7 +139,7 @@ workaroundOutputOperand(mlir::TypedValue<RankedTensorType> opResult,
Type elementType = utils::getElementType(
rewriter.getContext(),
outputWorkaroundResults.tensorLayoutResult.targetValue,
opResultLayoutAttr.getDataType());
outputWorkaroundResults.tensorDataTypeResult.targetValue);

// Get the input operand type.
RankedTensorType opResultType =
Expand All @@ -151,16 +153,24 @@ workaroundOutputOperand(mlir::TypedValue<RankedTensorType> opResult,
*outputWorkaroundResults.tensorMemoryLayoutResult.targetValue)
: nullptr;

// Create the new output result type with the updated tensor layout, buffer
// type and memory layout.
// Create the new output layout attribute with the updated tensor layout,
// buffer type, memory layout and data type.
TTNNLayoutAttr newOutputLayoutAttr =
opResultLayoutAttr.withElementType(rewriter.getContext(), elementType)
.withBufferType(
rewriter.getContext(),
outputWorkaroundResults.tensorBufferTypeResult.targetValue)
.withMemoryLayout(rewriter.getContext(), outputMemLayoutAttr);

// Create the new output result type with the updated data type and layout.
RankedTensorType newOutputResultType =
ttnn::utils::createRankedTensorTypeWithEncoding(
opResultType,
opResultLayoutAttr.withElementType(rewriter.getContext(), elementType)
.withBufferType(
ttnn::utils::createRankedTensorTypeWithElementType(
opResultType,
ttnn::utils::createRowMajorTypeFromDtype(
rewriter.getContext(),
outputWorkaroundResults.tensorBufferTypeResult.targetValue)
.withMemoryLayout(rewriter.getContext(), outputMemLayoutAttr));
outputWorkaroundResults.tensorDataTypeResult.targetValue)),
newOutputLayoutAttr);

// Update the type of result with applied workarounds.
rewriter.modifyOpInPlace(op, [&]() {
Expand All @@ -176,6 +186,13 @@ workaroundOutputOperand(mlir::TypedValue<RankedTensorType> opResult,
op->setAttr("layout", updatedLayoutAttr);
}

if (outputWorkaroundResults.tensorDataTypeResult.isModified() &&
op->getAttrDictionary().get("dtype")) {
DataTypeAttr updatedDataTypeAttr = rewriter.getAttr<DataTypeAttr>(
outputWorkaroundResults.tensorDataTypeResult.targetValue);
op->setAttr("dtype", updatedDataTypeAttr);
}

if ((outputWorkaroundResults.tensorBufferTypeResult.isModified() ||
outputWorkaroundResults.tensorMemoryLayoutResult.isModified()) &&
op->getAttrDictionary().get("memory_config")) {
Expand Down
17 changes: 11 additions & 6 deletions lib/Dialect/TTNN/Utils/TransformUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@ ToLayoutOp
createToLayoutOp(Operation *op, mlir::TypedValue<RankedTensorType> inputValue,
PatternRewriter &rewriter, Layout targetTensorLayout,
BufferType targetTensorBufferType,
std::optional<TensorMemoryLayout> targetTensorMemoryLayout) {
std::optional<TensorMemoryLayout> targetTensorMemoryLayout,
DataType targetTensorDataType) {
TTNNLayoutAttr inputLayoutAttr = getLayoutAttrFromTensor(inputValue);

// Create element type based on tensor layout.
Type elementType = getElementType(rewriter.getContext(), targetTensorLayout,
inputLayoutAttr.getDataType());
targetTensorDataType);

// Create tensor memory layout attribute.
ttnn::TensorMemoryLayoutAttr outputMemLayoutAttr =
Expand All @@ -63,10 +64,14 @@ createToLayoutOp(Operation *op, mlir::TypedValue<RankedTensorType> inputValue,
.withBufferType(rewriter.getContext(), targetTensorBufferType)
.withMemoryLayout(rewriter.getContext(), outputMemLayoutAttr);

// Create the output result type with the new encoding.
// Create the output result type with the new data type and encoding.
RankedTensorType toLayoutOpResultType =
ttnn::utils::createRankedTensorTypeWithEncoding(inputToLayoutOpType,
toLayoutOpResultEncoding);
ttnn::utils::createRankedTensorTypeWithEncoding(
ttnn::utils::createRankedTensorTypeWithElementType(
inputToLayoutOpType,
utils::createRowMajorTypeFromDtype(rewriter.getContext(),
targetTensorDataType)),
toLayoutOpResultEncoding);

// Create the output memory config attribute.
ttnn::MemoryConfigAttr outputMemConfigAttr = ttnn::MemoryConfigAttr::get(
Expand All @@ -88,7 +93,7 @@ createToLayoutOp(Operation *op, mlir::TypedValue<RankedTensorType> inputValue,
return rewriter.create<ttnn::ToLayoutOp>(
op->getLoc(), toLayoutOpResultType, inputValue,
LayoutAttr::get(rewriter.getContext(), targetTensorLayout),
DataTypeAttr::get(rewriter.getContext(), inputLayoutAttr.getDataType()),
DataTypeAttr::get(rewriter.getContext(), targetTensorDataType),
outputMemConfigAttr, deviceValue);
}
} // namespace mlir::tt::ttnn::utils
10 changes: 9 additions & 1 deletion lib/Dialect/TTNN/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,22 @@ Type createRowMajorTypeFromDtype(::mlir::MLIRContext *context, DataType dtype) {
}
}

// Helper method to create a RankedTensorType with the given encoding
// Helper method to create a RankedTensorType with the given encoding.
RankedTensorType
createRankedTensorTypeWithEncoding(RankedTensorType tensorType,
ttnn::TTNNLayoutAttr encoding) {
return RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), encoding);
}

// Helper method to create a RankedTensorType with the given element type.
RankedTensorType
createRankedTensorTypeWithElementType(RankedTensorType tensorType,
Type elementType) {
return RankedTensorType::get(tensorType.getShape(), elementType,
tensorType.getEncoding());
}

uint64_t getOpOutputL1Usage(TTNNLayoutAttr opLayout) {
// In case the opLayout is not in L1 memory space, L1 memory usage is 0.
//
Expand Down
Loading

0 comments on commit 9520cbb

Please sign in to comment.