diff --git a/python/test/unit/language/test_matmul.py b/python/test/unit/language/test_matmul.py index 28e7ff314cfd..d334f52467d7 100644 --- a/python/test/unit/language/test_matmul.py +++ b/python/test/unit/language/test_matmul.py @@ -361,6 +361,8 @@ def test_mxfp(BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, nonKDim, NUM_WARPS, device) pytest.skip("Scaled mxfp8 matmul is only natively supported on CDNA4 or above") if (nonKDim == 16 and BLOCK_K < 128) or (nonKDim == 32 and BLOCK_K < 64): pytest.skip(f"CDNA4 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants") + if (BLOCK_M == 256 or BLOCK_N == 256) and BLOCK_K == 256: + pytest.skip("Config requires too much shared memory") if BLOCK_N == 256 and BLOCK_K == 256: NUM_STAGES = min(NUM_STAGES, 2) @@ -1157,6 +1159,8 @@ def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TR pytest.skip(f"CDNA4 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants") if (A_DATA_TYPE == 'float4' and not WITH_A_SCALE) or (B_DATA_TYPE == 'float4' and not WITH_B_SCALE): pytest.skip("Float4 without scale is tested in test_block_scale_fp4") + if (BLOCK_M == 256 or BLOCK_N == 256) and BLOCK_K == 256: + pytest.skip("Config requires too much shared memory") if not PACK_B_ALONG_K and B_DATA_TYPE != "float4": pytest.skip("Pack along K can only be False for float4") if BLOCK_N == 256 and BLOCK_K == 256: @@ -1295,7 +1299,7 @@ def batched_mxfp_matmul( # @pytest.mark.parametrize("BATCH_SIZE, BLOCK_BATCH_SIZE", [(1, 1), (16, 1), (16, 4)]) -@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 64), (128, 64, 128)]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 64), (128, 64, 128), (64, 64, 128)]) @pytest.mark.parametrize("NUM_STAGES", [1, 2 if is_hip() else 3]) @pytest.mark.parametrize("NUM_WARPS", [4, 8]) @pytest.mark.parametrize("nonKDim", ([0, 16, 32] if (is_hip_cdna() or is_hip_gfx1250()) else [0])) @@ -1311,6 +1315,8 @@ def test_batched_mxfp(BATCH_SIZE, BLOCK_BATCH_SIZE, BLOCK_M, BLOCK_N, BLOCK_K, N pytest.skip("Scaled mxfp8 matmul is only natively supported on CDNA4 and above") if (nonKDim == 16 and BLOCK_K < 128) or (nonKDim == 32 and BLOCK_K < 64): pytest.skip(f"CDNA4 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants") + if is_hip_cdna4() and NUM_STAGES > 1 and max(BLOCK_M, BLOCK_N) > 64: + pytest.skip("Config requires too much shared memory") torch.manual_seed(42) dtype_src_str = "float8e5" diff --git a/python/triton/knobs.py b/python/triton/knobs.py index b55bd3de3218..e537f4123332 100644 --- a/python/triton/knobs.py +++ b/python/triton/knobs.py @@ -513,8 +513,8 @@ class amd_knobs(base_knobs): # We use strs so that we can have a default value based on other runtime info use_block_pingpong: env_opt_bool = env_opt_bool("TRITON_HIP_USE_BLOCK_PINGPONG") use_in_thread_transpose: env_opt_bool = env_opt_bool("TRITON_HIP_USE_IN_THREAD_TRANSPOSE") + use_async_copy: env_opt_bool = env_opt_bool("TRITON_HIP_USE_ASYNC_COPY") - use_async_copy: env_bool = env_bool("TRITON_HIP_USE_ASYNC_COPY") scalarize_packed_fops: env_bool = env_bool("AMDGCN_SCALARIZE_PACKED_FOPS") diff --git a/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py index 8e976fd8c208..3be552e75e9c 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py @@ -129,13 +129,12 @@ def make_default_opt_flags_amd( num_stages = 1 # specific configs for F16 x MXFP4 on CDNA4 - # Note that these configs will exceed LDS usage with async copy enabled if is_cdna4 and bitwidth(lhs_dtype) == 16 and bitwidth(rhs_dtype) == 4 and precision_config.b_mx_scale is not None: split_k = 1 if m <= 1024: target_kernel_kwargs["waves_per_eu"] = 3 block_n = 128 - block_k = 256 + block_k = 128 num_warps = 4 else: target_kernel_kwargs["waves_per_eu"] = 0 diff --git a/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_amd.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_amd.py index d1eb0d66f4b6..01b2d36eee98 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_amd.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_amd.py @@ -14,15 +14,13 @@ def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precisi if n <= 128 and (n & (n - 1)) == 0: block_n = n else: - block_n = max(32, min(256, triton.next_power_of_2(grid_m * n * num_xcds // n_cu))) + max_n = 64 if get_cdna_version() == 4 else 256 + block_n = max(32, min(max_n, triton.next_power_of_2(grid_m * n * num_xcds // n_cu))) elif block_m > 64: block_n = 256 else: block_n = 128 - if get_cdna_version() == 4 and block_m == 128: - block_n = 512 - if get_rdna_version() in (3, 4) and block_m == 64: block_n = 256 diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index ff2f87c3bc2f..3ae7497f969e 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -19,14 +19,18 @@ def get_min_dot_size(target: GPUTarget): def is_pingpong_schedule_enabled(arch, use_async_copy): - return (arch == "gfx942" or (arch == "gfx950" and use_async_copy is True) - ) if knobs.amd.use_block_pingpong is None else knobs.amd.use_block_pingpong + return (arch == "gfx942" or (arch == "gfx950" and use_async_copy is True)) \ + if knobs.amd.use_block_pingpong is None else knobs.amd.use_block_pingpong def is_in_thread_transpose_enabled(arch): return (arch == "gfx942") if knobs.amd.use_in_thread_transpose is None else knobs.amd.use_in_thread_transpose +def is_async_copy_enabled(arch): + return (arch in ["gfx950", "gfx1250"]) if knobs.amd.use_async_copy is None else knobs.amd.use_async_copy + + @dataclass(frozen=True) class HIPOptions: num_warps: int = 4 @@ -227,7 +231,7 @@ def make_ttgir(mod, metadata, options): passes.ttir.add_triton_licm(pm) passes.common.add_canonicalizer(pm) - use_async_copy = knobs.amd.use_async_copy + use_async_copy = is_async_copy_enabled(options.arch) use_block_pingpong = is_pingpong_schedule_enabled(options.arch, use_async_copy) amd.passes.ttgpuir.add_schedule_loops(pm, options.num_stages)