Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
267 changes: 263 additions & 4 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math
from typing import TYPE_CHECKING, Any

import torch
from torch.nn import Module
from torch.nn.parameter import UninitializedParameter
from torch.utils._python_dispatch import TorchDispatchMode

import vllm.envs as envs
Expand Down Expand Up @@ -38,6 +40,11 @@
LinearMethodBase,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
UnquantizedEmbeddingMethod,
VocabParallelEmbedding,
)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
Expand Down Expand Up @@ -102,10 +109,14 @@ def __init__(
activation_scheme: str = "dynamic",
ignored_layers: list[str] | None = None,
weight_block_size: list[int] | None = None,
lm_head_quantized: bool = False,
embeddings_quantized: bool = False,
) -> None:
super().__init__()

self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
self.lm_head_quantized = lm_head_quantized
self.embeddings_quantized = embeddings_quantized

if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
Expand Down Expand Up @@ -162,22 +173,35 @@ def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
ignored_layers = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None
)
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
embeddings_quantized = cls.get_from_keys_or(
config,
["embed_tokens"],
cls.get_from_keys_or(config, ["embeddings"], False),
)
return cls(
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
weight_block_size=weight_block_size,
lm_head_quantized=lm_head_quantized,
embeddings_quantized=embeddings_quantized,
)

def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase):
is_parallel_lm_head = isinstance(layer, ParallelLMHead)
if isinstance(layer, LinearBase) or (
is_parallel_lm_head and self.lm_head_quantized
):
if is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
if is_parallel_lm_head:
return UnquantizedEmbeddingMethod()
return UnquantizedLinearMethod()
if not self.is_checkpoint_fp8_serialized:
online_method = Fp8OnlineLinearMethod(self)
Expand All @@ -201,6 +225,21 @@ def get_quant_method(
return moe_quant_method
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
elif type(layer) is VocabParallelEmbedding:
if is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedEmbeddingMethod()
if not self.embeddings_quantized:
return UnquantizedEmbeddingMethod()
if not self.is_checkpoint_fp8_serialized:
raise NotImplementedError(
"Online FP8 quantization for VocabParallelEmbedding is "
"not implemented. Serialize embed_tokens to FP8 offline."
)
return Fp8EmbeddingMethod(self)
return None

def get_cache_scale(self, name: str) -> str | None:
Expand Down Expand Up @@ -254,6 +293,52 @@ def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None:
set_weight_attrs(new, attrs_to_set)



def _make_lm_head_block_scale_loader(layer, block_size):
"""Per-parameter weight_loader for FP8 block scale_inv on ParallelLMHead.

The default VocabParallelEmbedding.weight_loader assumes vocab-shaped
tensors and rejects companion params with a different leading dim
(e.g., weight_scale_inv has shape [vocab/block_out, hidden/block_in]).
This loader shards the scale tensor along the block-aligned vocab dim
using the layer's existing shard_indices, and zero-fills any padding
rows the param was sized for.
"""
block_out = block_size[0]

def load(param, loaded_weight):
start = layer.shard_indices.org_vocab_start_index
assert start % block_out == 0, (
f"FP8 lm_head requires the vocab-parallel shard start "
f"({start}) to be divisible by weight_block_size[0] "
f"({block_out})"
)
start_idx = start // block_out
local_rows = param.shape[0]
assert loaded_weight.shape[0] >= start_idx + local_rows, (
f"loaded scale has {loaded_weight.shape[0]} rows, "
f"need at least {start_idx + local_rows} "
f"(start_idx={start_idx}, local_rows={local_rows})"
)
chunk = loaded_weight.narrow(0, start_idx, local_rows)
param.data.copy_(chunk)

return load


def _make_lm_head_scalar_scale_loader():
"""Per-parameter weight_loader for FP8 per-tensor / input scale on
ParallelLMHead. Per-tensor scales are not vocab-parallel; just copy.
"""

def load(param, loaded_weight):
if loaded_weight.numel() > param.data.numel():
loaded_weight = loaded_weight.flatten()[0]
param.data.copy_(loaded_weight.reshape(param.data.shape))
Comment on lines +334 to +337
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This loader is not robust to changes in the tensor parallel (TP) degree. If a checkpoint was saved with a higher TP degree, loaded_weight may contain multiple per-tensor scales (one per shard). Attempting to reshape a multi-element tensor into a single-element param.data will raise a RuntimeError. It should handle this by taking the first element, which is consistent with how vLLM handles per-tensor scales during TP-degree-changing loads.

Suggested change
def load(param, loaded_weight):
param.data.copy_(loaded_weight.reshape(param.data.shape))
def load(param, loaded_weight):
if loaded_weight.numel() > param.data.numel():
loaded_weight = loaded_weight.flatten()[0]
param.data.copy_(loaded_weight.reshape(param.data.shape))

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemented in e3547df: this loader now checks loaded_weight.numel() > param.data.numel(), takes loaded_weight.flatten()[0], then reshapes into param.data.shape.


return load


class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
Expand Down Expand Up @@ -344,14 +429,33 @@ def create_weights(
)
layer.register_parameter("weight", weight)

# WEIGHT / INPUT SCALES
# When this method is dispatched to a ParallelLMHead (opt-in via
# Fp8Config.lm_head_quantized), companion params can't share the
# default VocabParallelEmbedding.weight_loader (which assumes
# vocab-shaped tensors). Pick the right scale loader up front so
# we don't have to override it post-hoc -- set_weight_attrs() asserts
# against double-assignment of `weight_loader`.
if isinstance(layer, ParallelLMHead):
if self.block_quant:
scale_weight_loader = _make_lm_head_block_scale_loader(
layer, self.weight_block_size
)
else:
scale_weight_loader = _make_lm_head_scalar_scale_loader()
input_scale_weight_loader = _make_lm_head_scalar_scale_loader()
else:
scale_weight_loader = weight_loader
input_scale_weight_loader = weight_loader

# WEIGHT SCALE
if not self.block_quant:
scale = create_fp8_scale_parameter(
PerTensorScaleParameter,
output_partition_sizes,
input_size_per_partition,
None,
weight_loader,
scale_weight_loader,
)
layer.register_parameter("weight_scale", scale)
else:
Expand All @@ -362,15 +466,17 @@ def create_weights(
output_partition_sizes,
input_size_per_partition,
self.weight_block_size,
weight_loader,
scale_weight_loader,
scale_dtype=(torch.float8_e8m0fnu if self.is_scale_e8m0 else None),
)
# The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale)

# INPUT ACTIVATION SCALE
if self.act_q_static:
scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
scale = create_fp8_input_scale(
output_partition_sizes, input_scale_weight_loader
)
set_weight_attrs(scale, {"scale_type": "input_scale"})
layer.register_parameter("input_scale", scale)

Expand Down Expand Up @@ -476,6 +582,159 @@ def apply(
return self.fp8_linear.apply_weights(layer, x, bias)



class Fp8EmbeddingMethod(Fp8LinearMethod):
"""Memory-first FP8 runtime for VocabParallelEmbedding.

The embedding table stays resident as FP8 and only gathered rows are
dequantized back to the model dtype during lookup. This is intended as a
correctness/VRAM test path, not the final fused performance design.
"""

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")

layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None

if self.block_quant:
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
validate_fp8_block_shape(
layer,
input_size,
output_size,
input_size_per_partition,
output_partition_sizes,
self.weight_block_size,
)

weight = create_fp8_weight_parameter(
output_size_per_partition, input_size_per_partition, weight_loader
)
layer.register_parameter("weight", weight)

if not self.block_quant:
scale = create_fp8_scale_parameter(
PerTensorScaleParameter,
output_partition_sizes,
input_size_per_partition,
None,
weight_loader,
)
layer.register_parameter("weight_scale", scale)
else:
assert not self.act_q_static
assert self.weight_block_size is not None
scale = create_fp8_scale_parameter(
BlockQuantScaleParameter,
output_partition_sizes,
input_size_per_partition,
self.weight_block_size,
weight_loader,
scale_dtype=(torch.float8_e8m0fnu if self.is_scale_e8m0 else None),
)
layer.register_parameter("weight_scale_inv", scale)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The Fp8EmbeddingMethod inherits from Fp8LinearMethod but its create_weights implementation fails to initialize self.fp8_linear. This is a critical issue for models with tied weights (tie_word_embeddings=True), such as some Qwen variants, because the lm_head will share the embed_tokens quantization method. When LogitsProcessor calls apply on the tied lm_head, it will crash with an AttributeError because self.fp8_linear is missing. Initializing the kernel ensures the method works correctly for both embedding lookups and linear projections.

            layer.register_parameter("weight_scale_inv", scale)

        self.fp8_linear = init_fp8_linear_kernel(
            activation_quant_key=self.activation_quant_key,
            weight_quant_key=self.weight_quant_key,
            weight_shape=layer.weight.shape,
            input_dtype=self.input_dtype,
            out_dtype=self.out_dtype,
            module_name=self.__class__.__name__,
        )
        self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemented in e3547df: Fp8EmbeddingMethod.create_weights() now initializes self.fp8_linear with init_fp8_linear_kernel(...) and sets self.use_marlin, so tied embedding/lm_head paths can call apply() safely.


self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)

def load_embedding_companion(
self,
layer: VocabParallelEmbedding,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
) -> None:
output_dim = getattr(param, "output_dim", None)
if output_dim is None:
if isinstance(param, UninitializedParameter):
param.materialize(tuple(loaded_weight.shape), dtype=loaded_weight.dtype)
assert param.data.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
return

if isinstance(param, UninitializedParameter):
shape = list(loaded_weight.shape)
if self.block_quant:
assert self.weight_block_size is not None
block_n = self.weight_block_size[0]
shape[output_dim] = math.ceil(
layer.num_embeddings_per_partition / block_n
)
else:
shape[output_dim] = layer.num_embeddings_per_partition
param.materialize(tuple(shape), dtype=loaded_weight.dtype)

if loaded_weight.shape == param.data.shape or loaded_weight.shape[output_dim] == 1:
param.data.copy_(loaded_weight)
return

if self.block_quant:
assert self.weight_block_size is not None
block_n = self.weight_block_size[0]
start = layer.shard_indices.org_vocab_start_index
assert start % block_n == 0, (
f"FP8 embedding requires the vocab-parallel shard start "
f"({start}) to be divisible by weight_block_size[0] "
f"({block_n})"
)
start_idx = start // block_n
else:
start_idx = layer.shard_indices.org_vocab_start_index

local_rows = param.data.shape[output_dim]
available_rows = max(0, loaded_weight.shape[output_dim] - start_idx)
copy_rows = min(local_rows, available_rows)

if copy_rows > 0:
shard = loaded_weight.narrow(output_dim, start_idx, copy_rows)
param.data.narrow(output_dim, 0, copy_rows).copy_(shard)
if copy_rows < local_rows:
param.data.narrow(output_dim, copy_rows, local_rows - copy_rows).fill_(0)

def process_weights_after_loading(self, layer: Module) -> None:
layer._already_called_process_weights_after_loading = True

def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
out_dtype = getattr(layer, "orig_dtype", torch.get_default_dtype())
flat_input = input_.reshape(-1).long()
weight_rows = layer.weight.index_select(0, flat_input).to(out_dtype)

if self.block_quant:
assert self.weight_block_size is not None
block_n, block_k = self.weight_block_size
scales = layer.weight_scale_inv.to(out_dtype)
row_blocks = torch.div(flat_input, block_n, rounding_mode="floor")
row_scales = scales.index_select(0, row_blocks)
row_scales = row_scales.repeat_interleave(block_k, dim=-1)
row_scales = row_scales[..., : layer.embedding_dim]
output = weight_rows * row_scales
else:
scales = layer.weight_scale.to(out_dtype)
output = weight_rows * scales

return output.reshape(*input_.shape, layer.embedding_dim)

# TODO(future PR): remove this class in favor of
# online/fp8.py::Fp8PerTensorOnlineLinearMethod
class Fp8OnlineLinearMethod(Fp8LinearMethod):
Expand Down
Loading
Loading