Skip to content

Int8 Quantization Support for DiT (Z-Image & Qwen-Image)#1470

Merged
hsliuustc0106 merged 33 commits into
vllm-project:mainfrom
yjb767868009:int8-quant
Mar 19, 2026
Merged

Int8 Quantization Support for DiT (Z-Image & Qwen-Image)#1470
hsliuustc0106 merged 33 commits into
vllm-project:mainfrom
yjb767868009:int8-quant

Conversation

@yjb767868009
Copy link
Copy Markdown
Contributor

@yjb767868009 yjb767868009 commented Feb 25, 2026

Overview

This PR introduces online int8 quantization for DiT models in vLLM-Omni based. The currently supported model range follows the FP8 quantization, starting with Z-Image and Qwen-Image (text-to-image). Online int8 converts BF16/FP16 weights to int8 at model load time with dynamic activation scaling.

Device compatibility: W8A8
Per-layer control: ignored_layers lets users keep sensitive layers in BF16

Supported Models

Model HF Models Recommendation ignored_layers
Z-Image Tongyi-MAI/Z-Image-Turbo All layers None
Qwen-Image Qwen/Qwen-Image, Qwen/Qwen-Image-2512 All layers None

Changes

Quantization framework (vllm_omni/diffusion/quantization/)

  • int8.py — DiffusionInt8Config with Int8Config for device (dynamic activation scaling, online weight conversion)
  • __init__.py — DiffusionInt8Config is added and registered to the supported quantization methods

Tests (tests/diffusion/quantization/)

  • test_int8_config.py — Unit tests covering config creation, vLLM config extraction, ignored_layers, dict non-mutation, conflicting method warnings, and end-to-end integration with OmniDiffusionConfig

Documentation (docs/)

  • user_guide/diffusion/quantization/overview.md — Quantization methods overview
  • user_guide/diffusion/quantization/int8.md — Usage guide (Python API + CLI), parameter reference, per-model recommendations
  • user_guide/diffusion_acceleration.md — Updated model support table with int8 column
  • .nav.yml — Added int8 quantization section to docs navigation

Example (examples/offline_inference/text_to_image/text_to_image.py)

  • Added --quantization int8

How to Use

from vllm_omni import Omni

# Z-Image: all layers quantized
omni = Omni(model="Tongyi-MAI/Z-Image-Turbo", quantization="int8")

# Qwen-Image: skip sensitive img_mlp layers
omni = Omni(
    model="Qwen/Qwen-Image-2512",
    quantization_config={
        "method": "int8",
        "ignored_layers": ["img_mlp"],
    },
)
# CLI
python text_to_image.py --model Tongyi-MAI/Z-Image-Turbo --quantization int8

python text_to_image.py --model Qwen/Qwen-Image-2512 --quantization int8 \
    --ignored-layers "int8"

Test Plan

  • Lint/type check passes with no import or type errors
  • python examples/offline_inference/text_to_video/text_to_video.py --quantization int8--model <wan_model> initializes without errors
  • Running without --quantization produces identical behavior (quant_config=None flows through as no-op)

Test Result

Quantization Quality Benchmark for GPU

  • Qwen-Image-2512
Config Avg Time Speedup Memory (GiB) Mem Reduction Mean LPIPS
BF16 baseline 30.07s 65.18 (ref)
int8 20.45s 32% 47.64 20% 0.0197
int8 skip img_mlp 26.78s 11% 51.43 14% 0.0027
  • Z-Image-Turbo
Config Avg Time Speedup Memory (GiB) Mem Reduction Mean LPIPS
BF16 baseline 22.69s 24.95 (ref)
int8 14.95s 34% 20.16 19% 0.1597
int8 skip feed_forward 20.32s 10% 22.80 9% 0.0290

Quantization Quality Benchmark for Atlas A2

  • Qwen-Image-2512
Config Avg Time Speedup Memory (GiB) Mem Reduction Mean LPIPS
BF16 baseline 22.88s 59.93 (ref)
int8 21.18s 7% 47.75 20% 0.0312
int8 skip img_mlp 22.26s 3% 51.81 14% 0.0068
  • Z-Image-Turbo
Config Avg Time Speedup Memory (GiB) Mem Reduction Mean LPIPS
BF16 baseline 16.95s 23.43 (ref)
int8 14.78s 13% 18.03 23% 0.1474
int8 skip feed_forward 16.24s 4% 21.63 8% 0.0337

Memory Profiling

  • Qwen-Image-2512, 1024x1024, 50 steps
Config Weights Activations Peak Total Reduction
BF16, TP=1 55.78 GB 9.4 GB 65.18 GB -
Int8, TP=1 44.96 GB 9.76 GB 54.72 GB 16%
BF16, TP=2 43.13 GB 10.32 GB 53.45 GB -
Int8, TP=2 37.17 GB 10.52 GB 47.69 GB 11%
  • Z-Image-Turbo, 1024x1024, 50 steps
Config Weights Activations Peak Total Reduction
BF16, TP=1 20.51 GB 4.44 GB 24.95 GB -
Int8, TP=1 14.87 GB 5.29 GB 20.16 GB 19%
BF16, TP=2 14.83 GB 5.30 GB 20.13 GB -
Int8, TP=2 11.84 GB 5.28 GB 17.12 GB 15%

Qwen-Image

gpu-qwen-image-coffee-bf16 gpu-qwen-image-coffee-int8-quantization
bf16 vs int8 in GPU
qwen-image-coffee-int8-quantization qwen-image-coffee-bf16
bf16 vs int8 in NPU

Z-Image

gpu-zimage-coffee-bf16 gpu-zimage-coffee-int8-quantization
bf16 vs int8 in GPU
z-image-coffee-int8-quantization z-image-coffee-bf16
bf16 vs int8 in NPU

Related Issues


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please providing the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please pasting the results comparison before and after, or e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 54149deb01

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm_omni/diffusion/quantization/int8.py Outdated
Comment thread vllm_omni/diffusion/quantization/int8.py Outdated
Comment thread vllm_omni/diffusion/quantization/int8.py Outdated
@lishunyang12
Copy link
Copy Markdown
Collaborator

May i ask your way to measure the memory usage? I didn't take kv cache into account and only record weight for model loaded. Seems like there is some discrepancies between the two.

@yjb767868009
Copy link
Copy Markdown
Contributor Author

May i ask your way to measure the memory usage? I didn't take kv cache into account and only record weight for model loaded. Seems like there is some discrepancies between the two.

I observed the peak VRAM during runtime by using npu-smi info similar to nvidia-smi.

Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left a couple comments, mostly around the top-level torch_npu import

from typing import TYPE_CHECKING, Any, Optional

import torch
import torch_npu
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch_npu at module top level will crash on any non-NPU machine. Since __init__.py imports DiffusionInt8Config unconditionally, this breaks the entire vllm_omni.diffusion.quantization package — including FP8 codepaths.

Move this to a lazy import inside the methods that actually call torch_npu.* (e.g. apply(), process_weights_after_loading()).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following your suggestion, I have using lazy import.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks.

replace_parameter(layer, "weight_scale", weight_scale)


class DiffusionInt8Config(DiffusionQuantizationConfig):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing quant_config_cls = Int8Config — without it get_name() from the base class raises NotImplementedError. See how DiffusionFp8Config sets quant_config_cls = Fp8Config.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following your suggestion, I have added quantic_fig_cls = Int8Config.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed in the diff. Thanks.


logger = logging.getLogger(__name__)

CONDITION_IMAGE_SIZE = 384 * 384
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CONDITION_IMAGE_SIZE / VAE_IMAGE_SIZE refactor seems unrelated to int8 quantization. Worth splitting into its own PR to keep review scope tight.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was caused by an incorrect commit and has now been removed from the branch.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

Signed-off-by: juboyu <767868009@qq.com>
Signed-off-by: juboyu <767868009@qq.com>
Signed-off-by: juboyu <767868009@qq.com>
Signed-off-by: juboyu <767868009@qq.com>
Signed-off-by: juboyu <767868009@qq.com>
Signed-off-by: juboyu <767868009@qq.com>
Signed-off-by: juboyu <767868009@qq.com>
Signed-off-by: juboyu <767868009@qq.com>
@lishunyang12
Copy link
Copy Markdown
Collaborator

Can you post visual output for z-image?

@yjb767868009
Copy link
Copy Markdown
Contributor Author

Can you post visual output for z-image?

I will upload the visual output for z-image later.

Signed-off-by: JuboYu <767868009@qq.com>
Signed-off-by: juboyu <767868009@qq.com>
@david6666666
Copy link
Copy Markdown
Collaborator

@codex review

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

any test result in gpu?


## Device Compatibility for Int8

| NPU Generation | Int8 Mode |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GPU should also support it, right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This version only supports NPUs. The GPU version needs to be developed. I don't think GPUs need INT8, there are better options for FP8, while NPU currently only supports INT8.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you should add a note for TODO GPU support int8

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: a117002e76

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

def from_config(cls, config: dict[str, Any]) -> "Int8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_int8_serialized = "int8" in quant_method
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Fall back to dynamic activation scheme in from_config

Int8Config.from_config treats activation_scheme as mandatory via get_from_keys, so any int8 quantization config that omits this field will raise during config parsing instead of using the class default. This breaks loading paths that rely on minimal quantization metadata (e.g., only quant_method present) even though __init__ already defines "dynamic" as the default scheme; switching to an optional lookup with a default avoids this hard failure.

Useful? React with 👍 / 👎.

@yjb767868009
Copy link
Copy Markdown
Contributor Author

I found one blocking gap in the validation for this new int8 path.

The current tests mostly stop at config plumbing and toy linear-kernel smoke checks, but they never instantiate Omni on a supported DiT model such as Tongyi-MAI/Z-Image-Turbo or Qwen/Qwen-Image and verify that an actual int8 generation request succeeds. Since this PR adds a new user-facing quantization path in vllm_omni/diffusion/quantization/int8.py, updates the support tables, and includes quality/performance claims, I think it needs at least one inference-level regression test on a supported model so broken weight mapping, unsupported layers, or loader-path regressions are caught automatically.

For this risk area, I would specifically ask for model inference correctness evidence rather than only config/unit coverage: a seeded text-to-image generation smoke test on a supported model that validates successful generation and expected output shape (and ideally compares against a BF16 baseline if the path is intended to be stable enough for that).

The end-to-end tests for qwen-image and z-image have been included in the PR, and the corresponding test results are provided.

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correction: Review Error

I apologize — my previous comment requested "model inference correctness evidence" but I failed to recognize that this evidence was already provided in the PR description.

The PR description includes:

Evidence Provided
Quality metrics ✅ Mean LPIPS scores for both GPU and NPU
Performance benchmarks ✅ Timing, speedup %, memory reduction tables
Memory profiling ✅ Weights/Activations/Peak breakdown by TP config
Visual outputs ✅ BF16 vs Int8 comparison images for both models
Test plan checkmarks ✅ [x] items showing tests were run

This is comprehensive inference-level validation. The review should not have requested additional tests.

Assessment update: The PR provides sufficient evidence for the int8 quantization feature. The existing unit tests in test_int8_config.py plus the PR description evidence together provide reasonable coverage.

Sorry for the redundant feedback. 🙏

david6666666 and others added 3 commits March 19, 2026 11:48
…he prefix for quantization ignored_layers

Signed-off-by: juboyu <767868009@qq.com>
Signed-off-by: JuboYu <767868009@qq.com>
@david6666666
Copy link
Copy Markdown
Collaborator

Please fix pre-commit

Signed-off-by: juboyu <767868009@qq.com>
@yjb767868009
Copy link
Copy Markdown
Contributor Author

@lishunyang12 @david6666666 @hsliuustc0106 This PR has been prepared and is ready for your review.

@david6666666 david6666666 enabled auto-merge (squash) March 19, 2026 14:47
@hsliuustc0106 hsliuustc0106 disabled auto-merge March 19, 2026 14:57
@hsliuustc0106 hsliuustc0106 merged commit b766c47 into vllm-project:main Mar 19, 2026
7 checks passed
zhumingjue138 pushed a commit to zhumingjue138/vllm-omni that referenced this pull request Mar 20, 2026
…t#1470)

Signed-off-by: juboyu <767868009@qq.com>
Signed-off-by: JuboYu <767868009@qq.com>
Signed-off-by: Alicia <115451386+congw729@users.noreply.github.com>
Co-authored-by: Alicia <115451386+congw729@users.noreply.github.com>
Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com>
yiliu30 pushed a commit to yiliu30/vllm-omni-fork that referenced this pull request Mar 20, 2026
…t#1470)

Signed-off-by: juboyu <767868009@qq.com>
Signed-off-by: JuboYu <767868009@qq.com>
Signed-off-by: Alicia <115451386+congw729@users.noreply.github.com>
Co-authored-by: Alicia <115451386+congw729@users.noreply.github.com>
Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com>

Signed-off-by: yiliu30 <yi4.liu@intel.com>
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
…t#1470)

Signed-off-by: juboyu <767868009@qq.com>
Signed-off-by: JuboYu <767868009@qq.com>
Signed-off-by: Alicia <115451386+congw729@users.noreply.github.com>
Co-authored-by: Alicia <115451386+congw729@users.noreply.github.com>
Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants