Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pytest
import torch
import triton
import triton.language as tl
from triton._internal_testing import is_cuda

from triton_kernels.matmul import matmul, matmul_torch, PrecisionConfig
from triton_kernels.matmul_details._matmul import _compute_packed_n_w
from triton_kernels.matmul_details.opt_flags import InapplicableConstraint, scoped_opt_flags_constraints
from triton_kernels.matmul_details.opt_flags_details import opt_flags_nvidia
from triton_kernels.numerics_details.mxfp import MXFP_BLOCK_SIZE, downcast_to_mxfp
Expand Down Expand Up @@ -104,6 +107,48 @@ def test_matmul_hopper_mxfp4_rhs_scale_padding_is_masked(device, constraints):
torch.testing.assert_close(actual, expected, rtol=0, atol=0)


@triton.jit
def _hopper_rhs_packed_n_extent(out, n: tl.constexpr):
tl.store(out, _compute_packed_n_w(n, 4, "HOPPER_VALUE"))


def test_matmul_hopper_mxfp4_rhs_packed_n_padding(device):
if device != "cuda" or not torch.cuda.is_available() or not is_cuda():
pytest.skip("requires CUDA")
if torch.cuda.get_device_capability()[0] != 9:
pytest.skip("requires Hopper")

torch.manual_seed(0)
# Hopper MXFP4 RHS values are stored with N packed by 4 and then padded in
# packed space. The generic kernel must ceil-divide before padding and wrap
# using that padded packed width.
n = 258
packed_n = torch.empty((1, ), dtype=torch.int32, device=device)
_hopper_rhs_packed_n_extent[(1, )](packed_n, n)
assert packed_n.item() == 128

m, k = 64, 2048
a = torch.randn((m, k), device=device, dtype=torch.bfloat16)
weight_fp = torch.randn((n, k), device=device, dtype=torch.bfloat16).T
weight_val, weight_scale = downcast_to_mxfp(weight_fp, torch.uint8, axis=-2)
value_layout = layout.make_default_matmul_mxfp4_w_layout(mx_axis=-2)
scale_layout = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=-2, num_warps=8)
b = convert_layout(wrap_torch_tensor(weight_val, dtype=FP4), value_layout)
b_scale = convert_layout(wrap_torch_tensor(weight_scale, dtype=UINT8), scale_layout)
precision_config = PrecisionConfig(
b_mx_scale=b_scale,
b_microblock_size=MXFP_BLOCK_SIZE.value,
out_dtype=a.dtype,
)

with scoped_opt_flags_constraints({"is_persistent": False, "block_n": 256}):
expected = matmul_torch(a, b, None, precision_config=precision_config)
actual = matmul(a, b, None, precision_config=precision_config)

assert torch.isfinite(actual).all()
assert_close(expected, actual, maxtol=3e-2, rmstol=None)


@pytest.mark.parametrize("n, expected", [(64, 128), (200, 256)])
def test_compute_block_n_blackwell_scale_aligns_to_128(n, expected):
precision_config = PrecisionConfig(
Expand Down
14 changes: 10 additions & 4 deletions python/triton_kernels/triton_kernels/matmul_details/_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def round_f32_to_tf32(x: tl.tensor):
ASM: tl.constexpr = "cvt.rn.tf32.f32 $0, $1;" if cuda_capability_geq(9, 0) else "cvt.rna.tf32.f32 $0, $1;"
return tl.inline_asm_elementwise(ASM, "=r, r", [x], dtype=tl.float32, is_pure=True, pack=1)


@triton.jit
def _compute_packed_n_w(N, W_N_DIVISOR: tl.constexpr, SWIZZLE_MX_VALUE: tl.constexpr):
packed_n_w = tl.cdiv(N, W_N_DIVISOR)
if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
packed_n_w = tl.cdiv(packed_n_w, 64) * 64
return packed_n_w

_matmul_repr = make_matmul_repr("_matmul", [0, 1, 2])
@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
repr=_matmul_repr, launch_metadata=matmul_launch_metadata)
Expand Down Expand Up @@ -338,10 +346,8 @@ def _matmul(

# B pointers
offs_w_n = pid_n * PACKED_BLOCK_N_W + tl.arange(0, PACKED_BLOCK_N_W)
N_W = N
if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
N_W = tl.cdiv(N_W, 64) * 64
offs_w_n = tl.max_contiguous(tl.multiple_of(offs_w_n % (N_W // W_N_DIVISOR), PACKED_BLOCK_N_W), PACKED_BLOCK_N_W)
packed_n_w = _compute_packed_n_w(N, W_N_DIVISOR, SWIZZLE_MX_VALUE)
offs_w_n = tl.max_contiguous(tl.multiple_of(offs_w_n % packed_n_w, PACKED_BLOCK_N_W), PACKED_BLOCK_N_W)

if is_x_microscaled:
XMxScale += start_z.to(index_type) * stride_x_mx_z
Expand Down
Loading