Skip to content

[not ready for review] extend fp8 online quant with blockwise scaling#32485

Open
vkuzo wants to merge 1 commit intovllm-project:mainfrom
vkuzo:20260116_fp8_online_rowwise_quant
Open

[not ready for review] extend fp8 online quant with blockwise scaling#32485
vkuzo wants to merge 1 commit intovllm-project:mainfrom
vkuzo:20260116_fp8_online_rowwise_quant

Conversation

@vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Jan 16, 2026

Summary:

Enables using float8 blockwise scaling with fp8.py online quantization.

For now, the UI part of this PR is a placeholder pending the discussions in #32412 . The bulk of the PR is just wiring up kernels that already exist to fp8.py + online quant

  • blockwise scaling.

This will need to be rebased after the following PRs land:

Test Plan:

// small dense model - meta-llama/Llama-3.2-1B

// bf16
> time lm_eval --model vllm --model_args "pretrained=meta-llama/Llama-3.2-1B,enforce_eager=true" --tasks gsm8k --batch_size auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.0644|±  |0.0068|
|     |       |strict-match    |     5|exact_match|↑  |0.0607|±  |0.0066|

// fp8 tensorwise
> time lm_eval --model vllm --model_args "pretrained=meta-llama/Llama-3.2-1B,enforce_eager=true,quantization=fp8" --tasks gsm8k --batch_size auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.0629|±  |0.0067|
|     |       |strict-match    |     5|exact_match|↑  |0.0576|±  |0.0064|

// fp8 blockwise
> time lm_eval --model vllm --model_args "pretrained=meta-llama/Llama-3.2-1B,enforce_eager=true,quantization=fp8" --tasks gsm8k --batch_size auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.0644|±  |0.0068|
|     |       |strict-match    |     5|exact_match|↑  |0.0607|±  |0.0066|

// small moe model - Qwen/Qwen1.5-MoE-A2.7B

// bf16
> time lm_eval --model vllm --model_args "pretrained=Qwen/Qwen1.5-MoE-A2.7B,enforce_eager=true" --tasks gsm8k --batch_size auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6096|±  |0.0134|
|     |       |strict-match    |     5|exact_match|↑  |0.1645|±  |0.0102|

// fp8 tensorwise
> time lm_eval --model vllm --model_args "pretrained=Qwen/Qwen1.5-MoE-A2.7B,enforce_eager=true,quantization=fp8" --tasks gsm8k --batch_size auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5974|±  |0.0135|
|     |       |strict-match    |     5|exact_match|↑  |0.1622|±  |0.0102|

// fp8 blockwise
> time lm_eval --model vllm --model_args "pretrained=Qwen/Qwen1.5-MoE-A2.7B,enforce_eager=true,quantization=fp8" --tasks gsm8k --batch_size auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6073|±  |0.0135|
|     |       |strict-match    |     5|exact_match|↑  |0.2047|±  |0.0111|


TODO

Purpose

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request extends FP8 online quantization to support blockwise scaling. The changes primarily involve wiring up existing kernels for blockwise operations within the fp8.py quantization logic.

My review focuses on the implementation in vllm/model_executor/layers/quantization/fp8.py. The changes look mostly correct and consistent. I've identified a few areas for improvement:

  • A critical assertion that was removed should be restored to prevent potential misconfigurations.
  • There are a couple of local imports that should be moved to the top of the file for better code style and maintainability.
  • There is some repeated logic for checking the quantization type, which could be refactored into a helper property to improve code clarity and reduce duplication.

Overall, this is a good step towards enabling more flexible FP8 quantization schemes.

@@ -1089,7 +1156,21 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
super().__init__(quant_config, layer)
assert not quant_config.is_checkpoint_fp8_serialized
assert quant_config.activation_scheme == "dynamic"
assert quant_config.weight_block_size is None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The assertion assert quant_config.weight_block_size is None was removed. Fp8OnlineMoEMethod is used for online quantization (is_checkpoint_fp8_serialized=False), where weight_block_size in Fp8Config is expected to be None. This assertion is a crucial sanity check to prevent misconfiguration. Please restore it.

Suggested change
assert self.quant_config.weight_block_size is None

Comment on lines +361 to +365
if (
self.block_quant
or self.quant_config.online_quant_scaling_type
is OnlineQuantScalingType.BLOCKWISE
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The condition self.block_quant or self.quant_config.online_quant_scaling_type is OnlineQuantScalingType.BLOCKWISE is repeated in several places within this class (e.g., in process_weights_after_loading and apply). To improve maintainability and reduce code duplication, consider creating a helper property within the Fp8LinearMethod class to encapsulate this logic. For example:

@property
def _is_blockwise_quant(self):
    return (self.block_quant or
            self.quant_config.online_quant_scaling_type is
            OnlineQuantScalingType.BLOCKWISE)

You could then use if self._is_blockwise_quant: in this and other locations.

is OnlineQuantScalingType.BLOCKWISE
):
# blockwise
from vllm.utils.deep_gemm import per_block_cast_to_fp8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Local imports should generally be at the top of the file. Please move from vllm.utils.deep_gemm import per_block_cast_to_fp8 to the top to follow standard Python style guidelines and improve maintainability.

is OnlineQuantScalingType.BLOCKWISE
):
# Blockwise quantization
from vllm.utils.deep_gemm import per_block_cast_to_fp8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This local import should be moved to the top of the file. This improves code readability and adheres to standard Python style guides.

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 4 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

if self.block_quant:
assert self.weight_block_size is not None

if (
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marlin path accesses non-existent weight scale attribute

High Severity

The Marlin code path in apply checks only self.block_quant to decide between layer.weight_scale_inv and layer.weight_scale, but block_quant is False when doing online quantization (it only reflects checkpoint-provided weight_block_size). Since online_quant_scaling_type is hardcoded to BLOCKWISE, process_weights_after_loading creates weight_scale_inv, but the Marlin path tries to access weight_scale which doesn't exist, causing an AttributeError on GPUs without FP8 hardware support.

Additional Locations (1)

Fix in Cursor Fix in Web

@@ -150,7 +152,8 @@ def test_load_fp16_model(
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

with vllm_runner(
"facebook/opt-125m",
# "facebook/opt-125m",
"Qwen/Qwen1.5-MoE-A2.7B",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test validation disabled with debug model change

Medium Severity

The test model was changed from facebook/opt-125m to Qwen/Qwen1.5-MoE-A2.7B, but the check_model validation function (which references opt-125m-specific layer paths like model.model.decoder.layers[0].fc1) was commented out instead of updated. The test now only runs inference without validating quantization was applied correctly. Additionally, test parameterization was reduced, decreasing coverage.

Additional Locations (2)

Fix in Cursor Fix in Web

self.block_quant
or self.quant_config.online_quant_scaling_type
is OnlineQuantScalingType.BLOCKWISE
):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blockwise condition breaks serialized per-tensor FP8 checkpoints

High Severity

The conditions checking online_quant_scaling_type is BLOCKWISE in the non-Marlin apply paths don't account for serialized checkpoints. For pre-quantized per-tensor FP8 checkpoints, create_weights creates weight_scale (not weight_scale_inv), but the hardcoded BLOCKWISE setting causes both the batch-invariant path and the main w8a8_block_fp8_linear.apply path to access layer.weight_scale_inv, which doesn't exist. This causes AttributeError when loading any serialized per-tensor FP8 checkpoint on FP8-capable hardware.

Additional Locations (1)

Fix in Cursor Fix in Web

block_quant=self.block_quant,
tp_size=layer.moe_parallel_config.tp_size,
with_lora_support=self.moe.is_lora_enabled,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parent validation rejects online blockwise before child overrides

Medium Severity

Fp8OnlineMoEMethod.__init__ calls super().__init__() before setting block_quant=True for online blockwise quantization. The parent class validates using its initial block_quant=False and activation_scheme="dynamic", computing dynamic_per_token=True. On SM90/SM100 GPUs with FlashInfer enabled, the parent selects a FlashInfer backend and raises NotImplementedError about "dynamic per token activation quantization" before the child can override block_quant=True. This causes online blockwise MoE quantization to fail on H100/Blackwell GPUs with a misleading error.

Additional Locations (1)

Fix in Cursor Fix in Web

@vkuzo vkuzo force-pushed the 20260116_fp8_online_rowwise_quant branch from 5d141ae to 3b8c1f4 Compare January 16, 2026 17:48
@mergify
Copy link

mergify bot commented Jan 16, 2026

Hi @vkuzo, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@vkuzo vkuzo force-pushed the 20260116_fp8_online_rowwise_quant branch from 3b8c1f4 to 7e08f35 Compare January 16, 2026 17:56
Summary:

Enables using float8 blockwise scaling with `fp8.py` online quantization.

For now, the UI part of this PR is a placeholder pending the discussions
in vllm-project#32412 . The bulk of the
PR is just wiring up kernels that already exist to fp8.py + online quant
+ blockwise scaling.

This will need to be rebased after the following PRs land:
* vllm-project#32189
* vllm-project#31914

Test Plan:

TODO

Signed-off-by: Vasiliy Kuznetsov <vasiliy@meta.com>
@vkuzo vkuzo force-pushed the 20260116_fp8_online_rowwise_quant branch from 7e08f35 to 419634b Compare January 16, 2026 17:59
@mergify
Copy link

mergify bot commented Jan 16, 2026

Hi @vkuzo, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link

mergify bot commented Jan 20, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vkuzo.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant