Skip to content

[Feature] Add FP8 quantization for Qwen2.5-Omni (thinker LM only)#3466

Open
wuli666 wants to merge 5 commits into
vllm-project:mainfrom
wuli666:feat/qwen2_5_omni_fp8_d
Open

[Feature] Add FP8 quantization for Qwen2.5-Omni (thinker LM only)#3466
wuli666 wants to merge 5 commits into
vllm-project:mainfrom
wuli666:feat/qwen2_5_omni_fp8_d

Conversation

@wuli666
Copy link
Copy Markdown

@wuli666 wuli666 commented May 9, 2026

Purpose

Add Qwen2.5-Omni FP8 dynamic quantization support, claimed in RFC #2136 (Quantization Matrix → 🌐 Qwen2.5-Omni FP8 D).

This PR mirrors the per-component routing pattern from #1764 (Qwen3-Omni FP8 D), restricting FP8 quantization to the thinker language model only — vision and audio encoders stay BF16 (they have no FP8 scale tensors in the checkpoint and would produce garbage embeddings if quantized). The talker stage receives the same quantization=fp8 flag via cross-process engine_args propagation; in the auto-wrap path it is also quantized to FP8 (additional ~1.8 GiB saving with no measurable quality loss). Users wanting a strict thinker-only configuration can pass --quantization-config '{"thinker.language_model": "fp8"}' once vLLM's per-component dict path is unblocked upstream.

The routing supports three input shapes:

  1. Pre-quantized checkpoint (e.g. ModelOpt FP8) — passed through as-is to the language model; encoders skipped via the PRE_QUANTIZED_METHODS check.
  2. Per-component dict (e.g. --quantization-config '{"thinker.language_model": "fp8"}') — a ComponentQuantizationConfig is supplied directly; resolved per layer prefix.
  3. Dynamic quant method (e.g. --quantization fp8 on a BF16 checkpoint) — auto-wrapped into ComponentQuantizationConfig({language_prefix: quant_config}, default=None) so encoders fall through to None. The wrapped config is propagated via replace(vllm_config, quant_config=wrapped) so all submodules within the thinker process see consistent routing.

Test Plan

End-to-end on a 2× RTX 4090 host (Ada sm_89, vLLM-Omni HEAD), using the mixed-modalities query (audio + image + video). Stage 0 alone on GPU 0; stages 1+2 cohabit GPU 1. Deploy YAML retuned locally for 24 GiB cards (not part of this PR; default YAML is sized for 80 GiB H100/H200).

cd examples/offline_inference/qwen2_5_omni
HF_ENDPOINT=https://hf-mirror.com VLLM_WORKER_MULTIPROC_METHOD=spawn \
python end2end.py \
    --model /path/to/Qwen2.5-Omni-7B \
    --output-wav out_bf16 \
    --query-type use_mixed_modalities

# FP8 D
python end2end.py \
    --model /path/to/Qwen2.5-Omni-7B \
    --quantization fp8 \
    --output-wav out_fp8d \
    --query-type use_mixed_modalities

The --quantization flag is added to end2end.py locally for test convenience and is not part of this PR. The same routing is exercised by --quantization-config '{"thinker.language_model": "fp8"}' (existing flag) once vLLM's strict pairing check on quantization_config is relaxed upstream.

Validation environment

  • vLLM-Omni: this PR.
  • vLLM: built from source / version including vllm-project/vllm#41424 ("Fix FP8 Bias Loading", merged 2026-05-03), which fixes an upstream Fp8OnlineLinearMethod regression that produces garbage on biased Linear layers (Qwen2 family). The current PyPI vllm==0.20.1 predates #41424; users on pip will hit upstream garbage tokens until the next vLLM release ships. This is independent of the routing logic in this PR — once #41424 is in the user's vLLM, no further changes here are needed.

Test Result

Memory (per-stage model weights, RTX 4090 24 GiB)

Stage BF16 model FP8 D model Reduction KV cache budget BF16 → FP8 D
Stage 0 (thinker) 16.74 GiB 10.62 GiB -6.12 GiB (-37%) 2.5 GiB → 8.27 GiB (3.3× larger)
Stage 1 (talker) 6.03 GiB 4.82 GiB -1.21 GiB (-20%) 2.4 GiB → 3.51 GiB
Stage 2 (code2wav) 1.46 GiB 1.46 GiB 0
Total model weights 24.23 GiB 16.90 GiB -7.33 GiB (-30%)

Peak GPU 0 VRAM (which hosts thinker only) is identical at 21.04 GiB BF16 vs 21.02 GiB FP8 D — vLLM's gpu_memory_utilization=0.85 budget fills the available headroom with KV cache, so the saved weight memory is realized as a 3.3× larger KV cache budget rather than a smaller VRAM footprint at the same configuration. Lowering gpu_memory_utilization directly trades that headroom back for a smaller actual VRAM footprint.

Wall time (mixed-modalities query, single prompt, end-to-end including init + audio decode + WAV write)

Configuration Wall time
BF16 baseline 22.76 s
FP8 D 36.45 s

FP8 D is slower than BF16 on RTX 4090 (Ada sm_89) because the cuTLASS FP8 GEMM kernel at these layer shapes does not outperform cuBLAS BF16 on first-generation FP8 tensor cores. The same code path on RTX 5090 (Blackwell sm_120, second-generation FP8 hardware) showed the inverse — FP8 D ran in 1m02s vs BF16 2m28s (~2.4× speedup) — so the speed/memory trade-off is hardware-dependent. Memory savings are the consistent benefit.

