diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py index d6fa29a704a..aed8ec2f112 100644 --- a/python/sglang/srt/custom_op.py +++ b/python/sglang/srt/custom_op.py @@ -50,6 +50,7 @@ def dispatch_forward(self): def scaled_fp8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, use_per_token_if_dynamic: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -59,6 +60,8 @@ def scaled_fp8_quant( input (torch.Tensor): Input tensor to be quantized scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization. If None, scales will be computed dynamically. + num_token_padding (Optional[int]): If specified, pad the first dimension + of the output to at least this value. use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None), determines the quantization granularity: - True: compute scale per token @@ -75,6 +78,8 @@ def scaled_fp8_quant( assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D" shape = input.shape out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) output = torch.empty(shape, device=input.device, dtype=out_dtype) if scale is None: diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 9ba62a6f654..7a7eb8884e0 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -457,12 +457,9 @@ def apply( qinput, x_scale = sgl_scaled_fp8_quant( input_2d, input_scale, + num_token_padding=self.output_padding, use_per_token_if_dynamic=use_per_token_if_dynamic, ) - if self.output_padding: - pad_size = max(self.output_padding - qinput.shape[0], 0) - if pad_size > 0: - qinput = torch.nn.functional.pad(qinput, (0, 0, 0, pad_size)) else: qinput, x_scale = ops.scaled_fp8_quant( input_2d, diff --git a/python/sglang/test/test_custom_ops.py b/python/sglang/test/test_custom_ops.py index 6fa343dcb09..72b9f5ab31e 100644 --- a/python/sglang/test/test_custom_ops.py +++ b/python/sglang/test/test_custom_ops.py @@ -82,6 +82,61 @@ def dequantize_per_token(tensor, inv_scale, dtype): dequantize_per_token(ref_y, scale, dtype), ) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_scaled_fp8_quant_with_padding(dtype) -> None: + original_rows = 5 + x = (torch.randn(size=(original_rows, 16), device="cuda") * 13).to(dtype) + + padding_size = 10 + + # Test with dynamic quantization + y_dynamic, scale_dynamic = scaled_fp8_quant( + x, None, num_token_padding=padding_size + ) + + # Verify output shape has the padded size + assert y_dynamic.shape[0] == padding_size + assert y_dynamic.shape[1] == x.shape[1] + + # Verify that the actual data in the non-padded region is correctly quantized + y_without_padding, scale_without_padding = scaled_fp8_quant(x, None) + torch.testing.assert_close(y_dynamic[:original_rows], y_without_padding) + + # Test with static quantization + # First get a scale + _, scale = scaled_fp8_quant(x, None) + + # Then use it for static quantization with padding + y_static, _ = scaled_fp8_quant(x, scale, num_token_padding=padding_size) + + # Verify output shape has the padded size + assert y_static.shape[0] == padding_size + assert y_static.shape[1] == x.shape[1] + + # Verify that the actual data in the non-padded region is correctly quantized + y_static_without_padding, _ = scaled_fp8_quant(x, scale) + torch.testing.assert_close(y_static[:original_rows], y_static_without_padding) + + # Test with per-token dynamic quantization + y_per_token, scale_per_token = scaled_fp8_quant( + x, None, num_token_padding=padding_size, use_per_token_if_dynamic=True + ) + + # Verify output shape has the padded size + assert y_per_token.shape[0] == padding_size + assert y_per_token.shape[1] == x.shape[1] + + # Verify that the actual data in the non-padded region is correctly quantized + y_per_token_without_padding, scale_per_token_without_padding = scaled_fp8_quant( + x, None, use_per_token_if_dynamic=True + ) + torch.testing.assert_close( + y_per_token[:original_rows], y_per_token_without_padding + ) + torch.testing.assert_close( + scale_per_token[:original_rows], scale_per_token_without_padding + ) + if __name__ == "__main__": # Run the specific test function directly