[Model] Optional FP8 lm_head compression for Llama and Mistral#35696
[Model] Optional FP8 lm_head compression for Llama and Mistral#35696lucaspirola wants to merge 4 commits intovllm-project:mainfrom
Conversation
Add optional post-load FP8 compression of lm_head weights, enabled by VLLM_FP8_LM_HEAD=1 environment variable. After load_weights completes, the lm_head weight tensor is cast from BF16/FP16 to float8_e4m3fn in-place, halving its VRAM footprint. For models with large vocabularies (e.g. 131K tokens), this saves ~640 MB of VRAM that can be used for KV cache. The FP8 weight is automatically cast back to the compute dtype during the forward pass by UnquantizedLinearMethod.apply(). This is opt-in and only activates when: - VLLM_FP8_LM_HEAD=1 is set - tie_word_embeddings is False - lm_head has a weight attribute - weight is not already FP8 Signed-off-by: Lucas Pirola <lucaspirola@gmail.com> Signed-off-by: Lucas Pirola <lucaspirola@users.noreply.github.com>
There was a problem hiding this comment.
Code Review
This pull request introduces an optional FP8 compression for the lm_head in Llama and Mistral models, which can lead to significant VRAM savings. The implementation is gated behind the VLLM_FP8_LM_HEAD environment variable, making it a safe, opt-in feature. My main feedback is regarding code duplication between the llama.py and mistral.py model files. I've left specific comments on how to refactor this to improve maintainability.
| # Compress lm_head to FP8 to save VRAM on large-vocab models. | ||
| # Saves ~640 MB for 131K vocab x 5120 dim. | ||
| # Enabled by VLLM_FP8_LM_HEAD=1 env var. | ||
| # The UnquantizedLinearMethod.apply() handles FP8->compute_dtype cast. | ||
| if ( | ||
| os.environ.get("VLLM_FP8_LM_HEAD") | ||
| and not self.config.tie_word_embeddings | ||
| and hasattr(self.lm_head, "weight") | ||
| and self.lm_head.weight.dtype != torch.float8_e4m3fn | ||
| ): | ||
| saved_mb = self.lm_head.weight.numel() / 1024**2 | ||
| self.lm_head.weight = torch.nn.Parameter( | ||
| self.lm_head.weight.data.to(torch.float8_e4m3fn), | ||
| requires_grad=False, | ||
| ) | ||
| logger.info( | ||
| "Compressed lm_head to float8_e4m3fn (saved %.0f MB VRAM)", | ||
| saved_mb, | ||
| ) |
There was a problem hiding this comment.
This logic for compressing lm_head is duplicated in vllm/model_executor/models/mistral.py. To improve maintainability and avoid code duplication, you can extract this logic into a new protected method within the LlamaForCausalLM class. This method can then be called from both LlamaForCausalLM.load_weights and MistralForCausalLM.load_weights.
Here's a suggested implementation for the new method to be added to LlamaForCausalLM:
def _maybe_compress_lm_head_to_fp8(self, logger):
"""Compresses lm_head to FP8 if VLLM_FP8_LM_HEAD is set."""
if (
os.environ.get("VLLM_FP8_LM_HEAD")
and not self.config.tie_word_embeddings
and hasattr(self.lm_head, "weight")
and self.lm_head.weight.dtype != torch.float8_e4m3fn
):
saved_mb = self.lm_head.weight.numel() / 1024**2
self.lm_head.weight = torch.nn.Parameter(
self.lm_head.weight.data.to(torch.float8_e4m3fn),
requires_grad=False,
)
logger.info(
"Compressed lm_head to float8_e4m3fn (saved %.0f MB VRAM)",
saved_mb,
)Then you can replace this block with a call to the new method.
| # Compress lm_head to FP8 to save VRAM on large-vocab models. | |
| # Saves ~640 MB for 131K vocab x 5120 dim. | |
| # Enabled by VLLM_FP8_LM_HEAD=1 env var. | |
| # The UnquantizedLinearMethod.apply() handles FP8->compute_dtype cast. | |
| if ( | |
| os.environ.get("VLLM_FP8_LM_HEAD") | |
| and not self.config.tie_word_embeddings | |
| and hasattr(self.lm_head, "weight") | |
| and self.lm_head.weight.dtype != torch.float8_e4m3fn | |
| ): | |
| saved_mb = self.lm_head.weight.numel() / 1024**2 | |
| self.lm_head.weight = torch.nn.Parameter( | |
| self.lm_head.weight.data.to(torch.float8_e4m3fn), | |
| requires_grad=False, | |
| ) | |
| logger.info( | |
| "Compressed lm_head to float8_e4m3fn (saved %.0f MB VRAM)", | |
| saved_mb, | |
| ) | |
| self._maybe_compress_lm_head_to_fp8(logger) |
| # Compress lm_head to FP8 to save VRAM on large-vocab models. | ||
| # Enabled by VLLM_FP8_LM_HEAD=1 env var. | ||
| if ( | ||
| os.environ.get("VLLM_FP8_LM_HEAD") | ||
| and not self.config.tie_word_embeddings | ||
| and hasattr(self.lm_head, "weight") | ||
| and self.lm_head.weight.dtype != torch.float8_e4m3fn | ||
| ): | ||
| saved_mb = self.lm_head.weight.numel() / 1024**2 | ||
| self.lm_head.weight = torch.nn.Parameter( | ||
| self.lm_head.weight.data.to(torch.float8_e4m3fn), | ||
| requires_grad=False, | ||
| ) | ||
| logger.info( | ||
| "Compressed lm_head to float8_e4m3fn (saved %.0f MB VRAM)", | ||
| saved_mb, | ||
| ) |
There was a problem hiding this comment.
This block of code is a duplicate of the logic in LlamaForCausalLM.load_weights. To follow the DRY (Don't Repeat Yourself) principle, please use the new _maybe_compress_lm_head_to_fp8 method from the base class LlamaForCausalLM as suggested in the review of vllm/model_executor/models/llama.py.
| # Compress lm_head to FP8 to save VRAM on large-vocab models. | |
| # Enabled by VLLM_FP8_LM_HEAD=1 env var. | |
| if ( | |
| os.environ.get("VLLM_FP8_LM_HEAD") | |
| and not self.config.tie_word_embeddings | |
| and hasattr(self.lm_head, "weight") | |
| and self.lm_head.weight.dtype != torch.float8_e4m3fn | |
| ): | |
| saved_mb = self.lm_head.weight.numel() / 1024**2 | |
| self.lm_head.weight = torch.nn.Parameter( | |
| self.lm_head.weight.data.to(torch.float8_e4m3fn), | |
| requires_grad=False, | |
| ) | |
| logger.info( | |
| "Compressed lm_head to float8_e4m3fn (saved %.0f MB VRAM)", | |
| saved_mb, | |
| ) | |
| self._maybe_compress_lm_head_to_fp8(logger) |
|
this is a useful feature, but please find another way to configure it besides an enviornment variable |
…xtract shared helper Address two reviewer concerns: 1. robertgshaw2-redhat: replace environment variable with a proper CLI arg - Add fp8_lm_head: bool to ModelConfig with docstring - Add fp8_lm_head to EngineArgs and wire through create_model_config() - Add --fp8-lm-head / --no-fp8-lm-head flags to argument parser - Store vllm_config.model_config.fp8_lm_head in __init__ as self._compress_lm_head 2. gemini-code-assist: extract duplicated lm_head compression logic - Add maybe_compress_lm_head_to_fp8() helper to models/utils.py - Both LlamaForCausalLM and MistralForCausalLM call the shared helper Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Lucas Pirola <lucaspirola@users.noreply.github.com>
Summary
lm_headweights for Llama and Mistral model familiesVLLM_FP8_LM_HEAD=1environment variable (opt-in, no behavior change by default)load_weightscompletes, castslm_head.weightfrom BF16 tofloat8_e4m3fnin-placeMotivation
For models with large vocabularies (131K+ tokens),
lm_headconsumes ~1.2 GB in BF16. Compressing to FP8 saves ~640 MB of VRAM, which can be used for KV cache to extend context length. On VRAM-constrained GPUs (16 GB), this represents a significant context increase.Example on Devstral-24B (131K vocab, 5120 dim) on RTX 5080:
Dependencies
Requires FP8 cast support in
UnquantizedLinearMethod.apply()— see PR #35694.Test plan
VLLM_FP8_LM_HEAD=1— correct outputs, 640 MB saved🤖 Generated with Claude Code