diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index dddbd119392a..2ee97b9578b6 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -15,6 +15,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.quantization.fp8 import ( + CopyNumelCounter, Fp8Config, Fp8KVCacheMethod, Fp8LinearMethod, @@ -466,3 +467,146 @@ def test_fp8_reloading( weight_loader(param, torch.zeros(shape)) # cannot use empty method.process_weights_after_loading(layer) + + +def test_copy_numel_counter_tracks_narrowed_copies(): + """Test that CopyNumelCounter correctly tracks only the elements that are + actually copied via copy_(), not the full source tensor size. + + This is critical for FP8 online quantization with TP > 1, where the + weight_loader narrows the loaded weight to the relevant TP shard before + calling copy_. Previously (before PR #30900), the code used + loaded_weight.numel() to track progress, which over-counted elements + and caused process_weights_after_loading to trigger at the wrong time, + resulting in empty/garbage output. + + Regression test for: https://github.com/vllm-project/vllm/issues/36583 + """ + # Simulate a fused weight parameter (e.g., gate_up_proj) with TP=2 + # Full weight: [2 * intermediate_size, hidden_size] + # Sharded weight per rank: [2 * intermediate_size / 2, hidden_size] + hidden_size = 64 + intermediate_size = 128 + tp_size = 2 + shard_size = intermediate_size // tp_size # 64 + + # Create the sharded parameter (what each TP rank holds) + param_data = torch.zeros(2 * shard_size, hidden_size) + + # Simulate loading gate_proj (first shard) from the full weight + full_gate_weight = torch.randn(intermediate_size, hidden_size) + # The weight_loader would narrow to the TP shard + gate_shard = full_gate_weight[:shard_size] + + counter = CopyNumelCounter() + with counter: + # This simulates what weight_loader does: narrow param and copy + param_data[:shard_size].copy_(gate_shard) + + # The counter should track only the narrowed copy, not the full weight + assert counter.copied_numel == shard_size * hidden_size + # NOT full_gate_weight.numel() == intermediate_size * hidden_size + + # Simulate loading up_proj (second shard) from the full weight + full_up_weight = torch.randn(intermediate_size, hidden_size) + up_shard = full_up_weight[:shard_size] + + counter2 = CopyNumelCounter() + with counter2: + param_data[shard_size:].copy_(up_shard) + + assert counter2.copied_numel == shard_size * hidden_size + + # Total copied should equal the parameter size + total_copied = counter.copied_numel + counter2.copied_numel + assert total_copied == param_data.numel() + + +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) +def test_online_fp8_streaming_quant_with_sharded_weights( + dist_init, + monkeypatch, +): + """Test that FP8 online quantization correctly handles weight loading + when weights are sharded (simulating TP > 1 behavior). + + The key scenario: for fused layers like QKV or gate_up_proj, the + weight_loader loads each sub-weight (q, k, v or gate, up) separately. + With TP, each sub-weight is narrowed to the TP shard before copy_. + The streaming quantization must wait until ALL shards are loaded before + calling process_weights_after_loading. + + Regression test for: https://github.com/vllm-project/vllm/issues/36583 + """ + monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "0") + + from vllm.model_executor.layers.quantization.fp8 import Fp8OnlineLinearMethod + + config = Fp8Config( + is_checkpoint_fp8_serialized=False, + activation_scheme="dynamic", + ) + method = Fp8OnlineLinearMethod(config) + + # Simulate a MergedColumnParallelLinear-like setup (gate_up_proj) + # with 2 logical shards, as if TP=2 narrowing had already occurred + hidden_size = 64 + shard_size = 32 # each logical shard (gate, up) per TP rank + output_partition_sizes = [shard_size, shard_size] + output_size_per_partition = sum(output_partition_sizes) # 64 + + with torch.device("cuda:0"): + layer = torch.nn.Module() + + method.create_weights( + layer=layer, + input_size_per_partition=hidden_size, + output_partition_sizes=output_partition_sizes, + input_size=hidden_size, + output_size=output_size_per_partition * 2, # full size before TP + params_dtype=torch.bfloat16, + weight_loader=default_weight_loader, + ) + + assert layer.weight.device == torch.device("meta") + + # Simulate loading the first shard (gate) + # In real TP loading, the loaded_weight is the full weight but + # only the shard portion gets copy_'d into the parameter + gate_weight = torch.randn( + shard_size, hidden_size, dtype=torch.bfloat16, device="cuda:0" + ) + patched_loader = layer.weight.weight_loader + patched_loader(layer.weight, gate_weight) + + # After loading just the first shard, weight should NOT be quantized yet + assert layer.weight.dtype == torch.bfloat16, ( + "Weight was prematurely quantized after loading only the first " + "shard. This causes garbage output with FP8 + TP > 1." + ) + assert not getattr( + layer, "_already_called_process_weights_after_loading", False + ) + assert layer._loaded_numel == shard_size * hidden_size + + # Now load the second shard (up) + up_weight = torch.randn( + shard_size, hidden_size, dtype=torch.bfloat16, device="cuda:0" + ) + patched_loader(layer.weight, up_weight) + + # After loading both shards, weight should be quantized + assert layer._loaded_numel == output_size_per_partition * hidden_size + assert getattr(layer, "_already_called_process_weights_after_loading", False) + # Weight should now be FP8 + assert layer.weight.dtype == current_platform.fp8_dtype(), ( + "Weight was not quantized to FP8 after all shards were loaded." + ) + # Weight should be transposed (input_dim x output_dim) + assert layer.weight.shape == (hidden_size, output_size_per_partition) + # Weight scale should exist + assert hasattr(layer, "weight_scale") + assert layer.weight_scale.dtype == torch.float32 diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5101347cd02a..6b2eae371f42 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -569,9 +569,15 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): layer._loaded_numel += copy_numel_counter.copied_numel # if we have loaded all of the elements, call - # process_weights_after_loading + # process_weights_after_loading. + # Note: we use >= instead of == as a defensive measure. With + # TP > 1, the weight_loader narrows loaded_weight to the + # relevant shard before copy_, and CopyNumelCounter tracks + # only the actually-copied elements. Using >= guards against + # edge cases where copy_ numel tracking slightly overshoots + # (e.g. from internal copy_ calls in the weight_loader). target_loaded_numel = layer.weight.numel() - if layer._loaded_numel == target_loaded_numel: + if layer._loaded_numel >= target_loaded_numel: self.process_weights_after_loading(layer) # Prevent the usual `process_weights_after_loading` call from doing @@ -1089,9 +1095,11 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): layer._loaded_numel += copy_numel_counter.copied_numel # if we have loaded all of the elements, call - # process_weights_after_loading + # process_weights_after_loading. + # Note: we use >= instead of == as a defensive measure + # (see comment in Fp8OnlineLinearMethod for details). target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel() - if layer._loaded_numel == target_loaded_numel: + if layer._loaded_numel >= target_loaded_numel: self.process_weights_after_loading(layer) # Prevent the usual `process_weights_after_loading` call