Skip to content

[Core] Support FP8 weight storage in unquantized linear and embedding layers#35694

Open
lucaspirola wants to merge 2 commits intovllm-project:mainfrom
lucaspirola:feat/fp8-unquantized-weights
Open

[Core] Support FP8 weight storage in unquantized linear and embedding layers#35694
lucaspirola wants to merge 2 commits intovllm-project:mainfrom
lucaspirola:feat/fp8-unquantized-weights

Conversation

@lucaspirola
Copy link
Copy Markdown

Summary

  • Add FP8 (float8_e4m3fn / float8_e5m2) dtype handling to UnquantizedLinearMethod and UnquantizedEmbeddingMethod
  • Cast FP8 weights to compute dtype before GEMM operations and after embedding lookup
  • Preserve FP8 dtype in VocabParallelEmbedding.weight_loader instead of silently upcasting to BF16

Motivation

This enables post-training FP8 compression of embed_tokens and lm_head layers to save VRAM without requiring a dedicated quantization scheme. For models with large vocabularies (131K+), this frees ~640 MB of VRAM that can be used for KV cache, significantly extending context length on memory-constrained GPUs.

Test plan

  • Verified on GLM-4.7-Flash-REAP-23B-A3B with FP8 embed_tokens (saves 464 MB)
  • Verified on Devstral-24B with FP8 embed_tokens (saves 640 MB)
  • Pre-commit passes (ruff, mypy, typos)
  • Unit tests for FP8 dtype path in both methods

🤖 Generated with Claude Code

… layers

Add FP8 (float8_e4m3fn / float8_e5m2) dtype handling to
UnquantizedLinearMethod and UnquantizedEmbeddingMethod so that
layers with FP8-compressed weights work correctly without a
dedicated quantization scheme.

Changes:
- UnquantizedLinearMethod.apply: cast FP8 weight to input dtype
  before GEMM
- UnquantizedEmbeddingMethod.apply: same cast for lm_head-as-linear
- UnquantizedEmbeddingMethod.embedding: cast FP8 lookup output to
  BF16
- VocabParallelEmbedding.weight_loader: preserve FP8 dtype when
  loading weights instead of silently upcasting to BF16

This enables post-training FP8 compression of embed_tokens and
lm_head layers to save VRAM (~640 MB for 131K-vocab models),
freeing memory for larger KV caches.

Signed-off-by: Lucas Pirola <lucaspirola@gmail.com>

Signed-off-by: Lucas Pirola <lucaspirola@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for FP8 weight storage in unquantized linear and embedding layers. This is achieved by casting FP8 weights to the compute dtype before GEMM operations and after embedding lookups, while preserving the FP8 dtype for weights in memory to save VRAM. The changes look mostly correct, but I've found a critical issue where the compute dtype is hardcoded when casting embedding outputs, which could lead to errors or incorrect behavior for models not using bfloat16.

output = F.embedding(input_, layer.weight)
# Support FP8 weight storage: cast to compute dtype after lookup
if output.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
output = output.to(torch.bfloat16)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Hardcoding the cast to torch.bfloat16 is incorrect as it assumes the model's compute dtype is always bfloat16. This will cause dtype mismatches or incorrect results if the model is running with a different compute dtype, such as float16 or float32.

To ensure correctness across different model configurations, you should use torch.get_default_dtype() to dynamically get the model's compute dtype, which is set during model loading.

Suggested change
output = output.to(torch.bfloat16)
output = output.to(torch.get_default_dtype())

…cast

Replace hardcoded torch.bfloat16 with torch.get_default_dtype() so that
float16 and float32 models are handled correctly.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Signed-off-by: Lucas Pirola <lucaspirola@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant