Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/triton_kernels/tests/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def _build_test_op_cases():
Case(*even_shape, "ragged", "float8_e5m2", "float8_e5m2", epilogue_subtile=val)
for val in (1, 2, 4)
])
# fp32
test_cases.extend([
Case(1024, 1000, 2048, "ragged", "float32", "float32", b_transpose=True)
])
# bfloat16 x mx
for shape in [odd_shape2, even_shape]:
test_cases.extend([
Expand Down Expand Up @@ -281,6 +285,8 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
pytest.skip("NYI: gamma and swiglu not supported together on AMD GPU")
if split_k is not None and split_k > 1:
pytest.skip("splitK hasn't been fully tested on AMD GPU.")
if "float32" in act_dtype_str:
pytest.skip("float32 not fully tested on AMD GPU")

if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3():
pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform")
Expand Down
5 changes: 5 additions & 0 deletions python/triton_kernels/triton_kernels/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,12 +386,16 @@ def matmul(a, b, bias,
block_k = None
if ragged_dimension == "K":
block_k = a_ragged_metadata.slice_sizes_divisibility or b_ragged_metadata.slice_sizes_divisibility
a_uses_tma_when_persistent = a.stride(-1) != 1 or (a_ragged_metadata.slice_sizes_divisibility is not None)
else:
a_uses_tma_when_persistent = has_gather_tma or not has_gather
opt_flags = make_opt_flags(out_dtype, a.dtype, b.dtype, precision_config,
batch_size, M, N, b.shape[-2], a_ragged_metadata,
can_use_tma, can_use_split_k, epilogue.effective_itemsize,
a_transpose, c_acc_in is not None,
block_k = block_k,
mx_block_size = mx_block_size,
x_uses_tma_when_persistent = a_uses_tma_when_persistent,
)
# there seems to be a bug on A100
# pytest -vs test_matmul.py::test_op[False-False-False-False-pad_b-16-768-512-1024-ragged-float16-float16-10-1-False-None-False-False-False-True-None]
Expand Down Expand Up @@ -485,6 +489,7 @@ def matmul(a, b, bias,
# create tma descriptor for y
c_has_tma = (
opt_flags.is_persistent and (scatter_indx is None or has_scatter_tma)
and is_tma_compliant(c)
and (c_acc_in is None or c_acc_is_c)
and fused_comm is None
and precision_config.c_value_pack_factor == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def make_default_opt_flags_nvidia(
x_transpose,
has_y_acc_in,
constraints,
x_uses_tma_when_persistent=True,
mx_block_size=None,
):
constraints_supported = {"block_m", "block_n", "block_k", "split_k", "is_persistent", "epilogue_subtile", "num_stages", "idle_sms", "max_allowable_mn", "num_warps"}
Expand Down Expand Up @@ -267,6 +268,13 @@ def _is_layout_strided(layout: Layout | None) -> bool:

# adjust block_n based on is_persistent signal
block_n = block_n_tma if is_persistent else block_n
if (is_persistent and constraints.get("block_n", None) is None
and cuda_capability_geq(10, 0) and (lhs_dtype == FP32 or rhs_dtype == FP32)
and not x_uses_tma_when_persistent):
# Blackwell's fp32/tf32 persistent dot stages an operand in TMEM in
Comment thread
yongjik marked this conversation as resolved.
# addition to the accumulator. A 128x256 accumulator already consumes
# the full 512-column TMEM budget, so leave headroom for that operand.
block_n = min(block_n, 128)
# adjust block_m based on is_persistent signal
if is_persistent and opt_flags_nvidia.is_x_scale_swizzled(precision_config):
# a mx scale has been swizzled to BlackwellActMXScaleLayout, enforce block_m=128 to align with swizzling layout
Expand Down Expand Up @@ -404,6 +412,7 @@ def make_opt_flags(
has_y_acc_in,
block_k,
mx_block_size=None,
x_uses_tma_when_persistent=True,
):
if _opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma:
raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint")
Expand All @@ -429,5 +438,9 @@ def make_opt_flags(
if backend == "hip":
return make_default_opt_flags_amd(*args)
if backend == "cuda":
return make_default_opt_flags_nvidia(*args, mx_block_size=mx_block_size)
return make_default_opt_flags_nvidia(
*args,
x_uses_tma_when_persistent=x_uses_tma_when_persistent,
mx_block_size=mx_block_size,
)
assert False
Loading