online fp8 quant with streaming weight post-processing#29196
online fp8 quant with streaming weight post-processing#29196robertgshaw2-redhat merged 1 commit intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
There was a problem hiding this comment.
Code Review
This pull request introduces a proof-of-concept for online FP8 quantization with streaming weight post-processing. The approach is clever, patching the weight loader to trigger post-processing as soon as a weight tensor is fully loaded. This can help reduce peak memory usage during model loading.
My main feedback is to improve the robustness of the state management. The flag _already_called_process_weights_after_loading is currently stored on the Fp8LinearMethod instance. While this works with the current code structure, it's fragile. Attaching this state to the layer object instead would make the implementation more robust against future changes, such as instance reuse for optimization. I've added specific comments with code suggestions to address this.
| del param._loaded_numel | ||
| # Prevent the usual `process_weights_after_loading` call from doing | ||
| # anything | ||
| self._already_called_process_weights_after_loading = True |
There was a problem hiding this comment.
Storing _already_called_process_weights_after_loading on self (the Fp8LinearMethod instance) makes the design fragile. Although a new instance is currently created for each layer, this might change in the future (e.g., for optimization), which could lead to this flag persisting incorrectly across different layers.
To make this more robust, this state should be attached to the layer object, which is guaranteed to be unique. This change should be made in conjunction with the corresponding check in process_weights_after_loading.
| self._already_called_process_weights_after_loading = True | |
| layer._already_called_process_weights_after_loading = True |
| @@ -487,6 +516,9 @@ def create_weights( | |||
| layer.register_parameter("input_scale", None) | |||
|
|
|||
| def process_weights_after_loading(self, layer: Module) -> None: | |||
| if getattr(self, "_already_called_process_weights_after_loading", False): | |||
There was a problem hiding this comment.
To make the state management more robust and in conjunction with the suggested change for setting this flag, this check should be on the layer object instead of self.
| if getattr(self, "_already_called_process_weights_after_loading", False): | |
| if getattr(layer, "_already_called_process_weights_after_loading", False): |
5326892 to
9583e3b
Compare
9583e3b to
8956364
Compare
| res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] | ||
|
|
||
| # add a counter to track how many elements we have updated | ||
| if not hasattr(layer, "_target_loaded_numel"): |
There was a problem hiding this comment.
Are all of these hasattr checks necessary? It's a little hard to read right now, I think the attribute is guaranteed to exist if not self.quant_config.is_checkpoint_fp8_serialized
There was a problem hiding this comment.
makes sense, sg, we can delete the _target_loaded_numel var and just compute it every time
There was a problem hiding this comment.
The same goes for _loaded_numel. If this attribute didn't exist, then something has gone terribly wrong and it might be appropriate to error.
| # add a counter to track how many elements we have updated | ||
| if not hasattr(layer, "_target_loaded_numel"): | ||
| # for linear, the only weight we need to load is `layer.weight` | ||
| layer._target_loaded_numel = layer.weight.numel() |
There was a problem hiding this comment.
Nit: avoid extra flags by deleting _target_loaded_numel and just calling layer.weight.numel()
Summary: not for land, just a demo 1. during weight loading, keep track of how many elements we have loaded 2. when we have loaded all the elements, call post-processing can be used to call weight post-processing in a streaming fashion to minimize GPU memory usage. Will only work if we can assume we only load each weight chunk once. Test Plan: tested locally with facebook/opt-125m and `fp8` online quantization Reviewers: Subscribers: Tasks: Tags: Signed-off-by: vasiliy <vasiliy@fb.com>
fda3623 to
427f05d
Compare
Summary: When we added online fp8 quant with streaming weight post-processing in vllm-project#29196, a bug was introduced where TP>1 case was not always handled correctly. Specifically: * vllm-project#29196 assumed that `weight_loader` copies `loaded_weight` to `param` * this is not true, as `weight_loader` can call arbitrary logic on both `param` and `loaded_weight` before eventually calling `copy_`. An example is here: https://github.com/vllm-project/vllm/blob/e3a0f21e6ce78268865cafcdc3dc58c7a80dbc57/vllm/model_executor/parameter.py#L195 A fix is to track exactly how many number of elements were updated with `copy_`. The PR implements this fix with `TorchDispatchMode`. Test Plan: ```bash // tp 1 still works CUDA_VISIBLE_DEVICES=6,7 with-proxy python3 examples/offline_inference/basic/generate.py --model Qwen/Qwen1.5-MoE-A2.7B --enforce-eager --quantization fp8 -tp=1 // tp > 2 was broken before, now works CUDA_VISIBLE_DEVICES=6,7 with-proxy python3 examples/offline_inference/basic/generate.py --model Qwen/Qwen1.5-MoE-A2.7B --enforce-eager --quantization fp8 -tp=2 ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: vasiliy <vasiliy@fb.com>
…29196) Signed-off-by: vasiliy <vasiliy@fb.com> Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Summary: vllm-project#29196 implemented streaming weight post-processing for online fp8 quant but did not actually reduce peak memory, because the linear|moe weights were created in bf16 and references to them were held for the entire `load_weights` loop in model loaders. this PR fixes it by changing fp8 online quant to create zero-sized weights in `create_weights`, and materialize them to the correct size just-in-time in `patched_weight_loader`. I would note that this PR is a bit hacky, and there are two more proper ways to fix this that I can think of, both with a much wider blast radius: - 1: change weight creation in vllm to be materialized just-in-time (same as this PR, just explicit instead of hacky callables) - 2: or, add an extension point for post-processing the weight before loading it (similar to vllm-project#27280) fixes vllm-project#31805 Test Plan: inspect memory usage inside of `load_weights` and verify that it increases ~monotonically as weights are loaded ```bash // dense python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --enforce-eager --dtype=bfloat16 --max_model_len=2048 --quantization=fp8 // moe CUDA_VISIBLE_DEVICES=7 python3 examples/offline_inference/basic/generate.py --model Qwen/Qwen3-30B-A3B --enforce-eager --dtype=bfloat16 --block-size=64 --max_model_len=2048 --gpu-memory-utilization=0.8 --trust-remote-code --quantization=fp8 ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: vasiliy <vasiliy@fb.com>
Summary: vllm-project#29196 implemented streaming weight post-processing for online fp8 quant but did not actually reduce peak memory, because the linear|moe weights were created in bf16 and references to them were held for the entire `load_weights` loop in model loaders. this PR fixes it by changing fp8 online quant to create zero-sized weights in `create_weights`, and materialize them to the correct size just-in-time in `patched_weight_loader`. I would note that this PR is a bit hacky, and there are two more proper ways to fix this that I can think of, both with a much wider blast radius: - 1: change weight creation in vllm to be materialized just-in-time (same as this PR, just explicit instead of hacky callables) - 2: or, add an extension point for post-processing the weight before loading it (similar to vllm-project#27280) fixes vllm-project#31805 Test Plan: inspect memory usage inside of `load_weights` and verify that it increases ~monotonically as weights are loaded ```bash // dense python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --enforce-eager --dtype=bfloat16 --max_model_len=2048 --quantization=fp8 // moe CUDA_VISIBLE_DEVICES=7 python3 examples/offline_inference/basic/generate.py --model Qwen/Qwen3-30B-A3B --enforce-eager --dtype=bfloat16 --block-size=64 --max_model_len=2048 --gpu-memory-utilization=0.8 --trust-remote-code --quantization=fp8 ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: vasiliy <vasiliy@fb.com>
Summary:
updates the fp8.py vllm online quant entry point to post-process weights as soon as they are ready, to prevent from holding the entire high precision model in memory. Specifically:
I'm putting up this PR to demonstrate the approach. Since for now the # of use cases is 2 (linear and moe in fp8.py), keeping the logic simple. If this is extended to other online quant providers such as compressed-tensors/torchao/etc in the future, might be good to make the code more reusable in future PRs.
Test Plan:
tested locally with
facebook/opt-125mfor linear andQwen/Qwen1.5-MoE-A2.7Bfor MoE andfp8online quantization. Verified (via prints) thatlayer.process_weights_after_loadingis called immediately after the last weight chunk is loaded in both cases.Reviewers:
Subscribers:
Tasks:
Tags:
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.