[Core] Support FP8 weight storage in unquantized linear and embedding layers#35694
[Core] Support FP8 weight storage in unquantized linear and embedding layers#35694lucaspirola wants to merge 2 commits intovllm-project:mainfrom
Conversation
… layers Add FP8 (float8_e4m3fn / float8_e5m2) dtype handling to UnquantizedLinearMethod and UnquantizedEmbeddingMethod so that layers with FP8-compressed weights work correctly without a dedicated quantization scheme. Changes: - UnquantizedLinearMethod.apply: cast FP8 weight to input dtype before GEMM - UnquantizedEmbeddingMethod.apply: same cast for lm_head-as-linear - UnquantizedEmbeddingMethod.embedding: cast FP8 lookup output to BF16 - VocabParallelEmbedding.weight_loader: preserve FP8 dtype when loading weights instead of silently upcasting to BF16 This enables post-training FP8 compression of embed_tokens and lm_head layers to save VRAM (~640 MB for 131K-vocab models), freeing memory for larger KV caches. Signed-off-by: Lucas Pirola <lucaspirola@gmail.com> Signed-off-by: Lucas Pirola <lucaspirola@users.noreply.github.com>
There was a problem hiding this comment.
Code Review
This pull request adds support for FP8 weight storage in unquantized linear and embedding layers. This is achieved by casting FP8 weights to the compute dtype before GEMM operations and after embedding lookups, while preserving the FP8 dtype for weights in memory to save VRAM. The changes look mostly correct, but I've found a critical issue where the compute dtype is hardcoded when casting embedding outputs, which could lead to errors or incorrect behavior for models not using bfloat16.
| output = F.embedding(input_, layer.weight) | ||
| # Support FP8 weight storage: cast to compute dtype after lookup | ||
| if output.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): | ||
| output = output.to(torch.bfloat16) |
There was a problem hiding this comment.
Hardcoding the cast to torch.bfloat16 is incorrect as it assumes the model's compute dtype is always bfloat16. This will cause dtype mismatches or incorrect results if the model is running with a different compute dtype, such as float16 or float32.
To ensure correctness across different model configurations, you should use torch.get_default_dtype() to dynamically get the model's compute dtype, which is set during model loading.
| output = output.to(torch.bfloat16) | |
| output = output.to(torch.get_default_dtype()) |
…cast Replace hardcoded torch.bfloat16 with torch.get_default_dtype() so that float16 and float32 models are handled correctly. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Lucas Pirola <lucaspirola@users.noreply.github.com>
Summary
float8_e4m3fn/float8_e5m2) dtype handling toUnquantizedLinearMethodandUnquantizedEmbeddingMethodVocabParallelEmbedding.weight_loaderinstead of silently upcasting to BF16Motivation
This enables post-training FP8 compression of
embed_tokensandlm_headlayers to save VRAM without requiring a dedicated quantization scheme. For models with large vocabularies (131K+), this frees ~640 MB of VRAM that can be used for KV cache, significantly extending context length on memory-constrained GPUs.Test plan
🤖 Generated with Claude Code