Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1456,9 +1456,27 @@ void replaceUsesAndPropagateType(
newVal = ttg::MemDescTransOp::create(builder, trans.getLoc(), val,
trans.getOrder());
} else if (auto reshape = dyn_cast<ttg::MemDescReshapeOp>(user)) {
auto shape = reshape.getType().getShape();
newVal =
ttg::MemDescReshapeOp::create(builder, reshape.getLoc(), val, shape);
// Use inferReturnTypes to compute the correct allocShape and mutability
// from the new source, but preserve the original reshape's encoding
// rather than re-inferring it (which can change e.g. nvmma_shared to
// shared_linear).
ttg::MemDescType inferredType;
LogicalResult result = ttg::MemDescReshapeOp::inferReturnTypes(
builder.getContext(), reshape.getLoc(),
cast<ttg::MemDescType>(val.getType()),
reshape.getType().getShape(), inferredType);
assert(succeeded(result) && "failed to infer reshape return type");
assert(ttg::areLayoutsEquivalent(
inferredType.getShape(),
cast<ttg::LayoutEncodingTrait>(reshape.getType().getEncoding()),
cast<ttg::LayoutEncodingTrait>(inferredType.getEncoding())) &&
"preserved encoding is not equivalent to inferred encoding");
Comment on lines +1462 to +1473
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All this is unnecessary. You can asser tthat the initial types are the same and that's it.

Copy link
Copy Markdown
Author

@Sibylau Sibylau Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the inferReturnTypes to re-compute the allocShape field which changes with aref insertion, while still keep the #ttg.nvvma_shared. The issue is that prior to this fix, the shape-only overload MemDescReshapeOp::create(builder, loc, val, shape) calls inferReturnTypes and it falls back the encoding to #ttg.shared_linear, which downstream ops such as TMA ops would reject because they check for NVMMASharedEncoding, even though verifier of MemDescReshapeOp already checks the inferred #ttg.shared_linear and #ttg.nvmma_shared are layout-equivalent.

Type newDstType = ttg::MemDescType::get(
inferredType.getShape(), inferredType.getElementType(),
reshape.getType().getEncoding(), inferredType.getMemorySpace(),
inferredType.getMutableMemory(), inferredType.getAllocShape());
Comment on lines +1474 to +1477
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also unnecessary as you are creating the same type we already had?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to use the explicit type overload MemDescReshapeOp::create(builder, loc, newDstType, val) to keep the original #ttg.nvmma_shared encoding, and the newly inferred allocShape.

newVal = ttg::MemDescReshapeOp::create(builder, reshape.getLoc(),
newDstType, val);
}
assert(newVal && "unhandled memdesc view");
newVal.getDefiningOp()->setAttrs(user->getAttrs());
Expand Down
29 changes: 29 additions & 0 deletions test/NVWS/insert_aref.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -779,3 +779,32 @@ tt.func @aref_result_outside_scheduled_loop(%lb: i32, %ub: i32, %step: i32) {
tt.return
}
}

// -----

// Test that memdesc_reshape preserves nvmma_shared encoding after aref insertion.
// A 3D shared_linear alloc is reshaped to 2D nvmma_shared and fed to a TMA store.
// Without the fix, replaceUsesAndPropagateType re-infers the encoding as
// shared_linear, which fails the TMA store verifier ("TMA descriptor must have
// NVMMA shared layout").

#blocked3d = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}>
#sl3d = #ttg.shared_linear<{offset = [[0, 0, 1], [0, 0, 2], [0, 0, 4], [0, 0, 8], [0, 0, 16], [0, 0, 32], [1, 0, 8], [2, 0, 16], [4, 0, 32], [8, 0, 0], [16, 0, 0], [32, 0, 0], [64, 0, 0], [0, 1, 0]]}, alignment = 1024>
#nvmma = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
#smem = #ttg.shared_memory

module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
// CHECK-LABEL: @reshape_preserves_encoding
tt.func @reshape_preserves_encoding(%src: tensor<128x2x64xbf16, #blocked3d>,
%desc: !tt.tensordesc<128x128xbf16, #nvmma>,
%lb: i32, %ub: i32, %step: i32) {
%c0 = arith.constant 0 : i32
scf.for %iv = %lb to %ub step %step : i32 {
%alloc = ttg.local_alloc %src {ttg.partition = array<i32: 0>} : (tensor<128x2x64xbf16, #blocked3d>) -> !ttg.memdesc<128x2x64xbf16, #sl3d, #smem>
%reshaped = ttg.memdesc_reshape %alloc {ttg.partition = array<i32: 1>} : !ttg.memdesc<128x2x64xbf16, #sl3d, #smem> -> !ttg.memdesc<128x128xbf16, #nvmma, #smem>
// CHECK: ttng.async_tma_copy_local_to_global
ttng.async_tma_copy_local_to_global %desc[%c0, %c0] %reshaped {ttg.partition = array<i32: 1>} : !tt.tensordesc<128x128xbf16, #nvmma>, !ttg.memdesc<128x128xbf16, #nvmma, #smem>
} {tt.warp_specialize, ttg.partition = array<i32: 0, 1>, ttg.partition.stages = [0, 2], ttg.warp_specialize.tag = 0 : i32}
tt.return
}
}