Skip to content

[Quantization] Support pre-load online quantization for compressed-tensors W8A8 channel-wise schema#27280

Open
luccafong wants to merge 4 commits intovllm-project:mainfrom
luccafong:fp8_channelwise_quantization_online
Open

[Quantization] Support pre-load online quantization for compressed-tensors W8A8 channel-wise schema#27280
luccafong wants to merge 4 commits intovllm-project:mainfrom
luccafong:fp8_channelwise_quantization_online

Conversation

@luccafong
Copy link
Collaborator

@luccafong luccafong commented Oct 21, 2025

Purpose

Support pre-load online quantization for compressed-tensors W8A8 channel-wise schema on bf16 ckpt.

  • Memory Optimization: This differentiates from the other online quantization approach that is post loading through process_weights_after_loading, we add quantization of weights through process_weights_before_loading when the hardware GPU can not hold the original dtype weights and cpu offloading is too slow. This approach quantize each weight while loading. This enables online dynamic quantization for Llama4 Maverick Raw BF16 on H100, which is not doable before.
  • Extendible: The PR implements compressed-tensor fp8 channelwise (same as FP_dynamic in offline quantization), while the approach is extendible to other quantization method if the method implements process_weights_before_loading.
  • MOE and Linear Support: This support both MOE and linear layers.
  • llama4 specific optimization: Llama4 has a transposed/chunked fused weights when calling weight loader, this PR also has an improvement on llama4 model loading that copy to device before transpose/chunk happenning to avoid expensive contiguous (the model loading reduced from 40 minutes to 2 mintues)
 --quantization compressed-tensors \
--quantization-schema fp8_channelwise \
--hf-overrides '{"quantization_config":{"ignore":["re:.*self_attn","re:.*lm_head","re:.*router","re:.*vision_model","re:.*multi_modal_projector","re:.*feed_forward.gate_up_proj","re:.*feed_forward.down_proj", "re:.*shared_expert"]}}'

Test Plan

Online Serving

vllm serve /data/local/models/oss/Llama-4-Maverick-17B-128E-Instruct -tp 8 --quantization compressed-tensors --quantization-schema fp8_channelwise --hf-overrides '{"quantization_config":{"ignore":["re:.*self_attn","re:.*lm_head","re:.*router","re:.*vision_model","re:.*multi_modal_projector","re:.*feed_forward.gate_up_proj","re:.*feed_forward.down_proj", "re:.*shared_expert"]}}' --max_num_seqs 32 --max-model-len 32768

UT/CI Tests

pytest tests/quantization/test_fp8_channelwise.py
pytest tests/quantization/tes
t_compressed_tensors.py -k "test_compressed_tensors_fp8_online_quantization_channelwise"

Added both llama3 and Qwen to guard linear and MOE models.

