Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions python/triton_kernels/tests/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from triton_kernels.matmul import matmul_set_idle_sms, matmul, matmul_torch
# numerics utilities
from triton_kernels.numerics import InFlexData, OutFlexData
from triton_kernels.numerics_details.mxfp import upcast_from_mxfp, quantize_mxfp8_fn, downcast_to_mxfp_torch, upcast_from_mxfp_torch, MXFP_BLOCK_SIZE, NVFP_BLOCK_SIZE
from triton_kernels.numerics_details.mxfp import upcast_from_mxfp, quantize_mxfp8_fn, quantize_nvfp4_fn, downcast_to_mxfp_torch, upcast_from_mxfp_torch, MXFP_BLOCK_SIZE, NVFP_BLOCK_SIZE
# testing utilities
from triton_kernels.testing import assert_close, make_random_tensor
# target-specific utilities
Expand All @@ -30,6 +30,8 @@
class DType:

def __init__(self, dtype_str):
# This tracks the regular fp8 flex scale path. NVFP4 has a tensor scale,
# but it is handled separately because it also has MX microscale storage.
self.has_global_scale = dtype_str.startswith("float8")
self.is_nvfp4 = dtype_str == "nvfp4_e2m1"
self.has_mx_scale = dtype_str.startswith("mx") or self.is_nvfp4
Expand Down Expand Up @@ -218,6 +220,13 @@ def _build_test_op_cases():
Case(*shape, mode, "mxfloat8_e4m3fn", "mxfloat4_e2m1", a_hbm_swizzling=True, b_hbm_swizzling=True, split_k=split_k, swiglu_opts=(1.1, 7))
for shape in [odd_shape2, even_shape] for mode in ["ragged", "batched"] for split_k in [1, 5]
])
# swiglu together with nvfp4 downcast epilogue
test_cases.extend([
Case(*shape, mode, "bfloat16", "bfloat16", "nvfp4_e2m1", swiglu_opts=(1.1, 7.0))
for shape in [even_shape]
for mode in ["ragged", "batched"]
])
test_cases.append(Case(256, 2048, 1024, "plain", "bfloat16", "bfloat16", "nvfp4_e2m1", swiglu_opts=(1.1, 7.0)))

return test_cases

Expand Down Expand Up @@ -268,6 +277,8 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
if is_cuda():
if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9:
pytest.skip("Float8 not tested on A100")
if output_dtype_str == "nvfp4_e2m1" and torch.cuda.get_device_capability()[0] < 9:
pytest.skip("NVFP4 output scales use fp8e4nv, which is not supported on A100")
if act_dtype_str == "float16" and weight_uses_mx and torch.cuda.get_device_capability()[0] >= 10:
pytest.skip("float16 x mx not supported with cuda capability >= 10")
if weight_uses_mx:
Expand Down Expand Up @@ -437,7 +448,12 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
wrap_list = lambda vals: torch.tensor(vals, dtype=torch.float32, device=device)
flex_a = InFlexData(c_dtype.torch_dtype, wrap_list([1.25])) if c_dtype.has_global_scale else InFlexData()
flex_b = InFlexData(b_dtype.torch_dtype, wrap_list([1.25])) if b_dtype.has_global_scale else InFlexData()
flex_c = OutFlexData(c_dtype.torch_dtype, wrap_list([4.00]), wrap_list([0]), None) if c_dtype.has_global_scale else OutFlexData()
if c_dtype.has_global_scale:
flex_c = OutFlexData(c_dtype.torch_dtype, wrap_list([4.00]), wrap_list([0]), None)
elif c_dtype.is_nvfp4:
flex_c = OutFlexData(c_dtype.torch_dtype, wrap_list([0.125]), None, None)
else:
flex_c = OutFlexData(c_dtype.torch_dtype, None, None, None)
precision_opt = PrecisionConfig(
flex_ctx=FlexCtx(flex_a, flex_b, flex_c),
acc_scale=2.0 if c_dtype.has_global_scale or b_dtype.has_global_scale else 1.0,
Expand All @@ -456,7 +472,11 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
precision_opt.c_mx_scale = c_scale
precision_opt.c_microblock_size = c_dtype.microblock_size
precision_opt.c_value_pack_factor = 2 if c_dtype.is_mxfloat4 else 1
epilogue_spec = FnSpecs(FnName.QUANTIZE_MXFP8.name, quantize_mxfp8_fn, (), ())
epilogue_spec = (
FnSpecs(FnName.QUANTIZE_NVFP4.name, quantize_nvfp4_fn, (), ())
if c_dtype.is_nvfp4
else FnSpecs(FnName.QUANTIZE_MXFP8.name, quantize_mxfp8_fn, (), ())
)
epilogue = Epilogue(epilogue_spec, tuple(), tuple(), effective_itemsize=6.0)


Expand All @@ -472,10 +492,21 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
except (opt_flags.InapplicableConstraint, NotImplementedError) as e:
pytest.skip(f"inapplicable opt_flags constraint {e}")
# --- torch implementation ---
ref_y = matmul_torch(a, b, bias, #
a_ragged_metadata, b_ragged_metadata,
gather_indx, scatter_indx, precision_opt,
gammas=gammas)
# Fused NVFP4 output quantizes the float32 activation result and applies
# expected_scale inside downcast_to_mxfp_torch, so keep the reference in
# float32 until that final downcast instead of letting matmul_torch
# return bf16 and apply the output scale early.
ref_y = matmul_torch(
a.float() if c_dtype.is_nvfp4 else a,
b.float() if c_dtype.is_nvfp4 else b,
bias,
a_ragged_metadata,
b_ragged_metadata,
gather_indx,
scatter_indx,
PrecisionConfig() if c_dtype.is_nvfp4 else precision_opt,
gammas=gammas,
)
if swiglu_opts is not None:
ref_y = swiglu(ref_y, alpha=swiglu_opts[0], precision_config=SwiGLUPrecisionConfig(swiglu_opts[1]))
if c_dtype.has_global_scale:
Expand All @@ -491,6 +522,7 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
axis=-1,
scale_dtype=c_dtype.scale_dtype,
microblock_size=c_dtype.microblock_size,
expected_scale=precision_opt.flex_ctx.out_data.expected_scale,
)
ref_y = upcast_from_mxfp_torch(ref_y, ref_scale, target_dtype=ref_target_dtype, axis=-1)
maxtol, rmstol = None, None
Expand Down
13 changes: 11 additions & 2 deletions python/triton_kernels/triton_kernels/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,16 @@ def matmul(a, b, bias,
assert a_microblock_size == b_microblock_size, (
f"Microscaled operands must share a block size. Got {a_microblock_size} and {b_microblock_size}"
)
mx_block_size = b_microblock_size or a_microblock_size or int(MXFP_BLOCK_SIZE)
mx_block_size = (
a_microblock_size or b_microblock_size or precision_config.c_microblock_size or int(MXFP_BLOCK_SIZE)
)
assert all(
size is None or size == mx_block_size
for size in (a_microblock_size, b_microblock_size, precision_config.c_microblock_size)
), (
"Microscaled operands/output must share a block size. "
f"Got a={a_microblock_size}, b={b_microblock_size}, c={precision_config.c_microblock_size}"
)
if precision_config.c_mx_scale is not None and precision_config.c_microblock_size is None:
precision_config.c_microblock_size = mx_block_size
precision_config.c_value_pack_factor = 2 if precision_config.c_mx_scale is not None and epilogue.specs.name in (
Expand Down Expand Up @@ -588,7 +597,7 @@ def matmul(a, b, bias,
n_valid_slices = b_tensor_or_tma.shape[0] if ragged_dimension == "M" else n_slices
(kernels._p_matmul if opt_flags.is_persistent else kernels._matmul)[(grid,)](
c_tensor_or_tma, c.storage.data, *out_matmul.stride(),
*((None, out_matmul_scale, None) if out_matmul_has_mx else out_matmul_flex),
*((out_matmul_flex.expected_scale, out_matmul_scale, None) if out_matmul_has_mx else out_matmul_flex),
*out_matmul_scale_strides[-4:],
a_tensor_or_tma, a.storage.data, *a_strides, a_transpose,
flex.lhs_data.scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,12 @@ def _matmul(
MX_SCALE_BLOCK_N: tl.constexpr = OUT_BLOCK_N // MX_BLOCK_SIZE
N_MX_BLOCK = tl.cdiv(N, MX_BLOCK_SIZE)
tl.static_assert(EPILOGUE_FN is not None)
if PER_BATCH_OUT_SCALE:
YExpectedScale = YExpectedScale + start_z_out
# OCP MX outputs leave YExpectedScale unset, so this is an identity there.
# NVFP4 uses YExpectedScale to precondition the dense output before the
# microscaling epilogue writes direct e4m3 block scales.
out = float_to_flex(out, YExpectedScale, None, None, mask, Y, False)
out, out_scale = EPILOGUE_FN(out, mask, *epilogue_fn_args)
tl.static_assert(BLOCK_N % MX_SCALE_BLOCK_N == 0, "")
offs_y_n_scale = MX_SCALE_BLOCK_N * pid_n + tl.arange(0, MX_SCALE_BLOCK_N)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,23 @@ def _p_matmul(
tl.static_assert(EPILOGUE_FN is not None)
offs_y_n = out_off_n + tl.arange(0, OUT_BLOCK_N)
mask_n = offs_y_n < yN
if PER_BATCH_OUT_SCALE:
ExpectedScale = YExpectedScale + start_z1
else:
ExpectedScale = YExpectedScale
# OCP MX outputs leave YExpectedScale unset, so this is an
# identity there. NVFP4 uses YExpectedScale to precondition the
# dense output before the microscaling epilogue writes direct
# e4m3 block scales.
out = float_to_flex(
out,
ExpectedScale,
None,
None,
mask_m[:, None] & mask_n[None, :],
YPtr,
False,
)
out, out_scale = EPILOGUE_FN(out, mask_m[:, None] & mask_n[None, :], *epilogue_fn_args)
tl.static_assert(BLOCK_N % MX_SCALE_BLOCK_N == 0, "")
offs_y_n_scale = off_n1 // ACTIVATION_REDUCTION_N // MX_BLOCK_SIZE + a_i * MX_SCALE_BLOCK_N + tl.arange(0, MX_SCALE_BLOCK_N)
Expand Down
13 changes: 13 additions & 0 deletions python/triton_kernels/triton_kernels/numerics_details/mxfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def get_max_quant_val(dtype: torch.dtype):
def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
scale_dtype: torch.dtype = torch.uint8,
microblock_size: int = MXFP_BLOCK_SIZE.value,
expected_scale: torch.Tensor | None = None,
DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
"""
Converts the src tensor to the output format specified by out_quant_type.
Expand All @@ -191,6 +192,9 @@ def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype
• For mxfp8, the output has the same shape as src_tensor.
• For FP4, the size along the axis is halved, and the tensor is returned as a torch.uint8.
scale: Scale tensor computed per microblock along the axis.

If expected_scale is provided, apply the same preconditioning as float_to_flex
before microscaled output quantization by dividing the source tensor by it.
"""
# This should probably be packed into its own tiny class
ndim = src_tensor.ndim
Expand All @@ -217,6 +221,12 @@ def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype

# Permute the tensor so that the contiguous axis becomes the last dimension.
src = src_tensor.transpose(axis, src_tensor.ndim - 1).to(torch.float32)
if expected_scale is not None:
if expected_scale.numel() == 1:
src = src / expected_scale.to(device=device, dtype=src.dtype)
else:
assert src.ndim == 3
src = src / expected_scale.to(device=device, dtype=src.dtype)[:, None, None]
axis_shape = src.shape[-1]

# Pad the axis to be divisible by the microblock size, in case it is not.
Expand Down Expand Up @@ -251,6 +261,9 @@ def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype
ds_int_rounded = ds_int & 0x7F800000
dequant_scale_rounded = ds_int_rounded.view(torch.float32)
else:
# Direct fp8 scales keep the existing plain cast semantics here: the
# stored scale is rounded by the fp8 conversion rather than forced
# upward despite the ROUND_UP mode name.
dequant_scale_rounded = dequant_scale.to(scale_dtype).to(torch.float32)

# Compute the quantization scale.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.con
else:
tl.static_assert(mx_scale_dtype == tl.float8e4nv, f"Unsupported {mx_scale_dtype=}")
tl.static_assert(DEQUANT_SCALE_ROUNDING_MODE == 0, "Direct float8 scales only support ROUND_UP")
# Direct fp8 scales keep the existing plain cast semantics here: the
# stored scale is rounded by the fp8 conversion rather than forced
# upward despite the ROUND_UP mode name.
scale_tensor = (max_val / _get_max_quant_val(mx_tensor_dtype)).to(tl.float8e4nv)
dequant_scale_rounded = scale_tensor.to(tl.float32)
scale_tensor = scale_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE])
Expand Down
10 changes: 6 additions & 4 deletions python/triton_kernels/triton_kernels/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,16 @@ def _reduce_forward(X, stride_xr: tl.int64, stride_x0: tl.int64, stride_x1, # x
valid_y_s1 = offs_y_s1 < Y_S1
valid_y_smx1 = offs_y_smx1 < tl.cdiv(Y_S1, 1 if Y_MX_BLOCK_SIZE is None else Y_MX_BLOCK_SIZE)
is_out_fp4: tl.constexpr = YMx is not None and Y_VALUE_PACK_FACTOR == 2
y = float_to_flex(y, YFlexExpected, YFlexActual, YFlexChecksum, None, Y, Y_FLEX_SATURATE_INF)
# TODO (phil): keeping for backward compatibility, but will remove !
if YMx is None and POSTPROCESS_FN2 is not None:
y = POSTPROCESS_FN2(y, *postprocess_fn2_args, target_dtype=Y.dtype.element_ty)
if YMx is not None:
y = float_to_flex(y, YFlexExpected, None, None, None, Y, False)
y, y_scale = POSTPROCESS_MX_FN(y, valid_y_s1[None, :], *postprocess_mx_fn_args)
y_mx_ptrs = YMx + offs_s0[:, None] * stride_ymx0 + offs_y_smx1[None, :] * stride_ymx1
tl.store(y_mx_ptrs, y_scale, mask=valid_s0[:, None] & valid_y_smx1[None, :])
else:
y = float_to_flex(y, YFlexExpected, YFlexActual, YFlexChecksum, None, Y, Y_FLEX_SATURATE_INF)
# TODO (phil): keeping for backward compatibility, but will remove !
if POSTPROCESS_FN2 is not None:
y = POSTPROCESS_FN2(y, *postprocess_fn2_args, target_dtype=Y.dtype.element_ty)
if is_out_fp4:
offs_y_s1 = pid_s1 * (BLOCK_Y_S1 // 2) + tl.arange(0, BLOCK_Y_S1 // 2)
valid_y_s1 = offs_y_s1 < tl.cdiv(Y_S1, 2)
Expand Down
Loading