Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 82 additions & 4 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
LinearMethodBase,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
UnquantizedEmbeddingMethod,
)
Comment on lines +41 to +44
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Importing ParallelLMHead and UnquantizedEmbeddingMethod at the top level of fp8.py from vllm.model_executor.layers.vocab_parallel_embedding may lead to circular import issues in the future, as quantization configs are often imported by the layers they configure. It is generally safer to perform these imports inside get_quant_method or use TYPE_CHECKING for type hints and importlib for runtime checks if necessary.

from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
Expand Down Expand Up @@ -102,10 +106,12 @@ def __init__(
activation_scheme: str = "dynamic",
ignored_layers: list[str] | None = None,
weight_block_size: list[int] | None = None,
lm_head_quantized: bool = False,
) -> None:
super().__init__()

self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
self.lm_head_quantized = lm_head_quantized

if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
Expand Down Expand Up @@ -162,22 +168,29 @@ def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
ignored_layers = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None
)
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
return cls(
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
weight_block_size=weight_block_size,
lm_head_quantized=lm_head_quantized,
)

def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase):
is_parallel_lm_head = isinstance(layer, ParallelLMHead)
if isinstance(layer, LinearBase) or (
is_parallel_lm_head and self.lm_head_quantized
):
Comment on lines +183 to +186
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of get_quant_method returns Fp8LinearMethod (or Fp8OnlineLinearMethod) for ParallelLMHead. However, Fp8LinearMethod is designed for LinearBase modules and does not implement the embedding method required by VocabParallelEmbedding (the base class of ParallelLMHead). While ParallelLMHead overrides forward to raise a RuntimeError, any code path that might attempt to use it as a standard embedding layer (e.g., if weights are tied and accessed via the embedding interface) will fail with a NotImplementedError.

Furthermore, as noted in the PR description, VocabParallelEmbedding.weight_loader does not currently handle the companion parameters (like weight_scale) created by Fp8LinearMethod. Returning a linear method for an embedding-sharded layer without ensuring the loader and interface compatibility is a high-risk change.

if is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
if is_parallel_lm_head:
return UnquantizedEmbeddingMethod()
return UnquantizedLinearMethod()
if not self.is_checkpoint_fp8_serialized:
online_method = Fp8OnlineLinearMethod(self)
Expand Down Expand Up @@ -254,6 +267,50 @@ def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None:
set_weight_attrs(new, attrs_to_set)



def _make_lm_head_block_scale_loader(layer, block_size):
"""Per-parameter weight_loader for FP8 block scale_inv on ParallelLMHead.

The default VocabParallelEmbedding.weight_loader assumes vocab-shaped
tensors and rejects companion params with a different leading dim
(e.g., weight_scale_inv has shape [vocab/block_out, hidden/block_in]).
This loader shards the scale tensor along the block-aligned vocab dim
using the layer's existing shard_indices, and zero-fills any padding
rows the param was sized for.
"""
block_out = block_size[0]

def load(param, loaded_weight):
start = layer.shard_indices.org_vocab_start_index
assert start % block_out == 0, (
f"FP8 lm_head requires the vocab-parallel shard start "
f"({start}) to be divisible by weight_block_size[0] "
f"({block_out})"
)
start_idx = start // block_out
local_rows = param.shape[0]
assert loaded_weight.shape[0] >= start_idx + local_rows, (
f"loaded scale has {loaded_weight.shape[0]} rows, "
f"need at least {start_idx + local_rows} "
f"(start_idx={start_idx}, local_rows={local_rows})"
)
chunk = loaded_weight.narrow(0, start_idx, local_rows)
param.data.copy_(chunk)

return load


def _make_lm_head_scalar_scale_loader():
"""Per-parameter weight_loader for FP8 per-tensor / input scale on
ParallelLMHead. Per-tensor scales are not vocab-parallel; just copy.
"""

def load(param, loaded_weight):
param.data.copy_(loaded_weight.reshape(param.data.shape))

return load


class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
Expand Down Expand Up @@ -344,14 +401,33 @@ def create_weights(
)
layer.register_parameter("weight", weight)

# WEIGHT / INPUT SCALES
# When this method is dispatched to a ParallelLMHead (opt-in via
# Fp8Config.lm_head_quantized), companion params can't share the
# default VocabParallelEmbedding.weight_loader (which assumes
# vocab-shaped tensors). Pick the right scale loader up front so
# we don't have to override it post-hoc -- set_weight_attrs() asserts
# against double-assignment of `weight_loader`.
if isinstance(layer, ParallelLMHead):
if self.block_quant:
scale_weight_loader = _make_lm_head_block_scale_loader(
layer, self.weight_block_size
)
else:
scale_weight_loader = _make_lm_head_scalar_scale_loader()
input_scale_weight_loader = _make_lm_head_scalar_scale_loader()
else:
scale_weight_loader = weight_loader
input_scale_weight_loader = weight_loader

# WEIGHT SCALE
if not self.block_quant:
scale = create_fp8_scale_parameter(
PerTensorScaleParameter,
output_partition_sizes,
input_size_per_partition,
None,
weight_loader,
scale_weight_loader,
)
layer.register_parameter("weight_scale", scale)
else:
Expand All @@ -362,15 +438,17 @@ def create_weights(
output_partition_sizes,
input_size_per_partition,
self.weight_block_size,
weight_loader,
scale_weight_loader,
scale_dtype=(torch.float8_e8m0fnu if self.is_scale_e8m0 else None),
)
# The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale)

# INPUT ACTIVATION SCALE
if self.act_q_static:
scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
scale = create_fp8_input_scale(
output_partition_sizes, input_scale_weight_loader
)
set_weight_attrs(scale, {"scale_type": "input_scale"})
layer.register_parameter("input_scale", scale)

Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
Expand Down
Loading