diff --git a/vllm/config/model.py b/vllm/config/model.py index 4e3568fa15b1..b7b3f0c41113 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -196,6 +196,12 @@ class ModelConfig: graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid for maximal performance and flexibility.""" + fp8_lm_head: bool = False + """Compress the lm_head weight to float8_e4m3fn after loading to reduce + VRAM usage on models with large vocabularies. Saves ~640 MB for a 131K + vocab × 5120 hidden-dim model. The UnquantizedLinearMethod handles the + FP8-to-compute-dtype cast at runtime. Has no effect when + tie_word_embeddings is True or when lm_head is already quantized.""" enable_return_routed_experts: bool = False """Whether to return routed experts.""" max_logprobs: int = 20 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 93384fd78cd7..26cb80466a15 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -473,6 +473,7 @@ class EngineArgs: quantization: QuantizationMethods | str | None = ModelConfig.quantization allow_deprecated_quantization: bool = ModelConfig.allow_deprecated_quantization enforce_eager: bool = ModelConfig.enforce_eager + fp8_lm_head: bool = ModelConfig.fp8_lm_head disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce language_model_only: bool = MultiModalConfig.language_model_only limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field( @@ -695,6 +696,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **model_kwargs["allow_deprecated_quantization"], ) model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"]) + model_group.add_argument("--fp8-lm-head", **model_kwargs["fp8_lm_head"]) model_group.add_argument( "--enable-return-routed-experts", **model_kwargs["enable_return_routed_experts"], @@ -1354,6 +1356,7 @@ def create_model_config(self) -> ModelConfig: quantization=self.quantization, allow_deprecated_quantization=self.allow_deprecated_quantization, enforce_eager=self.enforce_eager, + fp8_lm_head=self.fp8_lm_head, enable_return_routed_experts=self.enable_return_routed_experts, max_logprobs=self.max_logprobs, logprobs_mode=self.logprobs_mode, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 16d3cf88a60b..a7850e4d9d36 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -57,6 +57,7 @@ maybe_remap_kv_scale_name, ) from vllm.sequence import IntermediateTensors +from vllm.utils import init_logger from vllm.v1.attention.backend import AttentionType from .adapters import as_embedding_model, as_seq_cls_model @@ -73,9 +74,12 @@ is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, + maybe_compress_lm_head_to_fp8, maybe_prefix, ) +logger = init_logger(__name__) + class LlamaMLP(nn.Module): def __init__( @@ -555,6 +559,7 @@ def __init__( self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) + self._compress_lm_head = vllm_config.model_config.fp8_lm_head def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers @@ -603,7 +608,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights(weights) + loaded = loader.load_weights(weights) + if self._compress_lm_head: + maybe_compress_lm_head_to_fp8(self, logger) + return loaded class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)): diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index ce1332d0c9d1..dc04de8d82c6 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -24,9 +24,12 @@ LlamaModel, ) from vllm.sequence import IntermediateTensors +from vllm.utils import init_logger from vllm.v1.attention.backend import AttentionType -from .utils import AutoWeightsLoader +from .utils import AutoWeightsLoader, maybe_compress_lm_head_to_fp8 + +logger = init_logger(__name__) class MistralMLP(nn.Module): @@ -281,10 +284,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights( + loaded = loader.load_weights( self.maybe_remap_mistral(name, loaded_weight) for name, loaded_weight in weights ) + if self._compress_lm_head: + maybe_compress_lm_head_to_fp8(self, logger) + return loaded def maybe_remap_mistral( self, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index abc953b7f980..ca84b404300e 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools +import logging from collections.abc import Callable, Iterable, Mapping from contextlib import contextmanager from dataclasses import dataclass, field @@ -875,3 +876,27 @@ def get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: if feature_layer_index < 0: return num_hidden_layers + feature_layer_index + 1 return feature_layer_index + + +def maybe_compress_lm_head_to_fp8( + model: nn.Module, + logger: "logging.Logger", +) -> None: + """Compress lm_head weight to float8_e4m3fn in-place to reduce VRAM. + + Called from ``load_weights`` when ``--fp8-lm-head`` is set. + Has no effect if lm_head is already quantized or tied to embeddings. + The :class:`~vllm.model_executor.layers.linear.UnquantizedLinearMethod` + handles the FP8-to-compute-dtype cast at runtime. + """ + lm_head = getattr(model, "lm_head", None) + if lm_head is None or not hasattr(lm_head, "weight"): + return + if lm_head.weight.dtype == torch.float8_e4m3fn: + return + saved_mb = lm_head.weight.numel() / 1024**2 + lm_head.weight = nn.Parameter( + lm_head.weight.data.to(torch.float8_e4m3fn), + requires_grad=False, + ) + logger.info("Compressed lm_head to float8_e4m3fn (saved %.0f MB VRAM)", saved_mb)