diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1c9237d3f60a..a5260a8cda66 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,10 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math from typing import TYPE_CHECKING, Any import torch from torch.nn import Module +from torch.nn.parameter import UninitializedParameter from torch.utils._python_dispatch import TorchDispatchMode import vllm.envs as envs @@ -38,6 +40,11 @@ LinearMethodBase, UnquantizedLinearMethod, ) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + UnquantizedEmbeddingMethod, + VocabParallelEmbedding, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, @@ -102,10 +109,14 @@ def __init__( activation_scheme: str = "dynamic", ignored_layers: list[str] | None = None, weight_block_size: list[int] | None = None, + lm_head_quantized: bool = False, + embeddings_quantized: bool = False, ) -> None: super().__init__() self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + self.lm_head_quantized = lm_head_quantized + self.embeddings_quantized = embeddings_quantized if activation_scheme not in ACTIVATION_SCHEMES: raise ValueError(f"Unsupported activation scheme {activation_scheme}") @@ -162,22 +173,35 @@ def from_config(cls, config: dict[str, Any]) -> "Fp8Config": ignored_layers = cls.get_from_keys_or( config, ["modules_to_not_convert"], None ) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + embeddings_quantized = cls.get_from_keys_or( + config, + ["embed_tokens"], + cls.get_from_keys_or(config, ["embeddings"], False), + ) return cls( is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, activation_scheme=activation_scheme, ignored_layers=ignored_layers, weight_block_size=weight_block_size, + lm_head_quantized=lm_head_quantized, + embeddings_quantized=embeddings_quantized, ) def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> "QuantizeMethodBase | None": - if isinstance(layer, LinearBase): + is_parallel_lm_head = isinstance(layer, ParallelLMHead) + if isinstance(layer, LinearBase) or ( + is_parallel_lm_head and self.lm_head_quantized + ): if is_layer_skipped( prefix=prefix, ignored_layers=self.ignored_layers, fused_mapping=self.packed_modules_mapping, ): + if is_parallel_lm_head: + return UnquantizedEmbeddingMethod() return UnquantizedLinearMethod() if not self.is_checkpoint_fp8_serialized: online_method = Fp8OnlineLinearMethod(self) @@ -201,6 +225,21 @@ def get_quant_method( return moe_quant_method elif isinstance(layer, Attention): return Fp8KVCacheMethod(self) + elif type(layer) is VocabParallelEmbedding: + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): + return UnquantizedEmbeddingMethod() + if not self.embeddings_quantized: + return UnquantizedEmbeddingMethod() + if not self.is_checkpoint_fp8_serialized: + raise NotImplementedError( + "Online FP8 quantization for VocabParallelEmbedding is " + "not implemented. Serialize embed_tokens to FP8 offline." + ) + return Fp8EmbeddingMethod(self) return None def get_cache_scale(self, name: str) -> str | None: @@ -254,6 +293,52 @@ def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None: set_weight_attrs(new, attrs_to_set) + +def _make_lm_head_block_scale_loader(layer, block_size): + """Per-parameter weight_loader for FP8 block scale_inv on ParallelLMHead. + + The default VocabParallelEmbedding.weight_loader assumes vocab-shaped + tensors and rejects companion params with a different leading dim + (e.g., weight_scale_inv has shape [vocab/block_out, hidden/block_in]). + This loader shards the scale tensor along the block-aligned vocab dim + using the layer's existing shard_indices, and zero-fills any padding + rows the param was sized for. + """ + block_out = block_size[0] + + def load(param, loaded_weight): + start = layer.shard_indices.org_vocab_start_index + assert start % block_out == 0, ( + f"FP8 lm_head requires the vocab-parallel shard start " + f"({start}) to be divisible by weight_block_size[0] " + f"({block_out})" + ) + start_idx = start // block_out + local_rows = param.shape[0] + assert loaded_weight.shape[0] >= start_idx + local_rows, ( + f"loaded scale has {loaded_weight.shape[0]} rows, " + f"need at least {start_idx + local_rows} " + f"(start_idx={start_idx}, local_rows={local_rows})" + ) + chunk = loaded_weight.narrow(0, start_idx, local_rows) + param.data.copy_(chunk) + + return load + + +def _make_lm_head_scalar_scale_loader(): + """Per-parameter weight_loader for FP8 per-tensor / input scale on + ParallelLMHead. Per-tensor scales are not vocab-parallel; just copy. + """ + + def load(param, loaded_weight): + if loaded_weight.numel() > param.data.numel(): + loaded_weight = loaded_weight.flatten()[0] + param.data.copy_(loaded_weight.reshape(param.data.shape)) + + return load + + class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. Supports loading FP8 checkpoints with static weight scale and @@ -344,6 +429,25 @@ def create_weights( ) layer.register_parameter("weight", weight) + # WEIGHT / INPUT SCALES + # When this method is dispatched to a ParallelLMHead (opt-in via + # Fp8Config.lm_head_quantized), companion params can't share the + # default VocabParallelEmbedding.weight_loader (which assumes + # vocab-shaped tensors). Pick the right scale loader up front so + # we don't have to override it post-hoc -- set_weight_attrs() asserts + # against double-assignment of `weight_loader`. + if isinstance(layer, ParallelLMHead): + if self.block_quant: + scale_weight_loader = _make_lm_head_block_scale_loader( + layer, self.weight_block_size + ) + else: + scale_weight_loader = _make_lm_head_scalar_scale_loader() + input_scale_weight_loader = _make_lm_head_scalar_scale_loader() + else: + scale_weight_loader = weight_loader + input_scale_weight_loader = weight_loader + # WEIGHT SCALE if not self.block_quant: scale = create_fp8_scale_parameter( @@ -351,7 +455,7 @@ def create_weights( output_partition_sizes, input_size_per_partition, None, - weight_loader, + scale_weight_loader, ) layer.register_parameter("weight_scale", scale) else: @@ -362,7 +466,7 @@ def create_weights( output_partition_sizes, input_size_per_partition, self.weight_block_size, - weight_loader, + scale_weight_loader, scale_dtype=(torch.float8_e8m0fnu if self.is_scale_e8m0 else None), ) # The weight_scale_inv name is intentional for deepseekv3 @@ -370,7 +474,9 @@ def create_weights( # INPUT ACTIVATION SCALE if self.act_q_static: - scale = create_fp8_input_scale(output_partition_sizes, weight_loader) + scale = create_fp8_input_scale( + output_partition_sizes, input_scale_weight_loader + ) set_weight_attrs(scale, {"scale_type": "input_scale"}) layer.register_parameter("input_scale", scale) @@ -476,6 +582,159 @@ def apply( return self.fp8_linear.apply_weights(layer, x, bias) + +class Fp8EmbeddingMethod(Fp8LinearMethod): + """Memory-first FP8 runtime for VocabParallelEmbedding. + + The embedding table stays resident as FP8 and only gathered rows are + dequantized back to the model dtype during lookup. This is intended as a + correctness/VRAM test path, not the final fused performance design. + """ + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + if self.block_quant: + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + validate_fp8_block_shape( + layer, + input_size, + output_size, + input_size_per_partition, + output_partition_sizes, + self.weight_block_size, + ) + + weight = create_fp8_weight_parameter( + output_size_per_partition, input_size_per_partition, weight_loader + ) + layer.register_parameter("weight", weight) + + if not self.block_quant: + scale = create_fp8_scale_parameter( + PerTensorScaleParameter, + output_partition_sizes, + input_size_per_partition, + None, + weight_loader, + ) + layer.register_parameter("weight_scale", scale) + else: + assert not self.act_q_static + assert self.weight_block_size is not None + scale = create_fp8_scale_parameter( + BlockQuantScaleParameter, + output_partition_sizes, + input_size_per_partition, + self.weight_block_size, + weight_loader, + scale_dtype=(torch.float8_e8m0fnu if self.is_scale_e8m0 else None), + ) + layer.register_parameter("weight_scale_inv", scale) + + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=self.activation_quant_key, + weight_quant_key=self.weight_quant_key, + weight_shape=layer.weight.shape, + input_dtype=self.input_dtype, + out_dtype=self.out_dtype, + module_name=self.__class__.__name__, + ) + self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel) + + def load_embedding_companion( + self, + layer: VocabParallelEmbedding, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + ) -> None: + output_dim = getattr(param, "output_dim", None) + if output_dim is None: + if isinstance(param, UninitializedParameter): + param.materialize(tuple(loaded_weight.shape), dtype=loaded_weight.dtype) + assert param.data.shape == loaded_weight.shape + param.data.copy_(loaded_weight) + return + + if isinstance(param, UninitializedParameter): + shape = list(loaded_weight.shape) + if self.block_quant: + assert self.weight_block_size is not None + block_n = self.weight_block_size[0] + shape[output_dim] = math.ceil( + layer.num_embeddings_per_partition / block_n + ) + else: + shape[output_dim] = layer.num_embeddings_per_partition + param.materialize(tuple(shape), dtype=loaded_weight.dtype) + + if loaded_weight.shape == param.data.shape or loaded_weight.shape[output_dim] == 1: + param.data.copy_(loaded_weight) + return + + if self.block_quant: + assert self.weight_block_size is not None + block_n = self.weight_block_size[0] + start = layer.shard_indices.org_vocab_start_index + assert start % block_n == 0, ( + f"FP8 embedding requires the vocab-parallel shard start " + f"({start}) to be divisible by weight_block_size[0] " + f"({block_n})" + ) + start_idx = start // block_n + else: + start_idx = layer.shard_indices.org_vocab_start_index + + local_rows = param.data.shape[output_dim] + available_rows = max(0, loaded_weight.shape[output_dim] - start_idx) + copy_rows = min(local_rows, available_rows) + + if copy_rows > 0: + shard = loaded_weight.narrow(output_dim, start_idx, copy_rows) + param.data.narrow(output_dim, 0, copy_rows).copy_(shard) + if copy_rows < local_rows: + param.data.narrow(output_dim, copy_rows, local_rows - copy_rows).fill_(0) + + def process_weights_after_loading(self, layer: Module) -> None: + layer._already_called_process_weights_after_loading = True + + def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: + out_dtype = getattr(layer, "orig_dtype", torch.get_default_dtype()) + flat_input = input_.reshape(-1).long() + weight_rows = layer.weight.index_select(0, flat_input).to(out_dtype) + + if self.block_quant: + assert self.weight_block_size is not None + block_n, block_k = self.weight_block_size + scales = layer.weight_scale_inv.to(out_dtype) + row_blocks = torch.div(flat_input, block_n, rounding_mode="floor") + row_scales = scales.index_select(0, row_blocks) + row_scales = row_scales.repeat_interleave(block_k, dim=-1) + row_scales = row_scales[..., : layer.embedding_dim] + output = weight_rows * row_scales + else: + scales = layer.weight_scale.to(out_dtype) + output = weight_rows * scales + + return output.reshape(*input_.shape, layer.embedding_dim) + # TODO(future PR): remove this class in favor of # online/fp8.py::Fp8PerTensorOnlineLinearMethod class Fp8OnlineLinearMethod(Fp8LinearMethod): diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index ddae01856da0..f330316f0821 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -442,6 +442,33 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param.data.copy_(loaded_weight) return + if packed_dim is not None and packed_dim == output_dim: + packed_factor = ( + param.packed_factor + if isinstance(param, BasevLLMParameter) + else param.pack_factor + ) + expected_output_size = self.org_vocab_size // packed_factor + else: + expected_output_size = self.org_vocab_size + + # Quantized embedding companion params (for example FP8 block + # scales) follow the vocab axis but do not have raw vocab shape. + # Delegate those to the quant method before the main weight + # sharding assertions below. + if loaded_weight.shape[output_dim] != expected_output_size: + companion_loader = getattr( + self.quant_method, "load_embedding_companion", None + ) + if companion_loader is None: + raise AssertionError( + "Unsupported quantized embedding companion parameter " + f"shape={tuple(loaded_weight.shape)} for " + f"{type(self.quant_method).__name__}" + ) + companion_loader(self, param, loaded_weight) + return + # Shard indexes for loading the weight start_idx = self.shard_indices.org_vocab_start_index shard_size = self.shard_indices.org_vocab_end_index - start_idx diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py index 2449724cd04f..798e17a6e000 100644 --- a/vllm/model_executor/models/qwen3_5.py +++ b/vllm/model_executor/models/qwen3_5.py @@ -224,6 +224,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, + quant_config=vllm_config.quant_config, + prefix=maybe_prefix(prefix, "embed_tokens"), ) def get_layer(prefix: str): @@ -501,6 +503,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, + quant_config=self.quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) else: