-
-
Notifications
You must be signed in to change notification settings - Fork 16.6k
Add opt-in FP8 vocab embedding support #41365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
72238d4
3520a71
00def4f
e3547df
fd1ac49
7365b4e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,12 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import math | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| import torch | ||
| from torch.nn import Module | ||
| from torch.nn.parameter import UninitializedParameter | ||
| from torch.utils._python_dispatch import TorchDispatchMode | ||
|
|
||
| import vllm.envs as envs | ||
|
|
@@ -38,6 +40,11 @@ | |
| LinearMethodBase, | ||
| UnquantizedLinearMethod, | ||
| ) | ||
| from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
| ParallelLMHead, | ||
| UnquantizedEmbeddingMethod, | ||
| VocabParallelEmbedding, | ||
| ) | ||
| from vllm.model_executor.layers.quantization import QuantizationMethods | ||
| from vllm.model_executor.layers.quantization.base_config import ( | ||
| QuantizationConfig, | ||
|
|
@@ -102,10 +109,14 @@ def __init__( | |
| activation_scheme: str = "dynamic", | ||
| ignored_layers: list[str] | None = None, | ||
| weight_block_size: list[int] | None = None, | ||
| lm_head_quantized: bool = False, | ||
| embeddings_quantized: bool = False, | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized | ||
| self.lm_head_quantized = lm_head_quantized | ||
| self.embeddings_quantized = embeddings_quantized | ||
|
|
||
| if activation_scheme not in ACTIVATION_SCHEMES: | ||
| raise ValueError(f"Unsupported activation scheme {activation_scheme}") | ||
|
|
@@ -162,22 +173,35 @@ def from_config(cls, config: dict[str, Any]) -> "Fp8Config": | |
| ignored_layers = cls.get_from_keys_or( | ||
| config, ["modules_to_not_convert"], None | ||
| ) | ||
| lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) | ||
| embeddings_quantized = cls.get_from_keys_or( | ||
| config, | ||
| ["embed_tokens"], | ||
| cls.get_from_keys_or(config, ["embeddings"], False), | ||
| ) | ||
| return cls( | ||
| is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, | ||
| activation_scheme=activation_scheme, | ||
| ignored_layers=ignored_layers, | ||
| weight_block_size=weight_block_size, | ||
| lm_head_quantized=lm_head_quantized, | ||
| embeddings_quantized=embeddings_quantized, | ||
| ) | ||
|
|
||
| def get_quant_method( | ||
| self, layer: torch.nn.Module, prefix: str | ||
| ) -> "QuantizeMethodBase | None": | ||
| if isinstance(layer, LinearBase): | ||
| is_parallel_lm_head = isinstance(layer, ParallelLMHead) | ||
| if isinstance(layer, LinearBase) or ( | ||
| is_parallel_lm_head and self.lm_head_quantized | ||
| ): | ||
| if is_layer_skipped( | ||
| prefix=prefix, | ||
| ignored_layers=self.ignored_layers, | ||
| fused_mapping=self.packed_modules_mapping, | ||
| ): | ||
| if is_parallel_lm_head: | ||
| return UnquantizedEmbeddingMethod() | ||
| return UnquantizedLinearMethod() | ||
| if not self.is_checkpoint_fp8_serialized: | ||
| online_method = Fp8OnlineLinearMethod(self) | ||
|
|
@@ -201,6 +225,21 @@ def get_quant_method( | |
| return moe_quant_method | ||
| elif isinstance(layer, Attention): | ||
| return Fp8KVCacheMethod(self) | ||
| elif type(layer) is VocabParallelEmbedding: | ||
| if is_layer_skipped( | ||
| prefix=prefix, | ||
| ignored_layers=self.ignored_layers, | ||
| fused_mapping=self.packed_modules_mapping, | ||
| ): | ||
| return UnquantizedEmbeddingMethod() | ||
| if not self.embeddings_quantized: | ||
| return UnquantizedEmbeddingMethod() | ||
| if not self.is_checkpoint_fp8_serialized: | ||
| raise NotImplementedError( | ||
| "Online FP8 quantization for VocabParallelEmbedding is " | ||
| "not implemented. Serialize embed_tokens to FP8 offline." | ||
| ) | ||
| return Fp8EmbeddingMethod(self) | ||
| return None | ||
|
|
||
| def get_cache_scale(self, name: str) -> str | None: | ||
|
|
@@ -254,6 +293,52 @@ def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None: | |
| set_weight_attrs(new, attrs_to_set) | ||
|
|
||
|
|
||
|
|
||
| def _make_lm_head_block_scale_loader(layer, block_size): | ||
| """Per-parameter weight_loader for FP8 block scale_inv on ParallelLMHead. | ||
|
|
||
| The default VocabParallelEmbedding.weight_loader assumes vocab-shaped | ||
| tensors and rejects companion params with a different leading dim | ||
| (e.g., weight_scale_inv has shape [vocab/block_out, hidden/block_in]). | ||
| This loader shards the scale tensor along the block-aligned vocab dim | ||
| using the layer's existing shard_indices, and zero-fills any padding | ||
| rows the param was sized for. | ||
| """ | ||
| block_out = block_size[0] | ||
|
|
||
| def load(param, loaded_weight): | ||
| start = layer.shard_indices.org_vocab_start_index | ||
| assert start % block_out == 0, ( | ||
| f"FP8 lm_head requires the vocab-parallel shard start " | ||
| f"({start}) to be divisible by weight_block_size[0] " | ||
| f"({block_out})" | ||
| ) | ||
| start_idx = start // block_out | ||
| local_rows = param.shape[0] | ||
| assert loaded_weight.shape[0] >= start_idx + local_rows, ( | ||
| f"loaded scale has {loaded_weight.shape[0]} rows, " | ||
| f"need at least {start_idx + local_rows} " | ||
| f"(start_idx={start_idx}, local_rows={local_rows})" | ||
| ) | ||
| chunk = loaded_weight.narrow(0, start_idx, local_rows) | ||
| param.data.copy_(chunk) | ||
|
|
||
| return load | ||
|
|
||
|
|
||
| def _make_lm_head_scalar_scale_loader(): | ||
| """Per-parameter weight_loader for FP8 per-tensor / input scale on | ||
| ParallelLMHead. Per-tensor scales are not vocab-parallel; just copy. | ||
| """ | ||
|
|
||
| def load(param, loaded_weight): | ||
| if loaded_weight.numel() > param.data.numel(): | ||
| loaded_weight = loaded_weight.flatten()[0] | ||
| param.data.copy_(loaded_weight.reshape(param.data.shape)) | ||
|
|
||
| return load | ||
|
|
||
|
|
||
| class Fp8LinearMethod(LinearMethodBase): | ||
| """Linear method for FP8. | ||
| Supports loading FP8 checkpoints with static weight scale and | ||
|
|
@@ -344,14 +429,33 @@ def create_weights( | |
| ) | ||
| layer.register_parameter("weight", weight) | ||
|
|
||
| # WEIGHT / INPUT SCALES | ||
| # When this method is dispatched to a ParallelLMHead (opt-in via | ||
| # Fp8Config.lm_head_quantized), companion params can't share the | ||
| # default VocabParallelEmbedding.weight_loader (which assumes | ||
| # vocab-shaped tensors). Pick the right scale loader up front so | ||
| # we don't have to override it post-hoc -- set_weight_attrs() asserts | ||
| # against double-assignment of `weight_loader`. | ||
| if isinstance(layer, ParallelLMHead): | ||
| if self.block_quant: | ||
| scale_weight_loader = _make_lm_head_block_scale_loader( | ||
| layer, self.weight_block_size | ||
| ) | ||
| else: | ||
| scale_weight_loader = _make_lm_head_scalar_scale_loader() | ||
| input_scale_weight_loader = _make_lm_head_scalar_scale_loader() | ||
| else: | ||
| scale_weight_loader = weight_loader | ||
| input_scale_weight_loader = weight_loader | ||
|
|
||
| # WEIGHT SCALE | ||
| if not self.block_quant: | ||
| scale = create_fp8_scale_parameter( | ||
| PerTensorScaleParameter, | ||
| output_partition_sizes, | ||
| input_size_per_partition, | ||
| None, | ||
| weight_loader, | ||
| scale_weight_loader, | ||
| ) | ||
| layer.register_parameter("weight_scale", scale) | ||
| else: | ||
|
|
@@ -362,15 +466,17 @@ def create_weights( | |
| output_partition_sizes, | ||
| input_size_per_partition, | ||
| self.weight_block_size, | ||
| weight_loader, | ||
| scale_weight_loader, | ||
| scale_dtype=(torch.float8_e8m0fnu if self.is_scale_e8m0 else None), | ||
| ) | ||
| # The weight_scale_inv name is intentional for deepseekv3 | ||
| layer.register_parameter("weight_scale_inv", scale) | ||
|
|
||
| # INPUT ACTIVATION SCALE | ||
| if self.act_q_static: | ||
| scale = create_fp8_input_scale(output_partition_sizes, weight_loader) | ||
| scale = create_fp8_input_scale( | ||
| output_partition_sizes, input_scale_weight_loader | ||
| ) | ||
| set_weight_attrs(scale, {"scale_type": "input_scale"}) | ||
| layer.register_parameter("input_scale", scale) | ||
|
|
||
|
|
@@ -476,6 +582,159 @@ def apply( | |
| return self.fp8_linear.apply_weights(layer, x, bias) | ||
|
|
||
|
|
||
|
|
||
| class Fp8EmbeddingMethod(Fp8LinearMethod): | ||
| """Memory-first FP8 runtime for VocabParallelEmbedding. | ||
|
|
||
| The embedding table stays resident as FP8 and only gathered rows are | ||
| dequantized back to the model dtype during lookup. This is intended as a | ||
| correctness/VRAM test path, not the final fused performance design. | ||
| """ | ||
|
|
||
| def create_weights( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| input_size_per_partition: int, | ||
| output_partition_sizes: list[int], | ||
| input_size: int, | ||
| output_size: int, | ||
| params_dtype: torch.dtype, | ||
| **extra_weight_attrs, | ||
| ) -> None: | ||
| output_size_per_partition = sum(output_partition_sizes) | ||
| weight_loader = extra_weight_attrs.get("weight_loader") | ||
|
|
||
| layer.logical_widths = output_partition_sizes | ||
| layer.input_size_per_partition = input_size_per_partition | ||
| layer.output_size_per_partition = output_size_per_partition | ||
| layer.orig_dtype = params_dtype | ||
| layer.weight_block_size = None | ||
|
|
||
| if self.block_quant: | ||
| assert self.weight_block_size is not None | ||
| layer.weight_block_size = self.weight_block_size | ||
| validate_fp8_block_shape( | ||
| layer, | ||
| input_size, | ||
| output_size, | ||
| input_size_per_partition, | ||
| output_partition_sizes, | ||
| self.weight_block_size, | ||
| ) | ||
|
|
||
| weight = create_fp8_weight_parameter( | ||
| output_size_per_partition, input_size_per_partition, weight_loader | ||
| ) | ||
| layer.register_parameter("weight", weight) | ||
|
|
||
| if not self.block_quant: | ||
| scale = create_fp8_scale_parameter( | ||
| PerTensorScaleParameter, | ||
| output_partition_sizes, | ||
| input_size_per_partition, | ||
| None, | ||
| weight_loader, | ||
| ) | ||
| layer.register_parameter("weight_scale", scale) | ||
| else: | ||
| assert not self.act_q_static | ||
| assert self.weight_block_size is not None | ||
| scale = create_fp8_scale_parameter( | ||
| BlockQuantScaleParameter, | ||
| output_partition_sizes, | ||
| input_size_per_partition, | ||
| self.weight_block_size, | ||
| weight_loader, | ||
| scale_dtype=(torch.float8_e8m0fnu if self.is_scale_e8m0 else None), | ||
| ) | ||
| layer.register_parameter("weight_scale_inv", scale) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The layer.register_parameter("weight_scale_inv", scale)
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Implemented in e3547df: |
||
|
|
||
| self.fp8_linear = init_fp8_linear_kernel( | ||
| activation_quant_key=self.activation_quant_key, | ||
| weight_quant_key=self.weight_quant_key, | ||
| weight_shape=layer.weight.shape, | ||
| input_dtype=self.input_dtype, | ||
| out_dtype=self.out_dtype, | ||
| module_name=self.__class__.__name__, | ||
| ) | ||
| self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel) | ||
|
|
||
| def load_embedding_companion( | ||
| self, | ||
| layer: VocabParallelEmbedding, | ||
| param: torch.nn.Parameter, | ||
| loaded_weight: torch.Tensor, | ||
| ) -> None: | ||
| output_dim = getattr(param, "output_dim", None) | ||
| if output_dim is None: | ||
| if isinstance(param, UninitializedParameter): | ||
| param.materialize(tuple(loaded_weight.shape), dtype=loaded_weight.dtype) | ||
| assert param.data.shape == loaded_weight.shape | ||
| param.data.copy_(loaded_weight) | ||
| return | ||
|
|
||
| if isinstance(param, UninitializedParameter): | ||
| shape = list(loaded_weight.shape) | ||
| if self.block_quant: | ||
| assert self.weight_block_size is not None | ||
| block_n = self.weight_block_size[0] | ||
| shape[output_dim] = math.ceil( | ||
| layer.num_embeddings_per_partition / block_n | ||
| ) | ||
| else: | ||
| shape[output_dim] = layer.num_embeddings_per_partition | ||
| param.materialize(tuple(shape), dtype=loaded_weight.dtype) | ||
|
|
||
| if loaded_weight.shape == param.data.shape or loaded_weight.shape[output_dim] == 1: | ||
| param.data.copy_(loaded_weight) | ||
| return | ||
|
|
||
| if self.block_quant: | ||
| assert self.weight_block_size is not None | ||
| block_n = self.weight_block_size[0] | ||
| start = layer.shard_indices.org_vocab_start_index | ||
| assert start % block_n == 0, ( | ||
| f"FP8 embedding requires the vocab-parallel shard start " | ||
| f"({start}) to be divisible by weight_block_size[0] " | ||
| f"({block_n})" | ||
| ) | ||
| start_idx = start // block_n | ||
| else: | ||
| start_idx = layer.shard_indices.org_vocab_start_index | ||
|
|
||
| local_rows = param.data.shape[output_dim] | ||
| available_rows = max(0, loaded_weight.shape[output_dim] - start_idx) | ||
| copy_rows = min(local_rows, available_rows) | ||
|
|
||
| if copy_rows > 0: | ||
| shard = loaded_weight.narrow(output_dim, start_idx, copy_rows) | ||
| param.data.narrow(output_dim, 0, copy_rows).copy_(shard) | ||
| if copy_rows < local_rows: | ||
| param.data.narrow(output_dim, copy_rows, local_rows - copy_rows).fill_(0) | ||
|
|
||
| def process_weights_after_loading(self, layer: Module) -> None: | ||
| layer._already_called_process_weights_after_loading = True | ||
|
|
||
| def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: | ||
| out_dtype = getattr(layer, "orig_dtype", torch.get_default_dtype()) | ||
| flat_input = input_.reshape(-1).long() | ||
| weight_rows = layer.weight.index_select(0, flat_input).to(out_dtype) | ||
|
|
||
| if self.block_quant: | ||
| assert self.weight_block_size is not None | ||
| block_n, block_k = self.weight_block_size | ||
| scales = layer.weight_scale_inv.to(out_dtype) | ||
| row_blocks = torch.div(flat_input, block_n, rounding_mode="floor") | ||
| row_scales = scales.index_select(0, row_blocks) | ||
| row_scales = row_scales.repeat_interleave(block_k, dim=-1) | ||
| row_scales = row_scales[..., : layer.embedding_dim] | ||
| output = weight_rows * row_scales | ||
| else: | ||
| scales = layer.weight_scale.to(out_dtype) | ||
| output = weight_rows * scales | ||
|
|
||
| return output.reshape(*input_.shape, layer.embedding_dim) | ||
|
|
||
| # TODO(future PR): remove this class in favor of | ||
| # online/fp8.py::Fp8PerTensorOnlineLinearMethod | ||
| class Fp8OnlineLinearMethod(Fp8LinearMethod): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This loader is not robust to changes in the tensor parallel (TP) degree. If a checkpoint was saved with a higher TP degree,
loaded_weightmay contain multiple per-tensor scales (one per shard). Attempting toreshapea multi-element tensor into a single-elementparam.datawill raise aRuntimeError. It should handle this by taking the first element, which is consistent with how vLLM handles per-tensor scales during TP-degree-changing loads.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implemented in e3547df: this loader now checks
loaded_weight.numel() > param.data.numel(), takesloaded_weight.flatten()[0], then reshapes intoparam.data.shape.