Skip to content

online fp8 quant with streaming weight post-processing#29196

Merged
robertgshaw2-redhat merged 1 commit intovllm-project:mainfrom
vkuzo:20251121_fp8_online_quant_hack
Dec 8, 2025
Merged

online fp8 quant with streaming weight post-processing#29196
robertgshaw2-redhat merged 1 commit intovllm-project:mainfrom
vkuzo:20251121_fp8_online_quant_hack

Conversation

@vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Nov 21, 2025

Summary:

updates the fp8.py vllm online quant entry point to post-process weights as soon as they are ready, to prevent from holding the entire high precision model in memory. Specifically:

  1. during weight loading, keep track of how many elements we have loaded
  2. when we have loaded all the elements, call post-processing

I'm putting up this PR to demonstrate the approach. Since for now the # of use cases is 2 (linear and moe in fp8.py), keeping the logic simple. If this is extended to other online quant providers such as compressed-tensors/torchao/etc in the future, might be good to make the code more reusable in future PRs.

Test Plan:

tested locally with facebook/opt-125m for linear and Qwen/Qwen1.5-MoE-A2.7B for MoE and fp8 online quantization. Verified (via prints) that layer.process_weights_after_loading is called immediately after the last weight chunk is loaded in both cases.

Reviewers:

Subscribers:

Tasks:

Tags:

Purpose

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a proof-of-concept for online FP8 quantization with streaming weight post-processing. The approach is clever, patching the weight loader to trigger post-processing as soon as a weight tensor is fully loaded. This can help reduce peak memory usage during model loading.

My main feedback is to improve the robustness of the state management. The flag _already_called_process_weights_after_loading is currently stored on the Fp8LinearMethod instance. While this works with the current code structure, it's fragile. Attaching this state to the layer object instead would make the implementation more robust against future changes, such as instance reuse for optimization. I've added specific comments with code suggestions to address this.

del param._loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# anything
self._already_called_process_weights_after_loading = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Storing _already_called_process_weights_after_loading on self (the Fp8LinearMethod instance) makes the design fragile. Although a new instance is currently created for each layer, this might change in the future (e.g., for optimization), which could lead to this flag persisting incorrectly across different layers.

To make this more robust, this state should be attached to the layer object, which is guaranteed to be unique. This change should be made in conjunction with the corresponding check in process_weights_after_loading.

Suggested change
self._already_called_process_weights_after_loading = True
layer._already_called_process_weights_after_loading = True

