refactor fp8.py online quant weight loading to use layerwise reload utils#33814
refactor fp8.py online quant weight loading to use layerwise reload utils#33814vkuzo wants to merge 1 commit intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces an experimental implementation for layer-wise reloading for FP8 online quantization. The changes are clearly a work-in-progress, with several temporary code blocks, hardcoded flags, and debug statements. My review focuses on identifying these temporary elements and suggesting their removal or proper implementation for the final version. Key areas of feedback include removing dead code under if False blocks, replacing hardcoded feature flags with configuration options, and removing debug print/log statements. These changes are crucial for making the code production-ready.
50273e5 to
2b64e12
Compare
ddfd954 to
1a1b04a
Compare
1a1b04a to
1fe187b
Compare
kylesayrs
left a comment
There was a problem hiding this comment.
I think these changes look great! I think the documentation makes it clear how and where the two flows differ. Just small nits/cleanups from my side
| setattr(layer, name, materialize_meta_tensor(tensor)) | ||
|
|
||
|
|
||
| def materialize_layer_tensors_with_device_meta(layer: torch.nn.Module) -> None: |
There was a problem hiding this comment.
Do you want to have this implementation just replace materialize_layer? It should be safe to do so, as the assumption of materialize_layer is that it should only be relevant for meta tensors.
| """ | ||
| if is_reload: | ||
| # Materialize layer tensors onto device | ||
| materialize_layer(layer) |
There was a problem hiding this comment.
I think that, with your if device == "meta" guard, it should be safe to call this function in all cases, right? That would ensure that the entire layer is materialized at this point, including any scales, ect.
There was a problem hiding this comment.
I think it's better to be explicit, materialization is not relevant to the initial load path in this section of the code, so simpler to just skip it.
| num_loaded, ret = get_numel_loaded(original_loader, bound_args) | ||
|
|
||
| else: | ||
| if info.load_numel == 0: |
There was a problem hiding this comment.
I think, theoretically, you don't need this check, as the if device == "meta" check guards against double materialization anyways, right?
There was a problem hiding this comment.
I think it's easier to understand this way
1fe187b to
1f57fa3
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
1f57fa3 to
cfef389
Compare
|
Hi @vkuzo, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
cfef389 to
1a4c59d
Compare
| ) | ||
| layer.register_parameter("w13_bias", w13_bias) | ||
| set_weight_attrs(w13_bias, orig_extra_weight_attrs) | ||
| set_weight_attrs(w13_bias, extra_weight_attrs) |
There was a problem hiding this comment.
need to verify GPT-OSS 120B still works as this changes the code added by #34906 and there is no CI coverage
There was a problem hiding this comment.
following up on this, GPT-OSS bf16 is not expected to work with fp8.py online quant because:
- fp8.py online quant (and future online quant backends in vllm) require weight_loaders, because we use weight_loaders to inject the streaming weight loading functionality
- gpt_oss.py model definition for the bf16 weights case does not use weight loaders:
vllm/vllm/model_executor/models/gpt_oss.py
Line 1009 in 234a65b
I'm not exactly sure how #34906 worked given 1 and 2 ^. Going to skip this for now as gpt-oss + online quant seems low pri because the official weights are in mxfp4, and we can follow-up if needed.
for posterity, the easiest way to test this is using the 20b model from unsloth which goes through the same path as the 120b:
VLLM_ENABLE_V1_MULTIPROCESSING=0 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/basic/chat.py --model unsloth/gpt-oss-20b-BF16 --enforce-eager --dtype=bfloat16 --quantization=fp8
1a4c59d to
0d77ac9
Compare
| # Note: this is currently broken for gpt-oss because it | ||
| # does not use weight loaders at all in the bf16 weights | ||
| # path | ||
| device="meta", |
There was a problem hiding this comment.
gpt-oss bf16 is broken whether biases are initialized on gpu or on meta, going with meta to be consistent with layerwise loading infra
if we want gpt-oss to work with fp8.py we should refactor gpt_oss.py to use weight loaders
0d77ac9 to
f66481c
Compare
f66481c to
62b89b5
Compare
|
rebased on latest main |
|
Hi @vkuzo, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
cbd5460 to
de1ed65
Compare
|
rebased on latest main |
Summary: WIP, for now just getting a POC to see what is needed for the real version. Test Plan: ```bash // example with facebook/opt-125m VLLM_LOGGING_LEVEL=DEBUG VLLM_ENABLE_V1_MULTIPROCESSING=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --enforce-eager --dtype=bfloat16 --max_model_len=2048 --quantization=fp8 // before: DEBUG 02-04 18:49:50 [model_executor/model_loader/base_loader.py:66] Peak GPU memory after loading weights: 0.18 GiB // after: DEBUG 02-04 18:49:08 [model_executor/model_loader/base_loader.py:83] Peak GPU memory after loading weights: 0.25 GiB ``` Signed-off-by: Vasiliy Kuznetsov <vasiliy@meta.com>
|
This pull request has merge conflicts that must be resolved before it can be |
de1ed65 to
4341ba1
Compare
|
rebased and re-ran the test plan on 2xB200 |
|
after a conversation with @kylesayrs , abandoning in favor of #38032 which we expect will have an easier time passing PR review |
Summary:
Moves fp8.py's online quantization to be more consistent with the QERL abstractions introduced in #32133. The main benefit is the removal of custom logic in
fp8.pyin favor of a more generalized and composable path. The peak memory usage is unchanged with this PR. The new high level way fp8.py streaming weight loading works:create_weightsmethod can create weights on devicemetato opt in to saving memory with streaming weight loading + quantizationModelLoader'sload_modelfunction, a new API can be called to turn on streaming weight loading:meta.Test Plan:
// unit and integration tests, this includes testing for peak memory after weight loading pytest tests/quantization/test_fp8.py -s // dense VLLM_LOGGING_LEVEL=DEBUG python3 examples/basic/offline_inference/generate.py --model facebook/opt-125m --enforce-eager --dtype=bfloat16 --max_model_len=2048 --quantization=fp8 // moe VLLM_LOGGING_LEVEL=DEBUG python3 examples/basic/offline_inference/generate.py --model Qwen/Qwen3-30B-A3B --enforce-eager --dtype=bfloat16 --block-size=64 --max_model_len=2048 --gpu-memory-utilization=0.8 --trust-remote-code --quantization=fp8 // moe with tp on CUDA_VISIBLE_DEVICES=0,1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/basic/offline_inference/generate.py --model Qwen/Qwen3-30B-A3B --enforce-eager --dtype=bfloat16 --block-size=64 --max_model_len=2048 --gpu-memory-utilization=0.8 --trust-remote-code --quantization=fp8 -tp 2Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.