From 2b918e6728b59bdda27b11f0692d5c81fc8be7eb Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sun, 23 Nov 2025 16:54:31 -0800 Subject: [PATCH 01/11] [AMD] Turn TRITON_HIP_USE_ASYNC_COPY to be on by default --- python/triton/knobs.py | 2 +- third_party/amd/backend/compiler.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/triton/knobs.py b/python/triton/knobs.py index b55bd3de3218..e537f4123332 100644 --- a/python/triton/knobs.py +++ b/python/triton/knobs.py @@ -513,8 +513,8 @@ class amd_knobs(base_knobs): # We use strs so that we can have a default value based on other runtime info use_block_pingpong: env_opt_bool = env_opt_bool("TRITON_HIP_USE_BLOCK_PINGPONG") use_in_thread_transpose: env_opt_bool = env_opt_bool("TRITON_HIP_USE_IN_THREAD_TRANSPOSE") + use_async_copy: env_opt_bool = env_opt_bool("TRITON_HIP_USE_ASYNC_COPY") - use_async_copy: env_bool = env_bool("TRITON_HIP_USE_ASYNC_COPY") scalarize_packed_fops: env_bool = env_bool("AMDGCN_SCALARIZE_PACKED_FOPS") diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index ff2f87c3bc2f..3ae7497f969e 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -19,14 +19,18 @@ def get_min_dot_size(target: GPUTarget): def is_pingpong_schedule_enabled(arch, use_async_copy): - return (arch == "gfx942" or (arch == "gfx950" and use_async_copy is True) - ) if knobs.amd.use_block_pingpong is None else knobs.amd.use_block_pingpong + return (arch == "gfx942" or (arch == "gfx950" and use_async_copy is True)) \ + if knobs.amd.use_block_pingpong is None else knobs.amd.use_block_pingpong def is_in_thread_transpose_enabled(arch): return (arch == "gfx942") if knobs.amd.use_in_thread_transpose is None else knobs.amd.use_in_thread_transpose +def is_async_copy_enabled(arch): + return (arch in ["gfx950", "gfx1250"]) if knobs.amd.use_async_copy is None else knobs.amd.use_async_copy + + @dataclass(frozen=True) class HIPOptions: num_warps: int = 4 @@ -227,7 +231,7 @@ def make_ttgir(mod, metadata, options): passes.ttir.add_triton_licm(pm) passes.common.add_canonicalizer(pm) - use_async_copy = knobs.amd.use_async_copy + use_async_copy = is_async_copy_enabled(options.arch) use_block_pingpong = is_pingpong_schedule_enabled(options.arch, use_async_copy) amd.passes.ttgpuir.add_schedule_loops(pm, options.num_stages) From 71182ae93f69a2c3084d82a37abaa0987224ffe1 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Thu, 31 Jul 2025 14:19:48 +0000 Subject: [PATCH 02/11] Disable matmul tests running out of LDS space --- python/test/unit/language/test_matmul.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/test/unit/language/test_matmul.py b/python/test/unit/language/test_matmul.py index 3049cf9a9dc3..72daa1c38398 100644 --- a/python/test/unit/language/test_matmul.py +++ b/python/test/unit/language/test_matmul.py @@ -361,6 +361,8 @@ def test_mxfp(BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, nonKDim, NUM_WARPS, device) pytest.skip("Scaled mxfp8 matmul is only natively supported on CDNA4") if (nonKDim == 16 and BLOCK_K < 128) or (nonKDim == 32 and BLOCK_K < 64): pytest.skip(f"CDNA4 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants") + if (BLOCK_M == 256 or BLOCK_N == 256) and BLOCK_K == 256: + pytest.skip("Config requires too much shared memory") if BLOCK_N == 256 and BLOCK_K == 256: NUM_STAGES = min(NUM_STAGES, 2) @@ -1156,6 +1158,8 @@ def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TR pytest.skip(f"CDNA4 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants") if (A_DATA_TYPE == 'float4' and not WITH_A_SCALE) or (B_DATA_TYPE == 'float4' and not WITH_B_SCALE): pytest.skip("Float4 without scale is tested in test_block_scale_fp4") + if (BLOCK_M == 256 or BLOCK_N == 256) and BLOCK_K == 256: + pytest.skip("Config requires too much shared memory") if not PACK_B_ALONG_K and B_DATA_TYPE != "float4": pytest.skip("Pack along K can only be False for float4") if BLOCK_N == 256 and BLOCK_K == 256: From da42d2d70d84cf1a22a8f5c1e7fe37ee6731dd1a Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Fri, 12 Dec 2025 06:59:39 +0000 Subject: [PATCH 03/11] Fix async_copy_global_to_local to amdg.buffer_load_to_local --- .../TritonGPU/amd/amd-convert-buffer-ops.mlir | 22 +++++ .../amd/amd-pipeline-chained-dots.mlir | 3 +- .../amd/amd-update-async-wait-count.mlir | 28 ++++++ test/TritonGPU/loop-pipeline-hip.mlir | 2 +- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 86 +++---------------- .../amd/lib/TritonAMDGPUToLLVM/Utility.cpp | 76 +++++++++++++--- .../amd/lib/TritonAMDGPUToLLVM/Utility.h | 22 +++-- .../CoalesceAsyncCopy.cpp | 3 +- .../lib/TritonAMDGPUTransforms/LowerLoops.cpp | 25 ++++-- .../lib/TritonAMDGPUTransforms/Utility.cpp | 8 +- 10 files changed, 161 insertions(+), 114 deletions(-) diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir index 3d237f39866b..7aab9e83a331 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -942,3 +942,25 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return } } + +// ----- + +// Test that we don't generate buffer_load_to_local when the layout order is incompatible. +// The blocked layout has order [2, 1, 0] but the shared layout has order [2, 0, 1], +// which causes non-coalesced writes and cannot be lowered to direct-to-LDS. +#blocked_3d = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 8, 8], warpsPerCTA = [1, 1, 4], order = [2, 1, 0]}> +#shared_3d = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // COMMON-LABEL: @incompatible_order_no_buffer_load_to_local + // Check that we don't generate buffer_load_to_local for incompatible layouts + // COMMON-NOT: amdg.buffer_load_to_local + // COMMON: ttg.async_copy_global_to_local + tt.func @incompatible_order_no_buffer_load_to_local(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg1: !ttg.memdesc<1x128x64xf16, #shared_3d, #smem, mutable>) { + %0 = tt.splat %arg0 : !tt.ptr -> tensor<1x128x64x!tt.ptr, #blocked_3d> + %1 = ttg.async_copy_global_to_local %0, %arg1 : tensor<1x128x64x!tt.ptr, #blocked_3d> -> <1x128x64xf16, #shared_3d, #smem, mutable> + ttg.async_wait {num = 0 : i32} + tt.return + } +} diff --git a/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir b/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir index 01e92768d96c..d0a6e0820161 100644 --- a/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir +++ b/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir @@ -1,4 +1,5 @@ -// RUN: triton-opt %s -split-input-file -tritonamdgpu-schedule-loops="num_stages=4" -tritonamdgpu-pipeline="use_async_copy=1" -canonicalize | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritonamdgpu-schedule-loops="num_stages=4" -tritonamdgpu-pipeline="use_async_copy=1" -canonicalize +// FIXME: | FileCheck %s #blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32, 16], isTransposed = true}> diff --git a/test/TritonGPU/amd/amd-update-async-wait-count.mlir b/test/TritonGPU/amd/amd-update-async-wait-count.mlir index d7b52bb041cf..7702d4eac10b 100644 --- a/test/TritonGPU/amd/amd-update-async-wait-count.mlir +++ b/test/TritonGPU/amd/amd-update-async-wait-count.mlir @@ -433,3 +433,31 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return } } + +// ----- + +// Test scf.if without else region in def chain + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: scf_if_without_else + tt.func public @scf_if_without_else(%arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, %cond: i1) { + // Emits 1 direct to lds instruction + %0 = ttg.async_copy_global_to_local %arg3, %arg1 : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + %1 = ttg.async_commit_group tokens %0 + + // For scf.if without else region, the else path contributes 0 instructions; + // so the minimum across both paths is 0. + scf.if %cond { + // Emits 1 direct to lds instruction inside the if + %inner = ttg.async_copy_global_to_local %arg3, %arg1 : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + %inner_commit = ttg.async_commit_group tokens %inner + } + + // CHECK: amdg.async_wait {{.*}} {num_inst = 0 + %10 = ttg.async_wait %1 {num = 0 : i32} + tt.return + } +} diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir index 9d42e550b400..45aa9dbfad45 100644 --- a/test/TritonGPU/loop-pipeline-hip.mlir +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -1,5 +1,5 @@ // RUN: triton-opt %s -split-input-file -tritonamdgpu-schedule-loops=num_stages=2 -tritonamdgpu-pipeline -canonicalize | FileCheck %s --check-prefixes=COMMON,SYNC -// RUN: triton-opt %s -split-input-file -tritonamdgpu-schedule-loops="num_stages=2" -tritonamdgpu-pipeline="use_async_copy=1" -canonicalize | FileCheck %s --check-prefixes=COMMON,ASYNC +// FIXME: triton-opt %s -split-input-file -tritonamdgpu-schedule-loops="num_stages=2" -tritonamdgpu-pipeline="use_async_copy=1" -canonicalize | FileCheck %s --check-prefixes=COMMON,ASYNC #blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index a0e185ab36e7..467ccbc18a0a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -289,59 +289,6 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase { ModuleAxisInfoAnalysis &axisAnalysisPass) : LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} - // direct to lds loads do not support per lane shared offsets. We need to - // ensure that we write coalesced into shared memory. This means we cannot - // exceed the supported load width because splitting them would cause strided - // (non coalesced) writes. Additionally: - // 1) For *non* swizzled shared encodings we check if they result in - // coalesced writes and can then lower them directly to the intrinsics. - // 2) For swizzled shared encodings we need to transfer the swizzling to the - // source pointers. For now this is done by swizzling the pointers - // between the lane of a warp via permute. This only works if the swizzle - // pattern does not exchange elements between warps which holds for all - // our swizzle patterns. There is still a check performed to not silently - // produce wrong results if we invalidate the condition in the future - LogicalResult canWriteCoalesced(RewriterBase &rewriter, Operation *op, - RankedTensorType srcTy, MemDescType dstTy, - unsigned vectorSize, - bool requiresSrcPtrSwizzling) const { - if (targetInfo.supportsDirectToLDSScattering()) { - return success(); - } - - int vecBits = vectorSize * dstTy.getElementTypeBitWidth(); - if (!targetInfo.supportsDirectToLdsLoadBitWidth(vecBits)) { - LDBG(*op << " results in unsupported load bitwidth: " << vecBits); - return failure(); - } - // Compute the blocked -> shared linear layout to check preconditions - LinearLayout srcLayout = triton::gpu::toLinearLayout(srcTy); - LinearLayout sharedLayout; - if (auto paddedEnc = dyn_cast( - dstTy.getEncoding())) { - sharedLayout = paddedEnc.getLinearComponent(); - } else { - sharedLayout = triton::gpu::toLinearLayout(dstTy); - } - LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout); - - unsigned threadsPerWarp = lookupThreadsPerWarp(rewriter); - if (!requiresSrcPtrSwizzling && - !LLVM::AMD::canCoalesceWriteIntoSharedMemory( - rewriter, srcToSharedLayout, threadsPerWarp, vectorSize)) { - LDBG(*op << " does not write coalesced into LDS and is not swizzled"); - return failure(); - } - - if (requiresSrcPtrSwizzling && - !LLVM::AMD::doesSwizzleInsideWarp(rewriter, srcToSharedLayout, - threadsPerWarp)) { - LDBG(*op << " does swizzle across warp boundaries"); - return failure(); - } - return success(); - } - // For each load emit the computation to get the lane id offset which holds // the source pointers/offsets we need to store to shared memory SmallVector @@ -819,17 +766,7 @@ struct BufferLoadToLocalOpConversion // If the op has a contiguity hint use it to increase the vector size. vec = std::max(vec, op.getContiguity()); - // For padded encodings restrict vec by the min interval - if (auto padEnc = dyn_cast(dstEnc)) { - vec = std::min(vec, padEnc.getMinInterval()); - } - - auto maybeSwizzledEnc = dyn_cast(dstEnc); - bool requiresSrcPtrSwizzling = - !targetInfo.supportsDirectToLDSScattering() && maybeSwizzledEnc && - maybeSwizzledEnc.getMaxPhase() != 1; - if (failed(canWriteCoalesced(rewriter, op, ptrType, dstTy, vec, - requiresSrcPtrSwizzling))) { + if (!LLVM::AMD::canLoadDirectToLDS(targetInfo, ptrType, dstEnc, vec)) { return failure(); } @@ -838,6 +775,10 @@ struct BufferLoadToLocalOpConversion auto flatDstTy = dstTy; SmallVector swizzledLaneOffsets; + auto maybeSwizzledEnc = dyn_cast(dstEnc); + bool requiresSrcPtrSwizzling = + !targetInfo.supportsDirectToLDSScattering() && maybeSwizzledEnc && + maybeSwizzledEnc.getMaxPhase() != 1; if (requiresSrcPtrSwizzling) { // TODO (alex): this is only correct as long as the lds view is a // contiguous block. So this can break if we slice along the 2 minor @@ -968,18 +909,7 @@ struct AsyncCopyGlobalToLocalOpConversion // If the op has a contiguity hint use it to increase the vector size. vec = std::max(vec, op.getContiguity()); - // For padded encodings restrict vec by the min interval - if (auto padEnc = dyn_cast(dstEnc)) { - vec = std::min(vec, padEnc.getMinInterval()); - } - - auto maybeSwizzledEnc = dyn_cast(dstEnc); - bool requiresSrcPtrSwizzling = - !targetInfo.supportsDirectToLDSScattering() && maybeSwizzledEnc && - maybeSwizzledEnc.getMaxPhase() != 1; - - if (failed(canWriteCoalesced(rewriter, op, srcTy, dstTy, vec, - requiresSrcPtrSwizzling))) { + if (!LLVM::AMD::canLoadDirectToLDS(targetInfo, srcTy, dstEnc, vec)) { return failure(); } @@ -987,6 +917,10 @@ struct AsyncCopyGlobalToLocalOpConversion // the LDS addresses since we gather into LDS auto flatDstTy = dstTy; SmallVector swizzledLaneOffsets; + auto maybeSwizzledEnc = dyn_cast(dstEnc); + bool requiresSrcPtrSwizzling = + !targetInfo.supportsDirectToLDSScattering() && maybeSwizzledEnc && + maybeSwizzledEnc.getMaxPhase() != 1; if (requiresSrcPtrSwizzling) { auto flatSharedEnc = SwizzledSharedEncodingAttr::get( op->getContext(), maybeSwizzledEnc.getVec(), 1, 1, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 1d14593bc717..70b13b2d5eb6 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -8,6 +8,7 @@ #include "mlir/IR/PatternMatch.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" namespace tt = mlir::triton; using mlir::triton::ModuleAxisInfoAnalysis; @@ -591,7 +592,7 @@ Type scaleDotElemTypeToMLIRType(MLIRContext *ctx, triton::ScaleDotElemType t) { } } -bool canCoalesceWriteIntoSharedMemory(RewriterBase &rewriter, +bool canCoalesceWriteIntoSharedMemory(MLIRContext *ctx, const LinearLayout &srcToSharedLayout, unsigned threadsPerWarp, unsigned vecSize) { @@ -603,7 +604,7 @@ bool canCoalesceWriteIntoSharedMemory(RewriterBase &rewriter, return false; } - StringAttr kLane = rewriter.getStringAttr("lane"); + StringAttr kLane = StringAttr::get(ctx, "lane"); for (int inLane : llvm::seq(srcToSharedLayout.getInDimSizeLog2(kLane))) { auto basis = srcToSharedLayout.getBasis(kLane, inLane)[0]; unsigned expected = contig * (1 << inLane); @@ -621,7 +622,7 @@ bool canCoalesceWriteIntoSharedMemory(RewriterBase &rewriter, assert(llvm::isPowerOf2_32(threadsPerWarp)); assert(llvm::isPowerOf2_32(contig)); unsigned mask = (threadsPerWarp * contig) - 1; - StringAttr kWarp = rewriter.getStringAttr("warp"); + StringAttr kWarp = StringAttr::get(ctx, "warp"); for (int inWarp : llvm::seq(srcToSharedLayout.getInDimSizeLog2(kWarp))) { auto basis = srcToSharedLayout.getBasis(kWarp, inWarp)[0]; if ((basis & mask) != 0) { @@ -635,7 +636,7 @@ bool canCoalesceWriteIntoSharedMemory(RewriterBase &rewriter, return true; } -bool doesSwizzleInsideWarp(RewriterBase &rewriter, +bool doesSwizzleInsideWarp(MLIRContext *ctx, const LinearLayout &srcToSharedLayout, unsigned threadsPerWarp) { auto contig = srcToSharedLayout.getNumConsecutiveInOut(); @@ -644,7 +645,7 @@ bool doesSwizzleInsideWarp(RewriterBase &rewriter, assert(llvm::isPowerOf2_32(threadsPerWarp)); unsigned upperLimit = threadsPerWarp * contig; - StringAttr kLane = rewriter.getStringAttr("lane"); + StringAttr kLane = StringAttr::get(ctx, "lane"); for (int inLane : llvm::seq(srcToSharedLayout.getInDimSizeLog2(kLane))) { auto basis = srcToSharedLayout.getBasis(kLane, inLane)[0]; if (basis >= upperLimit) { @@ -654,16 +655,63 @@ bool doesSwizzleInsideWarp(RewriterBase &rewriter, return true; } -bool isUsedByDotScaledOp(Operation *op) { - const ForwardSliceOptions fwdOpt; - SetVector forwardSliceSet; - getForwardSlice(op, &forwardSliceSet, fwdOpt); +// On gfx9, direct to LDS loads do not support per lane shared offsets. We need +// to ensure that we write coalesced into shared memory. This means we cannot +// exceed the supported load width because splitting them would cause strided +// (non coalesced) writes. Additionally: +// +// 1. For *non* swizzled shared encodings we check if they result in coalesced +// writes and can then lower them directly to the intrinsics. +// 2. For swizzled shared encodings we need to transfer the swizzling to the +// source pointers. For now this is done by swizzling the pointers +// between the lane of a warp via permute. This only works if the swizzle +// pattern does not exchange elements between warps which holds for all +// our swizzle patterns. There is still a check performed to not silently +// produce wrong results if we invalidate the condition in the future +bool canLoadDirectToLDS(const triton::AMD::TargetInfo &targetInfo, + RankedTensorType srcTy, Attribute dstEnc, + unsigned &vectorSize) { + // For padded encodings restrict vec by the min interval + auto paddedEnc = dyn_cast(dstEnc); + if (paddedEnc) + vectorSize = std::min(vectorSize, paddedEnc.getMinInterval()); + + int elemBitWidth = tt::getPointeeBitWidth(srcTy); + int vectorBits = vectorSize * elemBitWidth; + if (!targetInfo.supportsDirectToLdsLoadBitWidth(vectorBits)) + return false; + + if (targetInfo.supportsDirectToLDSScattering()) + return true; + + // Compute the blocked -> shared linear layout to check preconditions + LinearLayout srcLayout = triton::gpu::toLinearLayout(srcTy); + LinearLayout sharedLayout; + if (paddedEnc) { + sharedLayout = paddedEnc.getLinearComponent(); + } else { + sharedLayout = triton::gpu::toLinearLayout(srcTy.getShape(), dstEnc); + } + LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout); + + auto swizzledEnc = dyn_cast(dstEnc); + bool requiresSrcPtrSwizzling = swizzledEnc && swizzledEnc.getMaxPhase() != 1; + unsigned warpSize = targetInfo.getWarpSize(); - return std::any_of( - forwardSliceSet.begin(), forwardSliceSet.end(), [](auto *operation) { - return isa( - operation); - }); + if (!requiresSrcPtrSwizzling && + !canCoalesceWriteIntoSharedMemory(srcTy.getContext(), srcToSharedLayout, + warpSize, vectorSize)) { + LDBG("Does not write coalesced into LDS"); + return false; + } + + if (requiresSrcPtrSwizzling && + !doesSwizzleInsideWarp(srcTy.getContext(), srcToSharedLayout, warpSize)) { + LDBG("Swizzles across warp boundaries"); + return false; + } + + return true; } bool isChainDotHead(tt::DotOpInterface dotOp, unsigned opIdx) { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index 75beb77f635b..9223ee20bf5f 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -97,19 +97,23 @@ Type scaleDotElemTypeToMLIRType(MLIRContext *ctx, triton::ScaleDotElemType t); // Returns true if we can perform coalesced write from the source encoding to // the destination encoding for a given vec size. -bool canCoalesceWriteIntoSharedMemory(RewriterBase &rewriter, +bool canCoalesceWriteIntoSharedMemory(MLIRContext *ctx, const LinearLayout &srcToSharedLayout, unsigned threadsPerWarp, unsigned vecSize); -// Returns true if the swizzling pattern does only swizzle the shared memory -// offsets of a warp and does not exchange destination elements across warps -bool doesSwizzleInsideWarp(RewriterBase &rewriter, - const LinearLayout &srcToSharedLayout, - unsigned threadsPerWarp); - -// Return true if op is used by DotScaledOp or UpcastMXFPOp ops. -bool isUsedByDotScaledOp(Operation *op); +// Returns true if we load directly from global |srcTy| to shared memory +// |dstEnc| for the given target. +// +// This function expects caller to pass in as |vectorSize| the vector size +// reading from global memory, after factoring in axis information and alignment +// hints. It will be updated to factor in shared memory |dstEnc| constraints. +// +// This is used by both the LLVM lowering and the conversion pattern to ensure +// consistency. +bool canLoadDirectToLDS(const triton::AMD::TargetInfo &targetInfo, + RankedTensorType srcTy, Attribute dstEnc, + unsigned &vectorSize); // Check if the result of this tl.dot is used as opA or opB of another tl.dot // in the same region diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CoalesceAsyncCopy.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CoalesceAsyncCopy.cpp index 12f22f5fc752..3160e32e560b 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CoalesceAsyncCopy.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CoalesceAsyncCopy.cpp @@ -102,7 +102,8 @@ struct CoalesceAsyncCopyWrites ttg::DistributedEncodingTrait newDistEnc; if (LLVM::AMD::canCoalesceWriteIntoSharedMemory( - rewriter, regToSharedLayout, threadsPerWarp, loadContig)) { + copyOp.getContext(), regToSharedLayout, threadsPerWarp, + loadContig)) { return rewriter.notifyMatchFailure(copyOp, "already writes coalesced"); } diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp index 6ee112d93c75..2bb474982b1b 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp @@ -2,12 +2,14 @@ #include "Utility.h" #include "amd/lib/TritonAMDGPUToLLVM/AsyncUtility.h" #include "amd/lib/TritonAMDGPUToLLVM/TargetInfo.h" +#include "amd/lib/TritonAMDGPUToLLVM/Utility.h" #include "amd/lib/TritonAMDGPUTransforms/PipelineUtility.h" #include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "llvm/Support/Debug.h" #include +#undef DEBUG_TYPE #define DEBUG_TYPE "tritonamdgpu-pipeline-lower-loops" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") @@ -273,13 +275,12 @@ bool canBeConvertedToAsyncLoad(unsigned numBuffers, tt::LoadOp loadOp, using tt::AMD::ISAFamily; if (sharedEnc && llvm::is_contained({ISAFamily::CDNA3, ISAFamily::CDNA4}, targetInfo.getISAFamily())) { - // Compute the final vecSize we can use for the combination of - // sourceEncoding and sharedEncoding. We can only use AsyncCopy if the - // target supports the requested or a smaller vecSize because we cannot - // stride when loading directly to lds on GFX9 + // Compute the final vecSize we can use for source to destination type and + // encoding. We can only use async copy if the target supports the requested + // or a smaller vecSize because we cannot stride when loading directly to + // lds on GFX9. auto srcTy = cast(loadOp.getPtr().getType()); auto regLayout = triton::gpu::toLinearLayout(srcTy); - // It's the allocation so we trim the multibuffer dimension auto srcShape = srcTy.getShape(); triton::LinearLayout sharedLayout; auto paddedEnc = dyn_cast(sharedEnc); @@ -290,13 +291,19 @@ bool canBeConvertedToAsyncLoad(unsigned numBuffers, tt::LoadOp loadOp, } auto regToSharedLayout = regLayout.invertAndCompose(sharedLayout); - unsigned elemBitWidth = tt::getPointeeBitWidth(srcTy); unsigned vecSize = regToSharedLayout.getNumConsecutiveInOut(); - if (paddedEnc) - vecSize = std::min(vecSize, paddedEnc.getMinInterval()); + LDBG("init global to shared vector size: " << vecSize); + int elemBitWidth = tt::getPointeeBitWidth(srcTy); + vecSize = fitToValidDirectToLdsVecSize(vecSize, elemBitWidth, targetInfo); + LDBG("vector size after fitting arch direct to LDS: " << vecSize); + if (vecSize == 0) { + return false; + } - if (fitToValidDirectToLdsVecSize(vecSize, elemBitWidth, targetInfo) == 0) + if (!LLVM::AMD::canLoadDirectToLDS(targetInfo, srcTy, sharedEnc, vecSize)) { + LDBG("cannot use direct to LDS due to arch constraints"); return false; + } } // Checks whether the global pointer's contiguity and mask alignment allows diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp index 051abe922ef4..3730adedd4f8 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp @@ -22,11 +22,13 @@ int deduceMinCountBetweeOps(Operation *beginOp, Operation *endOp, int count = 0; for (auto op = beginOp; op != endOp; op = op->getNextNode()) { if (auto ifOp = llvm::dyn_cast(op)) { - assert(!ifOp.getThenRegion().empty() && !ifOp.getElseRegion().empty()); + assert(!ifOp.getThenRegion().empty()); auto minThen = deduceMinCountInBlock(ifOp.getThenRegion().front(), countFunc); - auto minElse = - deduceMinCountInBlock(ifOp.getElseRegion().front(), countFunc); + int minElse = 0; + if (!ifOp.getElseRegion().empty()) + minElse = + deduceMinCountInBlock(ifOp.getElseRegion().front(), countFunc); count += std::min(minThen, minElse); } else if (auto forOp = llvm::dyn_cast(op)) { if (std::optional tripCount = forOp.getStaticTripCount()) { From 6224bc88631ffe922f095211c52f9fd9967a7337 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Wed, 17 Dec 2025 22:46:34 +0000 Subject: [PATCH 04/11] Revert "Disable matmul tests running out of LDS space" This reverts commit 71182ae93f69a2c3084d82a37abaa0987224ffe1. --- python/test/unit/language/test_matmul.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/test/unit/language/test_matmul.py b/python/test/unit/language/test_matmul.py index 72daa1c38398..3049cf9a9dc3 100644 --- a/python/test/unit/language/test_matmul.py +++ b/python/test/unit/language/test_matmul.py @@ -361,8 +361,6 @@ def test_mxfp(BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, nonKDim, NUM_WARPS, device) pytest.skip("Scaled mxfp8 matmul is only natively supported on CDNA4") if (nonKDim == 16 and BLOCK_K < 128) or (nonKDim == 32 and BLOCK_K < 64): pytest.skip(f"CDNA4 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants") - if (BLOCK_M == 256 or BLOCK_N == 256) and BLOCK_K == 256: - pytest.skip("Config requires too much shared memory") if BLOCK_N == 256 and BLOCK_K == 256: NUM_STAGES = min(NUM_STAGES, 2) @@ -1158,8 +1156,6 @@ def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TR pytest.skip(f"CDNA4 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants") if (A_DATA_TYPE == 'float4' and not WITH_A_SCALE) or (B_DATA_TYPE == 'float4' and not WITH_B_SCALE): pytest.skip("Float4 without scale is tested in test_block_scale_fp4") - if (BLOCK_M == 256 or BLOCK_N == 256) and BLOCK_K == 256: - pytest.skip("Config requires too much shared memory") if not PACK_B_ALONG_K and B_DATA_TYPE != "float4": pytest.skip("Pack along K can only be False for float4") if BLOCK_N == 256 and BLOCK_K == 256: From cd13b9ccc23614d0d85e55b0076895e117076ff2 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Wed, 17 Dec 2025 22:51:35 +0000 Subject: [PATCH 05/11] Reapply "Disable matmul tests running out of LDS space" This reverts commit 6224bc88631ffe922f095211c52f9fd9967a7337. --- python/test/unit/language/test_matmul.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/test/unit/language/test_matmul.py b/python/test/unit/language/test_matmul.py index 3049cf9a9dc3..72daa1c38398 100644 --- a/python/test/unit/language/test_matmul.py +++ b/python/test/unit/language/test_matmul.py @@ -361,6 +361,8 @@ def test_mxfp(BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, nonKDim, NUM_WARPS, device) pytest.skip("Scaled mxfp8 matmul is only natively supported on CDNA4") if (nonKDim == 16 and BLOCK_K < 128) or (nonKDim == 32 and BLOCK_K < 64): pytest.skip(f"CDNA4 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants") + if (BLOCK_M == 256 or BLOCK_N == 256) and BLOCK_K == 256: + pytest.skip("Config requires too much shared memory") if BLOCK_N == 256 and BLOCK_K == 256: NUM_STAGES = min(NUM_STAGES, 2) @@ -1156,6 +1158,8 @@ def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TR pytest.skip(f"CDNA4 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants") if (A_DATA_TYPE == 'float4' and not WITH_A_SCALE) or (B_DATA_TYPE == 'float4' and not WITH_B_SCALE): pytest.skip("Float4 without scale is tested in test_block_scale_fp4") + if (BLOCK_M == 256 or BLOCK_N == 256) and BLOCK_K == 256: + pytest.skip("Config requires too much shared memory") if not PACK_B_ALONG_K and B_DATA_TYPE != "float4": pytest.skip("Pack along K can only be False for float4") if BLOCK_N == 256 and BLOCK_K == 256: From 79177c4beb6f5cce856360ee839e49811b5285fd Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Wed, 17 Dec 2025 22:51:42 +0000 Subject: [PATCH 06/11] Revert "Fix async_copy_global_to_local to amdg.buffer_load_to_local" This reverts commit da42d2d70d84cf1a22a8f5c1e7fe37ee6731dd1a. --- .../TritonGPU/amd/amd-convert-buffer-ops.mlir | 22 ----- .../amd/amd-pipeline-chained-dots.mlir | 3 +- .../amd/amd-update-async-wait-count.mlir | 28 ------ test/TritonGPU/loop-pipeline-hip.mlir | 2 +- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 86 ++++++++++++++++--- .../amd/lib/TritonAMDGPUToLLVM/Utility.cpp | 76 +++------------- .../amd/lib/TritonAMDGPUToLLVM/Utility.h | 22 ++--- .../CoalesceAsyncCopy.cpp | 3 +- .../lib/TritonAMDGPUTransforms/LowerLoops.cpp | 25 ++---- .../lib/TritonAMDGPUTransforms/Utility.cpp | 8 +- 10 files changed, 114 insertions(+), 161 deletions(-) diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir index 7aab9e83a331..3d237f39866b 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -942,25 +942,3 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return } } - -// ----- - -// Test that we don't generate buffer_load_to_local when the layout order is incompatible. -// The blocked layout has order [2, 1, 0] but the shared layout has order [2, 0, 1], -// which causes non-coalesced writes and cannot be lowered to direct-to-LDS. -#blocked_3d = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 8, 8], warpsPerCTA = [1, 1, 4], order = [2, 1, 0]}> -#shared_3d = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 0, 1]}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { - // COMMON-LABEL: @incompatible_order_no_buffer_load_to_local - // Check that we don't generate buffer_load_to_local for incompatible layouts - // COMMON-NOT: amdg.buffer_load_to_local - // COMMON: ttg.async_copy_global_to_local - tt.func @incompatible_order_no_buffer_load_to_local(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, - %arg1: !ttg.memdesc<1x128x64xf16, #shared_3d, #smem, mutable>) { - %0 = tt.splat %arg0 : !tt.ptr -> tensor<1x128x64x!tt.ptr, #blocked_3d> - %1 = ttg.async_copy_global_to_local %0, %arg1 : tensor<1x128x64x!tt.ptr, #blocked_3d> -> <1x128x64xf16, #shared_3d, #smem, mutable> - ttg.async_wait {num = 0 : i32} - tt.return - } -} diff --git a/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir b/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir index d0a6e0820161..01e92768d96c 100644 --- a/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir +++ b/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir @@ -1,5 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritonamdgpu-schedule-loops="num_stages=4" -tritonamdgpu-pipeline="use_async_copy=1" -canonicalize -// FIXME: | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritonamdgpu-schedule-loops="num_stages=4" -tritonamdgpu-pipeline="use_async_copy=1" -canonicalize | FileCheck %s #blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32, 16], isTransposed = true}> diff --git a/test/TritonGPU/amd/amd-update-async-wait-count.mlir b/test/TritonGPU/amd/amd-update-async-wait-count.mlir index 5b29672d3c5c..56e1eb00e609 100644 --- a/test/TritonGPU/amd/amd-update-async-wait-count.mlir +++ b/test/TritonGPU/amd/amd-update-async-wait-count.mlir @@ -487,31 +487,3 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return } } - -// ----- - -// Test scf.if without else region in def chain - -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [1, 0]}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: scf_if_without_else - tt.func public @scf_if_without_else(%arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, %cond: i1) { - // Emits 1 direct to lds instruction - %0 = ttg.async_copy_global_to_local %arg3, %arg1 : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> - %1 = ttg.async_commit_group tokens %0 - - // For scf.if without else region, the else path contributes 0 instructions; - // so the minimum across both paths is 0. - scf.if %cond { - // Emits 1 direct to lds instruction inside the if - %inner = ttg.async_copy_global_to_local %arg3, %arg1 : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> - %inner_commit = ttg.async_commit_group tokens %inner - } - - // CHECK: amdg.async_wait {{.*}} {num_inst = 0 - %10 = ttg.async_wait %1 {num = 0 : i32} - tt.return - } -} diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir index 31ed914f478c..ffe51c195b14 100644 --- a/test/TritonGPU/loop-pipeline-hip.mlir +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -1,5 +1,5 @@ // RUN: triton-opt %s -split-input-file -tritonamdgpu-schedule-loops=num_stages=2 -tritonamdgpu-pipeline -canonicalize | FileCheck %s --check-prefixes=COMMON,SYNC -// FIXME: triton-opt %s -split-input-file -tritonamdgpu-schedule-loops="num_stages=2" -tritonamdgpu-pipeline="use_async_copy=1" -canonicalize | FileCheck %s --check-prefixes=COMMON,ASYNC +// RUN: triton-opt %s -split-input-file -tritonamdgpu-schedule-loops="num_stages=2" -tritonamdgpu-pipeline="use_async_copy=1" -canonicalize | FileCheck %s --check-prefixes=COMMON,ASYNC #blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index b96dc068d463..c29319e58d20 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -289,6 +289,59 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase { ModuleAxisInfoAnalysis &axisAnalysisPass) : LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} + // direct to lds loads do not support per lane shared offsets. We need to + // ensure that we write coalesced into shared memory. This means we cannot + // exceed the supported load width because splitting them would cause strided + // (non coalesced) writes. Additionally: + // 1) For *non* swizzled shared encodings we check if they result in + // coalesced writes and can then lower them directly to the intrinsics. + // 2) For swizzled shared encodings we need to transfer the swizzling to the + // source pointers. For now this is done by swizzling the pointers + // between the lane of a warp via permute. This only works if the swizzle + // pattern does not exchange elements between warps which holds for all + // our swizzle patterns. There is still a check performed to not silently + // produce wrong results if we invalidate the condition in the future + LogicalResult canWriteCoalesced(RewriterBase &rewriter, Operation *op, + RankedTensorType srcTy, MemDescType dstTy, + unsigned vectorSize, + bool requiresSrcPtrSwizzling) const { + if (targetInfo.supportsDirectToLDSScattering()) { + return success(); + } + + int vecBits = vectorSize * dstTy.getElementTypeBitWidth(); + if (!targetInfo.supportsDirectToLdsLoadBitWidth(vecBits)) { + LDBG(*op << " results in unsupported load bitwidth: " << vecBits); + return failure(); + } + // Compute the blocked -> shared linear layout to check preconditions + LinearLayout srcLayout = triton::gpu::toLinearLayout(srcTy); + LinearLayout sharedLayout; + if (auto paddedEnc = dyn_cast( + dstTy.getEncoding())) { + sharedLayout = paddedEnc.getLinearComponent(); + } else { + sharedLayout = triton::gpu::toLinearLayout(dstTy); + } + LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout); + + unsigned threadsPerWarp = lookupThreadsPerWarp(rewriter); + if (!requiresSrcPtrSwizzling && + !LLVM::AMD::canCoalesceWriteIntoSharedMemory( + rewriter, srcToSharedLayout, threadsPerWarp, vectorSize)) { + LDBG(*op << " does not write coalesced into LDS and is not swizzled"); + return failure(); + } + + if (requiresSrcPtrSwizzling && + !LLVM::AMD::doesSwizzleInsideWarp(rewriter, srcToSharedLayout, + threadsPerWarp)) { + LDBG(*op << " does swizzle across warp boundaries"); + return failure(); + } + return success(); + } + // For each load emit the computation to get the lane id offset which holds // the source pointers/offsets we need to store to shared memory SmallVector @@ -784,7 +837,17 @@ struct BufferLoadToLocalOpConversion // If the op has a contiguity hint use it to increase the vector size. vec = std::max(vec, op.getContiguity()); - if (!LLVM::AMD::canLoadDirectToLDS(targetInfo, ptrType, dstEnc, vec)) { + // For padded encodings restrict vec by the min interval + if (auto padEnc = dyn_cast(dstEnc)) { + vec = std::min(vec, padEnc.getMinInterval()); + } + + auto maybeSwizzledEnc = dyn_cast(dstEnc); + bool requiresSrcPtrSwizzling = + !targetInfo.supportsDirectToLDSScattering() && maybeSwizzledEnc && + maybeSwizzledEnc.getMaxPhase() != 1; + if (failed(canWriteCoalesced(rewriter, op, ptrType, dstTy, vec, + requiresSrcPtrSwizzling))) { return failure(); } @@ -793,10 +856,6 @@ struct BufferLoadToLocalOpConversion auto flatDstTy = dstTy; SmallVector swizzledLaneOffsets; - auto maybeSwizzledEnc = dyn_cast(dstEnc); - bool requiresSrcPtrSwizzling = - !targetInfo.supportsDirectToLDSScattering() && maybeSwizzledEnc && - maybeSwizzledEnc.getMaxPhase() != 1; if (requiresSrcPtrSwizzling) { // TODO (alex): this is only correct as long as the lds view is a // contiguous block. So this can break if we slice along the 2 minor @@ -927,7 +986,18 @@ struct AsyncCopyGlobalToLocalOpConversion // If the op has a contiguity hint use it to increase the vector size. vec = std::max(vec, op.getContiguity()); - if (!LLVM::AMD::canLoadDirectToLDS(targetInfo, srcTy, dstEnc, vec)) { + // For padded encodings restrict vec by the min interval + if (auto padEnc = dyn_cast(dstEnc)) { + vec = std::min(vec, padEnc.getMinInterval()); + } + + auto maybeSwizzledEnc = dyn_cast(dstEnc); + bool requiresSrcPtrSwizzling = + !targetInfo.supportsDirectToLDSScattering() && maybeSwizzledEnc && + maybeSwizzledEnc.getMaxPhase() != 1; + + if (failed(canWriteCoalesced(rewriter, op, srcTy, dstTy, vec, + requiresSrcPtrSwizzling))) { return failure(); } @@ -935,10 +1005,6 @@ struct AsyncCopyGlobalToLocalOpConversion // the LDS addresses since we gather into LDS auto flatDstTy = dstTy; SmallVector swizzledLaneOffsets; - auto maybeSwizzledEnc = dyn_cast(dstEnc); - bool requiresSrcPtrSwizzling = - !targetInfo.supportsDirectToLDSScattering() && maybeSwizzledEnc && - maybeSwizzledEnc.getMaxPhase() != 1; if (requiresSrcPtrSwizzling) { auto flatSharedEnc = SwizzledSharedEncodingAttr::get( op->getContext(), maybeSwizzledEnc.getVec(), 1, 1, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 70b13b2d5eb6..1d14593bc717 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -8,7 +8,6 @@ #include "mlir/IR/PatternMatch.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" namespace tt = mlir::triton; using mlir::triton::ModuleAxisInfoAnalysis; @@ -592,7 +591,7 @@ Type scaleDotElemTypeToMLIRType(MLIRContext *ctx, triton::ScaleDotElemType t) { } } -bool canCoalesceWriteIntoSharedMemory(MLIRContext *ctx, +bool canCoalesceWriteIntoSharedMemory(RewriterBase &rewriter, const LinearLayout &srcToSharedLayout, unsigned threadsPerWarp, unsigned vecSize) { @@ -604,7 +603,7 @@ bool canCoalesceWriteIntoSharedMemory(MLIRContext *ctx, return false; } - StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kLane = rewriter.getStringAttr("lane"); for (int inLane : llvm::seq(srcToSharedLayout.getInDimSizeLog2(kLane))) { auto basis = srcToSharedLayout.getBasis(kLane, inLane)[0]; unsigned expected = contig * (1 << inLane); @@ -622,7 +621,7 @@ bool canCoalesceWriteIntoSharedMemory(MLIRContext *ctx, assert(llvm::isPowerOf2_32(threadsPerWarp)); assert(llvm::isPowerOf2_32(contig)); unsigned mask = (threadsPerWarp * contig) - 1; - StringAttr kWarp = StringAttr::get(ctx, "warp"); + StringAttr kWarp = rewriter.getStringAttr("warp"); for (int inWarp : llvm::seq(srcToSharedLayout.getInDimSizeLog2(kWarp))) { auto basis = srcToSharedLayout.getBasis(kWarp, inWarp)[0]; if ((basis & mask) != 0) { @@ -636,7 +635,7 @@ bool canCoalesceWriteIntoSharedMemory(MLIRContext *ctx, return true; } -bool doesSwizzleInsideWarp(MLIRContext *ctx, +bool doesSwizzleInsideWarp(RewriterBase &rewriter, const LinearLayout &srcToSharedLayout, unsigned threadsPerWarp) { auto contig = srcToSharedLayout.getNumConsecutiveInOut(); @@ -645,7 +644,7 @@ bool doesSwizzleInsideWarp(MLIRContext *ctx, assert(llvm::isPowerOf2_32(threadsPerWarp)); unsigned upperLimit = threadsPerWarp * contig; - StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kLane = rewriter.getStringAttr("lane"); for (int inLane : llvm::seq(srcToSharedLayout.getInDimSizeLog2(kLane))) { auto basis = srcToSharedLayout.getBasis(kLane, inLane)[0]; if (basis >= upperLimit) { @@ -655,63 +654,16 @@ bool doesSwizzleInsideWarp(MLIRContext *ctx, return true; } -// On gfx9, direct to LDS loads do not support per lane shared offsets. We need -// to ensure that we write coalesced into shared memory. This means we cannot -// exceed the supported load width because splitting them would cause strided -// (non coalesced) writes. Additionally: -// -// 1. For *non* swizzled shared encodings we check if they result in coalesced -// writes and can then lower them directly to the intrinsics. -// 2. For swizzled shared encodings we need to transfer the swizzling to the -// source pointers. For now this is done by swizzling the pointers -// between the lane of a warp via permute. This only works if the swizzle -// pattern does not exchange elements between warps which holds for all -// our swizzle patterns. There is still a check performed to not silently -// produce wrong results if we invalidate the condition in the future -bool canLoadDirectToLDS(const triton::AMD::TargetInfo &targetInfo, - RankedTensorType srcTy, Attribute dstEnc, - unsigned &vectorSize) { - // For padded encodings restrict vec by the min interval - auto paddedEnc = dyn_cast(dstEnc); - if (paddedEnc) - vectorSize = std::min(vectorSize, paddedEnc.getMinInterval()); - - int elemBitWidth = tt::getPointeeBitWidth(srcTy); - int vectorBits = vectorSize * elemBitWidth; - if (!targetInfo.supportsDirectToLdsLoadBitWidth(vectorBits)) - return false; - - if (targetInfo.supportsDirectToLDSScattering()) - return true; - - // Compute the blocked -> shared linear layout to check preconditions - LinearLayout srcLayout = triton::gpu::toLinearLayout(srcTy); - LinearLayout sharedLayout; - if (paddedEnc) { - sharedLayout = paddedEnc.getLinearComponent(); - } else { - sharedLayout = triton::gpu::toLinearLayout(srcTy.getShape(), dstEnc); - } - LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout); - - auto swizzledEnc = dyn_cast(dstEnc); - bool requiresSrcPtrSwizzling = swizzledEnc && swizzledEnc.getMaxPhase() != 1; - unsigned warpSize = targetInfo.getWarpSize(); +bool isUsedByDotScaledOp(Operation *op) { + const ForwardSliceOptions fwdOpt; + SetVector forwardSliceSet; + getForwardSlice(op, &forwardSliceSet, fwdOpt); - if (!requiresSrcPtrSwizzling && - !canCoalesceWriteIntoSharedMemory(srcTy.getContext(), srcToSharedLayout, - warpSize, vectorSize)) { - LDBG("Does not write coalesced into LDS"); - return false; - } - - if (requiresSrcPtrSwizzling && - !doesSwizzleInsideWarp(srcTy.getContext(), srcToSharedLayout, warpSize)) { - LDBG("Swizzles across warp boundaries"); - return false; - } - - return true; + return std::any_of( + forwardSliceSet.begin(), forwardSliceSet.end(), [](auto *operation) { + return isa( + operation); + }); } bool isChainDotHead(tt::DotOpInterface dotOp, unsigned opIdx) { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index 9223ee20bf5f..75beb77f635b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -97,23 +97,19 @@ Type scaleDotElemTypeToMLIRType(MLIRContext *ctx, triton::ScaleDotElemType t); // Returns true if we can perform coalesced write from the source encoding to // the destination encoding for a given vec size. -bool canCoalesceWriteIntoSharedMemory(MLIRContext *ctx, +bool canCoalesceWriteIntoSharedMemory(RewriterBase &rewriter, const LinearLayout &srcToSharedLayout, unsigned threadsPerWarp, unsigned vecSize); -// Returns true if we load directly from global |srcTy| to shared memory -// |dstEnc| for the given target. -// -// This function expects caller to pass in as |vectorSize| the vector size -// reading from global memory, after factoring in axis information and alignment -// hints. It will be updated to factor in shared memory |dstEnc| constraints. -// -// This is used by both the LLVM lowering and the conversion pattern to ensure -// consistency. -bool canLoadDirectToLDS(const triton::AMD::TargetInfo &targetInfo, - RankedTensorType srcTy, Attribute dstEnc, - unsigned &vectorSize); +// Returns true if the swizzling pattern does only swizzle the shared memory +// offsets of a warp and does not exchange destination elements across warps +bool doesSwizzleInsideWarp(RewriterBase &rewriter, + const LinearLayout &srcToSharedLayout, + unsigned threadsPerWarp); + +// Return true if op is used by DotScaledOp or UpcastMXFPOp ops. +bool isUsedByDotScaledOp(Operation *op); // Check if the result of this tl.dot is used as opA or opB of another tl.dot // in the same region diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CoalesceAsyncCopy.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CoalesceAsyncCopy.cpp index 3160e32e560b..12f22f5fc752 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CoalesceAsyncCopy.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CoalesceAsyncCopy.cpp @@ -102,8 +102,7 @@ struct CoalesceAsyncCopyWrites ttg::DistributedEncodingTrait newDistEnc; if (LLVM::AMD::canCoalesceWriteIntoSharedMemory( - copyOp.getContext(), regToSharedLayout, threadsPerWarp, - loadContig)) { + rewriter, regToSharedLayout, threadsPerWarp, loadContig)) { return rewriter.notifyMatchFailure(copyOp, "already writes coalesced"); } diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp index d9fc72a35c42..1c8494f0ff07 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp @@ -2,14 +2,12 @@ #include "Utility.h" #include "amd/lib/TritonAMDGPUToLLVM/AsyncUtility.h" #include "amd/lib/TritonAMDGPUToLLVM/TargetInfo.h" -#include "amd/lib/TritonAMDGPUToLLVM/Utility.h" #include "amd/lib/TritonAMDGPUTransforms/PipelineUtility.h" #include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "llvm/Support/Debug.h" #include -#undef DEBUG_TYPE #define DEBUG_TYPE "tritonamdgpu-pipeline-lower-loops" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") @@ -294,12 +292,13 @@ bool canBeConvertedToAsyncLoad(unsigned numBuffers, tt::LoadOp loadOp, using tt::AMD::ISAFamily; if (sharedEnc && llvm::is_contained({ISAFamily::CDNA3, ISAFamily::CDNA4}, targetInfo.getISAFamily())) { - // Compute the final vecSize we can use for source to destination type and - // encoding. We can only use async copy if the target supports the requested - // or a smaller vecSize because we cannot stride when loading directly to - // lds on GFX9. + // Compute the final vecSize we can use for the combination of + // sourceEncoding and sharedEncoding. We can only use AsyncCopy if the + // target supports the requested or a smaller vecSize because we cannot + // stride when loading directly to lds on GFX9 auto srcTy = cast(loadOp.getPtr().getType()); auto regLayout = triton::gpu::toLinearLayout(srcTy); + // It's the allocation so we trim the multibuffer dimension auto srcShape = srcTy.getShape(); triton::LinearLayout sharedLayout; auto paddedEnc = dyn_cast(sharedEnc); @@ -310,19 +309,13 @@ bool canBeConvertedToAsyncLoad(unsigned numBuffers, tt::LoadOp loadOp, } auto regToSharedLayout = regLayout.invertAndCompose(sharedLayout); + unsigned elemBitWidth = tt::getPointeeBitWidth(srcTy); unsigned vecSize = regToSharedLayout.getNumConsecutiveInOut(); - LDBG("init global to shared vector size: " << vecSize); - int elemBitWidth = tt::getPointeeBitWidth(srcTy); - vecSize = fitToValidDirectToLdsVecSize(vecSize, elemBitWidth, targetInfo); - LDBG("vector size after fitting arch direct to LDS: " << vecSize); - if (vecSize == 0) { - return false; - } + if (paddedEnc) + vecSize = std::min(vecSize, paddedEnc.getMinInterval()); - if (!LLVM::AMD::canLoadDirectToLDS(targetInfo, srcTy, sharedEnc, vecSize)) { - LDBG("cannot use direct to LDS due to arch constraints"); + if (fitToValidDirectToLdsVecSize(vecSize, elemBitWidth, targetInfo) == 0) return false; - } } // Checks whether the global pointer's contiguity and mask alignment allows diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp index 3730adedd4f8..051abe922ef4 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp @@ -22,13 +22,11 @@ int deduceMinCountBetweeOps(Operation *beginOp, Operation *endOp, int count = 0; for (auto op = beginOp; op != endOp; op = op->getNextNode()) { if (auto ifOp = llvm::dyn_cast(op)) { - assert(!ifOp.getThenRegion().empty()); + assert(!ifOp.getThenRegion().empty() && !ifOp.getElseRegion().empty()); auto minThen = deduceMinCountInBlock(ifOp.getThenRegion().front(), countFunc); - int minElse = 0; - if (!ifOp.getElseRegion().empty()) - minElse = - deduceMinCountInBlock(ifOp.getElseRegion().front(), countFunc); + auto minElse = + deduceMinCountInBlock(ifOp.getElseRegion().front(), countFunc); count += std::min(minThen, minElse); } else if (auto forOp = llvm::dyn_cast(op)) { if (std::optional tripCount = forOp.getStaticTripCount()) { From 76ec2f2f2e051f7413c7fb7fa30178412827c4d6 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Thu, 18 Dec 2025 02:00:05 +0000 Subject: [PATCH 07/11] Fix missing else case --- .../amd/amd-update-async-wait-count.mlir | 28 +++++++++++++++++++ .../lib/TritonAMDGPUTransforms/Utility.cpp | 8 ++++-- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/test/TritonGPU/amd/amd-update-async-wait-count.mlir b/test/TritonGPU/amd/amd-update-async-wait-count.mlir index 56e1eb00e609..5b29672d3c5c 100644 --- a/test/TritonGPU/amd/amd-update-async-wait-count.mlir +++ b/test/TritonGPU/amd/amd-update-async-wait-count.mlir @@ -487,3 +487,31 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return } } + +// ----- + +// Test scf.if without else region in def chain + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: scf_if_without_else + tt.func public @scf_if_without_else(%arg1: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %arg3: tensor<128x16x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>}, %cond: i1) { + // Emits 1 direct to lds instruction + %0 = ttg.async_copy_global_to_local %arg3, %arg1 : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + %1 = ttg.async_commit_group tokens %0 + + // For scf.if without else region, the else path contributes 0 instructions; + // so the minimum across both paths is 0. + scf.if %cond { + // Emits 1 direct to lds instruction inside the if + %inner = ttg.async_copy_global_to_local %arg3, %arg1 : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + %inner_commit = ttg.async_commit_group tokens %inner + } + + // CHECK: amdg.async_wait {{.*}} {num_inst = 0 + %10 = ttg.async_wait %1 {num = 0 : i32} + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp index 051abe922ef4..3730adedd4f8 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp @@ -22,11 +22,13 @@ int deduceMinCountBetweeOps(Operation *beginOp, Operation *endOp, int count = 0; for (auto op = beginOp; op != endOp; op = op->getNextNode()) { if (auto ifOp = llvm::dyn_cast(op)) { - assert(!ifOp.getThenRegion().empty() && !ifOp.getElseRegion().empty()); + assert(!ifOp.getThenRegion().empty()); auto minThen = deduceMinCountInBlock(ifOp.getThenRegion().front(), countFunc); - auto minElse = - deduceMinCountInBlock(ifOp.getElseRegion().front(), countFunc); + int minElse = 0; + if (!ifOp.getElseRegion().empty()) + minElse = + deduceMinCountInBlock(ifOp.getElseRegion().front(), countFunc); count += std::min(minThen, minElse); } else if (auto forOp = llvm::dyn_cast(op)) { if (std::optional tripCount = forOp.getStaticTripCount()) { From 87a4f0cefffb0bbe26624775dc8d3044c917276d Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Fri, 19 Dec 2025 11:23:54 +0000 Subject: [PATCH 08/11] Fix block sizes --- .../triton_kernels/matmul_details/opt_flags.py | 3 +-- .../matmul_details/opt_flags_details/opt_flags_amd.py | 6 ++---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py index ae643db67acf..79f32e94e611 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py @@ -127,13 +127,12 @@ def make_default_opt_flags_amd( num_stages = 1 # specific configs for F16 x MXFP4 on CDNA4 - # Note that these configs will exceed LDS usage with async copy enabled if is_cdna4 and bitwidth(lhs_dtype) == 16 and bitwidth(rhs_dtype) == 4 and precision_config.b_mx_scale is not None: split_k = 1 if m <= 1024: target_kernel_kwargs["waves_per_eu"] = 3 block_n = 128 - block_k = 256 + block_k = 128 num_warps = 4 else: target_kernel_kwargs["waves_per_eu"] = 0 diff --git a/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_amd.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_amd.py index 593688a74d40..be4e6476de72 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_amd.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_amd.py @@ -14,15 +14,13 @@ def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precisi if n <= 128 and (n & (n - 1)) == 0: block_n = n else: - block_n = max(32, min(256, triton.next_power_of_2(grid_m * n * num_xcds // n_cu))) + max_n = 128 if get_cdna_version() == 4 else 256 + block_n = max(32, min(max_n, triton.next_power_of_2(grid_m * n * num_xcds // n_cu))) elif block_m > 64: block_n = 256 else: block_n = 128 - if get_cdna_version() == 4 and block_m == 128: - block_n = 512 - # block_k needs to match the cacheline size (128B) block_k = int(128 // min(lhs_width, rhs_width)) From cc9a14b6536150576ed17b8181f7e53a89ca8f7a Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 23 Dec 2025 09:55:01 +0000 Subject: [PATCH 09/11] Adjust more block sizes --- .../matmul_details/opt_flags_details/opt_flags_amd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_amd.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_amd.py index be4e6476de72..7d1974b96ca4 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_amd.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_amd.py @@ -14,7 +14,7 @@ def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precisi if n <= 128 and (n & (n - 1)) == 0: block_n = n else: - max_n = 128 if get_cdna_version() == 4 else 256 + max_n = 64 if get_cdna_version() == 4 else 256 block_n = max(32, min(max_n, triton.next_power_of_2(grid_m * n * num_xcds // n_cu))) elif block_m > 64: block_n = 256 From fea23402b60f645b4467faeefa92447c4c3dcce0 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 23 Dec 2025 09:55:29 +0000 Subject: [PATCH 10/11] Adjust unti tests for LDS requirements --- python/test/unit/language/test_matmul.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/test/unit/language/test_matmul.py b/python/test/unit/language/test_matmul.py index 47dabba219d8..d334f52467d7 100644 --- a/python/test/unit/language/test_matmul.py +++ b/python/test/unit/language/test_matmul.py @@ -1299,7 +1299,7 @@ def batched_mxfp_matmul( # @pytest.mark.parametrize("BATCH_SIZE, BLOCK_BATCH_SIZE", [(1, 1), (16, 1), (16, 4)]) -@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 64), (128, 64, 128)]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 64), (128, 64, 128), (64, 64, 128)]) @pytest.mark.parametrize("NUM_STAGES", [1, 2 if is_hip() else 3]) @pytest.mark.parametrize("NUM_WARPS", [4, 8]) @pytest.mark.parametrize("nonKDim", ([0, 16, 32] if (is_hip_cdna() or is_hip_gfx1250()) else [0])) @@ -1315,6 +1315,8 @@ def test_batched_mxfp(BATCH_SIZE, BLOCK_BATCH_SIZE, BLOCK_M, BLOCK_N, BLOCK_K, N pytest.skip("Scaled mxfp8 matmul is only natively supported on CDNA4 and above") if (nonKDim == 16 and BLOCK_K < 128) or (nonKDim == 32 and BLOCK_K < 64): pytest.skip(f"CDNA4 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants") + if is_hip_cdna4() and NUM_STAGES > 1 and max(BLOCK_M, BLOCK_N) > 64: + pytest.skip("Config requires too much shared memory") torch.manual_seed(42) dtype_src_str = "float8e5" From 56873a6e34c91b5837f41c3c318fe4842795b1e1 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 23 Dec 2025 10:14:57 +0000 Subject: [PATCH 11/11] Revert "Fix missing else case" This reverts commit 76ec2f2f2e051f7413c7fb7fa30178412827c4d6. --- third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp index 3a86f2ca10ca..e33f20b07721 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp @@ -28,10 +28,8 @@ int deduceMinCountBetweeOps(Operation *beginOp, Operation *endOp, assert(!ifOp.getThenRegion().empty() && !ifOp.getElseRegion().empty()); auto minThen = deduceMinCountInBlock(ifOp.getThenRegion().front(), countFunc); - int minElse = 0; - if (!ifOp.getElseRegion().empty()) - minElse = - deduceMinCountInBlock(ifOp.getElseRegion().front(), countFunc); + auto minElse = + deduceMinCountInBlock(ifOp.getElseRegion().front(), countFunc); count += std::min(minThen, minElse); } else if (auto forOp = llvm::dyn_cast(op)) { if (std::optional tripCount = forOp.getStaticTripCount()) {