Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.quantization.fp8 import (
CopyNumelCounter,
Fp8Config,
Fp8KVCacheMethod,
Fp8LinearMethod,
Expand Down Expand Up @@ -466,3 +467,146 @@ def test_fp8_reloading(
weight_loader(param, torch.zeros(shape)) # cannot use empty

method.process_weights_after_loading(layer)


def test_copy_numel_counter_tracks_narrowed_copies():
"""Test that CopyNumelCounter correctly tracks only the elements that are
actually copied via copy_(), not the full source tensor size.

This is critical for FP8 online quantization with TP > 1, where the
weight_loader narrows the loaded weight to the relevant TP shard before
calling copy_. Previously (before PR #30900), the code used
loaded_weight.numel() to track progress, which over-counted elements
and caused process_weights_after_loading to trigger at the wrong time,
resulting in empty/garbage output.

Regression test for: https://github.com/vllm-project/vllm/issues/36583
"""
# Simulate a fused weight parameter (e.g., gate_up_proj) with TP=2
# Full weight: [2 * intermediate_size, hidden_size]
# Sharded weight per rank: [2 * intermediate_size / 2, hidden_size]
hidden_size = 64
intermediate_size = 128
tp_size = 2
shard_size = intermediate_size // tp_size # 64

# Create the sharded parameter (what each TP rank holds)
param_data = torch.zeros(2 * shard_size, hidden_size)

# Simulate loading gate_proj (first shard) from the full weight
full_gate_weight = torch.randn(intermediate_size, hidden_size)
# The weight_loader would narrow to the TP shard
gate_shard = full_gate_weight[:shard_size]

counter = CopyNumelCounter()
with counter:
# This simulates what weight_loader does: narrow param and copy
param_data[:shard_size].copy_(gate_shard)

# The counter should track only the narrowed copy, not the full weight
assert counter.copied_numel == shard_size * hidden_size
# NOT full_gate_weight.numel() == intermediate_size * hidden_size

# Simulate loading up_proj (second shard) from the full weight
full_up_weight = torch.randn(intermediate_size, hidden_size)
up_shard = full_up_weight[:shard_size]

counter2 = CopyNumelCounter()
with counter2:
param_data[shard_size:].copy_(up_shard)

assert counter2.copied_numel == shard_size * hidden_size

# Total copied should equal the parameter size
total_copied = counter.copied_numel + counter2.copied_numel
assert total_copied == param_data.numel()


@pytest.mark.skipif(
not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.",
)
def test_online_fp8_streaming_quant_with_sharded_weights(
dist_init,
monkeypatch,
):
"""Test that FP8 online quantization correctly handles weight loading
when weights are sharded (simulating TP > 1 behavior).

The key scenario: for fused layers like QKV or gate_up_proj, the
weight_loader loads each sub-weight (q, k, v or gate, up) separately.
With TP, each sub-weight is narrowed to the TP shard before copy_.
The streaming quantization must wait until ALL shards are loaded before
calling process_weights_after_loading.

Regression test for: https://github.com/vllm-project/vllm/issues/36583
"""
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "0")

from vllm.model_executor.layers.quantization.fp8 import Fp8OnlineLinearMethod

config = Fp8Config(
is_checkpoint_fp8_serialized=False,
activation_scheme="dynamic",
)
method = Fp8OnlineLinearMethod(config)

# Simulate a MergedColumnParallelLinear-like setup (gate_up_proj)
# with 2 logical shards, as if TP=2 narrowing had already occurred
hidden_size = 64
shard_size = 32 # each logical shard (gate, up) per TP rank
output_partition_sizes = [shard_size, shard_size]
output_size_per_partition = sum(output_partition_sizes) # 64

with torch.device("cuda:0"):
layer = torch.nn.Module()

method.create_weights(
layer=layer,
input_size_per_partition=hidden_size,
output_partition_sizes=output_partition_sizes,
input_size=hidden_size,
output_size=output_size_per_partition * 2, # full size before TP
params_dtype=torch.bfloat16,
weight_loader=default_weight_loader,
)

assert layer.weight.device == torch.device("meta")

# Simulate loading the first shard (gate)
# In real TP loading, the loaded_weight is the full weight but
# only the shard portion gets copy_'d into the parameter
gate_weight = torch.randn(
shard_size, hidden_size, dtype=torch.bfloat16, device="cuda:0"
)
patched_loader = layer.weight.weight_loader
patched_loader(layer.weight, gate_weight)

# After loading just the first shard, weight should NOT be quantized yet
assert layer.weight.dtype == torch.bfloat16, (
"Weight was prematurely quantized after loading only the first "
"shard. This causes garbage output with FP8 + TP > 1."
)
assert not getattr(
layer, "_already_called_process_weights_after_loading", False
)
assert layer._loaded_numel == shard_size * hidden_size

# Now load the second shard (up)
up_weight = torch.randn(
shard_size, hidden_size, dtype=torch.bfloat16, device="cuda:0"
)
patched_loader(layer.weight, up_weight)

# After loading both shards, weight should be quantized
assert layer._loaded_numel == output_size_per_partition * hidden_size
assert getattr(layer, "_already_called_process_weights_after_loading", False)
# Weight should now be FP8
assert layer.weight.dtype == current_platform.fp8_dtype(), (
"Weight was not quantized to FP8 after all shards were loaded."
)
# Weight should be transposed (input_dim x output_dim)
assert layer.weight.shape == (hidden_size, output_size_per_partition)
# Weight scale should exist
assert hasattr(layer, "weight_scale")
assert layer.weight_scale.dtype == torch.float32
16 changes: 12 additions & 4 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,9 +569,15 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):
layer._loaded_numel += copy_numel_counter.copied_numel

# if we have loaded all of the elements, call
# process_weights_after_loading
# process_weights_after_loading.
# Note: we use >= instead of == as a defensive measure. With
# TP > 1, the weight_loader narrows loaded_weight to the
# relevant shard before copy_, and CopyNumelCounter tracks
# only the actually-copied elements. Using >= guards against
# edge cases where copy_ numel tracking slightly overshoots
# (e.g. from internal copy_ calls in the weight_loader).
target_loaded_numel = layer.weight.numel()
if layer._loaded_numel == target_loaded_numel:
if layer._loaded_numel >= target_loaded_numel:
self.process_weights_after_loading(layer)

# Prevent the usual `process_weights_after_loading` call from doing
Expand Down Expand Up @@ -1089,9 +1095,11 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):
layer._loaded_numel += copy_numel_counter.copied_numel

# if we have loaded all of the elements, call
# process_weights_after_loading
# process_weights_after_loading.
# Note: we use >= instead of == as a defensive measure
# (see comment in Fp8OnlineLinearMethod for details).
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
if layer._loaded_numel == target_loaded_numel:
if layer._loaded_numel >= target_loaded_numel:
self.process_weights_after_loading(layer)

# Prevent the usual `process_weights_after_loading` call
Expand Down
Loading