From f29600debe719b12279240bb7f561ec33e881abb Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 17 Oct 2025 14:22:04 -0700 Subject: [PATCH 1/3] support deepgemm + blackwell Signed-off-by: yewentao256 --- tests/v1/generation/test_batch_invariance.py | 2 +- .../test_rms_norm_batch_invariant.py | 2 +- vllm/model_executor/layers/quantization/fp8.py | 18 +++++++++++++++--- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/tests/v1/generation/test_batch_invariance.py b/tests/v1/generation/test_batch_invariance.py index 8c4e77fd8acf..f6705a8acaf2 100644 --- a/tests/v1/generation/test_batch_invariance.py +++ b/tests/v1/generation/test_batch_invariance.py @@ -11,7 +11,7 @@ from vllm.platforms import current_platform hopper_only = pytest.mark.skipif( - not (current_platform.is_cuda() and current_platform.is_device_capability(90)), + not (current_platform.is_cuda() and current_platform.has_device_capability(90)), reason="Requires CUDA and Hopper (SM90)", ) diff --git a/tests/v1/generation/test_rms_norm_batch_invariant.py b/tests/v1/generation/test_rms_norm_batch_invariant.py index 399965bbd734..9fcb63621908 100644 --- a/tests/v1/generation/test_rms_norm_batch_invariant.py +++ b/tests/v1/generation/test_rms_norm_batch_invariant.py @@ -15,7 +15,7 @@ from vllm.platforms import current_platform hopper_only = pytest.mark.skipif( - not (current_platform.is_cuda() and current_platform.is_device_capability(90)), + not (current_platform.is_cuda() and current_platform.has_device_capability(90)), reason="Requires CUDA and Hopper (SM90)", ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index bfd8fd7b9f7c..0c6154dfeb4a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -555,9 +555,21 @@ def apply( N, K = weight_fp8.shape - # Scale is stored transposed: [num_blocks_k, num_blocks_n] - # We need to transpose it to [num_blocks_n, num_blocks_k] first - weight_scale = weight_scale.t() + # determine expected number of blocks along N and K + num_blocks_n = (N + block_n - 1) // block_n + num_blocks_k = (K + block_k - 1) // block_k + + # scale layout may be [num_blocks_n, num_blocks_k] + # or [num_blocks_k, num_blocks_n] depending on backend + if weight_scale.dim() != 2: + raise RuntimeError( + f"FP8 block scale must be 2D, got {tuple(weight_scale.shape)}" + ) + + scale_rows, scale_cols = weight_scale.shape + if (scale_rows, scale_cols) == (num_blocks_k, num_blocks_n): + # stored transposed, bring to [num_blocks_n, num_blocks_k] + weight_scale = weight_scale.t() # Expand scale to match weight dimensions # scale_expanded should have shape [N, K] From 4caf7cf662b9baf6fbd75df6a23442422d12fce5 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 17 Oct 2025 15:22:31 -0700 Subject: [PATCH 2/3] support deep_gemm Signed-off-by: yewentao256 --- tests/v1/generation/test_batch_invariance.py | 14 +++--- .../test_rms_norm_batch_invariant.py | 16 +++---- .../model_executor/layers/quantization/fp8.py | 47 +++++++++++++++++-- 3 files changed, 59 insertions(+), 18 deletions(-) diff --git a/tests/v1/generation/test_batch_invariance.py b/tests/v1/generation/test_batch_invariance.py index f6705a8acaf2..8e59b695ed57 100644 --- a/tests/v1/generation/test_batch_invariance.py +++ b/tests/v1/generation/test_batch_invariance.py @@ -10,9 +10,9 @@ from vllm import LLM, SamplingParams from vllm.platforms import current_platform -hopper_only = pytest.mark.skipif( +skip_unsupported = pytest.mark.skipif( not (current_platform.is_cuda() and current_platform.has_device_capability(90)), - reason="Requires CUDA and Hopper (SM90)", + reason="Requires CUDA and >= Hopper (SM90)", ) @@ -74,7 +74,7 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: return base_prompt -@hopper_only +@skip_unsupported @pytest.mark.timeout(1000) def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): """ @@ -219,7 +219,7 @@ def _extract_step_logprobs(request_output): return None, None -@hopper_only +@skip_unsupported @pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"]) @pytest.mark.forked def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend): @@ -434,7 +434,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend): pytest.fail(msg) -@hopper_only +@skip_unsupported def test_simple_generation(): """ Simple test that runs the model with a basic prompt and prints the output. @@ -480,7 +480,7 @@ def test_simple_generation(): llm.shutdown() -@hopper_only +@skip_unsupported @pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"]) @pytest.mark.forked def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend): @@ -707,7 +707,7 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend): os.environ["VLLM_BATCH_INVARIANT"] = old_value -@hopper_only +@skip_unsupported @pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.forked def test_decode_logprobs_match_prefill_logprobs(backend): diff --git a/tests/v1/generation/test_rms_norm_batch_invariant.py b/tests/v1/generation/test_rms_norm_batch_invariant.py index 9fcb63621908..f79eba58d6ef 100644 --- a/tests/v1/generation/test_rms_norm_batch_invariant.py +++ b/tests/v1/generation/test_rms_norm_batch_invariant.py @@ -14,13 +14,13 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.platforms import current_platform -hopper_only = pytest.mark.skipif( +skip_unsupported = pytest.mark.skipif( not (current_platform.is_cuda() and current_platform.has_device_capability(90)), - reason="Requires CUDA and Hopper (SM90)", + reason="Requires CUDA and >= Hopper (SM90)", ) -@hopper_only +@skip_unsupported @pytest.mark.parametrize("batch_size", [1, 4, 16, 64]) @pytest.mark.parametrize("hidden_size", [512, 2048, 4096, 8192]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -69,7 +69,7 @@ def test_rms_norm_batch_invariant_vs_standard( ) -@hopper_only +@skip_unsupported @pytest.mark.parametrize("batch_size", [1, 16, 128]) @pytest.mark.parametrize("seq_len", [1, 32, 512]) @pytest.mark.parametrize("hidden_size", [2048, 4096]) @@ -111,7 +111,7 @@ def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int): ) -@hopper_only +@skip_unsupported def test_rms_norm_numerical_stability(): """ Test RMS norm numerical stability with extreme values. @@ -171,7 +171,7 @@ def test_rms_norm_numerical_stability(): ) -@hopper_only +@skip_unsupported def test_rms_norm_formula(): """ Test that RMS norm follows the correct mathematical formula. @@ -204,7 +204,7 @@ def test_rms_norm_formula(): ) -@hopper_only +@skip_unsupported @pytest.mark.parametrize("hidden_size", [128, 1024, 4096, 16384]) def test_rms_norm_different_hidden_sizes(hidden_size: int): """ @@ -242,7 +242,7 @@ def test_rms_norm_different_hidden_sizes(hidden_size: int): ) -@hopper_only +@skip_unsupported def test_rms_norm_determinism(): """ Test that batch-invariant RMS norm produces deterministic results. diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 0c6154dfeb4a..d2128da5e490 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -41,6 +41,7 @@ QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, @@ -94,9 +95,11 @@ from vllm.scalar_type import scalar_types from vllm.utils import has_deep_gemm from vllm.utils.deep_gemm import ( + fp8_gemm_nt, get_col_major_tma_aligned_tensor, is_deep_gemm_e8m0_used, is_deep_gemm_supported, + should_use_deepgemm_for_fp8_linear, ) from vllm.utils.flashinfer import has_flashinfer_moe @@ -539,8 +542,37 @@ def apply( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - # If batch invariant mode is enabled, dequantize and use BF16 compute + # if batch invariant mode is enabled, prefer DeepGEMM FP8 path + # we will use BF16 dequant when DeepGEMM is not supported. if vllm_is_batch_invariant(): + if self.block_quant and should_use_deepgemm_for_fp8_linear( + torch.bfloat16, layer.weight, None + ): + # quantize input to FP8 using the same op as runtime path + input_2d = x.view(-1, x.shape[-1]) + # use group quant consistent with block size across K + assert self.act_q_group_shape is not None + q_input, input_scale = QuantFP8( + False, + self.act_q_group_shape, + column_major_scales=True, + )(input_2d) + + output_2d = torch.empty( + (q_input.shape[0], layer.weight.shape[0]), + dtype=torch.bfloat16, + device=q_input.device, + ) + fp8_gemm_nt( + (q_input, input_scale), + (layer.weight, layer.weight_scale), + output_2d, + ) + output = output_2d.view(*x.shape[:-1], layer.weight.shape[0]) + if bias is not None: + output = output + bias + return output + # Dequantize FP8 weights to BF16 weight_fp8 = layer.weight.to(torch.bfloat16) weight_scale = layer.weight_scale.to(torch.bfloat16) @@ -568,8 +600,17 @@ def apply( scale_rows, scale_cols = weight_scale.shape if (scale_rows, scale_cols) == (num_blocks_k, num_blocks_n): - # stored transposed, bring to [num_blocks_n, num_blocks_k] - weight_scale = weight_scale.t() + if num_blocks_n == num_blocks_k: + # ambiguous square case, warn and skip transpose + logger.warning( + "Batch-invariant FP8: square block-scale %dx%d; " + "skipping transpose to avoid misorientation.", + scale_rows, + scale_cols, + ) + else: + # clear KN -> transpose to NK + weight_scale = weight_scale.t() # Expand scale to match weight dimensions # scale_expanded should have shape [N, K] From 5aeca7280b15919fc541be8802cb2611b0688642 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 17 Oct 2025 15:29:31 -0700 Subject: [PATCH 3/3] remove view Signed-off-by: yewentao256 --- vllm/model_executor/layers/quantization/fp8.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d2128da5e490..447b31b92d8f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -548,15 +548,13 @@ def apply( if self.block_quant and should_use_deepgemm_for_fp8_linear( torch.bfloat16, layer.weight, None ): - # quantize input to FP8 using the same op as runtime path - input_2d = x.view(-1, x.shape[-1]) # use group quant consistent with block size across K assert self.act_q_group_shape is not None q_input, input_scale = QuantFP8( False, self.act_q_group_shape, column_major_scales=True, - )(input_2d) + )(x) output_2d = torch.empty( (q_input.shape[0], layer.weight.shape[0]), @@ -568,10 +566,9 @@ def apply( (layer.weight, layer.weight_scale), output_2d, ) - output = output_2d.view(*x.shape[:-1], layer.weight.shape[0]) if bias is not None: - output = output + bias - return output + output_2d = output_2d + bias + return output_2d # Dequantize FP8 weights to BF16 weight_fp8 = layer.weight.to(torch.bfloat16)