Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/test/unit/language/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ def test_mxfp(BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, nonKDim, NUM_WARPS, device)
pytest.skip("Scaled mxfp8 matmul is only natively supported on CDNA4 or above")
if (nonKDim == 16 and BLOCK_K < 128) or (nonKDim == 32 and BLOCK_K < 64):
pytest.skip(f"CDNA4 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants")
if (BLOCK_M == 256 or BLOCK_N == 256) and BLOCK_K == 256:
pytest.skip("Config requires too much shared memory")

if BLOCK_N == 256 and BLOCK_K == 256:
NUM_STAGES = min(NUM_STAGES, 2)
Expand Down Expand Up @@ -1157,6 +1159,8 @@ def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TR
pytest.skip(f"CDNA4 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants")
if (A_DATA_TYPE == 'float4' and not WITH_A_SCALE) or (B_DATA_TYPE == 'float4' and not WITH_B_SCALE):
pytest.skip("Float4 without scale is tested in test_block_scale_fp4")
if (BLOCK_M == 256 or BLOCK_N == 256) and BLOCK_K == 256:
pytest.skip("Config requires too much shared memory")
if not PACK_B_ALONG_K and B_DATA_TYPE != "float4":
pytest.skip("Pack along K can only be False for float4")
if BLOCK_N == 256 and BLOCK_K == 256:
Expand Down Expand Up @@ -1295,7 +1299,7 @@ def batched_mxfp_matmul( #


@pytest.mark.parametrize("BATCH_SIZE, BLOCK_BATCH_SIZE", [(1, 1), (16, 1), (16, 4)])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 64), (128, 64, 128)])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 64), (128, 64, 128), (64, 64, 128)])
@pytest.mark.parametrize("NUM_STAGES", [1, 2 if is_hip() else 3])
@pytest.mark.parametrize("NUM_WARPS", [4, 8])
@pytest.mark.parametrize("nonKDim", ([0, 16, 32] if (is_hip_cdna() or is_hip_gfx1250()) else [0]))
Expand All @@ -1311,6 +1315,8 @@ def test_batched_mxfp(BATCH_SIZE, BLOCK_BATCH_SIZE, BLOCK_M, BLOCK_N, BLOCK_K, N
pytest.skip("Scaled mxfp8 matmul is only natively supported on CDNA4 and above")
if (nonKDim == 16 and BLOCK_K < 128) or (nonKDim == 32 and BLOCK_K < 64):
pytest.skip(f"CDNA4 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants")
if is_hip_cdna4() and NUM_STAGES > 1 and max(BLOCK_M, BLOCK_N) > 64:
pytest.skip("Config requires too much shared memory")

torch.manual_seed(42)
dtype_src_str = "float8e5"
Expand Down
2 changes: 1 addition & 1 deletion python/triton/knobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,8 @@ class amd_knobs(base_knobs):
# We use strs so that we can have a default value based on other runtime info
use_block_pingpong: env_opt_bool = env_opt_bool("TRITON_HIP_USE_BLOCK_PINGPONG")
use_in_thread_transpose: env_opt_bool = env_opt_bool("TRITON_HIP_USE_IN_THREAD_TRANSPOSE")
use_async_copy: env_opt_bool = env_opt_bool("TRITON_HIP_USE_ASYNC_COPY")

use_async_copy: env_bool = env_bool("TRITON_HIP_USE_ASYNC_COPY")
scalarize_packed_fops: env_bool = env_bool("AMDGCN_SCALARIZE_PACKED_FOPS")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,12 @@ def make_default_opt_flags_amd(
num_stages = 1

# specific configs for F16 x MXFP4 on CDNA4
# Note that these configs will exceed LDS usage with async copy enabled
if is_cdna4 and bitwidth(lhs_dtype) == 16 and bitwidth(rhs_dtype) == 4 and precision_config.b_mx_scale is not None:
split_k = 1
if m <= 1024:
target_kernel_kwargs["waves_per_eu"] = 3
block_n = 128
block_k = 256
block_k = 128
num_warps = 4
else:
target_kernel_kwargs["waves_per_eu"] = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@ def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precisi
if n <= 128 and (n & (n - 1)) == 0:
block_n = n
else:
block_n = max(32, min(256, triton.next_power_of_2(grid_m * n * num_xcds // n_cu)))
max_n = 64 if get_cdna_version() == 4 else 256
block_n = max(32, min(max_n, triton.next_power_of_2(grid_m * n * num_xcds // n_cu)))
elif block_m > 64:
block_n = 256
else:
block_n = 128

if get_cdna_version() == 4 and block_m == 128:
block_n = 512

if get_rdna_version() in (3, 4) and block_m == 64:
block_n = 256

Expand Down
10 changes: 7 additions & 3 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@ def get_min_dot_size(target: GPUTarget):


def is_pingpong_schedule_enabled(arch, use_async_copy):
return (arch == "gfx942" or (arch == "gfx950" and use_async_copy is True)
) if knobs.amd.use_block_pingpong is None else knobs.amd.use_block_pingpong
return (arch == "gfx942" or (arch == "gfx950" and use_async_copy is True)) \
if knobs.amd.use_block_pingpong is None else knobs.amd.use_block_pingpong


def is_in_thread_transpose_enabled(arch):
return (arch == "gfx942") if knobs.amd.use_in_thread_transpose is None else knobs.amd.use_in_thread_transpose


def is_async_copy_enabled(arch):
return (arch in ["gfx950", "gfx1250"]) if knobs.amd.use_async_copy is None else knobs.amd.use_async_copy


@dataclass(frozen=True)
class HIPOptions:
num_warps: int = 4
Expand Down Expand Up @@ -227,7 +231,7 @@ def make_ttgir(mod, metadata, options):
passes.ttir.add_triton_licm(pm)
passes.common.add_canonicalizer(pm)

use_async_copy = knobs.amd.use_async_copy
use_async_copy = is_async_copy_enabled(options.arch)
use_block_pingpong = is_pingpong_schedule_enabled(options.arch, use_async_copy)

amd.passes.ttgpuir.add_schedule_loops(pm, options.num_stages)
Expand Down
Loading