Skip to content

fp8 online quant: split out Fp8OnlineLinearMethod#32189

Merged
mgoin merged 1 commit intovllm-project:mainfrom
vkuzo:20260112_fp8_online_refactor
Jan 20, 2026
Merged

fp8 online quant: split out Fp8OnlineLinearMethod#32189
mgoin merged 1 commit intovllm-project:mainfrom
vkuzo:20260112_fp8_online_refactor

Conversation

@vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Jan 12, 2026

Summary:

Split out Fp8OnlineLinearMethod from Fp8LinearMethod to more clearly separate online quant from offline quant logic, following a similar PR recently landed for Fp8OnlineMoEMethod.

In the same PR, beef up testing for online quant in integration tests a bit so we can depend on tests for testing future functionality for online quant. Specifically, extend the online fp8 quant test to also include a small moe model, and also extend it to run inference with a couple of tokens.

Test Plan:

// on a NVIDIA B200
// run online quant test (dense + moe smoke tests)
with-proxy pytest tests/quantization/test_fp8.py -s -x -k online_quantization
// run entire fp8.py test suite
with-proxy pytest tests/quantization/test_fp8.py -s -x

Reviewers:

Subscribers:

Tasks:

Tags:

Purpose

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Note

Cursor Bugbot is generating a summary for commit d209af298f5d6601a9328c21c252c4fc11de1911. Configure here.


Note

Cleanly separates online FP8 linear quantization from offline (serialized) flow and strengthens smoke test coverage.

  • Adds Fp8OnlineLinearMethod that patches weight loading to quantize fp16/bf16 weights on-the-fly via ops.scaled_fp8_quant, with optional Marlin preparation
  • Updates Fp8Config.get_quant_method to choose Fp8OnlineLinearMethod vs Fp8LinearMethod based on is_checkpoint_fp8_serialized; simplifies Fp8LinearMethod to the serialized-FP8 path
  • Keeps MoE split (Fp8OnlineMoEMethod/Fp8MoEMethod) unchanged, integrates with selection logic
  • Expands tests: new test_online_quantization runs on facebook/opt-125m and Qwen/Qwen1.5-MoE-A2.7B, parameterizes KV cache (auto/fp8) and Marlin/ROCm flags, and performs short greedy generation

Written by Cursor Bugbot for commit d209af298f5d6601a9328c21c252c4fc11de1911. This will update automatically on new commits. Configure here.


Note

Cleanly separates online FP8 linear quantization from the serialized (offline) flow and strengthens smoke test coverage.

  • Adds Fp8OnlineLinearMethod that patches weight loading to quantize fp16/bf16 weights via ops.scaled_fp8_quant, with optional Marlin prep
  • Updates Fp8Config.get_quant_method to choose Fp8OnlineLinearMethod vs Fp8LinearMethod based on is_checkpoint_fp8_serialized; Fp8LinearMethod now focuses on fp8-serialized path
  • Keeps MoE split (Fp8OnlineMoEMethod/Fp8MoEMethod) and integrates it into the same selection logic
  • Expands test_online_quantization: runs on facebook/opt-125m and Qwen/Qwen1.5-MoE-A2.7B, parameterizes KV cache (auto/fp8) and Marlin/ROCm flags, and validates a short greedy generation

Written by Cursor Bugbot for commit b9d2f36f55819d195b00afd8efabf99ab9824a22. This will update automatically on new commits. Configure here.

# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.weight.numel()
if layer._loaded_numel == target_loaded_numel:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is outdated, deleting

assert input_scale is not None
input_scale = input_scale.max()
weight = weight.t()
weight = layer.weight
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of these changes is just moving online quant code to the new child class

