From 8317097968fa014b1f2f29bacee290114afc8bd0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 25 Mar 2026 04:02:00 +0900 Subject: [PATCH 01/54] Add swizzle=0 TCGen5 operand-view memdesc rewrite and lit test --- .../Transforms/OptimizeDotOperands.cpp | 130 +++++++++++++++++- test/TritonGPU/dot-operands.mlir | 39 ++++++ 2 files changed, 168 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 8fa3c6db297d..b753dfc1fe65 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -317,6 +317,133 @@ class UseShmemForScales } }; +class RewriteSwizzle0OperandViewsToMemDescForTCGen5MMA + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::nvidia_gpu::TCGen5MMAOp mmaOp, + PatternRewriter &rewriter) const override { + bool changed = false; + if (succeeded(rewriteOperand(mmaOp.getAMutable(), rewriter))) + changed = true; + if (succeeded(rewriteOperand(mmaOp.getBMutable(), rewriter))) + changed = true; + return changed ? success() : failure(); + } + +private: + struct ViewStep { + enum Kind { Reshape, Transpose } kind; + SmallVector shape; + SmallVector order; + Location loc; + }; + + LogicalResult rewriteOperand(OpOperand &operand, + PatternRewriter &rewriter) const { + Value orig = operand.get(); + auto origTy = dyn_cast(orig.getType()); + if (!origTy) + return failure(); + + SmallVector trailingMemDescTransOrder; + Value beforeTrailing = orig; + if (auto trailing = beforeTrailing.getDefiningOp()) { + trailingMemDescTransOrder.assign(trailing.getOrder().begin(), + trailing.getOrder().end()); + beforeTrailing = trailing.getSrc(); + } + + auto localAlloc = beforeTrailing.getDefiningOp(); + if (!localAlloc || !localAlloc.getSrc()) + return failure(); + + auto allocTy = cast(localAlloc.getType()); + auto allocEnc = dyn_cast(allocTy.getEncoding()); + if (!allocEnc || allocEnc.getSwizzlingByteWidth() != 0) + return failure(); + + SmallVector reverseSteps; + Value baseTensor = localAlloc.getSrc(); + while (true) { + if (auto cvt = baseTensor.getDefiningOp()) { + baseTensor = cvt.getSrc(); + continue; + } + if (auto reshape = baseTensor.getDefiningOp()) { + SmallVector shape(reshape.getType().getShape().begin(), + reshape.getType().getShape().end()); + reverseSteps.push_back(ViewStep{ + ViewStep::Reshape, std::move(shape), {}, reshape.getLoc()}); + baseTensor = reshape.getSrc(); + continue; + } + if (auto trans = baseTensor.getDefiningOp()) { + SmallVector order(trans.getOrder().begin(), + trans.getOrder().end()); + reverseSteps.push_back(ViewStep{ + ViewStep::Transpose, {}, std::move(order), trans.getLoc()}); + baseTensor = trans.getSrc(); + continue; + } + break; + } + + if (reverseSteps.empty()) + return failure(); + + auto baseTensorTy = dyn_cast(baseTensor.getType()); + if (!baseTensorTy) + return failure(); + + auto cgaLayout = CGAEncodingAttr::get1CTALayout(rewriter.getContext(), + baseTensorTy.getRank()); + auto baseEnc = NVMMASharedEncodingAttr::get( + rewriter.getContext(), /*swizzlingByteWidth=*/0, + /*transposed=*/false, allocEnc.getElementBitWidth(), + allocEnc.getFp4Padded(), cgaLayout); + auto baseMemTy = MemDescType::get( + baseTensorTy.getShape(), baseTensorTy.getElementType(), baseEnc, + allocTy.getMemorySpace(), allocTy.getMutableMemory()); + + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(localAlloc); + + Value rewritten = LocalAllocOp::create(rewriter, localAlloc.getLoc(), + baseMemTy, baseTensor); + + for (ViewStep &step : llvm::reverse(reverseSteps)) { + if (step.kind == ViewStep::Reshape) { + MemDescType reshapedTy; + if (failed(MemDescReshapeOp::inferReturnTypes( + rewriter.getContext(), step.loc, + cast(rewritten.getType()), step.shape, + reshapedTy))) + return failure(); + rewritten = + MemDescReshapeOp::create(rewriter, step.loc, reshapedTy, rewritten); + } else { + rewritten = + MemDescTransOp::create(rewriter, step.loc, rewritten, step.order); + } + } + + if (!trailingMemDescTransOrder.empty()) { + rewritten = MemDescTransOp::create(rewriter, localAlloc.getLoc(), + rewritten, trailingMemDescTransOrder); + } + + auto rewrittenTy = cast(rewritten.getType()); + if (rewrittenTy.getShape() != origTy.getShape() || + rewrittenTy.getElementType() != origTy.getElementType()) + return failure(); + + operand.assign(rewritten); + return success(); + } +}; + } // namespace #define GEN_PASS_DEF_TRITONGPUOPTIMIZEDOTOPERANDS @@ -341,7 +468,8 @@ class TritonGPUOptimizeDotOperandsPass mlir::RewritePatternSet patterns(context); patterns.add(context); patterns.add(context); - patterns.add(context); + patterns.add(context); ConvertLayoutOp::getCanonicalizationPatterns(patterns, context); if (failed(applyPatternsGreedily(m, std::move(patterns)))) signalPassFailure(); diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index cfdd2b4a84b4..73b718b2797f 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -279,6 +279,45 @@ module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num // ----- +#blockedA2 = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB2 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blockedA3 = #ttg.blocked<{sizePerThread = [1, 1, 16], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 4], order = [2, 1, 0]}> +#blockedB3 = #ttg.blocked<{sizePerThread = [1, 1, 16], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 4], order = [2, 1, 0]}> +#shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8}> +#smem = #ttg.shared_memory +#tmem0 = #ttng.tensor_memory_encoding +module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @swizzle0_operand_views + // CHECK-DAG: %[[A_LOAD:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<1x64x256xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[A_BASE:.*]] = ttg.local_alloc %[[A_LOAD]] : (tensor<1x64x256xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<1x64x256xf8E4M3FN, #{{.*}}, #smem> + // CHECK-DAG: %[[A_RS:.*]] = ttg.memdesc_reshape %[[A_BASE]] : !ttg.memdesc<1x64x256xf8E4M3FN, #{{.*}}, #smem> -> !ttg.memdesc<64x256xf8E4M3FN, #{{.*}}, #smem> + // CHECK-DAG: %[[A_TR:.*]] = ttg.memdesc_trans %[[A_RS]] {order = array} : !ttg.memdesc<64x256xf8E4M3FN, #{{.*}}, #smem> -> !ttg.memdesc<256x64xf8E4M3FN, #{{.*}}, #smem> + // CHECK-DAG: %[[B_LOAD:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<1x64x128xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_LOAD]] : (tensor<1x64x128xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<1x64x128xf8E4M3FN, #{{.*}}, #smem> + // CHECK-DAG: %[[B_RS:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x64x128xf8E4M3FN, #{{.*}}, #smem> -> !ttg.memdesc<64x128xf8E4M3FN, #{{.*}}, #smem> + // CHECK-NOT: tt.reshape + // CHECK-NOT: tt.trans + // CHECK: ttng.tc_gen5_mma %[[A_TR]], %[[B_RS]], %arg2, %true, %true + tt.func @swizzle0_operand_views( + %a_desc: !tt.tensordesc>, + %b_desc: !tt.tensordesc>, + %acc: !ttg.memdesc<256x128xf32, #tmem0, #ttng.tensor_memory>) { + %true = arith.constant true + %c0_i32 = arith.constant 0 : i32 + %a = tt.descriptor_load %a_desc[%c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<1x64x256xf8E4M3FN, #blockedA3> + %b = tt.descriptor_load %b_desc[%c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<1x64x128xf8E4M3FN, #blockedB3> + %a2d = tt.reshape %a : tensor<1x64x256xf8E4M3FN, #blockedA3> -> tensor<64x256xf8E4M3FN, #blockedB2> + %aT = tt.trans %a2d {order = array} : tensor<64x256xf8E4M3FN, #blockedB2> -> tensor<256x64xf8E4M3FN, #blockedA2> + %b2d = tt.reshape %b : tensor<1x64x128xf8E4M3FN, #blockedB3> -> tensor<64x128xf8E4M3FN, #blockedB2> + %a_s = ttg.local_alloc %aT : (tensor<256x64xf8E4M3FN, #blockedA2>) -> !ttg.memdesc<256x64xf8E4M3FN, #shared0, #smem> + %b_s = ttg.local_alloc %b2d : (tensor<64x128xf8E4M3FN, #blockedB2>) -> !ttg.memdesc<64x128xf8E4M3FN, #shared0, #smem> + ttng.tc_gen5_mma %a_s, %b_s, %acc, %true, %true : !ttg.memdesc<256x64xf8E4M3FN, #shared0, #smem>, !ttg.memdesc<64x128xf8E4M3FN, #shared0, #smem>, !ttg.memdesc<256x128xf32, #tmem0, #ttng.tensor_memory> + tt.return + } +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 32], warpsPerCTA = [1, 2, 2, 1], order = [3, 2, 1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 32}> From 19398572b9d4099894e89846e4432dfeaf01e624 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 24 Mar 2026 20:10:45 +0000 Subject: [PATCH 02/54] cmake fix --- third_party/nvidia/CMakeLists.txt | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/third_party/nvidia/CMakeLists.txt b/third_party/nvidia/CMakeLists.txt index 22d6a9e397f4..506cba3b2cc6 100644 --- a/third_party/nvidia/CMakeLists.txt +++ b/third_party/nvidia/CMakeLists.txt @@ -18,12 +18,34 @@ if(TRITON_BUILD_PYTHON_MODULE) message(FATAL_ERROR "clang++ is required to build gsan.ll") endif() + if(DEFINED TRITON_CUDART_PATH AND NOT "${TRITON_CUDART_PATH}" STREQUAL "") + set(GSAN_RUNTIME_CUDA_PATH "${TRITON_CUDART_PATH}") + else() + set(GSAN_RUNTIME_CUDA_PATH "${CMAKE_CURRENT_SOURCE_DIR}/backend") + endif() set(GSAN_RUNTIME_PLATFORM_FLAGS - "--cuda-path=${CMAKE_CURRENT_SOURCE_DIR}/backend") + "--cuda-path=${GSAN_RUNTIME_CUDA_PATH}") if(APPLE) list(APPEND GSAN_RUNTIME_PLATFORM_FLAGS -isysroot "${CMAKE_OSX_SYSROOT}") endif() + set(GSAN_RUNTIME_TOOLCHAIN_FLAGS) + set(GSAN_HOST_GNU_CXX "${CMAKE_CXX_COMPILER}") + if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + find_program(GSAN_HOST_GNU_CXX NAMES g++ c++) + endif() + if(GSAN_HOST_GNU_CXX) + execute_process( + COMMAND "${GSAN_HOST_GNU_CXX}" -print-file-name=libstdc++.so + OUTPUT_VARIABLE LIBSTDCXX_PATH + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + if(IS_ABSOLUTE "${LIBSTDCXX_PATH}") + get_filename_component(GCC_INSTALL_DIR "${LIBSTDCXX_PATH}" DIRECTORY) + list(APPEND GSAN_RUNTIME_TOOLCHAIN_FLAGS "--gcc-install-dir=${GCC_INSTALL_DIR}") + endif() + endif() + add_custom_command( OUTPUT "${GSAN_RUNTIME_IR}" COMMAND "${CMAKE_COMMAND}" -E make_directory @@ -38,6 +60,7 @@ if(TRITON_BUILD_PYTHON_MODULE) -fcuda-flush-denormals-to-zero --cuda-gpu-arch=sm_80 -Wno-unknown-cuda-version + ${GSAN_RUNTIME_TOOLCHAIN_FLAGS} ${GSAN_RUNTIME_PLATFORM_FLAGS} -isystem "${CMAKE_CURRENT_SOURCE_DIR}/clang_cuda_shims" -isystem "${CMAKE_CURRENT_SOURCE_DIR}/backend/include" From 7d1e42c680d81afad76588bf9014191451166ef0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 24 Mar 2026 20:56:45 +0000 Subject: [PATCH 03/54] works --- .../unit/language/test_tensor_descriptor.py | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/python/test/unit/language/test_tensor_descriptor.py b/python/test/unit/language/test_tensor_descriptor.py index 64c714ab4571..fa728d3301db 100644 --- a/python/test/unit/language/test_tensor_descriptor.py +++ b/python/test/unit/language/test_tensor_descriptor.py @@ -1734,6 +1734,64 @@ def matmul_kernel_host_tensor_descriptor(a_desc, b_desc, c_desc): c_desc.store([offs_am, offs_bn], accumulator) +@triton.jit +def matmul_kernel_host_tensor_descriptor_swizzle0_b(a_desc, b_desc, c_desc): + K = a_desc.shape[1] + BLOCK_M: tl.constexpr = a_desc.block_shape[0] + BLOCK_K: tl.constexpr = a_desc.block_shape[1] + BLOCK_N: tl.constexpr = c_desc.block_shape[1] + + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + offs_k = 0 + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k_tile in range(0, tl.cdiv(K, BLOCK_K)): + a = a_desc.load([offs_am, offs_k]) + # Inverse core-matrices reorder: + # [num_cm_k * num_cm_n, 8, 16] -> [num_cm_k, num_cm_n, 8, 16] + # -> [num_cm_n, 8, num_cm_k, 16] -> [BLOCK_N, BLOCK_K]. + b_cm = b_desc.load([pid_n, k_tile, 0, 0, 0]) + num_cm_k: tl.constexpr = BLOCK_K // 16 + num_cm_n: tl.constexpr = BLOCK_N // 8 + b = b_cm.reshape((num_cm_k, num_cm_n, 8, 16)) + b = tl.permute(b, (1, 2, 0, 3)) + b = b.reshape((BLOCK_N, BLOCK_K)) + accumulator = tl.dot(a, b.T, acc=accumulator) + offs_k += BLOCK_K + c_desc.store([offs_am, offs_bn], accumulator) + + +def transform_b_to_core_matrices_layout(B, BLOCK_N, BLOCK_K): + CM_ROWS = 8 + CM_COLS = 16 + + N, K = B.shape + assert N % BLOCK_N == 0 + assert K % BLOCK_K == 0 + assert BLOCK_N % CM_ROWS == 0 + assert BLOCK_K % CM_COLS == 0 + + num_blocks_n = N // BLOCK_N + num_blocks_k = K // BLOCK_K + num_cm_n = BLOCK_N // CM_ROWS + num_cm_k = BLOCK_K // CM_COLS + + # [N, K] -> [num_blocks_n, num_cm_n, CM_ROWS, num_blocks_k, num_cm_k, CM_COLS] + b_reshaped = B.reshape(num_blocks_n, num_cm_n, CM_ROWS, num_blocks_k, num_cm_k, CM_COLS) + # [num_blocks_n, num_cm_n, num_cm_k, num_blocks_k, CM_ROWS, CM_COLS] + b_perm = b_reshaped.permute(0, 1, 4, 3, 2, 5) + # N-major core-matrices: + # [num_blocks_n, num_blocks_k, num_cm_k, num_cm_n, CM_ROWS, CM_COLS] + b_perm = b_perm.permute(0, 3, 2, 1, 4, 5) + # Collapse cm-count dims to keep descriptor rank <= 5 while retaining + # explicit core-matrix rows/cols in the innermost axes. + b_transformed = b_perm.reshape(num_blocks_n, num_blocks_k, num_cm_k * num_cm_n, CM_ROWS, CM_COLS) + return b_transformed.contiguous() + + @pytest.mark.interpreter() @pytest.mark.parametrize("num_ctas", [1, 2]) @pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K, num_stages", [ @@ -1784,6 +1842,53 @@ def test_host_tensor_descriptor_matmul(num_stages, num_ctas, BLOCK_M, BLOCK_N, B "ptx"] or "stmatrix.sync.aligned.x4.m8n8.shared.b16" in kernel.asm["ptx"] +# TODO: require blackwell +def test_host_tensor_descriptor_matmul_fp8_swizzle0_b(device): + M = N = K = 512 + BLOCK_M = 128 + BLOCK_N = 256 + BLOCK_K = 128 + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N), 1) + + torch.manual_seed(0) + a_fp32 = torch.randn((M, K), dtype=torch.float32, device=device) + b_fp32 = torch.randn((N, K), dtype=torch.float32, device=device) + a = a_fp32.to(torch.float8_e4m3fn) + b = b_fp32.to(torch.float8_e4m3fn) + b_transformed = transform_b_to_core_matrices_layout(b, BLOCK_N, BLOCK_K) + c = torch.empty((M, N), dtype=torch.float32, device=device) + + a_desc = TensorDescriptor(a, a.shape, a.stride(), [BLOCK_M, BLOCK_K]) + num_cm_n = BLOCK_N // 8 + num_cm_k = BLOCK_K // 16 + b_desc = TensorDescriptor( + b_transformed, + b_transformed.shape, + b_transformed.stride(), + [1, 1, num_cm_k * num_cm_n, 8, 16], + ) + c_desc = TensorDescriptor(c, c.shape, c.stride(), [BLOCK_M, BLOCK_N]) + + kernel = matmul_kernel_host_tensor_descriptor_swizzle0_b[grid]( + a_desc, + b_desc, + c_desc, + num_warps=4, + num_stages=1, + ) + + # Compare against quantized operands actually consumed by the kernel. + ref_out = torch.matmul(a.to(torch.float32), b.to(torch.float32).T) + torch.testing.assert_close(ref_out, c, rtol=1.5e-1, atol=2.5e-1) + + ttgir = kernel.asm["ttgir"] + # assert "ttng.tc_gen5_mma" in ttgir + assert "swizzlingByteWidth = 0" in ttgir and "#ttg.shared_linear" in ttgir + # assert any(f"swizzlingByteWidth = {w}" in ttgir for w in [32, 64, 128]) + +test_host_tensor_descriptor_matmul_fp8_swizzle0_b("cuda") + + @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) def test_tensor_descriptor_store_downcast(dtype_str, device): From a86d083cabce6a40372d75bce37d9028052627a8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 25 Mar 2026 06:51:24 +0900 Subject: [PATCH 04/54] make it work for other dot ops --- .../Transforms/OptimizeDotOperands.cpp | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index b753dfc1fe65..1621bb237c54 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -317,17 +317,18 @@ class UseShmemForScales } }; -class RewriteSwizzle0OperandViewsToMemDescForTCGen5MMA - : public OpRewritePattern { +template +class RewriteSwizzle0OperandViewsToMemDescForDotOp + : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(triton::nvidia_gpu::TCGen5MMAOp mmaOp, + LogicalResult matchAndRewrite(DotOpTy dotOp, PatternRewriter &rewriter) const override { bool changed = false; - if (succeeded(rewriteOperand(mmaOp.getAMutable(), rewriter))) + if (succeeded(rewriteOperand(dotOp.getAMutable(), rewriter))) changed = true; - if (succeeded(rewriteOperand(mmaOp.getBMutable(), rewriter))) + if (succeeded(rewriteOperand(dotOp.getBMutable(), rewriter))) changed = true; return changed ? success() : failure(); } @@ -468,8 +469,12 @@ class TritonGPUOptimizeDotOperandsPass mlir::RewritePatternSet patterns(context); patterns.add(context); patterns.add(context); - patterns.add(context); + patterns.add< + UseShmemForScales, + RewriteSwizzle0OperandViewsToMemDescForDotOp, + RewriteSwizzle0OperandViewsToMemDescForDotOp, + RewriteSwizzle0OperandViewsToMemDescForDotOp>( + context); ConvertLayoutOp::getCanonicalizationPatterns(patterns, context); if (failed(applyPatternsGreedily(m, std::move(patterns)))) signalPassFailure(); From d2955e7ea6800767a34225fba2bae094d3274fc2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 25 Mar 2026 07:02:28 +0900 Subject: [PATCH 05/54] fix --- .../Transforms/OptimizeDotOperands.cpp | 26 ++++++++++++------- test/TritonGPU/dot-operands.mlir | 26 +++++++++++++++++++ 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 1621bb237c54..10103d494774 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -7,6 +7,7 @@ #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" @@ -334,6 +335,15 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp } private: + static bool isZeroSwizzleCompatibleEncoding(Attribute encoding) { + if (auto nvmma = dyn_cast(encoding)) + return nvmma.getSwizzlingByteWidth() == 0; + if (auto swizzled = dyn_cast(encoding)) + return swizzled.getVec() == 1 && swizzled.getPerPhase() == 1 && + swizzled.getMaxPhase() == 1; + return false; + } + struct ViewStep { enum Kind { Reshape, Transpose } kind; SmallVector shape; @@ -361,8 +371,10 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp return failure(); auto allocTy = cast(localAlloc.getType()); - auto allocEnc = dyn_cast(allocTy.getEncoding()); - if (!allocEnc || allocEnc.getSwizzlingByteWidth() != 0) + auto allocSharedEnc = + dyn_cast(allocTy.getEncoding()); + if (!allocSharedEnc || + !isZeroSwizzleCompatibleEncoding(allocTy.getEncoding())) return failure(); SmallVector reverseSteps; @@ -398,14 +410,10 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp if (!baseTensorTy) return failure(); - auto cgaLayout = CGAEncodingAttr::get1CTALayout(rewriter.getContext(), - baseTensorTy.getRank()); - auto baseEnc = NVMMASharedEncodingAttr::get( - rewriter.getContext(), /*swizzlingByteWidth=*/0, - /*transposed=*/false, allocEnc.getElementBitWidth(), - allocEnc.getFp4Padded(), cgaLayout); + auto baseEnc = updateEncodingForShape(localAlloc, allocSharedEnc, baseTensorTy); auto baseMemTy = MemDescType::get( - baseTensorTy.getShape(), baseTensorTy.getElementType(), baseEnc, + baseTensorTy.getShape(), baseTensorTy.getElementType(), + cast(baseEnc), allocTy.getMemorySpace(), allocTy.getMutableMemory()); PatternRewriter::InsertionGuard guard(rewriter); diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 73b718b2797f..76fe7a370da0 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -53,6 +53,32 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // ----- +#blockedA_h = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB_h = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blockedB3_h = #ttg.blocked<{sizePerThread = [1, 1, 16], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 4], order = [2, 1, 0]}> +#mma_h = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared_h = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#shared0_h = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 16, rank = 2}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @warp_group_dot_swizzle0_like_operand_views + // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %arg1 : (tensor<1x64x64xf16, #{{.*}}>) -> !ttg.memdesc<1x64x64xf16, #{{.*}}, #smem> + // CHECK-DAG: %[[B_RS:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x64x64xf16, #{{.*}}, #smem> -> !ttg.memdesc<64x64xf16, #{{.*}}, #smem> + // CHECK-NOT: tt.reshape + // CHECK: ttng.warp_group_dot %arg0, %[[B_RS]], %arg2 + tt.func @warp_group_dot_swizzle0_like_operand_views( + %a: !ttg.memdesc<128x64xf16, #shared_h, #smem>, + %b3: tensor<1x64x64xf16, #blockedB3_h>, + %c: tensor<128x64xf32, #mma_h>) -> tensor<128x64xf32, #mma_h> { + %b2d = tt.reshape %b3 : tensor<1x64x64xf16, #blockedB3_h> -> tensor<64x64xf16, #blockedB_h> + %b_s = ttg.local_alloc %b2d : (tensor<64x64xf16, #blockedB_h>) -> !ttg.memdesc<64x64xf16, #shared0_h, #smem> + %r = ttng.warp_group_dot %a, %b_s, %c : !ttg.memdesc<128x64xf16, #shared_h, #smem> * !ttg.memdesc<64x64xf16, #shared0_h, #smem> -> tensor<128x64xf32, #mma_h> + tt.return %r : tensor<128x64xf32, #mma_h> + } +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> From 28d35fab154fa7c6e22439542fe5da250ca92650 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 25 Mar 2026 08:13:30 +0900 Subject: [PATCH 06/54] fix --- .../TritonGPU/Transforms/OptimizeDotOperands.cpp | 16 ++++++++++++---- test/TritonGPU/dot-operands.mlir | 13 ++++++++----- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 10103d494774..010d8544c505 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -371,10 +371,8 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp return failure(); auto allocTy = cast(localAlloc.getType()); - auto allocSharedEnc = - dyn_cast(allocTy.getEncoding()); - if (!allocSharedEnc || - !isZeroSwizzleCompatibleEncoding(allocTy.getEncoding())) + auto allocSharedEnc = dyn_cast(allocTy.getEncoding()); + if (!allocSharedEnc) return failure(); SmallVector reverseSteps; @@ -406,6 +404,16 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp if (reverseSteps.empty()) return failure(); + bool sourceIsZeroSwizzleLike = false; + if (auto localLoad = baseTensor.getDefiningOp()) { + auto srcTy = dyn_cast(localLoad.getSrc().getType()); + sourceIsZeroSwizzleLike = + srcTy && isZeroSwizzleCompatibleEncoding(srcTy.getEncoding()); + } + if (!isZeroSwizzleCompatibleEncoding(allocTy.getEncoding()) && + !sourceIsZeroSwizzleLike) + return failure(); + auto baseTensorTy = dyn_cast(baseTensor.getType()); if (!baseTensorTy) return failure(); diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 76fe7a370da0..95f84c85be74 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -58,11 +58,12 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- #blockedB3_h = #ttg.blocked<{sizePerThread = [1, 1, 16], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 4], order = [2, 1, 0]}> #mma_h = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> #shared_h = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> -#shared0_h = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 16, rank = 2}> +#shared0_h = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 16, rank = 3}> #smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @warp_group_dot_swizzle0_like_operand_views - // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %arg1 : (tensor<1x64x64xf16, #{{.*}}>) -> !ttg.memdesc<1x64x64xf16, #{{.*}}, #smem> + // CHECK-DAG: %[[B_SRC:.*]] = ttg.local_load %{{.*}} : !ttg.memdesc<1x64x64xf16, #{{.*}}, #smem{{.*}}> -> tensor<1x64x64xf16, #{{.*}}> + // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_SRC]] : (tensor<1x64x64xf16, #{{.*}}>) -> !ttg.memdesc<1x64x64xf16, #{{.*}}, #smem{{.*}}> // CHECK-DAG: %[[B_RS:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x64x64xf16, #{{.*}}, #smem> -> !ttg.memdesc<64x64xf16, #{{.*}}, #smem> // CHECK-NOT: tt.reshape // CHECK: ttng.warp_group_dot %arg0, %[[B_RS]], %arg2 @@ -70,9 +71,11 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- %a: !ttg.memdesc<128x64xf16, #shared_h, #smem>, %b3: tensor<1x64x64xf16, #blockedB3_h>, %c: tensor<128x64xf32, #mma_h>) -> tensor<128x64xf32, #mma_h> { - %b2d = tt.reshape %b3 : tensor<1x64x64xf16, #blockedB3_h> -> tensor<64x64xf16, #blockedB_h> - %b_s = ttg.local_alloc %b2d : (tensor<64x64xf16, #blockedB_h>) -> !ttg.memdesc<64x64xf16, #shared0_h, #smem> - %r = ttng.warp_group_dot %a, %b_s, %c : !ttg.memdesc<128x64xf16, #shared_h, #smem> * !ttg.memdesc<64x64xf16, #shared0_h, #smem> -> tensor<128x64xf32, #mma_h> + %b3_s = ttg.local_alloc %b3 : (tensor<1x64x64xf16, #blockedB3_h>) -> !ttg.memdesc<1x64x64xf16, #shared0_h, #smem, mutable> + %b3_l = ttg.local_load %b3_s : !ttg.memdesc<1x64x64xf16, #shared0_h, #smem, mutable> -> tensor<1x64x64xf16, #blockedB3_h> + %b2d = tt.reshape %b3_l : tensor<1x64x64xf16, #blockedB3_h> -> tensor<64x64xf16, #blockedB_h> + %b_s = ttg.local_alloc %b2d : (tensor<64x64xf16, #blockedB_h>) -> !ttg.memdesc<64x64xf16, #shared_h, #smem> + %r = ttng.warp_group_dot %a, %b_s, %c : !ttg.memdesc<128x64xf16, #shared_h, #smem> * !ttg.memdesc<64x64xf16, #shared_h, #smem> -> tensor<128x64xf32, #mma_h> tt.return %r : tensor<128x64xf32, #mma_h> } } From 638c3b090b72a68f6b4fbac8187bb4df713080a4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 25 Mar 2026 08:37:46 +0900 Subject: [PATCH 07/54] [TritonGPU] Match swizzle0 operand-view rewrite from local_load source on Hopper --- .../Transforms/OptimizeDotOperands.cpp | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 010d8544c505..53d305bed912 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -405,10 +405,14 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp return failure(); bool sourceIsZeroSwizzleLike = false; + SharedEncodingTrait sourceSharedEnc; if (auto localLoad = baseTensor.getDefiningOp()) { auto srcTy = dyn_cast(localLoad.getSrc().getType()); - sourceIsZeroSwizzleLike = - srcTy && isZeroSwizzleCompatibleEncoding(srcTy.getEncoding()); + if (srcTy) { + sourceIsZeroSwizzleLike = + isZeroSwizzleCompatibleEncoding(srcTy.getEncoding()); + sourceSharedEnc = dyn_cast(srcTy.getEncoding()); + } } if (!isZeroSwizzleCompatibleEncoding(allocTy.getEncoding()) && !sourceIsZeroSwizzleLike) @@ -418,7 +422,15 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp if (!baseTensorTy) return failure(); - auto baseEnc = updateEncodingForShape(localAlloc, allocSharedEnc, baseTensorTy); + // If swizzle=0 comes from the source memdesc (common for descriptor+TMA + // staging on Hopper), use that encoding as the reference. Using the final + // alloc encoding can be rank-incompatible with the source tensor view + // chain and make this rewrite fail to materialize. + SharedEncodingTrait refSharedEnc = allocSharedEnc; + if (sourceIsZeroSwizzleLike && sourceSharedEnc) + refSharedEnc = sourceSharedEnc; + + auto baseEnc = updateEncodingForShape(localAlloc, refSharedEnc, baseTensorTy); auto baseMemTy = MemDescType::get( baseTensorTy.getShape(), baseTensorTy.getElementType(), cast(baseEnc), From 3375a12db77ecb8998c3787a911331e23207b831 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 25 Mar 2026 08:42:05 +0900 Subject: [PATCH 08/54] [TritonGPU] Use source shared encoding for swizzle0 operand-view rewrite --- lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 53d305bed912..f4f64d77710e 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -444,14 +444,8 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp for (ViewStep &step : llvm::reverse(reverseSteps)) { if (step.kind == ViewStep::Reshape) { - MemDescType reshapedTy; - if (failed(MemDescReshapeOp::inferReturnTypes( - rewriter.getContext(), step.loc, - cast(rewritten.getType()), step.shape, - reshapedTy))) - return failure(); rewritten = - MemDescReshapeOp::create(rewriter, step.loc, reshapedTy, rewritten); + MemDescReshapeOp::create(rewriter, step.loc, rewritten, step.shape); } else { rewritten = MemDescTransOp::create(rewriter, step.loc, rewritten, step.order); From 9f559e93a19e40f30d0737dfacaf3337a99e2507 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 25 Mar 2026 09:16:45 +0900 Subject: [PATCH 09/54] fix --- .../Transforms/OptimizeDotOperands.cpp | 11 +++++ test/TritonGPU/dot-operands.mlir | 45 ++++++++++++++++++- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index f4f64d77710e..2687e102e0ee 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -4,6 +4,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -413,6 +414,16 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp isZeroSwizzleCompatibleEncoding(srcTy.getEncoding()); sourceSharedEnc = dyn_cast(srcTy.getEncoding()); } + } else if (auto descLoad = baseTensor.getDefiningOp()) { + auto descTy = dyn_cast(descLoad.getDesc().getType()); + if (descTy) { + auto blockTy = descTy.getBlockType(); + sourceSharedEnc = + dyn_cast_or_null(blockTy.getEncoding()); + sourceIsZeroSwizzleLike = sourceSharedEnc && + isZeroSwizzleCompatibleEncoding( + cast(sourceSharedEnc)); + } } if (!isZeroSwizzleCompatibleEncoding(allocTy.getEncoding()) && !sourceIsZeroSwizzleLike) diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 95f84c85be74..6ae4df1fa4e1 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -82,6 +82,48 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // ----- +#blockedA_desc = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +#blockedB0_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> +#blockedB1_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> +#linear_desc = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> +#mma_desc = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 32]}> +#sharedA_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#sharedB0_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> +#sharedB_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @descriptor_load_warp_group_dot_swizzle0_operand_views + // CHECK-DAG: %[[B_DESC_LOAD:.*]] = tt.descriptor_load %arg1[ + // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC_LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, {{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_TR:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_T:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-NOT: tt.reshape + // CHECK-NOT: tt.trans + // CHECK: ttng.warp_group_dot %{{.*}}, %[[B_T]], %arg2 + tt.func @descriptor_load_warp_group_dot_swizzle0_operand_views( + %a_desc: !tt.tensordesc>, + %b_desc: !tt.tensordesc>, + %acc: tensor<128x256xf32, #mma_desc>) -> tensor<128x256xf32, #mma_desc> { + %c0 = arith.constant 0 : i32 + %a = tt.descriptor_load %a_desc[%c0, %c0] : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blockedA_desc> + %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blockedA_desc>) -> !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> + %b_cm = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> + %b0 = tt.reshape %b_cm : tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> -> tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> + %b1 = tt.trans %b0 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> -> tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> + %b2 = tt.reshape %b1 : tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> -> tensor<256x128xf8E4M3FN, #linear_desc> + %b_s = ttg.local_alloc %b2 : (tensor<256x128xf8E4M3FN, #linear_desc>) -> !ttg.memdesc<256x128xf8E4M3FN, #sharedA_desc, #smem> + %b_t = ttg.memdesc_trans %b_s {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #sharedA_desc, #smem> -> !ttg.memdesc<128x256xf8E4M3FN, #sharedB_desc, #smem> + %r = ttng.warp_group_dot %a_s, %b_t, %acc {inputPrecision = 0 : i32, isAsync = true, maxNumImpreciseAcc = 1073741824 : i32} + : !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> * !ttg.memdesc<128x256xf8E4M3FN, #sharedB_desc, #smem> -> tensor<128x256xf32, #mma_desc> + tt.return %r : tensor<128x256xf32, #mma_desc> + } +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> @@ -89,13 +131,12 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- #smem = #ttg.shared_memory #tmem = #ttng.tensor_memory_encoding -// CHECK: #[[$SHARED:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: mma_reorder_transpose_mmav5 tt.func @mma_reorder_transpose_mmav5(%t: tensor<64x256xf8E4M3FN, #blocked1>, %dotb: !ttg.memdesc<64x128xf8E4M3FN, #shared1, #smem>, %dotc: !ttg.memdesc<256x128xf32, #tmem, #ttng.tensor_memory>) { %true = arith.constant true %a = tt.trans %t {order = array} : tensor<64x256xf8E4M3FN, #blocked1> -> tensor<256x64xf8E4M3FN, #blocked> - // CHECK: %[[A:.+]] = ttg.local_alloc {{.*}} -> !ttg.memdesc<64x256xf8E4M3FN, #[[$SHARED]], #smem> + // CHECK: %[[A:.+]] = ttg.local_alloc {{.*}} -> !ttg.memdesc<64x256xf8E4M3FN, #{{.*}}, #smem> // CHECK: %[[T:.+]] = ttg.memdesc_trans %[[A]] {order = array} // CHECK: ttng.tc_gen5_mma %[[T]] %dota = ttg.local_alloc %a: (tensor<256x64xf8E4M3FN, #blocked>) -> !ttg.memdesc<256x64xf8E4M3FN, #shared1, #smem> From 390b118d39dd9f28313ae6f88f2a8de4b2cc5261 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 25 Mar 2026 09:24:08 +0900 Subject: [PATCH 10/54] clean --- .../Transforms/OptimizeDotOperands.cpp | 43 ++++++++++--------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 2687e102e0ee..417aae4a3634 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -352,6 +352,22 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp Location loc; }; + static SharedEncodingTrait getSourceSharedEncoding(Value baseTensor) { + if (auto localLoad = baseTensor.getDefiningOp()) { + if (auto srcTy = dyn_cast(localLoad.getSrc().getType())) + return dyn_cast(srcTy.getEncoding()); + return nullptr; + } + if (auto descLoad = baseTensor.getDefiningOp()) { + auto descTy = dyn_cast(descLoad.getDesc().getType()); + if (!descTy) + return nullptr; + return dyn_cast_or_null( + descTy.getBlockType().getEncoding()); + } + return nullptr; + } + LogicalResult rewriteOperand(OpOperand &operand, PatternRewriter &rewriter) const { Value orig = operand.get(); @@ -405,26 +421,10 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp if (reverseSteps.empty()) return failure(); - bool sourceIsZeroSwizzleLike = false; - SharedEncodingTrait sourceSharedEnc; - if (auto localLoad = baseTensor.getDefiningOp()) { - auto srcTy = dyn_cast(localLoad.getSrc().getType()); - if (srcTy) { - sourceIsZeroSwizzleLike = - isZeroSwizzleCompatibleEncoding(srcTy.getEncoding()); - sourceSharedEnc = dyn_cast(srcTy.getEncoding()); - } - } else if (auto descLoad = baseTensor.getDefiningOp()) { - auto descTy = dyn_cast(descLoad.getDesc().getType()); - if (descTy) { - auto blockTy = descTy.getBlockType(); - sourceSharedEnc = - dyn_cast_or_null(blockTy.getEncoding()); - sourceIsZeroSwizzleLike = sourceSharedEnc && - isZeroSwizzleCompatibleEncoding( - cast(sourceSharedEnc)); - } - } + auto sourceSharedEnc = getSourceSharedEncoding(baseTensor); + bool sourceIsZeroSwizzleLike = + sourceSharedEnc && + isZeroSwizzleCompatibleEncoding(cast(sourceSharedEnc)); if (!isZeroSwizzleCompatibleEncoding(allocTy.getEncoding()) && !sourceIsZeroSwizzleLike) return failure(); @@ -441,7 +441,8 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp if (sourceIsZeroSwizzleLike && sourceSharedEnc) refSharedEnc = sourceSharedEnc; - auto baseEnc = updateEncodingForShape(localAlloc, refSharedEnc, baseTensorTy); + auto baseEnc = + updateEncodingForShape(localAlloc, refSharedEnc, baseTensorTy); auto baseMemTy = MemDescType::get( baseTensorTy.getShape(), baseTensorTy.getElementType(), cast(baseEnc), From 37820689722bd3ce75a4a2da974bb3b60ec1c626 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 26 Mar 2026 07:09:53 +0900 Subject: [PATCH 11/54] simplify --- lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp | 8 -------- 1 file changed, 8 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 417aae4a3634..9e1d89916bb0 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -339,9 +339,6 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp static bool isZeroSwizzleCompatibleEncoding(Attribute encoding) { if (auto nvmma = dyn_cast(encoding)) return nvmma.getSwizzlingByteWidth() == 0; - if (auto swizzled = dyn_cast(encoding)) - return swizzled.getVec() == 1 && swizzled.getPerPhase() == 1 && - swizzled.getMaxPhase() == 1; return false; } @@ -353,11 +350,6 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp }; static SharedEncodingTrait getSourceSharedEncoding(Value baseTensor) { - if (auto localLoad = baseTensor.getDefiningOp()) { - if (auto srcTy = dyn_cast(localLoad.getSrc().getType())) - return dyn_cast(srcTy.getEncoding()); - return nullptr; - } if (auto descLoad = baseTensor.getDefiningOp()) { auto descTy = dyn_cast(descLoad.getDesc().getType()); if (!descTy) From 8707f6d01eb3508e0a908005b092f4861e3e5a24 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 26 Mar 2026 07:33:48 +0900 Subject: [PATCH 12/54] remove pattern matching against desc load --- .../Transforms/OptimizeDotOperands.cpp | 61 +++++-------------- test/TritonGPU/dot-operands.mlir | 42 ------------- 2 files changed, 14 insertions(+), 89 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 9e1d89916bb0..604f122e33dc 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -336,12 +336,6 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp } private: - static bool isZeroSwizzleCompatibleEncoding(Attribute encoding) { - if (auto nvmma = dyn_cast(encoding)) - return nvmma.getSwizzlingByteWidth() == 0; - return false; - } - struct ViewStep { enum Kind { Reshape, Transpose } kind; SmallVector shape; @@ -349,17 +343,6 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp Location loc; }; - static SharedEncodingTrait getSourceSharedEncoding(Value baseTensor) { - if (auto descLoad = baseTensor.getDefiningOp()) { - auto descTy = dyn_cast(descLoad.getDesc().getType()); - if (!descTy) - return nullptr; - return dyn_cast_or_null( - descTy.getBlockType().getEncoding()); - } - return nullptr; - } - LogicalResult rewriteOperand(OpOperand &operand, PatternRewriter &rewriter) const { Value orig = operand.get(); @@ -380,8 +363,8 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp return failure(); auto allocTy = cast(localAlloc.getType()); - auto allocSharedEnc = dyn_cast(allocTy.getEncoding()); - if (!allocSharedEnc) + auto allocEnc = dyn_cast(allocTy.getEncoding()); + if (!allocEnc || allocEnc.getSwizzlingByteWidth() != 0) return failure(); SmallVector reverseSteps; @@ -413,32 +396,15 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp if (reverseSteps.empty()) return failure(); - auto sourceSharedEnc = getSourceSharedEncoding(baseTensor); - bool sourceIsZeroSwizzleLike = - sourceSharedEnc && - isZeroSwizzleCompatibleEncoding(cast(sourceSharedEnc)); - if (!isZeroSwizzleCompatibleEncoding(allocTy.getEncoding()) && - !sourceIsZeroSwizzleLike) - return failure(); - auto baseTensorTy = dyn_cast(baseTensor.getType()); if (!baseTensorTy) return failure(); - // If swizzle=0 comes from the source memdesc (common for descriptor+TMA - // staging on Hopper), use that encoding as the reference. Using the final - // alloc encoding can be rank-incompatible with the source tensor view - // chain and make this rewrite fail to materialize. - SharedEncodingTrait refSharedEnc = allocSharedEnc; - if (sourceIsZeroSwizzleLike && sourceSharedEnc) - refSharedEnc = sourceSharedEnc; - - auto baseEnc = - updateEncodingForShape(localAlloc, refSharedEnc, baseTensorTy); - auto baseMemTy = MemDescType::get( - baseTensorTy.getShape(), baseTensorTy.getElementType(), - cast(baseEnc), - allocTy.getMemorySpace(), allocTy.getMutableMemory()); + auto baseEnc = updateEncodingForShape(localAlloc, allocEnc, baseTensorTy); + auto baseMemTy = + MemDescType::get(baseTensorTy.getShape(), baseTensorTy.getElementType(), + cast(baseEnc), allocTy.getMemorySpace(), + allocTy.getMutableMemory()); PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(localAlloc); @@ -495,12 +461,13 @@ class TritonGPUOptimizeDotOperandsPass mlir::RewritePatternSet patterns(context); patterns.add(context); patterns.add(context); - patterns.add< - UseShmemForScales, - RewriteSwizzle0OperandViewsToMemDescForDotOp, - RewriteSwizzle0OperandViewsToMemDescForDotOp, - RewriteSwizzle0OperandViewsToMemDescForDotOp>( - context); + patterns.add, + RewriteSwizzle0OperandViewsToMemDescForDotOp< + triton::nvidia_gpu::TCGen5MMAScaledOp>, + RewriteSwizzle0OperandViewsToMemDescForDotOp< + triton::nvidia_gpu::WarpGroupDotOp>>(context); ConvertLayoutOp::getCanonicalizationPatterns(patterns, context); if (failed(applyPatternsGreedily(m, std::move(patterns)))) signalPassFailure(); diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 6ae4df1fa4e1..2455b819b509 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -82,48 +82,6 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // ----- -#blockedA_desc = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blockedB_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> -#blockedB0_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> -#blockedB1_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> -#linear_desc = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> -#mma_desc = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 32]}> -#sharedA_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> -#sharedB0_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> -#sharedB_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}> -#smem = #ttg.shared_memory -module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: @descriptor_load_warp_group_dot_swizzle0_operand_views - // CHECK-DAG: %[[B_DESC_LOAD:.*]] = tt.descriptor_load %arg1[ - // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC_LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, {{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_TR:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_T:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-NOT: tt.reshape - // CHECK-NOT: tt.trans - // CHECK: ttng.warp_group_dot %{{.*}}, %[[B_T]], %arg2 - tt.func @descriptor_load_warp_group_dot_swizzle0_operand_views( - %a_desc: !tt.tensordesc>, - %b_desc: !tt.tensordesc>, - %acc: tensor<128x256xf32, #mma_desc>) -> tensor<128x256xf32, #mma_desc> { - %c0 = arith.constant 0 : i32 - %a = tt.descriptor_load %a_desc[%c0, %c0] : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blockedA_desc> - %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blockedA_desc>) -> !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> - %b_cm = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> - %b0 = tt.reshape %b_cm : tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> -> tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> - %b1 = tt.trans %b0 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> -> tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> - %b2 = tt.reshape %b1 : tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> -> tensor<256x128xf8E4M3FN, #linear_desc> - %b_s = ttg.local_alloc %b2 : (tensor<256x128xf8E4M3FN, #linear_desc>) -> !ttg.memdesc<256x128xf8E4M3FN, #sharedA_desc, #smem> - %b_t = ttg.memdesc_trans %b_s {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #sharedA_desc, #smem> -> !ttg.memdesc<128x256xf8E4M3FN, #sharedB_desc, #smem> - %r = ttng.warp_group_dot %a_s, %b_t, %acc {inputPrecision = 0 : i32, isAsync = true, maxNumImpreciseAcc = 1073741824 : i32} - : !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> * !ttg.memdesc<128x256xf8E4M3FN, #sharedB_desc, #smem> -> tensor<128x256xf32, #mma_desc> - tt.return %r : tensor<128x256xf32, #mma_desc> - } -} - -// ----- - #blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> From 5ea972429e9553718be687d6602f1655d785e58d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 26 Mar 2026 07:57:43 +0900 Subject: [PATCH 13/54] upd lit test --- test/TritonGPU/dot-operands.mlir | 72 +++++++++++++++++++------------- 1 file changed, 42 insertions(+), 30 deletions(-) diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 2455b819b509..fe9ba0dd115a 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -53,35 +53,6 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // ----- -#blockedA_h = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#blockedB_h = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#blockedB3_h = #ttg.blocked<{sizePerThread = [1, 1, 16], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 4], order = [2, 1, 0]}> -#mma_h = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#shared_h = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> -#shared0_h = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 16, rank = 3}> -#smem = #ttg.shared_memory -module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: @warp_group_dot_swizzle0_like_operand_views - // CHECK-DAG: %[[B_SRC:.*]] = ttg.local_load %{{.*}} : !ttg.memdesc<1x64x64xf16, #{{.*}}, #smem{{.*}}> -> tensor<1x64x64xf16, #{{.*}}> - // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_SRC]] : (tensor<1x64x64xf16, #{{.*}}>) -> !ttg.memdesc<1x64x64xf16, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_RS:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x64x64xf16, #{{.*}}, #smem> -> !ttg.memdesc<64x64xf16, #{{.*}}, #smem> - // CHECK-NOT: tt.reshape - // CHECK: ttng.warp_group_dot %arg0, %[[B_RS]], %arg2 - tt.func @warp_group_dot_swizzle0_like_operand_views( - %a: !ttg.memdesc<128x64xf16, #shared_h, #smem>, - %b3: tensor<1x64x64xf16, #blockedB3_h>, - %c: tensor<128x64xf32, #mma_h>) -> tensor<128x64xf32, #mma_h> { - %b3_s = ttg.local_alloc %b3 : (tensor<1x64x64xf16, #blockedB3_h>) -> !ttg.memdesc<1x64x64xf16, #shared0_h, #smem, mutable> - %b3_l = ttg.local_load %b3_s : !ttg.memdesc<1x64x64xf16, #shared0_h, #smem, mutable> -> tensor<1x64x64xf16, #blockedB3_h> - %b2d = tt.reshape %b3_l : tensor<1x64x64xf16, #blockedB3_h> -> tensor<64x64xf16, #blockedB_h> - %b_s = ttg.local_alloc %b2d : (tensor<64x64xf16, #blockedB_h>) -> !ttg.memdesc<64x64xf16, #shared_h, #smem> - %r = ttng.warp_group_dot %a, %b_s, %c : !ttg.memdesc<128x64xf16, #shared_h, #smem> * !ttg.memdesc<64x64xf16, #shared_h, #smem> -> tensor<128x64xf32, #mma_h> - tt.return %r : tensor<128x64xf32, #mma_h> - } -} - -// ----- - #blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> @@ -89,12 +60,13 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- #smem = #ttg.shared_memory #tmem = #ttng.tensor_memory_encoding +// CHECK: #[[$SHARED:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: mma_reorder_transpose_mmav5 tt.func @mma_reorder_transpose_mmav5(%t: tensor<64x256xf8E4M3FN, #blocked1>, %dotb: !ttg.memdesc<64x128xf8E4M3FN, #shared1, #smem>, %dotc: !ttg.memdesc<256x128xf32, #tmem, #ttng.tensor_memory>) { %true = arith.constant true %a = tt.trans %t {order = array} : tensor<64x256xf8E4M3FN, #blocked1> -> tensor<256x64xf8E4M3FN, #blocked> - // CHECK: %[[A:.+]] = ttg.local_alloc {{.*}} -> !ttg.memdesc<64x256xf8E4M3FN, #{{.*}}, #smem> + // CHECK: %[[A:.+]] = ttg.local_alloc {{.*}} -> !ttg.memdesc<64x256xf8E4M3FN, #[[$SHARED]], #smem> // CHECK: %[[T:.+]] = ttg.memdesc_trans %[[A]] {order = array} // CHECK: ttng.tc_gen5_mma %[[T]] %dota = ttg.local_alloc %a: (tensor<256x64xf8E4M3FN, #blocked>) -> !ttg.memdesc<256x64xf8E4M3FN, #shared1, #smem> @@ -342,6 +314,46 @@ module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num ttng.tc_gen5_mma %a_s, %b_s, %acc, %true, %true : !ttg.memdesc<256x64xf8E4M3FN, #shared0, #smem>, !ttg.memdesc<64x128xf8E4M3FN, #shared0, #smem>, !ttg.memdesc<256x128xf32, #tmem0, #ttng.tensor_memory> tt.return } + +} + +// ----- + +#blockedA2_h = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB2_h = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blockedA3_h = #ttg.blocked<{sizePerThread = [1, 1, 16], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 4], order = [2, 1, 0]}> +#blockedB3_h = #ttg.blocked<{sizePerThread = [1, 1, 16], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 4], order = [2, 1, 0]}> +#mma_h2 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared0_h2 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 16}> +#sharedA_h2 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @swizzle0_operand_views_warp_group_dot + // CHECK-DAG: %[[A_LOAD:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<1x64x64xf16, #{{.*}}> + // CHECK-DAG: %[[A_BASE:.*]] = ttg.local_alloc %[[A_LOAD]] : (tensor<1x64x64xf16, #{{.*}}>) -> !ttg.memdesc<1x64x64xf16, #{{.*}}, #smem> + // CHECK-DAG: %[[A_RS:.*]] = ttg.memdesc_reshape %[[A_BASE]] : !ttg.memdesc<1x64x64xf16, #{{.*}}, #smem> -> !ttg.memdesc<64x64xf16, #{{.*}}, #smem> + // CHECK-DAG: %[[A_TR:.*]] = ttg.memdesc_trans %[[A_RS]] {order = array} : !ttg.memdesc<64x64xf16, #{{.*}}, #smem> -> !ttg.memdesc<64x64xf16, #{{.*}}, #smem> + // CHECK-DAG: %[[B_LOAD:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<1x64x64xf16, #{{.*}}> + // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_LOAD]] : (tensor<1x64x64xf16, #{{.*}}>) -> !ttg.memdesc<1x64x64xf16, #{{.*}}, #smem> + // CHECK-DAG: %[[B_RS:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x64x64xf16, #{{.*}}, #smem> -> !ttg.memdesc<64x64xf16, #{{.*}}, #smem> + // CHECK-NOT: tt.reshape + // CHECK-NOT: tt.trans + // CHECK: ttng.warp_group_dot %[[A_TR]], %[[B_RS]], %arg2 + tt.func @swizzle0_operand_views_warp_group_dot( + %a_desc: !tt.tensordesc>, + %b_desc: !tt.tensordesc>, + %acc: tensor<64x64xf32, #mma_h2>) -> tensor<64x64xf32, #mma_h2> { + %c0_i32 = arith.constant 0 : i32 + %a = tt.descriptor_load %a_desc[%c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<1x64x64xf16, #blockedA3_h> + %b = tt.descriptor_load %b_desc[%c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<1x64x64xf16, #blockedB3_h> + %a2d = tt.reshape %a : tensor<1x64x64xf16, #blockedA3_h> -> tensor<64x64xf16, #blockedB2_h> + %aT = tt.trans %a2d {order = array} : tensor<64x64xf16, #blockedB2_h> -> tensor<64x64xf16, #blockedA2_h> + %b2d = tt.reshape %b : tensor<1x64x64xf16, #blockedB3_h> -> tensor<64x64xf16, #blockedB2_h> + %a_s = ttg.local_alloc %aT : (tensor<64x64xf16, #blockedA2_h>) -> !ttg.memdesc<64x64xf16, #shared0_h2, #smem> + %b_s = ttg.local_alloc %b2d : (tensor<64x64xf16, #blockedB2_h>) -> !ttg.memdesc<64x64xf16, #shared0_h2, #smem> + %r = ttng.warp_group_dot %a_s, %b_s, %acc : !ttg.memdesc<64x64xf16, #shared0_h2, #smem> * !ttg.memdesc<64x64xf16, #shared0_h2, #smem> -> tensor<64x64xf32, #mma_h2> + tt.return %r : tensor<64x64xf32, #mma_h2> + } } // ----- From 12cb8e012ffb81fb5135e7edc379f86726230d14 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 28 Mar 2026 09:49:54 +0900 Subject: [PATCH 14/54] fix --- .../Transforms/OptimizeDotOperands.cpp | 36 +++++++++++++-- test/TritonGPU/dot-operands.mlir | 44 +++++++++++++++++++ 2 files changed, 77 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 604f122e33dc..a0261e359045 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -336,6 +336,12 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp } private: + static bool isZeroSwizzleCompatibleEncoding(Attribute encoding) { + if (auto nvmma = dyn_cast(encoding)) + return nvmma.getSwizzlingByteWidth() == 0; + return false; + } + struct ViewStep { enum Kind { Reshape, Transpose } kind; SmallVector shape; @@ -343,6 +349,18 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp Location loc; }; + static SharedEncodingTrait getSourceSharedEncoding(Value baseTensor) { + if (auto descLoad = baseTensor.getDefiningOp()) { + auto descTy = + dyn_cast(descLoad.getDesc().getType()); + if (!descTy) + return nullptr; + return dyn_cast_or_null( + descTy.getBlockType().getEncoding()); + } + return nullptr; + } + LogicalResult rewriteOperand(OpOperand &operand, PatternRewriter &rewriter) const { Value orig = operand.get(); @@ -363,8 +381,8 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp return failure(); auto allocTy = cast(localAlloc.getType()); - auto allocEnc = dyn_cast(allocTy.getEncoding()); - if (!allocEnc || allocEnc.getSwizzlingByteWidth() != 0) + auto allocSharedEnc = dyn_cast(allocTy.getEncoding()); + if (!allocSharedEnc) return failure(); SmallVector reverseSteps; @@ -396,11 +414,23 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp if (reverseSteps.empty()) return failure(); + auto sourceSharedEnc = getSourceSharedEncoding(baseTensor); + bool sourceIsZeroSwizzleLike = + sourceSharedEnc && + isZeroSwizzleCompatibleEncoding(cast(sourceSharedEnc)); + if (!isZeroSwizzleCompatibleEncoding(allocTy.getEncoding()) && + !sourceIsZeroSwizzleLike) + return failure(); + auto baseTensorTy = dyn_cast(baseTensor.getType()); if (!baseTensorTy) return failure(); - auto baseEnc = updateEncodingForShape(localAlloc, allocEnc, baseTensorTy); + SharedEncodingTrait refSharedEnc = allocSharedEnc; + if (sourceIsZeroSwizzleLike && sourceSharedEnc) + refSharedEnc = sourceSharedEnc; + + auto baseEnc = updateEncodingForShape(localAlloc, refSharedEnc, baseTensorTy); auto baseMemTy = MemDescType::get(baseTensorTy.getShape(), baseTensorTy.getElementType(), cast(baseEnc), allocTy.getMemorySpace(), diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index fe9ba0dd115a..eb4bc9f9f79f 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -30,6 +30,50 @@ module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num- } } + +// ----- + +#blockedA_desc = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +#blockedB0_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> +#blockedB1_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> +#linear_desc = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> +#mma_desc = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 32]}> +#sharedA_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#sharedB0_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> +#sharedB_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @descriptor_load_warp_group_dot_swizzle0_operand_views + // CHECK-DAG: %[[B_DESC_LOAD:.*]] = tt.descriptor_load %arg1[ + // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC_LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, {{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_TR:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_T:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-NOT: tt.reshape + // CHECK-NOT: tt.trans + // CHECK: ttng.warp_group_dot %{{.*}}, %[[B_T]], %arg2 + tt.func @descriptor_load_warp_group_dot_swizzle0_operand_views( + %a_desc: !tt.tensordesc>, + %b_desc: !tt.tensordesc>, + %acc: tensor<128x256xf32, #mma_desc>) -> tensor<128x256xf32, #mma_desc> { + %c0 = arith.constant 0 : i32 + %a = tt.descriptor_load %a_desc[%c0, %c0] : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blockedA_desc> + %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blockedA_desc>) -> !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> + %b_cm = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> + %b0 = tt.reshape %b_cm : tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> -> tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> + %b1 = tt.trans %b0 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> -> tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> + %b2 = tt.reshape %b1 : tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> -> tensor<256x128xf8E4M3FN, #linear_desc> + %b_s = ttg.local_alloc %b2 : (tensor<256x128xf8E4M3FN, #linear_desc>) -> !ttg.memdesc<256x128xf8E4M3FN, #sharedA_desc, #smem> + %b_t = ttg.memdesc_trans %b_s {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #sharedA_desc, #smem> -> !ttg.memdesc<128x256xf8E4M3FN, #sharedB_desc, #smem> + %r = ttng.warp_group_dot %a_s, %b_t, %acc {inputPrecision = 0 : i32, isAsync = true, maxNumImpreciseAcc = 1073741824 : i32} + : !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> * !ttg.memdesc<128x256xf8E4M3FN, #sharedB_desc, #smem> -> tensor<128x256xf32, #mma_desc> + tt.return %r : tensor<128x256xf32, #mma_desc> + } +} + + // ----- #blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> From 07119d3acaf8a036b6c0f96be3689ed17912d833 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 28 Mar 2026 10:04:53 +0900 Subject: [PATCH 15/54] fix for bw --- .../Transforms/OptimizeDotOperands.cpp | 38 +++++++++++++++---- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index a0261e359045..cb778fe36d15 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -368,12 +368,29 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp if (!origTy) return failure(); - SmallVector trailingMemDescTransOrder; + SmallVector trailingMemDescSteps; Value beforeTrailing = orig; - if (auto trailing = beforeTrailing.getDefiningOp()) { - trailingMemDescTransOrder.assign(trailing.getOrder().begin(), - trailing.getOrder().end()); - beforeTrailing = trailing.getSrc(); + while (true) { + if (auto trailingTrans = beforeTrailing.getDefiningOp()) { + SmallVector order(trailingTrans.getOrder().begin(), + trailingTrans.getOrder().end()); + trailingMemDescSteps.push_back( + ViewStep{ViewStep::Transpose, {}, std::move(order), + trailingTrans.getLoc()}); + beforeTrailing = trailingTrans.getSrc(); + continue; + } + if (auto trailingReshape = + beforeTrailing.getDefiningOp()) { + auto ty = cast(trailingReshape.getType()); + SmallVector shape(ty.getShape().begin(), ty.getShape().end()); + trailingMemDescSteps.push_back( + ViewStep{ViewStep::Reshape, std::move(shape), {}, + trailingReshape.getLoc()}); + beforeTrailing = trailingReshape.getSrc(); + continue; + } + break; } auto localAlloc = beforeTrailing.getDefiningOp(); @@ -452,9 +469,14 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp } } - if (!trailingMemDescTransOrder.empty()) { - rewritten = MemDescTransOp::create(rewriter, localAlloc.getLoc(), - rewritten, trailingMemDescTransOrder); + for (ViewStep &step : llvm::reverse(trailingMemDescSteps)) { + if (step.kind == ViewStep::Reshape) { + rewritten = + MemDescReshapeOp::create(rewriter, step.loc, rewritten, step.shape); + } else { + rewritten = + MemDescTransOp::create(rewriter, step.loc, rewritten, step.order); + } } auto rewrittenTy = cast(rewritten.getType()); From 746c28a63f03da83a24e5d05e72624e0380fa5cd Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 28 Mar 2026 11:11:13 +0900 Subject: [PATCH 16/54] update bw lit --- test/TritonGPU/dot-operands.mlir | 57 ++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index eb4bc9f9f79f..1d018f183b01 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -323,39 +323,46 @@ module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num // ----- -#blockedA2 = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#blockedB2 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#blockedA3 = #ttg.blocked<{sizePerThread = [1, 1, 16], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 4], order = [2, 1, 0]}> -#blockedB3 = #ttg.blocked<{sizePerThread = [1, 1, 16], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 4], order = [2, 1, 0]}> -#shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8}> +#blocked0 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> +#linear0 = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> +#linear2 = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 128], [16, 0], [32, 0], [64, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], warp = [[0, 32], [0, 64]], block = []}> +#shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> #smem = #ttg.shared_memory -#tmem0 = #ttng.tensor_memory_encoding +#tmem0 = #ttng.tensor_memory_encoding module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @swizzle0_operand_views - // CHECK-DAG: %[[A_LOAD:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<1x64x256xf8E4M3FN, #{{.*}}> - // CHECK-DAG: %[[A_BASE:.*]] = ttg.local_alloc %[[A_LOAD]] : (tensor<1x64x256xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<1x64x256xf8E4M3FN, #{{.*}}, #smem> - // CHECK-DAG: %[[A_RS:.*]] = ttg.memdesc_reshape %[[A_BASE]] : !ttg.memdesc<1x64x256xf8E4M3FN, #{{.*}}, #smem> -> !ttg.memdesc<64x256xf8E4M3FN, #{{.*}}, #smem> - // CHECK-DAG: %[[A_TR:.*]] = ttg.memdesc_trans %[[A_RS]] {order = array} : !ttg.memdesc<64x256xf8E4M3FN, #{{.*}}, #smem> -> !ttg.memdesc<256x64xf8E4M3FN, #{{.*}}, #smem> - // CHECK-DAG: %[[B_LOAD:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<1x64x128xf8E4M3FN, #{{.*}}> - // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_LOAD]] : (tensor<1x64x128xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<1x64x128xf8E4M3FN, #{{.*}}, #smem> - // CHECK-DAG: %[[B_RS:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x64x128xf8E4M3FN, #{{.*}}, #smem> -> !ttg.memdesc<64x128xf8E4M3FN, #{{.*}}, #smem> + // CHECK-DAG: %[[A_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[A_BASE:.*]] = ttg.local_alloc %[[A_DESC]] : (tensor<128x128xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<128x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC]] : (tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_TR0:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR0]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_TR1:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, #{{.*}}, #smem{{.*}}> // CHECK-NOT: tt.reshape // CHECK-NOT: tt.trans - // CHECK: ttng.tc_gen5_mma %[[A_TR]], %[[B_RS]], %arg2, %true, %true + // CHECK-NOT: ttg.local_load + // CHECK: ttng.tc_gen5_mma %[[A_BASE]], %[[B_TR1]], %arg2, %true, %true tt.func @swizzle0_operand_views( - %a_desc: !tt.tensordesc>, - %b_desc: !tt.tensordesc>, - %acc: !ttg.memdesc<256x128xf32, #tmem0, #ttng.tensor_memory>) { + %a_desc: !tt.tensordesc>, + %b_desc: !tt.tensordesc>, + %acc: !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 - %a = tt.descriptor_load %a_desc[%c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<1x64x256xf8E4M3FN, #blockedA3> - %b = tt.descriptor_load %b_desc[%c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<1x64x128xf8E4M3FN, #blockedB3> - %a2d = tt.reshape %a : tensor<1x64x256xf8E4M3FN, #blockedA3> -> tensor<64x256xf8E4M3FN, #blockedB2> - %aT = tt.trans %a2d {order = array} : tensor<64x256xf8E4M3FN, #blockedB2> -> tensor<256x64xf8E4M3FN, #blockedA2> - %b2d = tt.reshape %b : tensor<1x64x128xf8E4M3FN, #blockedB3> -> tensor<64x128xf8E4M3FN, #blockedB2> - %a_s = ttg.local_alloc %aT : (tensor<256x64xf8E4M3FN, #blockedA2>) -> !ttg.memdesc<256x64xf8E4M3FN, #shared0, #smem> - %b_s = ttg.local_alloc %b2d : (tensor<64x128xf8E4M3FN, #blockedB2>) -> !ttg.memdesc<64x128xf8E4M3FN, #shared0, #smem> - ttng.tc_gen5_mma %a_s, %b_s, %acc, %true, %true : !ttg.memdesc<256x64xf8E4M3FN, #shared0, #smem>, !ttg.memdesc<64x128xf8E4M3FN, #shared0, #smem>, !ttg.memdesc<256x128xf32, #tmem0, #ttng.tensor_memory> + %a = tt.descriptor_load %a_desc[%c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blocked0> + %b = tt.descriptor_load %b_desc[%c0_i32, %c0_i32, %c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #blocked1> + %b1 = tt.reshape %b : tensor<1x1x256x8x16xf8E4M3FN, #blocked1> -> tensor<8x32x8x16xf8E4M3FN, #blocked2> + %b2 = tt.trans %b1 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blocked2> -> tensor<32x8x8x16xf8E4M3FN, #blocked3> + %b3 = tt.reshape %b2 : tensor<32x8x8x16xf8E4M3FN, #blocked3> -> tensor<256x128xf8E4M3FN, #linear0> + %b4 = tt.trans %b3 {order = array} : tensor<256x128xf8E4M3FN, #linear0> -> tensor<128x256xf8E4M3FN, #linear2> + %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blocked0>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared0, #smem> + %b_s = ttg.local_alloc %b4 : (tensor<128x256xf8E4M3FN, #linear2>) -> !ttg.memdesc<128x256xf8E4M3FN, #shared0, #smem> + ttng.tc_gen5_mma %a_s, %b_s, %acc, %true, %true : !ttg.memdesc<128x128xf8E4M3FN, #shared0, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #shared0, #smem>, !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory> tt.return } From 1d02e00f498083fede5010424b66182e3f17e7de Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 28 Mar 2026 11:18:36 +0900 Subject: [PATCH 17/54] update for hop --- .../Transforms/OptimizeDotOperands.cpp | 36 +++++++++++++++++++ test/TritonGPU/dot-operands.mlir | 6 ++-- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index cb778fe36d15..f771ea1b93bd 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -14,7 +14,9 @@ #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Tools/LayoutUtils.h" #include "triton/Tools/LinearLayout.h" +#include #include +#include namespace mlir::triton::gpu { @@ -327,11 +329,45 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp LogicalResult matchAndRewrite(DotOpTy dotOp, PatternRewriter &rewriter) const override { + Value oldA = dotOp.getA(); + Value oldB = dotOp.getB(); bool changed = false; if (succeeded(rewriteOperand(dotOp.getAMutable(), rewriter))) changed = true; if (succeeded(rewriteOperand(dotOp.getBMutable(), rewriter))) changed = true; + + // Keep warp_group_dot_wait operands consistent with rewritten dot operands. + // The wait op is variadic and can carry [dot_result, A, B], so after + // changing dotOp's A/B we need to retarget corresponding wait operands. + if constexpr (std::is_same_v) { + if (changed) { + Value newA = dotOp.getA(); + Value newB = dotOp.getB(); + for (Operation *user : dotOp.getResult().getUsers()) { + auto waitOp = dyn_cast(user); + if (!waitOp) + continue; + bool waitChanged = false; + for (OpOperand &operand : waitOp->getOpOperands()) { + if (operand.get() == oldA) { + operand.assign(newA); + waitChanged = true; + } else if (operand.get() == oldB) { + operand.assign(newB); + waitChanged = true; + } + } + if (!waitChanged) + continue; + rewriter.modifyOpInPlace(waitOp, [&]() { + unsigned n = std::min(waitOp->getNumOperands(), waitOp->getNumResults()); + for (unsigned i = 0; i < n; ++i) + waitOp->getResult(i).setType(waitOp->getOperand(i).getType()); + }); + } + } + } return changed ? success() : failure(); } diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 1d018f183b01..d18f46d7848d 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -389,7 +389,8 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK-DAG: %[[B_RS:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x64x64xf16, #{{.*}}, #smem> -> !ttg.memdesc<64x64xf16, #{{.*}}, #smem> // CHECK-NOT: tt.reshape // CHECK-NOT: tt.trans - // CHECK: ttng.warp_group_dot %[[A_TR]], %[[B_RS]], %arg2 + // CHECK: %[[DOT:.*]] = ttng.warp_group_dot %[[A_TR]], %[[B_RS]], %arg2 + // CHECK: %[[WAIT:.*]]:3 = ttng.warp_group_dot_wait %[[DOT]], %[[A_TR]], %[[B_RS]] {pendings = 0 : i32} tt.func @swizzle0_operand_views_warp_group_dot( %a_desc: !tt.tensordesc>, %b_desc: !tt.tensordesc>, @@ -403,7 +404,8 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- %a_s = ttg.local_alloc %aT : (tensor<64x64xf16, #blockedA2_h>) -> !ttg.memdesc<64x64xf16, #shared0_h2, #smem> %b_s = ttg.local_alloc %b2d : (tensor<64x64xf16, #blockedB2_h>) -> !ttg.memdesc<64x64xf16, #shared0_h2, #smem> %r = ttng.warp_group_dot %a_s, %b_s, %acc : !ttg.memdesc<64x64xf16, #shared0_h2, #smem> * !ttg.memdesc<64x64xf16, #shared0_h2, #smem> -> tensor<64x64xf32, #mma_h2> - tt.return %r : tensor<64x64xf32, #mma_h2> + %w:3 = ttng.warp_group_dot_wait %r, %a_s, %b_s {pendings = 0 : i32} : tensor<64x64xf32, #mma_h2>, !ttg.memdesc<64x64xf16, #shared0_h2, #smem>, !ttg.memdesc<64x64xf16, #shared0_h2, #smem> + tt.return %w#0 : tensor<64x64xf32, #mma_h2> } } From be6eb9323677886857a62431f0fc4431b0065d22 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 28 Mar 2026 11:24:00 +0900 Subject: [PATCH 18/54] upd --- test/TritonGPU/dot-operands.mlir | 51 ++++---------------------------- 1 file changed, 6 insertions(+), 45 deletions(-) diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index d18f46d7848d..2fe5d32d0744 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -44,7 +44,7 @@ module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num- #sharedB_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}> #smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: @descriptor_load_warp_group_dot_swizzle0_operand_views + // CHECK-LABEL: @swizzle0_operand_views_warp_group_dot // CHECK-DAG: %[[B_DESC_LOAD:.*]] = tt.descriptor_load %arg1[ // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC_LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, {{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> @@ -53,8 +53,9 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK-DAG: %[[B_T:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, {{.*}}, #smem{{.*}}> // CHECK-NOT: tt.reshape // CHECK-NOT: tt.trans - // CHECK: ttng.warp_group_dot %{{.*}}, %[[B_T]], %arg2 - tt.func @descriptor_load_warp_group_dot_swizzle0_operand_views( + // CHECK: %[[DOT:.*]] = ttng.warp_group_dot %{{.*}}, %[[B_T]], %arg2 + // CHECK: %[[WAIT:.*]]:3 = ttng.warp_group_dot_wait %[[DOT]], {{.*}}, %[[B_T]] {pendings = 0 : i32} + tt.func @swizzle0_operand_views_warp_group_dot( %a_desc: !tt.tensordesc>, %b_desc: !tt.tensordesc>, %acc: tensor<128x256xf32, #mma_desc>) -> tensor<128x256xf32, #mma_desc> { @@ -69,7 +70,8 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- %b_t = ttg.memdesc_trans %b_s {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #sharedA_desc, #smem> -> !ttg.memdesc<128x256xf8E4M3FN, #sharedB_desc, #smem> %r = ttng.warp_group_dot %a_s, %b_t, %acc {inputPrecision = 0 : i32, isAsync = true, maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> * !ttg.memdesc<128x256xf8E4M3FN, #sharedB_desc, #smem> -> tensor<128x256xf32, #mma_desc> - tt.return %r : tensor<128x256xf32, #mma_desc> + %w:3 = ttng.warp_group_dot_wait %r, %a_s, %b_t {pendings = 0 : i32} : tensor<128x256xf32, #mma_desc>, !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #sharedB_desc, #smem> + tt.return %w#0 : tensor<128x256xf32, #mma_desc> } } @@ -370,47 +372,6 @@ module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num // ----- -#blockedA2_h = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#blockedB2_h = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> -#blockedA3_h = #ttg.blocked<{sizePerThread = [1, 1, 16], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 4], order = [2, 1, 0]}> -#blockedB3_h = #ttg.blocked<{sizePerThread = [1, 1, 16], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 4], order = [2, 1, 0]}> -#mma_h2 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> -#shared0_h2 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 16}> -#sharedA_h2 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> -#smem = #ttg.shared_memory -module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: @swizzle0_operand_views_warp_group_dot - // CHECK-DAG: %[[A_LOAD:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<1x64x64xf16, #{{.*}}> - // CHECK-DAG: %[[A_BASE:.*]] = ttg.local_alloc %[[A_LOAD]] : (tensor<1x64x64xf16, #{{.*}}>) -> !ttg.memdesc<1x64x64xf16, #{{.*}}, #smem> - // CHECK-DAG: %[[A_RS:.*]] = ttg.memdesc_reshape %[[A_BASE]] : !ttg.memdesc<1x64x64xf16, #{{.*}}, #smem> -> !ttg.memdesc<64x64xf16, #{{.*}}, #smem> - // CHECK-DAG: %[[A_TR:.*]] = ttg.memdesc_trans %[[A_RS]] {order = array} : !ttg.memdesc<64x64xf16, #{{.*}}, #smem> -> !ttg.memdesc<64x64xf16, #{{.*}}, #smem> - // CHECK-DAG: %[[B_LOAD:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<1x64x64xf16, #{{.*}}> - // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_LOAD]] : (tensor<1x64x64xf16, #{{.*}}>) -> !ttg.memdesc<1x64x64xf16, #{{.*}}, #smem> - // CHECK-DAG: %[[B_RS:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x64x64xf16, #{{.*}}, #smem> -> !ttg.memdesc<64x64xf16, #{{.*}}, #smem> - // CHECK-NOT: tt.reshape - // CHECK-NOT: tt.trans - // CHECK: %[[DOT:.*]] = ttng.warp_group_dot %[[A_TR]], %[[B_RS]], %arg2 - // CHECK: %[[WAIT:.*]]:3 = ttng.warp_group_dot_wait %[[DOT]], %[[A_TR]], %[[B_RS]] {pendings = 0 : i32} - tt.func @swizzle0_operand_views_warp_group_dot( - %a_desc: !tt.tensordesc>, - %b_desc: !tt.tensordesc>, - %acc: tensor<64x64xf32, #mma_h2>) -> tensor<64x64xf32, #mma_h2> { - %c0_i32 = arith.constant 0 : i32 - %a = tt.descriptor_load %a_desc[%c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<1x64x64xf16, #blockedA3_h> - %b = tt.descriptor_load %b_desc[%c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<1x64x64xf16, #blockedB3_h> - %a2d = tt.reshape %a : tensor<1x64x64xf16, #blockedA3_h> -> tensor<64x64xf16, #blockedB2_h> - %aT = tt.trans %a2d {order = array} : tensor<64x64xf16, #blockedB2_h> -> tensor<64x64xf16, #blockedA2_h> - %b2d = tt.reshape %b : tensor<1x64x64xf16, #blockedB3_h> -> tensor<64x64xf16, #blockedB2_h> - %a_s = ttg.local_alloc %aT : (tensor<64x64xf16, #blockedA2_h>) -> !ttg.memdesc<64x64xf16, #shared0_h2, #smem> - %b_s = ttg.local_alloc %b2d : (tensor<64x64xf16, #blockedB2_h>) -> !ttg.memdesc<64x64xf16, #shared0_h2, #smem> - %r = ttng.warp_group_dot %a_s, %b_s, %acc : !ttg.memdesc<64x64xf16, #shared0_h2, #smem> * !ttg.memdesc<64x64xf16, #shared0_h2, #smem> -> tensor<64x64xf32, #mma_h2> - %w:3 = ttng.warp_group_dot_wait %r, %a_s, %b_s {pendings = 0 : i32} : tensor<64x64xf32, #mma_h2>, !ttg.memdesc<64x64xf16, #shared0_h2, #smem>, !ttg.memdesc<64x64xf16, #shared0_h2, #smem> - tt.return %w#0 : tensor<64x64xf32, #mma_h2> - } -} - -// ----- - #blocked = #ttg.blocked<{sizePerThread = [1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 32], warpsPerCTA = [1, 2, 2, 1], order = [3, 2, 1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 32}> From 0fa2e71439f2631efee53acf0e520394cca963a3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 28 Mar 2026 11:34:50 +0900 Subject: [PATCH 19/54] upd --- test/TritonGPU/dot-operands.mlir | 98 ++++++++++++++++---------------- 1 file changed, 50 insertions(+), 48 deletions(-) diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 2fe5d32d0744..677644133964 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -31,51 +31,6 @@ module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num- } -// ----- - -#blockedA_desc = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blockedB_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> -#blockedB0_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> -#blockedB1_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> -#linear_desc = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> -#mma_desc = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 32]}> -#sharedA_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> -#sharedB0_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> -#sharedB_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}> -#smem = #ttg.shared_memory -module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: @swizzle0_operand_views_warp_group_dot - // CHECK-DAG: %[[B_DESC_LOAD:.*]] = tt.descriptor_load %arg1[ - // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC_LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, {{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_TR:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_T:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-NOT: tt.reshape - // CHECK-NOT: tt.trans - // CHECK: %[[DOT:.*]] = ttng.warp_group_dot %{{.*}}, %[[B_T]], %arg2 - // CHECK: %[[WAIT:.*]]:3 = ttng.warp_group_dot_wait %[[DOT]], {{.*}}, %[[B_T]] {pendings = 0 : i32} - tt.func @swizzle0_operand_views_warp_group_dot( - %a_desc: !tt.tensordesc>, - %b_desc: !tt.tensordesc>, - %acc: tensor<128x256xf32, #mma_desc>) -> tensor<128x256xf32, #mma_desc> { - %c0 = arith.constant 0 : i32 - %a = tt.descriptor_load %a_desc[%c0, %c0] : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blockedA_desc> - %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blockedA_desc>) -> !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> - %b_cm = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> - %b0 = tt.reshape %b_cm : tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> -> tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> - %b1 = tt.trans %b0 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> -> tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> - %b2 = tt.reshape %b1 : tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> -> tensor<256x128xf8E4M3FN, #linear_desc> - %b_s = ttg.local_alloc %b2 : (tensor<256x128xf8E4M3FN, #linear_desc>) -> !ttg.memdesc<256x128xf8E4M3FN, #sharedA_desc, #smem> - %b_t = ttg.memdesc_trans %b_s {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #sharedA_desc, #smem> -> !ttg.memdesc<128x256xf8E4M3FN, #sharedB_desc, #smem> - %r = ttng.warp_group_dot %a_s, %b_t, %acc {inputPrecision = 0 : i32, isAsync = true, maxNumImpreciseAcc = 1073741824 : i32} - : !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> * !ttg.memdesc<128x256xf8E4M3FN, #sharedB_desc, #smem> -> tensor<128x256xf32, #mma_desc> - %w:3 = ttng.warp_group_dot_wait %r, %a_s, %b_t {pendings = 0 : i32} : tensor<128x256xf32, #mma_desc>, !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #sharedB_desc, #smem> - tt.return %w#0 : tensor<128x256xf32, #mma_desc> - } -} - - // ----- #blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> @@ -336,13 +291,15 @@ module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num #shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> #smem = #ttg.shared_memory #tmem0 = #ttng.tensor_memory_encoding +// CHECK-DAG: #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> +// CHECK-DAG: #shared2 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 4}> module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @swizzle0_operand_views // CHECK-DAG: %[[A_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #{{.*}}> // CHECK-DAG: %[[A_BASE:.*]] = ttg.local_alloc %[[A_DESC]] : (tensor<128x128xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<128x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> - // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC]] : (tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC]] : (tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #shared, #smem{{.*}}> + // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #shared, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, #shared2, #smem{{.*}}> // CHECK-DAG: %[[B_TR0:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR0]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> // CHECK-DAG: %[[B_TR1:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, #{{.*}}, #smem{{.*}}> @@ -370,6 +327,51 @@ module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num } + +// ----- + +#blockedA_desc = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +#blockedB0_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> +#blockedB1_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> +#linear_desc = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> +#mma_desc = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 32]}> +#sharedA_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#sharedB0_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> +#sharedB_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}> +#smem = #ttg.shared_memory +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @swizzle0_operand_views_warp_group_dot + // CHECK-DAG: %[[B_DESC_LOAD:.*]] = tt.descriptor_load %arg1[ + // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC_LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, {{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_TR:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_T:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-NOT: tt.reshape + // CHECK-NOT: tt.trans + // CHECK: %[[DOT:.*]] = ttng.warp_group_dot %{{.*}}, %[[B_T]], %arg2 + // CHECK: %[[WAIT:.*]]:3 = ttng.warp_group_dot_wait %[[DOT]], {{.*}}, %[[B_T]] {pendings = 0 : i32} + tt.func @swizzle0_operand_views_warp_group_dot( + %a_desc: !tt.tensordesc>, + %b_desc: !tt.tensordesc>, + %acc: tensor<128x256xf32, #mma_desc>) -> tensor<128x256xf32, #mma_desc> { + %c0 = arith.constant 0 : i32 + %a = tt.descriptor_load %a_desc[%c0, %c0] : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blockedA_desc> + %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blockedA_desc>) -> !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> + %b_cm = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> + %b0 = tt.reshape %b_cm : tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> -> tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> + %b1 = tt.trans %b0 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> -> tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> + %b2 = tt.reshape %b1 : tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> -> tensor<256x128xf8E4M3FN, #linear_desc> + %b_s = ttg.local_alloc %b2 : (tensor<256x128xf8E4M3FN, #linear_desc>) -> !ttg.memdesc<256x128xf8E4M3FN, #sharedA_desc, #smem> + %b_t = ttg.memdesc_trans %b_s {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #sharedA_desc, #smem> -> !ttg.memdesc<128x256xf8E4M3FN, #sharedB_desc, #smem> + %r = ttng.warp_group_dot %a_s, %b_t, %acc {inputPrecision = 0 : i32, isAsync = true, maxNumImpreciseAcc = 1073741824 : i32} + : !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> * !ttg.memdesc<128x256xf8E4M3FN, #sharedB_desc, #smem> -> tensor<128x256xf32, #mma_desc> + %w:3 = ttng.warp_group_dot_wait %r, %a_s, %b_t {pendings = 0 : i32} : tensor<128x256xf32, #mma_desc>, !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #sharedB_desc, #smem> + tt.return %w#0 : tensor<128x256xf32, #mma_desc> + } +} + // ----- #blocked = #ttg.blocked<{sizePerThread = [1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 32], warpsPerCTA = [1, 2, 2, 1], order = [3, 2, 1, 0]}> From 5e45dac401f6aab6177ebe8996869b62301c6714 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 31 Mar 2026 20:31:00 +0000 Subject: [PATCH 20/54] clean test --- .../TritonGPU/Transforms/OptimizeDotOperands.cpp | 16 ++++++++-------- .../test/unit/language/test_tensor_descriptor.py | 9 ++------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index f771ea1b93bd..0d747f295dd0 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -361,7 +361,8 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp if (!waitChanged) continue; rewriter.modifyOpInPlace(waitOp, [&]() { - unsigned n = std::min(waitOp->getNumOperands(), waitOp->getNumResults()); + unsigned n = + std::min(waitOp->getNumOperands(), waitOp->getNumResults()); for (unsigned i = 0; i < n; ++i) waitOp->getResult(i).setType(waitOp->getOperand(i).getType()); }); @@ -410,9 +411,8 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp if (auto trailingTrans = beforeTrailing.getDefiningOp()) { SmallVector order(trailingTrans.getOrder().begin(), trailingTrans.getOrder().end()); - trailingMemDescSteps.push_back( - ViewStep{ViewStep::Transpose, {}, std::move(order), - trailingTrans.getLoc()}); + trailingMemDescSteps.push_back(ViewStep{ + ViewStep::Transpose, {}, std::move(order), trailingTrans.getLoc()}); beforeTrailing = trailingTrans.getSrc(); continue; } @@ -420,9 +420,8 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp beforeTrailing.getDefiningOp()) { auto ty = cast(trailingReshape.getType()); SmallVector shape(ty.getShape().begin(), ty.getShape().end()); - trailingMemDescSteps.push_back( - ViewStep{ViewStep::Reshape, std::move(shape), {}, - trailingReshape.getLoc()}); + trailingMemDescSteps.push_back(ViewStep{ + ViewStep::Reshape, std::move(shape), {}, trailingReshape.getLoc()}); beforeTrailing = trailingReshape.getSrc(); continue; } @@ -483,7 +482,8 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp if (sourceIsZeroSwizzleLike && sourceSharedEnc) refSharedEnc = sourceSharedEnc; - auto baseEnc = updateEncodingForShape(localAlloc, refSharedEnc, baseTensorTy); + auto baseEnc = + updateEncodingForShape(localAlloc, refSharedEnc, baseTensorTy); auto baseMemTy = MemDescType::get(baseTensorTy.getShape(), baseTensorTy.getElementType(), cast(baseEnc), allocTy.getMemorySpace(), diff --git a/python/test/unit/language/test_tensor_descriptor.py b/python/test/unit/language/test_tensor_descriptor.py index fa728d3301db..98f90861f402 100644 --- a/python/test/unit/language/test_tensor_descriptor.py +++ b/python/test/unit/language/test_tensor_descriptor.py @@ -4,7 +4,7 @@ import triton import triton.language as tl -from triton._internal_testing import is_hopper, is_sm12x, is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy +from triton._internal_testing import is_hopper, is_blackwell, is_sm12x, is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor from typing import Optional from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3 @@ -1842,7 +1842,7 @@ def test_host_tensor_descriptor_matmul(num_stages, num_ctas, BLOCK_M, BLOCK_N, B "ptx"] or "stmatrix.sync.aligned.x4.m8n8.shared.b16" in kernel.asm["ptx"] -# TODO: require blackwell +@pytest.mark.skipif(not (is_hopper() or is_blackwell()), reason="Requires Hopper or Blackwell") def test_host_tensor_descriptor_matmul_fp8_swizzle0_b(device): M = N = K = 512 BLOCK_M = 128 @@ -1877,16 +1877,11 @@ def test_host_tensor_descriptor_matmul_fp8_swizzle0_b(device): num_stages=1, ) - # Compare against quantized operands actually consumed by the kernel. ref_out = torch.matmul(a.to(torch.float32), b.to(torch.float32).T) torch.testing.assert_close(ref_out, c, rtol=1.5e-1, atol=2.5e-1) ttgir = kernel.asm["ttgir"] - # assert "ttng.tc_gen5_mma" in ttgir assert "swizzlingByteWidth = 0" in ttgir and "#ttg.shared_linear" in ttgir - # assert any(f"swizzlingByteWidth = {w}" in ttgir for w in [32, 64, 128]) - -test_host_tensor_descriptor_matmul_fp8_swizzle0_b("cuda") @pytest.mark.interpreter From e7d54f851036cc517305b4d98e6b7518d9875b98 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 1 Apr 2026 05:52:20 +0900 Subject: [PATCH 21/54] refactoring operand update --- .../Transforms/OptimizeDotOperands.cpp | 79 ++++++++++--------- 1 file changed, 42 insertions(+), 37 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 0d747f295dd0..e821d969cd56 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -331,48 +331,53 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp PatternRewriter &rewriter) const override { Value oldA = dotOp.getA(); Value oldB = dotOp.getB(); - bool changed = false; - if (succeeded(rewriteOperand(dotOp.getAMutable(), rewriter))) - changed = true; - if (succeeded(rewriteOperand(dotOp.getBMutable(), rewriter))) - changed = true; - - // Keep warp_group_dot_wait operands consistent with rewritten dot operands. - // The wait op is variadic and can carry [dot_result, A, B], so after - // changing dotOp's A/B we need to retarget corresponding wait operands. - if constexpr (std::is_same_v) { - if (changed) { - Value newA = dotOp.getA(); - Value newB = dotOp.getB(); - for (Operation *user : dotOp.getResult().getUsers()) { - auto waitOp = dyn_cast(user); - if (!waitOp) - continue; - bool waitChanged = false; - for (OpOperand &operand : waitOp->getOpOperands()) { - if (operand.get() == oldA) { - operand.assign(newA); - waitChanged = true; - } else if (operand.get() == oldB) { - operand.assign(newB); - waitChanged = true; - } - } - if (!waitChanged) + bool changedA = rewriteOperand(dotOp.getAMutable(), rewriter).succeeded(); + bool changedB = rewriteOperand(dotOp.getBMutable(), rewriter).succeeded(); + + if (changedA || changedB) { + updateDependentOps(dotOp, oldA, oldB, rewriter); + return success(); + } + + return failure(); + } + +private: + template + static void updateDependentOps(T, Value, Value, PatternRewriter &) {} + + static void updateDependentOps(triton::nvidia_gpu::WarpGroupDotOp dotOp, + Value oldA, Value oldB, + PatternRewriter &rewriter) { + // Keep warp_group_dot_wait operands consistent with rewritten dot + // operands. The wait op is variadic and can carry [dot_result, A, B], so + // after changing dotOp's A/B we need to retarget corresponding wait + // operands. + Value newA = dotOp.getA(); + Value newB = dotOp.getB(); + for (Operation *user : dotOp.getResult().getUsers()) { + auto waitOp = dyn_cast(user); + if (!waitOp) + continue; + rewriter.modifyOpInPlace(waitOp, [&]() { + for (OpOperand &operand : waitOp->getOpOperands()) { + Value replacement; + if (operand.get() == oldA) + replacement = newA; + else if (operand.get() == oldB) + replacement = newB; + else continue; - rewriter.modifyOpInPlace(waitOp, [&]() { - unsigned n = - std::min(waitOp->getNumOperands(), waitOp->getNumResults()); - for (unsigned i = 0; i < n; ++i) - waitOp->getResult(i).setType(waitOp->getOperand(i).getType()); - }); + + operand.assign(replacement); + if (operand.getOperandNumber() < waitOp->getNumResults()) + waitOp->getResult(operand.getOperandNumber()) + .setType(replacement.getType()); } - } + }); } - return changed ? success() : failure(); } -private: static bool isZeroSwizzleCompatibleEncoding(Attribute encoding) { if (auto nvmma = dyn_cast(encoding)) return nvmma.getSwizzlingByteWidth() == 0; From 3291122fc6187e43307a92cc04a048814ac353ae Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 1 Apr 2026 06:39:16 +0900 Subject: [PATCH 22/54] wip --- .../Transforms/OptimizeDotOperands.cpp | 84 ++++++++----------- 1 file changed, 33 insertions(+), 51 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index e821d969cd56..2fff034caba7 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -12,6 +12,7 @@ #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" #include "triton/Tools/LayoutUtils.h" #include "triton/Tools/LinearLayout.h" #include @@ -391,14 +392,36 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp Location loc; }; + template + static SmallVector collectViewSteps(Value &value) { + SmallVector steps; + while (true) { + if (auto reshape = value.getDefiningOp()) { + auto ty = reshape.getType(); + SmallVector shape(ty.getShape().begin(), ty.getShape().end()); + steps.push_back( + ViewStep{ViewStep::Reshape, std::move(shape), {}, reshape.getLoc()}); + value = reshape.getSrc(); + continue; + } + if (auto trans = value.getDefiningOp()) { + SmallVector order(trans.getOrder().begin(), + trans.getOrder().end()); + steps.push_back(ViewStep{ViewStep::Transpose, {}, std::move(order), + trans.getLoc()}); + value = trans.getSrc(); + continue; + } + break; + } + return steps; + } + static SharedEncodingTrait getSourceSharedEncoding(Value baseTensor) { if (auto descLoad = baseTensor.getDefiningOp()) { - auto descTy = - dyn_cast(descLoad.getDesc().getType()); - if (!descTy) - return nullptr; - return dyn_cast_or_null( - descTy.getBlockType().getEncoding()); + if (auto tensorTy = dyn_cast(descLoad.getType())) + return triton::nvidia_gpu::getEncodingFromDescriptor( + descLoad, tensorTy, descLoad.getDesc()); } return nullptr; } @@ -410,28 +433,9 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp if (!origTy) return failure(); - SmallVector trailingMemDescSteps; Value beforeTrailing = orig; - while (true) { - if (auto trailingTrans = beforeTrailing.getDefiningOp()) { - SmallVector order(trailingTrans.getOrder().begin(), - trailingTrans.getOrder().end()); - trailingMemDescSteps.push_back(ViewStep{ - ViewStep::Transpose, {}, std::move(order), trailingTrans.getLoc()}); - beforeTrailing = trailingTrans.getSrc(); - continue; - } - if (auto trailingReshape = - beforeTrailing.getDefiningOp()) { - auto ty = cast(trailingReshape.getType()); - SmallVector shape(ty.getShape().begin(), ty.getShape().end()); - trailingMemDescSteps.push_back(ViewStep{ - ViewStep::Reshape, std::move(shape), {}, trailingReshape.getLoc()}); - beforeTrailing = trailingReshape.getSrc(); - continue; - } - break; - } + SmallVector trailingMemDescSteps = + collectViewSteps(beforeTrailing); auto localAlloc = beforeTrailing.getDefiningOp(); if (!localAlloc || !localAlloc.getSrc()) @@ -442,31 +446,9 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp if (!allocSharedEnc) return failure(); - SmallVector reverseSteps; Value baseTensor = localAlloc.getSrc(); - while (true) { - if (auto cvt = baseTensor.getDefiningOp()) { - baseTensor = cvt.getSrc(); - continue; - } - if (auto reshape = baseTensor.getDefiningOp()) { - SmallVector shape(reshape.getType().getShape().begin(), - reshape.getType().getShape().end()); - reverseSteps.push_back(ViewStep{ - ViewStep::Reshape, std::move(shape), {}, reshape.getLoc()}); - baseTensor = reshape.getSrc(); - continue; - } - if (auto trans = baseTensor.getDefiningOp()) { - SmallVector order(trans.getOrder().begin(), - trans.getOrder().end()); - reverseSteps.push_back(ViewStep{ - ViewStep::Transpose, {}, std::move(order), trans.getLoc()}); - baseTensor = trans.getSrc(); - continue; - } - break; - } + SmallVector reverseSteps = + collectViewSteps(baseTensor); if (reverseSteps.empty()) return failure(); From 6637c0dd54401772c5f775bd522b795be207f054 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 1 Apr 2026 07:05:16 +0900 Subject: [PATCH 23/54] more --- .../Transforms/OptimizeDotOperands.cpp | 100 ++++++++++-------- 1 file changed, 55 insertions(+), 45 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 2fff034caba7..fee3daa36f39 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -17,6 +17,7 @@ #include "triton/Tools/LinearLayout.h" #include #include +#include #include namespace mlir::triton::gpu { @@ -393,28 +394,30 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp }; template - static SmallVector collectViewSteps(Value &value) { + static std::tuple> + collectViewSteps(Value value) { + Value current = value; SmallVector steps; while (true) { - if (auto reshape = value.getDefiningOp()) { + if (auto reshape = current.template getDefiningOp()) { auto ty = reshape.getType(); SmallVector shape(ty.getShape().begin(), ty.getShape().end()); - steps.push_back( - ViewStep{ViewStep::Reshape, std::move(shape), {}, reshape.getLoc()}); - value = reshape.getSrc(); + steps.push_back(ViewStep{ + ViewStep::Reshape, std::move(shape), {}, reshape.getLoc()}); + current = reshape.getSrc(); continue; } - if (auto trans = value.getDefiningOp()) { + if (auto trans = current.template getDefiningOp()) { SmallVector order(trans.getOrder().begin(), trans.getOrder().end()); - steps.push_back(ViewStep{ViewStep::Transpose, {}, std::move(order), - trans.getLoc()}); - value = trans.getSrc(); + steps.push_back(ViewStep{ + ViewStep::Transpose, {}, std::move(order), trans.getLoc()}); + current = trans.getSrc(); continue; } break; } - return steps; + return {current, std::move(steps)}; } static SharedEncodingTrait getSourceSharedEncoding(Value baseTensor) { @@ -426,6 +429,35 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp return nullptr; } + static FailureOr + getRewrittenBaseMemDescType(LocalAllocOp localAlloc, MemDescType allocTy, + Value baseTensor) { + auto allocSharedEnc = cast(allocTy.getEncoding()); + auto sourceSharedEnc = getSourceSharedEncoding(baseTensor); + bool allocIsZeroSwizzleLike = + isZeroSwizzleCompatibleEncoding(cast(allocSharedEnc)); + bool sourceIsZeroSwizzleLike = + sourceSharedEnc && + isZeroSwizzleCompatibleEncoding(cast(sourceSharedEnc)); + if (!allocIsZeroSwizzleLike && !sourceIsZeroSwizzleLike) + return failure(); + + auto baseTensorTy = dyn_cast(baseTensor.getType()); + if (!baseTensorTy) + return failure(); + + SharedEncodingTrait refSharedEnc = allocSharedEnc; + if (sourceIsZeroSwizzleLike) + refSharedEnc = sourceSharedEnc; + + auto baseEnc = + updateEncodingForShape(localAlloc, refSharedEnc, baseTensorTy); + return MemDescType::get(baseTensorTy.getShape(), + baseTensorTy.getElementType(), + cast(baseEnc), allocTy.getMemorySpace(), + allocTy.getMutableMemory()); + } + LogicalResult rewriteOperand(OpOperand &operand, PatternRewriter &rewriter) const { Value orig = operand.get(); @@ -433,54 +465,32 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp if (!origTy) return failure(); - Value beforeTrailing = orig; - SmallVector trailingMemDescSteps = - collectViewSteps(beforeTrailing); + auto [beforeTrailing, trailingMemDescSteps] = + collectViewSteps(orig); - auto localAlloc = beforeTrailing.getDefiningOp(); + auto localAlloc = beforeTrailing.template getDefiningOp(); if (!localAlloc || !localAlloc.getSrc()) return failure(); auto allocTy = cast(localAlloc.getType()); - auto allocSharedEnc = dyn_cast(allocTy.getEncoding()); - if (!allocSharedEnc) - return failure(); - Value baseTensor = localAlloc.getSrc(); - SmallVector reverseSteps = - collectViewSteps(baseTensor); + auto [baseTensor, reverseSteps] = + collectViewSteps( + localAlloc.getSrc()); if (reverseSteps.empty()) return failure(); - auto sourceSharedEnc = getSourceSharedEncoding(baseTensor); - bool sourceIsZeroSwizzleLike = - sourceSharedEnc && - isZeroSwizzleCompatibleEncoding(cast(sourceSharedEnc)); - if (!isZeroSwizzleCompatibleEncoding(allocTy.getEncoding()) && - !sourceIsZeroSwizzleLike) - return failure(); - - auto baseTensorTy = dyn_cast(baseTensor.getType()); - if (!baseTensorTy) + FailureOr baseMemTy = + getRewrittenBaseMemDescType(localAlloc, allocTy, baseTensor); + if (failed(baseMemTy)) return failure(); - SharedEncodingTrait refSharedEnc = allocSharedEnc; - if (sourceIsZeroSwizzleLike && sourceSharedEnc) - refSharedEnc = sourceSharedEnc; - - auto baseEnc = - updateEncodingForShape(localAlloc, refSharedEnc, baseTensorTy); - auto baseMemTy = - MemDescType::get(baseTensorTy.getShape(), baseTensorTy.getElementType(), - cast(baseEnc), allocTy.getMemorySpace(), - allocTy.getMutableMemory()); - PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(localAlloc); Value rewritten = LocalAllocOp::create(rewriter, localAlloc.getLoc(), - baseMemTy, baseTensor); + *baseMemTy, baseTensor); for (ViewStep &step : llvm::reverse(reverseSteps)) { if (step.kind == ViewStep::Reshape) { @@ -503,9 +513,9 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp } auto rewrittenTy = cast(rewritten.getType()); - if (rewrittenTy.getShape() != origTy.getShape() || - rewrittenTy.getElementType() != origTy.getElementType()) - return failure(); + assert(rewrittenTy.getShape() == origTy.getShape() && + rewrittenTy.getElementType() == origTy.getElementType() && + "rewrite must preserve memdesc shape and element type"); operand.assign(rewritten); return success(); From 9dcce40b4e3697157119ffed2b1bd98864960691 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 1 Apr 2026 07:45:56 +0900 Subject: [PATCH 24/54] refactor --- .../Transforms/OptimizeDotOperands.cpp | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index fee3daa36f39..7fc5bc4f748e 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -380,12 +380,6 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp } } - static bool isZeroSwizzleCompatibleEncoding(Attribute encoding) { - if (auto nvmma = dyn_cast(encoding)) - return nvmma.getSwizzlingByteWidth() == 0; - return false; - } - struct ViewStep { enum Kind { Reshape, Transpose } kind; SmallVector shape; @@ -397,12 +391,12 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp static std::tuple> collectViewSteps(Value value) { Value current = value; - SmallVector steps; + SmallVector replaySteps; while (true) { if (auto reshape = current.template getDefiningOp()) { auto ty = reshape.getType(); SmallVector shape(ty.getShape().begin(), ty.getShape().end()); - steps.push_back(ViewStep{ + replaySteps.push_back(ViewStep{ ViewStep::Reshape, std::move(shape), {}, reshape.getLoc()}); current = reshape.getSrc(); continue; @@ -410,14 +404,20 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp if (auto trans = current.template getDefiningOp()) { SmallVector order(trans.getOrder().begin(), trans.getOrder().end()); - steps.push_back(ViewStep{ + replaySteps.push_back(ViewStep{ ViewStep::Transpose, {}, std::move(order), trans.getLoc()}); current = trans.getSrc(); continue; } break; } - return {current, std::move(steps)}; + return {current, llvm::to_vector(llvm::reverse(replaySteps))}; + } + + static bool isZeroSwizzleCompatibleEncoding(Attribute encoding) { + if (auto nvmma = dyn_cast(encoding)) + return nvmma.getSwizzlingByteWidth() == 0; + return false; } static SharedEncodingTrait getSourceSharedEncoding(Value baseTensor) { @@ -429,9 +429,9 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp return nullptr; } - static FailureOr - getRewrittenBaseMemDescType(LocalAllocOp localAlloc, MemDescType allocTy, - Value baseTensor) { + static FailureOr getRewrittenBaseMemDescType(LocalAllocOp localAlloc, + MemDescType allocTy, + Value baseTensor) { auto allocSharedEnc = cast(allocTy.getEncoding()); auto sourceSharedEnc = getSourceSharedEncoding(baseTensor); bool allocIsZeroSwizzleLike = @@ -465,7 +465,7 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp if (!origTy) return failure(); - auto [beforeTrailing, trailingMemDescSteps] = + auto [beforeTrailing, trailingMemDescReplaySteps] = collectViewSteps(orig); auto localAlloc = beforeTrailing.template getDefiningOp(); @@ -474,11 +474,11 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp auto allocTy = cast(localAlloc.getType()); - auto [baseTensor, reverseSteps] = + auto [baseTensor, tensorReplaySteps] = collectViewSteps( localAlloc.getSrc()); - if (reverseSteps.empty()) + if (tensorReplaySteps.empty()) return failure(); FailureOr baseMemTy = @@ -492,7 +492,7 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp Value rewritten = LocalAllocOp::create(rewriter, localAlloc.getLoc(), *baseMemTy, baseTensor); - for (ViewStep &step : llvm::reverse(reverseSteps)) { + for (ViewStep &step : tensorReplaySteps) { if (step.kind == ViewStep::Reshape) { rewritten = MemDescReshapeOp::create(rewriter, step.loc, rewritten, step.shape); @@ -502,7 +502,7 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp } } - for (ViewStep &step : llvm::reverse(trailingMemDescSteps)) { + for (ViewStep &step : trailingMemDescReplaySteps) { if (step.kind == ViewStep::Reshape) { rewritten = MemDescReshapeOp::create(rewriter, step.loc, rewritten, step.shape); From 914486063a9cfc1a533422fdccfbc79fec248dc8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 1 Apr 2026 08:00:35 +0900 Subject: [PATCH 25/54] wip --- .../Transforms/OptimizeDotOperands.cpp | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 7fc5bc4f748e..d9dbbfc763df 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -429,9 +429,15 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp return nullptr; } - static FailureOr getRewrittenBaseMemDescType(LocalAllocOp localAlloc, - MemDescType allocTy, - Value baseTensor) { + // Compute the memdesc type for the rewritten base `local_alloc`. The + // original alloc may already have a transformed 2D shared layout suitable for + // the final dot operand, while `baseTensor` is the pre-view tensor we want to + // allocate instead. This helper chooses the zero-swizzle-capable shared + // encoding we should preserve, retargets it to `baseTensor`'s shape, and + // builds the corresponding memdesc type. + static FailureOr + getRewrittenBaseMemDescType(LocalAllocOp localAlloc, Value baseTensor) { + auto allocTy = cast(localAlloc.getType()); auto allocSharedEnc = cast(allocTy.getEncoding()); auto sourceSharedEnc = getSourceSharedEncoding(baseTensor); bool allocIsZeroSwizzleLike = @@ -442,14 +448,11 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp if (!allocIsZeroSwizzleLike && !sourceIsZeroSwizzleLike) return failure(); - auto baseTensorTy = dyn_cast(baseTensor.getType()); - if (!baseTensorTy) - return failure(); - SharedEncodingTrait refSharedEnc = allocSharedEnc; if (sourceIsZeroSwizzleLike) refSharedEnc = sourceSharedEnc; + auto baseTensorTy = cast(baseTensor.getType()); auto baseEnc = updateEncodingForShape(localAlloc, refSharedEnc, baseTensorTy); return MemDescType::get(baseTensorTy.getShape(), @@ -472,8 +475,6 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp if (!localAlloc || !localAlloc.getSrc()) return failure(); - auto allocTy = cast(localAlloc.getType()); - auto [baseTensor, tensorReplaySteps] = collectViewSteps( localAlloc.getSrc()); @@ -482,7 +483,7 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp return failure(); FailureOr baseMemTy = - getRewrittenBaseMemDescType(localAlloc, allocTy, baseTensor); + getRewrittenBaseMemDescType(localAlloc, baseTensor); if (failed(baseMemTy)) return failure(); From da8d60c1e29d8da9d125a71a3e347d99592f0e62 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 1 Apr 2026 08:02:12 +0900 Subject: [PATCH 26/54] fix --- lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index d9dbbfc763df..3a7f581e46b6 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -422,9 +422,9 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp static SharedEncodingTrait getSourceSharedEncoding(Value baseTensor) { if (auto descLoad = baseTensor.getDefiningOp()) { - if (auto tensorTy = dyn_cast(descLoad.getType())) - return triton::nvidia_gpu::getEncodingFromDescriptor( - descLoad, tensorTy, descLoad.getDesc()); + auto descTy = cast(descLoad.getDesc().getType()); + auto descBlockTy = descTy.getBlockType(); + return dyn_cast_or_null(descBlockTy.getEncoding()); } return nullptr; } From a41052a32b772002b5d4ff252b28284266c5b462 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 1 Apr 2026 08:41:16 +0900 Subject: [PATCH 27/54] more clean --- .../Transforms/OptimizeDotOperands.cpp | 51 +++++++------------ 1 file changed, 17 insertions(+), 34 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 3a7f581e46b6..9a91de8c2b9d 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -414,51 +414,34 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp return {current, llvm::to_vector(llvm::reverse(replaySteps))}; } - static bool isZeroSwizzleCompatibleEncoding(Attribute encoding) { - if (auto nvmma = dyn_cast(encoding)) - return nvmma.getSwizzlingByteWidth() == 0; - return false; - } - - static SharedEncodingTrait getSourceSharedEncoding(Value baseTensor) { + static SharedEncodingTrait getSourceSwizzle0SharedEncoding(Value baseTensor) { if (auto descLoad = baseTensor.getDefiningOp()) { auto descTy = cast(descLoad.getDesc().getType()); auto descBlockTy = descTy.getBlockType(); - return dyn_cast_or_null(descBlockTy.getEncoding()); + auto sourceSharedEnc = + dyn_cast_or_null(descBlockTy.getEncoding()); + if (auto nvmma = dyn_cast_or_null( + cast_or_null(sourceSharedEnc)); + nvmma && nvmma.getSwizzlingByteWidth() == 0) + return sourceSharedEnc; } return nullptr; } - // Compute the memdesc type for the rewritten base `local_alloc`. The - // original alloc may already have a transformed 2D shared layout suitable for - // the final dot operand, while `baseTensor` is the pre-view tensor we want to - // allocate instead. This helper chooses the zero-swizzle-capable shared - // encoding we should preserve, retargets it to `baseTensor`'s shape, and - // builds the corresponding memdesc type. - static FailureOr - getRewrittenBaseMemDescType(LocalAllocOp localAlloc, Value baseTensor) { - auto allocTy = cast(localAlloc.getType()); - auto allocSharedEnc = cast(allocTy.getEncoding()); - auto sourceSharedEnc = getSourceSharedEncoding(baseTensor); - bool allocIsZeroSwizzleLike = - isZeroSwizzleCompatibleEncoding(cast(allocSharedEnc)); - bool sourceIsZeroSwizzleLike = - sourceSharedEnc && - isZeroSwizzleCompatibleEncoding(cast(sourceSharedEnc)); - if (!allocIsZeroSwizzleLike && !sourceIsZeroSwizzleLike) + // Build a new memdesc type for the rewritten `local_alloc` by taking the + // original MMA operand memdesc and replacing its shape and shared encoding + // with those from swizzle-0 `tt.descriptor_load` result + static FailureOr getSwizzle0MemDescType(MemDescType refTy, + Value baseTensor) { + auto sourceSharedEnc = getSourceSwizzle0SharedEncoding(baseTensor); + if (!sourceSharedEnc) return failure(); - SharedEncodingTrait refSharedEnc = allocSharedEnc; - if (sourceIsZeroSwizzleLike) - refSharedEnc = sourceSharedEnc; - auto baseTensorTy = cast(baseTensor.getType()); - auto baseEnc = - updateEncodingForShape(localAlloc, refSharedEnc, baseTensorTy); return MemDescType::get(baseTensorTy.getShape(), baseTensorTy.getElementType(), - cast(baseEnc), allocTy.getMemorySpace(), - allocTy.getMutableMemory()); + cast(sourceSharedEnc), + refTy.getMemorySpace(), refTy.getMutableMemory()); } LogicalResult rewriteOperand(OpOperand &operand, @@ -483,7 +466,7 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp return failure(); FailureOr baseMemTy = - getRewrittenBaseMemDescType(localAlloc, baseTensor); + getSwizzle0MemDescType(origTy, baseTensor); if (failed(baseMemTy)) return failure(); From d3eee9663db10f832267fd84bf29a1288a724064 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 1 Apr 2026 08:45:54 +0900 Subject: [PATCH 28/54] add comment --- lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 9a91de8c2b9d..ac7d7ea2bdbb 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -323,6 +323,11 @@ class UseShmemForScales } }; +// Rewrite +// desc_load -> tt.reshape / tt.trans -> local_alloc -> memdesc +// reshape / trans +// into +// desc_load -> local_alloc -> memdesc reshape / trans template class RewriteSwizzle0OperandViewsToMemDescForDotOp : public OpRewritePattern { From b9b6eb4ac76cd51af7adc8a4feb61c1297fcfcf3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 1 Apr 2026 08:58:43 +0900 Subject: [PATCH 29/54] remove stale include --- lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index ac7d7ea2bdbb..a9eb0e86f0f7 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -4,21 +4,15 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "triton/Analysis/Utility.h" -#include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" -#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" #include "triton/Tools/LayoutUtils.h" #include "triton/Tools/LinearLayout.h" -#include #include -#include -#include namespace mlir::triton::gpu { From 2cda92bcb10fe5a43fe7578ea9d9c8e01577b2dc Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 1 Apr 2026 16:34:56 +0900 Subject: [PATCH 30/54] add comment describing the rewrite pattern --- .../TritonGPU/Transforms/OptimizeDotOperands.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index a9eb0e86f0f7..5eb0ed74b4df 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -322,6 +322,22 @@ class UseShmemForScales // reshape / trans // into // desc_load -> local_alloc -> memdesc reshape / trans +// +// swizzle=0 in NVMMASharedEncodingAttr represents a flat, contiguous layout. +// This is valid as the destination encoding for TMA, but it is never the +// correct layout for an MMA operand which requires the special "core-matrices" +// layout even with swizzle=0. So if the result of swizzle-0 TMA is fed into MMA +// without smem layout conversion between them, the result would be incorrect. +// +// When using swizzle-0 TMA with MMA, it is a user's responsibility to have the +// source of TMA in global memory to be already in the core-matrices format, and +// insert a sequence of tt.reshape / tt.trans transformations between desc_load +// and MMA ops such that the MMA op sees the right core-matrices layout. + +// This rewrite pattern ensures that swizzle=0 in TMA and a sequence of +// tt.reshape / tt.trans ops are correctly propagated, via equivalent +// transformations on memdesc, into the right MMA SMEM operand layout with +// swizzle=0. template class RewriteSwizzle0OperandViewsToMemDescForDotOp : public OpRewritePattern { From dcf62c041aa04d6dce3c26f582cda043b5a18072 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 6 Apr 2026 14:38:10 +0900 Subject: [PATCH 31/54] minor --- test/TritonGPU/dot-operands.mlir | 1 - 1 file changed, 1 deletion(-) diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 677644133964..5a73c011c034 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -30,7 +30,6 @@ module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num- } } - // ----- #blocked = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> From 8aec72f8237d2f99d2cbf67bc59acf6ba76fbe3c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 6 Apr 2026 14:41:18 +0900 Subject: [PATCH 32/54] revert cmake change --- third_party/nvidia/CMakeLists.txt | 25 +------------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/third_party/nvidia/CMakeLists.txt b/third_party/nvidia/CMakeLists.txt index 506cba3b2cc6..22d6a9e397f4 100644 --- a/third_party/nvidia/CMakeLists.txt +++ b/third_party/nvidia/CMakeLists.txt @@ -18,34 +18,12 @@ if(TRITON_BUILD_PYTHON_MODULE) message(FATAL_ERROR "clang++ is required to build gsan.ll") endif() - if(DEFINED TRITON_CUDART_PATH AND NOT "${TRITON_CUDART_PATH}" STREQUAL "") - set(GSAN_RUNTIME_CUDA_PATH "${TRITON_CUDART_PATH}") - else() - set(GSAN_RUNTIME_CUDA_PATH "${CMAKE_CURRENT_SOURCE_DIR}/backend") - endif() set(GSAN_RUNTIME_PLATFORM_FLAGS - "--cuda-path=${GSAN_RUNTIME_CUDA_PATH}") + "--cuda-path=${CMAKE_CURRENT_SOURCE_DIR}/backend") if(APPLE) list(APPEND GSAN_RUNTIME_PLATFORM_FLAGS -isysroot "${CMAKE_OSX_SYSROOT}") endif() - set(GSAN_RUNTIME_TOOLCHAIN_FLAGS) - set(GSAN_HOST_GNU_CXX "${CMAKE_CXX_COMPILER}") - if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - find_program(GSAN_HOST_GNU_CXX NAMES g++ c++) - endif() - if(GSAN_HOST_GNU_CXX) - execute_process( - COMMAND "${GSAN_HOST_GNU_CXX}" -print-file-name=libstdc++.so - OUTPUT_VARIABLE LIBSTDCXX_PATH - OUTPUT_STRIP_TRAILING_WHITESPACE - ) - if(IS_ABSOLUTE "${LIBSTDCXX_PATH}") - get_filename_component(GCC_INSTALL_DIR "${LIBSTDCXX_PATH}" DIRECTORY) - list(APPEND GSAN_RUNTIME_TOOLCHAIN_FLAGS "--gcc-install-dir=${GCC_INSTALL_DIR}") - endif() - endif() - add_custom_command( OUTPUT "${GSAN_RUNTIME_IR}" COMMAND "${CMAKE_COMMAND}" -E make_directory @@ -60,7 +38,6 @@ if(TRITON_BUILD_PYTHON_MODULE) -fcuda-flush-denormals-to-zero --cuda-gpu-arch=sm_80 -Wno-unknown-cuda-version - ${GSAN_RUNTIME_TOOLCHAIN_FLAGS} ${GSAN_RUNTIME_PLATFORM_FLAGS} -isystem "${CMAKE_CURRENT_SOURCE_DIR}/clang_cuda_shims" -isystem "${CMAKE_CURRENT_SOURCE_DIR}/backend/include" From fbae09bfbc9c8eb59974e5e72fb1455cda6fdc83 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 7 Apr 2026 08:08:50 +0900 Subject: [PATCH 33/54] update comment to make it more accurate --- lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 5eb0ed74b4df..8e767e498cab 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -324,10 +324,11 @@ class UseShmemForScales // desc_load -> local_alloc -> memdesc reshape / trans // // swizzle=0 in NVMMASharedEncodingAttr represents a flat, contiguous layout. -// This is valid as the destination encoding for TMA, but it is never the -// correct layout for an MMA operand which requires the special "core-matrices" -// layout even with swizzle=0. So if the result of swizzle-0 TMA is fed into MMA -// without smem layout conversion between them, the result would be incorrect. +// This is valid as the destination encoding for TMA, but unless the operand's +// contiguous dimension is <= 16 bytes, it is not the correct layout for an MMA +// operand which requires the special "core-matrices" layout even with +// swizzle=0. So if the result of swizzle-0 TMA is fed into MMA without smem +// layout conversion between them, the result would be incorrect. // // When using swizzle-0 TMA with MMA, it is a user's responsibility to have the // source of TMA in global memory to be already in the core-matrices format, and From e01ce66fcd49df15f1ce2e479828945e45d1dfa4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 8 Apr 2026 19:23:26 +0900 Subject: [PATCH 34/54] Make swizzle0 operand view rewrite sink-driven --- .../Transforms/OptimizeDotOperands.cpp | 119 ++++++++++++++---- test/TritonGPU/dot-operands.mlir | 41 +++--- 2 files changed, 123 insertions(+), 37 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 8e767e498cab..6b785e64f1c7 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -398,8 +398,10 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp struct ViewStep { enum Kind { Reshape, Transpose } kind; - SmallVector shape; + SmallVector srcShape; + SmallVector dstShape; SmallVector order; + Operation *op; Location loc; }; @@ -410,18 +412,29 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp SmallVector replaySteps; while (true) { if (auto reshape = current.template getDefiningOp()) { - auto ty = reshape.getType(); - SmallVector shape(ty.getShape().begin(), ty.getShape().end()); - replaySteps.push_back(ViewStep{ - ViewStep::Reshape, std::move(shape), {}, reshape.getLoc()}); + auto srcTy = reshape.getSrc().getType(); + auto dstTy = reshape.getType(); + replaySteps.push_back(ViewStep{ViewStep::Reshape, + SmallVector(srcTy.getShape()), + SmallVector(dstTy.getShape()), + {}, + reshape.getOperation(), + reshape.getLoc()}); current = reshape.getSrc(); continue; } if (auto trans = current.template getDefiningOp()) { SmallVector order(trans.getOrder().begin(), trans.getOrder().end()); - replaySteps.push_back(ViewStep{ - ViewStep::Transpose, {}, std::move(order), trans.getLoc()}); + auto srcTy = trans.getSrc().getType(); + auto dstTy = trans.getType(); + replaySteps.push_back( + ViewStep{ViewStep::Transpose, + SmallVector(srcTy.getShape()), + SmallVector(dstTy.getShape()), + std::move(order), + trans.getOperation(), + trans.getLoc()}); current = trans.getSrc(); continue; } @@ -444,20 +457,67 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp return nullptr; } - // Build a new memdesc type for the rewritten `local_alloc` by taking the - // original MMA operand memdesc and replacing its shape and shared encoding - // with those from swizzle-0 `tt.descriptor_load` result - static FailureOr getSwizzle0MemDescType(MemDescType refTy, - Value baseTensor) { + static FailureOr inferViewStepBackward(MemDescType resultTy, + const ViewStep &step) { + if (resultTy.getShape() != ArrayRef(step.dstShape)) + return failure(); + + switch (step.kind) { + case ViewStep::Reshape: { + MemDescType srcTy; + if (failed(MemDescReshapeOp::inferReturnTypes( + resultTy.getContext(), step.loc, resultTy, step.srcShape, srcTy))) + return failure(); + return srcTy; + } + case ViewStep::Transpose: { + auto inverseOrder = triton::inversePermutation(step.order); + Attribute srcEnc = resultTy.getEncoding(); + if (srcEnc) { + auto inferLayoutInterface = + cast(&srcEnc.getDialect()); + if (failed(inferLayoutInterface->inferTransOpEncoding( + srcEnc, resultTy.getShape(), inverseOrder, srcEnc, step.loc))) + return failure(); + } + return MemDescType::get(step.srcShape, resultTy.getElementType(), srcEnc, + resultTy.getMemorySpace(), + resultTy.getMutableMemory()); + } + } + llvm_unreachable("unexpected view step"); + } + + static FailureOr + inferBackwardSourceType(MemDescType sinkTy, ArrayRef replaySteps) { + MemDescType currentTy = sinkTy; + for (const ViewStep &step : llvm::reverse(replaySteps)) { + auto srcTy = inferViewStepBackward(currentTy, step); + if (failed(srcTy)) + return failure(); + currentTy = *srcTy; + } + return currentTy; + } + + static bool layoutsEquivalent(ArrayRef shape, Attribute lhs, + Attribute rhs) { + if (lhs == rhs) + return true; + if (!lhs || !rhs) + return false; + return areLayoutsEquivalent(shape, cast(lhs), + cast(rhs)); + } + + static LogicalResult verifySourceSwizzle0Layout(MemDescType inferredBaseTy, + Value baseTensor) { auto sourceSharedEnc = getSourceSwizzle0SharedEncoding(baseTensor); if (!sourceSharedEnc) return failure(); - - auto baseTensorTy = cast(baseTensor.getType()); - return MemDescType::get(baseTensorTy.getShape(), - baseTensorTy.getElementType(), - cast(sourceSharedEnc), - refTy.getMemorySpace(), refTy.getMutableMemory()); + return success(layoutsEquivalent( + inferredBaseTy.getShape(), inferredBaseTy.getEncoding(), + cast(sourceSharedEnc))); } LogicalResult rewriteOperand(OpOperand &operand, @@ -482,30 +542,42 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp return failure(); FailureOr baseMemTy = - getSwizzle0MemDescType(origTy, baseTensor); + inferBackwardSourceType(localAlloc.getType(), tensorReplaySteps); if (failed(baseMemTy)) return failure(); + if (failed(verifySourceSwizzle0Layout(*baseMemTy, baseTensor))) + return failure(); PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(localAlloc); Value rewritten = LocalAllocOp::create(rewriter, localAlloc.getLoc(), *baseMemTy, baseTensor); + auto sinkTy = localAlloc.getType(); for (ViewStep &step : tensorReplaySteps) { if (step.kind == ViewStep::Reshape) { - rewritten = - MemDescReshapeOp::create(rewriter, step.loc, rewritten, step.shape); + rewritten = MemDescReshapeOp::create(rewriter, step.loc, rewritten, + step.dstShape); } else { rewritten = MemDescTransOp::create(rewriter, step.loc, rewritten, step.order); } } + auto rewrittenSinkTy = cast(rewritten.getType()); + if (rewrittenSinkTy.getShape() != sinkTy.getShape() || + rewrittenSinkTy.getElementType() != sinkTy.getElementType() || + !layoutsEquivalent(sinkTy.getShape(), rewrittenSinkTy.getEncoding(), + sinkTy.getEncoding())) { + return failure(); + } + for (ViewStep &step : trailingMemDescReplaySteps) { if (step.kind == ViewStep::Reshape) { rewritten = - MemDescReshapeOp::create(rewriter, step.loc, rewritten, step.shape); + MemDescReshapeOp::create(rewriter, step.loc, rewritten, + step.dstShape); } else { rewritten = MemDescTransOp::create(rewriter, step.loc, rewritten, step.order); @@ -516,6 +588,9 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp assert(rewrittenTy.getShape() == origTy.getShape() && rewrittenTy.getElementType() == origTy.getElementType() && "rewrite must preserve memdesc shape and element type"); + if (!layoutsEquivalent(origTy.getShape(), rewrittenTy.getEncoding(), + origTy.getEncoding())) + return failure(); operand.assign(rewritten); return success(); diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 5a73c011c034..65d69ec01b1a 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -288,20 +288,25 @@ module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num #linear2 = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 128], [16, 0], [32, 0], [64, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], warp = [[0, 32], [0, 64]], block = []}> #shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> #shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> +#shared2 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 4}> +#shared3 = #ttg.shared_linear<{offset = [[0, 0, 0, 1], [0, 0, 0, 2], [0, 0, 0, 4], [0, 0, 0, 8], [0, 1, 0, 0], [0, 2, 0, 0], [0, 4, 0, 0], [1, 0, 0, 0], [2, 0, 0, 0], [4, 0, 0, 0], [8, 0, 0, 0], [16, 0, 0, 0], [0, 0, 1, 0], [0, 0, 2, 0], [0, 0, 4, 0]]}, alignment = 128> +#shared4 = #ttg.shared_linear<{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0], [128, 0], [0, 16], [0, 32], [0, 64]]}, alignment = 128> +#shared5 = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0], [32, 0], [64, 0]]}, alignment = 128> #smem = #ttg.shared_memory #tmem0 = #ttng.tensor_memory_encoding // CHECK-DAG: #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> -// CHECK-DAG: #shared2 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 4}> +// CHECK-DAG: #shared2 = #ttg.shared_linear<{{.*}}alignment = 128> +// CHECK-DAG: #shared6 = #ttg.shared_linear<{{.*}}alignment = 128> module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @swizzle0_operand_views // CHECK-DAG: %[[A_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #{{.*}}> // CHECK-DAG: %[[A_BASE:.*]] = ttg.local_alloc %[[A_DESC]] : (tensor<128x128xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<128x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> // CHECK-DAG: %[[B_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> - // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC]] : (tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #shared, #smem{{.*}}> - // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #shared, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, #shared2, #smem{{.*}}> + // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC]] : (tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #shared2, #smem{{.*}}> + // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #shared2, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> // CHECK-DAG: %[[B_TR0:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR0]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_TR1:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_TR1:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, #shared6, #smem{{.*}}> // CHECK-NOT: tt.reshape // CHECK-NOT: tt.trans // CHECK-NOT: ttg.local_load @@ -319,8 +324,8 @@ module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num %b3 = tt.reshape %b2 : tensor<32x8x8x16xf8E4M3FN, #blocked3> -> tensor<256x128xf8E4M3FN, #linear0> %b4 = tt.trans %b3 {order = array} : tensor<256x128xf8E4M3FN, #linear0> -> tensor<128x256xf8E4M3FN, #linear2> %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blocked0>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared0, #smem> - %b_s = ttg.local_alloc %b4 : (tensor<128x256xf8E4M3FN, #linear2>) -> !ttg.memdesc<128x256xf8E4M3FN, #shared0, #smem> - ttng.tc_gen5_mma %a_s, %b_s, %acc, %true, %true : !ttg.memdesc<128x128xf8E4M3FN, #shared0, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #shared0, #smem>, !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory> + %b_s = ttg.local_alloc %b4 : (tensor<128x256xf8E4M3FN, #linear2>) -> !ttg.memdesc<128x256xf8E4M3FN, #shared5, #smem> + ttng.tc_gen5_mma %a_s, %b_s, %acc, %true, %true : !ttg.memdesc<128x128xf8E4M3FN, #shared0, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #shared5, #smem>, !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory> tt.return } @@ -337,16 +342,22 @@ module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num #mma_desc = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 32]}> #sharedA_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> #sharedB0_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> -#sharedB_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}> +#sharedB1_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 4}> +#sharedB2_desc = #ttg.shared_linear<{offset = [[0, 0, 0, 1], [0, 0, 0, 2], [0, 0, 0, 4], [0, 0, 0, 8], [0, 1, 0, 0], [0, 2, 0, 0], [0, 4, 0, 0], [1, 0, 0, 0], [2, 0, 0, 0], [4, 0, 0, 0], [8, 0, 0, 0], [16, 0, 0, 0], [0, 0, 1, 0], [0, 0, 2, 0], [0, 0, 4, 0]]}, alignment = 128> +#sharedB3_desc = #ttg.shared_linear<{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0], [128, 0], [0, 16], [0, 32], [0, 64]]}, alignment = 128> +#sharedB4_desc = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0], [32, 0], [64, 0]]}, alignment = 128> #smem = #ttg.shared_memory +// CHECK-DAG: #shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> +// CHECK-DAG: #shared2 = #ttg.shared_linear<{{.*}}alignment = 128> +// CHECK-DAG: #shared6 = #ttg.shared_linear<{{.*}}alignment = 128> module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @swizzle0_operand_views_warp_group_dot - // CHECK-DAG: %[[B_DESC_LOAD:.*]] = tt.descriptor_load %arg1[ - // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC_LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, {{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_DESC_LOAD:.*]] = tt.descriptor_load %arg1[%{{.*}}] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC_LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, {{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #shared2, #smem{{.*}}> + // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #shared2, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> // CHECK-DAG: %[[B_TR:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_T:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_T:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, #shared6, #smem{{.*}}> // CHECK-NOT: tt.reshape // CHECK-NOT: tt.trans // CHECK: %[[DOT:.*]] = ttng.warp_group_dot %{{.*}}, %[[B_T]], %arg2 @@ -362,11 +373,11 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- %b0 = tt.reshape %b_cm : tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> -> tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> %b1 = tt.trans %b0 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> -> tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> %b2 = tt.reshape %b1 : tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> -> tensor<256x128xf8E4M3FN, #linear_desc> - %b_s = ttg.local_alloc %b2 : (tensor<256x128xf8E4M3FN, #linear_desc>) -> !ttg.memdesc<256x128xf8E4M3FN, #sharedA_desc, #smem> - %b_t = ttg.memdesc_trans %b_s {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #sharedA_desc, #smem> -> !ttg.memdesc<128x256xf8E4M3FN, #sharedB_desc, #smem> + %b_s = ttg.local_alloc %b2 : (tensor<256x128xf8E4M3FN, #linear_desc>) -> !ttg.memdesc<256x128xf8E4M3FN, #sharedB3_desc, #smem> + %b_t = ttg.memdesc_trans %b_s {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #sharedB3_desc, #smem> -> !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> %r = ttng.warp_group_dot %a_s, %b_t, %acc {inputPrecision = 0 : i32, isAsync = true, maxNumImpreciseAcc = 1073741824 : i32} - : !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> * !ttg.memdesc<128x256xf8E4M3FN, #sharedB_desc, #smem> -> tensor<128x256xf32, #mma_desc> - %w:3 = ttng.warp_group_dot_wait %r, %a_s, %b_t {pendings = 0 : i32} : tensor<128x256xf32, #mma_desc>, !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #sharedB_desc, #smem> + : !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> * !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> -> tensor<128x256xf32, #mma_desc> + %w:3 = ttng.warp_group_dot_wait %r, %a_s, %b_t {pendings = 0 : i32} : tensor<128x256xf32, #mma_desc>, !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> tt.return %w#0 : tensor<128x256xf32, #mma_desc> } } From c3884789dadfefabad853590f02008f75d5aca5e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 8 Apr 2026 19:45:48 +0900 Subject: [PATCH 35/54] Clean up sink-driven dot operand rewrite --- .../Transforms/OptimizeDotOperands.cpp | 101 +++++++----------- 1 file changed, 41 insertions(+), 60 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 6b785e64f1c7..691c5c84c21e 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -13,6 +13,7 @@ #include "triton/Tools/LayoutUtils.h" #include "triton/Tools/LinearLayout.h" #include +#include namespace mlir::triton::gpu { @@ -353,7 +354,8 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp bool changedB = rewriteOperand(dotOp.getBMutable(), rewriter).succeeded(); if (changedA || changedB) { - updateDependentOps(dotOp, oldA, oldB, rewriter); + if constexpr (std::is_same_v) + updateWarpGroupDotWaitOperands(dotOp, oldA, oldB, rewriter); return success(); } @@ -361,10 +363,8 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp } private: - template - static void updateDependentOps(T, Value, Value, PatternRewriter &) {} - - static void updateDependentOps(triton::nvidia_gpu::WarpGroupDotOp dotOp, + static void + updateWarpGroupDotWaitOperands(triton::nvidia_gpu::WarpGroupDotOp dotOp, Value oldA, Value oldB, PatternRewriter &rewriter) { // Keep warp_group_dot_wait operands consistent with rewritten dot @@ -401,7 +401,6 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp SmallVector srcShape; SmallVector dstShape; SmallVector order; - Operation *op; Location loc; }; @@ -418,7 +417,6 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp SmallVector(srcTy.getShape()), SmallVector(dstTy.getShape()), {}, - reshape.getOperation(), reshape.getLoc()}); current = reshape.getSrc(); continue; @@ -428,13 +426,10 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp trans.getOrder().end()); auto srcTy = trans.getSrc().getType(); auto dstTy = trans.getType(); - replaySteps.push_back( - ViewStep{ViewStep::Transpose, - SmallVector(srcTy.getShape()), - SmallVector(dstTy.getShape()), - std::move(order), - trans.getOperation(), - trans.getLoc()}); + replaySteps.push_back(ViewStep{ViewStep::Transpose, + SmallVector(srcTy.getShape()), + SmallVector(dstTy.getShape()), + std::move(order), trans.getLoc()}); current = trans.getSrc(); continue; } @@ -443,24 +438,20 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp return {current, llvm::to_vector(llvm::reverse(replaySteps))}; } - static SharedEncodingTrait getSourceSwizzle0SharedEncoding(Value baseTensor) { + static SharedEncodingTrait getSourceSharedEncoding(Value baseTensor) { if (auto descLoad = baseTensor.getDefiningOp()) { auto descTy = cast(descLoad.getDesc().getType()); auto descBlockTy = descTy.getBlockType(); - auto sourceSharedEnc = - dyn_cast_or_null(descBlockTy.getEncoding()); - if (auto nvmma = dyn_cast_or_null( - cast_or_null(sourceSharedEnc)); - nvmma && nvmma.getSwizzlingByteWidth() == 0) - return sourceSharedEnc; + return dyn_cast_or_null(descBlockTy.getEncoding()); } return nullptr; } static FailureOr inferViewStepBackward(MemDescType resultTy, const ViewStep &step) { - if (resultTy.getShape() != ArrayRef(step.dstShape)) - return failure(); + assert(resultTy.getShape() == ArrayRef(step.dstShape) && + "backward inference must start from the view step destination " + "shape"); switch (step.kind) { case ViewStep::Reshape: { @@ -473,13 +464,11 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp case ViewStep::Transpose: { auto inverseOrder = triton::inversePermutation(step.order); Attribute srcEnc = resultTy.getEncoding(); - if (srcEnc) { - auto inferLayoutInterface = - cast(&srcEnc.getDialect()); - if (failed(inferLayoutInterface->inferTransOpEncoding( - srcEnc, resultTy.getShape(), inverseOrder, srcEnc, step.loc))) - return failure(); - } + auto inferLayoutInterface = + cast(&srcEnc.getDialect()); + if (failed(inferLayoutInterface->inferTransOpEncoding( + srcEnc, resultTy.getShape(), inverseOrder, srcEnc, step.loc))) + return failure(); return MemDescType::get(step.srcShape, resultTy.getElementType(), srcEnc, resultTy.getMemorySpace(), resultTy.getMutableMemory()); @@ -500,24 +489,15 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp return currentTy; } - static bool layoutsEquivalent(ArrayRef shape, Attribute lhs, - Attribute rhs) { - if (lhs == rhs) - return true; - if (!lhs || !rhs) - return false; - return areLayoutsEquivalent(shape, cast(lhs), - cast(rhs)); - } - - static LogicalResult verifySourceSwizzle0Layout(MemDescType inferredBaseTy, - Value baseTensor) { - auto sourceSharedEnc = getSourceSwizzle0SharedEncoding(baseTensor); + static LogicalResult verifySourceLayout(MemDescType inferredBaseTy, + Value baseTensor) { + auto sourceSharedEnc = getSourceSharedEncoding(baseTensor); if (!sourceSharedEnc) return failure(); - return success(layoutsEquivalent( - inferredBaseTy.getShape(), inferredBaseTy.getEncoding(), - cast(sourceSharedEnc))); + return success(areLayoutsEquivalent( + inferredBaseTy.getShape(), + cast(inferredBaseTy.getEncoding()), + cast(cast(sourceSharedEnc)))); } LogicalResult rewriteOperand(OpOperand &operand, @@ -545,7 +525,7 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp inferBackwardSourceType(localAlloc.getType(), tensorReplaySteps); if (failed(baseMemTy)) return failure(); - if (failed(verifySourceSwizzle0Layout(*baseMemTy, baseTensor))) + if (failed(verifySourceLayout(*baseMemTy, baseTensor))) return failure(); PatternRewriter::InsertionGuard guard(rewriter); @@ -566,18 +546,18 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp } auto rewrittenSinkTy = cast(rewritten.getType()); - if (rewrittenSinkTy.getShape() != sinkTy.getShape() || - rewrittenSinkTy.getElementType() != sinkTy.getElementType() || - !layoutsEquivalent(sinkTy.getShape(), rewrittenSinkTy.getEncoding(), - sinkTy.getEncoding())) { - return failure(); - } + assert(rewrittenSinkTy.getShape() == sinkTy.getShape() && + rewrittenSinkTy.getElementType() == sinkTy.getElementType() && + areLayoutsEquivalent( + sinkTy.getShape(), + cast(rewrittenSinkTy.getEncoding()), + cast(sinkTy.getEncoding())) && + "rewrite must preserve the intermediate sink memdesc"); for (ViewStep &step : trailingMemDescReplaySteps) { if (step.kind == ViewStep::Reshape) { - rewritten = - MemDescReshapeOp::create(rewriter, step.loc, rewritten, - step.dstShape); + rewritten = MemDescReshapeOp::create(rewriter, step.loc, rewritten, + step.dstShape); } else { rewritten = MemDescTransOp::create(rewriter, step.loc, rewritten, step.order); @@ -587,10 +567,11 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp auto rewrittenTy = cast(rewritten.getType()); assert(rewrittenTy.getShape() == origTy.getShape() && rewrittenTy.getElementType() == origTy.getElementType() && - "rewrite must preserve memdesc shape and element type"); - if (!layoutsEquivalent(origTy.getShape(), rewrittenTy.getEncoding(), - origTy.getEncoding())) - return failure(); + areLayoutsEquivalent( + origTy.getShape(), + cast(rewrittenTy.getEncoding()), + cast(origTy.getEncoding())) && + "rewrite must preserve the final memdesc"); operand.assign(rewritten); return success(); From b9bb708b297e2affb31fe1817dda3a034e8a6a96 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 8 Apr 2026 19:49:32 +0900 Subject: [PATCH 36/54] Refine sink-driven operand rewrite checks --- .../TritonGPU/Transforms/OptimizeDotOperands.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 691c5c84c21e..95397ba5a94a 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -438,7 +438,7 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp return {current, llvm::to_vector(llvm::reverse(replaySteps))}; } - static SharedEncodingTrait getSourceSharedEncoding(Value baseTensor) { + static SharedEncodingTrait getDescriptorSharedEncoding(Value baseTensor) { if (auto descLoad = baseTensor.getDefiningOp()) { auto descTy = cast(descLoad.getDesc().getType()); auto descBlockTy = descTy.getBlockType(); @@ -489,15 +489,16 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp return currentTy; } - static LogicalResult verifySourceLayout(MemDescType inferredBaseTy, - Value baseTensor) { - auto sourceSharedEnc = getSourceSharedEncoding(baseTensor); - if (!sourceSharedEnc) + static LogicalResult + verifyBaseMatchesDescriptorLayout(MemDescType inferredBaseTy, + Value baseTensor) { + auto descriptorSharedEnc = getDescriptorSharedEncoding(baseTensor); + if (!descriptorSharedEnc) return failure(); return success(areLayoutsEquivalent( inferredBaseTy.getShape(), cast(inferredBaseTy.getEncoding()), - cast(cast(sourceSharedEnc)))); + cast(cast(descriptorSharedEnc)))); } LogicalResult rewriteOperand(OpOperand &operand, @@ -525,7 +526,7 @@ class RewriteSwizzle0OperandViewsToMemDescForDotOp inferBackwardSourceType(localAlloc.getType(), tensorReplaySteps); if (failed(baseMemTy)) return failure(); - if (failed(verifySourceLayout(*baseMemTy, baseTensor))) + if (failed(verifyBaseMatchesDescriptorLayout(*baseMemTy, baseTensor))) return failure(); PatternRewriter::InsertionGuard guard(rewriter); From 1133abd0de8ac45bed5bc73136ae0b0131179cf9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 8 Apr 2026 19:52:34 +0900 Subject: [PATCH 37/54] Generalize dot operand view rewrite naming --- .../Transforms/OptimizeDotOperands.cpp | 36 ++++++++----------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 95397ba5a94a..d5c8d1b0b59e 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -319,29 +319,21 @@ class UseShmemForScales }; // Rewrite -// desc_load -> tt.reshape / tt.trans -> local_alloc -> memdesc -// reshape / trans +// desc_load -> tt.reshape / tt.trans -> local_alloc -> [memdesc views] -> mma // into -// desc_load -> local_alloc -> memdesc reshape / trans +// desc_load -> local_alloc -> memdesc reshape / trans -> [memdesc views] -> +// mma // -// swizzle=0 in NVMMASharedEncodingAttr represents a flat, contiguous layout. -// This is valid as the destination encoding for TMA, but unless the operand's -// contiguous dimension is <= 16 bytes, it is not the correct layout for an MMA -// operand which requires the special "core-matrices" layout even with -// swizzle=0. So if the result of swizzle-0 TMA is fed into MMA without smem -// layout conversion between them, the result would be incorrect. +// The MMA operand layout is determined by the sink memdesc already feeding the +// dot-like op. This pattern back-propagates that layout through the tensor +// reshape/transpose chain, hoists local_alloc to the descriptor_load result, +// and then replays the same views as memdesc reshape/transpose ops. // -// When using swizzle-0 TMA with MMA, it is a user's responsibility to have the -// source of TMA in global memory to be already in the core-matrices format, and -// insert a sequence of tt.reshape / tt.trans transformations between desc_load -// and MMA ops such that the MMA op sees the right core-matrices layout. - -// This rewrite pattern ensures that swizzle=0 in TMA and a sequence of -// tt.reshape / tt.trans ops are correctly propagated, via equivalent -// transformations on memdesc, into the right MMA SMEM operand layout with -// swizzle=0. +// The rewrite only applies when the backward-inferred base memdesc layout is +// equivalent to the descriptor block layout, so the hoisted local_alloc still +// represents the same underlying shared-memory view. template -class RewriteSwizzle0OperandViewsToMemDescForDotOp +class RewriteDotOperandViewsToMemDescForDotOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -604,11 +596,11 @@ class TritonGPUOptimizeDotOperandsPass patterns.add(context); patterns.add(context); patterns.add, - RewriteSwizzle0OperandViewsToMemDescForDotOp< + RewriteDotOperandViewsToMemDescForDotOp< triton::nvidia_gpu::TCGen5MMAScaledOp>, - RewriteSwizzle0OperandViewsToMemDescForDotOp< + RewriteDotOperandViewsToMemDescForDotOp< triton::nvidia_gpu::WarpGroupDotOp>>(context); ConvertLayoutOp::getCanonicalizationPatterns(patterns, context); if (failed(applyPatternsGreedily(m, std::move(patterns)))) From ae6782cf173483a8b9af92914bcfd7bb05e992da Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 8 Apr 2026 19:57:02 +0900 Subject: [PATCH 38/54] Remove stale swizzle0 host descriptor test --- .../unit/language/test_tensor_descriptor.py | 100 ------------------ 1 file changed, 100 deletions(-) diff --git a/python/test/unit/language/test_tensor_descriptor.py b/python/test/unit/language/test_tensor_descriptor.py index 98f90861f402..cbdfdef230d2 100644 --- a/python/test/unit/language/test_tensor_descriptor.py +++ b/python/test/unit/language/test_tensor_descriptor.py @@ -1734,64 +1734,6 @@ def matmul_kernel_host_tensor_descriptor(a_desc, b_desc, c_desc): c_desc.store([offs_am, offs_bn], accumulator) -@triton.jit -def matmul_kernel_host_tensor_descriptor_swizzle0_b(a_desc, b_desc, c_desc): - K = a_desc.shape[1] - BLOCK_M: tl.constexpr = a_desc.block_shape[0] - BLOCK_K: tl.constexpr = a_desc.block_shape[1] - BLOCK_N: tl.constexpr = c_desc.block_shape[1] - - pid_m = tl.program_id(axis=0) - pid_n = tl.program_id(axis=1) - offs_am = pid_m * BLOCK_M - offs_bn = pid_n * BLOCK_N - offs_k = 0 - - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k_tile in range(0, tl.cdiv(K, BLOCK_K)): - a = a_desc.load([offs_am, offs_k]) - # Inverse core-matrices reorder: - # [num_cm_k * num_cm_n, 8, 16] -> [num_cm_k, num_cm_n, 8, 16] - # -> [num_cm_n, 8, num_cm_k, 16] -> [BLOCK_N, BLOCK_K]. - b_cm = b_desc.load([pid_n, k_tile, 0, 0, 0]) - num_cm_k: tl.constexpr = BLOCK_K // 16 - num_cm_n: tl.constexpr = BLOCK_N // 8 - b = b_cm.reshape((num_cm_k, num_cm_n, 8, 16)) - b = tl.permute(b, (1, 2, 0, 3)) - b = b.reshape((BLOCK_N, BLOCK_K)) - accumulator = tl.dot(a, b.T, acc=accumulator) - offs_k += BLOCK_K - c_desc.store([offs_am, offs_bn], accumulator) - - -def transform_b_to_core_matrices_layout(B, BLOCK_N, BLOCK_K): - CM_ROWS = 8 - CM_COLS = 16 - - N, K = B.shape - assert N % BLOCK_N == 0 - assert K % BLOCK_K == 0 - assert BLOCK_N % CM_ROWS == 0 - assert BLOCK_K % CM_COLS == 0 - - num_blocks_n = N // BLOCK_N - num_blocks_k = K // BLOCK_K - num_cm_n = BLOCK_N // CM_ROWS - num_cm_k = BLOCK_K // CM_COLS - - # [N, K] -> [num_blocks_n, num_cm_n, CM_ROWS, num_blocks_k, num_cm_k, CM_COLS] - b_reshaped = B.reshape(num_blocks_n, num_cm_n, CM_ROWS, num_blocks_k, num_cm_k, CM_COLS) - # [num_blocks_n, num_cm_n, num_cm_k, num_blocks_k, CM_ROWS, CM_COLS] - b_perm = b_reshaped.permute(0, 1, 4, 3, 2, 5) - # N-major core-matrices: - # [num_blocks_n, num_blocks_k, num_cm_k, num_cm_n, CM_ROWS, CM_COLS] - b_perm = b_perm.permute(0, 3, 2, 1, 4, 5) - # Collapse cm-count dims to keep descriptor rank <= 5 while retaining - # explicit core-matrix rows/cols in the innermost axes. - b_transformed = b_perm.reshape(num_blocks_n, num_blocks_k, num_cm_k * num_cm_n, CM_ROWS, CM_COLS) - return b_transformed.contiguous() - - @pytest.mark.interpreter() @pytest.mark.parametrize("num_ctas", [1, 2]) @pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K, num_stages", [ @@ -1842,48 +1784,6 @@ def test_host_tensor_descriptor_matmul(num_stages, num_ctas, BLOCK_M, BLOCK_N, B "ptx"] or "stmatrix.sync.aligned.x4.m8n8.shared.b16" in kernel.asm["ptx"] -@pytest.mark.skipif(not (is_hopper() or is_blackwell()), reason="Requires Hopper or Blackwell") -def test_host_tensor_descriptor_matmul_fp8_swizzle0_b(device): - M = N = K = 512 - BLOCK_M = 128 - BLOCK_N = 256 - BLOCK_K = 128 - grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N), 1) - - torch.manual_seed(0) - a_fp32 = torch.randn((M, K), dtype=torch.float32, device=device) - b_fp32 = torch.randn((N, K), dtype=torch.float32, device=device) - a = a_fp32.to(torch.float8_e4m3fn) - b = b_fp32.to(torch.float8_e4m3fn) - b_transformed = transform_b_to_core_matrices_layout(b, BLOCK_N, BLOCK_K) - c = torch.empty((M, N), dtype=torch.float32, device=device) - - a_desc = TensorDescriptor(a, a.shape, a.stride(), [BLOCK_M, BLOCK_K]) - num_cm_n = BLOCK_N // 8 - num_cm_k = BLOCK_K // 16 - b_desc = TensorDescriptor( - b_transformed, - b_transformed.shape, - b_transformed.stride(), - [1, 1, num_cm_k * num_cm_n, 8, 16], - ) - c_desc = TensorDescriptor(c, c.shape, c.stride(), [BLOCK_M, BLOCK_N]) - - kernel = matmul_kernel_host_tensor_descriptor_swizzle0_b[grid]( - a_desc, - b_desc, - c_desc, - num_warps=4, - num_stages=1, - ) - - ref_out = torch.matmul(a.to(torch.float32), b.to(torch.float32).T) - torch.testing.assert_close(ref_out, c, rtol=1.5e-1, atol=2.5e-1) - - ttgir = kernel.asm["ttgir"] - assert "swizzlingByteWidth = 0" in ttgir and "#ttg.shared_linear" in ttgir - - @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) def test_tensor_descriptor_store_downcast(dtype_str, device): From ffa4f6f648f5624fca1e5ef31cbb2ccac928e7a6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 8 Apr 2026 20:01:38 +0900 Subject: [PATCH 39/54] revert unnecessary test change --- python/test/unit/language/test_tensor_descriptor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/test/unit/language/test_tensor_descriptor.py b/python/test/unit/language/test_tensor_descriptor.py index cbdfdef230d2..64c714ab4571 100644 --- a/python/test/unit/language/test_tensor_descriptor.py +++ b/python/test/unit/language/test_tensor_descriptor.py @@ -4,7 +4,7 @@ import triton import triton.language as tl -from triton._internal_testing import is_hopper, is_blackwell, is_sm12x, is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy +from triton._internal_testing import is_hopper, is_sm12x, is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor from typing import Optional from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3 From 96793597fea4912912cbf143c15ac51e034ba097 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 8 Apr 2026 20:04:03 +0900 Subject: [PATCH 40/54] Restore template dispatch for dot operand updates --- .../TritonGPU/Transforms/OptimizeDotOperands.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index d5c8d1b0b59e..f9ebbd3b9a3a 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -13,7 +13,6 @@ #include "triton/Tools/LayoutUtils.h" #include "triton/Tools/LinearLayout.h" #include -#include namespace mlir::triton::gpu { @@ -346,8 +345,7 @@ class RewriteDotOperandViewsToMemDescForDotOp bool changedB = rewriteOperand(dotOp.getBMutable(), rewriter).succeeded(); if (changedA || changedB) { - if constexpr (std::is_same_v) - updateWarpGroupDotWaitOperands(dotOp, oldA, oldB, rewriter); + updateDependentOps(dotOp, oldA, oldB, rewriter); return success(); } @@ -355,8 +353,10 @@ class RewriteDotOperandViewsToMemDescForDotOp } private: - static void - updateWarpGroupDotWaitOperands(triton::nvidia_gpu::WarpGroupDotOp dotOp, + template + static void updateDependentOps(T, Value, Value, PatternRewriter &) {} + + static void updateDependentOps(triton::nvidia_gpu::WarpGroupDotOp dotOp, Value oldA, Value oldB, PatternRewriter &rewriter) { // Keep warp_group_dot_wait operands consistent with rewritten dot From e315bf2cf45a6bac8bf07c023c190054ce883f12 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 9 Apr 2026 04:43:52 +0900 Subject: [PATCH 41/54] Use inferSrcEncoding in dot operand rewrite --- .../Transforms/OptimizeDotOperands.cpp | 30 ++++++++----------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index f9ebbd3b9a3a..712422279f7e 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -393,6 +393,7 @@ class RewriteDotOperandViewsToMemDescForDotOp SmallVector srcShape; SmallVector dstShape; SmallVector order; + Operation *op; Location loc; }; @@ -409,6 +410,7 @@ class RewriteDotOperandViewsToMemDescForDotOp SmallVector(srcTy.getShape()), SmallVector(dstTy.getShape()), {}, + reshape.getOperation(), reshape.getLoc()}); current = reshape.getSrc(); continue; @@ -421,7 +423,9 @@ class RewriteDotOperandViewsToMemDescForDotOp replaySteps.push_back(ViewStep{ViewStep::Transpose, SmallVector(srcTy.getShape()), SmallVector(dstTy.getShape()), - std::move(order), trans.getLoc()}); + std::move(order), + trans.getOperation(), + trans.getLoc()}); current = trans.getSrc(); continue; } @@ -444,29 +448,19 @@ class RewriteDotOperandViewsToMemDescForDotOp assert(resultTy.getShape() == ArrayRef(step.dstShape) && "backward inference must start from the view step destination " "shape"); - - switch (step.kind) { - case ViewStep::Reshape: { + if (step.kind == ViewStep::Reshape) { MemDescType srcTy; if (failed(MemDescReshapeOp::inferReturnTypes( resultTy.getContext(), step.loc, resultTy, step.srcShape, srcTy))) return failure(); return srcTy; } - case ViewStep::Transpose: { - auto inverseOrder = triton::inversePermutation(step.order); - Attribute srcEnc = resultTy.getEncoding(); - auto inferLayoutInterface = - cast(&srcEnc.getDialect()); - if (failed(inferLayoutInterface->inferTransOpEncoding( - srcEnc, resultTy.getShape(), inverseOrder, srcEnc, step.loc))) - return failure(); - return MemDescType::get(step.srcShape, resultTy.getElementType(), srcEnc, - resultTy.getMemorySpace(), - resultTy.getMutableMemory()); - } - } - llvm_unreachable("unexpected view step"); + Attribute srcEnc = inferSrcEncoding(step.op, resultTy.getEncoding()); + if (!srcEnc) + return failure(); + return MemDescType::get(step.srcShape, resultTy.getElementType(), srcEnc, + resultTy.getMemorySpace(), + resultTy.getMutableMemory()); } static FailureOr From 02dcdba77be6b6abb367d6081718a7a4c8ef4af0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 9 Apr 2026 04:51:03 +0900 Subject: [PATCH 42/54] Simplify dot operand rewiring after rewrite --- .../Transforms/OptimizeDotOperands.cpp | 61 +++++-------------- 1 file changed, 14 insertions(+), 47 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 712422279f7e..b0099a8e03fc 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -341,53 +341,22 @@ class RewriteDotOperandViewsToMemDescForDotOp PatternRewriter &rewriter) const override { Value oldA = dotOp.getA(); Value oldB = dotOp.getB(); - bool changedA = rewriteOperand(dotOp.getAMutable(), rewriter).succeeded(); - bool changedB = rewriteOperand(dotOp.getBMutable(), rewriter).succeeded(); + bool changed = false; - if (changedA || changedB) { - updateDependentOps(dotOp, oldA, oldB, rewriter); - return success(); + if (rewriteOperand(dotOp.getAMutable(), rewriter).succeeded()) { + oldA.replaceAllUsesExcept(dotOp.getA(), dotOp.getOperation()); + changed = true; } - return failure(); - } - -private: - template - static void updateDependentOps(T, Value, Value, PatternRewriter &) {} - - static void updateDependentOps(triton::nvidia_gpu::WarpGroupDotOp dotOp, - Value oldA, Value oldB, - PatternRewriter &rewriter) { - // Keep warp_group_dot_wait operands consistent with rewritten dot - // operands. The wait op is variadic and can carry [dot_result, A, B], so - // after changing dotOp's A/B we need to retarget corresponding wait - // operands. - Value newA = dotOp.getA(); - Value newB = dotOp.getB(); - for (Operation *user : dotOp.getResult().getUsers()) { - auto waitOp = dyn_cast(user); - if (!waitOp) - continue; - rewriter.modifyOpInPlace(waitOp, [&]() { - for (OpOperand &operand : waitOp->getOpOperands()) { - Value replacement; - if (operand.get() == oldA) - replacement = newA; - else if (operand.get() == oldB) - replacement = newB; - else - continue; - - operand.assign(replacement); - if (operand.getOperandNumber() < waitOp->getNumResults()) - waitOp->getResult(operand.getOperandNumber()) - .setType(replacement.getType()); - } - }); + if (rewriteOperand(dotOp.getBMutable(), rewriter).succeeded()) { + oldB.replaceAllUsesExcept(dotOp.getB(), dotOp.getOperation()); + changed = true; } + + return success(changed); } +private: struct ViewStep { enum Kind { Reshape, Transpose } kind; SmallVector srcShape; @@ -420,12 +389,10 @@ class RewriteDotOperandViewsToMemDescForDotOp trans.getOrder().end()); auto srcTy = trans.getSrc().getType(); auto dstTy = trans.getType(); - replaySteps.push_back(ViewStep{ViewStep::Transpose, - SmallVector(srcTy.getShape()), - SmallVector(dstTy.getShape()), - std::move(order), - trans.getOperation(), - trans.getLoc()}); + replaySteps.push_back(ViewStep{ + ViewStep::Transpose, SmallVector(srcTy.getShape()), + SmallVector(dstTy.getShape()), std::move(order), + trans.getOperation(), trans.getLoc()}); current = trans.getSrc(); continue; } From 68fe5acfbf6091f41fdaba5490696e3f8bda8545 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 9 Apr 2026 15:26:36 +0900 Subject: [PATCH 43/54] Move MMA operand view rewrite into NVIDIA pass --- .../TritonNvidiaGPU/Transforms/Passes.td | 14 + .../Transforms/OptimizeDotOperands.cpp | 228 +--------------- .../TritonNvidiaGPU/Transforms/CMakeLists.txt | 1 + .../RewriteMmaOperandViewsToMemDesc.cpp | 245 ++++++++++++++++++ test/TritonGPU/dot-operands.mlir | 105 -------- .../rewrite-mma-operand-views-to-memdesc.mlir | 96 +++++++ third_party/nvidia/backend/compiler.py | 1 + third_party/nvidia/triton_nvidia.cc | 2 + 8 files changed, 364 insertions(+), 328 deletions(-) create mode 100644 lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp create mode 100644 test/TritonNvidiaGPU/rewrite-mma-operand-views-to-memdesc.mlir diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td index a41b2e891434..c9abeadbde7d 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td @@ -142,6 +142,20 @@ def TritonNvidiaGPUOptimizeDescriptorEncodingPass : Pass<"triton-nvidia-optimize "mlir::triton::TritonDialect"]; } +def TritonNvidiaGPURewriteMmaOperandViewsToMemDescPass : Pass<"triton-nvidia-rewrite-mma-operand-views-to-memdesc", "mlir::ModuleOp"> { + let summary = "Rewrite tensor MMA operand views into memdesc views"; + + let description = [{ + Rewrite tensor reshape/transpose chains feeding MMA shared-memory operands + into equivalent memdesc reshape/transpose chains once the operand layout is + fixed. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", + "mlir::triton::TritonDialect"]; +} + def TritonNvidiaGPUOptimizeTMemLayoutsPass : Pass<"triton-nvidia-optimize-tmem-layouts", "mlir::ModuleOp"> { let summary = "Optimize TMEM layouts."; diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index b0099a8e03fc..9fd5e75602b1 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -108,7 +108,10 @@ class FuseTransMMAV3Plus : public OpRewritePattern { return failure(); MemDescType allocType = allocOp.getType(); - auto allocEncoding = cast(allocType.getEncoding()); + auto allocEncoding = + dyn_cast(allocType.getEncoding()); + if (!allocEncoding) + return failure(); RankedTensorType srcTy = trans.getSrc().getType(); auto ctx = getContext(); @@ -317,221 +320,6 @@ class UseShmemForScales } }; -// Rewrite -// desc_load -> tt.reshape / tt.trans -> local_alloc -> [memdesc views] -> mma -// into -// desc_load -> local_alloc -> memdesc reshape / trans -> [memdesc views] -> -// mma -// -// The MMA operand layout is determined by the sink memdesc already feeding the -// dot-like op. This pattern back-propagates that layout through the tensor -// reshape/transpose chain, hoists local_alloc to the descriptor_load result, -// and then replays the same views as memdesc reshape/transpose ops. -// -// The rewrite only applies when the backward-inferred base memdesc layout is -// equivalent to the descriptor block layout, so the hoisted local_alloc still -// represents the same underlying shared-memory view. -template -class RewriteDotOperandViewsToMemDescForDotOp - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DotOpTy dotOp, - PatternRewriter &rewriter) const override { - Value oldA = dotOp.getA(); - Value oldB = dotOp.getB(); - bool changed = false; - - if (rewriteOperand(dotOp.getAMutable(), rewriter).succeeded()) { - oldA.replaceAllUsesExcept(dotOp.getA(), dotOp.getOperation()); - changed = true; - } - - if (rewriteOperand(dotOp.getBMutable(), rewriter).succeeded()) { - oldB.replaceAllUsesExcept(dotOp.getB(), dotOp.getOperation()); - changed = true; - } - - return success(changed); - } - -private: - struct ViewStep { - enum Kind { Reshape, Transpose } kind; - SmallVector srcShape; - SmallVector dstShape; - SmallVector order; - Operation *op; - Location loc; - }; - - template - static std::tuple> - collectViewSteps(Value value) { - Value current = value; - SmallVector replaySteps; - while (true) { - if (auto reshape = current.template getDefiningOp()) { - auto srcTy = reshape.getSrc().getType(); - auto dstTy = reshape.getType(); - replaySteps.push_back(ViewStep{ViewStep::Reshape, - SmallVector(srcTy.getShape()), - SmallVector(dstTy.getShape()), - {}, - reshape.getOperation(), - reshape.getLoc()}); - current = reshape.getSrc(); - continue; - } - if (auto trans = current.template getDefiningOp()) { - SmallVector order(trans.getOrder().begin(), - trans.getOrder().end()); - auto srcTy = trans.getSrc().getType(); - auto dstTy = trans.getType(); - replaySteps.push_back(ViewStep{ - ViewStep::Transpose, SmallVector(srcTy.getShape()), - SmallVector(dstTy.getShape()), std::move(order), - trans.getOperation(), trans.getLoc()}); - current = trans.getSrc(); - continue; - } - break; - } - return {current, llvm::to_vector(llvm::reverse(replaySteps))}; - } - - static SharedEncodingTrait getDescriptorSharedEncoding(Value baseTensor) { - if (auto descLoad = baseTensor.getDefiningOp()) { - auto descTy = cast(descLoad.getDesc().getType()); - auto descBlockTy = descTy.getBlockType(); - return dyn_cast_or_null(descBlockTy.getEncoding()); - } - return nullptr; - } - - static FailureOr inferViewStepBackward(MemDescType resultTy, - const ViewStep &step) { - assert(resultTy.getShape() == ArrayRef(step.dstShape) && - "backward inference must start from the view step destination " - "shape"); - if (step.kind == ViewStep::Reshape) { - MemDescType srcTy; - if (failed(MemDescReshapeOp::inferReturnTypes( - resultTy.getContext(), step.loc, resultTy, step.srcShape, srcTy))) - return failure(); - return srcTy; - } - Attribute srcEnc = inferSrcEncoding(step.op, resultTy.getEncoding()); - if (!srcEnc) - return failure(); - return MemDescType::get(step.srcShape, resultTy.getElementType(), srcEnc, - resultTy.getMemorySpace(), - resultTy.getMutableMemory()); - } - - static FailureOr - inferBackwardSourceType(MemDescType sinkTy, ArrayRef replaySteps) { - MemDescType currentTy = sinkTy; - for (const ViewStep &step : llvm::reverse(replaySteps)) { - auto srcTy = inferViewStepBackward(currentTy, step); - if (failed(srcTy)) - return failure(); - currentTy = *srcTy; - } - return currentTy; - } - - static LogicalResult - verifyBaseMatchesDescriptorLayout(MemDescType inferredBaseTy, - Value baseTensor) { - auto descriptorSharedEnc = getDescriptorSharedEncoding(baseTensor); - if (!descriptorSharedEnc) - return failure(); - return success(areLayoutsEquivalent( - inferredBaseTy.getShape(), - cast(inferredBaseTy.getEncoding()), - cast(cast(descriptorSharedEnc)))); - } - - LogicalResult rewriteOperand(OpOperand &operand, - PatternRewriter &rewriter) const { - Value orig = operand.get(); - auto origTy = dyn_cast(orig.getType()); - if (!origTy) - return failure(); - - auto [beforeTrailing, trailingMemDescReplaySteps] = - collectViewSteps(orig); - - auto localAlloc = beforeTrailing.template getDefiningOp(); - if (!localAlloc || !localAlloc.getSrc()) - return failure(); - - auto [baseTensor, tensorReplaySteps] = - collectViewSteps( - localAlloc.getSrc()); - - if (tensorReplaySteps.empty()) - return failure(); - - FailureOr baseMemTy = - inferBackwardSourceType(localAlloc.getType(), tensorReplaySteps); - if (failed(baseMemTy)) - return failure(); - if (failed(verifyBaseMatchesDescriptorLayout(*baseMemTy, baseTensor))) - return failure(); - - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(localAlloc); - - Value rewritten = LocalAllocOp::create(rewriter, localAlloc.getLoc(), - *baseMemTy, baseTensor); - auto sinkTy = localAlloc.getType(); - - for (ViewStep &step : tensorReplaySteps) { - if (step.kind == ViewStep::Reshape) { - rewritten = MemDescReshapeOp::create(rewriter, step.loc, rewritten, - step.dstShape); - } else { - rewritten = - MemDescTransOp::create(rewriter, step.loc, rewritten, step.order); - } - } - - auto rewrittenSinkTy = cast(rewritten.getType()); - assert(rewrittenSinkTy.getShape() == sinkTy.getShape() && - rewrittenSinkTy.getElementType() == sinkTy.getElementType() && - areLayoutsEquivalent( - sinkTy.getShape(), - cast(rewrittenSinkTy.getEncoding()), - cast(sinkTy.getEncoding())) && - "rewrite must preserve the intermediate sink memdesc"); - - for (ViewStep &step : trailingMemDescReplaySteps) { - if (step.kind == ViewStep::Reshape) { - rewritten = MemDescReshapeOp::create(rewriter, step.loc, rewritten, - step.dstShape); - } else { - rewritten = - MemDescTransOp::create(rewriter, step.loc, rewritten, step.order); - } - } - - auto rewrittenTy = cast(rewritten.getType()); - assert(rewrittenTy.getShape() == origTy.getShape() && - rewrittenTy.getElementType() == origTy.getElementType() && - areLayoutsEquivalent( - origTy.getShape(), - cast(rewrittenTy.getEncoding()), - cast(origTy.getEncoding())) && - "rewrite must preserve the final memdesc"); - - operand.assign(rewritten); - return success(); - } -}; - } // namespace #define GEN_PASS_DEF_TRITONGPUOPTIMIZEDOTOPERANDS @@ -556,13 +344,7 @@ class TritonGPUOptimizeDotOperandsPass mlir::RewritePatternSet patterns(context); patterns.add(context); patterns.add(context); - patterns.add, - RewriteDotOperandViewsToMemDescForDotOp< - triton::nvidia_gpu::TCGen5MMAScaledOp>, - RewriteDotOperandViewsToMemDescForDotOp< - triton::nvidia_gpu::WarpGroupDotOp>>(context); + patterns.add(context); ConvertLayoutOp::getCanonicalizationPatterns(patterns, context); if (failed(applyPatternsGreedily(m, std::move(patterns)))) signalPassFailure(); diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt index 7b0d2c626a1e..4059f549ec87 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt @@ -10,6 +10,7 @@ add_triton_library(TritonNvidiaGPUTransforms PlanCTA.cpp PromoteLHSToTMem.cpp ProxyFenceInsertion.cpp + RewriteMmaOperandViewsToMemDesc.cpp RemoveTMEMTokens.cpp TensorMemoryAllocation.cpp TMALowering.cpp diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp new file mode 100644 index 000000000000..26b1aa1f6f86 --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp @@ -0,0 +1,245 @@ +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" + +namespace mlir::triton::nvidia_gpu { + +namespace { + +// Rewrite +// desc_load -> tt.reshape / tt.trans -> local_alloc -> [memdesc views] -> mma +// into +// desc_load -> local_alloc -> memdesc reshape / trans -> [memdesc views] -> +// mma +// +// The MMA operand layout is determined by the sink memdesc already feeding the +// dot-like op. This pattern back-propagates that layout through the tensor +// reshape/transpose chain, hoists local_alloc to the descriptor_load result, +// and then replays the same views as memdesc reshape/transpose ops. +// +// The rewrite only applies when the backward-inferred base memdesc layout is +// equivalent to the descriptor block layout, so the hoisted local_alloc still +// represents the same underlying shared-memory view. +template +class RewriteMmaOperandViewsToMemDescForDotOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DotOpTy dotOp, + PatternRewriter &rewriter) const override { + Value oldA = dotOp.getA(); + Value oldB = dotOp.getB(); + bool changed = false; + + if (rewriteOperand(dotOp.getAMutable(), rewriter).succeeded()) { + oldA.replaceAllUsesExcept(dotOp.getA(), dotOp.getOperation()); + changed = true; + } + + if (rewriteOperand(dotOp.getBMutable(), rewriter).succeeded()) { + oldB.replaceAllUsesExcept(dotOp.getB(), dotOp.getOperation()); + changed = true; + } + + return success(changed); + } + +private: + struct ViewStep { + enum Kind { Reshape, Transpose } kind; + SmallVector srcShape; + SmallVector dstShape; + SmallVector order; + Operation *op; + Location loc; + }; + + template + static std::tuple> + collectViewSteps(Value value) { + Value current = value; + SmallVector replaySteps; + while (true) { + if (auto reshape = current.template getDefiningOp()) { + auto srcTy = reshape.getSrc().getType(); + auto dstTy = reshape.getType(); + replaySteps.push_back(ViewStep{ViewStep::Reshape, + SmallVector(srcTy.getShape()), + SmallVector(dstTy.getShape()), + {}, + reshape.getOperation(), + reshape.getLoc()}); + current = reshape.getSrc(); + continue; + } + if (auto trans = current.template getDefiningOp()) { + SmallVector order(trans.getOrder().begin(), + trans.getOrder().end()); + auto srcTy = trans.getSrc().getType(); + auto dstTy = trans.getType(); + replaySteps.push_back(ViewStep{ + ViewStep::Transpose, SmallVector(srcTy.getShape()), + SmallVector(dstTy.getShape()), std::move(order), + trans.getOperation(), trans.getLoc()}); + current = trans.getSrc(); + continue; + } + break; + } + return {current, llvm::to_vector(llvm::reverse(replaySteps))}; + } + + static gpu::SharedEncodingTrait getDescriptorSharedEncoding(Value baseTensor) { + if (auto descLoad = baseTensor.getDefiningOp()) { + auto descTy = cast(descLoad.getDesc().getType()); + auto descBlockTy = descTy.getBlockType(); + return dyn_cast_or_null( + descBlockTy.getEncoding()); + } + return nullptr; + } + + static FailureOr + inferViewStepBackward(gpu::MemDescType resultTy, const ViewStep &step) { + assert(resultTy.getShape() == ArrayRef(step.dstShape) && + "backward inference must start from the view step destination " + "shape"); + if (step.kind == ViewStep::Reshape) { + gpu::MemDescType srcTy; + if (failed(gpu::MemDescReshapeOp::inferReturnTypes( + resultTy.getContext(), step.loc, resultTy, step.srcShape, + srcTy))) + return failure(); + return srcTy; + } + Attribute srcEnc = inferSrcEncoding(step.op, resultTy.getEncoding()); + if (!srcEnc) + return failure(); + return gpu::MemDescType::get(step.srcShape, resultTy.getElementType(), + srcEnc, resultTy.getMemorySpace(), + resultTy.getMutableMemory()); + } + + static LogicalResult + verifyBaseMatchesDescriptorLayout(gpu::MemDescType inferredBaseTy, + Value baseTensor) { + auto descriptorSharedEnc = getDescriptorSharedEncoding(baseTensor); + if (!descriptorSharedEnc) + return failure(); + return success(gpu::areLayoutsEquivalent( + inferredBaseTy.getShape(), + cast(inferredBaseTy.getEncoding()), + cast(cast(descriptorSharedEnc)))); + } + + LogicalResult rewriteOperand(OpOperand &operand, + PatternRewriter &rewriter) const { + Value orig = operand.get(); + auto origTy = dyn_cast(orig.getType()); + if (!origTy) + return failure(); + + auto [beforeTrailing, trailingMemDescReplaySteps] = + collectViewSteps(orig); + + auto localAlloc = beforeTrailing.template getDefiningOp(); + if (!localAlloc || !localAlloc.getSrc()) + return failure(); + + auto [baseTensor, tensorReplaySteps] = + collectViewSteps( + localAlloc.getSrc()); + if (tensorReplaySteps.empty()) + return failure(); + + gpu::MemDescType baseMemTy = localAlloc.getType(); + for (const ViewStep &step : llvm::reverse(tensorReplaySteps)) { + auto srcTy = inferViewStepBackward(baseMemTy, step); + if (failed(srcTy)) + return failure(); + baseMemTy = *srcTy; + } + if (failed(verifyBaseMatchesDescriptorLayout(baseMemTy, baseTensor))) + return failure(); + + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(localAlloc); + + Value rewritten = gpu::LocalAllocOp::create(rewriter, localAlloc.getLoc(), + baseMemTy, baseTensor); + auto sinkTy = localAlloc.getType(); + + for (ViewStep &step : tensorReplaySteps) { + if (step.kind == ViewStep::Reshape) { + rewritten = gpu::MemDescReshapeOp::create(rewriter, step.loc, rewritten, + step.dstShape); + } else { + rewritten = gpu::MemDescTransOp::create(rewriter, step.loc, rewritten, + step.order); + } + } + + auto rewrittenSinkTy = cast(rewritten.getType()); + assert(rewrittenSinkTy.getShape() == sinkTy.getShape() && + rewrittenSinkTy.getElementType() == sinkTy.getElementType() && + gpu::areLayoutsEquivalent( + sinkTy.getShape(), + cast(rewrittenSinkTy.getEncoding()), + cast(sinkTy.getEncoding())) && + "rewrite must preserve the intermediate sink memdesc"); + + for (ViewStep &step : trailingMemDescReplaySteps) { + if (step.kind == ViewStep::Reshape) { + rewritten = gpu::MemDescReshapeOp::create(rewriter, step.loc, rewritten, + step.dstShape); + } else { + rewritten = gpu::MemDescTransOp::create(rewriter, step.loc, rewritten, + step.order); + } + } + + auto rewrittenTy = cast(rewritten.getType()); + assert(rewrittenTy.getShape() == origTy.getShape() && + rewrittenTy.getElementType() == origTy.getElementType() && + gpu::areLayoutsEquivalent( + origTy.getShape(), + cast(rewrittenTy.getEncoding()), + cast(origTy.getEncoding())) && + "rewrite must preserve the final memdesc"); + + operand.assign(rewritten); + return success(); + } +}; + +} // namespace + +#define GEN_PASS_DEF_TRITONNVIDIAGPUREWRITEMMAOPERANDVIEWSTOMEMDESCPASS +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" + +class TritonNvidiaGPURewriteMmaOperandViewsToMemDescPass + : public impl::TritonNvidiaGPURewriteMmaOperandViewsToMemDescPassBase< + TritonNvidiaGPURewriteMmaOperandViewsToMemDescPass> { +public: + using BaseT = impl::TritonNvidiaGPURewriteMmaOperandViewsToMemDescPassBase< + TritonNvidiaGPURewriteMmaOperandViewsToMemDescPass>; + using BaseT::BaseT; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add, + RewriteMmaOperandViewsToMemDescForDotOp, + RewriteMmaOperandViewsToMemDescForDotOp>( + &getContext()); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace mlir::triton::nvidia_gpu diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 65d69ec01b1a..cfdd2b4a84b4 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -279,111 +279,6 @@ module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num // ----- -#blocked0 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> -#blocked3 = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> -#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> -#linear0 = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> -#linear2 = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 128], [16, 0], [32, 0], [64, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], warp = [[0, 32], [0, 64]], block = []}> -#shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> -#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> -#shared2 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 4}> -#shared3 = #ttg.shared_linear<{offset = [[0, 0, 0, 1], [0, 0, 0, 2], [0, 0, 0, 4], [0, 0, 0, 8], [0, 1, 0, 0], [0, 2, 0, 0], [0, 4, 0, 0], [1, 0, 0, 0], [2, 0, 0, 0], [4, 0, 0, 0], [8, 0, 0, 0], [16, 0, 0, 0], [0, 0, 1, 0], [0, 0, 2, 0], [0, 0, 4, 0]]}, alignment = 128> -#shared4 = #ttg.shared_linear<{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0], [128, 0], [0, 16], [0, 32], [0, 64]]}, alignment = 128> -#shared5 = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0], [32, 0], [64, 0]]}, alignment = 128> -#smem = #ttg.shared_memory -#tmem0 = #ttng.tensor_memory_encoding -// CHECK-DAG: #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> -// CHECK-DAG: #shared2 = #ttg.shared_linear<{{.*}}alignment = 128> -// CHECK-DAG: #shared6 = #ttg.shared_linear<{{.*}}alignment = 128> -module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: @swizzle0_operand_views - // CHECK-DAG: %[[A_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #{{.*}}> - // CHECK-DAG: %[[A_BASE:.*]] = ttg.local_alloc %[[A_DESC]] : (tensor<128x128xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<128x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> - // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC]] : (tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #shared2, #smem{{.*}}> - // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #shared2, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_TR0:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR0]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_TR1:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, #shared6, #smem{{.*}}> - // CHECK-NOT: tt.reshape - // CHECK-NOT: tt.trans - // CHECK-NOT: ttg.local_load - // CHECK: ttng.tc_gen5_mma %[[A_BASE]], %[[B_TR1]], %arg2, %true, %true - tt.func @swizzle0_operand_views( - %a_desc: !tt.tensordesc>, - %b_desc: !tt.tensordesc>, - %acc: !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory>) { - %true = arith.constant true - %c0_i32 = arith.constant 0 : i32 - %a = tt.descriptor_load %a_desc[%c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blocked0> - %b = tt.descriptor_load %b_desc[%c0_i32, %c0_i32, %c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #blocked1> - %b1 = tt.reshape %b : tensor<1x1x256x8x16xf8E4M3FN, #blocked1> -> tensor<8x32x8x16xf8E4M3FN, #blocked2> - %b2 = tt.trans %b1 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blocked2> -> tensor<32x8x8x16xf8E4M3FN, #blocked3> - %b3 = tt.reshape %b2 : tensor<32x8x8x16xf8E4M3FN, #blocked3> -> tensor<256x128xf8E4M3FN, #linear0> - %b4 = tt.trans %b3 {order = array} : tensor<256x128xf8E4M3FN, #linear0> -> tensor<128x256xf8E4M3FN, #linear2> - %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blocked0>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared0, #smem> - %b_s = ttg.local_alloc %b4 : (tensor<128x256xf8E4M3FN, #linear2>) -> !ttg.memdesc<128x256xf8E4M3FN, #shared5, #smem> - ttng.tc_gen5_mma %a_s, %b_s, %acc, %true, %true : !ttg.memdesc<128x128xf8E4M3FN, #shared0, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #shared5, #smem>, !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory> - tt.return - } - -} - - -// ----- - -#blockedA_desc = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blockedB_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> -#blockedB0_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> -#blockedB1_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> -#linear_desc = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> -#mma_desc = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 32]}> -#sharedA_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> -#sharedB0_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> -#sharedB1_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 4}> -#sharedB2_desc = #ttg.shared_linear<{offset = [[0, 0, 0, 1], [0, 0, 0, 2], [0, 0, 0, 4], [0, 0, 0, 8], [0, 1, 0, 0], [0, 2, 0, 0], [0, 4, 0, 0], [1, 0, 0, 0], [2, 0, 0, 0], [4, 0, 0, 0], [8, 0, 0, 0], [16, 0, 0, 0], [0, 0, 1, 0], [0, 0, 2, 0], [0, 0, 4, 0]]}, alignment = 128> -#sharedB3_desc = #ttg.shared_linear<{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0], [128, 0], [0, 16], [0, 32], [0, 64]]}, alignment = 128> -#sharedB4_desc = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0], [32, 0], [64, 0]]}, alignment = 128> -#smem = #ttg.shared_memory -// CHECK-DAG: #shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> -// CHECK-DAG: #shared2 = #ttg.shared_linear<{{.*}}alignment = 128> -// CHECK-DAG: #shared6 = #ttg.shared_linear<{{.*}}alignment = 128> -module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: @swizzle0_operand_views_warp_group_dot - // CHECK-DAG: %[[B_DESC_LOAD:.*]] = tt.descriptor_load %arg1[%{{.*}}] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> - // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC_LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, {{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #shared2, #smem{{.*}}> - // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #shared2, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_TR:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_T:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, #shared6, #smem{{.*}}> - // CHECK-NOT: tt.reshape - // CHECK-NOT: tt.trans - // CHECK: %[[DOT:.*]] = ttng.warp_group_dot %{{.*}}, %[[B_T]], %arg2 - // CHECK: %[[WAIT:.*]]:3 = ttng.warp_group_dot_wait %[[DOT]], {{.*}}, %[[B_T]] {pendings = 0 : i32} - tt.func @swizzle0_operand_views_warp_group_dot( - %a_desc: !tt.tensordesc>, - %b_desc: !tt.tensordesc>, - %acc: tensor<128x256xf32, #mma_desc>) -> tensor<128x256xf32, #mma_desc> { - %c0 = arith.constant 0 : i32 - %a = tt.descriptor_load %a_desc[%c0, %c0] : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blockedA_desc> - %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blockedA_desc>) -> !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> - %b_cm = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> - %b0 = tt.reshape %b_cm : tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> -> tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> - %b1 = tt.trans %b0 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> -> tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> - %b2 = tt.reshape %b1 : tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> -> tensor<256x128xf8E4M3FN, #linear_desc> - %b_s = ttg.local_alloc %b2 : (tensor<256x128xf8E4M3FN, #linear_desc>) -> !ttg.memdesc<256x128xf8E4M3FN, #sharedB3_desc, #smem> - %b_t = ttg.memdesc_trans %b_s {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #sharedB3_desc, #smem> -> !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> - %r = ttng.warp_group_dot %a_s, %b_t, %acc {inputPrecision = 0 : i32, isAsync = true, maxNumImpreciseAcc = 1073741824 : i32} - : !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> * !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> -> tensor<128x256xf32, #mma_desc> - %w:3 = ttng.warp_group_dot_wait %r, %a_s, %b_t {pendings = 0 : i32} : tensor<128x256xf32, #mma_desc>, !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> - tt.return %w#0 : tensor<128x256xf32, #mma_desc> - } -} - -// ----- - #blocked = #ttg.blocked<{sizePerThread = [1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 32], warpsPerCTA = [1, 2, 2, 1], order = [3, 2, 1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 32}> diff --git a/test/TritonNvidiaGPU/rewrite-mma-operand-views-to-memdesc.mlir b/test/TritonNvidiaGPU/rewrite-mma-operand-views-to-memdesc.mlir new file mode 100644 index 000000000000..fff031c485fb --- /dev/null +++ b/test/TritonNvidiaGPU/rewrite-mma-operand-views-to-memdesc.mlir @@ -0,0 +1,96 @@ +// RUN: triton-opt %s -split-input-file -triton-nvidia-optimize-descriptor-encoding -triton-nvidia-rewrite-mma-operand-views-to-memdesc -canonicalize | FileCheck %s + +#blocked0 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> +#linear0 = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> +#linear2 = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 128], [16, 0], [32, 0], [64, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], warp = [[0, 32], [0, 64]], block = []}> +#shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> +#shared5 = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0], [32, 0], [64, 0]]}, alignment = 128> +#smem = #ttg.shared_memory +#tmem0 = #ttng.tensor_memory_encoding +// CHECK-DAG: #{{.*}} = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> +// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> +// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> +module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @swizzle0_operand_views + // CHECK-DAG: %[[A_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[A_BASE:.*]] = ttg.local_alloc %[[A_DESC]] : (tensor<128x128xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<128x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC]] : (tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_TR0:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR0]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_TR1:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-NOT: tt.reshape + // CHECK-NOT: tt.trans + // CHECK-NOT: ttg.local_load + // CHECK: ttng.tc_gen5_mma %[[A_BASE]], %[[B_TR1]], %arg2, %true, %true + tt.func @swizzle0_operand_views( + %a_desc: !tt.tensordesc>, + %b_desc: !tt.tensordesc>, + %acc: !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory>) { + %true = arith.constant true + %c0_i32 = arith.constant 0 : i32 + %a = tt.descriptor_load %a_desc[%c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blocked0> + %b = tt.descriptor_load %b_desc[%c0_i32, %c0_i32, %c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #blocked1> + %b1 = tt.reshape %b : tensor<1x1x256x8x16xf8E4M3FN, #blocked1> -> tensor<8x32x8x16xf8E4M3FN, #blocked2> + %b2 = tt.trans %b1 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blocked2> -> tensor<32x8x8x16xf8E4M3FN, #blocked3> + %b3 = tt.reshape %b2 : tensor<32x8x8x16xf8E4M3FN, #blocked3> -> tensor<256x128xf8E4M3FN, #linear0> + %b4 = tt.trans %b3 {order = array} : tensor<256x128xf8E4M3FN, #linear0> -> tensor<128x256xf8E4M3FN, #linear2> + %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blocked0>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared0, #smem> + %b_s = ttg.local_alloc %b4 : (tensor<128x256xf8E4M3FN, #linear2>) -> !ttg.memdesc<128x256xf8E4M3FN, #shared5, #smem> + ttng.tc_gen5_mma %a_s, %b_s, %acc, %true, %true : !ttg.memdesc<128x128xf8E4M3FN, #shared0, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #shared5, #smem>, !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory> + tt.return + } +} + +// ----- + +#blockedA_desc = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +#blockedB0_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> +#blockedB1_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> +#linear_desc = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> +#mma_desc = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 32]}> +#sharedA_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#sharedB0_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> +#sharedB3_desc = #ttg.shared_linear<{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0], [128, 0], [0, 16], [0, 32], [0, 64]]}, alignment = 128> +#sharedB4_desc = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0], [32, 0], [64, 0]]}, alignment = 128> +#smem = #ttg.shared_memory +// CHECK-DAG: #{{.*}} = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> +// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> +// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @swizzle0_operand_views_warp_group_dot + // CHECK-DAG: %[[B_DESC_LOAD:.*]] = tt.descriptor_load %arg1[%{{.*}}] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC_LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, {{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_TR:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_T:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-NOT: tt.reshape + // CHECK-NOT: tt.trans + // CHECK: %[[DOT:.*]] = ttng.warp_group_dot %{{.*}}, %[[B_T]], %arg2 + // CHECK: %[[WAIT:.*]]:3 = ttng.warp_group_dot_wait %[[DOT]], {{.*}}, %[[B_T]] {pendings = 0 : i32} + tt.func @swizzle0_operand_views_warp_group_dot( + %a_desc: !tt.tensordesc>, + %b_desc: !tt.tensordesc>, + %acc: tensor<128x256xf32, #mma_desc>) -> tensor<128x256xf32, #mma_desc> { + %c0 = arith.constant 0 : i32 + %a = tt.descriptor_load %a_desc[%c0, %c0] : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blockedA_desc> + %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blockedA_desc>) -> !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> + %b_cm = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> + %b0 = tt.reshape %b_cm : tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> -> tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> + %b1 = tt.trans %b0 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> -> tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> + %b2 = tt.reshape %b1 : tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> -> tensor<256x128xf8E4M3FN, #linear_desc> + %b_s = ttg.local_alloc %b2 : (tensor<256x128xf8E4M3FN, #linear_desc>) -> !ttg.memdesc<256x128xf8E4M3FN, #sharedB3_desc, #smem> + %b_t = ttg.memdesc_trans %b_s {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #sharedB3_desc, #smem> -> !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> + %r = ttng.warp_group_dot %a_s, %b_t, %acc {inputPrecision = 0 : i32, isAsync = true, maxNumImpreciseAcc = 1073741824 : i32} + : !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> * !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> -> tensor<128x256xf32, #mma_desc> + %w:3 = ttng.warp_group_dot_wait %r, %a_s, %b_t {pendings = 0 : i32} : tensor<128x256xf32, #mma_desc>, !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> + tt.return %w#0 : tensor<128x256xf32, #mma_desc> + } +} diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index fa027f416660..9e64b4a50f18 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -274,6 +274,7 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm) + nvidia.passes.ttnvgpuir.add_rewrite_mma_operand_views_to_memdesc(pm) passes.ttir.add_loop_aware_cse(pm) if capability // 10 in [8, 9]: passes.ttgpuir.add_fuse_nested_loops(pm) diff --git a/third_party/nvidia/triton_nvidia.cc b/third_party/nvidia/triton_nvidia.cc index 043c1a4fee6a..189d0cbfefe5 100644 --- a/third_party/nvidia/triton_nvidia.cc +++ b/third_party/nvidia/triton_nvidia.cc @@ -182,6 +182,8 @@ void init_triton_nvidia_passes_ttnvgpuir(py::module &&m) { ttng::createTritonNvidiaGPUMMALoweringPass); ADD_PASS_WRAPPER_0("add_optimize_descriptor_encoding", ttng::createTritonNvidiaGPUOptimizeDescriptorEncodingPass); + ADD_PASS_WRAPPER_0("add_rewrite_mma_operand_views_to_memdesc", + ttng::createTritonNvidiaGPURewriteMmaOperandViewsToMemDescPass); ADD_PASS_WRAPPER_0("add_optimize_tmem_layouts", ttng::createTritonNvidiaGPUOptimizeTMemLayoutsPass); ADD_PASS_WRAPPER_0("add_interleave_tmem", From df2f6f950cab39a00a28f8b519bb87eb8a67ba82 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 9 Apr 2026 19:39:26 +0900 Subject: [PATCH 44/54] Simplify MMA operand view rewrite --- .../RewriteMmaOperandViewsToMemDesc.cpp | 246 ++++++------------ 1 file changed, 81 insertions(+), 165 deletions(-) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp index 26b1aa1f6f86..e39eaa56bc0a 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp @@ -3,7 +3,6 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" @@ -17,203 +16,123 @@ namespace { // desc_load -> local_alloc -> memdesc reshape / trans -> [memdesc views] -> // mma // -// The MMA operand layout is determined by the sink memdesc already feeding the -// dot-like op. This pattern back-propagates that layout through the tensor -// reshape/transpose chain, hoists local_alloc to the descriptor_load result, -// and then replays the same views as memdesc reshape/transpose ops. +// The MMA operand layout is already determined by the shared-memory memdesc +// feeding the dot-like op. This pattern lifts a descriptor-load-backed tensor +// view chain into equivalent memdesc reshape/transpose ops, while keeping the +// chosen MMA sink layout unchanged. // -// The rewrite only applies when the backward-inferred base memdesc layout is -// equivalent to the descriptor block layout, so the hoisted local_alloc still -// represents the same underlying shared-memory view. -template -class RewriteMmaOperandViewsToMemDescForDotOp - : public OpRewritePattern { +// The optimization is intentionally narrow: +// - the chain must start at tt.descriptor_load +// - each tensor view along the lifted path must have one use +// - only shared-memory memdesc operands of dot-like ops are considered +class RewriteMmaOperandViewsToMemDesc + : public OpInterfaceRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpInterfaceRewritePattern< + triton::DotOpInterface>::OpInterfaceRewritePattern; - LogicalResult matchAndRewrite(DotOpTy dotOp, + LogicalResult matchAndRewrite(triton::DotOpInterface dotOp, PatternRewriter &rewriter) const override { - Value oldA = dotOp.getA(); - Value oldB = dotOp.getB(); - bool changed = false; - - if (rewriteOperand(dotOp.getAMutable(), rewriter).succeeded()) { - oldA.replaceAllUsesExcept(dotOp.getA(), dotOp.getOperation()); - changed = true; - } + if (!isa(dotOp)) + return failure(); - if (rewriteOperand(dotOp.getBMutable(), rewriter).succeeded()) { - oldB.replaceAllUsesExcept(dotOp.getB(), dotOp.getOperation()); - changed = true; + bool changed = false; + for (OpOperand &operand : dotOp->getOpOperands()) { + auto memDesc = dyn_cast>(operand.get()); + if (!memDesc || + !isa(memDesc.getType().getMemorySpace())) + continue; + changed |= succeeded(rewriteOperand(memDesc, rewriter)); } return success(changed); } private: - struct ViewStep { - enum Kind { Reshape, Transpose } kind; - SmallVector srcShape; - SmallVector dstShape; - SmallVector order; - Operation *op; - Location loc; - }; - - template - static std::tuple> - collectViewSteps(Value value) { - Value current = value; - SmallVector replaySteps; - while (true) { - if (auto reshape = current.template getDefiningOp()) { - auto srcTy = reshape.getSrc().getType(); - auto dstTy = reshape.getType(); - replaySteps.push_back(ViewStep{ViewStep::Reshape, - SmallVector(srcTy.getShape()), - SmallVector(dstTy.getShape()), - {}, - reshape.getOperation(), - reshape.getLoc()}); - current = reshape.getSrc(); - continue; - } - if (auto trans = current.template getDefiningOp()) { - SmallVector order(trans.getOrder().begin(), - trans.getOrder().end()); - auto srcTy = trans.getSrc().getType(); - auto dstTy = trans.getType(); - replaySteps.push_back(ViewStep{ - ViewStep::Transpose, SmallVector(srcTy.getShape()), - SmallVector(dstTy.getShape()), std::move(order), - trans.getOperation(), trans.getLoc()}); - current = trans.getSrc(); - continue; - } - break; - } - return {current, llvm::to_vector(llvm::reverse(replaySteps))}; + static bool isTensorViewOp(Operation *op) { + return isa(op); } - static gpu::SharedEncodingTrait getDescriptorSharedEncoding(Value baseTensor) { - if (auto descLoad = baseTensor.getDefiningOp()) { - auto descTy = cast(descLoad.getDesc().getType()); - auto descBlockTy = descTy.getBlockType(); - return dyn_cast_or_null( - descBlockTy.getEncoding()); - } - return nullptr; + static bool isMemDescViewOp(Operation *op) { + return isa(op); } - static FailureOr - inferViewStepBackward(gpu::MemDescType resultTy, const ViewStep &step) { - assert(resultTy.getShape() == ArrayRef(step.dstShape) && - "backward inference must start from the view step destination " - "shape"); - if (step.kind == ViewStep::Reshape) { - gpu::MemDescType srcTy; - if (failed(gpu::MemDescReshapeOp::inferReturnTypes( - resultTy.getContext(), step.loc, resultTy, step.srcShape, - srcTy))) - return failure(); - return srcTy; + LogicalResult rewriteOperand(TypedValue memDesc, + PatternRewriter &rewriter) const { + Value current = memDesc; + + // Strip trailing memdesc views so we can rewrite the producing local_alloc. + while (Operation *def = current.getDefiningOp()) { + if (!isMemDescViewOp(def)) + break; + current = def->getOperand(0); } - Attribute srcEnc = inferSrcEncoding(step.op, resultTy.getEncoding()); - if (!srcEnc) - return failure(); - return gpu::MemDescType::get(step.srcShape, resultTy.getElementType(), - srcEnc, resultTy.getMemorySpace(), - resultTy.getMutableMemory()); - } - static LogicalResult - verifyBaseMatchesDescriptorLayout(gpu::MemDescType inferredBaseTy, - Value baseTensor) { - auto descriptorSharedEnc = getDescriptorSharedEncoding(baseTensor); - if (!descriptorSharedEnc) + auto localAlloc = current.getDefiningOp(); + if (!localAlloc || !localAlloc.getSrc()) return failure(); - return success(gpu::areLayoutsEquivalent( - inferredBaseTy.getShape(), - cast(inferredBaseTy.getEncoding()), - cast(cast(descriptorSharedEnc)))); - } - LogicalResult rewriteOperand(OpOperand &operand, - PatternRewriter &rewriter) const { - Value orig = operand.get(); - auto origTy = dyn_cast(orig.getType()); - if (!origTy) - return failure(); + Value localAllocSrc = localAlloc.getSrc(); + current = localAllocSrc; - auto [beforeTrailing, trailingMemDescReplaySteps] = - collectViewSteps(orig); + // Walk back to the base of the tensor view chain. + while (Operation *def = current.getDefiningOp()) { + if (!isTensorViewOp(def)) + break; + current = def->getOperand(0); + } - auto localAlloc = beforeTrailing.template getDefiningOp(); - if (!localAlloc || !localAlloc.getSrc()) + // If there are no tensor views, there is nothing to lift. + if (current == localAllocSrc) return failure(); - auto [baseTensor, tensorReplaySteps] = - collectViewSteps( - localAlloc.getSrc()); - if (tensorReplaySteps.empty()) + auto descLoad = current.getDefiningOp(); + if (!descLoad) return failure(); - gpu::MemDescType baseMemTy = localAlloc.getType(); - for (const ViewStep &step : llvm::reverse(tensorReplaySteps)) { - auto srcTy = inferViewStepBackward(baseMemTy, step); - if (failed(srcTy)) + RankedTensorType blockTy = descLoad.getDesc().getType().getBlockType(); + auto descriptorSharedEnc = + cast(blockTy.getEncoding()); + gpu::MemDescType descMemTy = gpu::MemDescType::get( + blockTy.getShape(), blockTy.getElementType(), descriptorSharedEnc, + localAlloc.getType().getMemorySpace(), + localAlloc.getType().getMutableMemory()); + + // Validate that the tensor view chain is a single-use path from the + // descriptor load to the original local_alloc source. + SmallVector tensorViewOps; + Value path = current; + while (path != localAllocSrc) { + if (!path.hasOneUse()) return failure(); - baseMemTy = *srcTy; + Operation *viewOp = *path.getUsers().begin(); + if (!isTensorViewOp(viewOp)) + return failure(); + tensorViewOps.push_back(viewOp); + path = viewOp->getResult(0); } - if (failed(verifyBaseMatchesDescriptorLayout(baseMemTy, baseTensor))) - return failure(); PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(localAlloc); Value rewritten = gpu::LocalAllocOp::create(rewriter, localAlloc.getLoc(), - baseMemTy, baseTensor); - auto sinkTy = localAlloc.getType(); + descMemTy, current); - for (ViewStep &step : tensorReplaySteps) { - if (step.kind == ViewStep::Reshape) { - rewritten = gpu::MemDescReshapeOp::create(rewriter, step.loc, rewritten, - step.dstShape); + for (Operation *viewOp : tensorViewOps) { + if (auto trans = dyn_cast(viewOp)) { + rewritten = gpu::MemDescTransOp::create(rewriter, viewOp->getLoc(), + rewritten, trans.getOrder()); } else { - rewritten = gpu::MemDescTransOp::create(rewriter, step.loc, rewritten, - step.order); + auto reshape = cast(viewOp); + rewritten = + gpu::MemDescReshapeOp::create(rewriter, reshape.getLoc(), rewritten, + reshape.getType().getShape()); } } - auto rewrittenSinkTy = cast(rewritten.getType()); - assert(rewrittenSinkTy.getShape() == sinkTy.getShape() && - rewrittenSinkTy.getElementType() == sinkTy.getElementType() && - gpu::areLayoutsEquivalent( - sinkTy.getShape(), - cast(rewrittenSinkTy.getEncoding()), - cast(sinkTy.getEncoding())) && - "rewrite must preserve the intermediate sink memdesc"); - - for (ViewStep &step : trailingMemDescReplaySteps) { - if (step.kind == ViewStep::Reshape) { - rewritten = gpu::MemDescReshapeOp::create(rewriter, step.loc, rewritten, - step.dstShape); - } else { - rewritten = gpu::MemDescTransOp::create(rewriter, step.loc, rewritten, - step.order); - } - } - - auto rewrittenTy = cast(rewritten.getType()); - assert(rewrittenTy.getShape() == origTy.getShape() && - rewrittenTy.getElementType() == origTy.getElementType() && - gpu::areLayoutsEquivalent( - origTy.getShape(), - cast(rewrittenTy.getEncoding()), - cast(origTy.getEncoding())) && - "rewrite must preserve the final memdesc"); - - operand.assign(rewritten); + assert(rewritten.getType() == localAlloc.getType() && + "rewrite must preserve local_alloc result type"); + localAlloc.replaceAllUsesWith(rewritten); return success(); } }; @@ -233,10 +152,7 @@ class TritonNvidiaGPURewriteMmaOperandViewsToMemDescPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); - patterns.add, - RewriteMmaOperandViewsToMemDescForDotOp, - RewriteMmaOperandViewsToMemDescForDotOp>( - &getContext()); + patterns.add(&getContext()); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } From 52f2848281d7dcecfc59f0e251ec1e22786b27e7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 9 Apr 2026 10:49:45 +0000 Subject: [PATCH 45/54] precommit --- third_party/nvidia/triton_nvidia.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/third_party/nvidia/triton_nvidia.cc b/third_party/nvidia/triton_nvidia.cc index 189d0cbfefe5..eb508107b86e 100644 --- a/third_party/nvidia/triton_nvidia.cc +++ b/third_party/nvidia/triton_nvidia.cc @@ -182,8 +182,9 @@ void init_triton_nvidia_passes_ttnvgpuir(py::module &&m) { ttng::createTritonNvidiaGPUMMALoweringPass); ADD_PASS_WRAPPER_0("add_optimize_descriptor_encoding", ttng::createTritonNvidiaGPUOptimizeDescriptorEncodingPass); - ADD_PASS_WRAPPER_0("add_rewrite_mma_operand_views_to_memdesc", - ttng::createTritonNvidiaGPURewriteMmaOperandViewsToMemDescPass); + ADD_PASS_WRAPPER_0( + "add_rewrite_mma_operand_views_to_memdesc", + ttng::createTritonNvidiaGPURewriteMmaOperandViewsToMemDescPass); ADD_PASS_WRAPPER_0("add_optimize_tmem_layouts", ttng::createTritonNvidiaGPUOptimizeTMemLayoutsPass); ADD_PASS_WRAPPER_0("add_interleave_tmem", From a77c439cdbcb4caf0d8a89462525fef41bf61032 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Apr 2026 05:38:24 +0900 Subject: [PATCH 46/54] Revert to the old backward inference impl, run the pass before ODE --- .../Transforms/DescriptorMemoryLayouts.h | 9 + .../Transforms/DescriptorMemoryLayouts.cpp | 32 ++- .../Transforms/OptimizeDescriptorEncoding.cpp | 47 +++- .../RewriteMmaOperandViewsToMemDesc.cpp | 237 +++++++++++------- .../optimize_descriptor_encoding.mlir | 22 ++ .../rewrite-mma-operand-views-to-memdesc.mlir | 20 +- third_party/nvidia/backend/compiler.py | 2 +- 7 files changed, 254 insertions(+), 115 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.h b/include/triton/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.h index e91d4cc37879..dab48523f207 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.h +++ b/include/triton/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.h @@ -40,6 +40,15 @@ class AssignDescriptorMemoryLayouts { CGAEncodingAttr cgaLayout, ArrayRef usageShape, unsigned numCTAs); + +protected: + virtual Attribute getCompatibleSharedEncoding(Attribute enc, + ArrayRef shape, + Type elementType) { + return isCompatibleSharedEncoding(enc) ? enc : Attribute(); + } + +private: // Override with backend specific implementation virtual Attribute buildFallbackSharedEncoding(mlir::MLIRContext *, ArrayRef, diff --git a/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp b/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp index 52d9a30c3c27..df0ff3ec691f 100644 --- a/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp +++ b/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp @@ -250,23 +250,35 @@ EncodingInfo AssignDescriptorMemoryLayouts::combineEncodings( Attribute AssignDescriptorMemoryLayouts::findLoadEncodingFromUsers(Operation *op) { + auto getCompatibleEncodingForType = [&](Type type) -> Attribute { + if (auto memDescTy = dyn_cast(type)) { + return getCompatibleSharedEncoding(memDescTy.getEncoding(), + memDescTy.getShape(), + memDescTy.getElementType()); + } + if (auto tensorTy = dyn_cast(type)) { + return getCompatibleSharedEncoding(tensorTy.getEncoding(), + tensorTy.getShape(), + tensorTy.getElementType()); + } + return {}; + }; + // Check if there are any desired encodings available on the op if (auto attr = op->getDiscardableAttr("tt.desired_encoding")) { - if (auto enc = dyn_cast(attr)) { - if (isCompatibleSharedEncoding(enc)) - return enc; - } + if (auto resultTy = dyn_cast(op->getResult(0).getType())) + if (auto compatible = getCompatibleSharedEncoding( + attr, resultTy.getShape(), resultTy.getElementType())) + return compatible; } // Ignore multiple users and just pick the first compatible layout for (auto use : op->getUsers()) { if (auto alloc = dyn_cast(use)) { - auto enc = alloc.getType().getEncoding(); - if (isCompatibleSharedEncoding(enc)) - return enc; + if (auto compatible = getCompatibleEncodingForType(alloc.getType())) + return compatible; } else if (auto store = dyn_cast(use)) { - auto enc = store.getDst().getType().getEncoding(); - if (isCompatibleSharedEncoding(enc)) - return enc; + if (auto compatible = getCompatibleEncodingForType(store.getDst().getType())) + return compatible; } } return {}; diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp index 2dbaab2d77ce..5f0eeab41fa7 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp @@ -20,15 +20,56 @@ class NvidiaGPUAssignDescriptorMemoryLayouts ArrayRef order, ttg::CGAEncodingAttr cgaLayout, Type elementType) override; + Attribute getCompatibleSharedEncoding(Attribute enc, ArrayRef shape, + Type elementType) override; bool isCompatibleSharedEncoding(Attribute enc) override; }; bool NvidiaGPUAssignDescriptorMemoryLayouts::isCompatibleSharedEncoding( Attribute enc) { - if (auto nvmma = dyn_cast(enc)) { - return !nvmma.getTransposed(); + return isa(enc); +} + +Attribute NvidiaGPUAssignDescriptorMemoryLayouts::getCompatibleSharedEncoding( + Attribute enc, ArrayRef shape, Type elementType) { + if (isCompatibleSharedEncoding(enc)) + return enc; + + auto sharedLinear = dyn_cast(enc); + if (!sharedLinear) + return {}; + + auto *ctx = enc.getContext(); + auto cgaLayout = ttg::getCGALayout(sharedLinear); + auto order = ttg::getOrder(sharedLinear, shape); + + SmallVector preferredCandidates; + // Preserve Triton's default shape/order-based choice when it already matches + // this shared_linear layout. The full candidate scan below is only a + // fallback for equivalent layouts not selected by the heuristic builder. + for (bool fp4Padded : {false, true}) { + auto preferred = ttg::NVMMASharedEncodingAttr::get( + ctx, shape, order, cgaLayout, elementType, fp4Padded); + preferredCandidates.push_back(preferred); + if (ttg::areLayoutsEquivalent(shape, sharedLinear, preferred)) + return preferred; + } + + unsigned elementBitWidth = std::max(8u, elementType.getIntOrFloatBitWidth()); + for (bool transposed : {false, true}) { + for (bool fp4Padded : {false, true}) { + for (unsigned swizzle : {0u, 32u, 64u, 128u}) { + auto candidate = ttg::NVMMASharedEncodingAttr::get( + ctx, swizzle, transposed, elementBitWidth, fp4Padded, cgaLayout); + if (llvm::is_contained(preferredCandidates, candidate)) + continue; + if (ttg::areLayoutsEquivalent(shape, sharedLinear, candidate)) + return candidate; + } + } } - return false; + + return {}; } // Build fallback encoding given shape, order, cga layout and element type diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp index e39eaa56bc0a..8fb981012633 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp @@ -3,6 +3,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" @@ -11,128 +12,183 @@ namespace mlir::triton::nvidia_gpu { namespace { // Rewrite -// desc_load -> tt.reshape / tt.trans -> local_alloc -> [memdesc views] -> mma +// tt.reshape / tt.trans -> local_alloc -> [memdesc views] -> mma // into -// desc_load -> local_alloc -> memdesc reshape / trans -> [memdesc views] -> -// mma +// local_alloc -> memdesc reshape / trans -> [memdesc views] -> mma // -// The MMA operand layout is already determined by the shared-memory memdesc -// feeding the dot-like op. This pattern lifts a descriptor-load-backed tensor -// view chain into equivalent memdesc reshape/transpose ops, while keeping the -// chosen MMA sink layout unchanged. -// -// The optimization is intentionally narrow: -// - the chain must start at tt.descriptor_load -// - each tensor view along the lifted path must have one use -// - only shared-memory memdesc operands of dot-like ops are considered -class RewriteMmaOperandViewsToMemDesc - : public OpInterfaceRewritePattern { +// The MMA operand layout is determined by the sink memdesc already feeding the +// dot-like op. This pattern back-propagates that layout through the tensor +// reshape/transpose chain, hoists local_alloc to the base tensor feeding that +// view chain, and then replays the same views as memdesc reshape/transpose +// ops. +template +class RewriteMmaOperandViewsToMemDescForDotOp + : public OpRewritePattern { public: - using OpInterfaceRewritePattern< - triton::DotOpInterface>::OpInterfaceRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(triton::DotOpInterface dotOp, + LogicalResult matchAndRewrite(DotOpTy dotOp, PatternRewriter &rewriter) const override { - if (!isa(dotOp)) - return failure(); - + Value oldA = dotOp.getA(); + Value oldB = dotOp.getB(); bool changed = false; - for (OpOperand &operand : dotOp->getOpOperands()) { - auto memDesc = dyn_cast>(operand.get()); - if (!memDesc || - !isa(memDesc.getType().getMemorySpace())) - continue; - changed |= succeeded(rewriteOperand(memDesc, rewriter)); + + if (rewriteOperand(dotOp.getAMutable(), rewriter).succeeded()) { + oldA.replaceAllUsesExcept(dotOp.getA(), dotOp.getOperation()); + changed = true; + } + + if (rewriteOperand(dotOp.getBMutable(), rewriter).succeeded()) { + oldB.replaceAllUsesExcept(dotOp.getB(), dotOp.getOperation()); + changed = true; } return success(changed); } private: - static bool isTensorViewOp(Operation *op) { - return isa(op); + struct ViewStep { + enum Kind { Reshape, Transpose } kind; + SmallVector srcShape; + SmallVector dstShape; + SmallVector order; + Operation *op; + Location loc; + }; + + template + static std::tuple> + collectViewSteps(Value value) { + Value current = value; + SmallVector replaySteps; + while (true) { + if (auto reshape = current.template getDefiningOp()) { + auto srcTy = reshape.getSrc().getType(); + auto dstTy = reshape.getType(); + replaySteps.push_back(ViewStep{ViewStep::Reshape, + SmallVector(srcTy.getShape()), + SmallVector(dstTy.getShape()), + {}, + reshape.getOperation(), + reshape.getLoc()}); + current = reshape.getSrc(); + continue; + } + if (auto trans = current.template getDefiningOp()) { + SmallVector order(trans.getOrder().begin(), + trans.getOrder().end()); + auto srcTy = trans.getSrc().getType(); + auto dstTy = trans.getType(); + replaySteps.push_back(ViewStep{ + ViewStep::Transpose, SmallVector(srcTy.getShape()), + SmallVector(dstTy.getShape()), std::move(order), + trans.getOperation(), trans.getLoc()}); + current = trans.getSrc(); + continue; + } + break; + } + return {current, llvm::to_vector(llvm::reverse(replaySteps))}; } - static bool isMemDescViewOp(Operation *op) { - return isa(op); + static FailureOr + inferViewStepBackward(gpu::MemDescType resultTy, const ViewStep &step) { + assert(resultTy.getShape() == ArrayRef(step.dstShape) && + "backward inference must start from the view step destination " + "shape"); + if (step.kind == ViewStep::Reshape) { + gpu::MemDescType srcTy; + if (failed(gpu::MemDescReshapeOp::inferReturnTypes( + resultTy.getContext(), step.loc, resultTy, step.srcShape, srcTy))) + return failure(); + return srcTy; + } + Attribute srcEnc = inferSrcEncoding(step.op, resultTy.getEncoding()); + if (!srcEnc) + return failure(); + return gpu::MemDescType::get(step.srcShape, resultTy.getElementType(), + srcEnc, resultTy.getMemorySpace(), + resultTy.getMutableMemory()); } - LogicalResult rewriteOperand(TypedValue memDesc, - PatternRewriter &rewriter) const { - Value current = memDesc; - - // Strip trailing memdesc views so we can rewrite the producing local_alloc. - while (Operation *def = current.getDefiningOp()) { - if (!isMemDescViewOp(def)) - break; - current = def->getOperand(0); + static Value replayViewSteps(PatternRewriter &rewriter, Value value, + ArrayRef steps) { + Value rewritten = value; + for (const ViewStep &step : steps) { + if (step.kind == ViewStep::Reshape) { + rewritten = gpu::MemDescReshapeOp::create(rewriter, step.loc, rewritten, + step.dstShape); + } else { + rewritten = gpu::MemDescTransOp::create(rewriter, step.loc, rewritten, + step.order); + } } + return rewritten; + } - auto localAlloc = current.getDefiningOp(); - if (!localAlloc || !localAlloc.getSrc()) - return failure(); + static void assertEquivalentMemDescType(gpu::MemDescType actualTy, + gpu::MemDescType expectedTy, + const char *message) { + assert(actualTy.getShape() == expectedTy.getShape() && + actualTy.getElementType() == expectedTy.getElementType() && + gpu::areLayoutsEquivalent( + expectedTy.getShape(), + cast(actualTy.getEncoding()), + cast(expectedTy.getEncoding())) && + message); + } - Value localAllocSrc = localAlloc.getSrc(); - current = localAllocSrc; + LogicalResult rewriteOperand(OpOperand &operand, + PatternRewriter &rewriter) const { + Value orig = operand.get(); + auto origTy = dyn_cast(orig.getType()); + if (!origTy) + return failure(); - // Walk back to the base of the tensor view chain. - while (Operation *def = current.getDefiningOp()) { - if (!isTensorViewOp(def)) - break; - current = def->getOperand(0); - } + auto [beforeTrailing, trailingMemDescReplaySteps] = + collectViewSteps(orig); - // If there are no tensor views, there is nothing to lift. - if (current == localAllocSrc) + auto localAlloc = + beforeTrailing.template getDefiningOp(); + if (!localAlloc || !localAlloc.getSrc()) return failure(); - auto descLoad = current.getDefiningOp(); - if (!descLoad) + auto [baseTensor, tensorReplaySteps] = + collectViewSteps( + localAlloc.getSrc()); + if (tensorReplaySteps.empty()) return failure(); - RankedTensorType blockTy = descLoad.getDesc().getType().getBlockType(); - auto descriptorSharedEnc = - cast(blockTy.getEncoding()); - gpu::MemDescType descMemTy = gpu::MemDescType::get( - blockTy.getShape(), blockTy.getElementType(), descriptorSharedEnc, - localAlloc.getType().getMemorySpace(), - localAlloc.getType().getMutableMemory()); - - // Validate that the tensor view chain is a single-use path from the - // descriptor load to the original local_alloc source. - SmallVector tensorViewOps; - Value path = current; - while (path != localAllocSrc) { - if (!path.hasOneUse()) - return failure(); - Operation *viewOp = *path.getUsers().begin(); - if (!isTensorViewOp(viewOp)) + gpu::MemDescType baseMemTy = localAlloc.getType(); + for (const ViewStep &step : llvm::reverse(tensorReplaySteps)) { + auto srcTy = inferViewStepBackward(baseMemTy, step); + if (failed(srcTy)) return failure(); - tensorViewOps.push_back(viewOp); - path = viewOp->getResult(0); + baseMemTy = *srcTy; } PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(localAlloc); Value rewritten = gpu::LocalAllocOp::create(rewriter, localAlloc.getLoc(), - descMemTy, current); + baseMemTy, baseTensor); + auto sinkTy = localAlloc.getType(); - for (Operation *viewOp : tensorViewOps) { - if (auto trans = dyn_cast(viewOp)) { - rewritten = gpu::MemDescTransOp::create(rewriter, viewOp->getLoc(), - rewritten, trans.getOrder()); - } else { - auto reshape = cast(viewOp); - rewritten = - gpu::MemDescReshapeOp::create(rewriter, reshape.getLoc(), rewritten, - reshape.getType().getShape()); - } - } + rewritten = replayViewSteps(rewriter, rewritten, tensorReplaySteps); + + auto rewrittenSinkTy = cast(rewritten.getType()); + assertEquivalentMemDescType( + rewrittenSinkTy, sinkTy, + "rewrite must preserve the intermediate sink memdesc"); + + rewritten = + replayViewSteps(rewriter, rewritten, trailingMemDescReplaySteps); + + auto rewrittenTy = cast(rewritten.getType()); + assertEquivalentMemDescType(rewrittenTy, origTy, + "rewrite must preserve the final memdesc"); - assert(rewritten.getType() == localAlloc.getType() && - "rewrite must preserve local_alloc result type"); - localAlloc.replaceAllUsesWith(rewritten); + operand.assign(rewritten); return success(); } }; @@ -152,7 +208,10 @@ class TritonNvidiaGPURewriteMmaOperandViewsToMemDescPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); + patterns.add, + RewriteMmaOperandViewsToMemDescForDotOp, + RewriteMmaOperandViewsToMemDescForDotOp>( + &getContext()); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } diff --git a/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir b/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir index db185179ae88..a6131b962ad1 100644 --- a/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir +++ b/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir @@ -123,3 +123,25 @@ tt.func public @tma_load_while(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, tt.return } } + +// ----- + +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +#shared_linear_base = #ttg.shared_linear<{offset = [[0, 0, 0, 0, 1], [0, 0, 0, 0, 2], [0, 0, 0, 0, 4], [0, 0, 0, 0, 8], [0, 0, 0, 1, 0], [0, 0, 0, 2, 0], [0, 0, 0, 4, 0], [0, 0, 1, 0, 0], [0, 0, 2, 0, 0], [0, 0, 4, 0, 0], [0, 0, 8, 0, 0], [0, 0, 16, 0, 0], [0, 0, 32, 0, 0], [0, 0, 64, 0, 0], [0, 0, 128, 0, 0]]}, alignment = 128> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { +// CHECK-DAG: #[[NVMMA_0:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> +// CHECK-DAG: #[[BLOCKED5D:.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +// CHECK-DAG: #[[SL_BASE:.*]] = #ttg.shared_linear<{{.*}}alignment = 128> +tt.func public @descriptor_arg_from_shared_linear_use(%b_desc: !tt.tensordesc>) { + // CHECK: %arg0: !tt.tensordesc> + // CHECK: %[[LOAD:.*]] = tt.descriptor_load %arg0 + // CHECK-SAME: : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #[[BLOCKED5D]]> + // CHECK: ttg.local_alloc %[[LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, #[[BLOCKED5D]]>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #[[SL_BASE]], #smem> + %c0 = arith.constant 0 : i32 + %b = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #blocked2> + %b_s = ttg.local_alloc %b : (tensor<1x1x256x8x16xf8E4M3FN, #blocked2>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #shared_linear_base, #smem> + tt.return +} +} diff --git a/test/TritonNvidiaGPU/rewrite-mma-operand-views-to-memdesc.mlir b/test/TritonNvidiaGPU/rewrite-mma-operand-views-to-memdesc.mlir index fff031c485fb..114f4b5986d4 100644 --- a/test/TritonNvidiaGPU/rewrite-mma-operand-views-to-memdesc.mlir +++ b/test/TritonNvidiaGPU/rewrite-mma-operand-views-to-memdesc.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -triton-nvidia-optimize-descriptor-encoding -triton-nvidia-rewrite-mma-operand-views-to-memdesc -canonicalize | FileCheck %s +// RUN: triton-opt %s -split-input-file -triton-nvidia-rewrite-mma-operand-views-to-memdesc -canonicalize | FileCheck %s #blocked0 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> @@ -7,18 +7,16 @@ #linear0 = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> #linear2 = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 128], [16, 0], [32, 0], [64, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], warp = [[0, 32], [0, 64]], block = []}> #shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> -#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> #shared5 = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0], [32, 0], [64, 0]]}, alignment = 128> #smem = #ttg.shared_memory #tmem0 = #ttng.tensor_memory_encoding -// CHECK-DAG: #{{.*}} = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> // CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> // CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @swizzle0_operand_views - // CHECK-DAG: %[[A_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[A_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #{{.*}}> // CHECK-DAG: %[[A_BASE:.*]] = ttg.local_alloc %[[A_DESC]] : (tensor<128x128xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<128x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[B_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC]] : (tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> // CHECK-DAG: %[[B_TR0:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> @@ -30,12 +28,12 @@ module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num // CHECK: ttng.tc_gen5_mma %[[A_BASE]], %[[B_TR1]], %arg2, %true, %true tt.func @swizzle0_operand_views( %a_desc: !tt.tensordesc>, - %b_desc: !tt.tensordesc>, + %b_desc: !tt.tensordesc>, %acc: !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %a = tt.descriptor_load %a_desc[%c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blocked0> - %b = tt.descriptor_load %b_desc[%c0_i32, %c0_i32, %c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #blocked1> + %b = tt.descriptor_load %b_desc[%c0_i32, %c0_i32, %c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #blocked1> %b1 = tt.reshape %b : tensor<1x1x256x8x16xf8E4M3FN, #blocked1> -> tensor<8x32x8x16xf8E4M3FN, #blocked2> %b2 = tt.trans %b1 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blocked2> -> tensor<32x8x8x16xf8E4M3FN, #blocked3> %b3 = tt.reshape %b2 : tensor<32x8x8x16xf8E4M3FN, #blocked3> -> tensor<256x128xf8E4M3FN, #linear0> @@ -56,16 +54,14 @@ module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num #linear_desc = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> #mma_desc = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 32]}> #sharedA_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> -#sharedB0_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> #sharedB3_desc = #ttg.shared_linear<{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0], [128, 0], [0, 16], [0, 32], [0, 64]]}, alignment = 128> #sharedB4_desc = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0], [32, 0], [64, 0]]}, alignment = 128> #smem = #ttg.shared_memory -// CHECK-DAG: #{{.*}} = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> // CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> // CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @swizzle0_operand_views_warp_group_dot - // CHECK-DAG: %[[B_DESC_LOAD:.*]] = tt.descriptor_load %arg1[%{{.*}}] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[B_DESC_LOAD:.*]] = tt.descriptor_load %arg1[%{{.*}}] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC_LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, {{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> // CHECK-DAG: %[[B_TR:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> @@ -77,12 +73,12 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK: %[[WAIT:.*]]:3 = ttng.warp_group_dot_wait %[[DOT]], {{.*}}, %[[B_T]] {pendings = 0 : i32} tt.func @swizzle0_operand_views_warp_group_dot( %a_desc: !tt.tensordesc>, - %b_desc: !tt.tensordesc>, + %b_desc: !tt.tensordesc>, %acc: tensor<128x256xf32, #mma_desc>) -> tensor<128x256xf32, #mma_desc> { %c0 = arith.constant 0 : i32 %a = tt.descriptor_load %a_desc[%c0, %c0] : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blockedA_desc> %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blockedA_desc>) -> !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> - %b_cm = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> + %b_cm = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> %b0 = tt.reshape %b_cm : tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> -> tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> %b1 = tt.trans %b0 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> -> tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> %b2 = tt.reshape %b1 : tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> -> tensor<256x128xf8E4M3FN, #linear_desc> diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 9e64b4a50f18..9f7dc7541b56 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -273,8 +273,8 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttgpuir.add_accelerate_matmul(pm) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) - nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm) nvidia.passes.ttnvgpuir.add_rewrite_mma_operand_views_to_memdesc(pm) + nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm) passes.ttir.add_loop_aware_cse(pm) if capability // 10 in [8, 9]: passes.ttgpuir.add_fuse_nested_loops(pm) From 6e07bb6156d93598b100adc33cdac8e6f376e34f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Apr 2026 05:41:25 +0900 Subject: [PATCH 47/54] pre commit --- lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp b/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp index df0ff3ec691f..b7dd81b9a51e 100644 --- a/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp +++ b/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp @@ -277,7 +277,8 @@ AssignDescriptorMemoryLayouts::findLoadEncodingFromUsers(Operation *op) { if (auto compatible = getCompatibleEncodingForType(alloc.getType())) return compatible; } else if (auto store = dyn_cast(use)) { - if (auto compatible = getCompatibleEncodingForType(store.getDst().getType())) + if (auto compatible = + getCompatibleEncodingForType(store.getDst().getType())) return compatible; } } From e093192a363a69f7193f56af2fd0de1dece7bcf3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Apr 2026 06:14:04 +0900 Subject: [PATCH 48/54] Update descriptor rewrite for new tensordesc type --- .../Transforms/DescriptorMemoryLayouts.cpp | 13 ++++++----- .../optimize_descriptor_encoding.mlir | 8 +++---- .../rewrite-mma-operand-views-to-memdesc.mlir | 22 +++++++++---------- 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp b/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp index 4513c193378a..6e513e954214 100644 --- a/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp +++ b/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp @@ -455,7 +455,9 @@ void AssignDescriptorMemoryLayouts::runOnFunction(FuncOp &func) { auto ctx = func.getContext(); auto numCTAs = triton::gpu::lookupNumCTAs(func); for (auto &[desc, einfo] : valueToEncodingInfo) { - auto existingTy = desc.getType().getBlockType(); + auto descTy = desc.getType(); + auto existingTy = + RankedTensorType::get(descTy.getShape(), descTy.getElementType()); Attribute newEncoding; if (einfo->desiredEncoding) { newEncoding = einfo->desiredEncoding; @@ -473,10 +475,11 @@ void AssignDescriptorMemoryLayouts::runOnFunction(FuncOp &func) { SmallVector resultTys(func.getResultTypes()); for (auto [i, resultTy] : llvm::enumerate(resultTys)) { if (auto descTy = dyn_cast(resultTy)) { - auto encoding = - getFallbackSharedEncoding(descTy.getBlockType(), {}, {}, numCTAs); - resultTys[i] = getTensorDescTypeWithEncoding( - nullptr, descTy.getBlockType(), encoding); + auto existingTy = + RankedTensorType::get(descTy.getShape(), descTy.getElementType()); + auto encoding = getFallbackSharedEncoding(existingTy, {}, {}, numCTAs); + resultTys[i] = + getTensorDescTypeWithEncoding(nullptr, existingTy, encoding); } } func.setFunctionType(FunctionType::get(ctx, argTys, resultTys)); diff --git a/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir b/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir index 6b59626ee409..ba95c5b8217e 100644 --- a/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir +++ b/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir @@ -134,13 +134,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-DAG: #[[NVMMA_0:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> // CHECK-DAG: #[[BLOCKED5D:.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> // CHECK-DAG: #[[SL_BASE:.*]] = #ttg.shared_linear<{{.*}}alignment = 128> -tt.func public @descriptor_arg_from_shared_linear_use(%b_desc: !tt.tensordesc>) { - // CHECK: %arg0: !tt.tensordesc> +tt.func public @descriptor_arg_from_shared_linear_use(%b_desc: !tt.tensordesc<1x1x256x8x16xf8E4M3FN>) { + // CHECK: %arg0: !tt.tensordesc<1x1x256x8x16xf8E4M3FN, #[[NVMMA_0]]> // CHECK: %[[LOAD:.*]] = tt.descriptor_load %arg0 - // CHECK-SAME: : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #[[BLOCKED5D]]> + // CHECK-SAME: : !tt.tensordesc<1x1x256x8x16xf8E4M3FN, #[[NVMMA_0]]> -> tensor<1x1x256x8x16xf8E4M3FN, #[[BLOCKED5D]]> // CHECK: ttg.local_alloc %[[LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, #[[BLOCKED5D]]>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #[[SL_BASE]], #smem> %c0 = arith.constant 0 : i32 - %b = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #blocked2> + %b = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #blocked2> %b_s = ttg.local_alloc %b : (tensor<1x1x256x8x16xf8E4M3FN, #blocked2>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #shared_linear_base, #smem> tt.return } diff --git a/test/TritonNvidiaGPU/rewrite-mma-operand-views-to-memdesc.mlir b/test/TritonNvidiaGPU/rewrite-mma-operand-views-to-memdesc.mlir index 114f4b5986d4..0c86fb09cf89 100644 --- a/test/TritonNvidiaGPU/rewrite-mma-operand-views-to-memdesc.mlir +++ b/test/TritonNvidiaGPU/rewrite-mma-operand-views-to-memdesc.mlir @@ -14,9 +14,9 @@ // CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @swizzle0_operand_views - // CHECK-DAG: %[[A_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[A_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc<128x128xf8E4M3FN> -> tensor<128x128xf8E4M3FN, #{{.*}}> // CHECK-DAG: %[[A_BASE:.*]] = ttg.local_alloc %[[A_DESC]] : (tensor<128x128xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<128x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[B_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC]] : (tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> // CHECK-DAG: %[[B_TR0:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> @@ -27,13 +27,13 @@ module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num // CHECK-NOT: ttg.local_load // CHECK: ttng.tc_gen5_mma %[[A_BASE]], %[[B_TR1]], %arg2, %true, %true tt.func @swizzle0_operand_views( - %a_desc: !tt.tensordesc>, - %b_desc: !tt.tensordesc>, + %a_desc: !tt.tensordesc<128x128xf8E4M3FN>, + %b_desc: !tt.tensordesc<1x1x256x8x16xf8E4M3FN>, %acc: !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory>) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 - %a = tt.descriptor_load %a_desc[%c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blocked0> - %b = tt.descriptor_load %b_desc[%c0_i32, %c0_i32, %c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #blocked1> + %a = tt.descriptor_load %a_desc[%c0_i32, %c0_i32] : !tt.tensordesc<128x128xf8E4M3FN> -> tensor<128x128xf8E4M3FN, #blocked0> + %b = tt.descriptor_load %b_desc[%c0_i32, %c0_i32, %c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #blocked1> %b1 = tt.reshape %b : tensor<1x1x256x8x16xf8E4M3FN, #blocked1> -> tensor<8x32x8x16xf8E4M3FN, #blocked2> %b2 = tt.trans %b1 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blocked2> -> tensor<32x8x8x16xf8E4M3FN, #blocked3> %b3 = tt.reshape %b2 : tensor<32x8x8x16xf8E4M3FN, #blocked3> -> tensor<256x128xf8E4M3FN, #linear0> @@ -61,7 +61,7 @@ module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num // CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @swizzle0_operand_views_warp_group_dot - // CHECK-DAG: %[[B_DESC_LOAD:.*]] = tt.descriptor_load %arg1[%{{.*}}] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[B_DESC_LOAD:.*]] = tt.descriptor_load %arg1[%{{.*}}] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC_LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, {{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> // CHECK-DAG: %[[B_TR:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> @@ -72,13 +72,13 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // CHECK: %[[DOT:.*]] = ttng.warp_group_dot %{{.*}}, %[[B_T]], %arg2 // CHECK: %[[WAIT:.*]]:3 = ttng.warp_group_dot_wait %[[DOT]], {{.*}}, %[[B_T]] {pendings = 0 : i32} tt.func @swizzle0_operand_views_warp_group_dot( - %a_desc: !tt.tensordesc>, - %b_desc: !tt.tensordesc>, + %a_desc: !tt.tensordesc<128x128xf8E4M3FN, #sharedA_desc>, + %b_desc: !tt.tensordesc<1x1x256x8x16xf8E4M3FN>, %acc: tensor<128x256xf32, #mma_desc>) -> tensor<128x256xf32, #mma_desc> { %c0 = arith.constant 0 : i32 - %a = tt.descriptor_load %a_desc[%c0, %c0] : !tt.tensordesc> -> tensor<128x128xf8E4M3FN, #blockedA_desc> + %a = tt.descriptor_load %a_desc[%c0, %c0] : !tt.tensordesc<128x128xf8E4M3FN, #sharedA_desc> -> tensor<128x128xf8E4M3FN, #blockedA_desc> %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blockedA_desc>) -> !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> - %b_cm = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc> -> tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> + %b_cm = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> %b0 = tt.reshape %b_cm : tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> -> tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> %b1 = tt.trans %b0 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> -> tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> %b2 = tt.reshape %b1 : tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> -> tensor<256x128xf8E4M3FN, #linear_desc> From 4f97dc122abe746228d60fa88e95dd9e8ddb501c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Apr 2026 07:46:42 +0900 Subject: [PATCH 49/54] Keep descriptor layouts non-transposed --- .../Transforms/OptimizeDescriptorEncoding.cpp | 30 ++++++++++--------- .../optimize_descriptor_encoding.mlir | 21 +++++++++++++ 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp index 5f0eeab41fa7..e8d88a9bdc8a 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp @@ -27,7 +27,9 @@ class NvidiaGPUAssignDescriptorMemoryLayouts bool NvidiaGPUAssignDescriptorMemoryLayouts::isCompatibleSharedEncoding( Attribute enc) { - return isa(enc); + if (auto nvmma = dyn_cast(enc)) + return !nvmma.getTransposed(); + return false; } Attribute NvidiaGPUAssignDescriptorMemoryLayouts::getCompatibleSharedEncoding( @@ -44,9 +46,10 @@ Attribute NvidiaGPUAssignDescriptorMemoryLayouts::getCompatibleSharedEncoding( auto order = ttg::getOrder(sharedLinear, shape); SmallVector preferredCandidates; - // Preserve Triton's default shape/order-based choice when it already matches - // this shared_linear layout. The full candidate scan below is only a - // fallback for equivalent layouts not selected by the heuristic builder. + // TMA descriptors only support non-transposed layouts. Preserve Triton's + // default shape/order-based choice when it already matches this + // shared_linear layout. The full candidate scan below is only a fallback for + // equivalent non-transposed layouts not selected by the heuristic builder. for (bool fp4Padded : {false, true}) { auto preferred = ttg::NVMMASharedEncodingAttr::get( ctx, shape, order, cgaLayout, elementType, fp4Padded); @@ -56,16 +59,15 @@ Attribute NvidiaGPUAssignDescriptorMemoryLayouts::getCompatibleSharedEncoding( } unsigned elementBitWidth = std::max(8u, elementType.getIntOrFloatBitWidth()); - for (bool transposed : {false, true}) { - for (bool fp4Padded : {false, true}) { - for (unsigned swizzle : {0u, 32u, 64u, 128u}) { - auto candidate = ttg::NVMMASharedEncodingAttr::get( - ctx, swizzle, transposed, elementBitWidth, fp4Padded, cgaLayout); - if (llvm::is_contained(preferredCandidates, candidate)) - continue; - if (ttg::areLayoutsEquivalent(shape, sharedLinear, candidate)) - return candidate; - } + for (bool fp4Padded : {false, true}) { + for (unsigned swizzle : {0u, 32u, 64u, 128u}) { + auto candidate = ttg::NVMMASharedEncodingAttr::get( + ctx, swizzle, /*transposed=*/false, elementBitWidth, fp4Padded, + cgaLayout); + if (llvm::is_contained(preferredCandidates, candidate)) + continue; + if (ttg::areLayoutsEquivalent(shape, sharedLinear, candidate)) + return candidate; } } diff --git a/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir b/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir index ba95c5b8217e..1376692737b3 100644 --- a/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir +++ b/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir @@ -86,6 +86,27 @@ tt.func public @descriptor_kernel_arg(%arg0: !tt.tensordesc<64x64xf16>, %arg1: i // ----- +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { +// CHECK-DAG: #[[BLOCKED_16x128:.*]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK-DAG: #[[NVMMA_128:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +// CHECK-DAG: #[[NVMMA_TRANSPOSED_32:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16}> +tt.func public @descriptor_ignores_transposed_local_alloc(%arg0: !tt.tensordesc<16x128xf16>) { + // CHECK: %arg0: !tt.tensordesc<16x128xf16, #[[NVMMA_128]]> + // CHECK: %[[LOAD:.*]] = tt.descriptor_load %arg0{{.*}} : !tt.tensordesc<16x128xf16, #[[NVMMA_128]]> -> tensor<16x128xf16, #[[BLOCKED_16x128]]> + // CHECK: ttg.local_alloc %[[LOAD]] : (tensor<16x128xf16, #[[BLOCKED_16x128]]>) -> !ttg.memdesc<16x128xf16, #[[NVMMA_TRANSPOSED_32]], #smem> + %c0 = arith.constant 0 : i32 + %0 = tt.descriptor_load %arg0[%c0, %c0] : !tt.tensordesc<16x128xf16> -> tensor<16x128xf16, #blocked> + %1 = ttg.local_alloc %0 : (tensor<16x128xf16, #blocked>) -> !ttg.memdesc<16x128xf16, #shared, #smem> + tt.return +} +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> From 3dee2de2a4a532b731b9590ec42e56a7d832e276 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Apr 2026 08:43:23 +0900 Subject: [PATCH 50/54] Simplify MMA operand view replay steps --- .../RewriteMmaOperandViewsToMemDesc.cpp | 170 ++++++------------ 1 file changed, 56 insertions(+), 114 deletions(-) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp index 8fb981012633..66e2fecbdf2c 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp @@ -6,6 +6,7 @@ #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include namespace mlir::triton::nvidia_gpu { @@ -19,8 +20,8 @@ namespace { // The MMA operand layout is determined by the sink memdesc already feeding the // dot-like op. This pattern back-propagates that layout through the tensor // reshape/transpose chain, hoists local_alloc to the base tensor feeding that -// view chain, and then replays the same views as memdesc reshape/transpose -// ops. +// view chain, and replays those tensor views as memdesc reshape/transpose +// ops so the original local_alloc type is preserved. template class RewriteMmaOperandViewsToMemDescForDotOp : public OpRewritePattern { @@ -29,166 +30,107 @@ class RewriteMmaOperandViewsToMemDescForDotOp LogicalResult matchAndRewrite(DotOpTy dotOp, PatternRewriter &rewriter) const override { - Value oldA = dotOp.getA(); - Value oldB = dotOp.getB(); bool changed = false; - if (rewriteOperand(dotOp.getAMutable(), rewriter).succeeded()) { - oldA.replaceAllUsesExcept(dotOp.getA(), dotOp.getOperation()); + if (rewriteOperand(dotOp.getA(), rewriter).succeeded()) changed = true; - } - if (rewriteOperand(dotOp.getBMutable(), rewriter).succeeded()) { - oldB.replaceAllUsesExcept(dotOp.getB(), dotOp.getOperation()); + if (rewriteOperand(dotOp.getB(), rewriter).succeeded()) changed = true; - } return success(changed); } private: - struct ViewStep { - enum Kind { Reshape, Transpose } kind; - SmallVector srcShape; - SmallVector dstShape; - SmallVector order; - Operation *op; - Location loc; - }; - - template - static std::tuple> - collectViewSteps(Value value) { - Value current = value; - SmallVector replaySteps; - while (true) { - if (auto reshape = current.template getDefiningOp()) { - auto srcTy = reshape.getSrc().getType(); - auto dstTy = reshape.getType(); - replaySteps.push_back(ViewStep{ViewStep::Reshape, - SmallVector(srcTy.getShape()), - SmallVector(dstTy.getShape()), - {}, - reshape.getOperation(), - reshape.getLoc()}); - current = reshape.getSrc(); - continue; - } - if (auto trans = current.template getDefiningOp()) { - SmallVector order(trans.getOrder().begin(), - trans.getOrder().end()); - auto srcTy = trans.getSrc().getType(); - auto dstTy = trans.getType(); - replaySteps.push_back(ViewStep{ - ViewStep::Transpose, SmallVector(srcTy.getShape()), - SmallVector(dstTy.getShape()), std::move(order), - trans.getOperation(), trans.getLoc()}); - current = trans.getSrc(); - continue; - } - break; - } - return {current, llvm::to_vector(llvm::reverse(replaySteps))}; - } - static FailureOr - inferViewStepBackward(gpu::MemDescType resultTy, const ViewStep &step) { - assert(resultTy.getShape() == ArrayRef(step.dstShape) && - "backward inference must start from the view step destination " - "shape"); - if (step.kind == ViewStep::Reshape) { + pushLayoutBackward(gpu::MemDescType resultTy, Operation *op) { + if (auto reshape = dyn_cast(op)) { gpu::MemDescType srcTy; if (failed(gpu::MemDescReshapeOp::inferReturnTypes( - resultTy.getContext(), step.loc, resultTy, step.srcShape, srcTy))) + resultTy.getContext(), reshape.getLoc(), resultTy, + reshape.getSrc().getType().getShape(), srcTy))) return failure(); return srcTy; } - Attribute srcEnc = inferSrcEncoding(step.op, resultTy.getEncoding()); + + auto trans = cast(op); + Attribute srcEnc = inferSrcEncoding(op, resultTy.getEncoding()); if (!srcEnc) return failure(); - return gpu::MemDescType::get(step.srcShape, resultTy.getElementType(), - srcEnc, resultTy.getMemorySpace(), - resultTy.getMutableMemory()); + return gpu::MemDescType::get( + trans.getSrc().getType().getShape(), resultTy.getElementType(), srcEnc, + resultTy.getMemorySpace(), resultTy.getMutableMemory()); } - static Value replayViewSteps(PatternRewriter &rewriter, Value value, - ArrayRef steps) { + static Value replayTensorViews(PatternRewriter &rewriter, Value value, + ArrayRef steps) { Value rewritten = value; - for (const ViewStep &step : steps) { - if (step.kind == ViewStep::Reshape) { - rewritten = gpu::MemDescReshapeOp::create(rewriter, step.loc, rewritten, - step.dstShape); + for (Operation *op : steps) { + if (auto reshape = dyn_cast(op)) { + rewritten = gpu::MemDescReshapeOp::create( + rewriter, op->getLoc(), rewritten, reshape.getType().getShape()); } else { - rewritten = gpu::MemDescTransOp::create(rewriter, step.loc, rewritten, - step.order); + auto trans = cast(op); + rewritten = gpu::MemDescTransOp::create(rewriter, op->getLoc(), + rewritten, trans.getOrder()); } } return rewritten; } - static void assertEquivalentMemDescType(gpu::MemDescType actualTy, - gpu::MemDescType expectedTy, - const char *message) { - assert(actualTy.getShape() == expectedTy.getShape() && - actualTy.getElementType() == expectedTy.getElementType() && - gpu::areLayoutsEquivalent( - expectedTy.getShape(), - cast(actualTy.getEncoding()), - cast(expectedTy.getEncoding())) && - message); + static Value peelMemDescViews(Value value) { + Value current = value; + while (auto view = current.getDefiningOp()) { + if (auto reshape = dyn_cast(view)) { + current = reshape.getSrc(); + continue; + } + if (auto trans = dyn_cast(view)) { + current = trans.getSrc(); + continue; + } + break; + } + return current; } - LogicalResult rewriteOperand(OpOperand &operand, - PatternRewriter &rewriter) const { - Value orig = operand.get(); - auto origTy = dyn_cast(orig.getType()); - if (!origTy) + LogicalResult rewriteOperand(Value operand, PatternRewriter &rewriter) const { + if (!isa(operand.getType())) return failure(); - auto [beforeTrailing, trailingMemDescReplaySteps] = - collectViewSteps(orig); - - auto localAlloc = - beforeTrailing.template getDefiningOp(); + Value beforeTrailing = peelMemDescViews(operand); + auto localAlloc = beforeTrailing.getDefiningOp(); if (!localAlloc || !localAlloc.getSrc()) return failure(); - auto [baseTensor, tensorReplaySteps] = - collectViewSteps( - localAlloc.getSrc()); - if (tensorReplaySteps.empty()) - return failure(); - + Value baseTensor = localAlloc.getSrc(); + SmallVector tensorReplaySteps; gpu::MemDescType baseMemTy = localAlloc.getType(); - for (const ViewStep &step : llvm::reverse(tensorReplaySteps)) { - auto srcTy = inferViewStepBackward(baseMemTy, step); + while (auto view = baseTensor.getDefiningOp()) { + if (!isa(view)) + break; + auto srcTy = pushLayoutBackward(baseMemTy, view); if (failed(srcTy)) return failure(); + tensorReplaySteps.push_back(view); baseMemTy = *srcTy; + baseTensor = view->getOperand(0); } + if (tensorReplaySteps.empty()) + return failure(); + + std::reverse(tensorReplaySteps.begin(), tensorReplaySteps.end()); PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(localAlloc); Value rewritten = gpu::LocalAllocOp::create(rewriter, localAlloc.getLoc(), baseMemTy, baseTensor); - auto sinkTy = localAlloc.getType(); - - rewritten = replayViewSteps(rewriter, rewritten, tensorReplaySteps); + rewritten = replayTensorViews(rewriter, rewritten, tensorReplaySteps); auto rewrittenSinkTy = cast(rewritten.getType()); - assertEquivalentMemDescType( - rewrittenSinkTy, sinkTy, - "rewrite must preserve the intermediate sink memdesc"); - - rewritten = - replayViewSteps(rewriter, rewritten, trailingMemDescReplaySteps); - - auto rewrittenTy = cast(rewritten.getType()); - assertEquivalentMemDescType(rewrittenTy, origTy, - "rewrite must preserve the final memdesc"); - operand.assign(rewritten); + rewriter.replaceOp(localAlloc, rewritten); return success(); } }; From f70af5bbfe5bb0cd3f30980af7569e5a32e2f9b6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Apr 2026 08:51:00 +0900 Subject: [PATCH 51/54] Use DotOpInterface in MMA view rewrite --- .../RewriteMmaOperandViewsToMemDesc.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp index 66e2fecbdf2c..6942d89b6beb 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp @@ -22,14 +22,17 @@ namespace { // reshape/transpose chain, hoists local_alloc to the base tensor feeding that // view chain, and replays those tensor views as memdesc reshape/transpose // ops so the original local_alloc type is preserved. -template class RewriteMmaOperandViewsToMemDescForDotOp - : public OpRewritePattern { + : public OpInterfaceRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpInterfaceRewritePattern< + triton::DotOpInterface>::OpInterfaceRewritePattern; - LogicalResult matchAndRewrite(DotOpTy dotOp, + LogicalResult matchAndRewrite(triton::DotOpInterface dotOp, PatternRewriter &rewriter) const override { + if (!isa(dotOp)) + return failure(); + bool changed = false; if (rewriteOperand(dotOp.getA(), rewriter).succeeded()) @@ -150,10 +153,7 @@ class TritonNvidiaGPURewriteMmaOperandViewsToMemDescPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); - patterns.add, - RewriteMmaOperandViewsToMemDescForDotOp, - RewriteMmaOperandViewsToMemDescForDotOp>( - &getContext()); + patterns.add(&getContext()); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } From 87eb143151c2978ec3159dba0bbbf3ce9b3f53a3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Apr 2026 09:11:54 +0900 Subject: [PATCH 52/54] Move MMA operand view rewrite into ODO --- .../TritonNvidiaGPU/Transforms/Passes.td | 14 -- .../Transforms/OptimizeDotOperands.cpp | 128 ++++++++++++++ .../TritonNvidiaGPU/Transforms/CMakeLists.txt | 1 - .../RewriteMmaOperandViewsToMemDesc.cpp | 162 ------------------ test/TritonGPU/dot-operands.mlir | 93 ++++++++++ .../rewrite-mma-operand-views-to-memdesc.mlir | 92 ---------- third_party/nvidia/backend/compiler.py | 1 - third_party/nvidia/triton_nvidia.cc | 3 - 8 files changed, 221 insertions(+), 273 deletions(-) delete mode 100644 lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp delete mode 100644 test/TritonNvidiaGPU/rewrite-mma-operand-views-to-memdesc.mlir diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td index c9abeadbde7d..a41b2e891434 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td @@ -142,20 +142,6 @@ def TritonNvidiaGPUOptimizeDescriptorEncodingPass : Pass<"triton-nvidia-optimize "mlir::triton::TritonDialect"]; } -def TritonNvidiaGPURewriteMmaOperandViewsToMemDescPass : Pass<"triton-nvidia-rewrite-mma-operand-views-to-memdesc", "mlir::ModuleOp"> { - let summary = "Rewrite tensor MMA operand views into memdesc views"; - - let description = [{ - Rewrite tensor reshape/transpose chains feeding MMA shared-memory operands - into equivalent memdesc reshape/transpose chains once the operand layout is - fixed. - }]; - - let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", - "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect", - "mlir::triton::TritonDialect"]; -} - def TritonNvidiaGPUOptimizeTMemLayoutsPass : Pass<"triton-nvidia-optimize-tmem-layouts", "mlir::ModuleOp"> { let summary = "Optimize TMEM layouts."; diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 9fd5e75602b1..e6c5a588e3de 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -12,6 +12,7 @@ #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Tools/LayoutUtils.h" #include "triton/Tools/LinearLayout.h" +#include #include namespace mlir::triton::gpu { @@ -183,6 +184,132 @@ class ReshapeMemDesc : public OpRewritePattern { } }; +// Rewrite +// tt.reshape / tt.trans -> local_alloc -> [memdesc views] -> mma +// into +// local_alloc -> memdesc reshape / trans -> [memdesc views] -> mma +// +// The MMA operand layout is determined by the sink memdesc already feeding the +// dot-like op. This pattern back-propagates that layout through the tensor +// reshape/transpose chain, hoists local_alloc to the base tensor feeding that +// view chain, and replays those tensor views as memdesc reshape/transpose +// ops so the original local_alloc type is preserved. +class RewriteMmaOperandViewsToMemDescForDotOp + : public OpInterfaceRewritePattern { +public: + using OpInterfaceRewritePattern< + triton::DotOpInterface>::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(triton::DotOpInterface dotOp, + PatternRewriter &rewriter) const override { + if (!isa(dotOp)) + return failure(); + + bool changed = false; + + if (rewriteOperand(dotOp.getA(), rewriter).succeeded()) + changed = true; + + if (rewriteOperand(dotOp.getB(), rewriter).succeeded()) + changed = true; + + return success(changed); + } + +private: + static FailureOr pushLayoutBackward(MemDescType resultTy, + Operation *op) { + if (auto reshape = dyn_cast(op)) { + MemDescType srcTy; + if (failed(MemDescReshapeOp::inferReturnTypes( + resultTy.getContext(), reshape.getLoc(), resultTy, + reshape.getSrc().getType().getShape(), srcTy))) + return failure(); + return srcTy; + } + + auto trans = cast(op); + Attribute srcEnc = inferSrcEncoding(op, resultTy.getEncoding()); + if (!srcEnc) + return failure(); + return MemDescType::get(trans.getSrc().getType().getShape(), + resultTy.getElementType(), srcEnc, + resultTy.getMemorySpace(), + resultTy.getMutableMemory()); + } + + static Value replayTensorViews(PatternRewriter &rewriter, Value value, + ArrayRef steps) { + Value rewritten = value; + for (Operation *op : steps) { + if (auto reshape = dyn_cast(op)) { + rewritten = MemDescReshapeOp::create(rewriter, op->getLoc(), rewritten, + reshape.getType().getShape()); + } else { + auto trans = cast(op); + rewritten = MemDescTransOp::create(rewriter, op->getLoc(), rewritten, + trans.getOrder()); + } + } + return rewritten; + } + + static Value peelMemDescViews(Value value) { + Value current = value; + while (auto view = current.getDefiningOp()) { + if (auto reshape = dyn_cast(view)) { + current = reshape.getSrc(); + continue; + } + if (auto trans = dyn_cast(view)) { + current = trans.getSrc(); + continue; + } + break; + } + return current; + } + + LogicalResult rewriteOperand(Value operand, PatternRewriter &rewriter) const { + if (!isa(operand.getType())) + return failure(); + + Value beforeTrailing = peelMemDescViews(operand); + auto localAlloc = beforeTrailing.getDefiningOp(); + if (!localAlloc || !localAlloc.getSrc()) + return failure(); + + Value baseTensor = localAlloc.getSrc(); + SmallVector tensorReplaySteps; + MemDescType baseMemTy = localAlloc.getType(); + while (auto view = baseTensor.getDefiningOp()) { + if (!isa(view)) + break; + auto srcTy = pushLayoutBackward(baseMemTy, view); + if (failed(srcTy)) + return failure(); + tensorReplaySteps.push_back(view); + baseMemTy = *srcTy; + baseTensor = view->getOperand(0); + } + if (tensorReplaySteps.empty()) + return failure(); + + std::reverse(tensorReplaySteps.begin(), tensorReplaySteps.end()); + + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(localAlloc); + + Value rewritten = LocalAllocOp::create(rewriter, localAlloc.getLoc(), + baseMemTy, baseTensor); + rewritten = replayTensorViews(rewriter, rewritten, tensorReplaySteps); + rewriter.replaceOp(localAlloc, rewritten); + return success(); + } +}; + // Inject TMEM copy instructions into IR to efficiently load blocked scales for // scaled dot class UseShmemForScales @@ -344,6 +471,7 @@ class TritonGPUOptimizeDotOperandsPass mlir::RewritePatternSet patterns(context); patterns.add(context); patterns.add(context); + patterns.add(context); patterns.add(context); ConvertLayoutOp::getCanonicalizationPatterns(patterns, context); if (failed(applyPatternsGreedily(m, std::move(patterns)))) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt index 4059f549ec87..7b0d2c626a1e 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt @@ -10,7 +10,6 @@ add_triton_library(TritonNvidiaGPUTransforms PlanCTA.cpp PromoteLHSToTMem.cpp ProxyFenceInsertion.cpp - RewriteMmaOperandViewsToMemDesc.cpp RemoveTMEMTokens.cpp TensorMemoryAllocation.cpp TMALowering.cpp diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp deleted file mode 100644 index 6942d89b6beb..000000000000 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteMmaOperandViewsToMemDesc.cpp +++ /dev/null @@ -1,162 +0,0 @@ -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" -#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" -#include - -namespace mlir::triton::nvidia_gpu { - -namespace { - -// Rewrite -// tt.reshape / tt.trans -> local_alloc -> [memdesc views] -> mma -// into -// local_alloc -> memdesc reshape / trans -> [memdesc views] -> mma -// -// The MMA operand layout is determined by the sink memdesc already feeding the -// dot-like op. This pattern back-propagates that layout through the tensor -// reshape/transpose chain, hoists local_alloc to the base tensor feeding that -// view chain, and replays those tensor views as memdesc reshape/transpose -// ops so the original local_alloc type is preserved. -class RewriteMmaOperandViewsToMemDescForDotOp - : public OpInterfaceRewritePattern { -public: - using OpInterfaceRewritePattern< - triton::DotOpInterface>::OpInterfaceRewritePattern; - - LogicalResult matchAndRewrite(triton::DotOpInterface dotOp, - PatternRewriter &rewriter) const override { - if (!isa(dotOp)) - return failure(); - - bool changed = false; - - if (rewriteOperand(dotOp.getA(), rewriter).succeeded()) - changed = true; - - if (rewriteOperand(dotOp.getB(), rewriter).succeeded()) - changed = true; - - return success(changed); - } - -private: - static FailureOr - pushLayoutBackward(gpu::MemDescType resultTy, Operation *op) { - if (auto reshape = dyn_cast(op)) { - gpu::MemDescType srcTy; - if (failed(gpu::MemDescReshapeOp::inferReturnTypes( - resultTy.getContext(), reshape.getLoc(), resultTy, - reshape.getSrc().getType().getShape(), srcTy))) - return failure(); - return srcTy; - } - - auto trans = cast(op); - Attribute srcEnc = inferSrcEncoding(op, resultTy.getEncoding()); - if (!srcEnc) - return failure(); - return gpu::MemDescType::get( - trans.getSrc().getType().getShape(), resultTy.getElementType(), srcEnc, - resultTy.getMemorySpace(), resultTy.getMutableMemory()); - } - - static Value replayTensorViews(PatternRewriter &rewriter, Value value, - ArrayRef steps) { - Value rewritten = value; - for (Operation *op : steps) { - if (auto reshape = dyn_cast(op)) { - rewritten = gpu::MemDescReshapeOp::create( - rewriter, op->getLoc(), rewritten, reshape.getType().getShape()); - } else { - auto trans = cast(op); - rewritten = gpu::MemDescTransOp::create(rewriter, op->getLoc(), - rewritten, trans.getOrder()); - } - } - return rewritten; - } - - static Value peelMemDescViews(Value value) { - Value current = value; - while (auto view = current.getDefiningOp()) { - if (auto reshape = dyn_cast(view)) { - current = reshape.getSrc(); - continue; - } - if (auto trans = dyn_cast(view)) { - current = trans.getSrc(); - continue; - } - break; - } - return current; - } - - LogicalResult rewriteOperand(Value operand, PatternRewriter &rewriter) const { - if (!isa(operand.getType())) - return failure(); - - Value beforeTrailing = peelMemDescViews(operand); - auto localAlloc = beforeTrailing.getDefiningOp(); - if (!localAlloc || !localAlloc.getSrc()) - return failure(); - - Value baseTensor = localAlloc.getSrc(); - SmallVector tensorReplaySteps; - gpu::MemDescType baseMemTy = localAlloc.getType(); - while (auto view = baseTensor.getDefiningOp()) { - if (!isa(view)) - break; - auto srcTy = pushLayoutBackward(baseMemTy, view); - if (failed(srcTy)) - return failure(); - tensorReplaySteps.push_back(view); - baseMemTy = *srcTy; - baseTensor = view->getOperand(0); - } - if (tensorReplaySteps.empty()) - return failure(); - - std::reverse(tensorReplaySteps.begin(), tensorReplaySteps.end()); - - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(localAlloc); - - Value rewritten = gpu::LocalAllocOp::create(rewriter, localAlloc.getLoc(), - baseMemTy, baseTensor); - rewritten = replayTensorViews(rewriter, rewritten, tensorReplaySteps); - - auto rewrittenSinkTy = cast(rewritten.getType()); - - rewriter.replaceOp(localAlloc, rewritten); - return success(); - } -}; - -} // namespace - -#define GEN_PASS_DEF_TRITONNVIDIAGPUREWRITEMMAOPERANDVIEWSTOMEMDESCPASS -#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc" - -class TritonNvidiaGPURewriteMmaOperandViewsToMemDescPass - : public impl::TritonNvidiaGPURewriteMmaOperandViewsToMemDescPassBase< - TritonNvidiaGPURewriteMmaOperandViewsToMemDescPass> { -public: - using BaseT = impl::TritonNvidiaGPURewriteMmaOperandViewsToMemDescPassBase< - TritonNvidiaGPURewriteMmaOperandViewsToMemDescPass>; - using BaseT::BaseT; - - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) - signalPassFailure(); - } -}; - -} // namespace mlir::triton::nvidia_gpu diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index f292a2fcb211..4434061722a8 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -293,3 +293,96 @@ module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num tt.return %a: !ttg.memdesc<64x64xf32, #shared, #smem> } } + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> +#linear0 = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> +#linear2 = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 128], [16, 0], [32, 0], [64, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], warp = [[0, 32], [0, 64]], block = []}> +#shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#shared5 = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0], [32, 0], [64, 0]]}, alignment = 128> +#smem = #ttg.shared_memory +#tmem0 = #ttng.tensor_memory_encoding +// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> +// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> +module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @swizzle0_operand_views + // CHECK-DAG: %[[A_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc<128x128xf8E4M3FN> -> tensor<128x128xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[A_BASE:.*]] = ttg.local_alloc %[[A_DESC]] : (tensor<128x128xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<128x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC]] : (tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_TR0:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR0]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_TR1:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-NOT: tt.reshape + // CHECK-NOT: tt.trans + // CHECK-NOT: ttg.local_load + // CHECK: ttng.tc_gen5_mma %[[A_BASE]], %[[B_TR1]], %arg2, %true, %true + tt.func @swizzle0_operand_views( + %a_desc: !tt.tensordesc<128x128xf8E4M3FN>, + %b_desc: !tt.tensordesc<1x1x256x8x16xf8E4M3FN>, + %acc: !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory>) { + %true = arith.constant true + %c0_i32 = arith.constant 0 : i32 + %a = tt.descriptor_load %a_desc[%c0_i32, %c0_i32] : !tt.tensordesc<128x128xf8E4M3FN> -> tensor<128x128xf8E4M3FN, #blocked0> + %b = tt.descriptor_load %b_desc[%c0_i32, %c0_i32, %c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #blocked1> + %b1 = tt.reshape %b : tensor<1x1x256x8x16xf8E4M3FN, #blocked1> -> tensor<8x32x8x16xf8E4M3FN, #blocked2> + %b2 = tt.trans %b1 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blocked2> -> tensor<32x8x8x16xf8E4M3FN, #blocked3> + %b3 = tt.reshape %b2 : tensor<32x8x8x16xf8E4M3FN, #blocked3> -> tensor<256x128xf8E4M3FN, #linear0> + %b4 = tt.trans %b3 {order = array} : tensor<256x128xf8E4M3FN, #linear0> -> tensor<128x256xf8E4M3FN, #linear2> + %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blocked0>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared0, #smem> + %b_s = ttg.local_alloc %b4 : (tensor<128x256xf8E4M3FN, #linear2>) -> !ttg.memdesc<128x256xf8E4M3FN, #shared5, #smem> + ttng.tc_gen5_mma %a_s, %b_s, %acc, %true, %true : !ttg.memdesc<128x128xf8E4M3FN, #shared0, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #shared5, #smem>, !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory> + tt.return + } +} + +// ----- + +#blockedA_desc = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +#blockedB0_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> +#blockedB1_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> +#linear_desc = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> +#mma_desc = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 32]}> +#sharedA_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#sharedB3_desc = #ttg.shared_linear<{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0], [128, 0], [0, 16], [0, 32], [0, 64]]}, alignment = 128> +#sharedB4_desc = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0], [32, 0], [64, 0]]}, alignment = 128> +#smem = #ttg.shared_memory +// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> +// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @swizzle0_operand_views_warp_group_dot + // CHECK-DAG: %[[B_DESC_LOAD:.*]] = tt.descriptor_load %arg1[%{{.*}}] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC_LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, {{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_TR:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_T:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-NOT: tt.reshape + // CHECK-NOT: tt.trans + // CHECK: %[[DOT:.*]] = ttng.warp_group_dot %{{.*}}, %[[B_T]], %arg2 + // CHECK: %[[WAIT:.*]]:3 = ttng.warp_group_dot_wait %[[DOT]], {{.*}}, %[[B_T]] {pendings = 0 : i32} + tt.func @swizzle0_operand_views_warp_group_dot( + %a_desc: !tt.tensordesc<128x128xf8E4M3FN, #sharedA_desc>, + %b_desc: !tt.tensordesc<1x1x256x8x16xf8E4M3FN>, + %acc: tensor<128x256xf32, #mma_desc>) -> tensor<128x256xf32, #mma_desc> { + %c0 = arith.constant 0 : i32 + %a = tt.descriptor_load %a_desc[%c0, %c0] : !tt.tensordesc<128x128xf8E4M3FN, #sharedA_desc> -> tensor<128x128xf8E4M3FN, #blockedA_desc> + %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blockedA_desc>) -> !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> + %b_cm = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> + %b0 = tt.reshape %b_cm : tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> -> tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> + %b1 = tt.trans %b0 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> -> tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> + %b2 = tt.reshape %b1 : tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> -> tensor<256x128xf8E4M3FN, #linear_desc> + %b_s = ttg.local_alloc %b2 : (tensor<256x128xf8E4M3FN, #linear_desc>) -> !ttg.memdesc<256x128xf8E4M3FN, #sharedB3_desc, #smem> + %b_t = ttg.memdesc_trans %b_s {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #sharedB3_desc, #smem> -> !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> + %r = ttng.warp_group_dot %a_s, %b_t, %acc {inputPrecision = 0 : i32, isAsync = true, maxNumImpreciseAcc = 1073741824 : i32} + : !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> * !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> -> tensor<128x256xf32, #mma_desc> + %w:3 = ttng.warp_group_dot_wait %r, %a_s, %b_t {pendings = 0 : i32} : tensor<128x256xf32, #mma_desc>, !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> + tt.return %w#0 : tensor<128x256xf32, #mma_desc> + } +} diff --git a/test/TritonNvidiaGPU/rewrite-mma-operand-views-to-memdesc.mlir b/test/TritonNvidiaGPU/rewrite-mma-operand-views-to-memdesc.mlir deleted file mode 100644 index 0c86fb09cf89..000000000000 --- a/test/TritonNvidiaGPU/rewrite-mma-operand-views-to-memdesc.mlir +++ /dev/null @@ -1,92 +0,0 @@ -// RUN: triton-opt %s -split-input-file -triton-nvidia-rewrite-mma-operand-views-to-memdesc -canonicalize | FileCheck %s - -#blocked0 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> -#blocked3 = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> -#linear0 = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> -#linear2 = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 128], [16, 0], [32, 0], [64, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], warp = [[0, 32], [0, 64]], block = []}> -#shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> -#shared5 = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0], [32, 0], [64, 0]]}, alignment = 128> -#smem = #ttg.shared_memory -#tmem0 = #ttng.tensor_memory_encoding -// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> -// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> -module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: @swizzle0_operand_views - // CHECK-DAG: %[[A_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc<128x128xf8E4M3FN> -> tensor<128x128xf8E4M3FN, #{{.*}}> - // CHECK-DAG: %[[A_BASE:.*]] = ttg.local_alloc %[[A_DESC]] : (tensor<128x128xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<128x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> - // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC]] : (tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_TR0:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR0]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_TR1:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-NOT: tt.reshape - // CHECK-NOT: tt.trans - // CHECK-NOT: ttg.local_load - // CHECK: ttng.tc_gen5_mma %[[A_BASE]], %[[B_TR1]], %arg2, %true, %true - tt.func @swizzle0_operand_views( - %a_desc: !tt.tensordesc<128x128xf8E4M3FN>, - %b_desc: !tt.tensordesc<1x1x256x8x16xf8E4M3FN>, - %acc: !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory>) { - %true = arith.constant true - %c0_i32 = arith.constant 0 : i32 - %a = tt.descriptor_load %a_desc[%c0_i32, %c0_i32] : !tt.tensordesc<128x128xf8E4M3FN> -> tensor<128x128xf8E4M3FN, #blocked0> - %b = tt.descriptor_load %b_desc[%c0_i32, %c0_i32, %c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #blocked1> - %b1 = tt.reshape %b : tensor<1x1x256x8x16xf8E4M3FN, #blocked1> -> tensor<8x32x8x16xf8E4M3FN, #blocked2> - %b2 = tt.trans %b1 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blocked2> -> tensor<32x8x8x16xf8E4M3FN, #blocked3> - %b3 = tt.reshape %b2 : tensor<32x8x8x16xf8E4M3FN, #blocked3> -> tensor<256x128xf8E4M3FN, #linear0> - %b4 = tt.trans %b3 {order = array} : tensor<256x128xf8E4M3FN, #linear0> -> tensor<128x256xf8E4M3FN, #linear2> - %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blocked0>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared0, #smem> - %b_s = ttg.local_alloc %b4 : (tensor<128x256xf8E4M3FN, #linear2>) -> !ttg.memdesc<128x256xf8E4M3FN, #shared5, #smem> - ttng.tc_gen5_mma %a_s, %b_s, %acc, %true, %true : !ttg.memdesc<128x128xf8E4M3FN, #shared0, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #shared5, #smem>, !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory> - tt.return - } -} - -// ----- - -#blockedA_desc = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blockedB_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> -#blockedB0_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> -#blockedB1_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> -#linear_desc = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> -#mma_desc = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 32]}> -#sharedA_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> -#sharedB3_desc = #ttg.shared_linear<{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0], [128, 0], [0, 16], [0, 32], [0, 64]]}, alignment = 128> -#sharedB4_desc = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0], [32, 0], [64, 0]]}, alignment = 128> -#smem = #ttg.shared_memory -// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> -// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> -module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: @swizzle0_operand_views_warp_group_dot - // CHECK-DAG: %[[B_DESC_LOAD:.*]] = tt.descriptor_load %arg1[%{{.*}}] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> - // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC_LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, {{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_TR:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_T:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-NOT: tt.reshape - // CHECK-NOT: tt.trans - // CHECK: %[[DOT:.*]] = ttng.warp_group_dot %{{.*}}, %[[B_T]], %arg2 - // CHECK: %[[WAIT:.*]]:3 = ttng.warp_group_dot_wait %[[DOT]], {{.*}}, %[[B_T]] {pendings = 0 : i32} - tt.func @swizzle0_operand_views_warp_group_dot( - %a_desc: !tt.tensordesc<128x128xf8E4M3FN, #sharedA_desc>, - %b_desc: !tt.tensordesc<1x1x256x8x16xf8E4M3FN>, - %acc: tensor<128x256xf32, #mma_desc>) -> tensor<128x256xf32, #mma_desc> { - %c0 = arith.constant 0 : i32 - %a = tt.descriptor_load %a_desc[%c0, %c0] : !tt.tensordesc<128x128xf8E4M3FN, #sharedA_desc> -> tensor<128x128xf8E4M3FN, #blockedA_desc> - %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blockedA_desc>) -> !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> - %b_cm = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> - %b0 = tt.reshape %b_cm : tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> -> tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> - %b1 = tt.trans %b0 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> -> tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> - %b2 = tt.reshape %b1 : tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> -> tensor<256x128xf8E4M3FN, #linear_desc> - %b_s = ttg.local_alloc %b2 : (tensor<256x128xf8E4M3FN, #linear_desc>) -> !ttg.memdesc<256x128xf8E4M3FN, #sharedB3_desc, #smem> - %b_t = ttg.memdesc_trans %b_s {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #sharedB3_desc, #smem> -> !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> - %r = ttng.warp_group_dot %a_s, %b_t, %acc {inputPrecision = 0 : i32, isAsync = true, maxNumImpreciseAcc = 1073741824 : i32} - : !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> * !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> -> tensor<128x256xf32, #mma_desc> - %w:3 = ttng.warp_group_dot_wait %r, %a_s, %b_t {pendings = 0 : i32} : tensor<128x256xf32, #mma_desc>, !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> - tt.return %w#0 : tensor<128x256xf32, #mma_desc> - } -} diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index dbcfb9f0c331..818d4f338a6b 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -273,7 +273,6 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttgpuir.add_accelerate_matmul(pm) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) - nvidia.passes.ttnvgpuir.add_rewrite_mma_operand_views_to_memdesc(pm) nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm) passes.ttir.add_loop_aware_cse(pm) if capability // 10 in [8, 9]: diff --git a/third_party/nvidia/triton_nvidia.cc b/third_party/nvidia/triton_nvidia.cc index eb508107b86e..043c1a4fee6a 100644 --- a/third_party/nvidia/triton_nvidia.cc +++ b/third_party/nvidia/triton_nvidia.cc @@ -182,9 +182,6 @@ void init_triton_nvidia_passes_ttnvgpuir(py::module &&m) { ttng::createTritonNvidiaGPUMMALoweringPass); ADD_PASS_WRAPPER_0("add_optimize_descriptor_encoding", ttng::createTritonNvidiaGPUOptimizeDescriptorEncodingPass); - ADD_PASS_WRAPPER_0( - "add_rewrite_mma_operand_views_to_memdesc", - ttng::createTritonNvidiaGPURewriteMmaOperandViewsToMemDescPass); ADD_PASS_WRAPPER_0("add_optimize_tmem_layouts", ttng::createTritonNvidiaGPUOptimizeTMemLayoutsPass); ADD_PASS_WRAPPER_0("add_interleave_tmem", From 72859e032404815b5d8c56c5d1a242c54244ffce Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Apr 2026 09:15:57 +0900 Subject: [PATCH 53/54] precommit --- lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index e6c5a588e3de..e401bbc4ed0f 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -234,10 +234,9 @@ class RewriteMmaOperandViewsToMemDescForDotOp Attribute srcEnc = inferSrcEncoding(op, resultTy.getEncoding()); if (!srcEnc) return failure(); - return MemDescType::get(trans.getSrc().getType().getShape(), - resultTy.getElementType(), srcEnc, - resultTy.getMemorySpace(), - resultTy.getMutableMemory()); + return MemDescType::get( + trans.getSrc().getType().getShape(), resultTy.getElementType(), srcEnc, + resultTy.getMemorySpace(), resultTy.getMutableMemory()); } static Value replayTensorViews(PatternRewriter &rewriter, Value value, From 3130b82d07e4d1e6bf69759ff599d9fb901cf90d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Apr 2026 16:00:01 +0900 Subject: [PATCH 54/54] inline helpers --- .../Transforms/OptimizeDotOperands.cpp | 90 ++++++++----------- 1 file changed, 35 insertions(+), 55 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index e401bbc4ed0f..134b8ea8154b 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -3,8 +3,6 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" -#include "triton/Analysis/Utility.h" -#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" @@ -13,7 +11,7 @@ #include "triton/Tools/LayoutUtils.h" #include "triton/Tools/LinearLayout.h" #include -#include +#include namespace mlir::triton::gpu { @@ -219,63 +217,23 @@ class RewriteMmaOperandViewsToMemDescForDotOp } private: - static FailureOr pushLayoutBackward(MemDescType resultTy, - Operation *op) { - if (auto reshape = dyn_cast(op)) { - MemDescType srcTy; - if (failed(MemDescReshapeOp::inferReturnTypes( - resultTy.getContext(), reshape.getLoc(), resultTy, - reshape.getSrc().getType().getShape(), srcTy))) - return failure(); - return srcTy; - } - - auto trans = cast(op); - Attribute srcEnc = inferSrcEncoding(op, resultTy.getEncoding()); - if (!srcEnc) + LogicalResult rewriteOperand(Value operand, PatternRewriter &rewriter) const { + if (!isa(operand.getType())) return failure(); - return MemDescType::get( - trans.getSrc().getType().getShape(), resultTy.getElementType(), srcEnc, - resultTy.getMemorySpace(), resultTy.getMutableMemory()); - } - - static Value replayTensorViews(PatternRewriter &rewriter, Value value, - ArrayRef steps) { - Value rewritten = value; - for (Operation *op : steps) { - if (auto reshape = dyn_cast(op)) { - rewritten = MemDescReshapeOp::create(rewriter, op->getLoc(), rewritten, - reshape.getType().getShape()); - } else { - auto trans = cast(op); - rewritten = MemDescTransOp::create(rewriter, op->getLoc(), rewritten, - trans.getOrder()); - } - } - return rewritten; - } - static Value peelMemDescViews(Value value) { - Value current = value; - while (auto view = current.getDefiningOp()) { + Value beforeTrailing = operand; + while (auto view = beforeTrailing.getDefiningOp()) { if (auto reshape = dyn_cast(view)) { - current = reshape.getSrc(); + beforeTrailing = reshape.getSrc(); continue; } if (auto trans = dyn_cast(view)) { - current = trans.getSrc(); + beforeTrailing = trans.getSrc(); continue; } break; } - return current; - } - - LogicalResult rewriteOperand(Value operand, PatternRewriter &rewriter) const { - if (!isa(operand.getType())) - return failure(); - Value beforeTrailing = peelMemDescViews(operand); auto localAlloc = beforeTrailing.getDefiningOp(); if (!localAlloc || !localAlloc.getSrc()) return failure(); @@ -284,13 +242,26 @@ class RewriteMmaOperandViewsToMemDescForDotOp SmallVector tensorReplaySteps; MemDescType baseMemTy = localAlloc.getType(); while (auto view = baseTensor.getDefiningOp()) { - if (!isa(view)) + if (auto reshape = dyn_cast(view)) { + MemDescType srcTy; + auto inferred = MemDescReshapeOp::inferReturnTypes( + getContext(), reshape.getLoc(), baseMemTy, + reshape.getSrc().getType().getShape(), srcTy); + assert(succeeded(inferred) && "backward memdesc reshape inference " + "must succeed"); + (void)inferred; + baseMemTy = srcTy; + } else if (auto trans = dyn_cast(view)) { + Attribute srcEnc = inferSrcEncoding(view, baseMemTy.getEncoding()); + if (!srcEnc) + return failure(); + baseMemTy = MemDescType::get( + trans.getSrc().getType().getShape(), baseMemTy.getElementType(), + srcEnc, baseMemTy.getMemorySpace(), baseMemTy.getMutableMemory()); + } else { break; - auto srcTy = pushLayoutBackward(baseMemTy, view); - if (failed(srcTy)) - return failure(); + } tensorReplaySteps.push_back(view); - baseMemTy = *srcTy; baseTensor = view->getOperand(0); } if (tensorReplaySteps.empty()) @@ -303,7 +274,16 @@ class RewriteMmaOperandViewsToMemDescForDotOp Value rewritten = LocalAllocOp::create(rewriter, localAlloc.getLoc(), baseMemTy, baseTensor); - rewritten = replayTensorViews(rewriter, rewritten, tensorReplaySteps); + for (Operation *op : tensorReplaySteps) { + if (auto reshape = dyn_cast(op)) { + rewritten = MemDescReshapeOp::create(rewriter, op->getLoc(), rewritten, + reshape.getType().getShape()); + } else { + auto trans = cast(op); + rewritten = MemDescTransOp::create(rewriter, op->getLoc(), rewritten, + trans.getOrder()); + } + } rewriter.replaceOp(localAlloc, rewritten); return success(); }