diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f0d06e179f33..eb6ae7df158b 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -217,9 +217,13 @@ def apply( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: + weight = layer.weight + # Support FP8 weight storage: cast to compute dtype for GEMM + if weight.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + weight = weight.to(x.dtype) if vllm_is_batch_invariant() and current_platform.is_cuda_alike(): - return linear_batch_invariant(x, layer.weight, bias) - return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) + return linear_batch_invariant(x, weight, bias) + return dispatch_unquantized_gemm()(layer, x, weight, bias) class LinearBase(PluggableLayer): diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index daaa86bed478..fb9fa0c98e19 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -66,10 +66,18 @@ def apply( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) + weight = layer.weight + # Support FP8 weight storage: cast to compute dtype for GEMM + if weight.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + weight = weight.to(x.dtype) + return dispatch_unquantized_gemm()(layer, x, weight, bias) def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: - return F.embedding(input_, layer.weight) + output = F.embedding(input_, layer.weight) + # Support FP8 weight storage: cast to compute dtype after lookup + if output.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + output = output.to(torch.get_default_dtype()) + return output def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: @@ -429,6 +437,14 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): shape[output_dim] = self.num_embeddings_per_partition param.materialize(tuple(shape), dtype=loaded_weight.dtype) + # If loaded weight is FP8, cast parameter to match so FP8 is + # preserved in memory (saves VRAM for e.g. embed_tokens). + if ( + loaded_weight.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) + and param.data.dtype != loaded_weight.dtype + ): + param.data = torch.empty_like(param.data, dtype=loaded_weight.dtype) + # If parameter does not have output dim, then it should # be copied onto all gpus (e.g. g_idx for act_order gptq). if output_dim is None: