From 511a7f5a04c5f701e99481af0fb734ee9ff952ea Mon Sep 17 00:00:00 2001 From: Austin Eng Date: Tue, 26 May 2026 22:55:11 -0700 Subject: [PATCH 1/2] [KERNELS] fix hopper mxfp4 swizzle bug --- .../test_opt_flags_nvidia.py | 45 +++++++++++++++++++ .../triton_kernels/matmul_details/_matmul.py | 15 +++++-- 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/python/triton_kernels/tests/test_matmul_details/test_opt_flags_nvidia.py b/python/triton_kernels/tests/test_matmul_details/test_opt_flags_nvidia.py index 45fbb40d59a0..5587fec8c80a 100644 --- a/python/triton_kernels/tests/test_matmul_details/test_opt_flags_nvidia.py +++ b/python/triton_kernels/tests/test_matmul_details/test_opt_flags_nvidia.py @@ -1,8 +1,11 @@ import pytest import torch +import triton +import triton.language as tl from triton._internal_testing import is_cuda from triton_kernels.matmul import matmul, matmul_torch, PrecisionConfig +from triton_kernels.matmul_details._matmul import _compute_packed_n_w from triton_kernels.matmul_details.opt_flags import InapplicableConstraint, scoped_opt_flags_constraints from triton_kernels.matmul_details.opt_flags_details import opt_flags_nvidia from triton_kernels.numerics_details.mxfp import MXFP_BLOCK_SIZE, downcast_to_mxfp @@ -104,6 +107,48 @@ def test_matmul_hopper_mxfp4_rhs_scale_padding_is_masked(device, constraints): torch.testing.assert_close(actual, expected, rtol=0, atol=0) +@triton.jit +def _hopper_rhs_packed_n_extent(out, n: tl.constexpr): + tl.store(out, _compute_packed_n_w(n, 4, "HOPPER_VALUE")) + + +def test_matmul_hopper_mxfp4_rhs_packed_n_padding(device): + if device != "cuda" or not torch.cuda.is_available() or not is_cuda(): + pytest.skip("requires CUDA") + if torch.cuda.get_device_capability()[0] != 9: + pytest.skip("requires Hopper") + + torch.manual_seed(0) + # Hopper MXFP4 RHS values are stored with N packed by 4 and then padded in + # packed space. The generic kernel must wrap using that padded packed width, + # not by padding logical N first and dividing afterward. + n = 3456 + packed_n = torch.empty((1,), dtype=torch.int32, device=device) + _hopper_rhs_packed_n_extent[(1,)](packed_n, n) + assert packed_n.item() == 896 + + m, k = 64, 2048 + a = torch.randn((m, k), device=device, dtype=torch.bfloat16) + weight_fp = torch.randn((n, k), device=device, dtype=torch.bfloat16).T + weight_val, weight_scale = downcast_to_mxfp(weight_fp, torch.uint8, axis=-2) + value_layout = layout.make_default_matmul_mxfp4_w_layout(mx_axis=-2) + scale_layout = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=-2, num_warps=8) + b = convert_layout(wrap_torch_tensor(weight_val, dtype=FP4), value_layout) + b_scale = convert_layout(wrap_torch_tensor(weight_scale, dtype=UINT8), scale_layout) + precision_config = PrecisionConfig( + b_mx_scale=b_scale, + b_microblock_size=MXFP_BLOCK_SIZE.value, + out_dtype=a.dtype, + ) + + with scoped_opt_flags_constraints({"is_persistent": False, "block_n": 256}): + expected = matmul_torch(a, b, None, precision_config=precision_config) + actual = matmul(a, b, None, precision_config=precision_config) + + assert torch.isfinite(actual).all() + assert_close(expected, actual, maxtol=3e-2, rmstol=None) + + @pytest.mark.parametrize("n, expected", [(64, 128), (200, 256)]) def test_compute_block_n_blackwell_scale_aligns_to_128(n, expected): precision_config = PrecisionConfig( diff --git a/python/triton_kernels/triton_kernels/matmul_details/_matmul.py b/python/triton_kernels/triton_kernels/matmul_details/_matmul.py index c30b3fb518d2..2d10bb353477 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/_matmul.py +++ b/python/triton_kernels/triton_kernels/matmul_details/_matmul.py @@ -29,6 +29,15 @@ def round_f32_to_tf32(x: tl.tensor): ASM: tl.constexpr = "cvt.rn.tf32.f32 $0, $1;" if cuda_capability_geq(9, 0) else "cvt.rna.tf32.f32 $0, $1;" return tl.inline_asm_elementwise(ASM, "=r, r", [x], dtype=tl.float32, is_pure=True, pack=1) + +@triton.jit +def _compute_packed_n_w(N, W_N_DIVISOR: tl.constexpr, SWIZZLE_MX_VALUE: tl.constexpr): + packed_n_w = N // W_N_DIVISOR + if SWIZZLE_MX_VALUE == "HOPPER_VALUE": + packed_n_w = tl.cdiv(packed_n_w, 64) * 64 + return packed_n_w + + _matmul_repr = make_matmul_repr("_matmul", [0, 1, 2]) @triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"], repr=_matmul_repr, launch_metadata=matmul_launch_metadata) @@ -338,10 +347,8 @@ def _matmul( # B pointers offs_w_n = pid_n * PACKED_BLOCK_N_W + tl.arange(0, PACKED_BLOCK_N_W) - N_W = N - if SWIZZLE_MX_VALUE == "HOPPER_VALUE": - N_W = tl.cdiv(N_W, 64) * 64 - offs_w_n = tl.max_contiguous(tl.multiple_of(offs_w_n % (N_W // W_N_DIVISOR), PACKED_BLOCK_N_W), PACKED_BLOCK_N_W) + packed_n_w = _compute_packed_n_w(N, W_N_DIVISOR, SWIZZLE_MX_VALUE) + offs_w_n = tl.max_contiguous(tl.multiple_of(offs_w_n % packed_n_w, PACKED_BLOCK_N_W), PACKED_BLOCK_N_W) if is_x_microscaled: XMxScale += start_z.to(index_type) * stride_x_mx_z From d32f6b400fa0002907f0e4b977f47857d64366a8 Mon Sep 17 00:00:00 2001 From: Austin Eng Date: Wed, 27 May 2026 09:47:31 -0700 Subject: [PATCH 2/2] [KERNELS] ceil hopper packed N before padding --- .../test_matmul_details/test_opt_flags_nvidia.py | 12 ++++++------ .../triton_kernels/matmul_details/_matmul.py | 3 +-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/python/triton_kernels/tests/test_matmul_details/test_opt_flags_nvidia.py b/python/triton_kernels/tests/test_matmul_details/test_opt_flags_nvidia.py index 5587fec8c80a..d04f4e60902d 100644 --- a/python/triton_kernels/tests/test_matmul_details/test_opt_flags_nvidia.py +++ b/python/triton_kernels/tests/test_matmul_details/test_opt_flags_nvidia.py @@ -120,12 +120,12 @@ def test_matmul_hopper_mxfp4_rhs_packed_n_padding(device): torch.manual_seed(0) # Hopper MXFP4 RHS values are stored with N packed by 4 and then padded in - # packed space. The generic kernel must wrap using that padded packed width, - # not by padding logical N first and dividing afterward. - n = 3456 - packed_n = torch.empty((1,), dtype=torch.int32, device=device) - _hopper_rhs_packed_n_extent[(1,)](packed_n, n) - assert packed_n.item() == 896 + # packed space. The generic kernel must ceil-divide before padding and wrap + # using that padded packed width. + n = 258 + packed_n = torch.empty((1, ), dtype=torch.int32, device=device) + _hopper_rhs_packed_n_extent[(1, )](packed_n, n) + assert packed_n.item() == 128 m, k = 64, 2048 a = torch.randn((m, k), device=device, dtype=torch.bfloat16) diff --git a/python/triton_kernels/triton_kernels/matmul_details/_matmul.py b/python/triton_kernels/triton_kernels/matmul_details/_matmul.py index 2d10bb353477..5f638b9d833a 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/_matmul.py +++ b/python/triton_kernels/triton_kernels/matmul_details/_matmul.py @@ -32,12 +32,11 @@ def round_f32_to_tf32(x: tl.tensor): @triton.jit def _compute_packed_n_w(N, W_N_DIVISOR: tl.constexpr, SWIZZLE_MX_VALUE: tl.constexpr): - packed_n_w = N // W_N_DIVISOR + packed_n_w = tl.cdiv(N, W_N_DIVISOR) if SWIZZLE_MX_VALUE == "HOPPER_VALUE": packed_n_w = tl.cdiv(packed_n_w, 64) * 64 return packed_n_w - _matmul_repr = make_matmul_repr("_matmul", [0, 1, 2]) @triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"], repr=_matmul_repr, launch_metadata=matmul_launch_metadata)