@@ -487,6 +516,9 @@ def create_weights(
layer.register_parameter("input_scale", None)

def process_weights_after_loading(self, layer: Module) -> None:
if getattr(self, "_already_called_process_weights_after_loading", False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To make the state management more robust and in conjunction with the suggested change for setting this flag, this check should be on the layer object instead of self.

Suggested change
if getattr(self, "_already_called_process_weights_after_loading", False):
if getattr(layer, "_already_called_process_weights_after_loading", False):

@vkuzo vkuzo force-pushed the 20251121_fp8_online_quant_hack branch from 5326892 to 9583e3b Compare November 21, 2025 19:07
@vkuzo vkuzo force-pushed the 20251121_fp8_online_quant_hack branch from 9583e3b to 8956364 Compare November 26, 2025 15:52
@vkuzo vkuzo changed the title [not for land] online fp8 quant with streaming weight post-processing online fp8 quant with streaming weight post-processing Nov 26, 2025
@robertgshaw2-redhat robertgshaw2-redhat added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 3, 2025
Copy link
Contributor

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]

# add a counter to track how many elements we have updated
if not hasattr(layer, "_target_loaded_numel"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all of these hasattr checks necessary? It's a little hard to read right now, I think the attribute is guaranteed to exist if not self.quant_config.is_checkpoint_fp8_serialized

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, sg, we can delete the _target_loaded_numel var and just compute it every time

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same goes for _loaded_numel. If this attribute didn't exist, then something has gone terribly wrong and it might be appropriate to error.

# add a counter to track how many elements we have updated
if not hasattr(layer, "_target_loaded_numel"):
# for linear, the only weight we need to load is `layer.weight`
layer._target_loaded_numel = layer.weight.numel()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: avoid extra flags by deleting _target_loaded_numel and just calling layer.weight.numel()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg

Summary:

not for land, just a demo

1. during weight loading, keep track of how many elements we have loaded
2. when we have loaded all the elements, call post-processing

can be used to call weight post-processing in a streaming fashion
to minimize GPU memory usage. Will only work if we can assume we only
load each weight chunk once.

Test Plan:

tested locally with facebook/opt-125m and `fp8` online quantization

Reviewers:

Subscribers:

Tasks:

Tags:

Signed-off-by: vasiliy <vasiliy@fb.com>
@vkuzo vkuzo force-pushed the 20251121_fp8_online_quant_hack branch from fda3623 to 427f05d Compare December 4, 2025 13:34
@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) December 8, 2025 19:44
@robertgshaw2-redhat robertgshaw2-redhat merged commit 0d402d2 into vllm-project:main Dec 8, 2025
54 checks passed
vkuzo added a commit to vkuzo/vllm that referenced this pull request Dec 18, 2025
Summary:

When we added online fp8 quant with streaming weight post-processing in
vllm-project#29196, a bug was introduced
where TP>1 case was not always handled correctly. Specifically:

* vllm-project#29196 assumed that
  `weight_loader` copies `loaded_weight` to `param`
* this is not true, as `weight_loader` can call arbitrary logic on both
  `param` and `loaded_weight` before eventually calling `copy_`. An
  example is here:
  https://github.com/vllm-project/vllm/blob/e3a0f21e6ce78268865cafcdc3dc58c7a80dbc57/vllm/model_executor/parameter.py#L195

A fix is to track exactly how many number of elements were updated with
`copy_`.  The PR implements this fix with `TorchDispatchMode`.

Test Plan:

```bash
// tp 1 still works
CUDA_VISIBLE_DEVICES=6,7 with-proxy python3 examples/offline_inference/basic/generate.py --model Qwen/Qwen1.5-MoE-A2.7B --enforce-eager --quantization fp8 -tp=1
// tp > 2 was broken before, now works
CUDA_VISIBLE_DEVICES=6,7 with-proxy python3 examples/offline_inference/basic/generate.py --model Qwen/Qwen1.5-MoE-A2.7B --enforce-eager --quantization fp8 -tp=2
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: vasiliy <vasiliy@fb.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…29196)

Signed-off-by: vasiliy <vasiliy@fb.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
vkuzo added a commit to vkuzo/vllm that referenced this pull request Jan 23, 2026
Summary:

vllm-project#29196 implemented streaming
weight post-processing for online fp8 quant but did not actually reduce peak
memory, because the linear|moe weights were created in bf16 and references to
them were held for the entire `load_weights` loop in model loaders.

this PR fixes it by changing fp8 online quant to create zero-sized weights in
`create_weights`, and materialize them to the correct size just-in-time
in `patched_weight_loader`.

I would note that this PR is a bit hacky, and there are two more proper ways to
fix this that I can think of, both with a much wider blast radius:
- 1: change weight creation in vllm to be materialized just-in-time
  (same as this PR, just explicit instead of hacky callables)
- 2: or, add an extension point for post-processing the weight before
  loading it (similar to vllm-project#27280)

fixes vllm-project#31805

Test Plan:

inspect memory usage inside of `load_weights` and verify that it
increases ~monotonically as weights are loaded

```bash
// dense
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --enforce-eager --dtype=bfloat16 --max_model_len=2048 --quantization=fp8
// moe
CUDA_VISIBLE_DEVICES=7 python3 examples/offline_inference/basic/generate.py --model Qwen/Qwen3-30B-A3B  --enforce-eager --dtype=bfloat16 --block-size=64 --max_model_len=2048 --gpu-memory-utilization=0.8 --trust-remote-code --quantization=fp8
```

Reviewers:

Subscribers:

Tasks:

Tags:

Signed-off-by: vasiliy <vasiliy@fb.com>
vkuzo added a commit to vkuzo/vllm that referenced this pull request Jan 30, 2026
Summary:

vllm-project#29196 implemented streaming
weight post-processing for online fp8 quant but did not actually reduce peak
memory, because the linear|moe weights were created in bf16 and references to
them were held for the entire `load_weights` loop in model loaders.

this PR fixes it by changing fp8 online quant to create zero-sized weights in
`create_weights`, and materialize them to the correct size just-in-time
in `patched_weight_loader`.

I would note that this PR is a bit hacky, and there are two more proper ways to
fix this that I can think of, both with a much wider blast radius:
- 1: change weight creation in vllm to be materialized just-in-time
  (same as this PR, just explicit instead of hacky callables)
- 2: or, add an extension point for post-processing the weight before
  loading it (similar to vllm-project#27280)

fixes vllm-project#31805

Test Plan:

inspect memory usage inside of `load_weights` and verify that it
increases ~monotonically as weights are loaded

```bash
// dense
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --enforce-eager --dtype=bfloat16 --max_model_len=2048 --quantization=fp8
// moe
CUDA_VISIBLE_DEVICES=7 python3 examples/offline_inference/basic/generate.py --model Qwen/Qwen3-30B-A3B  --enforce-eager --dtype=bfloat16 --block-size=64 --max_model_len=2048 --gpu-memory-utilization=0.8 --trust-remote-code --quantization=fp8
```

Reviewers:

Subscribers:

Tasks:

Tags:

Signed-off-by: vasiliy <vasiliy@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants