Skip to content

[Quantization] feat: add FP8 for Omnigen2#2441

Merged
lishunyang12 merged 18 commits intovllm-project:mainfrom
zhangj1an:jian/omnigen2_fp8
Apr 16, 2026
Merged

[Quantization] feat: add FP8 for Omnigen2#2441
lishunyang12 merged 18 commits intovllm-project:mainfrom
zhangj1an:jian/omnigen2_fp8

Conversation

@zhangj1an
Copy link
Copy Markdown
Contributor

@zhangj1an zhangj1an commented Apr 2, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

Follows #1854.

Add FP8 online quantization support for the OmniGen2 model, including Attn and MLP.

On top of that, OmniGen2 has hidden_size=2520, which is not a multiple of 16 (2520 % 16 == 8). This causes vLLM's cutlass_scaled_mm to fall back to a slow Triton scaled_mm_kernel for every FP8 linear layer (QKV, attn output, gate_up_proj, down_proj). To fix that, we pad weight tensors to multiples of 16 in omnigen2_transformer.py, so the native CUTLASS FP8 tensor-core kernel is used instead of the Triton fallback.

quantisation layer details

Layers quantized (FP8)

Layer Type Count
attn.to_qkv QKVParallelLinear 38
attn.to_out RowParallelLinear 38
feed_forward.gate_up_proj MergedColumnParallelLinear 38
feed_forward.down_proj RowParallelLinear 38

Layers kept at full precision

  • norm1.linear — produces scale_msa, gate_msa, scale_mlp, gate_mlp via tanh(). These multiplicative control signals are precision-sensitive; FP8 quantization errors compound across 38 blocks (32 main + 6 refiner) and visibly degrade generation quality.
  • Embedding layers (x_embedder, ref_image_patch_embedder, caption_embedder), timestep MLP, and output norm projections — small and precision-sensitive, consistent with existing diffusion model quantization policy.

Test Plan

image editing task We have two sample images as shown below, and the prompt is:
"Edit the first image: replace the clothing of the person with the dress "
"from the second image. Keep the person's identity, face, hairstyle, pose, "
"body shape, and background unchanged. Photorealistic."
Girl Dress
Image girl dress

Test Result

Metric BF16 (no quantization) FP8 quantization
Image omnigen2_bf16_30steps omnigen2_fp8_final
Time 14.653 s 13.920 s (5.0% faster)
VRAM 19.56 GiB 19.39 GiB (0.9% less)

GEdit-Bench Evaluation

We ran ~10% of GEdit-Bench (55 samples, 5 per task group, English only) using a local Qwen2.5-VL-7B-Instruct judge. Generation used 20 inference steps at 512×512, seed 42.

Metric BF16 FP8 Delta
Q_SC (semantics) 4.80 4.53 -0.27
Q_PQ (quality) 5.22 5.35 +0.13
Q_O (overall) 4.72 4.66 -0.06 (-1.3%)

FP8 overall score is within 1.3% of BF16, confirming minimal quality degradation from quantization.

Per-task breakdown (Q_O overall score)
Task BF16 FP8 Delta
background_change 6.57 6.56 -0.01
color_alter 4.21 4.68 +0.47
material_alter 3.02 2.87 -0.15
motion_change 3.76 4.67 +0.91
ps_human 5.60 5.34 -0.26
style_change 5.76 5.56 -0.20
subject-add 4.79 4.04 -0.75
subject-remove 4.91 4.96 +0.05
subject-replace 4.48 3.60 -0.88
text_change 4.84 4.79 -0.05
tone_transfer 4.00 4.20 +0.20
Why minimal VRAM savings?The OmniGen2 diffusion transformer is ~3.6B params (the full model including the Qwen2.5-VL backbone is ~11B, but only the transformer is FP8-quantized). So the transformer weight memory is a smaller fraction of total VRAM, and the main dominating factor are activations, the VAE decoder, and CUDA context overhead. Additionally, FP8 online quantization loads BF16 weights from disk and converts to FP8 at runtime, so peak allocation during loading is still BF16-sized. Significant VRAM reduction requires FP8 serialized checkpoints (pre-quantized weights stored as FP8 on disk).
Why is the FP8 image not pixel-identical to BF16? FP8 has 3–4 bits of mantissa vs BF16's 7 bits, so every quantized matmul introduces small rounding errors. These errors compound across 4 quantized linears × 38 blocks (32 main + 6 refiner) × 30 denoising steps = ~4,560 quantized matmuls per image. Diffusion models are sensitive to early-step perturbations, so small numerical differences can steer the denoising trajectory slightly differently. The output is semantically equivalent (same composition, quality, and identity) but not numerically identical — this is expected behavior for FP8 quantization, consistent with other FP8 diffusion models (FLUX, SD3, HunyuanImage).

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
@zhangj1an