Test Result

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.940|±  |0.0168|
|     |       |strict-match    |     5|exact_match|↑  |0.945|±  |0.0162|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added the llama Related to Llama models label Oct 21, 2025
@luccafong luccafong changed the title [Quantizon] Support compressed-tensors W8A8 channelwise online quantization [Quantizion] Support compressed-tensors W8A8 channelwise online quantization Oct 21, 2025
@@ -196,6 +197,9 @@ class ModelConfig:
`quantization_config` attribute in the model config file. If that is
`None`, we assume the model weights are not quantized and use `dtype` to
determine the data type of the weights."""
quantization_schema: str | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just use quantization and hf_overrides to specify the config? like #23014

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the user experience might be too complicated for fp8_channelwise, so add some schema for pre-defined config here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's fine to use a string, maybe reuse hf_override to specify the string?

quantization == "compressed-tensors"
and quantization_schema == "fp8_channelwise"
):
return {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I guess this could be passed as a json string, or a file that stores the serialized json string

otherwise we'd need to invent quantization_schema name for each settings

Copy link
Member

@hmellor hmellor Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JSON configs can also be passed with . notation. i.e. --quantization-scheme.format float-quantize --quantization-scheme.quant_method compressed-tensors ..., but this will be extremely verbose for a CLI.

Perhaps as a file is a good idea as @jerryzh168 suggested?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel not a good user experience if user have to define these settings in either file or pass through the args, so I think for frequently used schema, pre-defined should be fine @hmellor @jerryzh168

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's OK to have predefined strings, but can this live in hf_overrides, instead of adding a new quantization_schema in parallel to quantization?

Copy link
Collaborator Author

@luccafong luccafong Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

later we may add fp8_blockwise, without a separate field, we could not differentiate, the current quantization are all compressed-tensors which is not enough to differentiate different schemas of the same quantization method

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I see, do you mean we specifiy a field in hf config to differentiate?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I see, do you mean we specifiy a field in hf config to differentiate?

yeah just put this in hf_override is cleaner I think, since you already have some other configs there, and this is specific to llm compressor

@mergify
Copy link

mergify bot commented Oct 21, 2025

Documentation preview: https://vllm--27280.org.readthedocs.build/en/27280/

@mergify mergify bot added the documentation Improvements or additions to documentation label Oct 21, 2025
@luccafong luccafong changed the title [Quantizion] Support compressed-tensors W8A8 channelwise online quantization [Quantizaion] Support pre-load online quantization for compressed-tensors W8A8 channel-wise schema Oct 21, 2025
@luccafong luccafong changed the title [Quantizaion] Support pre-load online quantization for compressed-tensors W8A8 channel-wise schema [Quantization] Support pre-load online quantization for compressed-tensors W8A8 channel-wise schema Oct 21, 2025
@luccafong luccafong marked this pull request as ready for review October 21, 2025 22:49
Signed-off-by: Lu Fang <fanglu@fb.com>
Signed-off-by: Lu Fang <fanglu@fb.com>
@luccafong luccafong force-pushed the fp8_channelwise_quantization_online branch from f2b24a4 to 27e66ac Compare November 11, 2025 19:53
hf_overrides_kw = {}
dict_overrides = {}
if quant_config_override:
dict_overrides["quantization_config"] = quant_config_override
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this warn if overriding non null quantization configuration? Otherwise may be unclear which takes precedence

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HDCharles thanks for the comments, I moved where get_default_quantization_hf_config called closer to this line and also add to warn to the original method of get_default_quantization_hf_config

Signed-off-by: Lu Fang <fanglu@fb.com>
if isinstance(value, dict):
dict_overrides[key] = value
if dict_overrides.get(key):
dict_overrides[key] = {**dict_overrides[key], **value}
Copy link
Contributor

@HDCharles HDCharles Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems rather convoluted.

We're assuming dict_overrides[key] is always a mapping as long as dict_overrides.get(key) and isinstance(value, dict) are true which doesn't seem obvious

as far as the logic:

we take the dict that we assume we get from dict_overrides[key] and add the value dict to it, letting value take precedence where there are any conflicts.

Feels like we should be taking one or the other i.e. overriding things rather than this merge operation.

also still no warning when one thing overrides another.

enforce_eager=True,
dtype="bfloat16",
hf_overrides={
"quantization_config": {"ignore": ["re:.*self_attn", "re:.*lm_head"]}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e. here:

hf_overrides={
            "llmc_quantization_schema"="fp8_channelwise",
            "llmc_quantization_config": {"ignore": ["re:.*self_attn", "re:.*lm_head"]}
}

llm = LLM(
"meta-llama/Llama-3-8B-Instruct",
quantization="compressed-tensors",
quantization_schema="fp8_channelwise"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this match the per tensor API?

Super weird that

quantization="fp8" does per tensor while
quantization="compressed-tensoe", quantization_schema="fp8_channelwise" does per channel

Why not just quantization="fp8_channelwise"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, let's make the adjustment for the interface

Copy link
Contributor

@HDCharles HDCharles left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. At a high level, the API for online per tensor fp8 should match the one for online per channel fp8.

  2. It looks like this PR takes a completely different path than the existing per tensor fp8 online support, creating entirely new preprocessing steps. rather than having 2 unrelated paths for these techniques, it's be much cleaner to either support both techniques with the new abstractions or just do whatever per tensor did initially.

return unfused_matches[0] if all(unfused_matches) else None


def fp8_channelwise_quantize(x: Tensor, channel_dim: int = -1) -> tuple[Tensor, Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a higher level function in compressed-tensors to do this so we don't have to redefine it here? @kylesayrs since I'm very n00b in compressed-tensors

@kylesayrs
Copy link
Contributor

@luccafong Can this be closed now that #29196 has landed?

vkuzo added a commit to vkuzo/vllm that referenced this pull request Jan 23, 2026
Summary:

vllm-project#29196 implemented streaming
weight post-processing for online fp8 quant but did not actually reduce peak
memory, because the linear|moe weights were created in bf16 and references to
them were held for the entire `load_weights` loop in model loaders.

this PR fixes it by changing fp8 online quant to create zero-sized weights in
`create_weights`, and materialize them to the correct size just-in-time
in `patched_weight_loader`.

I would note that this PR is a bit hacky, and there are two more proper ways to
fix this that I can think of, both with a much wider blast radius:
- 1: change weight creation in vllm to be materialized just-in-time
  (same as this PR, just explicit instead of hacky callables)
- 2: or, add an extension point for post-processing the weight before
  loading it (similar to vllm-project#27280)

fixes vllm-project#31805

Test Plan:

inspect memory usage inside of `load_weights` and verify that it
increases ~monotonically as weights are loaded

```bash
// dense
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --enforce-eager --dtype=bfloat16 --max_model_len=2048 --quantization=fp8
// moe
CUDA_VISIBLE_DEVICES=7 python3 examples/offline_inference/basic/generate.py --model Qwen/Qwen3-30B-A3B  --enforce-eager --dtype=bfloat16 --block-size=64 --max_model_len=2048 --gpu-memory-utilization=0.8 --trust-remote-code --quantization=fp8
```

Reviewers:

Subscribers:

Tasks:

Tags:

Signed-off-by: vasiliy <vasiliy@fb.com>
vkuzo added a commit to vkuzo/vllm that referenced this pull request Jan 30, 2026
Summary:

vllm-project#29196 implemented streaming
weight post-processing for online fp8 quant but did not actually reduce peak
memory, because the linear|moe weights were created in bf16 and references to
them were held for the entire `load_weights` loop in model loaders.

this PR fixes it by changing fp8 online quant to create zero-sized weights in
`create_weights`, and materialize them to the correct size just-in-time
in `patched_weight_loader`.

I would note that this PR is a bit hacky, and there are two more proper ways to
fix this that I can think of, both with a much wider blast radius:
- 1: change weight creation in vllm to be materialized just-in-time
  (same as this PR, just explicit instead of hacky callables)
- 2: or, add an extension point for post-processing the weight before
  loading it (similar to vllm-project#27280)

fixes vllm-project#31805

Test Plan:

inspect memory usage inside of `load_weights` and verify that it
increases ~monotonically as weights are loaded

```bash
// dense
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --enforce-eager --dtype=bfloat16 --max_model_len=2048 --quantization=fp8
// moe
CUDA_VISIBLE_DEVICES=7 python3 examples/offline_inference/basic/generate.py --model Qwen/Qwen3-30B-A3B  --enforce-eager --dtype=bfloat16 --block-size=64 --max_model_len=2048 --gpu-memory-utilization=0.8 --trust-remote-code --quantization=fp8
```

Reviewers:

Subscribers:

Tasks:

Tags:

Signed-off-by: vasiliy <vasiliy@fb.com>
@github-actions
Copy link

github-actions bot commented Mar 9, 2026

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale Over 90 days of inactivity label Mar 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation llama Related to Llama models stale Over 90 days of inactivity

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants