diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index f5bcf975186e..12f33d07f596 100755 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -44,6 +44,7 @@ is_sm120_supported, offloader, ) +from sglang.srt.utils.custom_op import register_custom_op logger = logging.getLogger(__name__) @@ -857,7 +858,40 @@ def _pack_mxfp8_scales(scale_u8: torch.Tensor) -> torch.Tensor: return packed.view(1, scale_m, scale_k, 2, 256) -def triton_mxfp8_blockscaled_linear( +@register_custom_op( + op_name="triton_mxfp8_block_scaled_matmul", + mutates_args=[], + fake_impl=lambda a, a_scale, b, b_scale, output_dtype, block_m=128, block_n=256, block_k=128, num_stages=None: ( # noqa: E501 + a.new_empty((a.shape[0], b.shape[0]), dtype=output_dtype) + ), +) +def triton_mxfp8_block_scaled_matmul( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + output_dtype: torch.dtype, + *, + block_m: int = 128, + block_n: int = 256, + block_k: int = 128, + num_stages: Optional[int] = None, +) -> torch.Tensor: + """Opaque custom op wrapper to prevent Dynamo tracing Triton grid math.""" + return mxfp8_block_scaled_matmul_triton( + a, + a_scale, + b, + b_scale, + output_dtype=output_dtype, + block_m=block_m, + block_n=block_n, + block_k=block_k, + num_stages=num_stages, + ) + + +def _raw_triton_mxfp8_blockscaled_linear( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, @@ -918,7 +952,7 @@ def triton_mxfp8_blockscaled_linear( b_scale_packed = _pack_mxfp8_scales(weight_scale) num_stages = 1 if _is_sm120_supported else (4 if _is_sm100_supported else 1) - output = mxfp8_block_scaled_matmul_triton( + output = triton_mxfp8_block_scaled_matmul( q_input, a_scale_packed, weight.contiguous(), @@ -935,6 +969,35 @@ def triton_mxfp8_blockscaled_linear( return output.to(dtype=output_dtype).view(*output_shape) +@register_custom_op( + op_name="triton_mxfp8_blockscaled_linear", + mutates_args=[], + fake_impl=lambda input, weight, weight_scale, input_scale=None, bias=None, output_dtype=None: ( + input.new_empty( + (*input.shape[:-1], weight.shape[0]), + dtype=(output_dtype if output_dtype is not None else input.dtype), + ) + ), +) +def triton_mxfp8_blockscaled_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + output_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """Opaque custom-op wrapper to prevent Dynamo guards on MXFP8 padding branches.""" + return _raw_triton_mxfp8_blockscaled_linear( + input=input, + weight=weight, + weight_scale=weight_scale, + input_scale=input_scale, + bias=bias, + output_dtype=output_dtype, + ) + + def flashinfer_mxfp8_blockscaled_linear( input: torch.Tensor, weight: torch.Tensor,