-
-
Notifications
You must be signed in to change notification settings - Fork 16.6k
[FP8] Add opt-in ParallelLMHead dispatch to Fp8Config #41000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
72238d4
3520a71
5ca163e
511412b
49a8a6b
de77662
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,6 +38,10 @@ | |
| LinearMethodBase, | ||
| UnquantizedLinearMethod, | ||
| ) | ||
| from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
| ParallelLMHead, | ||
| UnquantizedEmbeddingMethod, | ||
| ) | ||
| from vllm.model_executor.layers.quantization import QuantizationMethods | ||
| from vllm.model_executor.layers.quantization.base_config import ( | ||
| QuantizationConfig, | ||
|
|
@@ -102,10 +106,12 @@ def __init__( | |
| activation_scheme: str = "dynamic", | ||
| ignored_layers: list[str] | None = None, | ||
| weight_block_size: list[int] | None = None, | ||
| lm_head_quantized: bool = False, | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized | ||
| self.lm_head_quantized = lm_head_quantized | ||
|
|
||
| if activation_scheme not in ACTIVATION_SCHEMES: | ||
| raise ValueError(f"Unsupported activation scheme {activation_scheme}") | ||
|
|
@@ -162,22 +168,29 @@ def from_config(cls, config: dict[str, Any]) -> "Fp8Config": | |
| ignored_layers = cls.get_from_keys_or( | ||
| config, ["modules_to_not_convert"], None | ||
| ) | ||
| lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) | ||
| return cls( | ||
| is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, | ||
| activation_scheme=activation_scheme, | ||
| ignored_layers=ignored_layers, | ||
| weight_block_size=weight_block_size, | ||
| lm_head_quantized=lm_head_quantized, | ||
| ) | ||
|
|
||
| def get_quant_method( | ||
| self, layer: torch.nn.Module, prefix: str | ||
| ) -> "QuantizeMethodBase | None": | ||
| if isinstance(layer, LinearBase): | ||
| is_parallel_lm_head = isinstance(layer, ParallelLMHead) | ||
| if isinstance(layer, LinearBase) or ( | ||
| is_parallel_lm_head and self.lm_head_quantized | ||
| ): | ||
|
Comment on lines
+183
to
+186
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation of Furthermore, as noted in the PR description, |
||
| if is_layer_skipped( | ||
| prefix=prefix, | ||
| ignored_layers=self.ignored_layers, | ||
| fused_mapping=self.packed_modules_mapping, | ||
| ): | ||
| if is_parallel_lm_head: | ||
| return UnquantizedEmbeddingMethod() | ||
| return UnquantizedLinearMethod() | ||
| if not self.is_checkpoint_fp8_serialized: | ||
| online_method = Fp8OnlineLinearMethod(self) | ||
|
|
@@ -254,6 +267,50 @@ def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None: | |
| set_weight_attrs(new, attrs_to_set) | ||
|
|
||
|
|
||
|
|
||
| def _make_lm_head_block_scale_loader(layer, block_size): | ||
| """Per-parameter weight_loader for FP8 block scale_inv on ParallelLMHead. | ||
|
|
||
| The default VocabParallelEmbedding.weight_loader assumes vocab-shaped | ||
| tensors and rejects companion params with a different leading dim | ||
| (e.g., weight_scale_inv has shape [vocab/block_out, hidden/block_in]). | ||
| This loader shards the scale tensor along the block-aligned vocab dim | ||
| using the layer's existing shard_indices, and zero-fills any padding | ||
| rows the param was sized for. | ||
| """ | ||
| block_out = block_size[0] | ||
|
|
||
| def load(param, loaded_weight): | ||
| start = layer.shard_indices.org_vocab_start_index | ||
| assert start % block_out == 0, ( | ||
| f"FP8 lm_head requires the vocab-parallel shard start " | ||
| f"({start}) to be divisible by weight_block_size[0] " | ||
| f"({block_out})" | ||
| ) | ||
| start_idx = start // block_out | ||
| local_rows = param.shape[0] | ||
| assert loaded_weight.shape[0] >= start_idx + local_rows, ( | ||
| f"loaded scale has {loaded_weight.shape[0]} rows, " | ||
| f"need at least {start_idx + local_rows} " | ||
| f"(start_idx={start_idx}, local_rows={local_rows})" | ||
| ) | ||
| chunk = loaded_weight.narrow(0, start_idx, local_rows) | ||
| param.data.copy_(chunk) | ||
|
|
||
| return load | ||
|
|
||
|
|
||
| def _make_lm_head_scalar_scale_loader(): | ||
| """Per-parameter weight_loader for FP8 per-tensor / input scale on | ||
| ParallelLMHead. Per-tensor scales are not vocab-parallel; just copy. | ||
| """ | ||
|
|
||
| def load(param, loaded_weight): | ||
| param.data.copy_(loaded_weight.reshape(param.data.shape)) | ||
|
|
||
| return load | ||
|
|
||
|
|
||
| class Fp8LinearMethod(LinearMethodBase): | ||
| """Linear method for FP8. | ||
| Supports loading FP8 checkpoints with static weight scale and | ||
|
|
@@ -344,14 +401,33 @@ def create_weights( | |
| ) | ||
| layer.register_parameter("weight", weight) | ||
|
|
||
| # WEIGHT / INPUT SCALES | ||
| # When this method is dispatched to a ParallelLMHead (opt-in via | ||
| # Fp8Config.lm_head_quantized), companion params can't share the | ||
| # default VocabParallelEmbedding.weight_loader (which assumes | ||
| # vocab-shaped tensors). Pick the right scale loader up front so | ||
| # we don't have to override it post-hoc -- set_weight_attrs() asserts | ||
| # against double-assignment of `weight_loader`. | ||
| if isinstance(layer, ParallelLMHead): | ||
| if self.block_quant: | ||
| scale_weight_loader = _make_lm_head_block_scale_loader( | ||
| layer, self.weight_block_size | ||
| ) | ||
| else: | ||
| scale_weight_loader = _make_lm_head_scalar_scale_loader() | ||
| input_scale_weight_loader = _make_lm_head_scalar_scale_loader() | ||
| else: | ||
| scale_weight_loader = weight_loader | ||
| input_scale_weight_loader = weight_loader | ||
|
|
||
| # WEIGHT SCALE | ||
| if not self.block_quant: | ||
| scale = create_fp8_scale_parameter( | ||
| PerTensorScaleParameter, | ||
| output_partition_sizes, | ||
| input_size_per_partition, | ||
| None, | ||
| weight_loader, | ||
| scale_weight_loader, | ||
| ) | ||
| layer.register_parameter("weight_scale", scale) | ||
| else: | ||
|
|
@@ -362,15 +438,17 @@ def create_weights( | |
| output_partition_sizes, | ||
| input_size_per_partition, | ||
| self.weight_block_size, | ||
| weight_loader, | ||
| scale_weight_loader, | ||
| scale_dtype=(torch.float8_e8m0fnu if self.is_scale_e8m0 else None), | ||
| ) | ||
| # The weight_scale_inv name is intentional for deepseekv3 | ||
| layer.register_parameter("weight_scale_inv", scale) | ||
|
|
||
| # INPUT ACTIVATION SCALE | ||
| if self.act_q_static: | ||
| scale = create_fp8_input_scale(output_partition_sizes, weight_loader) | ||
| scale = create_fp8_input_scale( | ||
| output_partition_sizes, input_scale_weight_loader | ||
| ) | ||
| set_weight_attrs(scale, {"scale_type": "input_scale"}) | ||
| layer.register_parameter("input_scale", scale) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Importing
ParallelLMHeadandUnquantizedEmbeddingMethodat the top level offp8.pyfromvllm.model_executor.layers.vocab_parallel_embeddingmay lead to circular import issues in the future, as quantization configs are often imported by the layers they configure. It is generally safer to perform these imports insideget_quant_methodor useTYPE_CHECKINGfor type hints andimportlibfor runtime checks if necessary.