From e0cc909bc0b3db8e95d8655074273c232707ce6b Mon Sep 17 00:00:00 2001 From: Jie Liu Date: Wed, 8 Apr 2026 17:10:25 -0700 Subject: [PATCH] Preserve memdesc_reshape encoding on type propagation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When replaceUsesAndPropagateType recreates a MemDescReshapeOp (e.g. during aref insertion), the old code re-inferred the encoding from the source. This silently changed encodings like nvmma_shared to shared_linear, since inferMemDescReshapeOpEncoding always produces shared_linear when the source has shared_linear encoding — even though the two may have equivalent LinearLayouts. Call inferReturnTypes to compute the correct allocShape from the new source, then preserve the original reshape's encoding. Co-authored-by: Claude --- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 24 ++++++++++++++-- test/NVWS/insert_aref.mlir | 29 ++++++++++++++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index f8fe3361ce0d..85fd640d5cc5 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -1456,9 +1456,27 @@ void replaceUsesAndPropagateType( newVal = ttg::MemDescTransOp::create(builder, trans.getLoc(), val, trans.getOrder()); } else if (auto reshape = dyn_cast(user)) { - auto shape = reshape.getType().getShape(); - newVal = - ttg::MemDescReshapeOp::create(builder, reshape.getLoc(), val, shape); + // Use inferReturnTypes to compute the correct allocShape and mutability + // from the new source, but preserve the original reshape's encoding + // rather than re-inferring it (which can change e.g. nvmma_shared to + // shared_linear). + ttg::MemDescType inferredType; + LogicalResult result = ttg::MemDescReshapeOp::inferReturnTypes( + builder.getContext(), reshape.getLoc(), + cast(val.getType()), + reshape.getType().getShape(), inferredType); + assert(succeeded(result) && "failed to infer reshape return type"); + assert(ttg::areLayoutsEquivalent( + inferredType.getShape(), + cast(reshape.getType().getEncoding()), + cast(inferredType.getEncoding())) && + "preserved encoding is not equivalent to inferred encoding"); + Type newDstType = ttg::MemDescType::get( + inferredType.getShape(), inferredType.getElementType(), + reshape.getType().getEncoding(), inferredType.getMemorySpace(), + inferredType.getMutableMemory(), inferredType.getAllocShape()); + newVal = ttg::MemDescReshapeOp::create(builder, reshape.getLoc(), + newDstType, val); } assert(newVal && "unhandled memdesc view"); newVal.getDefiningOp()->setAttrs(user->getAttrs()); diff --git a/test/NVWS/insert_aref.mlir b/test/NVWS/insert_aref.mlir index 90cf2b19d5be..57bdfe3e148e 100644 --- a/test/NVWS/insert_aref.mlir +++ b/test/NVWS/insert_aref.mlir @@ -779,3 +779,32 @@ tt.func @aref_result_outside_scheduled_loop(%lb: i32, %ub: i32, %step: i32) { tt.return } } + +// ----- + +// Test that memdesc_reshape preserves nvmma_shared encoding after aref insertion. +// A 3D shared_linear alloc is reshaped to 2D nvmma_shared and fed to a TMA store. +// Without the fix, replaceUsesAndPropagateType re-infers the encoding as +// shared_linear, which fails the TMA store verifier ("TMA descriptor must have +// NVMMA shared layout"). + +#blocked3d = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [2, 2, 1], order = [2, 1, 0]}> +#sl3d = #ttg.shared_linear<{offset = [[0, 0, 1], [0, 0, 2], [0, 0, 4], [0, 0, 8], [0, 0, 16], [0, 0, 32], [1, 0, 8], [2, 0, 16], [4, 0, 32], [8, 0, 0], [16, 0, 0], [32, 0, 0], [64, 0, 0], [0, 1, 0]]}, alignment = 1024> +#nvmma = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { + // CHECK-LABEL: @reshape_preserves_encoding + tt.func @reshape_preserves_encoding(%src: tensor<128x2x64xbf16, #blocked3d>, + %desc: !tt.tensordesc<128x128xbf16, #nvmma>, + %lb: i32, %ub: i32, %step: i32) { + %c0 = arith.constant 0 : i32 + scf.for %iv = %lb to %ub step %step : i32 { + %alloc = ttg.local_alloc %src {ttg.partition = array} : (tensor<128x2x64xbf16, #blocked3d>) -> !ttg.memdesc<128x2x64xbf16, #sl3d, #smem> + %reshaped = ttg.memdesc_reshape %alloc {ttg.partition = array} : !ttg.memdesc<128x2x64xbf16, #sl3d, #smem> -> !ttg.memdesc<128x128xbf16, #nvmma, #smem> + // CHECK: ttng.async_tma_copy_local_to_global + ttng.async_tma_copy_local_to_global %desc[%c0, %c0] %reshaped {ttg.partition = array} : !tt.tensordesc<128x128xbf16, #nvmma>, !ttg.memdesc<128x128xbf16, #nvmma, #smem> + } {tt.warp_specialize, ttg.partition = array, ttg.partition.stages = [0, 2], ttg.warp_specialize.tag = 0 : i32} + tt.return + } +}