From c94480d155e24f46c5e210b0be83dd321059ba9d Mon Sep 17 00:00:00 2001 From: Lucas Pirola Date: Mon, 2 Mar 2026 03:11:51 +0100 Subject: [PATCH 1/2] [Core] Support FP8 weight storage in unquantized linear and embedding layers Add FP8 (float8_e4m3fn / float8_e5m2) dtype handling to UnquantizedLinearMethod and UnquantizedEmbeddingMethod so that layers with FP8-compressed weights work correctly without a dedicated quantization scheme. Changes: - UnquantizedLinearMethod.apply: cast FP8 weight to input dtype before GEMM - UnquantizedEmbeddingMethod.apply: same cast for lm_head-as-linear - UnquantizedEmbeddingMethod.embedding: cast FP8 lookup output to BF16 - VocabParallelEmbedding.weight_loader: preserve FP8 dtype when loading weights instead of silently upcasting to BF16 This enables post-training FP8 compression of embed_tokens and lm_head layers to save VRAM (~640 MB for 131K-vocab models), freeing memory for larger KV caches. Signed-off-by: Lucas Pirola Signed-off-by: Lucas Pirola --- vllm/model_executor/layers/linear.py | 8 ++++++-- .../layers/vocab_parallel_embedding.py | 20 +++++++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f0d06e179f33..eb6ae7df158b 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -217,9 +217,13 @@ def apply( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: + weight = layer.weight + # Support FP8 weight storage: cast to compute dtype for GEMM + if weight.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + weight = weight.to(x.dtype) if vllm_is_batch_invariant() and current_platform.is_cuda_alike(): - return linear_batch_invariant(x, layer.weight, bias) - return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) + return linear_batch_invariant(x, weight, bias) + return dispatch_unquantized_gemm()(layer, x, weight, bias) class LinearBase(PluggableLayer): diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index daaa86bed478..8d4d7128807f 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -66,10 +66,18 @@ def apply( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) + weight = layer.weight + # Support FP8 weight storage: cast to compute dtype for GEMM + if weight.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + weight = weight.to(x.dtype) + return dispatch_unquantized_gemm()(layer, x, weight, bias) def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: - return F.embedding(input_, layer.weight) + output = F.embedding(input_, layer.weight) + # Support FP8 weight storage: cast to compute dtype after lookup + if output.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + output = output.to(torch.bfloat16) + return output def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: @@ -429,6 +437,14 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): shape[output_dim] = self.num_embeddings_per_partition param.materialize(tuple(shape), dtype=loaded_weight.dtype) + # If loaded weight is FP8, cast parameter to match so FP8 is + # preserved in memory (saves VRAM for e.g. embed_tokens). + if ( + loaded_weight.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) + and param.data.dtype != loaded_weight.dtype + ): + param.data = torch.empty_like(param.data, dtype=loaded_weight.dtype) + # If parameter does not have output dim, then it should # be copied onto all gpus (e.g. g_idx for act_order gptq). if output_dim is None: From 8c523dec3f62f568779092a168e25f825ec3e82d Mon Sep 17 00:00:00 2001 From: Lucas Pirola Date: Sun, 8 Mar 2026 18:42:33 +0100 Subject: [PATCH 2/2] fixup: use get_default_dtype() instead of hardcoded bfloat16 for FP8 cast Replace hardcoded torch.bfloat16 with torch.get_default_dtype() so that float16 and float32 models are handled correctly. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Lucas Pirola --- vllm/model_executor/layers/vocab_parallel_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 8d4d7128807f..fb9fa0c98e19 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -76,7 +76,7 @@ def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tenso output = F.embedding(input_, layer.weight) # Support FP8 weight storage: cast to compute dtype after lookup if output.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - output = output.to(torch.bfloat16) + output = output.to(torch.get_default_dtype()) return output