diff --git a/python/triton_kernels/tests/test_matmul.py b/python/triton_kernels/tests/test_matmul.py index f7542501da0e..739134b069ac 100644 --- a/python/triton_kernels/tests/test_matmul.py +++ b/python/triton_kernels/tests/test_matmul.py @@ -13,7 +13,7 @@ from triton_kernels.matmul import matmul_set_idle_sms, matmul, matmul_torch # numerics utilities from triton_kernels.numerics import InFlexData, OutFlexData -from triton_kernels.numerics_details.mxfp import upcast_from_mxfp, quantize_mxfp8_fn, downcast_to_mxfp_torch, upcast_from_mxfp_torch, MXFP_BLOCK_SIZE, NVFP_BLOCK_SIZE +from triton_kernels.numerics_details.mxfp import upcast_from_mxfp, quantize_mxfp8_fn, quantize_nvfp4_fn, downcast_to_mxfp_torch, upcast_from_mxfp_torch, MXFP_BLOCK_SIZE, NVFP_BLOCK_SIZE # testing utilities from triton_kernels.testing import assert_close, make_random_tensor # target-specific utilities @@ -30,6 +30,8 @@ class DType: def __init__(self, dtype_str): + # This tracks the regular fp8 flex scale path. NVFP4 has a tensor scale, + # but it is handled separately because it also has MX microscale storage. self.has_global_scale = dtype_str.startswith("float8") self.is_nvfp4 = dtype_str == "nvfp4_e2m1" self.has_mx_scale = dtype_str.startswith("mx") or self.is_nvfp4 @@ -218,6 +220,13 @@ def _build_test_op_cases(): Case(*shape, mode, "mxfloat8_e4m3fn", "mxfloat4_e2m1", a_hbm_swizzling=True, b_hbm_swizzling=True, split_k=split_k, swiglu_opts=(1.1, 7)) for shape in [odd_shape2, even_shape] for mode in ["ragged", "batched"] for split_k in [1, 5] ]) + # swiglu together with nvfp4 downcast epilogue + test_cases.extend([ + Case(*shape, mode, "bfloat16", "bfloat16", "nvfp4_e2m1", swiglu_opts=(1.1, 7.0)) + for shape in [even_shape] + for mode in ["ragged", "batched"] + ]) + test_cases.append(Case(256, 2048, 1024, "plain", "bfloat16", "bfloat16", "nvfp4_e2m1", swiglu_opts=(1.1, 7.0))) return test_cases @@ -268,6 +277,8 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, if is_cuda(): if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9: pytest.skip("Float8 not tested on A100") + if output_dtype_str == "nvfp4_e2m1" and torch.cuda.get_device_capability()[0] < 9: + pytest.skip("NVFP4 output scales use fp8e4nv, which is not supported on A100") if act_dtype_str == "float16" and weight_uses_mx and torch.cuda.get_device_capability()[0] >= 10: pytest.skip("float16 x mx not supported with cuda capability >= 10") if weight_uses_mx: @@ -437,7 +448,12 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, wrap_list = lambda vals: torch.tensor(vals, dtype=torch.float32, device=device) flex_a = InFlexData(c_dtype.torch_dtype, wrap_list([1.25])) if c_dtype.has_global_scale else InFlexData() flex_b = InFlexData(b_dtype.torch_dtype, wrap_list([1.25])) if b_dtype.has_global_scale else InFlexData() - flex_c = OutFlexData(c_dtype.torch_dtype, wrap_list([4.00]), wrap_list([0]), None) if c_dtype.has_global_scale else OutFlexData() + if c_dtype.has_global_scale: + flex_c = OutFlexData(c_dtype.torch_dtype, wrap_list([4.00]), wrap_list([0]), None) + elif c_dtype.is_nvfp4: + flex_c = OutFlexData(c_dtype.torch_dtype, wrap_list([0.125]), None, None) + else: + flex_c = OutFlexData(c_dtype.torch_dtype, None, None, None) precision_opt = PrecisionConfig( flex_ctx=FlexCtx(flex_a, flex_b, flex_c), acc_scale=2.0 if c_dtype.has_global_scale or b_dtype.has_global_scale else 1.0, @@ -456,7 +472,11 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, precision_opt.c_mx_scale = c_scale precision_opt.c_microblock_size = c_dtype.microblock_size precision_opt.c_value_pack_factor = 2 if c_dtype.is_mxfloat4 else 1 - epilogue_spec = FnSpecs(FnName.QUANTIZE_MXFP8.name, quantize_mxfp8_fn, (), ()) + epilogue_spec = ( + FnSpecs(FnName.QUANTIZE_NVFP4.name, quantize_nvfp4_fn, (), ()) + if c_dtype.is_nvfp4 + else FnSpecs(FnName.QUANTIZE_MXFP8.name, quantize_mxfp8_fn, (), ()) + ) epilogue = Epilogue(epilogue_spec, tuple(), tuple(), effective_itemsize=6.0) @@ -472,10 +492,21 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, except (opt_flags.InapplicableConstraint, NotImplementedError) as e: pytest.skip(f"inapplicable opt_flags constraint {e}") # --- torch implementation --- - ref_y = matmul_torch(a, b, bias, # - a_ragged_metadata, b_ragged_metadata, - gather_indx, scatter_indx, precision_opt, - gammas=gammas) + # Fused NVFP4 output quantizes the float32 activation result and applies + # expected_scale inside downcast_to_mxfp_torch, so keep the reference in + # float32 until that final downcast instead of letting matmul_torch + # return bf16 and apply the output scale early. + ref_y = matmul_torch( + a.float() if c_dtype.is_nvfp4 else a, + b.float() if c_dtype.is_nvfp4 else b, + bias, + a_ragged_metadata, + b_ragged_metadata, + gather_indx, + scatter_indx, + PrecisionConfig() if c_dtype.is_nvfp4 else precision_opt, + gammas=gammas, + ) if swiglu_opts is not None: ref_y = swiglu(ref_y, alpha=swiglu_opts[0], precision_config=SwiGLUPrecisionConfig(swiglu_opts[1])) if c_dtype.has_global_scale: @@ -491,6 +522,7 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, axis=-1, scale_dtype=c_dtype.scale_dtype, microblock_size=c_dtype.microblock_size, + expected_scale=precision_opt.flex_ctx.out_data.expected_scale, ) ref_y = upcast_from_mxfp_torch(ref_y, ref_scale, target_dtype=ref_target_dtype, axis=-1) maxtol, rmstol = None, None diff --git a/python/triton_kernels/triton_kernels/matmul.py b/python/triton_kernels/triton_kernels/matmul.py index 25b0fdff90d4..80c01330d978 100644 --- a/python/triton_kernels/triton_kernels/matmul.py +++ b/python/triton_kernels/triton_kernels/matmul.py @@ -335,7 +335,16 @@ def matmul(a, b, bias, assert a_microblock_size == b_microblock_size, ( f"Microscaled operands must share a block size. Got {a_microblock_size} and {b_microblock_size}" ) - mx_block_size = b_microblock_size or a_microblock_size or int(MXFP_BLOCK_SIZE) + mx_block_size = ( + a_microblock_size or b_microblock_size or precision_config.c_microblock_size or int(MXFP_BLOCK_SIZE) + ) + assert all( + size is None or size == mx_block_size + for size in (a_microblock_size, b_microblock_size, precision_config.c_microblock_size) + ), ( + "Microscaled operands/output must share a block size. " + f"Got a={a_microblock_size}, b={b_microblock_size}, c={precision_config.c_microblock_size}" + ) if precision_config.c_mx_scale is not None and precision_config.c_microblock_size is None: precision_config.c_microblock_size = mx_block_size precision_config.c_value_pack_factor = 2 if precision_config.c_mx_scale is not None and epilogue.specs.name in ( @@ -588,7 +597,7 @@ def matmul(a, b, bias, n_valid_slices = b_tensor_or_tma.shape[0] if ragged_dimension == "M" else n_slices (kernels._p_matmul if opt_flags.is_persistent else kernels._matmul)[(grid,)]( c_tensor_or_tma, c.storage.data, *out_matmul.stride(), - *((None, out_matmul_scale, None) if out_matmul_has_mx else out_matmul_flex), + *((out_matmul_flex.expected_scale, out_matmul_scale, None) if out_matmul_has_mx else out_matmul_flex), *out_matmul_scale_strides[-4:], a_tensor_or_tma, a.storage.data, *a_strides, a_transpose, flex.lhs_data.scale, diff --git a/python/triton_kernels/triton_kernels/matmul_details/_matmul.py b/python/triton_kernels/triton_kernels/matmul_details/_matmul.py index 629939f1982e..4c06483f2a0d 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/_matmul.py +++ b/python/triton_kernels/triton_kernels/matmul_details/_matmul.py @@ -522,6 +522,12 @@ def _matmul( MX_SCALE_BLOCK_N: tl.constexpr = OUT_BLOCK_N // MX_BLOCK_SIZE N_MX_BLOCK = tl.cdiv(N, MX_BLOCK_SIZE) tl.static_assert(EPILOGUE_FN is not None) + if PER_BATCH_OUT_SCALE: + YExpectedScale = YExpectedScale + start_z_out + # OCP MX outputs leave YExpectedScale unset, so this is an identity there. + # NVFP4 uses YExpectedScale to precondition the dense output before the + # microscaling epilogue writes direct e4m3 block scales. + out = float_to_flex(out, YExpectedScale, None, None, mask, Y, False) out, out_scale = EPILOGUE_FN(out, mask, *epilogue_fn_args) tl.static_assert(BLOCK_N % MX_SCALE_BLOCK_N == 0, "") offs_y_n_scale = MX_SCALE_BLOCK_N * pid_n + tl.arange(0, MX_SCALE_BLOCK_N) diff --git a/python/triton_kernels/triton_kernels/matmul_details/_p_matmul.py b/python/triton_kernels/triton_kernels/matmul_details/_p_matmul.py index 41387a3fc8db..b28110781443 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/_p_matmul.py +++ b/python/triton_kernels/triton_kernels/matmul_details/_p_matmul.py @@ -581,6 +581,23 @@ def _p_matmul( tl.static_assert(EPILOGUE_FN is not None) offs_y_n = out_off_n + tl.arange(0, OUT_BLOCK_N) mask_n = offs_y_n < yN + if PER_BATCH_OUT_SCALE: + ExpectedScale = YExpectedScale + start_z1 + else: + ExpectedScale = YExpectedScale + # OCP MX outputs leave YExpectedScale unset, so this is an + # identity there. NVFP4 uses YExpectedScale to precondition the + # dense output before the microscaling epilogue writes direct + # e4m3 block scales. + out = float_to_flex( + out, + ExpectedScale, + None, + None, + mask_m[:, None] & mask_n[None, :], + YPtr, + False, + ) out, out_scale = EPILOGUE_FN(out, mask_m[:, None] & mask_n[None, :], *epilogue_fn_args) tl.static_assert(BLOCK_N % MX_SCALE_BLOCK_N == 0, "") offs_y_n_scale = off_n1 // ACTIVATION_REDUCTION_N // MX_BLOCK_SIZE + a_i * MX_SCALE_BLOCK_N + tl.arange(0, MX_SCALE_BLOCK_N) diff --git a/python/triton_kernels/triton_kernels/numerics_details/mxfp.py b/python/triton_kernels/triton_kernels/numerics_details/mxfp.py index 04f2204969b3..50fc720a4989 100644 --- a/python/triton_kernels/triton_kernels/numerics_details/mxfp.py +++ b/python/triton_kernels/triton_kernels/numerics_details/mxfp.py @@ -180,6 +180,7 @@ def get_max_quant_val(dtype: torch.dtype): def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int, scale_dtype: torch.dtype = torch.uint8, microblock_size: int = MXFP_BLOCK_SIZE.value, + expected_scale: torch.Tensor | None = None, DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP): """ Converts the src tensor to the output format specified by out_quant_type. @@ -191,6 +192,9 @@ def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype • For mxfp8, the output has the same shape as src_tensor. • For FP4, the size along the axis is halved, and the tensor is returned as a torch.uint8. scale: Scale tensor computed per microblock along the axis. + + If expected_scale is provided, apply the same preconditioning as float_to_flex + before microscaled output quantization by dividing the source tensor by it. """ # This should probably be packed into its own tiny class ndim = src_tensor.ndim @@ -217,6 +221,12 @@ def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype # Permute the tensor so that the contiguous axis becomes the last dimension. src = src_tensor.transpose(axis, src_tensor.ndim - 1).to(torch.float32) + if expected_scale is not None: + if expected_scale.numel() == 1: + src = src / expected_scale.to(device=device, dtype=src.dtype) + else: + assert src.ndim == 3 + src = src / expected_scale.to(device=device, dtype=src.dtype)[:, None, None] axis_shape = src.shape[-1] # Pad the axis to be divisible by the microblock size, in case it is not. @@ -251,6 +261,9 @@ def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype ds_int_rounded = ds_int & 0x7F800000 dequant_scale_rounded = ds_int_rounded.view(torch.float32) else: + # Direct fp8 scales keep the existing plain cast semantics here: the + # stored scale is rounded by the fp8 conversion rather than forced + # upward despite the ROUND_UP mode name. dequant_scale_rounded = dequant_scale.to(scale_dtype).to(torch.float32) # Compute the quantization scale. diff --git a/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py b/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py index 19d65470c26a..d05bb0844b4c 100644 --- a/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py +++ b/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py @@ -65,6 +65,9 @@ def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.con else: tl.static_assert(mx_scale_dtype == tl.float8e4nv, f"Unsupported {mx_scale_dtype=}") tl.static_assert(DEQUANT_SCALE_ROUNDING_MODE == 0, "Direct float8 scales only support ROUND_UP") + # Direct fp8 scales keep the existing plain cast semantics here: the + # stored scale is rounded by the fp8 conversion rather than forced + # upward despite the ROUND_UP mode name. scale_tensor = (max_val / _get_max_quant_val(mx_tensor_dtype)).to(tl.float8e4nv) dequant_scale_rounded = scale_tensor.to(tl.float32) scale_tensor = scale_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE]) diff --git a/python/triton_kernels/triton_kernels/reduce.py b/python/triton_kernels/triton_kernels/reduce.py index 63b01aa445ac..7bdbb9802195 100644 --- a/python/triton_kernels/triton_kernels/reduce.py +++ b/python/triton_kernels/triton_kernels/reduce.py @@ -206,14 +206,16 @@ def _reduce_forward(X, stride_xr: tl.int64, stride_x0: tl.int64, stride_x1, # x valid_y_s1 = offs_y_s1 < Y_S1 valid_y_smx1 = offs_y_smx1 < tl.cdiv(Y_S1, 1 if Y_MX_BLOCK_SIZE is None else Y_MX_BLOCK_SIZE) is_out_fp4: tl.constexpr = YMx is not None and Y_VALUE_PACK_FACTOR == 2 - y = float_to_flex(y, YFlexExpected, YFlexActual, YFlexChecksum, None, Y, Y_FLEX_SATURATE_INF) - # TODO (phil): keeping for backward compatibility, but will remove ! - if YMx is None and POSTPROCESS_FN2 is not None: - y = POSTPROCESS_FN2(y, *postprocess_fn2_args, target_dtype=Y.dtype.element_ty) if YMx is not None: + y = float_to_flex(y, YFlexExpected, None, None, None, Y, False) y, y_scale = POSTPROCESS_MX_FN(y, valid_y_s1[None, :], *postprocess_mx_fn_args) y_mx_ptrs = YMx + offs_s0[:, None] * stride_ymx0 + offs_y_smx1[None, :] * stride_ymx1 tl.store(y_mx_ptrs, y_scale, mask=valid_s0[:, None] & valid_y_smx1[None, :]) + else: + y = float_to_flex(y, YFlexExpected, YFlexActual, YFlexChecksum, None, Y, Y_FLEX_SATURATE_INF) + # TODO (phil): keeping for backward compatibility, but will remove ! + if POSTPROCESS_FN2 is not None: + y = POSTPROCESS_FN2(y, *postprocess_fn2_args, target_dtype=Y.dtype.element_ty) if is_out_fp4: offs_y_s1 = pid_s1 * (BLOCK_Y_S1 // 2) + tl.arange(0, BLOCK_Y_S1 // 2) valid_y_s1 = offs_y_s1 < tl.cdiv(Y_S1, 2)