This comment was marked as resolved.

@lishunyang12
Copy link
Copy Markdown
Collaborator

Hey @lishunyang12 , do you mind tell me what GPU you were using to run quantisation? I am still experiencing slowdown on a single H100 SXM that uses Hopper.

As shown in omnigen2_transformer.py, I quantised the MLP and ATTN layer as required. My implementation is quite similar to #2292, that uses ColumnParallelLinear, QKVParallelLinear, RowParallelLinear. During my local testing, the image output quality is the same, just the speed is very slow.

Thanks in advance!

Reproduce steps
Check out to this PR's branch,

cd examples/online_serving/image_to_image

# start server for original BF16 Omnigen2
ENABLE_TORCH_PROFILER=1 PROFILE_DIR=/root/vllm-omni/outputs/my_traces_bf16 bash run_server_omnigen2_bf16.sh

# Wait until http://127.0.0.1:8092/health returns OK (server is up).
# in another terminal, run a sample
cd vllm-omni/examples/online_serving/image_to_image
python omnigen2_fp8_edit_client.py \
  --server http://127.0.0.1:8092 \
  --steps 2 \
  --profile \
  -o omnigen2_bf16_profiled.png
# runtime result is in vllm-omni/outputs/my_traces_bf16


# When result is saved, kill server
pkill -f "vllm-omni serve" || true

# Start a new server for quantised FP8 Omnigen2
ENABLE_TORCH_PROFILER=1 PROFILE_DIR=/root/vllm-omni/outputs/my_traces_fp8 bash run_server_omnigen2_fp8.sh

# When server is up, run a sample
cd vllm-omni/examples/online_serving/image_to_image
python omnigen2_fp8_edit_client.py \
  --server http://127.0.0.1:8092 \
  --steps 2 \
  --profile \
  -o omnigen2_fp8_profiled.png
# runtime result is in vllm-omni/outputs/my_traces_fp8

For BF16, attn::mm takes 33% of GPU usage.

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                       pipeline_forward         0.00%       0.000us         0.00%       0.000us       0.000us        1.232s       108.83%        1.232s        1.232s             1  
                                               aten::mm         1.89%      24.011ms         2.66%      33.887ms      28.357us     371.439ms        32.82%     450.309ms     376.827us          1195  
                                    Command Buffer Full        22.89%     291.215ms        22.89%     291.215ms     107.183us     288.731ms        25.51%     288.731ms     106.268us          2717 

BF16_profile.txt

For FP8, scaled_mm_kernel takes 97% of GPU usage.

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                       pipeline_forward         0.00%       0.000us         0.00%       0.000us       0.000us       27.047s       106.48%       27.047s       27.047s             1  
                                       scaled_mm_kernel         0.00%       0.000us         0.00%       0.000us       0.000us       24.575s        96.75%       24.575s      27.428ms           896  
                                              aten::mul         0.09%      23.045ms         2.71%     733.721ms     156.277us     166.155ms         0.65%     196.647ms      41.884us          4695 

FP8_profile.txt

I will run it another day.

Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
@zhangj1an zhangj1an changed the title [WIP][Quantization] feat: add FP8 for Omnigen2 [Quantization] feat: add FP8 for Omnigen2 Apr 13, 2026
@zhangj1an zhangj1an marked this pull request as ready for review April 13, 2026 09:07
@zhangj1an
Copy link
Copy Markdown
Contributor Author

Hey @lishunyang12, this PR is now ready for review. Please take a look when you are free, thank you!

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 2d61c30023

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py
Comment thread vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py
@lishunyang12 lishunyang12 added the quantization Code related to quantization label Apr 15, 2026
@lishunyang12 lishunyang12 added the ready label to trigger buildkite CI label Apr 15, 2026
@lishunyang12 lishunyang12 merged commit 817e32d into vllm-project:main Apr 16, 2026
8 checks passed
lvliang-intel pushed a commit to lvliang-intel/vllm-omni that referenced this pull request Apr 20, 2026
Signed-off-by: Zhang <jianmusings@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

quantization Code related to quantization ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants