diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index e8abe0d41a16..3eaa5c28fd5a 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -126,9 +126,21 @@ def check_model(model): not is_quant_method_supported("fp8"), reason="FP8 is not supported on this GPU type.", ) -@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) +# @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) @pytest.mark.parametrize( - "force_marlin", [False] if current_platform.is_rocm() else [False, True] + "kv_cache_dtype", + [ + "auto", + ], +) +@pytest.mark.parametrize( + # "force_marlin", [False] if current_platform.is_rocm() else [False, True] + "force_marlin", + [False] + if current_platform.is_rocm() + else [ + False, + ], ) @pytest.mark.parametrize( "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] @@ -150,7 +162,8 @@ def test_load_fp16_model( monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1") with vllm_runner( - "facebook/opt-125m", + # "facebook/opt-125m", + "Qwen/Qwen1.5-MoE-A2.7B", quantization="fp8", enforce_eager=True, kv_cache_dtype=kv_cache_dtype, @@ -189,7 +202,10 @@ def check_model(model): "It only runs on CUDA and ROCm platform." ) - llm.apply_model(check_model) + # below currently hardcodes opt-125m layers, skip for now + # llm.apply_model(check_model) + outputs = llm.generate_greedy(["Hello my name is"], max_tokens=20) + print(outputs[0][1]) @pytest.mark.skipif( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1c0c35bf6f41..6b424104ecfa 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import Enum from typing import TYPE_CHECKING, Any, Optional import torch @@ -103,6 +104,13 @@ logger = init_logger(__name__) +class OnlineQuantScalingType(Enum): + # TODO(before land): align on naming and add descriptive comments + # to each enum value + TENSORWISE = "tensorwise" + BLOCKWISE = "blockwise" + + class Fp8Config(QuantizationConfig): """Config class for FP8.""" @@ -140,6 +148,11 @@ def __init__( ) self.weight_block_size = weight_block_size + # TODO(before land): hook this up to user UI, for now hardcode it here + self.online_quant_scaling_type = OnlineQuantScalingType.BLOCKWISE + # self.online_quant_scaling_type = OnlineQuantScalingType.TENSORWISE + self.online_block_size = [128, 128] # [block_n, block_k] + @classmethod def get_name(cls) -> QuantizationMethods: return "fp8" @@ -328,8 +341,16 @@ def __init__(self, quant_config: Fp8Config): self.weight_block_size = self.quant_config.weight_block_size self.block_quant = self.weight_block_size is not None self.act_q_static = self.quant_config.activation_scheme == "static" + if self.weight_block_size: self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) + elif ( + self.quant_config.online_quant_scaling_type + is OnlineQuantScalingType.BLOCKWISE + ): + self.act_q_group_shape = GroupShape( + 1, self.quant_config.online_block_size[0] + ) else: # Use per-token quantization for better perf if dynamic and cutlass if not self.act_q_static and cutlass_fp8_supported(): @@ -337,11 +358,16 @@ def __init__(self, quant_config: Fp8Config): else: self.act_q_group_shape = GroupShape.PER_TENSOR - if self.block_quant: + if ( + self.block_quant + or self.quant_config.online_quant_scaling_type + is OnlineQuantScalingType.BLOCKWISE + ): + block_size = self.weight_block_size or self.quant_config.online_block_size + assert block_size is not None assert not self.act_q_static - assert self.weight_block_size is not None self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( - weight_group_shape=GroupShape(*self.weight_block_size), + weight_group_shape=GroupShape(*block_size), act_quant_group_shape=self.act_q_group_shape, cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported, @@ -487,8 +513,38 @@ def process_weights_after_loading(self, layer: Module) -> None: # If checkpoint not serialized fp8, quantize the weights. else: if not self.quant_config.is_checkpoint_fp8_serialized: - qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) - weight = qweight.t() + # Online quantization + if ( + self.quant_config.online_quant_scaling_type + is OnlineQuantScalingType.BLOCKWISE + ): + # blockwise + from vllm.utils.deep_gemm import per_block_cast_to_fp8 + + block_size = self.quant_config.online_block_size + # layer.weight is [N, K] where N=output_size, K=input_size + qweight, weight_scale_inv = per_block_cast_to_fp8( + layer.weight, block_size=block_size + ) + # qweight: [N, K] in FP8 + # weight_scale_inv: [N/block_n, K/block_k] - inverse scales + # Note: block ops expect [N, K] format (no transpose) + replace_parameter(layer, "weight", qweight.data) + replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data) + layer.weight_block_size = block_size + size_k_first = False + else: + # tensorwise + assert ( + self.quant_config.online_quant_scaling_type + is OnlineQuantScalingType.TENSORWISE + ) + qweight, weight_scale = ops.scaled_fp8_quant( + layer.weight, scale=None + ) + weight = qweight.t() + replace_parameter(layer, "weight", weight.data) + replace_parameter(layer, "weight_scale", weight_scale.data) # If checkpoint is fp8 per-tensor, handle that there are N scales for N # shards in a fused module @@ -512,9 +568,9 @@ def process_weights_after_loading(self, layer: Module) -> None: input_scale = input_scale.max() weight = weight.t() - # Update layer with new values. - replace_parameter(layer, "weight", weight.data) - replace_parameter(layer, "weight_scale", weight_scale.data) + # Update layer with new values. + replace_parameter(layer, "weight", weight.data) + replace_parameter(layer, "weight_scale", weight_scale.data) if input_scale is not None: replace_parameter(layer, "input_scale", input_scale) @@ -529,7 +585,11 @@ def process_weights_after_loading(self, layer: Module) -> None: del layer.input_scale return - if self.block_quant: + if ( + self.block_quant + or self.quant_config.online_quant_scaling_type + is OnlineQuantScalingType.BLOCKWISE + ): maybe_post_process_fp8_weight_block(layer) def apply( @@ -541,8 +601,11 @@ def apply( # if batch invariant mode is enabled, prefer DeepGEMM FP8 path # we will use BF16 dequant when DeepGEMM is not supported. if vllm_is_batch_invariant(): - if self.block_quant: - assert self.weight_block_size is not None + if ( + self.block_quant + or self.quant_config.online_quant_scaling_type + is OnlineQuantScalingType.BLOCKWISE + ): return self.w8a8_block_fp8_linear.apply( input=x, weight=layer.weight, @@ -591,9 +654,11 @@ def apply( bias=bias, ) - if self.block_quant: - assert self.weight_block_size is not None - + if ( + self.block_quant + or self.quant_config.online_quant_scaling_type + is OnlineQuantScalingType.BLOCKWISE + ): return self.w8a8_block_fp8_linear.apply( input=x, weight=layer.weight, @@ -1089,7 +1154,21 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): super().__init__(quant_config, layer) assert not quant_config.is_checkpoint_fp8_serialized assert quant_config.activation_scheme == "dynamic" - assert quant_config.weight_block_size is None + + # Override parent class attributes for online blockwise quantization + if ( + self.quant_config.online_quant_scaling_type + is OnlineQuantScalingType.BLOCKWISE + ): + self.weight_block_size = self.quant_config.online_block_size + self.block_quant = True + self.weight_scale_name = "weight_scale_inv" + # Re-select backend with correct block_quant flag + self.fp8_backend = select_fp8_moe_backend( + block_quant=self.block_quant, + tp_size=layer.moe_parallel_config.tp_size, + with_lora_support=self.moe.is_lora_enabled, + ) def create_weights( self, @@ -1168,16 +1247,43 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - layer.register_parameter("w2_weight_scale", w2_weight_scale) + if ( + self.quant_config.online_quant_scaling_type + is OnlineQuantScalingType.BLOCKWISE + ): + # For blockwise, scales are per block (typically 128x128) + block_size = self.quant_config.online_block_size + block_n, block_k = block_size[0], block_size[1] + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + (hidden_size + block_n - 1) // block_n, + (intermediate_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + else: + # For tensorwise, scales are per expert + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) @@ -1192,16 +1298,39 @@ def process_weights_after_loading(self, layer: Module) -> None: fp8_dtype = current_platform.fp8_dtype() w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) - w13_scale = layer.w13_weight_scale - w2_scale = layer.w2_weight_scale - for expert in range(layer.local_num_experts): - w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant( - layer.w13_weight[expert, :, :] - ) - w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant( - layer.w2_weight[expert, :, :] - ) + if ( + self.quant_config.online_quant_scaling_type + is OnlineQuantScalingType.BLOCKWISE + ): + # Blockwise quantization + from vllm.utils.deep_gemm import per_block_cast_to_fp8 + + block_size = self.quant_config.online_block_size + w13_scale = layer.w13_weight_scale_inv + w2_scale = layer.w2_weight_scale_inv + + for expert in range(layer.local_num_experts): + w13[expert, :, :], w13_scale[expert, :, :] = per_block_cast_to_fp8( + layer.w13_weight[expert, :, :], block_size=block_size + ) + w2[expert, :, :], w2_scale[expert, :, :] = per_block_cast_to_fp8( + layer.w2_weight[expert, :, :], block_size=block_size + ) + + layer.weight_block_size = block_size + else: + # Tensorwise quantization + w13_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + + for expert in range(layer.local_num_experts): + w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant( + layer.w13_weight[expert, :, :] + ) + w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant( + layer.w2_weight[expert, :, :] + ) # Shuffle weights to runtime format and setup kernel. self._setup_kernel(