diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1c9237d3f60a..dd5705374f1b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -38,6 +38,10 @@ LinearMethodBase, UnquantizedLinearMethod, ) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + UnquantizedEmbeddingMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, @@ -102,10 +106,12 @@ def __init__( activation_scheme: str = "dynamic", ignored_layers: list[str] | None = None, weight_block_size: list[int] | None = None, + lm_head_quantized: bool = False, ) -> None: super().__init__() self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + self.lm_head_quantized = lm_head_quantized if activation_scheme not in ACTIVATION_SCHEMES: raise ValueError(f"Unsupported activation scheme {activation_scheme}") @@ -162,22 +168,29 @@ def from_config(cls, config: dict[str, Any]) -> "Fp8Config": ignored_layers = cls.get_from_keys_or( config, ["modules_to_not_convert"], None ) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) return cls( is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, activation_scheme=activation_scheme, ignored_layers=ignored_layers, weight_block_size=weight_block_size, + lm_head_quantized=lm_head_quantized, ) def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> "QuantizeMethodBase | None": - if isinstance(layer, LinearBase): + is_parallel_lm_head = isinstance(layer, ParallelLMHead) + if isinstance(layer, LinearBase) or ( + is_parallel_lm_head and self.lm_head_quantized + ): if is_layer_skipped( prefix=prefix, ignored_layers=self.ignored_layers, fused_mapping=self.packed_modules_mapping, ): + if is_parallel_lm_head: + return UnquantizedEmbeddingMethod() return UnquantizedLinearMethod() if not self.is_checkpoint_fp8_serialized: online_method = Fp8OnlineLinearMethod(self) @@ -254,6 +267,50 @@ def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None: set_weight_attrs(new, attrs_to_set) + +def _make_lm_head_block_scale_loader(layer, block_size): + """Per-parameter weight_loader for FP8 block scale_inv on ParallelLMHead. + + The default VocabParallelEmbedding.weight_loader assumes vocab-shaped + tensors and rejects companion params with a different leading dim + (e.g., weight_scale_inv has shape [vocab/block_out, hidden/block_in]). + This loader shards the scale tensor along the block-aligned vocab dim + using the layer's existing shard_indices, and zero-fills any padding + rows the param was sized for. + """ + block_out = block_size[0] + + def load(param, loaded_weight): + start = layer.shard_indices.org_vocab_start_index + assert start % block_out == 0, ( + f"FP8 lm_head requires the vocab-parallel shard start " + f"({start}) to be divisible by weight_block_size[0] " + f"({block_out})" + ) + start_idx = start // block_out + local_rows = param.shape[0] + assert loaded_weight.shape[0] >= start_idx + local_rows, ( + f"loaded scale has {loaded_weight.shape[0]} rows, " + f"need at least {start_idx + local_rows} " + f"(start_idx={start_idx}, local_rows={local_rows})" + ) + chunk = loaded_weight.narrow(0, start_idx, local_rows) + param.data.copy_(chunk) + + return load + + +def _make_lm_head_scalar_scale_loader(): + """Per-parameter weight_loader for FP8 per-tensor / input scale on + ParallelLMHead. Per-tensor scales are not vocab-parallel; just copy. + """ + + def load(param, loaded_weight): + param.data.copy_(loaded_weight.reshape(param.data.shape)) + + return load + + class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. Supports loading FP8 checkpoints with static weight scale and @@ -344,6 +401,25 @@ def create_weights( ) layer.register_parameter("weight", weight) + # WEIGHT / INPUT SCALES + # When this method is dispatched to a ParallelLMHead (opt-in via + # Fp8Config.lm_head_quantized), companion params can't share the + # default VocabParallelEmbedding.weight_loader (which assumes + # vocab-shaped tensors). Pick the right scale loader up front so + # we don't have to override it post-hoc -- set_weight_attrs() asserts + # against double-assignment of `weight_loader`. + if isinstance(layer, ParallelLMHead): + if self.block_quant: + scale_weight_loader = _make_lm_head_block_scale_loader( + layer, self.weight_block_size + ) + else: + scale_weight_loader = _make_lm_head_scalar_scale_loader() + input_scale_weight_loader = _make_lm_head_scalar_scale_loader() + else: + scale_weight_loader = weight_loader + input_scale_weight_loader = weight_loader + # WEIGHT SCALE if not self.block_quant: scale = create_fp8_scale_parameter( @@ -351,7 +427,7 @@ def create_weights( output_partition_sizes, input_size_per_partition, None, - weight_loader, + scale_weight_loader, ) layer.register_parameter("weight_scale", scale) else: @@ -362,7 +438,7 @@ def create_weights( output_partition_sizes, input_size_per_partition, self.weight_block_size, - weight_loader, + scale_weight_loader, scale_dtype=(torch.float8_e8m0fnu if self.is_scale_e8m0 else None), ) # The weight_scale_inv name is intentional for deepseekv3 @@ -370,7 +446,9 @@ def create_weights( # INPUT ACTIVATION SCALE if self.act_q_static: - scale = create_fp8_input_scale(output_partition_sizes, weight_loader) + scale = create_fp8_input_scale( + output_partition_sizes, input_scale_weight_loader + ) set_weight_attrs(scale, {"scale_type": "input_scale"}) layer.register_parameter("input_scale", scale) diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py index 2449724cd04f..091c394500ca 100644 --- a/vllm/model_executor/models/qwen3_5.py +++ b/vllm/model_executor/models/qwen3_5.py @@ -501,6 +501,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, + quant_config=self.quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) else: