Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions python/sglang/srt/layers/quantization/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
is_sm120_supported,
offloader,
)
from sglang.srt.utils.custom_op import register_custom_op

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -857,7 +858,40 @@ def _pack_mxfp8_scales(scale_u8: torch.Tensor) -> torch.Tensor:
return packed.view(1, scale_m, scale_k, 2, 256)


def triton_mxfp8_blockscaled_linear(
@register_custom_op(
op_name="triton_mxfp8_block_scaled_matmul",
mutates_args=[],
fake_impl=lambda a, a_scale, b, b_scale, output_dtype, block_m=128, block_n=256, block_k=128, num_stages=None: ( # noqa: E501
a.new_empty((a.shape[0], b.shape[0]), dtype=output_dtype)
),
)
def triton_mxfp8_block_scaled_matmul(
a: torch.Tensor,
a_scale: torch.Tensor,
b: torch.Tensor,
b_scale: torch.Tensor,
output_dtype: torch.dtype,
*,
block_m: int = 128,
block_n: int = 256,
block_k: int = 128,
num_stages: Optional[int] = None,
) -> torch.Tensor:
"""Opaque custom op wrapper to prevent Dynamo tracing Triton grid math."""
return mxfp8_block_scaled_matmul_triton(
a,
a_scale,
b,
b_scale,
output_dtype=output_dtype,
block_m=block_m,
block_n=block_n,
block_k=block_k,
num_stages=num_stages,
)


def _raw_triton_mxfp8_blockscaled_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
Expand Down Expand Up @@ -918,7 +952,7 @@ def triton_mxfp8_blockscaled_linear(
b_scale_packed = _pack_mxfp8_scales(weight_scale)

num_stages = 1 if _is_sm120_supported else (4 if _is_sm100_supported else 1)
output = mxfp8_block_scaled_matmul_triton(
output = triton_mxfp8_block_scaled_matmul(
q_input,
a_scale_packed,
weight.contiguous(),
Expand All @@ -935,6 +969,35 @@ def triton_mxfp8_blockscaled_linear(
return output.to(dtype=output_dtype).view(*output_shape)


@register_custom_op(
op_name="triton_mxfp8_blockscaled_linear",
mutates_args=[],
fake_impl=lambda input, weight, weight_scale, input_scale=None, bias=None, output_dtype=None: (
input.new_empty(
(*input.shape[:-1], weight.shape[0]),
dtype=(output_dtype if output_dtype is not None else input.dtype),
)
),
)
def triton_mxfp8_blockscaled_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
output_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""Opaque custom-op wrapper to prevent Dynamo guards on MXFP8 padding branches."""
return _raw_triton_mxfp8_blockscaled_linear(
input=input,
weight=weight,
weight_scale=weight_scale,
input_scale=input_scale,
bias=bias,
output_dtype=output_dtype,
)


def flashinfer_mxfp8_blockscaled_linear(
input: torch.Tensor,
weight: torch.Tensor,
Expand Down
Loading