fix memory for online fp8 quantization with streaming weight load#31914
fix memory for online fp8 quantization with streaming weight load#31914mgoin merged 1 commit intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request addresses a memory issue in online fp8 quantization with streaming weight loading by deferring weight tensor materialization. The approach is to create zero-sized placeholder tensors initially and then materialize them to their full size just-in-time during weight loading. This is a clever way to reduce peak memory usage. My review focuses on improving the robustness of this implementation. I've identified a few places where the placeholder tensors are created as 1D tensors, while the original tensors are multi-dimensional. While this might work currently, it's fragile. I've suggested changes to preserve the tensor dimensionality for better correctness and to prevent potential issues in the future.
168dcdb to
c8f4c2f
Compare
kylesayrs
left a comment
There was a problem hiding this comment.
Looks good to me. This is essentially the strategy that I'd like to adopt moving forward.
| weight = ModelWeightParameter( | ||
| data=torch.empty( | ||
| output_size_per_partition, | ||
| input_size_per_partition, | ||
| dtype=params_dtype, | ||
| device=layer._load_device, | ||
| ), | ||
| input_dim=1, | ||
| output_dim=0, | ||
| weight_loader=patched_weight_loader, | ||
| ) |
There was a problem hiding this comment.
This is technically duplicated code with previous logic, but it's nbd.
mgoin
left a comment
There was a problem hiding this comment.
LGTM overall but two things
- We should add an integration test measuring peak memory during load with online FP8 quantization. We should always be below BF16 size now
- I would like #32189 to land first, then we land this, so it is easier to see the code separated from the serialized cases
c8f4c2f to
563c725
Compare
| @@ -56,6 +56,19 @@ def load_model( | |||
| logger.debug("Loading weights on %s ...", load_device) | |||
| # Quantization does not happen in `load_weights` but after it | |||
| self.load_weights(model, model_config) | |||
|
|
|||
| # Log peak GPU memory after loading weights. This is needed | |||
There was a problem hiding this comment.
note that the actual peak as logged here is not visible when just measuring peak memory after the llm object is initialized - seems like we need extra logging. Open on where to put this if there is a better place.
There was a problem hiding this comment.
We shouldn't add this log by default... could you make it a debug_once log and just set the logging level within the test?
done. I looked into running bf16 model and then fp8 model to compare, but ran into various issues with resources not properly being released. Probably solvable, but seemed easier just to test vs a model-specific threshold with fp8 on. |
Summary: Enables using float8 blockwise scaling with `fp8.py` online quantization. For now, the UI part of this PR is a placeholder pending the discussions in vllm-project#32412 . The bulk of the PR is just wiring up kernels that already exist to fp8.py + online quant + blockwise scaling. This will need to be rebased after the following PRs land: * vllm-project#32189 * vllm-project#31914 Test Plan: TODO Signed-off-by: Vasiliy Kuznetsov <vasiliy@meta.com>
|
This pull request has merge conflicts that must be resolved before it can be |
563c725 to
b33f6c2
Compare
b33f6c2 to
926fe18
Compare
| # when the first `loaded_weight` is about to be | ||
| # loaded to `param`, materialize `param` just-in-time | ||
| weight = ModelWeightParameter( | ||
| data=torch.empty( |
There was a problem hiding this comment.
note: duplicating this weight creation between outside and inside of patched_weight_loader to keep it simple, since there is only one copy and some args are different. I'm flexible though.
|
use |
|
almost there, just need to get the new test that enforces peak memory for online quant to properly run in CI |
5d97909 to
434baea
Compare
| set_weight_attrs(w2_weight, extra_weight_attrs) | ||
| _copy_missing_attrs(layer.w2_weight, w2_weight) | ||
| layer.register_parameter("w2_weight", w2_weight) | ||
| del layer._load_device |
There was a problem hiding this comment.
I found del here will cause error in DP + EP case:
AttributeError: 'FusedMoE' object has no attribute '_load_device'
There was a problem hiding this comment.
@yma11 thanks! I will take a look directly after I fix the logging issue in CI. Just in case I don't repro right away, if you can share your repro command that would be great.
There was a problem hiding this comment.
nm, I can repro, looking
There was a problem hiding this comment.
I fixed the issue by making sure we do not incorrectly reinitialize weights when EP is on, please let me know if there are any further issues
434baea to
b15a0a4
Compare
|
moving the |
b15a0a4 to
ffc08da
Compare
|
ok, finally got the |
0484aed to
700e616
Compare
Summary: vllm-project#29196 implemented streaming weight post-processing for online fp8 quant but did not actually reduce peak memory, because the linear|moe weights were created in bf16 and references to them were held for the entire `load_weights` loop in model loaders. this PR fixes it by changing fp8 online quant to create zero-sized weights in `create_weights`, and materialize them to the correct size just-in-time in `patched_weight_loader`. I would note that this PR is a bit hacky, and there are two more proper ways to fix this that I can think of, both with a much wider blast radius: - 1: change weight creation in vllm to be materialized just-in-time (same as this PR, just explicit instead of hacky callables) - 2: or, add an extension point for post-processing the weight before loading it (similar to vllm-project#27280) fixes vllm-project#31805 Test Plan: inspect memory usage inside of `load_weights` and verify that it increases ~monotonically as weights are loaded ```bash // dense python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --enforce-eager --dtype=bfloat16 --max_model_len=2048 --quantization=fp8 // moe CUDA_VISIBLE_DEVICES=7 python3 examples/offline_inference/basic/generate.py --model Qwen/Qwen3-30B-A3B --enforce-eager --dtype=bfloat16 --block-size=64 --max_model_len=2048 --gpu-memory-utilization=0.8 --trust-remote-code --quantization=fp8 ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: vasiliy <vasiliy@fb.com>
700e616 to
5bf77bd
Compare
|
CI on previous revision looked good, but rebasing one more time to be on top of #33432 |
|
ok, this should be ready now changes since last review:
cc @mgoin |
…lm-project#31914) Signed-off-by: vasiliy <vasiliy@fb.com>
…lm-project#31914) Signed-off-by: vasiliy <vasiliy@fb.com> Signed-off-by: Pai <416932041@qq.com>
…lm-project#31914) Signed-off-by: vasiliy <vasiliy@fb.com> Signed-off-by: Pai <416932041@qq.com>
…lm-project#31914) Signed-off-by: vasiliy <vasiliy@fb.com> Signed-off-by: felix01.yu <felix01.yu@vipshop.com>
…lm-project#31914) Signed-off-by: vasiliy <vasiliy@fb.com>
Summary:
#29196 implemented streaming weight post-processing for online fp8 quant but did not actually reduce peak memory, because the linear|moe weights were created in bf16 and references to them were held for the entire
load_weightsloop in model loaders.This PR fixes it by changing fp8 online quant to create weights on device
metaincreate_weights, and materialize them just-in-time inpatched_weight_loader. We also add a log for peak memory usage directly after the weight loading loop, and a unit test on an MoE model (with linear and MoE layers) to ensure that peak memory usage is as expected. Finally, we add a workaround to ensure--load_format dummyfunctionality still works with online quant.Peak memory usage before/after this PR on
Qwen/Qwen1.5-MoE-A2.7B:fixes #31805
Test Plan:
inspect memory usage inside of
load_weightsand verify that it increases ~monotonically as weights are loadedReviewers:
Subscribers:
Tasks:
Tags:
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.