Skip to content

fix memory for online fp8 quantization with streaming weight load#31914

Merged
mgoin merged 1 commit intovllm-project:mainfrom
vkuzo:20260107_streaming_quant_memory_fix
Feb 2, 2026
Merged

fix memory for online fp8 quantization with streaming weight load#31914
mgoin merged 1 commit intovllm-project:mainfrom
vkuzo:20260107_streaming_quant_memory_fix

Conversation

@vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Jan 7, 2026

Summary:

#29196 implemented streaming weight post-processing for online fp8 quant but did not actually reduce peak memory, because the linear|moe weights were created in bf16 and references to them were held for the entire load_weights loop in model loaders.

This PR fixes it by changing fp8 online quant to create weights on device meta in create_weights, and materialize them just-in-time in patched_weight_loader. We also add a log for peak memory usage directly after the weight loading loop, and a unit test on an MoE model (with linear and MoE layers) to ensure that peak memory usage is as expected. Finally, we add a workaround to ensure --load_format dummy functionality still works with online quant.

Peak memory usage before/after this PR on Qwen/Qwen1.5-MoE-A2.7B:

# baseline (bfloat16 without quantization)
GPU memory used after loading weights: 26.67 GiB                                                                                                                                                                                
Peak GPU memory usage while loading weights: 26.73 GiB

# fp8, before this PR
GPU memory used after loading weights: 13.94 GiB                                                                                                                                                                                
Peak GPU memory usage while loading weights: 40.48 GiB 

# fp8, after this PR
GPU memory used after loading weights: 13.94 GiB
Peak GPU memory usage while loading weights: 14.98 GiB
# for additional context,
# with moe memory optimization only (no linear): peak GPU memory 17.27 GiB
# with linear memory optimization only (no moe): peak GPU memory 38.18 GiB

fixes #31805

Test Plan:

inspect memory usage inside of load_weights and verify that it increases ~monotonically as weights are loaded

// dense
VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --enforce-eager --dtype=bfloat16 --max_model_len=2048 --quantization=fp8
// moe
VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/basic/generate.py --model Qwen/Qwen3-30B-A3B  --enforce-eager --dtype=bfloat16 --block-size=64 --max_model_len=2048 --gpu-memory-utilization=0.8 --trust-remote-code --quantization=fp8
// moe with tp on
CUDA_VISIBLE_DEVICES=0,1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/basic/generate.py --model Qwen/Qwen3-30B-A3B  --enforce-eager --dtype=bfloat16 --block-size=64 --max_model_len=2048 --gpu-memory-utilization=0.8 --trust-remote-code --quantization=fp8 -tp 2 -ep

// also, run new test to enforce peak memory is sane
$ pytest tests/quantization/test_fp8.py -s -x -k online_quant_peak_mem
...
GPU memory used after loading weights: 1.29 GiB
Peak GPU memory usage while loading weights: 1.48 GiB

// ensure --load_format dummy works
$ pytest tests/quantization/test_fp8.py -s -x -k online_quant_load_format_dummy

Reviewers:

Subscribers:

Tasks:

Tags:

Purpose

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a memory issue in online fp8 quantization with streaming weight loading by deferring weight tensor materialization. The approach is to create zero-sized placeholder tensors initially and then materialize them to their full size just-in-time during weight loading. This is a clever way to reduce peak memory usage. My review focuses on improving the robustness of this implementation. I've identified a few places where the placeholder tensors are created as 1D tensors, while the original tensors are multi-dimensional. While this might work currently, it's fragile. I've suggested changes to preserve the tensor dimensionality for better correctness and to prevent potential issues in the future.

@vkuzo vkuzo force-pushed the 20260107_streaming_quant_memory_fix branch from 168dcdb to c8f4c2f Compare January 7, 2026 20:54
Copy link
Contributor

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. This is essentially the strategy that I'd like to adopt moving forward.

Comment on lines +510 to +520
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
device=layer._load_device,
),
input_dim=1,
output_dim=0,
weight_loader=patched_weight_loader,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is technically duplicated code with previous logic, but it's nbd.

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall but two things

  1. We should add an integration test measuring peak memory during load with online FP8 quantization. We should always be below BF16 size now
  2. I would like #32189 to land first, then we land this, so it is easier to see the code separated from the serialized cases

@vkuzo vkuzo force-pushed the 20260107_streaming_quant_memory_fix branch from c8f4c2f to 563c725 Compare January 14, 2026 21:32
@vkuzo vkuzo requested a review from 22quinn as a code owner January 14, 2026 21:32
@@ -56,6 +56,19 @@ def load_model(
logger.debug("Loading weights on %s ...", load_device)
# Quantization does not happen in `load_weights` but after it
self.load_weights(model, model_config)

# Log peak GPU memory after loading weights. This is needed
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that the actual peak as logged here is not visible when just measuring peak memory after the llm object is initialized - seems like we need extra logging. Open on where to put this if there is a better place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't add this log by default... could you make it a debug_once log and just set the logging level within the test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, fixed!

@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 14, 2026

We should add an integration test measuring peak memory during load with online FP8 quantization. We should always be below BF16 size now

done. I looked into running bf16 model and then fp8 model to compare, but ran into various issues with resources not properly being released. Probably solvable, but seemed easier just to test vs a model-specific threshold with fp8 on.

@vkuzo vkuzo requested a review from mgoin January 14, 2026 21:37
vkuzo added a commit to vkuzo/vllm that referenced this pull request Jan 16, 2026
Summary:

Enables using float8 blockwise scaling with `fp8.py` online quantization.

For now, the UI part of this PR is a placeholder pending the discussions
in vllm-project#32412 . The bulk of the
PR is just wiring up kernels that already exist to fp8.py + online quant
+ blockwise scaling.

This will need to be rebased after the following PRs land:
* vllm-project#32189
* vllm-project#31914

Test Plan:

TODO

Signed-off-by: Vasiliy Kuznetsov <vasiliy@meta.com>
@mergify
Copy link

mergify bot commented Jan 21, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vkuzo.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 21, 2026
@vkuzo vkuzo force-pushed the 20260107_streaming_quant_memory_fix branch from 563c725 to b33f6c2 Compare January 21, 2026 12:50
@mergify mergify bot removed the needs-rebase label Jan 21, 2026
@vkuzo vkuzo force-pushed the 20260107_streaming_quant_memory_fix branch from b33f6c2 to 926fe18 Compare January 21, 2026 13:20
# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time
weight = ModelWeightParameter(
data=torch.empty(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: duplicating this weight creation between outside and inside of patched_weight_loader to keep it simple, since there is only one copy and some args are different. I'm flexible though.

@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 28, 2026

use pytest.mark.forked instead of pytest --forked

@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 28, 2026

almost there, just need to get the new test that enforces peak memory for online quant to properly run in CI

@vkuzo vkuzo force-pushed the 20260107_streaming_quant_memory_fix branch 5 times, most recently from 5d97909 to 434baea Compare January 29, 2026 16:03
set_weight_attrs(w2_weight, extra_weight_attrs)
_copy_missing_attrs(layer.w2_weight, w2_weight)
layer.register_parameter("w2_weight", w2_weight)
del layer._load_device
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found del here will cause error in DP + EP case:

AttributeError: 'FusedMoE' object has no attribute '_load_device'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yma11 thanks! I will take a look directly after I fix the logging issue in CI. Just in case I don't repro right away, if you can share your repro command that would be great.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nm, I can repro, looking

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the issue by making sure we do not incorrectly reinitialize weights when EP is on, please let me know if there are any further issues

@vkuzo vkuzo force-pushed the 20260107_streaming_quant_memory_fix branch from 434baea to b15a0a4 Compare January 30, 2026 11:37
@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 30, 2026

moving the test_online_quant_peak_mem test to use caplog_vllm and caplog_mp_fork to see if that works in CI

@vkuzo vkuzo force-pushed the 20260107_streaming_quant_memory_fix branch from b15a0a4 to ffc08da Compare January 30, 2026 13:14
@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 30, 2026

ok, finally got the test_online_quant_peak_mem to run in CI properly. Now need to fix the EP issue.

@vkuzo vkuzo force-pushed the 20260107_streaming_quant_memory_fix branch 2 times, most recently from 0484aed to 700e616 Compare January 30, 2026 15:30
Summary:

vllm-project#29196 implemented streaming
weight post-processing for online fp8 quant but did not actually reduce peak
memory, because the linear|moe weights were created in bf16 and references to
them were held for the entire `load_weights` loop in model loaders.

this PR fixes it by changing fp8 online quant to create zero-sized weights in
`create_weights`, and materialize them to the correct size just-in-time
in `patched_weight_loader`.

I would note that this PR is a bit hacky, and there are two more proper ways to
fix this that I can think of, both with a much wider blast radius:
- 1: change weight creation in vllm to be materialized just-in-time
  (same as this PR, just explicit instead of hacky callables)
- 2: or, add an extension point for post-processing the weight before
  loading it (similar to vllm-project#27280)

fixes vllm-project#31805

Test Plan:

inspect memory usage inside of `load_weights` and verify that it
increases ~monotonically as weights are loaded

```bash
// dense
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --enforce-eager --dtype=bfloat16 --max_model_len=2048 --quantization=fp8
// moe
CUDA_VISIBLE_DEVICES=7 python3 examples/offline_inference/basic/generate.py --model Qwen/Qwen3-30B-A3B  --enforce-eager --dtype=bfloat16 --block-size=64 --max_model_len=2048 --gpu-memory-utilization=0.8 --trust-remote-code --quantization=fp8
```

Reviewers:

Subscribers:

Tasks:

Tags:

Signed-off-by: vasiliy <vasiliy@fb.com>
@vkuzo vkuzo force-pushed the 20260107_streaming_quant_memory_fix branch from 700e616 to 5bf77bd Compare January 30, 2026 18:03
@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 30, 2026

CI on previous revision looked good, but rebasing one more time to be on top of #33432

@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 30, 2026

ok, this should be ready now

changes since last review:

  1. made the test_online_quant_peak_mem test properly run in CI - choose a model already used by other tests (to not increase disk space usage on runners), and use the right abstractions to enable and capture debug logs from within the test when multiprocessing is using spawn
  2. ensured that EP works (not tested in CI but verified locally and added to test plan). Let me know if this is worth adding to CI? If yes, I can look into it, although would need runners with 2 GPUs.

cc @mgoin

Copy link
Contributor

@fxmarty-amd fxmarty-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mgoin mgoin merged commit 0130223 into vllm-project:main Feb 2, 2026
49 checks passed
yma11 pushed a commit to yma11/vllm that referenced this pull request Feb 3, 2026
PiratePai pushed a commit to PiratePai/epd_shm that referenced this pull request Feb 3, 2026
…lm-project#31914)

Signed-off-by: vasiliy <vasiliy@fb.com>
Signed-off-by: Pai <416932041@qq.com>
PiratePai pushed a commit to PiratePai/epd_shm that referenced this pull request Feb 3, 2026
…lm-project#31914)

Signed-off-by: vasiliy <vasiliy@fb.com>
Signed-off-by: Pai <416932041@qq.com>
gameofdimension pushed a commit to gameofdimension/vllm that referenced this pull request Feb 5, 2026
…lm-project#31914)

Signed-off-by: vasiliy <vasiliy@fb.com>
Signed-off-by: felix01.yu <felix01.yu@vipshop.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
@yma11 yma11 mentioned this pull request Mar 4, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build quantization ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: streaming quantization cause higher peak memory used during model loading and post process

6 participants