Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ASAN build fixes for TTNNLayoutAttr class #1566

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> {
uint64_t getElementSizeBytes() const;
int64_t getTensorSizeInBytes(ArrayRef<int64_t> tensorShape, ::mlir::tt::DeviceAttr device) const;
llvm::SmallVector<int64_t> getStride(ArrayRef<int64_t> logicalShape) const;
llvm::SmallVector<int64_t> getShardShape() const;
llvm::ArrayRef<int64_t> getShardShape() const;
llvm::SmallVector<int64_t> getScalarShardShape() const;
AffineMap getIdentityTileLinearMap() const;
llvm::SmallVector<int64_t> getTiledShape(ArrayRef<int64_t> logicalTensorShape) const;
Expand Down
8 changes: 3 additions & 5 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class TensorEmptyConversionPattern
// Create MemoryConfigAttr
//
auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op);
llvm::SmallVector<int64_t> shardShape = layoutAttr.getShardShape();
llvm::ArrayRef<int64_t> shardShape = layoutAttr.getShardShape();
ttnn::MemoryConfigAttr memoryConfigAttr = ttnn::MemoryConfigAttr::get(
op.getContext(), ttnn::BufferTypeAttr::get(op.getContext(), bufferType),
ttnn::ShardSpecAttr::get(
Expand Down Expand Up @@ -154,8 +154,7 @@ class ToLayoutOpConversionPattern

ttnn::LayoutAttr outputLayout =
ttnn::LayoutAttr::get(rewriter.getContext(), outputLayoutEnum);
llvm::SmallVector<int64_t> outputShardShape =
outputLayoutAttr.getShardShape();
llvm::ArrayRef<int64_t> outputShardShape = outputLayoutAttr.getShardShape();

ttnn::MemoryConfigAttr outputMemConfigAttr = ttnn::MemoryConfigAttr::get(
rewriter.getContext(),
Expand Down Expand Up @@ -193,8 +192,7 @@ class ToLayoutOpConversionPattern
auto oldOutputLayoutAttr =
mlir::cast<ttnn::TTNNLayoutAttr>(oldOutput.getEncoding());
DataType outputDtype = oldOutputLayoutAttr.getDataType();
SmallVector<std::int64_t> oldShardShape =
oldOutputLayoutAttr.getShardShape();
ArrayRef<std::int64_t> oldShardShape = oldOutputLayoutAttr.getShardShape();
size_t shardShapeSize = oldShardShape.size();
assert(shardShapeSize >= 2 && "expected at least 2D shape");

Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ ::mlir::LogicalResult mlir::tt::ttnn::ToMemoryConfigOp::verify() {
if (not outputLayout.hasShardedL1TensorMemoryLayout()) {
return emitOpError("Sharded tensors layout must reside in L1");
}
::llvm::SmallVector<int64_t> shardShape = outputLayout.getShardShape();
::llvm::ArrayRef<int64_t> shardShape = outputLayout.getShardShape();
// Currently TTNN backend only supports 2D shard shape
if (shardShape.size() != 2) {
return emitOpError("Shard shape must be 2D");
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ uint64_t TTNNLayoutAttr::getElementSizeBytes() const {
// Example: memref<2x3!tt.tile<32x32xf32>> -> { 2, 3 }
//
// return The shape of the shard.
llvm::SmallVector<int64_t> TTNNLayoutAttr::getShardShape() const {
Copy link
Contributor

@uazizTT uazizTT Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it requires to use a different constructor to make SmallVector to make an owning SmallVector, the following should resolve it:

llvm::SmallVector<int64_t> TTNNLayoutAttr::getShardShape() const {
  return SmallVector<int64_t>(getMemref().getShape().begin(), getMemref().getShape().end());
}

We saw a similar issue before. https://github.com/tenstorrent/tt-mlir/pull/882/files

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm look in the source code include/llvm/ADT/SmallVector.h:1232 and this doesn't seem like an issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok you are right. In that case, I wonder the cause of this issue as it appears to be already making an owning copy?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've looked into that PR, difference is that options.meshShape is ListOption, but the same effect could've been achieved with implicitDeviceOptions.meshShape = ::llvm::SmallVector<int64_t>(*options.meshShape); My guess is that the big part behind the decision that all SmallVector constructors are marked as explicit it exactly for the reason where you say "Yes, I know SmallVector will a acquire new memory and make a copy, and I agree to that."

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding issue in this PR, the root cause is in thelib/Dialect/TTNN/Transforms/Passes.cpp:235. In input.shardShape = inputLayoutAttr.getShardShape();, inputLayoutAttr.getShardShape() returns SmallVector. That SmallVector is an unnamed and temporary object, so it dies at the end of the statement, and input.shardShape as an ArrayRef gets a dangling reference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thanks, yeah that was my question that how the changes in this PR are resolving the issue as the output for getShardShape() is already owning. But from your explanation it seems the issue propagated from lib/Dialect/TTNN/Transforms/Passes.cpp:235.

return SmallVector<int64_t>(getMemref().getShape());
llvm::ArrayRef<int64_t> TTNNLayoutAttr::getShardShape() const {
return getMemref().getShape();
}

// Get scalar shard shape
Expand Down Expand Up @@ -249,7 +249,7 @@ TTNNLayoutAttr::getTiledShape(llvm::ArrayRef<int64_t> tensorShape) const {
//
// return The size of the shard in bytes.
uint64_t TTNNLayoutAttr::getShardSizeInBytes() const {
SmallVector<int64_t> shape = getShardShape();
llvm::ArrayRef<int64_t> shape = getShardShape();
uint64_t size = getElementSizeBytes();
return std::accumulate(shape.begin(), shape.end(), size,
std::multiplies<uint64_t>());
Expand Down Expand Up @@ -277,7 +277,7 @@ mlir::AffineMap TTNNLayoutAttr::getIdentityTileLinearMap() const {
// return New memory map with symbols replaced with shard shape.
mlir::AffineMap TTNNLayoutAttr::replaceMemoryMapSymbolsWithShardShape(
AffineMap physicalMemoryMap) const {
mlir::SmallVector<int64_t> shardShape = getShardShape();
llvm::ArrayRef<int64_t> shardShape = getShardShape();
assert(physicalMemoryMap.getNumSymbols() == shardShape.size() &&
"Physical memory map must have same number of symbols as logical "
"shard rank");
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Transforms/Optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
TensorMemoryLayoutAttr outputTensorMemoryLayoutAttr =
consumerOpOutputLayout.getMemLayout();

llvm::SmallVector<int64_t> shardShape =
llvm::ArrayRef<int64_t> shardShape =
consumerOpOutputLayout.getShardShape();
MemoryConfigAttr outputMemConfigAttr = MemoryConfigAttr::get(
consumerOp->getContext(),
Expand Down
Loading