if self.quant_config.is_checkpoint_fp8_serialized:
weight = create_fp8_weight_parameter(
output_size_per_partition, input_size_per_partition, weight_loader
weight = create_fp8_weight_parameter(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of these changes is just moving online quant code to the new child class

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a well-executed refactoring that splits the online quantization logic for FP8 linear layers into a new Fp8OnlineLinearMethod class. This change significantly improves code clarity and maintainability by separating the concerns of online and offline quantization, following the pattern established for MoE layers. The implementation is clean, and the logic has been moved correctly. The tests have also been updated to cover both dense and MoE models for online quantization, which is a great improvement. Overall, this is a solid contribution that enhances the codebase.

@robertgshaw2-redhat
Copy link
Collaborator

This looks good. Can you make a new directory called online where this sits?

@mergify
Copy link

mergify bot commented Jan 12, 2026

Hi @vkuzo, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Copy link
Contributor

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with @robertgshaw2-redhat, it'd be great to break these out into separate files/directories

Looks great in concept, thanks for looking at this

@vkuzo vkuzo force-pushed the 20260112_fp8_online_refactor branch from a59a81f to d209af2 Compare January 12, 2026 19:21
@mergify
Copy link

mergify bot commented Jan 12, 2026

Hi @vkuzo, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as just structural changes, thanks!

@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed quantization labels Jan 13, 2026
@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 14, 2026

CI is failing with an OOM on QWen 1.5B on a 24GB NVIDIA L4 machine, my best guess is that we actually need #31914 to land for memory usage to be sane. I'm going to remove Qwen 1.5B from this PR (since the fp8.py changes here do not touch MoEs) to unblock, and we can revisit online quant + moe in CI once the memory issue is fixed.

@vkuzo vkuzo force-pushed the 20260112_fp8_online_refactor branch 2 times, most recently from 4806453 to 7df7c9a Compare January 14, 2026 17:41
@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 14, 2026

@mgoin thanks, CI is green after the latest fix

@vkuzo vkuzo force-pushed the 20260112_fp8_online_refactor branch from 7df7c9a to 2fe028f Compare January 16, 2026 14:54
@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 16, 2026

rebasing manually since automatic rebase failed on permissions

vkuzo added a commit to vkuzo/vllm that referenced this pull request Jan 16, 2026
Summary:

Enables using float8 blockwise scaling with `fp8.py` online quantization.

For now, the UI part of this PR is a placeholder pending the discussions
in vllm-project#32412 . The bulk of the
PR is just wiring up kernels that already exist to fp8.py + online quant
+ blockwise scaling.

This will need to be rebased after the following PRs land:
* vllm-project#32189
* vllm-project#31914

Test Plan:

TODO

Signed-off-by: Vasiliy Kuznetsov <vasiliy@meta.com>
@vkuzo vkuzo force-pushed the 20260112_fp8_online_refactor branch from 2fe028f to 2debbac Compare January 16, 2026 19:56
Summary:

Split out `Fp8OnlineLinearMethod` from `Fp8LinearMethod` to more clearly
separate online quant from offline quant logic, following a similar PR
recently landed for `Fp8OnlineMoEMethod`.

Test Plan:

```
// run online quant test (dense + moe smoke tests)
with-proxy pytest tests/quantization/test_fp8.py -s -x -k online_quantization

// run entire fp8.py test suite
with-proxy pytest tests/quantization/test_fp8.py -s -x
```

Reviewers:

Subscribers:

Tasks:

Tags:

Signed-off-by: vasiliy <vasiliy@fb.com>
@vkuzo vkuzo force-pushed the 20260112_fp8_online_refactor branch from 2debbac to 0ebce01 Compare January 20, 2026 14:54
@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 20, 2026

rebased on top of #27814 which just landed

@mgoin mgoin merged commit d2389c1 into vllm-project:main Jan 20, 2026
54 checks passed
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
monajafi-amd pushed a commit to monajafi-amd/vllm that referenced this pull request Jan 23, 2026
Signed-off-by: mohammad najafi <mohammad.najafi@amd.com>
yma11 pushed a commit to yma11/vllm that referenced this pull request Jan 27, 2026
yma11 added a commit to yma11/vllm that referenced this pull request Feb 3, 2026
* Revert "offload weights to cpu before fp8 online quant (vllm-project#225)"

This reverts commit fc5a0a6.

* fp8 online quant: split out Fp8OnlineLinearMethod (vllm-project#32189)

Signed-off-by: Yan Ma <yan.ma@intel.com>

* cherry-pick:fix memory for online fp8 quantization with streaming weight load

Signed-off-by: Yan Ma <yan.ma@intel.com>

* fix fp8

Signed-off-by: Yan Ma <yan.ma@intel.com>

---------

Signed-off-by: Yan Ma <yan.ma@intel.com>
Co-authored-by: Vasiliy Kuznetsov <vkuzo@users.noreply.github.com>
yma11 pushed a commit to yma11/vllm that referenced this pull request Feb 3, 2026
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

quantization ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants