Skip to content

[Model] Optional FP8 lm_head compression for Llama and Mistral#35696

Open
lucaspirola wants to merge 4 commits intovllm-project:mainfrom
lucaspirola:feat/fp8-lm-head-compress
Open

[Model] Optional FP8 lm_head compression for Llama and Mistral#35696
lucaspirola wants to merge 4 commits intovllm-project:mainfrom
lucaspirola:feat/fp8-lm-head-compress

Conversation

@lucaspirola
Copy link
Copy Markdown

Summary

  • Add optional post-load FP8 compression of lm_head weights for Llama and Mistral model families
  • Enabled by VLLM_FP8_LM_HEAD=1 environment variable (opt-in, no behavior change by default)
  • After load_weights completes, casts lm_head.weight from BF16 to float8_e4m3fn in-place

Motivation

For models with large vocabularies (131K+ tokens), lm_head consumes ~1.2 GB in BF16. Compressing to FP8 saves ~640 MB of VRAM, which can be used for KV cache to extend context length. On VRAM-constrained GPUs (16 GB), this represents a significant context increase.

Example on Devstral-24B (131K vocab, 5120 dim) on RTX 5080:

lm_head dtype VRAM saved Extra context
BF16 (default)
FP8 640 MB +~10K tokens

Dependencies

Requires FP8 cast support in UnquantizedLinearMethod.apply() — see PR #35694.

Test plan

  • Verified on Devstral-24B with VLLM_FP8_LM_HEAD=1 — correct outputs, 640 MB saved
  • Verified default behavior unchanged (env var not set)
  • Pre-commit passes (ruff, mypy, typos)

🤖 Generated with Claude Code

Add optional post-load FP8 compression of lm_head weights,
enabled by VLLM_FP8_LM_HEAD=1 environment variable.

After load_weights completes, the lm_head weight tensor is cast
from BF16/FP16 to float8_e4m3fn in-place, halving its VRAM
footprint. For models with large vocabularies (e.g. 131K tokens),
this saves ~640 MB of VRAM that can be used for KV cache.

The FP8 weight is automatically cast back to the compute dtype
during the forward pass by UnquantizedLinearMethod.apply().

This is opt-in and only activates when:
- VLLM_FP8_LM_HEAD=1 is set
- tie_word_embeddings is False
- lm_head has a weight attribute
- weight is not already FP8

Signed-off-by: Lucas Pirola <lucaspirola@gmail.com>

Signed-off-by: Lucas Pirola <lucaspirola@users.noreply.github.com>
@mergify mergify Bot added the llama Related to Llama models label Mar 2, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optional FP8 compression for the lm_head in Llama and Mistral models, which can lead to significant VRAM savings. The implementation is gated behind the VLLM_FP8_LM_HEAD environment variable, making it a safe, opt-in feature. My main feedback is regarding code duplication between the llama.py and mistral.py model files. I've left specific comments on how to refactor this to improve maintainability.

Comment thread vllm/model_executor/models/llama.py Outdated
Comment on lines +612 to +630
# Compress lm_head to FP8 to save VRAM on large-vocab models.
# Saves ~640 MB for 131K vocab x 5120 dim.
# Enabled by VLLM_FP8_LM_HEAD=1 env var.
# The UnquantizedLinearMethod.apply() handles FP8->compute_dtype cast.
if (
os.environ.get("VLLM_FP8_LM_HEAD")
and not self.config.tie_word_embeddings
and hasattr(self.lm_head, "weight")
and self.lm_head.weight.dtype != torch.float8_e4m3fn
):
saved_mb = self.lm_head.weight.numel() / 1024**2
self.lm_head.weight = torch.nn.Parameter(
self.lm_head.weight.data.to(torch.float8_e4m3fn),
requires_grad=False,
)
logger.info(
"Compressed lm_head to float8_e4m3fn (saved %.0f MB VRAM)",
saved_mb,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic for compressing lm_head is duplicated in vllm/model_executor/models/mistral.py. To improve maintainability and avoid code duplication, you can extract this logic into a new protected method within the LlamaForCausalLM class. This method can then be called from both LlamaForCausalLM.load_weights and MistralForCausalLM.load_weights.

Here's a suggested implementation for the new method to be added to LlamaForCausalLM:

    def _maybe_compress_lm_head_to_fp8(self, logger):
        """Compresses lm_head to FP8 if VLLM_FP8_LM_HEAD is set."""
        if (
            os.environ.get("VLLM_FP8_LM_HEAD")
            and not self.config.tie_word_embeddings
            and hasattr(self.lm_head, "weight")
            and self.lm_head.weight.dtype != torch.float8_e4m3fn
        ):
            saved_mb = self.lm_head.weight.numel() / 1024**2
            self.lm_head.weight = torch.nn.Parameter(
                self.lm_head.weight.data.to(torch.float8_e4m3fn),
                requires_grad=False,
            )
            logger.info(
                "Compressed lm_head to float8_e4m3fn (saved %.0f MB VRAM)",
                saved_mb,
            )

Then you can replace this block with a call to the new method.

Suggested change
# Compress lm_head to FP8 to save VRAM on large-vocab models.
# Saves ~640 MB for 131K vocab x 5120 dim.
# Enabled by VLLM_FP8_LM_HEAD=1 env var.
# The UnquantizedLinearMethod.apply() handles FP8->compute_dtype cast.
if (
os.environ.get("VLLM_FP8_LM_HEAD")
and not self.config.tie_word_embeddings
and hasattr(self.lm_head, "weight")
and self.lm_head.weight.dtype != torch.float8_e4m3fn
):
saved_mb = self.lm_head.weight.numel() / 1024**2
self.lm_head.weight = torch.nn.Parameter(
self.lm_head.weight.data.to(torch.float8_e4m3fn),
requires_grad=False,
)
logger.info(
"Compressed lm_head to float8_e4m3fn (saved %.0f MB VRAM)",
saved_mb,
)
self._maybe_compress_lm_head_to_fp8(logger)

Comment thread vllm/model_executor/models/mistral.py Outdated
Comment on lines +293 to +309
# Compress lm_head to FP8 to save VRAM on large-vocab models.
# Enabled by VLLM_FP8_LM_HEAD=1 env var.
if (
os.environ.get("VLLM_FP8_LM_HEAD")
and not self.config.tie_word_embeddings
and hasattr(self.lm_head, "weight")
and self.lm_head.weight.dtype != torch.float8_e4m3fn
):
saved_mb = self.lm_head.weight.numel() / 1024**2
self.lm_head.weight = torch.nn.Parameter(
self.lm_head.weight.data.to(torch.float8_e4m3fn),
requires_grad=False,
)
logger.info(
"Compressed lm_head to float8_e4m3fn (saved %.0f MB VRAM)",
saved_mb,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code is a duplicate of the logic in LlamaForCausalLM.load_weights. To follow the DRY (Don't Repeat Yourself) principle, please use the new _maybe_compress_lm_head_to_fp8 method from the base class LlamaForCausalLM as suggested in the review of vllm/model_executor/models/llama.py.

Suggested change
# Compress lm_head to FP8 to save VRAM on large-vocab models.
# Enabled by VLLM_FP8_LM_HEAD=1 env var.
if (
os.environ.get("VLLM_FP8_LM_HEAD")
and not self.config.tie_word_embeddings
and hasattr(self.lm_head, "weight")
and self.lm_head.weight.dtype != torch.float8_e4m3fn
):
saved_mb = self.lm_head.weight.numel() / 1024**2
self.lm_head.weight = torch.nn.Parameter(
self.lm_head.weight.data.to(torch.float8_e4m3fn),
requires_grad=False,
)
logger.info(
"Compressed lm_head to float8_e4m3fn (saved %.0f MB VRAM)",
saved_mb,
)
self._maybe_compress_lm_head_to_fp8(logger)

@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

this is a useful feature, but please find another way to configure it besides an enviornment variable

lucaspirola and others added 3 commits March 5, 2026 14:45
…xtract shared helper

Address two reviewer concerns:
1. robertgshaw2-redhat: replace environment variable with a proper CLI arg
   - Add fp8_lm_head: bool to ModelConfig with docstring
   - Add fp8_lm_head to EngineArgs and wire through create_model_config()
   - Add --fp8-lm-head / --no-fp8-lm-head flags to argument parser
   - Store vllm_config.model_config.fp8_lm_head in __init__ as self._compress_lm_head

2. gemini-code-assist: extract duplicated lm_head compression logic
   - Add maybe_compress_lm_head_to_fp8() helper to models/utils.py
   - Both LlamaForCausalLM and MistralForCausalLM call the shared helper

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Signed-off-by: Lucas Pirola <lucaspirola@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llama Related to Llama models mistral Related to Mistral models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants