diff --git a/python/triton_kernels/tests/test_matmul.py b/python/triton_kernels/tests/test_matmul.py index 787118d37e45..6cbe32d7291d 100644 --- a/python/triton_kernels/tests/test_matmul.py +++ b/python/triton_kernels/tests/test_matmul.py @@ -127,6 +127,10 @@ def _build_test_op_cases(): Case(*even_shape, "ragged", "float8_e5m2", "float8_e5m2", epilogue_subtile=val) for val in (1, 2, 4) ]) + # fp32 + test_cases.extend([ + Case(1024, 1000, 2048, "ragged", "float32", "float32", b_transpose=True) + ]) # bfloat16 x mx for shape in [odd_shape2, even_shape]: test_cases.extend([ @@ -281,6 +285,8 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, pytest.skip("NYI: gamma and swiglu not supported together on AMD GPU") if split_k is not None and split_k > 1: pytest.skip("splitK hasn't been fully tested on AMD GPU.") + if "float32" in act_dtype_str: + pytest.skip("float32 not fully tested on AMD GPU") if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3(): pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform") diff --git a/python/triton_kernels/triton_kernels/matmul.py b/python/triton_kernels/triton_kernels/matmul.py index b769ae92c052..7b85ddf66a58 100644 --- a/python/triton_kernels/triton_kernels/matmul.py +++ b/python/triton_kernels/triton_kernels/matmul.py @@ -386,12 +386,16 @@ def matmul(a, b, bias, block_k = None if ragged_dimension == "K": block_k = a_ragged_metadata.slice_sizes_divisibility or b_ragged_metadata.slice_sizes_divisibility + a_uses_tma_when_persistent = a.stride(-1) != 1 or (a_ragged_metadata.slice_sizes_divisibility is not None) + else: + a_uses_tma_when_persistent = has_gather_tma or not has_gather opt_flags = make_opt_flags(out_dtype, a.dtype, b.dtype, precision_config, batch_size, M, N, b.shape[-2], a_ragged_metadata, can_use_tma, can_use_split_k, epilogue.effective_itemsize, a_transpose, c_acc_in is not None, block_k = block_k, mx_block_size = mx_block_size, + x_uses_tma_when_persistent = a_uses_tma_when_persistent, ) # there seems to be a bug on A100 # pytest -vs test_matmul.py::test_op[False-False-False-False-pad_b-16-768-512-1024-ragged-float16-float16-10-1-False-None-False-False-False-True-None] @@ -485,6 +489,7 @@ def matmul(a, b, bias, # create tma descriptor for y c_has_tma = ( opt_flags.is_persistent and (scatter_indx is None or has_scatter_tma) + and is_tma_compliant(c) and (c_acc_in is None or c_acc_is_c) and fused_comm is None and precision_config.c_value_pack_factor == 1 diff --git a/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py index b6d185829183..736c8f690413 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py @@ -190,6 +190,7 @@ def make_default_opt_flags_nvidia( x_transpose, has_y_acc_in, constraints, + x_uses_tma_when_persistent=True, mx_block_size=None, ): constraints_supported = {"block_m", "block_n", "block_k", "split_k", "is_persistent", "epilogue_subtile", "num_stages", "idle_sms", "max_allowable_mn", "num_warps"} @@ -267,6 +268,13 @@ def _is_layout_strided(layout: Layout | None) -> bool: # adjust block_n based on is_persistent signal block_n = block_n_tma if is_persistent else block_n + if (is_persistent and constraints.get("block_n", None) is None + and cuda_capability_geq(10, 0) and (lhs_dtype == FP32 or rhs_dtype == FP32) + and not x_uses_tma_when_persistent): + # Blackwell's fp32/tf32 persistent dot stages an operand in TMEM in + # addition to the accumulator. A 128x256 accumulator already consumes + # the full 512-column TMEM budget, so leave headroom for that operand. + block_n = min(block_n, 128) # adjust block_m based on is_persistent signal if is_persistent and opt_flags_nvidia.is_x_scale_swizzled(precision_config): # a mx scale has been swizzled to BlackwellActMXScaleLayout, enforce block_m=128 to align with swizzling layout @@ -404,6 +412,7 @@ def make_opt_flags( has_y_acc_in, block_k, mx_block_size=None, + x_uses_tma_when_persistent=True, ): if _opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma: raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint") @@ -429,5 +438,9 @@ def make_opt_flags( if backend == "hip": return make_default_opt_flags_amd(*args) if backend == "cuda": - return make_default_opt_flags_nvidia(*args, mx_block_size=mx_block_size) + return make_default_opt_flags_nvidia( + *args, + x_uses_tma_when_persistent=x_uses_tma_when_persistent, + mx_block_size=mx_block_size, + ) assert False