Output quality

BF16 baseline (mixed_modalities query):

The audio recites "Mary had a little lamb". The first image shows cherry blossoms with a tower in the background. I'm not sure why it's considered funny without more context though. Maybe because it's an unexpected combination? What do you think?

FP8 D (same prompt, same seed):

The first part of the audio recites "Mary had a little lamb". The image shows a baby with glasses reading a book. The video might be considered funny because it's so cute to see a baby wearing glasses and actually reading a book. Babies usually don't read books like that! So, what do you think about it? Do you have any other thoughts on why it could be funny?

The FP8 D output is a near-paraphrase of BF16 and correctly grounds in all three input modalities (audio recitation, image content, video humor). Both runs produced valid output_*.wav files. This is the expected behavior for FP8 dynamic quantization — token-level output may differ slightly due to numerical noise, but semantic content is preserved.

Routing correctness (verified via instrumentation in ComponentQuantizationConfig.get_quant_method)

  • 112 LM linear layers (qkv_proj × 28, o_proj × 28, gate_up_proj × 28, down_proj × 28) routed to Fp8OnlineLinearMethod.
  • 56 LM attention layers receive Fp8KVCacheMethod marker (KV cache stays BF16 because kv_cache_dtype=auto).
  • 2 LM layers correctly skipped (lm_head + embed_tokensNone).
  • Visual + audio encoder layers all resolve to None → BF16 (verified by prefix not matching thinker.language_model).

Both out_bf16/*.wav and out_fp8d/*.wav were generated successfully, matching the BF16 / FP8 D text outputs above.

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

wuli666 added 2 commits May 9, 2026 08:20
Mirrors PR vllm-project#1764 (Qwen3-Omni FP8 D) routing pattern for Qwen2.5-Omni thinker LM.
Routes user-supplied quant_config (e.g. --quantization fp8) to the language_model
only via ComponentQuantizationConfig; vision and audio encoders stay BF16.

Uses maybe_prefix(prefix, ...) for component keys so the routing works under
the nested prefix structure: the thinker is constructed inside the parent
Qwen2_5OmniForConditionalGeneration with prefix="thinker", so language
model layers are at "thinker.language_model.X" at runtime.

Signed-off-by: wuli666 <djjpro975@gmail.com>
Per reviewer guidance and matrix RFC vllm-project#2136: for omni models, dynamic FP8
should scope to the thinker/LLM only — talker, audio encoder, vision
encoder, and code2wav stay BF16.

The talker stage runs in a separate process and receives engine_args.quantization
propagated from the user. When that becomes a bare dynamic quant_config (e.g.
Fp8Config from --quantization fp8), the auto-wrap branch in talker.__init__
now sets quant_config=None and propagates the cleared config via
replace(vllm_config, quant_config=None) so talker submodules construct as BF16.

Pre-quantized checkpoints (modelopt) and explicit per-component dicts (where
the user includes a talker entry) continue to be honored.

Signed-off-by: wuli666 <djjpro975@gmail.com>
@wuli666 wuli666 force-pushed the feat/qwen2_5_omni_fp8_d branch from ac6898a to 7dae773 Compare May 9, 2026 08:21
@princepride
Copy link
Copy Markdown
Collaborator

@wuli666 CI failed, PTAL

wuli666 added 3 commits May 9, 2026 08:24
Signed-off-by: wuli666 <djjpro975@gmail.com>
Signed-off-by: wuli666 <djjpro975@gmail.com>
Signed-off-by: wuli666 <djjpro975@gmail.com>
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

Review Notes

LGTM. Clean implementation that mirrors the Qwen3-Omni FP8 D pattern (PR #1764).

What works well

  • Correctly scopes FP8 quantization to the thinker LM only via ComponentQuantizationConfig
  • Properly handles pre-quantized checkpoints (ModelOpt FP8) via PRE_QUANTIZED_METHODS check
  • Talker stage correctly clears quant_config for dynamic quantization
  • Uses maybe_prefix() for nested prefix structure (thinker.language_model)
  • Comprehensive test results: 30% memory reduction, quality preservation verified

Minor observation

The 3.3× KV cache budget increase (from saved weight memory) is expected behavior with gpu_memory_utilization=0.85. Users wanting a smaller VRAM footprint should lower gpu_memory_utilization directly rather than expecting automatic reduction.

Documentation

PR body is thorough with test plan, results, and environment details. No additional docs needed for this scoped quantization change.

@hsliuustc0106 hsliuustc0106 requested a review from lishunyang12 May 9, 2026 11:32
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

have you checked the audio wav quality?

@wuli666
Copy link
Copy Markdown
Author

wuli666 commented May 9, 2026

have you checked the audio wav quality?

yes,audio is fine,matches BF16 baseline

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

Hi @wuli666, friendly reminder — this PR hasn't had any activity (commits or reviews) in the past 7 days. 🕐

Could you please provide an update?

  • If you're still working on it, that's great — just let us know.
  • If you're blocked on something, feel free to ask for help.
  • If this PR is no longer being pursued, please consider closing it so we can keep the review queue manageable.

Thanks for your contribution! 🙏

@wuli666
Copy link
Copy Markdown
Author

wuli666 commented May 17, 2026

Hi @hsliuustc0106,code is ready, pre-commit clean, 18/18 tests passing. Just waiting on review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants