diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml
index d9a2315953..00823951dc 100644
--- a/.buildkite/pipeline.yml
+++ b/.buildkite/pipeline.yml
@@ -44,11 +44,19 @@ steps:
agents:
queue: "cpu_queue_premerge"
- # L4 Test — main+NIGHTLY=1 (scheduled), or PR with label nightly-test (e.g. add label then Rebuild)
+ # L4 Test — main+NIGHTLY=1 (scheduled), or PR with specific label (e.g. add label then Rebuild)
- label: "Upload Nightly Pipeline"
depends_on: image-build
key: upload-nightly-pipeline
- if: '(build.branch == "main" && build.env("NIGHTLY") == "1") || (build.branch != "main" && build.pull_request.labels includes "nightly-test")'
+ if: >-
+ (build.branch == "main" && build.env("NIGHTLY") == "1") ||
+ (build.branch != "main" && (
+ build.pull_request.labels includes "nightly-test" ||
+ build.pull_request.labels includes "omni-test" ||
+ build.pull_request.labels includes "tts-test" ||
+ build.pull_request.labels includes "diffusion-x2iat-test" ||
+ build.pull_request.labels includes "diffusion-x2v-test"
+ ))
commands:
- buildkite-agent pipeline upload .buildkite/test-nightly.yml
agents:
diff --git a/.buildkite/test-amd-merge.yml b/.buildkite/test-amd-merge.yml
index 60ba0d9d41..ac52f60b35 100644
--- a/.buildkite/test-amd-merge.yml
+++ b/.buildkite/test-amd-merge.yml
@@ -32,7 +32,6 @@ steps:
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- - export GPU_ARCHS=gfx942
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- |
@@ -55,7 +54,7 @@ steps:
# - export GPU_ARCHS=gfx942
# - export VLLM_LOGGING_LEVEL=DEBUG
# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
-# - timeout 20m pytest -s -v tests/e2e/offline_inference/test_stable_audio_model.py
+# - timeout 20m pytest -s -v tests/e2e/offline_inference/test_stable_audio_expansion.py -m "advanced_model and diffusion and L4" --run-level advanced_model
- label: "Diffusion Cache Backend Test"
agent_pool: mi325_1
@@ -63,13 +62,12 @@ steps:
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- - export GPU_ARCHS=gfx942
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- timeout 15m pytest -s -v -m "core_model and cache and diffusion and not distributed_cuda and L4"
-- label: "Diffusion Sequence Parallelism Test"
- agent_pool: mi325_2
+- label: "Diffusion Sequence Parallelism Test (Need 4 GPUs)"
+ agent_pool: mi325_4
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
@@ -77,6 +75,7 @@ steps:
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- timeout 20m pytest -s -v tests/e2e/offline_inference/test_sequence_parallel.py
+ - timeout 20m pytest -s -v tests/diffusion/distributed/test_ulysses_uaa_perf.py
# merge-only tests
- label: "Diffusion Tensor Parallelism Test"
@@ -95,22 +94,14 @@ steps:
commands:
- timeout 20m pytest -s -v tests/diffusion/test_diffusion_worker.py
-- label: "Benchmark & Engine Test"
- agent_pool: mi325_2
+- label: "Engine Test"
+ agent_pool: mi325_1
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- - |
- timeout 20m bash -c '
- set +e
- pytest -s -v tests/benchmarks/test_serve_cli.py
- EXIT1=\$?
- pytest -s -v tests/engine/test_async_omni_engine_abort.py
- EXIT2=\$?
- exit \$((EXIT1 | EXIT2))
- '
+ - timeout 20m pytest -s -v tests/engine/test_async_omni_engine_abort.py
- label: "Omni Model Test Qwen2-5-Omni"
agent_pool: mi325_2
@@ -121,6 +112,7 @@ steps:
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- timeout 20m pytest -s -v tests/e2e/offline_inference/test_qwen2_5_omni.py
+ - timeout 20m pytest -s -v tests/e2e/online_serving/test_qwen2_5_omni.py -m "advanced_model" --run-level "advanced_model"
- label: "Omni Model Test Qwen3-Omni"
agent_pool: mi325_2
@@ -131,11 +123,10 @@ steps:
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- export VLLM_TEST_CLEAN_GPU_MEMORY=1
- - timeout 10m pytest -s -v tests/e2e/offline_inference/test_qwen3_omni.py
- - timeout 20m pytest -s -v tests/e2e/online_serving/test_qwen3_omni.py -m "advanced_model" --run-level "advanced_model"
+ - timeout 30m pytest -s -v tests/e2e/offline_inference/test_qwen3_omni.py tests/e2e/online_serving/test_qwen3_omni.py tests/e2e/online_serving/test_mimo_audio.py -m "advanced_model" --run-level "advanced_model"
- label: "Qwen3-TTS CustomVoice E2E Test"
- agent_pool: mi325_2
+ agent_pool: mi325_1
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
@@ -145,21 +136,21 @@ steps:
export VLLM_LOGGING_LEVEL=DEBUG
export VLLM_WORKER_MULTIPROC_METHOD=spawn
export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
- pytest -s -v tests/e2e/online_serving/test_qwen3_tts_customvoice.py -m "advanced_model" --run-level "advanced_model" && pytest -s -v tests/e2e/offline_inference/test_qwen3_tts_customvoice.py
+ pytest -s -v tests/e2e/online_serving/test_qwen3_tts_customvoice.py tests/e2e/offline_inference/test_qwen3_tts_customvoice.py -m "advanced_model" --run-level "advanced_model"
'
- label: "Qwen3-TTS Base E2E Test"
- agent_pool: mi325_2
+ agent_pool: mi325_1
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- |
- timeout 20m bash -c '
+ timeout 30m bash -c '
export VLLM_LOGGING_LEVEL=DEBUG
export VLLM_WORKER_MULTIPROC_METHOD=spawn
export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
- pytest -s -v tests/e2e/online_serving/test_qwen3_tts_base.py -m "advanced_model" --run-level "advanced_model" && pytest -s -v tests/e2e/offline_inference/test_qwen3_tts_base.py
+ pytest -s -v tests/e2e/online_serving/test_qwen3_tts_base.py tests/e2e/offline_inference/test_qwen3_tts_base.py -m "advanced_model" --run-level "advanced_model"
'
- label: "Diffusion Image Edit Test"
@@ -173,43 +164,58 @@ steps:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- timeout 20m pytest -s -v tests/e2e/online_serving/test_image_gen_edit.py
-# split Bagel Model Test with H100 (Real Weights) into three tests
-- label: "Bagel Text2Img Model Test"
- agent_pool: mi325_1
- depends_on: amd-build
- mirror_hardwares: [amdproduction]
- grade: Blocking
- commands:
- - export GPU_ARCHS=gfx942
- - export VLLM_TEST_CLEAN_GPU_MEMORY=1
- - export VLLM_LOGGING_LEVEL=DEBUG
- - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- - export VLLM_ROCM_USE_AITER_RMSNORM=0
- - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_text2img.py -m "advanced_model" --run-level "advanced_model" -k "shared_memory" -k "rocm"
+# TODO: Bagel test on ROCm is very unstable. @tjtanaa
+# Need to debug before reneable numerical changes across large PRs
+# # split Bagel Model Test with H100 (Real Weights) into three tests
+# - label: "Bagel Text2Img Model Test (1/3)"
+# agent_pool: mi325_1
+# depends_on: amd-build
+# mirror_hardwares: [amdproduction]
+# grade: Blocking
+# commands:
+# - export GPU_ARCHS=gfx942
+# - export VLLM_TEST_CLEAN_GPU_MEMORY=1
+# - export VLLM_LOGGING_LEVEL=DEBUG
+# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+# - export VLLM_ROCM_USE_AITER_RMSNORM=0
+# - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_text2img.py -m "advanced_model" --run-level "advanced_model" -k "shared_memory" -k "rocm"
-- label: "Bagel Img2Img Model Test"
- agent_pool: mi325_1
- depends_on: amd-build
- mirror_hardwares: [amdproduction]
- grade: Blocking
- commands:
- - export GPU_ARCHS=gfx942
- - export VLLM_TEST_CLEAN_GPU_MEMORY=1
- - export VLLM_LOGGING_LEVEL=DEBUG
- - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- - export VLLM_ROCM_USE_AITER_RMSNORM=0
- - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_img2img.py -m "advanced_model" --run-level "advanced_model" -k "rocm"
+# - label: "Bagel Img2Img Model Test (2/3)"
+# agent_pool: mi325_1
+# depends_on: amd-build
+# mirror_hardwares: [amdproduction]
+# grade: Blocking
+# commands:
+# - export GPU_ARCHS=gfx942
+# - export VLLM_TEST_CLEAN_GPU_MEMORY=1
+# - export VLLM_LOGGING_LEVEL=DEBUG
+# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+# - export VLLM_ROCM_USE_AITER_RMSNORM=0
+# - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_img2img.py -m "advanced_model" --run-level "advanced_model" -k "rocm"
+
+# - label: "Bagel Online Serving Test (3/3)"
+# agent_pool: mi325_1
+# depends_on: amd-build
+# mirror_hardwares: [amdproduction]
+# grade: Blocking
+# commands:
+# - export GPU_ARCHS=gfx942
+# - export VLLM_TEST_CLEAN_GPU_MEMORY=1
+# - export VLLM_IMAGE_FETCH_TIMEOUT=60
+# - export VLLM_LOGGING_LEVEL=DEBUG
+# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+# - export VLLM_ROCM_USE_AITER_RMSNORM=0
+# - timeout 40m pytest -s -v tests/e2e/online_serving/test_bagel_online.py -m "advanced_model" --run-level "advanced_model" -k "rocm"
-- label: "Bagel Online Serving Test"
+- label: "Voxtral-TTS E2E Test"
agent_pool: mi325_1
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- - export GPU_ARCHS=gfx942
- - export VLLM_TEST_CLEAN_GPU_MEMORY=1
- - export VLLM_IMAGE_FETCH_TIMEOUT=60
- - export VLLM_LOGGING_LEVEL=DEBUG
- - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- - export VLLM_ROCM_USE_AITER_RMSNORM=0
- - timeout 40m pytest -s -v tests/e2e/online_serving/test_bagel_online.py -m "advanced_model" --run-level "advanced_model" -k "rocm"
+ - |
+ timeout 20m bash -c '
+ export VLLM_LOGGING_LEVEL=DEBUG
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ pytest -s -v tests/e2e/online_serving/test_voxtral_tts.py tests/e2e/offline_inference/test_voxtral_tts.py -m "advanced_model" --run-level "advanced_model"
+ '
diff --git a/.buildkite/test-amd-ready.yaml b/.buildkite/test-amd-ready.yaml
index 6e31163acc..30bbc76941 100644
--- a/.buildkite/test-amd-ready.yaml
+++ b/.buildkite/test-amd-ready.yaml
@@ -9,13 +9,37 @@ steps:
- export VLLM_ROCM_USE_AITER=0
- "timeout 20m pytest -v -s -m 'core_model and cpu' --cov=vllm_omni --cov-branch --cov-report=term-missing --cov-report=html --cov-report=xml"
+- label: "Voxtral TTS CUDA Unit Test"
+ agent_pool: mi325_1
+ depends_on: amd-build
+ mirror_hardwares: [amdproduction]
+ grade: Blocking
+ commands:
+ - timeout 10m pytest -s -v tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py
+
- label: "Diffusion Model Test"
- agent_pool: mi325_2
+ agent_pool: mi325_1
+ depends_on: amd-build
+ mirror_hardwares: [amdproduction]
+ grade: Blocking
+ commands:
+ - timeout 30m pytest -s -v tests/e2e/offline_inference/test_t2i_model.py -m "core_model and diffusion" --run-level "core_model"
+
+- label: "Diffusion Batching Test"
+ agent_pool: mi325_1
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- - timeout 20m pytest -s -v tests/e2e/offline_inference/test_t2i_model.py -m "core_model and diffusion" --run-level "core_model"
+ - timeout 20m pytest -s -v tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py -m "core_model and diffusion" --run-level "core_model"
+
+- label: "Custom Pipeline Test"
+ agent_pool: mi325_1
+ depends_on: amd-build
+ mirror_hardwares: [amdproduction]
+ grade: Blocking
+ commands:
+ - timeout 20m pytest -s -v tests/e2e/offline_inference/custom_pipeline/ -m "core_model"
- label: "Diffusion Model CPU offloading Test"
agent_pool: mi325_1
@@ -23,7 +47,6 @@ steps:
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- - export GPU_ARCHS=gfx942
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- |
@@ -46,7 +69,7 @@ steps:
# - export GPU_ARCHS=gfx942
# - export VLLM_LOGGING_LEVEL=DEBUG
# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
-# - timeout 20m pytest -s -v tests/e2e/offline_inference/test_stable_audio_model.py
+# - timeout 20m pytest -s -v tests/e2e/offline_inference/test_stable_audio_expansion.py -m "advanced_model and diffusion and L4" --run-level advanced_model
- label: "Diffusion Cache Backend Test"
agent_pool: mi325_1
@@ -77,47 +100,58 @@ steps:
commands:
- timeout 20m pytest -s -v tests/diffusion/test_diffusion_worker.py
-- label: "Benchmark & Engine Test"
- agent_pool: mi325_2
+- label: "Engine Test"
+ agent_pool: mi325_1
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- |
- timeout 30m bash -c '
- set +e
- pytest -s -v tests/benchmarks/test_serve_cli.py
- EXIT1=\$?
- pytest -s -v tests/engine/test_async_omni_engine_abort.py
- EXIT2=\$?
- exit \$((EXIT1 | EXIT2))
+ timeout 15m bash -c '
+ pytest -s -v tests/engine/test_async_omni_engine_abort.py
'
-- label: "Omni Model Test Qwen2-5-Omni"
- agent_pool: mi325_2
- depends_on: amd-build
- mirror_hardwares: [amdproduction]
- grade: Blocking
- commands:
- - export VLLM_LOGGING_LEVEL=DEBUG
- - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- - timeout 17m pytest -s -v tests/e2e/offline_inference/test_qwen2_5_omni.py
-- label: "Omni Model Test Qwen3-Omni"
- agent_pool: mi325_2
+# NOTE: This test is not running any thing. It is skipped and deselected.
+# Currently it is = 1 skipped, 1 deselected, 17 warnings in 0.03s ======
+# - label: "Omni Model Test Qwen2-5-Omni"
+# agent_pool: mi325_2
+# depends_on: amd-build
+# mirror_hardwares: [amdproduction]
+# grade: Blocking
+# commands:
+# - export VLLM_LOGGING_LEVEL=DEBUG
+# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+# - timeout 20m pytest -s -v tests/e2e/offline_inference/test_qwen2_5_omni.py -m "core_model" --run-level "core_model"
+
+# - label: "Omni Model Test Qwen3-Omni"
+# agent_pool: mi325_2
+# depends_on: amd-build
+# mirror_hardwares: [amdproduction]
+# grade: Blocking
+# commands:
+# - export VLLM_LOGGING_LEVEL=DEBUG
+# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+# - export VLLM_TEST_CLEAN_GPU_MEMORY=1
+# - timeout 10m pytest -s -v tests/e2e/offline_inference/test_qwen3_omni.py
+# - timeout 20m pytest -s -v tests/e2e/online_serving/test_qwen3_omni.py -m "core_model" --run-level "core_model"
+
+- label: "MiMo-Audio E2E Test with H100"
+ agent_pool: mi325_1
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- - export VLLM_LOGGING_LEVEL=DEBUG
- - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- - export VLLM_TEST_CLEAN_GPU_MEMORY=1
- - timeout 10m pytest -s -v tests/e2e/offline_inference/test_qwen3_omni.py
- - timeout 10m pytest -s -v tests/e2e/online_serving/test_qwen3_omni.py -m "core_model" --run-level "core_model"
+ - |
+ timeout 30m bash -c '
+ export VLLM_LOGGING_LEVEL=DEBUG
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ pytest -s -v tests/e2e/online_serving/test_mimo_audio.py -m "core_model" --run-level "core_model"
+ '
- label: "Qwen3-TTS E2E Test"
- agent_pool: mi325_2
+ agent_pool: mi325_1
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
@@ -125,55 +159,82 @@ steps:
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
- - timeout 20m pytest -s -v tests/e2e/online_serving/test_qwen3_tts_customvoice.py -m "core_model" --run-level "core_model"
+ - timeout 30m pytest -s -v tests/e2e/online_serving/test_qwen3_tts_customvoice.py -m "core_model" --run-level "core_model"
-- label: "Diffusion Image Edit Test"
+- label: "Voxtral-TTS E2E Test"
agent_pool: mi325_1
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- - export GPU_ARCHS=gfx942
- - export VLLM_LOGGING_LEVEL=DEBUG
- - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- - timeout 20m pytest -s -v tests/e2e/online_serving/test_image_gen_edit.py
+ - |
+ timeout 20m bash -c '
+ export VLLM_LOGGING_LEVEL=DEBUG
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ pytest -s -v tests/e2e/online_serving/test_voxtral_tts.py -m "advanced_model" --run-level "advanced_model"
+ pytest -s -v tests/e2e/offline_inference/test_voxtral_tts.py -m "advanced_model" --run-level "advanced_model"
+ '
-- label: "Bagel Text2Img Model Test"
+- label: "Diffusion Image Edit Test"
agent_pool: mi325_1
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- export GPU_ARCHS=gfx942
- - export VLLM_TEST_CLEAN_GPU_MEMORY=1
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- - export VLLM_ROCM_USE_AITER_RMSNORM=0
- - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_text2img.py -m "core_model" --run-level "core_model" -k "rocm"
+ - timeout 20m pytest -s -v tests/e2e/online_serving/test_image_gen_edit.py
-- label: "Bagel Img2Img Model Test"
- agent_pool: mi325_1
- depends_on: amd-build
- mirror_hardwares: [amdproduction]
- grade: Blocking
- commands:
- - export GPU_ARCHS=gfx942
- - export VLLM_TEST_CLEAN_GPU_MEMORY=1
- - export VLLM_LOGGING_LEVEL=DEBUG
- - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- - export VLLM_ROCM_USE_AITER_RMSNORM=0
- - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_img2img.py -m "core_model" --run-level "core_model" -k "rocm"
+# TODO: Bagel test on ROCm is very unstable. @tjtanaa
+# Need to debug before reneable numerical changes across large PRs
+# - label: "Bagel Text2Img Model Test"
+# agent_pool: mi325_1
+# depends_on: amd-build
+# mirror_hardwares: [amdproduction]
+# grade: Blocking
+# commands:
+# - export GPU_ARCHS=gfx942
+# - export VLLM_TEST_CLEAN_GPU_MEMORY=1
+# - export VLLM_LOGGING_LEVEL=DEBUG
+# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+# - export VLLM_ROCM_USE_AITER_RMSNORM=0
+# - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_text2img.py -m "core_model" --run-level "core_model" -k "rocm"
+
+# - label: "Bagel Img2Img Model Test"
+# agent_pool: mi325_1
+# depends_on: amd-build
+# mirror_hardwares: [amdproduction]
+# grade: Blocking
+# commands:
+# - export GPU_ARCHS=gfx942
+# - export VLLM_TEST_CLEAN_GPU_MEMORY=1
+# - export VLLM_LOGGING_LEVEL=DEBUG
+# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+# - export VLLM_ROCM_USE_AITER_RMSNORM=0
+# - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_img2img.py -m "core_model" --run-level "core_model" -k "rocm"
-- label: "Bagel Online Serving Test"
+# - label: "Bagel Online Serving Test"
+# agent_pool: mi325_1
+# depends_on: amd-build
+# mirror_hardwares: [amdproduction]
+# grade: Blocking
+# commands:
+# - export GPU_ARCHS=gfx942
+# - export VLLM_TEST_CLEAN_GPU_MEMORY=1
+# - export VLLM_IMAGE_FETCH_TIMEOUT=60
+# - export VLLM_LOGGING_LEVEL=DEBUG
+# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+# - export VLLM_ROCM_USE_AITER_RMSNORM=0
+# - timeout 40m pytest -s -v tests/e2e/online_serving/test_bagel_online.py -m "core_model" --run-level "core_model" -k "rocm"
+
+- label: "CosyVoice3-TTS E2E Test"
agent_pool: mi325_1
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- - export GPU_ARCHS=gfx942
- - export VLLM_TEST_CLEAN_GPU_MEMORY=1
- - export VLLM_IMAGE_FETCH_TIMEOUT=60
- - export VLLM_LOGGING_LEVEL=DEBUG
- - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- - export VLLM_ROCM_USE_AITER_RMSNORM=0
- - timeout 40m pytest -s -v tests/e2e/online_serving/test_bagel_online.py -m "core_model" --run-level "core_model" -k "rocm"
+ - |
+ timeout 20m bash -c '
+ pytest -s -v tests/e2e/online_serving/test_cosyvoice3_tts.py -m "core_model" --run-level "core_model"
+ '
diff --git a/.buildkite/test-merge.yml b/.buildkite/test-merge.yml
index 7355e2b4c7..2a6cb6488a 100644
--- a/.buildkite/test-merge.yml
+++ b/.buildkite/test-merge.yml
@@ -76,24 +76,6 @@ steps:
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"
- - label: "Audio Generation Model Test"
- timeout_in_minutes: 20
- depends_on: upload-merge-pipeline
- commands:
- - pytest -s -v tests/e2e/offline_inference/test_stable_audio_model.py
- agents:
- queue: "gpu_1_queue" # g6.4xlarge instance on AWS, has 1 L4 GPU
- plugins:
- - docker#v5.2.0:
- image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- always-pull: true
- propagate-environment: true
- environment:
- - "HF_HOME=/fsx/hf_cache"
- - "HF_TOKEN"
- volumes:
- - "/fsx/hf_cache:/fsx/hf_cache"
-
- label: "Diffusion Cache Backend Test"
timeout_in_minutes: 15
depends_on: upload-merge-pipeline
@@ -113,7 +95,7 @@ steps:
- "/fsx/hf_cache:/fsx/hf_cache"
- label: "Diffusion Sequence Parallelism Test"
- timeout_in_minutes: 20
+ timeout_in_minutes: 25
depends_on: upload-merge-pipeline
commands:
- pytest -s -v tests/e2e/offline_inference/test_sequence_parallel.py tests/diffusion/distributed/test_ulysses_uaa_perf.py
diff --git a/.buildkite/test-nightly-diffusion.yml b/.buildkite/test-nightly-diffusion.yml
deleted file mode 100644
index 04b99c0a83..0000000000
--- a/.buildkite/test-nightly-diffusion.yml
+++ /dev/null
@@ -1,364 +0,0 @@
-# Nightly diffusion GPU tests — appended to the main nightly build via
-# buildkite-agent pipeline upload .buildkite/test-nightly-diffusion.yml
-# from test-nightly.yml (step key: nightly-diffusion-model-test). Top-level groups are
-# foldable in the Buildkite UI (Other / Wan / Qwen-Image).
-env:
- VLLM_WORKER_MULTIPROC_METHOD: spawn
- HF_HUB_DOWNLOAD_TIMEOUT: 300
- HF_HUB_ETAG_TIMEOUT: 60
-
-steps:
- - group: ":card_index_dividers: Other Model Test"
- key: nightly-other-model-test-group
- steps:
- - label: ":full_moon: Diffusion · Other · Function Test with H100"
- timeout_in_minutes: 120
- # Shared nightly vs PR label conditional; referenced below as *nightly_or_pr_label
- if: &nightly_or_pr_label 'build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test"'
- commands:
- - pytest -s -v tests/e2e/online_serving/test_*_expansion.py -k "not test_wan22_expansion and not test_wan_2_1_vace_expansion and not test_qwen_image" -m "advanced_model and diffusion and H100" --run-level "advanced_model"
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 2
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
- - name: devshm
- emptyDir:
- medium: Memory
- - name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
-
- - label: ":full_moon: Diffusion · Other · Function Test with L4"
- timeout_in_minutes: 60
- if: *nightly_or_pr_label
- commands:
- - pytest -s -v tests/e2e/online_serving/test_*_expansion.py -m "advanced_model and diffusion and L4" --run-level "advanced_model"
- agents:
- queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU
- plugins:
- - docker#v5.2.0:
- image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- always-pull: true
- propagate-environment: true
- shm-size: "8gb"
- environment:
- - "HF_HOME=/fsx/hf_cache"
- - "HF_TOKEN"
- volumes:
- - "/fsx/hf_cache:/fsx/hf_cache"
-
- - label: ":full_moon: Diffusion · Other · Doc Test"
- timeout_in_minutes: 60
- if: *nightly_or_pr_label
- commands:
- - export VLLM_TEST_CLEAN_GPU_MEMORY="1"
- - pytest -s -v tests/examples/online_serving/test_text_to_image.py tests/examples/offline_inference/test_text_to_image.py -m "advanced_model and example and H100" --run-level "advanced_model"
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 2
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
- - name: devshm
- emptyDir:
- medium: Memory
- - name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
-
- - group: ":card_index_dividers: Wan Series Model Test"
- key: nightly-wan-model-test-group
- steps:
- - label: ":full_moon: Diffusion · Wan · Function Test"
- timeout_in_minutes: 90
- if: *nightly_or_pr_label
- commands:
- - pytest -s -v tests/e2e/online_serving/test_wan22_expansion.py tests/e2e/online_serving/test_wan_2_1_vace_expansion.py -m "advanced_model" --run-level "advanced_model"
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 2
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
- - name: devshm
- emptyDir:
- medium: Memory
- - name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
-
- - label: ":full_moon: Diffusion · Wan · Accuracy Test"
- key: nightly-wan22-i2v-accuracy
- timeout_in_minutes: 180
- if: *nightly_or_pr_label
- commands:
- - pytest -s -v tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py --run-level advanced_model
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 2
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
- - name: devshm
- emptyDir:
- medium: Memory
- - name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
-
- - group: ":card_index_dividers: Qwen-Image Series Model Test"
- key: nightly-qwen-image-edit-group
- steps:
- - label: ":full_moon: Diffusion · Qwen-Image · Function Test with H100"
- timeout_in_minutes: 120
- if: *nightly_or_pr_label
- commands:
- - pytest -s -v tests/e2e/online_serving/test_qwen_image*_expansion.py -m "advanced_model and diffusion and H100" --run-level "advanced_model"
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 2
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
- - name: devshm
- emptyDir:
- medium: Memory
- - name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
-
- - label: ":full_moon: Diffusion · Qwen-Image · GEBench Accuracy Test"
- key: nightly-gebench-accuracy
- timeout_in_minutes: 60
- if: *nightly_or_pr_label
- commands:
- - pytest -s -v tests/e2e/accuracy/test_gebench_h100_smoke.py --run-level advanced_model --gebench-model Qwen/Qwen-Image-2512 --accuracy-judge-model QuantTrio/Qwen3-VL-30B-A3B-Instruct-AWQ --accuracy-gpu 0 --gebench-port 8093 --accuracy-workers 1
- - buildkite-agent artifact upload "tests/e2e/accuracy/artifacts/gebench_qwen-image-2512/summary*.json"
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 1
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
- - name: devshm
- emptyDir:
- medium: Memory
- - name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
-
- - label: ":full_moon: Diffusion · Qwen-Image · GEdit-Bench Accuracy Test"
- key: nightly-gedit-bench-accuracy
- timeout_in_minutes: 60
- if: *nightly_or_pr_label
- commands:
- - pytest -s -v tests/e2e/accuracy/test_gedit_bench_h100_smoke.py --run-level advanced_model --gedit-model Qwen/Qwen-Image-Edit --accuracy-judge-model QuantTrio/Qwen3-VL-30B-A3B-Instruct-AWQ --accuracy-gpu 0 --gedit-port 8093 --gedit-samples-per-group 20 --accuracy-workers 1
- - buildkite-agent artifact upload "tests/e2e/accuracy/artifacts/gedit_scores_qwen-image-edit/qwen-image-edit_all_all_vie_score_*.csv"
- - buildkite-agent artifact upload "tests/e2e/accuracy/artifacts/gedit_scores_qwen-image-edit/qwen-image-edit_all_all_summary_*.json"
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 1
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: VLLM_HTTP_TIMEOUT_KEEP_ALIVE
- value: "120"
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
- - name: devshm
- emptyDir:
- medium: Memory
- - name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
-
- - label: ":full_moon: Diffusion · Qwen-Image · Perf Test"
- key: nightly-qwen-image-performance
- timeout_in_minutes: 180
- if: *nightly_or_pr_label
- commands:
- - export DIFFUSION_BENCHMARK_DIR=tests/dfx/perf/results
- - export CACHE_DIT_VERSION=1.3.0
- - pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --config-file tests/dfx/perf/tests/test_qwen_image_vllm_omni.json
- - buildkite-agent artifact upload "tests/dfx/perf/results/benchmark_results_*.json"
- - buildkite-agent artifact upload "tests/dfx/perf/results/logs/*.log"
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 4
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
- - name: devshm
- emptyDir:
- medium: Memory
- - name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
diff --git a/.buildkite/test-nightly.yml b/.buildkite/test-nightly.yml
index 06b7c14ae1..02d8cced40 100644
--- a/.buildkite/test-nightly.yml
+++ b/.buildkite/test-nightly.yml
@@ -7,12 +7,11 @@ steps:
# Group: collapses under one heading in the Buildkite UI; child steps still run in parallel.
- group: ":card_index_dividers: Omni Model Test"
key: nightly-omni-test-group
+ depends_on: upload-nightly-pipeline
+ if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test" || build.pull_request.labels includes "omni-test"
steps:
- - label: ":full_moon: Omni · Function Test with H100"
+ - label: ":full_moon: Omni · Function Test"
timeout_in_minutes: 90
- depends_on: upload-nightly-pipeline
- # Shared nightly vs PR label conditional; referenced below as *nightly_or_pr_label
- if: &nightly_or_pr_label 'build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test"'
commands:
- pytest -s -v tests/e2e/online_serving/test_*_expansion.py -m "advanced_model and H100 and omni" --run-level "advanced_model"
agents:
@@ -49,13 +48,11 @@ steps:
path: /mnt/hf-cache
type: DirectoryOrCreate
- - label: ":full_moon: Omni · Function Test with L4"
+ - label: ":full_moon: Omni · Doc Test with L4"
timeout_in_minutes: 90
- depends_on: upload-nightly-pipeline
- if: *nightly_or_pr_label
commands:
- export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
- - pytest -s -v tests/e2e/online_serving/test_*_expansion.py -m "advanced_model and L4 and omni" --run-level "advanced_model"
+ - pytest -s -v tests/examples/ -m "advanced_model and omni and L4" --run-level "advanced_model"
agents:
queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU
plugins:
@@ -70,13 +67,211 @@ steps:
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"
- - label: ":full_moon: Omni · Doc Test with L4"
+ - label: ":full_moon: Omni · Doc Test with H100"
+ timeout_in_minutes: 90
+ commands:
+ - pytest -s -v tests/examples/ -m "advanced_model and omni and H100" --run-level "advanced_model"
+ agents:
+ queue: "mithril-h100-pool"
+ plugins:
+ - kubernetes:
+ podSpec:
+ containers:
+ - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ resources:
+ limits:
+ nvidia.com/gpu: 2
+ volumeMounts:
+ - name: devshm
+ mountPath: /dev/shm
+ - name: hf-cache
+ mountPath: /root/.cache/huggingface
+ env:
+ - name: HF_HOME
+ value: /root/.cache/huggingface
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ nodeSelector:
+ node.kubernetes.io/instance-type: gpu-h100-sxm
+ volumes:
+ - name: devshm
+ emptyDir:
+ medium: Memory
+ - name: hf-cache
+ hostPath:
+ path: /mnt/hf-cache
+ type: DirectoryOrCreate
+
+ - label: ":full_moon: Omni · Perf Test"
+ key: nightly-omni-performance
+ timeout_in_minutes: 180
+ commands:
+ - export BENCHMARK_DIR=tests/dfx/perf/results
+ - |
+ set +e
+ pytest -s -v tests/dfx/perf/scripts/run_benchmark.py --test-config-file tests/dfx/perf/tests/test_qwen_omni.json
+ EXIT=$$?
+ buildkite-agent artifact upload "tests/dfx/perf/results/*.json"
+ exit $$EXIT
+ agents:
+ queue: "mithril-h100-pool"
+ plugins:
+ - kubernetes:
+ podSpec:
+ containers:
+ - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ resources:
+ limits:
+ nvidia.com/gpu: 2
+ volumeMounts:
+ - name: devshm
+ mountPath: /dev/shm
+ - name: hf-cache
+ mountPath: /root/.cache/huggingface
+ env:
+ - name: HF_HOME
+ value: /root/.cache/huggingface
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ nodeSelector:
+ node.kubernetes.io/instance-type: gpu-h100-sxm
+ volumes:
+ - name: devshm
+ emptyDir:
+ medium: Memory
+ - name: hf-cache
+ hostPath:
+ path: /mnt/hf-cache
+ type: DirectoryOrCreate
+
+
+ - group: ":card_index_dividers: TTS Model Test"
+ key: nightly-tts-test-group
+ depends_on: upload-nightly-pipeline
+ if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test" || build.pull_request.labels includes "tts-test"
+ steps:
+ - label: ":full_moon: TTS · Function Test"
timeout_in_minutes: 90
- depends_on: upload-nightly-pipeline
- if: *nightly_or_pr_label
commands:
- export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
- - pytest -s -v tests/examples/ -m "advanced_model and omni and L4" --run-level "advanced_model"
+ - pytest -s -v tests/e2e/online_serving/test_*_expansion.py -m "advanced_model and L4 and omni" --run-level "advanced_model"
+ agents:
+ queue: "gpu_1_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU
+ plugins:
+ - docker#v5.2.0:
+ image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ always-pull: true
+ propagate-environment: true
+ shm-size: "8gb"
+ environment:
+ - "HF_HOME=/fsx/hf_cache"
+ - "HF_TOKEN"
+ volumes:
+ - "/fsx/hf_cache:/fsx/hf_cache"
+
+ - label: ":full_moon: TTS · Perf Test"
+ key: nightly-tts-performance
+ timeout_in_minutes: 180
+ commands:
+ - export BENCHMARK_DIR=tests/dfx/perf/results
+ - export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
+ - |
+ set +e
+ pytest -s -v tests/dfx/perf/scripts/run_benchmark.py --test-config-file tests/dfx/perf/tests/test_tts.json
+ EXIT=$$?
+ buildkite-agent artifact upload "tests/dfx/perf/results/*.json"
+ exit $$EXIT
+ agents:
+ queue: "mithril-h100-pool"
+ plugins:
+ - kubernetes:
+ podSpec:
+ containers:
+ - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ resources:
+ limits:
+ nvidia.com/gpu: 1
+ volumeMounts:
+ - name: devshm
+ mountPath: /dev/shm
+ - name: hf-cache
+ mountPath: /root/.cache/huggingface
+ env:
+ - name: HF_HOME
+ value: /root/.cache/huggingface
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ nodeSelector:
+ node.kubernetes.io/instance-type: gpu-h100-sxm
+ volumes:
+ - name: devshm
+ emptyDir:
+ medium: Memory
+ - name: hf-cache
+ hostPath:
+ path: /mnt/hf-cache
+ type: DirectoryOrCreate
+
+ # Diffusion X2I suite: x2i / x2a / x2t and related non-video paths; x2v is only in "Diffusion X2V Model Test" below.
+ - group: ":card_index_dividers: Diffusion X2I(&A&T) Model Test"
+ key: nightly-diffusion-x2iat-group
+ depends_on: upload-nightly-pipeline
+ if: >-
+ build.env("NIGHTLY") == "1" ||
+ build.pull_request.labels includes "nightly-test" ||
+ build.pull_request.labels includes "diffusion-x2iat-test"
+ steps:
+ - label: ":full_moon: Diffusion X2I(&A&T) · Function Test with H100"
+ timeout_in_minutes: 120
+ commands:
+ - pytest -s -v tests/e2e/online_serving/test_*_expansion.py -k "not test_wan22_expansion and not test_wan_2_1_vace_expansion and not hunyuan" -m "advanced_model and diffusion and H100" --run-level "advanced_model"
+ agents:
+ queue: "mithril-h100-pool"
+ plugins:
+ - kubernetes:
+ podSpec:
+ containers:
+ - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ resources:
+ limits:
+ nvidia.com/gpu: 2
+ volumeMounts:
+ - name: devshm
+ mountPath: /dev/shm
+ - name: hf-cache
+ mountPath: /root/.cache/huggingface
+ env:
+ - name: HF_HOME
+ value: /root/.cache/huggingface
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ nodeSelector:
+ node.kubernetes.io/instance-type: gpu-h100-sxm
+ volumes:
+ - name: devshm
+ emptyDir:
+ medium: Memory
+ - name: hf-cache
+ hostPath:
+ path: /mnt/hf-cache
+ type: DirectoryOrCreate
+
+ - label: ":full_moon: Diffusion X2I(&A&T) · Function Test with L4"
+ timeout_in_minutes: 60
+ commands:
+ - pytest -s -v tests/e2e/online_serving/test_*_expansion.py -k "not test_wan22_expansion and not test_wan_2_1_vace_expansion and not hunyuan" -m "advanced_model and diffusion and L4" --run-level "advanced_model"
agents:
queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU
plugins:
@@ -91,12 +286,11 @@ steps:
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"
- - label: ":full_moon: Omni · Doc Test with H100"
- timeout_in_minutes: 90
- depends_on: upload-nightly-pipeline
- if: *nightly_or_pr_label
+ - label: ":full_moon: Diffusion X2I(&A&T) · Doc Test"
+ timeout_in_minutes: 60
commands:
- - pytest -s -v tests/examples/ -m "advanced_model and omni and H100" --run-level "advanced_model"
+ - export VLLM_TEST_CLEAN_GPU_MEMORY="1"
+ - pytest -s -v tests/examples/*/test_text_to_image.py -m "advanced_model and example and H100" --run-level "advanced_model"
agents:
queue: "mithril-h100-pool"
plugins:
@@ -131,17 +325,91 @@ steps:
path: /mnt/hf-cache
type: DirectoryOrCreate
- - label: ":full_moon: Omni · Perf Test"
- key: nightly-omni-performance
+ - label: ":full_moon: Diffusion X2I(&A&T) · GEBench Accuracy Test"
+ timeout_in_minutes: 60
+ commands:
+ - pytest -s -v tests/e2e/accuracy/test_gebench_h100_smoke.py --run-level advanced_model --gebench-model Qwen/Qwen-Image-2512 --accuracy-judge-model QuantTrio/Qwen3-VL-30B-A3B-Instruct-AWQ --accuracy-gpu 0 --gebench-port 8093 --accuracy-workers 1
+ - buildkite-agent artifact upload "tests/e2e/accuracy/artifacts/gebench_qwen-image-2512/summary*.json"
+ agents:
+ queue: "mithril-h100-pool"
+ plugins:
+ - kubernetes:
+ podSpec:
+ containers:
+ - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ resources:
+ limits:
+ nvidia.com/gpu: 1
+ volumeMounts:
+ - name: devshm
+ mountPath: /dev/shm
+ - name: hf-cache
+ mountPath: /root/.cache/huggingface
+ env:
+ - name: HF_HOME
+ value: /root/.cache/huggingface
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ nodeSelector:
+ node.kubernetes.io/instance-type: gpu-h100-sxm
+ volumes:
+ - name: devshm
+ emptyDir:
+ medium: Memory
+ - name: hf-cache
+ hostPath:
+ path: /mnt/hf-cache
+ type: DirectoryOrCreate
+
+ - label: ":full_moon: Diffusion X2I(&A&T) · GEdit-Bench Accuracy Test"
+ timeout_in_minutes: 60
+ commands:
+ - pytest -s -v tests/e2e/accuracy/test_gedit_bench_h100_smoke.py --run-level advanced_model --gedit-model Qwen/Qwen-Image-Edit --accuracy-judge-model QuantTrio/Qwen3-VL-30B-A3B-Instruct-AWQ --accuracy-gpu 0 --gedit-port 8093 --gedit-samples-per-group 20 --accuracy-workers 1
+ - buildkite-agent artifact upload "tests/e2e/accuracy/artifacts/gedit_scores_qwen-image-edit/qwen-image-edit_all_all_vie_score_*.csv"
+ - buildkite-agent artifact upload "tests/e2e/accuracy/artifacts/gedit_scores_qwen-image-edit/qwen-image-edit_all_all_summary_*.json"
+ agents:
+ queue: "mithril-h100-pool"
+ plugins:
+ - kubernetes:
+ podSpec:
+ containers:
+ - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ resources:
+ limits:
+ nvidia.com/gpu: 1
+ volumeMounts:
+ - name: devshm
+ mountPath: /dev/shm
+ - name: hf-cache
+ mountPath: /root/.cache/huggingface
+ env:
+ - name: HF_HOME
+ value: /root/.cache/huggingface
+ - name: VLLM_HTTP_TIMEOUT_KEEP_ALIVE
+ value: "120"
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ nodeSelector:
+ node.kubernetes.io/instance-type: gpu-h100-sxm
+ volumes:
+ - name: devshm
+ emptyDir:
+ medium: Memory
+ - name: hf-cache
+ hostPath:
+ path: /mnt/hf-cache
+ type: DirectoryOrCreate
+
+ - label: ":full_moon: Diffusion X2I(&A&T) · Accuracy Test"
timeout_in_minutes: 180
- depends_on: upload-nightly-pipeline
- if: *nightly_or_pr_label
commands:
- - export BENCHMARK_DIR=tests/dfx/perf/results
- - export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
- - pytest -s -v tests/dfx/perf/scripts/run_benchmark.py
- - buildkite-agent artifact upload "tests/dfx/perf/results/*.json"
- - buildkite-agent artifact upload "tests/dfx/perf/results/*.html"
+ - pytest -s -v tests/e2e/accuracy/test_qwen_image*.py --run-level advanced_model
agents:
queue: "mithril-h100-pool"
plugins:
@@ -151,7 +419,65 @@ steps:
- image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
resources:
limits:
- nvidia.com/gpu: 2
+ nvidia.com/gpu: 1
+ volumeMounts:
+ - name: devshm
+ mountPath: /dev/shm
+ - name: hf-cache
+ mountPath: /root/.cache/huggingface
+ env:
+ - name: HF_HOME
+ value: /root/.cache/huggingface
+ - name: VLLM_HTTP_TIMEOUT_KEEP_ALIVE
+ value: "120"
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ nodeSelector:
+ node.kubernetes.io/instance-type: gpu-h100-sxm
+ volumes:
+ - name: devshm
+ emptyDir:
+ medium: Memory
+ - name: hf-cache
+ hostPath:
+ path: /mnt/hf-cache
+ type: DirectoryOrCreate
+
+ - label: ":full_moon: Diffusion X2I(&A&T) · Perf Test"
+ key: nightly-diffusion-x2iat-performance
+ timeout_in_minutes: 180
+ commands:
+ - export DIFFUSION_BENCHMARK_DIR=tests/dfx/perf/results
+ - export DIFFUSION_ATTENTION_BACKEND=FLASH_ATTN
+ - export CACHE_DIT_VERSION=1.3.0
+ # [HACK]: run upload in the same command block as pytest.
+ # Because `exit` aborts the entire commands list.
+ - |
+ set +e
+ pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --test-config-file tests/dfx/perf/tests/test_qwen_image_vllm_omni.json
+ EXIT1=$$?
+ pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --test-config-file tests/dfx/perf/tests/test_qwen_image_edit_vllm_omni.json
+ EXIT2=$$?
+ pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --test-config-file tests/dfx/perf/tests/test_qwen_image_edit_2509_vllm_omni.json
+ EXIT3=$$?
+ pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --test-config-file tests/dfx/perf/tests/test_qwen_image_layered_vllm_omni.json
+ EXIT4=$$?
+ buildkite-agent artifact upload "tests/dfx/perf/results/diffusion_result_*.json"
+ buildkite-agent artifact upload "tests/dfx/perf/results/logs/*.log"
+ exit $$((EXIT1 | EXIT2 | EXIT3 | EXIT4))
+ agents:
+ queue: "mithril-h100-pool"
+ plugins:
+ - kubernetes:
+ podSpec:
+ containers:
+ - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ resources:
+ limits:
+ nvidia.com/gpu: 4
volumeMounts:
- name: devshm
mountPath: /dev/shm
@@ -176,23 +502,96 @@ steps:
path: /mnt/hf-cache
type: DirectoryOrCreate
- # Dynamically appends steps from test-nightly-diffusion.yml into this build (same mechanism as
- # pipeline.yml → test-ready.yml / test-merge.yml / test-nightly.yml). Foldable groups stay in the
- # uploaded YAML (Other / Wan / Qwen-Image).
- - label: ":card_index_dividers: Diffusion Model Test"
- key: nightly-diffusion-model-test
+ # Diffusion x2v only (Wan, HunyuanVideo, …). x2i/x2a/x2t live in the X2I group above, not here.
+ - group: ":card_index_dividers: Diffusion X2V Model Test"
+ key: nightly-diffusion-x2v-group
depends_on: upload-nightly-pipeline
- if: *nightly_or_pr_label
- commands:
- - buildkite-agent pipeline upload .buildkite/test-nightly-diffusion.yml
- agents:
- queue: "cpu_queue_premerge"
+ if: >-
+ build.env("NIGHTLY") == "1" ||
+ build.pull_request.labels includes "nightly-test" ||
+ build.pull_request.labels includes "diffusion-x2v-test"
+ steps:
+ - label: ":full_moon: Diffusion X2V · Function Test"
+ timeout_in_minutes: 90
+ commands:
+ - pytest -s -v tests/e2e/online_serving/test_wan22_expansion.py tests/e2e/online_serving/test_wan_2_1_vace_expansion.py tests/e2e/online_serving/test_hunyuan_video_15_expansion.py -m "advanced_model" --run-level "advanced_model"
+ agents:
+ queue: "mithril-h100-pool"
+ plugins:
+ - kubernetes:
+ podSpec:
+ containers:
+ - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ resources:
+ limits:
+ nvidia.com/gpu: 2
+ volumeMounts:
+ - name: devshm
+ mountPath: /dev/shm
+ - name: hf-cache
+ mountPath: /root/.cache/huggingface
+ env:
+ - name: HF_HOME
+ value: /root/.cache/huggingface
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ nodeSelector:
+ node.kubernetes.io/instance-type: gpu-h100-sxm
+ volumes:
+ - name: devshm
+ emptyDir:
+ medium: Memory
+ - name: hf-cache
+ hostPath:
+ path: /mnt/hf-cache
+ type: DirectoryOrCreate
+
+ - label: ":full_moon: Diffusion X2V · Accuracy Test"
+ timeout_in_minutes: 180
+ commands:
+ - pytest -s -v tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py --run-level advanced_model
+ agents:
+ queue: "mithril-h100-pool"
+ plugins:
+ - kubernetes:
+ podSpec:
+ containers:
+ - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ resources:
+ limits:
+ nvidia.com/gpu: 2
+ volumeMounts:
+ - name: devshm
+ mountPath: /dev/shm
+ - name: hf-cache
+ mountPath: /root/.cache/huggingface
+ env:
+ - name: HF_HOME
+ value: /root/.cache/huggingface
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ nodeSelector:
+ node.kubernetes.io/instance-type: gpu-h100-sxm
+ volumes:
+ - name: devshm
+ emptyDir:
+ medium: Memory
+ - name: hf-cache
+ hostPath:
+ path: /mnt/hf-cache
+ type: DirectoryOrCreate
- label: ":bar_chart: Testcase Statistics"
key: nightly-testcase-statistics
timeout_in_minutes: 120
depends_on: upload-nightly-pipeline
- if: *nightly_or_pr_label
+ if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test"
commands:
- python tools/nightly/buildkite_testcase_statistics.py -o tests/dfx/perf/results/buildkite_testcase_statistics.html
- buildkite-agent artifact upload "tests/dfx/perf/results/*.html"
@@ -235,16 +634,18 @@ steps:
key: nightly-perf-distribution
depends_on:
- nightly-omni-performance
- - nightly-qwen-image-performance
+ - nightly-tts-performance
+ - nightly-diffusion-x2iat-performance
- nightly-testcase-statistics
if: build.env("NIGHTLY") == "1"
commands:
- pip install openpyxl
- export DEFAULT_INPUT_DIR=tests/dfx/perf/results
- export DEFAULT_OUTPUT_DIR=tests/dfx/perf/results
+ - buildkite-agent artifact download "tests/dfx/perf/results/*.json" . --step nightly-tts-performance
- buildkite-agent artifact download "tests/dfx/perf/results/*.json" . --step nightly-omni-performance
- - buildkite-agent artifact download "tests/dfx/perf/results/*.json" . --step nightly-qwen-image-performance
- - buildkite-agent artifact download "tests/dfx/perf/results/*.html" . --step nightly-omni-performance
+ - buildkite-agent artifact download "tests/dfx/perf/results/*.json" . --step nightly-diffusion-x2iat-performance
+ - buildkite-agent artifact download "tests/dfx/perf/results/*.html" . --step nightly-testcase-statistics
- python tools/nightly/generate_nightly_perf_excel.py
- python tools/nightly/generate_nightly_perf_html.py
- python tools/nightly/send_nightly_email.py --report-file "tests/dfx/perf/results/*.xlsx, tests/dfx/perf/results/*.html"
diff --git a/.buildkite/test-ready.yml b/.buildkite/test-ready.yml
index 2f1f05463a..3ca1747fe6 100644
--- a/.buildkite/test-ready.yml
+++ b/.buildkite/test-ready.yml
@@ -123,7 +123,7 @@ steps:
- label: "Audio Generation Model Test"
depends_on: upload-ready-pipeline
commands:
- - timeout 20m pytest -s -v tests/e2e/offline_inference/test_stable_audio_model.py
+ - timeout 20m pytest -s -v tests/e2e/offline_inference/test_stable_audio_expansion.py -m "advanced_model and diffusion and L4" --run-level advanced_model
agents:
queue: "gpu_1_queue" # g6.4xlarge instance on AWS, has 1 L4 GPU
plugins:
@@ -194,28 +194,6 @@ steps:
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"
-
- - label: "Omni Model Test"
- depends_on: upload-ready-pipeline
- commands:
- - |
- timeout 17m bash -c '
- export VLLM_LOGGING_LEVEL=DEBUG
- pytest -s -v tests/e2e/online_serving/test_qwen2_5_omni.py -m "core_model" --run-level "core_model"
- '
- agents:
- queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU
- plugins:
- - docker#v5.2.0:
- image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- always-pull: true
- propagate-environment: true
- environment:
- - "HF_HOME=/fsx/hf_cache"
- - "HF_TOKEN"
- volumes:
- - "/fsx/hf_cache:/fsx/hf_cache"
-
- label: "Omni Model Test with H100"
depends_on: upload-ready-pipeline
commands:
@@ -317,6 +295,56 @@ steps:
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"
+ - label: "VoxCPM E2E Test"
+ timeout_in_minutes: 20
+ depends_on: upload-ready-pipeline
+ commands:
+ - |
+ timeout 20m bash -c '
+ pip install voxcpm
+ export VLLM_LOGGING_LEVEL=DEBUG
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ pytest -s -v tests/e2e/offline_inference/test_voxcpm.py -m "core_model" --run-level "core_model"
+ '
+ agents:
+ queue: "gpu_1_queue"
+ plugins:
+ - docker#v5.2.0:
+ image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ always-pull: true
+ propagate-environment: true
+ shm-size: "8gb"
+ environment:
+ - "HF_HOME=/fsx/hf_cache"
+ - "HF_TOKEN"
+ volumes:
+ - "/fsx/hf_cache:/fsx/hf_cache"
+
+ - label: "VoxCPM2 Native AR E2E Test"
+ timeout_in_minutes: 20
+ depends_on: upload-ready-pipeline
+ commands:
+ - |
+ timeout 20m bash -c '
+ pip install voxcpm
+ export VLLM_LOGGING_LEVEL=DEBUG
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ pytest -s -v tests/e2e/offline_inference/test_voxcpm2.py -m "core_model" --run-level "core_model"
+ '
+ agents:
+ queue: "gpu_1_queue"
+ plugins:
+ - docker#v5.2.0:
+ image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ always-pull: true
+ propagate-environment: true
+ shm-size: "8gb"
+ environment:
+ - "HF_HOME=/fsx/hf_cache"
+ - "HF_TOKEN"
+ volumes:
+ - "/fsx/hf_cache:/fsx/hf_cache"
+
- label: "OmniVoice E2E Test"
timeout_in_minutes: 20
depends_on: upload-ready-pipeline
@@ -339,6 +367,33 @@ steps:
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"
+ - label: "Qwen3-TTS Base E2E Test (ModelRunner V2)"
+ depends_on: upload-ready-pipeline
+ soft_fail:
+ - exit_status: 1
+ commands:
+ - |
+ timeout 20m bash -c '
+ export VLLM_LOGGING_LEVEL=DEBUG
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
+ export VLLM_OMNI_USE_V2_RUNNER="1"
+ pytest -s -v tests/e2e/online_serving/test_qwen3_tts_base.py -m "core_model" --run-level "core_model"
+ '
+ agents:
+ queue: "gpu_1_queue"
+ plugins:
+ - docker#v5.2.0:
+ image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ always-pull: true
+ propagate-environment: true
+ shm-size: "8gb"
+ environment:
+ - "HF_HOME=/fsx/hf_cache"
+ - "HF_TOKEN"
+ volumes:
+ - "/fsx/hf_cache:/fsx/hf_cache"
+
- label: "Voxtral-TTS E2E Test"
timeout_in_minutes: 20
depends_on: upload-ready-pipeline
diff --git a/.buildkite/test-template-amd-omni.j2 b/.buildkite/test-template-amd-omni.j2
index 8dc91a1172..f4c386a5fe 100644
--- a/.buildkite/test-template-amd-omni.j2
+++ b/.buildkite/test-template-amd-omni.j2
@@ -48,6 +48,9 @@
DOCKER_BUILDKIT: "1"
TEST_COMMAND: |-
(command rocm-smi || true) && cd {{ (step.working_dir or default_working_dir) | safe }}
+{% if "mi250" in step.agent_pool %}
+ python3 -m pip uninstall -y amd-aiter
+{% endif %}
{{ indented_cmd | safe }}
priority: 100
{% if step.grade and step.grade == "Blocking" %}
diff --git a/.claude/skills/add-diffusion-model/SKILL.md b/.claude/skills/add-diffusion-model/SKILL.md
new file mode 100644
index 0000000000..a7e0bbf9a5
--- /dev/null
+++ b/.claude/skills/add-diffusion-model/SKILL.md
@@ -0,0 +1,534 @@
+---
+name: add-diffusion-model
+description: Add a new diffusion model (text-to-image, text-to-video, image-to-video, text-to-audio, image editing) to vLLM-Omni, including Cache-DiT acceleration and parallelism support (TP, SP/USP, CFG-Parallel, HSDP). Use when integrating a new diffusion model, porting a diffusers pipeline or a custom model repo to vllm-omni, creating a new DiT transformer adapter, adding diffusion model support, or enabling multi-GPU parallelism and cache acceleration for an existing model.
+---
+
+# Adding a Diffusion Model to vLLM-Omni
+
+## Overview
+
+This skill guides you through adding a new diffusion model to vLLM-Omni. The model may come from HuggingFace Diffusers (structured pipeline) or from a private/custom repo. The workflow differs significantly depending on the source.
+
+## Prerequisites
+
+Before starting, determine:
+
+1. **Model category**: Text-to-Image, Text-to-Video, Image-to-Video, Image Editing, Text-to-Audio, or Omni
+2. **Reference source**: Diffusers pipeline, custom repo, or a combination
+3. **Model HuggingFace ID** or local checkpoint path
+4. **Architecture**: Scheduler, text encoder, VAE, transformer/backbone
+
+## Step 0: Classify the Migration Path
+
+Check the model's HF repo for `model_index.json`. This determines your path:
+
+| Scenario | How to identify | Migration path |
+|----------|----------------|----------------|
+| **Already supported** | `_class_name` in `model_index.json` matches a key in `_DIFFUSION_MODELS` in `registry.py` | Skip to Step 5 (test) and Step 7 (docs) |
+| **Diffusers-based** | Has standard `model_index.json` with `_diffusers_version`, subfolders for `transformer/`, `vae/`, etc. | Follow **Path A** below |
+| **Custom/private repo** | No diffusers `model_index.json`, weights in non-standard format, custom model code in a separate git repo | Follow **Path B** below |
+| **Hybrid** | Has some diffusers components (VAE) but custom transformer/fusion | Mix of Path A and Path B |
+
+## Path A: Diffusers-Based Model
+
+For models with a standard diffusers layout. See [references/transformer-adaptation.md](references/transformer-adaptation.md) for detailed code patterns.
+
+### A1. Analyze `model_index.json`
+
+Identify components: `transformer`, `scheduler`, `vae`, `text_encoder`, `tokenizer`.
+
+### A2. Create model directory
+
+```
+vllm_omni/diffusion/models/your_model_name/
+├── __init__.py
+├── pipeline_your_model.py
+└── your_model_transformer.py
+```
+
+### A3. Adapt transformer
+
+1. Copy from diffusers source. Remove mixins (`ModelMixin`, `ConfigMixin`, `AttentionModuleMixin`).
+2. Replace attention with `vllm_omni.diffusion.attention.layer.Attention` (QKV shape: `[B, seq, heads, head_dim]`).
+3. Add `od_config: OmniDiffusionConfig | None = None` to `__init__`.
+4. Add `load_weights()` method mapping diffusers weight names to vllm-omni names.
+5. Add class attributes: `_repeated_blocks`, `_layerwise_offload_blocks_attr`.
+
+### A4. Adapt pipeline
+
+Inherit from `nn.Module`. The key contract:
+
+```python
+class YourPipeline(nn.Module):
+ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
+ # Load VAE, text encoder, tokenizer via from_pretrained()
+ # Instantiate transformer (weights loaded later via weights_sources)
+ self.weights_sources = [
+ DiffusersPipelineLoader.ComponentSource(
+ model_or_path=od_config.model, subfolder="transformer",
+ prefix="transformer.", fall_back_to_pt=True)]
+
+ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
+ # Encode prompt → prepare latents → denoise loop → VAE decode
+ return DiffusionOutput(output=output)
+
+ def load_weights(self, weights):
+ return AutoWeightsLoader(self).load_weights(weights)
+```
+
+Add post/pre-process functions in the same pipeline file. Register them in `registry.py`.
+
+### A5. Register, test, docs → continue at Step 4 below.
+
+---
+
+## Path B: Custom/Private Repo Model
+
+For models without a diffusers pipeline — weights in custom format, model code in a private repo. Real examples: DreamID-Omni, BAGEL, HunyuanImage3.
+
+### B1. Understand the reference repo
+
+Study the original model's code to identify:
+- **Model architecture files** (transformers, fusion modules, embeddings)
+- **Weight format** (safetensors, `.pth`, custom checkpoint structure)
+- **Weight loading helpers** (custom init functions, checkpoint loaders)
+- **Pre/post-processing** (image/audio transforms, tokenization, VAE encode/decode)
+- **External dependencies** (packages not on PyPI)
+- **Config format** (JSON config files, hardcoded dicts)
+
+### B2. Decide what lives WHERE
+
+This is the key design decision for custom models. Follow these placement rules:
+
+| Code type | Where to place | Example |
+|-----------|---------------|---------|
+| **Pipeline orchestration** (init, forward, denoise loop) | `vllm_omni/diffusion/models//pipeline_.py` | Always required |
+| **Custom transformer/backbone** (ported and adapted to vllm-omni) | `vllm_omni/diffusion/models//_transformer.py` or similar | `wan2_2.py`, `fusion.py`, `bagel_transformer.py` |
+| **Custom sub-models** (VAE, fusion, autoencoder) | `vllm_omni/diffusion/models//` as separate files | `autoencoder.py`, `fusion.py` |
+| **External dependency code** (original repo utilities) | **External repo**, installed via download script or pip | `dreamid_omni` package via git clone |
+| **Hardcoded model configs** | Module-level dicts in pipeline file | `VIDEO_CONFIG`, `AUDIO_CONFIG` dicts |
+| **Download/setup script** | `examples/offline_inference//download_.py` | `download_dreamid_omni.py` |
+| **Custom `model_index.json`** | Generated by download script, placed at model root | Minimal: `{"_class_name": "YourPipeline", ...}` |
+
+### B3. Handle external dependencies
+
+If the model's code lives in a separate git repo:
+
+**Option 1: Import with graceful fallback** (recommended for models with external utils)
+
+```python
+try:
+ from external_model.utils import init_vae, load_checkpoint
+except ImportError:
+ raise ImportError(
+ "Failed to import from dependency 'external_model'. "
+ "Please run the download script first."
+ )
+```
+
+**Option 2: Port the code directly** (preferred when feasible)
+
+Copy the essential model files into `vllm_omni/diffusion/models//` and adapt them. This avoids external dependencies. BAGEL does this — `autoencoder.py` and `bagel_transformer.py` are ported directly.
+
+**Decision criteria**: Port if the code is self-contained and won't diverge. Use external deps if the model repo is actively maintained and the code is complex.
+
+### B4. Handle custom weight loading
+
+Custom models have two patterns for weight loading:
+
+**Pattern 1: Bypass standard loader** (DreamID-Omni style)
+
+When the original model has complex custom init functions that load weights in `__init__`:
+
+```python
+class CustomPipeline(nn.Module):
+ def __init__(self, *, od_config, prefix=""):
+ super().__init__()
+ model = od_config.model
+ # Load everything eagerly in __init__ using custom helpers
+ self.vae = custom_init_vae(model, device=self.device)
+ self.text_encoder = custom_init_text_encoder(model, device=self.device)
+ self.transformer = CustomFusionModel(CONFIG)
+ load_custom_checkpoint(self.transformer,
+ checkpoint_path=os.path.join(model, "model.safetensors"))
+ # NO weights_sources defined — bypasses standard loader
+
+ def load_weights(self, weights):
+ pass # No-op — all weights loaded in __init__
+```
+
+**Pattern 2: Use standard loader with custom `load_weights`** (BAGEL style)
+
+When weights are in safetensors format but need name remapping:
+
+```python
+class CustomPipeline(nn.Module):
+ def __init__(self, *, od_config, prefix=""):
+ super().__init__()
+ # Instantiate model architecture without weights
+ self.bagel = BagelModel(config)
+ self.vae = AutoEncoder(ae_params)
+
+ # Point loader at the safetensors in the model root
+ self.weights_sources = [
+ DiffusersPipelineLoader.ComponentSource(
+ model_or_path=od_config.model,
+ subfolder=None, # weights at root, not in subfolder
+ prefix="",
+ fall_back_to_pt=False,
+ )
+ ]
+
+ def load_weights(self, weights):
+ # Custom name remapping for non-diffusers weight names
+ params = dict(self.named_parameters())
+ loaded = set()
+ for name, tensor in weights:
+ # Remap original weight names to vllm-omni module names
+ name = self._remap_weight_name(name)
+ if name in params:
+ default_weight_loader(params[name], tensor)
+ loaded.add(name)
+ return loaded
+```
+
+### B5. Create the `model_index.json`
+
+Custom models need a `model_index.json` at the model root for vllm-omni to discover them. For custom models, this is minimal:
+
+```json
+{
+ "_class_name": "YourModelPipeline",
+ "custom_key": "path/to/custom_weights.safetensors"
+}
+```
+
+The `_class_name` must match a key in `_DIFFUSION_MODELS` in `registry.py`. Additional keys are model-specific (accessed via `od_config.model_config`).
+
+If the model's weights come from multiple HF repos, write a **download script** that:
+1. Downloads from each repo
+2. Assembles into a single directory
+3. Generates `model_index.json`
+4. Installs any external dependencies (git clone + `.pth` file)
+
+Place at: `examples/offline_inference//download_.py`
+
+### B6. Handle multi-modal inputs
+
+If the model accepts images, audio, or other multi-modal inputs, implement the protocol classes from `vllm_omni/diffusion/models/interface.py`:
+
+```python
+from vllm_omni.diffusion.models.interface import SupportImageInput, SupportAudioInput
+
+class MyPipeline(nn.Module, SupportImageInput, SupportAudioInput):
+ # Protocol markers — the engine uses these to enable proper input routing
+ pass
+```
+
+Preprocessing for custom models is typically done **inside `forward()`** rather than via registered pre-process functions, since the logic is often tightly coupled to the model.
+
+### B7. Continue at Step 4 below.
+
+---
+
+## Common Steps (Both Paths)
+
+### Step 4: Register Model in registry.py
+
+Edit `vllm_omni/diffusion/registry.py`:
+
+```python
+_DIFFUSION_MODELS = {
+ "YourModelPipeline": ("your_model_name", "pipeline_your_model", "YourModelPipeline"),
+}
+_DIFFUSION_POST_PROCESS_FUNCS = {
+ "YourModelPipeline": "get_your_model_post_process_func", # if applicable
+}
+_DIFFUSION_PRE_PROCESS_FUNCS = {
+ "YourModelPipeline": "get_your_model_pre_process_func", # if applicable
+}
+```
+
+The registry key is the `_class_name` from `model_index.json`. The tuple is `(folder_name, module_file, class_name)`.
+
+Create `__init__.py` exporting the pipeline class and any factory functions.
+
+### Step 5: Run, Test, Debug
+
+Use the appropriate existing example script:
+
+| Category | Script |
+|----------|--------|
+| Text-to-Image | `examples/offline_inference/text_to_image/text_to_image.py` |
+| Text-to-Video | `examples/offline_inference/text_to_video/text_to_video.py` |
+| Image-to-Video | `examples/offline_inference/image_to_video/image_to_video.py` |
+| Image-to-Image | `examples/offline_inference/image_to_image/image_edit.py` |
+| Text-to-Audio | `examples/offline_inference/text_to_audio/text_to_audio.py` |
+
+For custom/Omni models that don't fit these categories, create a dedicated example script.
+
+**Validation**: No errors, output is meaningful, quality matches reference implementation.
+
+See [references/troubleshooting.md](references/troubleshooting.md) for common errors.
+
+### Step 6: Add Example Scripts
+
+For Omni or custom models, create:
+- `examples/offline_inference/your_model_name/` — offline script + README
+- `examples/online_serving/your_model_name/` — server script + client
+- Download script if weights require assembly from multiple sources
+
+### Step 7: Update Documentation
+
+Required updates:
+1. `docs/user_guide/diffusion/parallelism_acceleration.md` — parallelism support table
+2. `docs/user_guide/diffusion/teacache.md` — if TeaCache supported
+3. `docs/user_guide/diffusion/cache_dit_acceleration.md` — if Cache-DiT supported
+4. `examples/offline_inference/xxx/README.md` — offline example docs
+5. `examples/online_serve/xxx/README.md` — online serve docs
+
+### Step 8: Add E2E Tests (Recommended)
+
+Create `tests/e2e/online_serving/test_your_model_expansion.py`.
+
+### Step 9: Add Cache-DiT Acceleration
+
+Cache-DiT accelerates inference by caching intermediate computation results across denoising steps. After your model is working correctly on a single GPU, add cache-dit support.
+
+See [references/cache-dit-patterns.md](references/cache-dit-patterns.md) for detailed code patterns.
+
+#### 9a. Determine your model type
+
+| Model Type | Description | Action |
+|------------|-------------|--------|
+| **Standard single-transformer** | One transformer with one `ModuleList` of blocks | No code needed — `CacheDiTBackend` auto-detects via `enable_cache_for_dit()` |
+| **Multi-block-list** | One transformer with multiple block lists (e.g., `transformer_blocks` + `single_transformer_blocks`) | Write custom enabler with `BlockAdapter` |
+| **Dual-transformer** | Two transformers (e.g., high-noise + low-noise) | Write custom enabler with `BlockAdapter` wrapping both |
+
+#### 9b. Standard models — verify automatic support
+
+For standard single-transformer models, test directly:
+
+```python
+omni = Omni(
+ model="your-model-name",
+ cache_backend="cache_dit",
+ cache_config={
+ "Fn_compute_blocks": 1,
+ "Bn_compute_blocks": 0,
+ "max_warmup_steps": 4,
+ }
+)
+```
+
+Check logs for "Cache-dit enabled successfully on xxx". If it works, skip to Step 9e.
+
+#### 9c. Custom architectures — write a custom enabler
+
+For multi-block-list or dual-transformer models, write a custom enabler function:
+
+```python
+from cache_dit import BlockAdapter, ForwardPattern, ParamsModifier, DBCacheConfig
+
+def enable_cache_for_your_model(pipeline, cache_config):
+ db_cache_config = DBCacheConfig(
+ num_inference_steps=None,
+ Fn_compute_blocks=cache_config.Fn_compute_blocks,
+ Bn_compute_blocks=cache_config.Bn_compute_blocks,
+ max_warmup_steps=cache_config.max_warmup_steps,
+ max_cached_steps=cache_config.max_cached_steps,
+ max_continuous_cached_steps=cache_config.max_continuous_cached_steps,
+ residual_diff_threshold=cache_config.residual_diff_threshold,
+ )
+
+ cache_dit.enable_cache(
+ BlockAdapter(
+ transformer=pipeline.transformer,
+ blocks=[
+ pipeline.transformer.transformer_blocks,
+ pipeline.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_1],
+ params_modifiers=[ParamsModifier(...)],
+ ),
+ cache_config=db_cache_config,
+ )
+
+ def refresh_cache_context(pipeline, num_inference_steps, verbose=True):
+ cache_dit.refresh_context(
+ pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose
+ )
+ return refresh_cache_context
+```
+
+#### 9d. Register the custom enabler
+
+Add your enabler to `CUSTOM_DIT_ENABLERS` in `vllm_omni/diffusion/cache/cache_dit_backend.py`:
+
+```python
+CUSTOM_DIT_ENABLERS = {
+ "Wan22Pipeline": enable_cache_for_wan22,
+ "LongCatImagePipeline": enable_cache_for_longcat_image,
+ "YourModelPipeline": enable_cache_for_your_model, # Add here
+}
+```
+
+#### 9e. Test Cache-DiT
+
+```python
+omni = Omni(
+ model="your-model-name",
+ cache_backend="cache_dit",
+ cache_config={
+ "Fn_compute_blocks": 1, "Bn_compute_blocks": 0,
+ "max_warmup_steps": 4, "residual_diff_threshold": 0.24,
+ }
+)
+images = omni.generate("a beautiful landscape",
+ OmniDiffusionSamplingParams(num_inference_steps=50))
+```
+
+**Verify**: 1) logs show cache enabled, 2) 1.5-2x speedup, 3) output quality acceptable vs baseline.
+
+If quality degrades, lower `residual_diff_threshold` (try 0.12-0.18) or increase `max_warmup_steps` (try 6-8).
+
+---
+
+### Step 10: Add Parallelism Support
+
+After the model works on a single GPU, add multi-GPU parallelism. Add each type incrementally, testing after each addition.
+
+See [references/parallelism-patterns.md](references/parallelism-patterns.md) for detailed code patterns and API reference.
+
+**Recommended order**: TP → SP/USP → CFG Parallel → HSDP
+
+#### 10a. Tensor Parallelism (TP)
+
+Shards DiT linear layers across GPUs. Requires code changes in the transformer.
+
+**What to change in the transformer**:
+1. Replace `nn.Linear` with `ColumnParallelLinear` / `RowParallelLinear` / `QKVParallelLinear`
+2. Update `load_weights()` to handle QKV fusion with `stacked_params_mapping`
+3. Use `self.to_qkv.num_heads` (local heads) instead of total heads for split sizes
+
+```python
+from vllm.model_executor.layers.linear import (
+ QKVParallelLinear, RowParallelLinear, ColumnParallelLinear,
+)
+
+# Attention: QKV → RowParallel output
+self.to_qkv = QKVParallelLinear(dim, head_dim, num_heads, num_kv_heads)
+self.to_out = RowParallelLinear(dim, dim, input_is_parallel=True)
+
+# FFN: ColumnParallel → RowParallel
+self.w1 = ColumnParallelLinear(dim, ffn_dim)
+self.w2 = RowParallelLinear(ffn_dim, dim, input_is_parallel=True)
+```
+
+**Constraints**: `num_heads % tp_size == 0` and `num_kv_heads % tp_size == 0`.
+
+**Test**: `--tensor-parallel-size 2`
+
+#### 10b. Sequence Parallelism (SP / USP)
+
+Splits sequence tokens across GPUs. Non-intrusive via `_sp_plan` on the transformer class — no changes to `forward()`.
+
+**What to change in the transformer**:
+
+Add `_sp_plan` class attribute:
+
+```python
+from vllm_omni.diffusion.distributed.sp_plan import (
+ SequenceParallelInput, SequenceParallelOutput,
+)
+
+class YourTransformer(nn.Module):
+ _sp_plan = {
+ "blocks.0": {
+ "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
+ },
+ "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
+ }
+```
+
+If inline tensor ops (e.g., `torch.cat`) exist between shard/gather points, extract them into `nn.Module` submodules so hooks can intercept them.
+
+For RoPE that needs splitting, add an entry for the RoPE module with `split_output=True`.
+
+**Test**: `--ulysses-degree 2` (offline) or `--usp 2` (online serving)
+
+#### 10c. CFG Parallel
+
+Distributes positive/negative CFG branches across 2 GPUs. Requires the pipeline to inherit `CFGParallelMixin`.
+
+**What to change in the pipeline**:
+
+```python
+from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
+
+class YourPipeline(nn.Module, CFGParallelMixin):
+ def diffuse(self, ...) -> torch.Tensor:
+ for i, t in enumerate(timesteps):
+ positive_kwargs = {...}
+ negative_kwargs = {...} if do_true_cfg else None
+ noise_pred = self.predict_noise_maybe_with_cfg(
+ do_true_cfg=do_true_cfg, true_cfg_scale=cfg_scale,
+ positive_kwargs=positive_kwargs, negative_kwargs=negative_kwargs,
+ )
+ latents = self.scheduler_step_maybe_with_cfg(
+ noise_pred, t, latents, do_true_cfg
+ )
+ return latents
+```
+
+Override `predict_noise()` if your transformer call is non-standard. Override `combine_cfg_noise()` for multi-output models (e.g., video + audio).
+
+**Constraint**: Exactly 2 GPUs. Only for models using classifier-free guidance.
+
+**Test**: `--cfg-parallel-size 2`
+
+#### 10d. HSDP (Hybrid Sharded Data Parallel)
+
+Shards transformer weights via PyTorch FSDP2 to reduce per-GPU VRAM. No code changes to the forward pass — just add a class attribute.
+
+**What to change in the transformer**:
+
+```python
+class YourTransformer(nn.Module):
+ @staticmethod
+ def _is_transformer_block(name: str, module) -> bool:
+ return "blocks" in name and name.split(".")[-1].isdigit()
+
+ _hsdp_shard_conditions = [_is_transformer_block]
+```
+
+**Constraint**: Cannot combine with TP. For standalone HSDP, set `hsdp_shard_size` explicitly.
+
+**Test**: `--use-hsdp` or `DiffusionParallelConfig(use_hsdp=True)`
+
+#### 10e. Update parallelism documentation
+
+After adding parallelism support, update:
+1. `docs/user_guide/diffusion/parallelism_acceleration.md` — add your model to the support table
+2. Record which parallelism methods are supported (USP, Ring, CFG, TP, HSDP, VAE-Patch)
+
+---
+
+## Iterative Development Tips
+
+1. **Start minimal**: Basic generation first, no parallelism/caching
+2. **Use `--enforce-eager`**: Disable torch.compile during debugging
+3. **Use small models**: Test with smaller variants first
+4. **Check tensor shapes**: Most errors are reshape mismatches in attention
+5. **Add features incrementally**: Single GPU → TP → SP → CFG → HSDP → Cache-DiT
+6. **For custom models**: Get the model running with the original code first, then progressively replace components with vllm-omni equivalents
+7. **Cache-DiT before parallelism tuning**: Cache-DiT is lossy — verify quality at baseline before combining with parallelism
+8. **Combine lossless + lossy**: e.g., TP + SP + Cache-DiT for maximum throughput
+
+## Reference Files
+
+- [Transformer Adaptation](references/transformer-adaptation.md) — porting transformers from diffusers
+- [Custom Model Patterns](references/custom-model-patterns.md) — patterns for non-diffusers models
+- [Parallelism Patterns](references/parallelism-patterns.md) — TP, SP/USP, CFG parallel, HSDP implementation details
+- [Cache-DiT Patterns](references/cache-dit-patterns.md) — cache-dit acceleration for standard and custom architectures
+- [Troubleshooting](references/troubleshooting.md) — common errors and fixes
diff --git a/.claude/skills/add-diffusion-model/references/cache-dit-patterns.md b/.claude/skills/add-diffusion-model/references/cache-dit-patterns.md
new file mode 100644
index 0000000000..d34ce0e0f4
--- /dev/null
+++ b/.claude/skills/add-diffusion-model/references/cache-dit-patterns.md
@@ -0,0 +1,254 @@
+# Cache-DiT Patterns Reference
+
+## Overview
+
+Cache-DiT accelerates Diffusion Transformers by caching intermediate computation results across denoising steps. Adjacent steps produce similar features, so redundant computations can be skipped.
+
+Three caching strategies:
+- **DBCache**: Dynamic block-level caching — selectively computes or caches transformer blocks based on residual differences
+- **TaylorSeer**: Calibration-based prediction using Taylor expansion to estimate block outputs
+- **SCM** (Step Computation Masking): Dynamic step skipping based on configurable policies
+
+**Typical speedup**: 1.5-2.5x depending on model and configuration.
+
+**Official docs**: https://docs.vllm.ai/projects/vllm-omni/en/latest/design/feature/cache_dit
+
+## Architecture
+
+vLLM-Omni integrates cache-dit through `CacheDiTBackend`:
+
+| Component | Purpose |
+|-----------|---------|
+| `CacheDiTBackend` | Unified backend — auto-selects enabler (standard or custom) |
+| `enable_cache_for_dit()` | Default enabler for standard single-transformer models |
+| `CUSTOM_DIT_ENABLERS` dict | Registry of custom enablers keyed by pipeline class name |
+| `BlockAdapter` | Wraps complex architectures (multi-block-list or multi-transformer) |
+| `ForwardPattern` | Specifies block forward signature: `Pattern_0`, `Pattern_1`, `Pattern_2` |
+| `ParamsModifier` | Per-transformer or per-block-list config customization |
+| `DBCacheConfig` | Configuration for DBCache parameters |
+| `cache_dit.refresh_context()` | Updates cache context when `num_inference_steps` changes |
+
+**Source files**:
+- `vllm_omni/diffusion/cache/cache_dit_backend.py` — `CacheDiTBackend`, enablers, `CUSTOM_DIT_ENABLERS`
+- `vllm_omni/diffusion/cache/` — cache backend implementations
+
+## Standard Models: Automatic Support
+
+Most DiT models follow this pattern:
+- Single transformer with one `nn.ModuleList` of blocks
+- Standard forward signature
+- Compatible with cache-dit's automatic detection
+
+**Examples**: Qwen-Image, Z-Image, FLUX
+
+No code changes needed. `CacheDiTBackend` automatically uses `enable_cache_for_dit()`:
+
+```python
+from vllm_omni import Omni
+
+omni = Omni(
+ model="Qwen/Qwen-Image",
+ cache_backend="cache_dit",
+ cache_config={
+ "Fn_compute_blocks": 1,
+ "Bn_compute_blocks": 0,
+ "max_warmup_steps": 4,
+ }
+)
+```
+
+What happens automatically:
+
+```python
+def enable_cache_for_dit(pipeline, cache_config):
+ db_cache_config = DBCacheConfig(
+ num_inference_steps=None,
+ Fn_compute_blocks=cache_config.Fn_compute_blocks,
+ Bn_compute_blocks=cache_config.Bn_compute_blocks,
+ max_warmup_steps=cache_config.max_warmup_steps,
+ max_cached_steps=cache_config.max_cached_steps,
+ max_continuous_cached_steps=cache_config.max_continuous_cached_steps,
+ residual_diff_threshold=cache_config.residual_diff_threshold,
+ )
+
+ cache_dit.enable_cache(pipeline.transformer, cache_config=db_cache_config)
+
+ def refresh_cache_context(pipeline, num_inference_steps, verbose=True):
+ cache_dit.refresh_context(
+ pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose
+ )
+ return refresh_cache_context
+```
+
+## Custom Architectures: Writing Custom Enablers
+
+### When you need a custom enabler
+
+- Model has multiple block lists in one transformer (e.g., `transformer_blocks` + `single_transformer_blocks`)
+- Model has two transformers (e.g., high-noise + low-noise like Wan2.2)
+- Model uses non-standard block forward signature
+
+### Pattern 1: Multi-Block-List (LongCat-Image style)
+
+Single transformer with two block lists:
+
+```python
+import cache_dit
+from cache_dit import BlockAdapter, ForwardPattern, ParamsModifier, DBCacheConfig
+
+def enable_cache_for_your_model(pipeline, cache_config):
+ db_cache_config = DBCacheConfig(
+ num_inference_steps=None,
+ Fn_compute_blocks=cache_config.Fn_compute_blocks,
+ Bn_compute_blocks=cache_config.Bn_compute_blocks,
+ max_warmup_steps=cache_config.max_warmup_steps,
+ max_cached_steps=cache_config.max_cached_steps,
+ max_continuous_cached_steps=cache_config.max_continuous_cached_steps,
+ residual_diff_threshold=cache_config.residual_diff_threshold,
+ )
+
+ cache_dit.enable_cache(
+ BlockAdapter(
+ transformer=pipeline.transformer,
+ blocks=[
+ pipeline.transformer.transformer_blocks,
+ pipeline.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_1],
+ params_modifiers=[ParamsModifier(...)],
+ ),
+ cache_config=db_cache_config,
+ )
+
+ def refresh_cache_context(pipeline, num_inference_steps, verbose=True):
+ cache_dit.refresh_context(
+ pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose
+ )
+ return refresh_cache_context
+```
+
+For single transformer with multiple block lists, `refresh_context` works the same as standard models — call it once on the transformer.
+
+### Pattern 2: Dual-Transformer (Wan2.2 style)
+
+Two transformers with separate configs:
+
+```python
+def enable_cache_for_dual_transformer(pipeline, cache_config):
+ db_cache_config = DBCacheConfig(...)
+
+ cache_dit.enable_cache(
+ BlockAdapter(
+ transformer=[pipeline.transformer, pipeline.transformer_2],
+ blocks=[pipeline.transformer.blocks, pipeline.transformer_2.blocks],
+ forward_pattern=[ForwardPattern.Pattern_2, ForwardPattern.Pattern_2],
+ params_modifiers=[
+ ParamsModifier(...), # Config for transformer 1
+ ParamsModifier(...), # Config for transformer 2
+ ],
+ ),
+ cache_config=db_cache_config,
+ )
+
+ def refresh_cache_context(pipeline, num_inference_steps, verbose=True):
+ high_steps, low_steps = _split_inference_steps(num_inference_steps)
+ cache_dit.refresh_context(
+ pipeline.transformer, num_inference_steps=high_steps, verbose=verbose
+ )
+ cache_dit.refresh_context(
+ pipeline.transformer_2, num_inference_steps=low_steps, verbose=verbose
+ )
+ return refresh_cache_context
+```
+
+Key difference: `refresh_context` must be called on **each transformer separately** with its own step count.
+
+### Choosing the ForwardPattern
+
+| Pattern | Block forward signature | Example models |
+|---------|------------------------|----------------|
+| `Pattern_0` | `block(hidden_states, **kwargs)` → residual added inside block | Default |
+| `Pattern_1` | `block(hidden_states, **kwargs)` → returns `(hidden_states, ...)` tuple | FLUX-style single blocks |
+| `Pattern_2` | `block(hidden_states, **kwargs)` → `(hidden_states, ...)` with different residual pattern | Wan2.2 blocks |
+
+Inspect your block's `forward()` return type and residual connection pattern to choose the right one. See [Cache-DiT API Reference](https://cache-dit.readthedocs.io/en/latest/user_guide/CACHE_API/) for details.
+
+## Registering Custom Enablers
+
+Add your enabler to `CUSTOM_DIT_ENABLERS` in `vllm_omni/diffusion/cache/cache_dit_backend.py`:
+
+```python
+CUSTOM_DIT_ENABLERS = {
+ "Wan22Pipeline": enable_cache_for_wan22,
+ "LongCatImagePipeline": enable_cache_for_longcat_image,
+ "YourModelPipeline": enable_cache_for_your_model,
+}
+```
+
+The key must match `pipeline.__class__.__name__`.
+
+## Configuration Parameters
+
+| Parameter | Default | Description |
+|-----------|---------|-------------|
+| `Fn_compute_blocks` | 1 | Number of blocks to always compute at the front |
+| `Bn_compute_blocks` | 0 | Number of blocks to always compute at the back |
+| `max_warmup_steps` | 4 | Steps to run without caching at the beginning |
+| `max_cached_steps` | — | Max total cached steps |
+| `max_continuous_cached_steps` | — | Max consecutive cached steps |
+| `residual_diff_threshold` | 0.24 | Threshold for deciding whether to cache a block |
+
+### Tuning for quality vs speed
+
+| Goal | Adjustments |
+|------|-------------|
+| **More speed, acceptable quality loss** | Higher `residual_diff_threshold` (0.24-0.4), lower `max_warmup_steps` (2-4) |
+| **Better quality, less speed** | Lower `residual_diff_threshold` (0.12-0.18), higher `max_warmup_steps` (6-8), lower `max_continuous_cached_steps` (2) |
+
+## Testing
+
+```python
+from vllm_omni import Omni
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+
+omni = Omni(
+ model="your-model-name",
+ cache_backend="cache_dit",
+ cache_config={
+ "Fn_compute_blocks": 1,
+ "Bn_compute_blocks": 0,
+ "max_warmup_steps": 4,
+ "residual_diff_threshold": 0.24,
+ }
+)
+images = omni.generate(
+ "a beautiful landscape",
+ OmniDiffusionSamplingParams(num_inference_steps=50),
+)
+```
+
+CLI (online serving):
+
+```bash
+vllm serve your-model --omni --port 8098 \
+ --cache-backend cache_dit \
+ --cache-config '{"Fn_compute_blocks": 1, "Bn_compute_blocks": 0, "max_warmup_steps": 4}'
+```
+
+**Verification checklist**:
+1. Logs show "Cache-dit enabled successfully on xxx"
+2. Performance: 1.5-2x speedup vs no cache
+3. Quality: compare output with `cache_backend=None`
+
+## Excluded Models
+
+Models listed in `_NO_CACHE_ACCELERATION` in `vllm_omni/diffusion/registry.py` do not support cache-dit (e.g., `NextStep11Pipeline`, `StableDiffusionPipeline`). Check this set before attempting to enable cache-dit.
+
+## Reference Implementations
+
+| Model | Path | Notes |
+|-------|------|-------|
+| Standard DiT | `cache_dit_backend.py::enable_cache_for_dit` | Default enabler, automatic |
+| Wan2.2 | `cache_dit_backend.py::enable_cache_for_wan22` | Dual-transformer, auto-detects mode |
+| LongCat | `cache_dit_backend.py::enable_cache_for_longcat_image` | Multi-block-list |
+| BAGEL | `cache_dit_backend.py::enable_cache_for_bagel` | Complex omni model |
diff --git a/.claude/skills/add-diffusion-model/references/custom-model-patterns.md b/.claude/skills/add-diffusion-model/references/custom-model-patterns.md
new file mode 100644
index 0000000000..2434e0b5da
--- /dev/null
+++ b/.claude/skills/add-diffusion-model/references/custom-model-patterns.md
@@ -0,0 +1,273 @@
+# Custom Model Patterns Reference
+
+Patterns for adding models that don't come from the standard diffusers pipeline format.
+
+## Directory Structure Comparison
+
+### Diffusers-based model (e.g., Wan2.2)
+
+```
+vllm_omni/diffusion/models/wan2_2/
+├── __init__.py # Exports pipeline + transformer + helpers
+├── pipeline_wan2_2.py # Pipeline: loads components via from_pretrained()
+├── pipeline_wan2_2_i2v.py # Variant pipeline for image-to-video
+└── wan2_2_transformer.py # Transformer: ported from diffusers, uses Attention layer
+```
+
+The transformer is loaded separately via `weights_sources` + `load_weights()`. Non-transformer components (VAE, text encoder) are loaded in `__init__` via `from_pretrained()`.
+
+### Custom model with external deps (e.g., DreamID-Omni)
+
+```
+vllm_omni/diffusion/models/dreamid_omni/
+├── __init__.py # Exports pipeline only
+├── pipeline_dreamid_omni.py # Pipeline: loads ALL weights in __init__ via custom helpers
+├── fusion.py # Custom fusion architecture (video + audio cross-attention)
+└── wan2_2.py # Re-implemented Wan backbone with split API
+
+examples/offline_inference/x_to_video_audio/
+└── download_dreamid_omni.py # Downloads weights from 3 HF repos + clones code repo
+```
+
+All weights loaded eagerly in `__init__`. `load_weights()` is a no-op. External dependency (`dreamid_omni` package) imported with try/except.
+
+### Custom model with ported code (e.g., BAGEL)
+
+```
+vllm_omni/diffusion/models/bagel/
+├── __init__.py
+├── pipeline_bagel.py # Pipeline: instantiates models, uses weights_sources
+├── bagel_transformer.py # Full LLM backbone (Qwen2-MoT) ported into vllm-omni
+└── autoencoder.py # Custom VAE ported from original repo
+```
+
+Model code is fully ported (no external dependency). Uses `weights_sources` and `load_weights()` with custom name remapping to handle non-diffusers safetensors format.
+
+## Weight Loading Patterns
+
+### Pattern 1: Standard diffusers flow (Wan2.2, Z-Image, FLUX)
+
+```
+init → create transformer (empty) → set weights_sources → [loader calls load_weights()]
+```
+
+- `weights_sources` points to safetensors in HF subfolder (e.g., `transformer/`)
+- `load_weights()` receives `(name, tensor)` pairs from the loader
+- Name remapping handles diffusers→vllm-omni differences (QKV fusion, Sequential index removal)
+
+### Pattern 2: Custom safetensors at root (BAGEL)
+
+```
+init → create all models (empty) → set weights_sources(subfolder=None) → [loader calls load_weights()]
+```
+
+- `weights_sources` points to **root** of model directory, not a subfolder
+- Weights have non-diffusers names (e.g., `bagel.language_model.model.layers.0.self_attn.q_proj.weight`)
+- `load_weights()` does heavy name normalization
+
+```python
+self.weights_sources = [
+ DiffusersPipelineLoader.ComponentSource(
+ model_or_path=od_config.model,
+ subfolder=None, # root directory
+ prefix="", # no prefix stripping
+ fall_back_to_pt=False,
+ )
+]
+```
+
+### Pattern 3: Fully custom loading (DreamID-Omni)
+
+```
+init → load ALL weights eagerly via custom helpers → load_weights() = no-op
+```
+
+- No `weights_sources` attribute — standard loader finds nothing to iterate
+- Custom init functions (e.g., `init_wan_vae_2_2()`, `load_fusion_checkpoint()`) handle downloading and loading
+- `load_weights()` is `pass`
+- Weights may come from multiple HF repos in different formats (`.pth`, `.safetensors`)
+
+Use this when:
+- The original model has complex, well-tested loading code you don't want to rewrite
+- Weights span multiple HF repos
+- Weight format is non-standard (e.g., a single `.pth` file, not sharded safetensors)
+
+## model_index.json for Custom Models
+
+Standard diffusers `model_index.json`:
+```json
+{
+ "_class_name": "WanPipeline",
+ "_diffusers_version": "0.35.0.dev0",
+ "scheduler": ["diffusers", "UniPCMultistepScheduler"],
+ "transformer": ["diffusers", "WanTransformer3DModel"],
+ "vae": ["diffusers", "AutoencoderKLWan"]
+}
+```
+
+Custom model `model_index.json` (minimal):
+```json
+{
+ "_class_name": "DreamIDOmniPipeline",
+ "fusion": "DreamID-Omni/dreamid_omni.safetensors"
+}
+```
+
+The only **required** field is `_class_name` — it must match a key in `_DIFFUSION_MODELS` in `registry.py`. Other fields are model-specific and accessible via `od_config.model_config` dict.
+
+## External Dependency Management
+
+### Git clone + .pth injection (DreamID-Omni pattern)
+
+```python
+def download_dependency():
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
+ with open(LOCK_FILE, "w") as f:
+ fcntl.flock(f, fcntl.LOCK_EX)
+ if not DEPENDENCY_DIR.exists():
+ subprocess.run([
+ "git", "clone", "--depth", "1",
+ REPO_URL, "--branch", BRANCH,
+ str(DEPENDENCY_DIR)
+ ], check=True)
+ fcntl.flock(f, fcntl.LOCK_UN)
+
+ # Add to Python path via .pth file
+ site_packages = Path(site.getsitepackages()[0])
+ pth_file = site_packages / "vllm_omni_dependency.pth"
+ pth_file.write_text(str(DEPENDENCY_DIR))
+```
+
+### Direct port (BAGEL pattern)
+
+Copy essential files from the original repo into `vllm_omni/diffusion/models//`. Adapt imports to use vllm-omni utilities. Benefits: no external dependency, no git clone step. Drawback: must maintain the ported code.
+
+## Multi-Modal Input/Output Protocols
+
+Custom models that handle images, audio, or video I/O should implement protocol classes:
+
+```python
+from vllm_omni.diffusion.models.interface import (
+ SupportImageInput, # Model accepts image input
+ SupportAudioInput, # Model accepts audio input
+ SupportAudioOutput, # Model produces audio output
+)
+
+class MyPipeline(nn.Module, SupportImageInput, SupportAudioInput, SupportAudioOutput):
+ pass # Protocol markers enable proper engine routing
+```
+
+The engine checks `isinstance(pipeline, SupportImageInput)` at startup to configure input validation and warmup behavior.
+
+## Hardcoded Config vs Config Files
+
+Diffusers models use `config.json` in each subfolder. Custom models often use:
+
+**Module-level config dicts** (DreamID-Omni):
+```python
+VIDEO_CONFIG = {
+ "patch_size": [1, 2, 2], "model_type": "ti2v",
+ "dim": 3072, "ffn_dim": 14336, "num_heads": 24, "num_layers": 30, ...
+}
+```
+
+**Loaded from custom JSON** (BAGEL):
+```python
+cfg_path = os.path.join(model_path, "config.json")
+with open(cfg_path) as f:
+ bagel_cfg = json.load(f)
+vae_cfg = bagel_cfg.get("vae_config", {})
+```
+
+## Custom Architecture Patterns
+
+### Split forward API (DreamID-Omni)
+
+When a fusion model needs to interleave blocks from two backbones:
+
+```python
+class WanModel(nn.Module):
+ def prepare_transformer_block_kwargs(self, x, t, context, ...):
+ # Patch embed, time embed, text embed, RoPE
+ return x, e, kwargs
+
+ def post_transformer_block_out(self, x, grid_sizes, e):
+ # Output projection, unpatchify
+ return output
+
+ def forward(self, *args, **kwargs):
+ raise NotImplementedError # Fusion model handles block iteration
+```
+
+The `FusionModel` then iterates blocks in lock-step:
+```python
+for video_block, audio_block in zip(self.video_model.blocks, self.audio_model.blocks):
+ video_out = video_block(video_hidden, ...)
+ audio_out = audio_block(audio_hidden, ...)
+ # Cross-attend between modalities
+ video_out = cross_attention(video_out, audio_out)
+ audio_out = cross_attention(audio_out, video_out)
+```
+
+### LLM-as-denoiser (BAGEL)
+
+When the backbone is a language model that also does diffusion:
+
+```python
+class BagelModel(nn.Module):
+ def __init__(self):
+ self.language_model = Qwen2MoTForCausalLM(config)
+ self.vit_model = SiglipVisionModel(vit_config)
+```
+
+The LLM processes both text tokens and latent image tokens in a single forward pass, using KV caching for the text portion.
+
+## Pre/Post Processing for Custom Models
+
+Custom models typically handle pre/post processing **inside `forward()`** rather than via registered functions, because the logic is tightly coupled:
+
+```python
+def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
+ # Inline preprocessing
+ image = self._load_and_resize_image(req.prompts[0].get("multi_modal_data", {}).get("image"))
+ image_latent = self._vae_encode(image)
+
+ # ... denoising loop ...
+
+ # Inline postprocessing
+ pil_image = self._decode_to_pil(latents)
+ return DiffusionOutput(output=[pil_image])
+```
+
+If pre/post functions are not registered in `_DIFFUSION_PRE_PROCESS_FUNCS` / `_DIFFUSION_POST_PROCESS_FUNCS`, the engine simply skips those steps.
+
+## Download Script Template
+
+```python
+# examples/offline_inference//download_.py
+from huggingface_hub import snapshot_download
+import json, os
+
+def main(output_dir):
+ # Download model weights from HF
+ snapshot_download(repo_id="org/model-weights", local_dir=os.path.join(output_dir, "weights"))
+
+ # Download additional components if from separate repos
+ snapshot_download(repo_id="org/vae-weights", local_dir=os.path.join(output_dir, "vae"),
+ allow_patterns=["*.safetensors"])
+
+ # Generate model_index.json
+ config = {"_class_name": "YourPipeline", "custom_key": "weights/model.safetensors"}
+ with open(os.path.join(output_dir, "model_index.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ # Install external code dependency (if needed)
+ download_dependency()
+
+if __name__ == "__main__":
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--output-dir", default="./your_model")
+ args = parser.parse_args()
+ main(args.output_dir)
+```
diff --git a/.claude/skills/add-diffusion-model/references/parallelism-patterns.md b/.claude/skills/add-diffusion-model/references/parallelism-patterns.md
new file mode 100644
index 0000000000..933e2d2320
--- /dev/null
+++ b/.claude/skills/add-diffusion-model/references/parallelism-patterns.md
@@ -0,0 +1,571 @@
+# Parallelism Patterns Reference
+
+## Overview
+
+vLLM-Omni supports multiple parallelism strategies for diffusion models. Each targets a different bottleneck:
+
+| Strategy | Splits | Best For | Constraint |
+|----------|--------|----------|------------|
+| Tensor Parallel (TP) | Model layers across GPUs | Latency reduction, large models | Requires fast GPU interconnect, `num_heads % tp == 0` |
+| Sequence Parallel (SP/USP) | Sequence tokens across GPUs | Long sequences (video, high-res) | Near-linear scaling |
+| CFG Parallel | Positive/negative CFG branches | Models using classifier-free guidance | Exactly 2 GPUs |
+| HSDP | Weight shards via FSDP2 | VRAM reduction | Cannot combine with TP |
+| VAE Patch Parallel | VAE decode spatial tiles | Large VAE outputs | Auto-enables tiling |
+
+**Recommended integration order**: TP → SP → CFG Parallel → HSDP
+
+**Official design docs**:
+- TP: https://docs.vllm.ai/projects/vllm-omni/en/latest/design/feature/tensor_parallel
+- SP: https://docs.vllm.ai/projects/vllm-omni/en/latest/design/feature/sequence_parallel
+- CFG: https://docs.vllm.ai/projects/vllm-omni/en/latest/design/feature/cfg_parallel
+- HSDP: https://docs.vllm.ai/projects/vllm-omni/en/latest/design/feature/hsdp
+
+---
+
+## Tensor Parallelism (TP)
+
+Replace standard `nn.Linear` with vLLM's parallel linear layers. This is the most invasive change but provides direct VRAM savings and compute speedup.
+
+### Layer replacement rules
+
+| Pattern | vLLM Layer | When to Use |
+|---------|-----------|-------------|
+| Fan-out (first in FFN) | `ColumnParallelLinear` | Projection that splits output across ranks |
+| Fan-in (second in FFN) | `RowParallelLinear` | Projection that gathers across ranks |
+| QKV projection | `QKVParallelLinear` | Fused Q/K/V for self-attention |
+| Single Q or K or V | `ColumnParallelLinear` | Separate projections (cross-attention) |
+| Attention output | `RowParallelLinear` | Output projection after attention |
+| Must not shard | `ReplicatedLinear` | Layers that must stay replicated |
+
+### MLP Block (Up-Down Pattern)
+
+```python
+from vllm.model_executor.layers.linear import (
+ ColumnParallelLinear, RowParallelLinear,
+)
+
+class TPFeedForward(nn.Module):
+ def __init__(self, dim, ffn_dim):
+ super().__init__()
+ self.fc1 = ColumnParallelLinear(dim, ffn_dim, bias=False, return_bias=False)
+ self.fc2 = RowParallelLinear(
+ ffn_dim, dim, bias=False,
+ input_is_parallel=True, # Input already sharded from fc1
+ return_bias=False,
+ )
+
+ def forward(self, x):
+ x, _ = self.fc1(x)
+ x = torch.nn.functional.gelu(x)
+ x, _ = self.fc2(x)
+ return x
+```
+
+### Attention Block (QKV-Out Pattern)
+
+```python
+from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
+from vllm_omni.diffusion.attention.layer import Attention
+
+class TPSelfAttention(nn.Module):
+ def __init__(self, dim, num_heads, num_kv_heads=None):
+ super().__init__()
+ num_kv_heads = num_kv_heads or num_heads
+ self.head_dim = dim // num_heads
+
+ self.to_qkv = QKVParallelLinear(
+ hidden_size=dim,
+ head_size=self.head_dim,
+ total_num_heads=num_heads,
+ total_num_kv_heads=num_kv_heads,
+ bias=False,
+ return_bias=False,
+ )
+ self.to_out = RowParallelLinear(
+ dim, dim, bias=False,
+ input_is_parallel=True,
+ return_bias=False,
+ )
+ self.attn = Attention(
+ num_heads=self.to_qkv.num_heads, # Local heads per GPU
+ head_size=self.head_dim,
+ softmax_scale=1.0 / (self.head_dim ** 0.5),
+ causal=False,
+ num_kv_heads=self.to_qkv.num_kv_heads, # Local KV heads per GPU
+ )
+
+ def forward(self, x):
+ qkv, _ = self.to_qkv(x)
+ q, k, v = qkv.split(
+ [self.to_qkv.num_heads * self.head_dim,
+ self.to_qkv.num_kv_heads * self.head_dim,
+ self.to_qkv.num_kv_heads * self.head_dim],
+ dim=-1,
+ )
+ B, S, _ = x.shape
+ q = q.view(B, S, self.to_qkv.num_heads, self.head_dim)
+ k = k.view(B, S, self.to_qkv.num_kv_heads, self.head_dim)
+ v = v.view(B, S, self.to_qkv.num_kv_heads, self.head_dim)
+ out = self.attn(q, k, v)
+ out = out.reshape(B, S, -1)
+ out, _ = self.to_out(out)
+ return out
+```
+
+### QKV Fusion in load_weights
+
+When you fuse separate Q/K/V into `QKVParallelLinear`, map diffusers' separate weight names:
+
+```python
+stacked_params_mapping = [
+ ("to_qkv", "to_q", "q"),
+ ("to_qkv", "to_k", "k"),
+ ("to_qkv", "to_v", "v"),
+]
+
+def load_weights(self, weights):
+ params = dict(self.named_parameters())
+ loaded = set()
+ for name, tensor in weights:
+ for fused_name, orig_name, shard_id in stacked_params_mapping:
+ if orig_name in name:
+ name = name.replace(orig_name, fused_name)
+ param = params[name]
+ param.weight_loader(param, tensor, shard_id)
+ loaded.add(name)
+ break
+ else:
+ if name in params:
+ param = params[name]
+ if hasattr(param, "weight_loader"):
+ param.weight_loader(param, tensor)
+ else:
+ default_weight_loader(param, tensor)
+ loaded.add(name)
+ return loaded
+```
+
+### RMSNorm with TP
+
+When RMSNorm sits between TP-sharded dimensions, use `DistributedRMSNorm` — it computes global RMS via all-reduce across TP ranks. See the Wan2.2 implementation for the pattern.
+
+### TP Constraints
+
+- `num_heads % tp_size == 0`
+- `num_kv_heads % tp_size == 0`
+- Use `self.to_qkv.num_heads` (local per-GPU count), not total heads, for split sizes
+
+### Testing TP
+
+```bash
+python text_to_image.py --model Your-org/your-model \
+ --tensor-parallel-size 2 --output "tp_test.png"
+```
+
+**Verify**: speedup, memory reduction proportional to TP size, quality matches single-GPU.
+
+### Reference implementations
+
+| Model | Path |
+|-------|------|
+| Z-Image | `vllm_omni/diffusion/models/z_image/z_image_transformer.py` |
+| FLUX | `vllm_omni/diffusion/models/flux/flux_transformer.py` |
+| Qwen-Image | `vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py` |
+
+---
+
+## Sequence Parallelism (SP / USP)
+
+SP splits sequence tokens across GPUs using Ulysses (all-to-all) or Ring (P2P) communication. It is applied non-intrusively via the `_sp_plan` dict — no changes to `forward()` logic.
+
+### Approach 1: Non-Intrusive `_sp_plan` (Recommended)
+
+The framework automatically registers hooks to shard inputs and gather outputs at `nn.Module` boundaries.
+
+#### Step 1: Identify module boundaries
+
+Find where tensors need sharding/gathering:
+
+```python
+class MyTransformer(nn.Module):
+ def __init__(self):
+ self.patch_embed = PatchEmbed() # Before blocks
+ self.pos_embed = RoPE() # RoPE may need splitting
+ self.blocks = nn.ModuleList([...]) # Blocks process sharded x
+ self.norm_out = LayerNorm()
+ self.proj_out = Linear() # Gather after this
+
+ def forward(self, x):
+ x = self.patch_embed(x)
+ pos = self.pos_embed(x)
+ for block in self.blocks:
+ x = block(x, pos)
+ x = self.norm_out(x)
+ return self.proj_out(x)
+```
+
+#### Step 2: Handle inline operations
+
+`_sp_plan` hooks only work at `nn.Module` boundaries. Inline ops like `torch.cat()` must be extracted into submodules:
+
+```python
+# BAD: Inline — hooks can't intercept
+unified = torch.cat([x, cap_feats], dim=1)
+
+# GOOD: Extract into submodule
+class UnifiedPrepare(nn.Module):
+ def forward(self, x, cap_feats):
+ return torch.cat([x, cap_feats], dim=1)
+
+self.unified_prepare = UnifiedPrepare()
+unified = self.unified_prepare(x, cap_feats)
+```
+
+Common cases: `torch.cat()`, `pad_sequence()`, `tensor.reshape()`, complex preprocessing.
+
+#### Step 3: Write `_sp_plan`
+
+**Pattern 1: Shard at first block, gather at output** (most common)
+
+```python
+from vllm_omni.diffusion.distributed.sp_plan import (
+ SequenceParallelInput, SequenceParallelOutput,
+)
+
+class StandardTransformer(nn.Module):
+ _sp_plan = {
+ "blocks.0": {
+ "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
+ },
+ "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
+ }
+```
+
+**Pattern 2: Shard RoPE outputs separately**
+
+```python
+class TransformerWithRoPE(nn.Module):
+ _sp_plan = {
+ "rope": {
+ 0: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True),
+ 1: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True),
+ },
+ "blocks.0": {
+ "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
+ },
+ "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
+ }
+```
+
+**Pattern 3: Dual-stream (shard image, replicate text)**
+
+```python
+class DualStreamTransformer(nn.Module):
+ _sp_plan = {
+ "rope_preparer": {
+ 2: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True),
+ 3: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True),
+ },
+ "transformer_blocks.0": {
+ "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
+ },
+ "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
+ }
+```
+
+### API Reference
+
+**SequenceParallelInput**:
+
+| Parameter | Type | Description |
+|-----------|------|-------------|
+| `split_dim` | int | Dimension to split (usually 1 for sequence) |
+| `expected_dims` | int/None | Expected tensor rank for validation |
+| `split_output` | bool | `False`: shard input params; `True`: shard output tensors |
+| `auto_pad` | bool | Auto-pad if sequence not divisible by world_size |
+
+**SequenceParallelOutput**:
+
+| Parameter | Type | Description |
+|-----------|------|-------------|
+| `gather_dim` | int | Dimension to gather (usually 1 for sequence) |
+| `expected_dims` | int/None | Expected tensor rank for validation |
+
+**Module naming**:
+
+| Key | Meaning |
+|-----|---------|
+| `"blocks.0"` | First element of ModuleList |
+| `"blocks.*"` | All elements of ModuleList |
+| `"rope"` | Named submodule |
+
+**Dictionary value types**:
+
+| Key type | split_output | Description |
+|----------|-------------|-------------|
+| `"param_name"` (str) | False | Shard input parameter by name |
+| `0, 1, ...` (int) | True | Shard output tuple by index |
+
+### Approach 2: Intrusive Modification (Complex Cases)
+
+For dynamic sharding logic that can't be expressed via `_sp_plan`:
+
+```python
+from vllm_omni.diffusion.distributed.sp_sharding import sp_shard, sp_gather
+
+def forward(self, hidden_states, ...):
+ if self.parallel_config.sequence_parallel_size > 1:
+ hidden_states = sp_shard(hidden_states, dim=1)
+ for block in self.blocks:
+ hidden_states = block(hidden_states)
+ if self.parallel_config.sequence_parallel_size > 1:
+ hidden_states = sp_gather(hidden_states, dim=1)
+ return hidden_states
+```
+
+Use intrusive modification as a last resort — `_sp_plan` is preferred for maintainability.
+
+### UAA Mode (Experimental)
+
+`ulysses_mode="advanced_uaa"` handles arbitrary sequence lengths and head counts that aren't divisible by `ulysses_degree`. Uses variable all-to-all split sizes and temporary head padding.
+
+### Combining SP methods
+
+Ulysses and Ring can be combined: `ulysses_degree × ring_degree = total SP GPUs`.
+
+```python
+DiffusionParallelConfig(ulysses_degree=2, ring_degree=2) # 4 GPUs total
+```
+
+### Testing SP
+
+```bash
+# Offline
+python text_to_image.py --model Your-model --ulysses-degree 2
+
+# Online serving
+vllm serve Your-model --omni --usp 2
+```
+
+### Reference implementations
+
+| Model | Path |
+|-------|------|
+| Qwen-Image | `vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py` |
+| Wan2.2 | `vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py` |
+| Z-Image | `vllm_omni/diffusion/models/z_image/z_image_transformer.py` |
+
+---
+
+## CFG Parallelism
+
+Distributes positive/negative Classifier-Free Guidance branches across 2 GPUs.
+
+### Implementation
+
+Inherit `CFGParallelMixin` and implement `diffuse()`:
+
+```python
+from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
+
+class YourPipeline(nn.Module, CFGParallelMixin):
+ def diffuse(self, latents, timesteps, prompt_embeds, negative_embeds,
+ do_true_cfg, true_cfg_scale, **kwargs):
+ for i, t in enumerate(timesteps):
+ positive_kwargs = {
+ "hidden_states": latents,
+ "encoder_hidden_states": prompt_embeds,
+ "timestep": t,
+ }
+ negative_kwargs = {
+ "hidden_states": latents,
+ "encoder_hidden_states": negative_embeds,
+ "timestep": t,
+ } if do_true_cfg else None
+
+ noise_pred = self.predict_noise_maybe_with_cfg(
+ do_true_cfg=do_true_cfg,
+ true_cfg_scale=true_cfg_scale,
+ positive_kwargs=positive_kwargs,
+ negative_kwargs=negative_kwargs,
+ )
+ latents = self.scheduler_step_maybe_with_cfg(
+ noise_pred, t, latents, do_true_cfg
+ )
+ return latents
+```
+
+### Customization hooks
+
+| Method | Override when |
+|--------|-------------|
+| `predict_noise()` | Non-standard transformer call (e.g., dual-transformer like Wan2.2) |
+| `cfg_normalize_function()` | Custom normalization (e.g., LongCat with clamping) |
+| `combine_cfg_noise()` | Multi-output models (e.g., video + audio: CFG on video, positive-only on audio) |
+
+**Custom predict_noise** (Wan2.2 — selects active transformer):
+
+```python
+def predict_noise(self, current_model=None, **kwargs):
+ if current_model is None:
+ current_model = self.transformer
+ return current_model(**kwargs)[0]
+```
+
+**Custom combine_cfg_noise** (multi-output):
+
+```python
+def combine_cfg_noise(self, positive_pred, negative_pred, scale, normalize):
+ video_pos, audio_pos = positive_pred
+ video_neg, audio_neg = negative_pred
+ video_combined = super().combine_cfg_noise(video_pos, video_neg, scale, normalize)
+ return (video_combined, audio_pos)
+```
+
+### Composite scheduler for multi-output
+
+When each output has its own schedule:
+
+```python
+class VideoAudioScheduler:
+ def __init__(self, video_scheduler, audio_scheduler):
+ self.video_scheduler = video_scheduler
+ self.audio_scheduler = audio_scheduler
+
+ def step(self, noise_pred, t, latents, return_dict=False, generator=None):
+ video_out = self.video_scheduler.step(
+ noise_pred[0], t[0], latents[0], return_dict=False, generator=generator
+ )[0]
+ audio_out = self.audio_scheduler.step(
+ noise_pred[1], t[1], latents[1], return_dict=False, generator=generator
+ )[0]
+ return ((video_out, audio_out),)
+```
+
+### Testing CFG Parallel
+
+```bash
+python text_to_image.py --model Your-model \
+ --cfg-parallel-size 2 --cfg-scale 4.0 \
+ --negative-prompt "ugly, unclear"
+```
+
+**Constraint**: `guidance_scale > 1.0` and negative prompt must be provided.
+
+### Reference implementations
+
+| Model | Path |
+|-------|------|
+| Qwen-Image | `vllm_omni/diffusion/models/qwen_image/cfg_parallel.py` |
+| Wan2.2 | `vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py` |
+| Mixin base | `vllm_omni/diffusion/distributed/cfg_parallel.py` |
+
+---
+
+## HSDP (Hybrid Sharded Data Parallel)
+
+Shards model weights across GPUs using PyTorch FSDP2. Reduces per-GPU VRAM without changing computation.
+
+### Implementation
+
+Add `_hsdp_shard_conditions` to the transformer class:
+
+```python
+class YourTransformer(nn.Module):
+ @staticmethod
+ def _is_transformer_block(name: str, module) -> bool:
+ return "blocks" in name and name.split(".")[-1].isdigit()
+
+ _hsdp_shard_conditions = [_is_transformer_block]
+```
+
+For MoE models, add additional conditions:
+
+```python
+class MoETransformer(nn.Module):
+ @staticmethod
+ def _is_transformer_block(name, module):
+ return "blocks" in name and name.split(".")[-1].isdigit()
+
+ @staticmethod
+ def _is_moe_expert(name, module):
+ return "experts" in name and name.split(".")[-1].isdigit()
+
+ _hsdp_shard_conditions = [_is_transformer_block, _is_moe_expert]
+```
+
+A module is sharded if **any** condition returns `True`.
+
+### Constraints
+
+- Cannot combine with Tensor Parallelism
+- For standalone HSDP (no other parallelism), `hsdp_shard_size` must be specified explicitly
+- Can combine with SP: HSDP reduces memory while SP distributes sequence
+
+### Testing HSDP
+
+```python
+from vllm_omni.diffusion.data import DiffusionParallelConfig
+
+parallel_config = DiffusionParallelConfig(use_hsdp=True, hsdp_shard_size=8)
+omni = Omni(model="your-model", parallel_config=parallel_config)
+```
+
+Or CLI:
+
+```bash
+vllm serve Your-model --omni --use-hsdp
+```
+
+**Verify**: logs show "HSDP Inference: replicate_size=..., shard_size=..." and "Sharded N modules + root". Check VRAM reduction.
+
+### Reference implementations
+
+| Model | Path |
+|-------|------|
+| Wan2.2 | `vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py` |
+| HSDP Core | `vllm_omni/diffusion/distributed/hsdp.py` |
+
+---
+
+## VAE Patch Parallelism
+
+Shards VAE decode spatially across ranks using tiling:
+
+```bash
+python text_to_image.py --model Your-model --vae-patch-parallel-size 4
+```
+
+Auto-enables `--vae-use-tiling`. Uses `DistributedAutoencoderKLWan` or similar distributed VAE. Set `vae_patch_parallel_size` in `DiffusionParallelConfig`.
+
+---
+
+## Combining Parallelism Methods
+
+Common multi-GPU recipes:
+
+```bash
+# 4 GPUs: CFG (2) × Ulysses (2)
+python text_to_image.py --model Qwen/Qwen-Image \
+ --cfg-parallel-size 2 --ulysses-degree 2
+
+# 8 GPUs: Ulysses (4) × Ring (2) + VAE patch (8)
+python text_to_video.py --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \
+ --ulysses-degree 4 --ring-degree 2 --vae-patch-parallel-size 8
+
+# 2 GPUs: HSDP + Ulysses (cannot combine HSDP with TP)
+vllm serve Your-model --omni --use-hsdp --usp 2
+```
+
+## Discovering Parallelism Support
+
+Check which parallelism methods a model supports:
+
+| Check | How |
+|-------|-----|
+| **Ulysses / Ring SP** | Transformer defines `_sp_plan`. Search: `grep -r '_sp_plan' vllm_omni/diffusion/models/` |
+| **CFG Parallel** | Pipeline inherits `CFGParallelMixin`. Search: `grep -r 'CFGParallelMixin' vllm_omni/diffusion/models/` |
+| **TP** | Uses `ColumnParallelLinear` / `QKVParallelLinear`. Search: `grep -r 'ParallelLinear\|QKVParallel' vllm_omni/diffusion/models//` |
+| **HSDP** | Transformer defines `_hsdp_shard_conditions`. Search: `grep -r '_hsdp_shard_conditions' vllm_omni/diffusion/models/` |
+
+The canonical per-model support table is in `docs/user_guide/diffusion/parallelism_acceleration.md`.
diff --git a/.claude/skills/add-diffusion-model/references/transformer-adaptation.md b/.claude/skills/add-diffusion-model/references/transformer-adaptation.md
new file mode 100644
index 0000000000..6e344b6a66
--- /dev/null
+++ b/.claude/skills/add-diffusion-model/references/transformer-adaptation.md
@@ -0,0 +1,218 @@
+# Transformer Adaptation Reference
+
+## Adapting a Diffusers Transformer to vLLM-Omni
+
+### Step-by-step Checklist
+
+1. Copy the transformer class from diffusers source
+2. Remove all mixin classes — inherit only from `nn.Module`
+3. Replace attention dispatch with `vllm_omni.diffusion.attention.layer.Attention`
+4. Replace logger with `vllm.logger.init_logger`
+5. Add `od_config: OmniDiffusionConfig | None = None` to `__init__`
+6. Remove training-only code (gradient checkpointing, dropout)
+7. Add `load_weights()` method for weight loading from safetensors
+8. Add class-level attributes for acceleration features
+
+### Mixin Removal
+
+Remove these diffusers mixins (and their imports):
+
+```python
+# Remove all of these:
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.attention_processor import AttentionModuleMixin
+from diffusers.loaders import PeftAdapterMixin, FromOriginalModelMixin
+
+# Replace:
+class MyTransformer(ModelMixin, ConfigMixin, AttentionModuleMixin):
+# With:
+class MyTransformer(nn.Module):
+```
+
+Also remove `@register_to_config` decorators from `__init__`.
+
+### Attention Replacement
+
+The vLLM-Omni `Attention` layer wraps backend selection (FlashAttention, SDPA, SageAttn, etc.) and supports sequence parallelism hooks.
+
+**QKV tensor shape must be `[batch, seq_len, num_heads, head_dim]`.**
+
+#### Self-Attention Pattern
+
+```python
+from vllm_omni.diffusion.attention.layer import Attention
+from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
+
+class SelfAttentionBlock(nn.Module):
+ def __init__(self, dim, num_heads):
+ super().__init__()
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+
+ self.to_q = nn.Linear(dim, dim)
+ self.to_k = nn.Linear(dim, dim)
+ self.to_v = nn.Linear(dim, dim)
+ self.to_out = nn.Linear(dim, dim)
+
+ self.attn = Attention(
+ num_heads=num_heads,
+ head_size=self.head_dim,
+ softmax_scale=1.0 / (self.head_dim ** 0.5),
+ causal=False,
+ num_kv_heads=num_heads,
+ )
+
+ def forward(self, x, attn_mask=None):
+ B, S, _ = x.shape
+ q = self.to_q(x).view(B, S, self.num_heads, self.head_dim)
+ k = self.to_k(x).view(B, S, self.num_heads, self.head_dim)
+ v = self.to_v(x).view(B, S, self.num_heads, self.head_dim)
+
+ attn_metadata = AttentionMetadata(attn_mask=attn_mask)
+ out = self.attn(q, k, v, attn_metadata=attn_metadata)
+ out = out.reshape(B, S, -1)
+ return self.to_out(out)
+```
+
+#### Fused QKV with TP (Advanced)
+
+For tensor parallelism, use vLLM's parallel linear layers:
+
+```python
+from vllm.model_executor.layers.linear import (
+ QKVParallelLinear, RowParallelLinear
+)
+
+class TPSelfAttention(nn.Module):
+ def __init__(self, dim, num_heads):
+ super().__init__()
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+
+ self.to_qkv = QKVParallelLinear(
+ hidden_size=dim,
+ head_size=self.head_dim,
+ total_num_heads=num_heads,
+ total_num_kv_heads=num_heads,
+ )
+ self.to_out = RowParallelLinear(dim, dim)
+
+ self.attn = Attention(
+ num_heads=num_heads,
+ head_size=self.head_dim,
+ softmax_scale=1.0 / (self.head_dim ** 0.5),
+ causal=False,
+ num_kv_heads=num_heads,
+ )
+```
+
+### Logger Replacement
+
+```python
+# Replace:
+from diffusers.utils import logging
+logger = logging.get_logger(__name__)
+
+# With:
+from vllm.logger import init_logger
+logger = init_logger(__name__)
+```
+
+### Custom Layers from vLLM-Omni
+
+Available utility layers:
+
+```python
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm_omni.diffusion.layers.rope import RotaryEmbedding
+from vllm_omni.diffusion.layers.adalayernorm import AdaLayerNorm
+```
+
+### Config Support
+
+```python
+from vllm_omni.diffusion.data import OmniDiffusionConfig
+
+class MyTransformer(nn.Module):
+ def __init__(self, *, od_config=None, num_layers=28, hidden_size=3072, **kwargs):
+ super().__init__()
+ self.od_config = od_config
+ self.parallel_config = od_config.parallel_config if od_config else None
+ # ... build layers
+```
+
+The transformer config values come from `model_index.json` → `config.json` in the transformer subfolder. The pipeline uses `get_transformer_config_kwargs(od_config.tf_model_config, TransformerClass)` to filter config keys to match the `__init__` signature.
+
+### Weight Loading
+
+The `load_weights` method receives an iterable of `(name, tensor)` from safetensors files, with the prefix (e.g., `"transformer."`) already stripped by the loader.
+
+```python
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+
+class MyTransformer(nn.Module):
+ def load_weights(self, weights):
+ params = dict(self.named_parameters())
+ loaded = set()
+ for name, tensor in weights:
+ # Optional: remap names from diffusers to vllm-omni naming
+ # e.g., "ff.net.0.proj" -> "ff.net_0.proj"
+
+ if name in params:
+ param = params[name]
+ if hasattr(param, "weight_loader"):
+ param.weight_loader(param, tensor)
+ else:
+ default_weight_loader(param, tensor)
+ loaded.add(name)
+ return loaded
+```
+
+#### QKV Fusion in load_weights
+
+If you fused separate Q/K/V into a `QKVParallelLinear`, you need to map diffusers' separate weight names:
+
+```python
+stacked_params_mapping = [
+ ("to_qkv", "to_q", "q"),
+ ("to_qkv", "to_k", "k"),
+ ("to_qkv", "to_v", "v"),
+]
+
+def load_weights(self, weights):
+ params = dict(self.named_parameters())
+ loaded = set()
+ for name, tensor in weights:
+ for fused_name, orig_name, shard_id in stacked_params_mapping:
+ if orig_name in name:
+ name = name.replace(orig_name, fused_name)
+ param = params[name]
+ param.weight_loader(param, tensor, shard_id)
+ loaded.add(name)
+ break
+ else:
+ # Normal loading
+ ...
+ return loaded
+```
+
+### Class-Level Attributes for Features
+
+```python
+class MyTransformer(nn.Module):
+ # torch.compile: list block class names that repeat and can be compiled
+ _repeated_blocks = ["MyTransformerBlock"]
+
+ # CPU offload: attribute name of the nn.ModuleList containing blocks
+ _layerwise_offload_blocks_attr = "blocks"
+
+ # LoRA: mapping of fused param names to original param names
+ packed_modules_mapping = {"to_qkv": ["to_q", "to_k", "to_v"]}
+
+ # Sequence parallelism plan (advanced — add after basic impl works)
+ _sp_plan = {
+ "blocks.0": SequenceParallelInput(split_dim=1),
+ "proj_out": SequenceParallelOutput(gather_dim=1),
+ }
+```
diff --git a/.claude/skills/add-diffusion-model/references/troubleshooting.md b/.claude/skills/add-diffusion-model/references/troubleshooting.md
new file mode 100644
index 0000000000..27acdd8d15
--- /dev/null
+++ b/.claude/skills/add-diffusion-model/references/troubleshooting.md
@@ -0,0 +1,178 @@
+# Troubleshooting Reference
+
+## Common Errors When Adding a Diffusion Model
+
+### ImportError / ModuleNotFoundError
+
+**Cause**: Missing or incorrect registration.
+
+**Fix checklist**:
+1. Model registered in `vllm_omni/diffusion/registry.py` `_DIFFUSION_MODELS` dict
+2. `__init__.py` exports the pipeline class
+3. Pipeline file exists at the correct path: `vllm_omni/diffusion/models/{folder}/{file}.py`
+4. Class name in registry matches the actual class name in the file
+
+### Shape Mismatch in Attention
+
+**Symptom**: `RuntimeError: shape mismatch` or `expected 4D tensor`
+
+**Cause**: QKV tensors not reshaped to `[batch, seq_len, num_heads, head_dim]`.
+
+**Fix**: Before calling `self.attn(q, k, v, ...)`, ensure:
+```python
+q = q.view(batch, seq_len, self.num_heads, self.head_dim)
+k = k.view(batch, kv_seq_len, self.num_kv_heads, self.head_dim)
+v = v.view(batch, kv_seq_len, self.num_kv_heads, self.head_dim)
+```
+
+After attention, reshape back:
+```python
+out = out.reshape(batch, seq_len, -1)
+```
+
+### Weight Loading Failures
+
+**Symptom**: `RuntimeError: size mismatch for parameter ...` or missing keys
+
+**Debugging**:
+1. Print diffusers weight names: `safetensors.safe_open(path, "pt").keys()`
+2. Print model parameter names: `dict(model.named_parameters()).keys()`
+3. Compare and add name remappings in `load_weights()`
+
+**Common remappings needed**:
+- `ff.net.0.proj` → `ff.net_0.proj` (PyTorch Sequential indexing)
+- `.to_out.0.` → `.to_out.` (Sequential unwrapping)
+- `scale_shift_table` → moved to a wrapper module
+
+### Black/Blank/Noisy Output
+
+**Possible causes**:
+1. **Wrong latent normalization**: Check VAE expects latents scaled by `vae.config.scaling_factor`
+2. **Wrong scheduler**: Using the wrong scheduler class or wrong `flow_shift`
+3. **Missing CFG**: Some models require `guidance_scale > 1.0` with negative prompt
+4. **Wrong timestep format**: Some schedulers expect float, others expect int/long
+5. **Missing post-processing**: Raw VAE output may need denormalization
+
+**Quick test**: Run with diffusers directly using the same seed and compare latents at each step.
+
+### OOM (Out of Memory)
+
+**Solutions** (in order of preference):
+1. `--enforce-eager` to disable torch.compile (saves compile memory)
+2. `--enable-cpu-offload` for model-level offload
+3. `--enable-layerwise-offload` for block-level offload (better for large models)
+4. `--vae-use-slicing --vae-use-tiling` for VAE memory reduction
+5. Reduce resolution: `--height 480 --width 832`
+6. Use TP: `--tensor-parallel-size 2`
+
+### Different Output vs Diffusers Reference
+
+**Common causes**:
+1. **Attention backend difference**: FlashAttention vs SDPA may produce slightly different results. Set `DIFFUSION_ATTENTION_BACKEND=TORCH_SDPA` to match diffusers
+2. **Float precision**: vLLM-Omni may use bfloat16 where diffusers uses float32 for some operations
+3. **Missing normalization**: Check all LayerNorm/RMSNorm are preserved
+4. **Scheduler rounding**: Some schedulers have numerical sensitivity
+
+### Tensor Parallel Errors
+
+**Symptom**: `AssertionError: not divisible` or incorrect output with TP>1
+
+**Fix**:
+1. Verify `num_heads % tp_size == 0` and `num_kv_heads % tp_size == 0`
+2. Ensure `ColumnParallelLinear` / `RowParallelLinear` are used correctly
+3. Check that norms between parallel layers use distributed norm if needed
+4. Verify `load_weights` handles TP sharding for norm weights
+5. Use `self.to_qkv.num_heads` (local heads per GPU) for QKV split sizes, not total heads
+
+**Missing `input_is_parallel=True`**:
+
+`RowParallelLinear` expects sharded input from `ColumnParallelLinear`:
+```python
+self.w1 = ColumnParallelLinear(dim, hidden_dim, return_bias=False)
+self.w2 = RowParallelLinear(hidden_dim, dim, input_is_parallel=True, return_bias=False)
+```
+
+### Sequence Parallel Errors
+
+**Symptom**: Incorrect output or crashes with `--ulysses-degree N` or `--usp N`
+
+**Possible causes**:
+1. **Inline operations between shard/gather points**: `torch.cat()`, `pad_sequence()` etc. not at `nn.Module` boundaries. Fix: extract into submodule.
+2. **Wrong `split_dim`**: Check the tensor shape at the shard point. Sequence dimension is typically `dim=1` for `[B, S, D]` tensors.
+3. **RoPE not sharded**: If RoPE is computed separately, add it to `_sp_plan` with `split_output=True`.
+4. **Sequence not divisible by SP degree**: Use `auto_pad=True` in `SequenceParallelInput` or switch to `ulysses_mode="advanced_uaa"`.
+
+**Debugging**: Add `expected_dims=N` to `SequenceParallelInput`/`Output` for shape validation at runtime.
+
+### CFG Parallel Errors
+
+**Symptom**: CFG parallel not activating, no speedup
+
+**Fix checklist**:
+1. Pipeline inherits `CFGParallelMixin`
+2. `guidance_scale > 1.0`
+3. Negative prompt provided (even if empty string)
+4. `--cfg-parallel-size 2` specified
+5. `diffuse()` method calls `predict_noise_maybe_with_cfg()` and `scheduler_step_maybe_with_cfg()`
+
+**Symptom**: Different output with CFG parallel vs sequential
+
+**Possible cause**: Non-deterministic scheduler. Fix: pass `generator=torch.Generator(device).manual_seed(seed)` to `scheduler_step_maybe_with_cfg()`.
+
+### HSDP Errors
+
+**Symptom**: HSDP not activating or errors during weight loading
+
+**Fix checklist**:
+1. Transformer defines `_hsdp_shard_conditions` class attribute
+2. Shard condition functions return `True` for correct modules (test with `model.named_modules()`)
+3. Not combining with TP (HSDP and TP are incompatible)
+4. For standalone HSDP, `hsdp_shard_size` is specified explicitly
+
+**Verify**: Check logs for "HSDP Inference: replicate_size=..., shard_size=..." and "Sharded N modules + root".
+
+### Cache-DiT Not Applied
+
+**Symptom**: No speedup, no cache-related log messages
+
+**Fix checklist**:
+1. Model not in `_NO_CACHE_ACCELERATION` in `registry.py`
+2. Pipeline class name matches `CUSTOM_DIT_ENABLERS` key (if using custom enabler)
+3. `cache_backend="cache_dit"` specified
+4. Check logs for "Cache-dit enabled successfully on xxx"
+
+**Verify pipeline name**: `print(pipeline.__class__.__name__)` — must match registry key.
+
+### Cache-DiT Quality Degradation
+
+**Symptom**: Artifacts or lower quality with cache-dit
+
+**Fix**: Reduce aggressiveness:
+```python
+cache_config={
+ "residual_diff_threshold": 0.12, # Lower from 0.24
+ "max_warmup_steps": 6, # Increase from 4
+ "max_continuous_cached_steps": 2, # Reduce if higher
+}
+```
+
+If quality is still poor, the model may need a custom enabler with per-block-list `ParamsModifier` tuning.
+
+### Model Not Detected / Wrong Pipeline Class
+
+**Symptom**: `ValueError: Model class ... not found in diffusion model registry`
+
+**Cause**: The model's `model_index.json` has a `_class_name` for the pipeline that doesn't match registry keys.
+
+**Fix**: The registry key must match the diffusers pipeline class name from `model_index.json`. If using a different name, map it in the registry:
+```python
+"DiffusersPipelineClassName": ("your_folder", "your_file", "YourVllmClassName"),
+```
+
+## Debugging Workflow
+
+1. **Add verbose logging**: Use `logger.info()` to print tensor shapes at each stage
+2. **Compare step-by-step**: Run diffusers and vllm-omni side by side, comparing tensors after each major operation
+3. **Use small configs**: Reduce `num_inference_steps=2`, small resolution for fast iteration
+4. **Test transformer isolation**: Feed the same input to both diffusers and vllm-omni transformers, compare outputs
+5. **Binary search for bugs**: Comment out blocks/layers to isolate where divergence starts
diff --git a/.claude/skills/add-tts-model/SKILL.md b/.claude/skills/add-tts-model/SKILL.md
new file mode 100644
index 0000000000..e64e7e763e
--- /dev/null
+++ b/.claude/skills/add-tts-model/SKILL.md
@@ -0,0 +1,284 @@
+---
+name: add-tts-model
+description: "Integrate a new text-to-speech model into vLLM-Omni from HuggingFace reference implementation through production-ready serving with streaming and CUDA graph acceleration. Use when adding a new TTS model, wiring stage separation for speech synthesis, enabling online voice generation serving, debugging TTS integration behavior, or building audio output pipelines."
+---
+
+# TTS Model Integration Workflow
+
+## Overview
+
+```
+HF Reference -> Stage Separation -> Online Serving -> Async Chunk -> CUDA Graph
+ (Phase 1) (Phase 2) (Phase 3) (Phase 4) (Phase 5)
+```
+
+## Phase 1: HuggingFace Reference
+
+**Goal**: Understand the reference implementation and verify it produces correct audio.
+
+### Steps
+
+1. **Run the reference model** end-to-end using the official HuggingFace / GitHub code
+2. **Document the architecture**:
+ - What are the sub-models? (AR decoder, codec decoder, vocoder, etc.)
+ - What is the token vocabulary? (semantic codes, RVQ codebooks, special tokens)
+ - What is the output format? (sample rate, channels, codec type)
+3. **Capture reference outputs** for comparison during integration
+4. **Identify the config structure**: `config.json` fields, `model_type`, sub-model configs
+
+### Key Questions
+
+- How many codebooks? What are the codebook sizes?
+- What special tokens exist? (`<|voice|>`, `<|audio_start|>`, `<|im_end|>`, etc.)
+- What is the token-to-ID mapping for codec codes?
+- What is the hop length / frame rate of the codec?
+- Does the model support voice cloning? How? (reference audio encoding, speaker embeddings, etc.)
+
+### Deliverables
+
+- Working reference script that produces audio
+- Architecture diagram / notes
+- Token vocabulary mapping
+- Reference audio samples for regression testing
+
+## Phase 2: Stage Separation (Offline Inference)
+
+**Goal**: Split the model into vLLM-Omni stages and get offline inference working.
+
+### Steps
+
+1. **Register the model** in `vllm_omni/model_executor/models/registry.py`
+2. **Create config classes** (`configuration_.py`) with `model_type` registration
+3. **Implement Stage 0** (AR model):
+ - Subclass appropriate base (e.g., wrap Qwen3 decoder layers)
+ - Implement `forward()` for autoregressive token generation
+ - Handle special token logic (start/stop tokens, codec token mapping)
+ - If dual-AR (like Fish Speech), implement Fast AR as a nested module
+4. **Implement Stage 1** (Decoder):
+ - Load codec weights (may need lazy loading from separate checkpoint)
+ - Implement `forward()`: codec codes -> audio waveform
+ - Return `OmniOutput` with `multimodal_outputs`
+5. **Create stage config YAML** defining both stages, memory allocation, and model paths
+6. **Create stage input processor** for prompt building
+7. **Write end2end.py** test script
+
+### Critical Parameters to Get Right
+
+| Parameter | Impact if Wrong |
+|-----------|----------------|
+| Hop length | Audio duration wrong, streaming noise |
+| Token ID mapping | Garbage codes -> noise output |
+| Codebook count/size | Shape mismatch crashes |
+| Stop token | Generation never stops or stops too early |
+| dtype / autocast | Numerical issues, silent quality degradation |
+| Repetition penalty | Must match reference (often 1.0 for TTS) |
+
+### Debugging Priority (from experience)
+
+When audio output is wrong, check in this order:
+
+1. **RoPE / attention**: Are position encodings correct? Is the attention mask right?
+2. **Normalization**: RMSNorm epsilon, layer norm placement (pre vs post)
+3. **Hop length**: Product of all upsample rates in the codec decoder
+4. **Token mapping**: Are codec IDs correctly offset from the vocabulary base?
+5. **Sampling parameters**: Temperature, top_k, top_p, repetition_penalty
+6. **Tensor layout**: Codebook-major vs frame-major ordering
+7. **dtype**: Float32 for codec decoders (autocast can corrupt audio)
+
+### Deliverables
+
+- Model files in `vllm_omni/model_executor/models//`
+- Stage config YAML
+- Working `end2end.py` with correct audio output
+- README.md in the example directory
+
+## Phase 3: Online Serving
+
+**Goal**: Expose the model via `/v1/audio/speech` API endpoint.
+
+### Steps
+
+1. **Register in `serving_speech.py`**:
+ - Add model stage name to `_TTS_MODEL_STAGES` set
+ - Add model detection flag (e.g., `_is_fish_speech`)
+ - Implement prompt builder method (e.g., `_build_fish_speech_prompt()`)
+2. **Handle model-specific parameters**:
+ - Voice cloning: `ref_audio` encoding and prompt injection
+ - `max_new_tokens` override in sampling params
+ - Model-specific default values
+3. **Create client scripts**: `speech_client.py`, `run_server.sh`
+4. **Test all response formats**: wav, mp3, flac, pcm
+5. **Add Gradio demo**: Interactive web UI with streaming support
+
+### Voice Cloning Pattern
+
+```python
+import base64
+from pathlib import Path
+
+def build_voice_clone_prompt(ref_audio_path: str, text: str, codec) -> list:
+ """Build prompt with reference audio for voice cloning in serving_speech.py."""
+ audio_bytes = Path(ref_audio_path).read_bytes()
+ codes = codec.encode(audio_bytes) # Encode on CPU using model's codec (e.g., DAC)
+ token_ids = [code + codec.vocab_offset for code in codes.flatten().tolist()]
+ return [
+ {"role": "system", "content": f"<|voice|>{''.join(chr(t) for t in token_ids)}"},
+ {"role": "user", "content": text},
+ ]
+```
+
+### Deliverables
+
+- Updated `serving_speech.py` with model-specific prompt builder
+- Client scripts and server launcher
+- Gradio demo with streaming and voice cloning UI
+- Documentation (offline + online serving docs)
+
+## Phase 4: Async Chunk (Streaming)
+
+**Goal**: Enable inter-stage streaming so audio chunks are produced while AR generation continues.
+
+### Steps
+
+1. **Update stage config YAML**:
+ ```yaml
+ async_chunk: true
+ codec_chunk_frames: 25 # frames per chunk
+ codec_left_context_frames: 25 # overlap for smooth boundaries
+ ```
+2. **Implement chunk handling in Stage 1**:
+ - Accept partial input (chunk of codec codes)
+ - Handle left context for smooth audio boundaries
+ - Return partial audio in `OmniOutput`
+3. **Test streaming**:
+ - Verify audio quality matches non-streaming output
+ - Check for artifacts at chunk boundaries
+ - Measure TTFA (time to first audio)
+4. **Update online serving** to support `stream=true` with PCM output
+
+### Streaming Architecture
+
+```
+Stage 0 (AR) Stage 1 (Decoder)
+ | |
+ |-- chunk 0 (25 frames) ------> decode -> audio chunk 0 -> client
+ |-- chunk 1 (25 frames) ------> decode -> audio chunk 1 -> client
+ |-- chunk 2 (25 frames) ------> decode -> audio chunk 2 -> client
+ ...
+```
+
+### Key Considerations
+
+- **Left context overlap**: Prevents audible artifacts at chunk boundaries
+- **Hop length matters**: `context_audio_samples = context_frames * hop_length`
+- **First chunk latency**: Can use larger initial chunk for better quality, then smaller chunks
+
+### Deliverables
+
+- Updated stage config with async_chunk enabled
+- Smooth streaming audio without boundary artifacts
+- TTFA metrics
+
+## Phase 5: CUDA Graph Acceleration
+
+**Goal**: Capture the AR loop as a CUDA graph for significant speedup.
+
+### Steps
+
+1. **Identify the hot loop**: The AR decoding loop that runs N steps per token
+2. **Create static buffers**:
+ - KV caches with fixed max sequence length
+ - Pre-built causal masks and position tensors per step
+ - Static input/output tensors
+3. **Implement graph capture**:
+ - Warm up with real data
+ - Capture the forward pass
+ - Replay with updated inputs
+4. **Handle constraints**:
+ - Use `torch.argmax` instead of `torch.multinomial` (graph-safe)
+ - Fixed batch size (fall back to eager for other sizes)
+ - No dynamic control flow inside the graph
+
+### Example: Code Predictor CUDA Graph (Qwen3-TTS)
+
+```python
+import torch
+
+class CodePredictorGraph:
+ """Captures the 16-step code predictor AR loop as a single CUDA graph."""
+
+ def setup_graph(self, device: torch.device, kv_heads: int = 4, head_dim: int = 64):
+ self.num_steps = 16
+ self.kv_cache = torch.zeros(1, kv_heads, self.num_steps, head_dim, device=device)
+ self.positions = torch.arange(self.num_steps, device=device)
+ self.causal_mask = torch.tril(torch.ones(self.num_steps, self.num_steps, device=device))
+ self.input_buf = torch.zeros(1, 1, kv_heads * head_dim, device=device)
+ self.output_buf = torch.zeros(1, self.num_steps, device=device, dtype=torch.long)
+ # Warm up, then: self.graph = torch.cuda.CUDAGraph(); self.graph.capture(...)
+
+ def run_graph(self, initial_input: torch.Tensor) -> torch.Tensor:
+ self.input_buf.copy_(initial_input)
+ self.graph.replay()
+ return self.output_buf.clone()
+```
+
+### Performance Expectations
+
+Based on Qwen3-TTS code predictor experience:
+- **3-5x speedup** for the graphed component
+- Only effective for fixed batch sizes (typically batch_size=1)
+- Falls back to eager mode for unsupported configurations
+
+### Deliverables
+
+- CUDA graph implementation for the AR hot loop
+- Benchmark script comparing eager vs graph performance
+- Documentation of constraints and fallback behavior
+
+## Integration Checklist
+
+Use this checklist when integrating a new TTS model:
+
+### Phase 1: HF Reference
+- [ ] Reference model runs and produces correct audio
+- [ ] Architecture documented (stages, codebooks, tokens, sample rate)
+- [ ] Reference audio samples saved for comparison
+
+### Phase 2: Stage Separation
+- [ ] Model registered in `registry.py`
+- [ ] Config classes created with `model_type` registration
+- [ ] Stage 0 (AR) implemented and generates correct tokens
+- [ ] Stage 1 (Decoder) produces correct audio from tokens
+- [ ] Stage config YAML created
+- [ ] `end2end.py` produces audio matching reference quality
+- [ ] README.md written
+
+### Phase 3: Online Serving
+- [ ] Model added to `serving_speech.py`
+- [ ] Prompt builder handles text input correctly
+- [ ] Voice cloning works (if supported)
+- [ ] All response formats work (wav, mp3, flac, pcm)
+- [ ] Client scripts and server launcher created
+- [ ] Gradio demo working
+- [ ] Documentation added (offline + online docs, nav, supported models)
+
+### Phase 4: Async Chunk
+- [ ] Stage config updated with `async_chunk: true`
+- [ ] Stage 1 handles partial chunks correctly
+- [ ] No audio artifacts at chunk boundaries
+- [ ] Streaming via API (`stream=true`) works
+- [ ] TTFA measured and acceptable
+
+### Phase 5: CUDA Graph
+- [ ] Hot loop identified and profiled
+- [ ] Static buffers allocated
+- [ ] Graph captured and replays correctly
+- [ ] Benchmark shows meaningful speedup
+- [ ] Fallback to eager works for unsupported configs
+
+## References
+
+- [TTS audio skill](../vllm-omni-audio-tts/SKILL.md) -- supported models and usage
+- [Fish Speech integration](../vllm-omni-audio-tts/references/fish-speech.md) -- complete example of Phases 1-3
+- [Qwen3-TTS reference](../vllm-omni-audio-tts/references/qwen-tts.md) -- complete example of all 5 phases
+- [Adding a TTS model (developer guide)](https://github.com/vllm-project/vllm-omni/blob/main/docs/contributing/model/adding_tts_model.md)
diff --git a/.claude/skills/readme.md b/.claude/skills/readme.md
new file mode 100644
index 0000000000..b66f2ecd13
--- /dev/null
+++ b/.claude/skills/readme.md
@@ -0,0 +1,34 @@
+# Claude Skills for vLLM-Omni
+
+This directory contains Claude Code skills maintained for the `vllm-omni`
+repository. These skills capture repeatable workflows for common contributor
+tasks such as model integration, pull request review, and release note
+generation.
+
+## Directory Structure
+
+Each skill lives in its own directory under `.claude/skills/`. A skill may
+include:
+
+- `SKILL.md`: the main workflow and operating instructions
+- `references/`: focused reference material used by the skill
+- `scripts/`: small helper scripts used by the skill
+
+## Available Skills
+
+- `add-diffusion-model`: guides integration of a new diffusion model into
+ `vllm-omni`
+- `add-omni-model`: covers addition of new omni-modality model support
+- `add-tts-model`: covers integration of new TTS models and related serving
+ workflows
+- `generate-release-note`: helps prepare release notes for repository changes
+- `review-pr`: provides a structured workflow for reviewing pull requests
+
+## Maintenance Guidelines
+
+- Keep skill names short and task-oriented.
+- Prefer repository-local paths, commands, and examples.
+- Avoid hardcoding fast-changing support matrices unless the skill is actively
+ maintained alongside those changes.
+- Treat skills as contributor tooling: optimize for clarity, actionability, and
+ low maintenance overhead.
diff --git a/.claude/skills/vllm-omni-npu-upgrade/SKILL.md b/.claude/skills/vllm-omni-npu-upgrade/SKILL.md
new file mode 100644
index 0000000000..1ef7ab3930
--- /dev/null
+++ b/.claude/skills/vllm-omni-npu-upgrade/SKILL.md
@@ -0,0 +1,300 @@
+---
+name: vllm-omni-npu-model-runner-upgrade
+description: "Upgrade vllm-omni NPU model runners (OmniNPUModelRunner, NPUARModelRunner, NPUGenerationModelRunner) to align with the latest vllm-ascend NPUModelRunner while preserving omni-specific logic."
+---
+
+# vLLM-Omni NPU Model Runner Upgrade Skill
+
+## Overview
+
+This skill guides the process of upgrading vllm-omni's NPU model runners to align with the latest vllm-ascend codebase while preserving omni-specific enhancements. The NPU runners are designed to run omni multimodal models (like Qwen3-Omni, Bagel, MiMoAudio) on Ascend NPUs.
+
+## File Structure
+
+### NPU Model Runner Files
+```
+vllm-omni/vllm_omni/platforms/npu/worker/
+├── __init__.py
+├── npu_model_runner.py # OmniNPUModelRunner (base class)
+├── npu_ar_model_runner.py # NPUARModelRunner (autoregressive)
+├── npu_ar_worker.py # AR worker
+├── npu_generation_model_runner.py # NPUGenerationModelRunner (diffusion/non-AR)
+└── npu_generation_worker.py # Generation worker
+```
+
+### GPU Reference Files (for omni-specific logic sync)
+```
+vllm-omni/vllm_omni/worker/
+├── __init__.py
+├── gpu_model_runner.py # OmniGPUModelRunner
+├── gpu_ar_model_runner.py # GPUARModelRunner
+├── gpu_ar_worker.py
+├── gpu_generation_model_runner.py
+├── gpu_generation_worker.py
+├── mixins.py
+├── base.py
+└── gpu_memory_utils.py
+```
+
+### vllm-ascend Reference Files
+```
+vllm-ascend/vllm_ascend/worker/
+├── model_runner_v1.py # NPUModelRunner (base class to copy from)
+├── npu_input_batch.py
+├── block_table.py
+├── pcp_utils.py
+└── worker.py
+```
+
+## Inheritance Hierarchy
+
+```
+ GPUModelRunner (vllm)
+ |
+ +----------------+----------------+
+ | |
+ OmniGPUModelRunner NPUModelRunner (vllm-ascend)
+ (vllm_omni/worker) (vllm_ascend/worker)
+ | |
+ +----------- OmniNPUModelRunner --+
+ (multiple inheritance)
+ |
+ +---------------+---------------+
+ | |
+ NPUARModelRunner NPUGenerationModelRunner
+ (autoregressive) (non-autoregressive/diffusion)
+```
+
+## Omni-Specific Comment Markers
+
+Omni-specific logic is marked with comment blocks:
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+# ... omni-specific code ...
+# -------------------------------------- Omni-new -------------------------------------------------
+```
+
+Or simpler variations:
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+# ------------------------------------------------------------------------------------------------
+```
+
+**Important**:
+- Always preserve and add these markers when modifying code.
+- **The reference documents (`references/omni-specific-blocks.md`) may not be up-to-date.** Always grep for `Omni-new` in the GPU implementations to find the authoritative list of omni-specific blocks.
+- When you discover new omni-specific code that is not documented in the references, please update the reference files.
+
+## Key Methods Requiring Attention
+
+### OmniNPUModelRunner (npu_model_runner.py)
+
+| Method | Description | Omni-Specific Logic |
+|--------|-------------|---------------------|
+| `load_model` | Load model and initialize talker_mtp | Uses `ACLGraphWrapper` instead of `CUDAGraphWrapper`, initializes talker buffers |
+| `_dummy_run` | Warmup/profiling run | talker_mtp dummy forward, `extract_multimodal_outputs` |
+| `_model_forward` | Forward pass wrapper | Injects `model_kwargs_extra`, wraps with `OmniOutput`, NPU-specific graph updates |
+| `_talker_mtp_forward` | Talker MTP forward for Qwen3-Omni | Uses `set_ascend_forward_context` |
+
+### NPUARModelRunner (npu_ar_model_runner.py)
+
+| Method | Description | Omni-Specific Logic |
+|--------|-------------|---------------------|
+| `__init__` | Initialize with KV transfer manager | `OmniKVTransferManager` setup |
+| `execute_model` | Main inference entry | KV transfer handling, `_update_states` override, `extract_multimodal_outputs` |
+| `sample_tokens` | Token sampling | Hidden states extraction, multimodal outputs processing, `OmniModelRunnerOutput` |
+| `_resolve_global_request_id` | Request ID resolution | For disaggregated inference |
+
+### NPUGenerationModelRunner (npu_generation_model_runner.py)
+
+| Method | Description | Omni-Specific Logic |
+|--------|-------------|---------------------|
+| `_update_request_states` | Update request states for async chunk | async_chunk handling |
+| `execute_model` | Generation forward | async_chunk, `seq_token_counts`, `_run_generation_model` |
+| `sample_tokens` | Output processing | multimodal output packaging to `OmniModelRunnerOutput` |
+| `_dummy_run` | Dummy run override | model_kwargs initialization, multimodal extraction |
+| `_run_generation_model` | Run generation model | Calls `_model_forward` with sampler |
+
+## Upgrade Workflow
+
+### Step 1: Preparation
+
+1. **Identify target versions**(Use gh cli to check):
+ - We're using vllm-omni main branch
+ - Check the last release of vllm-omni
+ - Target vllm-ascend version(Just directly use the local latest vllm-ascend code)
+
+2. **Check GPU-side changes** (since last release):
+ ```bash
+ cd /root/vllm-workspace/vllm-omni
+ git log --oneline --since="" -- vllm_omni/worker/
+ ```
+
+3. **Read latest vllm-ascend code**:
+ - We don't track vllm-ascend changes - just directly use the latest code from `/root/vllm-workspace/vllm-ascend/vllm_ascend/worker/model_runner_v1.py`
+ - Copy the relevant methods and re-insert omni-specific blocks
+
+### Step 2: Analyze Omni-Specific Logic
+
+For each NPU model runner file:
+
+1. **Extract existing omni-specific blocks**:
+ ```bash
+ grep -n "Omni-new" vllm_omni/platforms/npu/worker/npu_model_runner.py
+ ```
+
+2. **Document each omni block**:
+ - Which method it belongs to
+ - What functionality it provides
+ - Dependencies on other omni code
+
+### Step 3: Update Base Class (OmniNPUModelRunner)
+
+**Note**: Always check the GPU implementation `gpu_model_runner.py` for any new omni logic not yet documented in references.
+
+1. **Read the latest vllm-ascend `NPUModelRunner.load_model`**
+2. **Copy the method, keeping the structure**
+3. **Re-insert omni-specific logic** (check GPU `gpu_model_runner.py` for authoritative list):
+ - Replace `CUDAGraphWrapper` with `ACLGraphWrapper`
+ - Keep talker_mtp initialization
+ - Preserve buffer allocations for talker
+ - Check for any new omni blocks added since last sync
+
+4. **Update `_dummy_run`**:
+ - Copy from vllm-ascend
+ - Compare with GPU `_dummy_run` for omni-specific blocks
+ - Re-insert all `Omni-new` marked code from GPU version
+
+5. **Update `_model_forward`**:
+ - Keep the omni wrapper logic
+ - Update NPU-specific parts (graph params, SP all-gather)
+ - Check GPU version for any new omni logic
+
+### Step 4: Update AR Model Runner
+
+1. **Compare with GPU `gpu_ar_model_runner.py`** for any new omni features
+2. **Copy `execute_model` from vllm-ascend**
+3. **Re-insert omni blocks** (reference `references/omni-specific-blocks.md`, but note it may be incomplete):
+ - **IMPORTANT**: Always check the GPU implementation `gpu_ar_model_runner.py` for all `Omni-new` marked code blocks
+ - The reference doc may not include newly added omni logic - treat it as a starting point, not exhaustive
+ - When discovering new omni code blocks, please update `references/omni-specific-blocks.md`
+ - Common omni blocks include but are not limited to: KV transfer, multimodal outputs, sampling_metadata handling, etc.
+
+4. **Update `sample_tokens`** (also compare with GPU implementation):
+ - Compare with `gpu_ar_model_runner.py`'s `sample_tokens` method
+ - Identify all `Omni-new` marked code blocks
+ - Ensure NPU version includes all omni-specific logic
+
+### Step 5: Update Generation Model Runner
+
+**Note**: Generation model runner may have unique omni logic for diffusion/non-AR models.
+
+1. **Compare with GPU `gpu_generation_model_runner.py`** - grep for all `Omni-new` blocks
+2. **Update `execute_model`**:
+ - Check GPU version for all omni-specific blocks
+ - Keep async_chunk handling
+ - Keep `seq_token_counts` injection
+ - Update forward/context setup from vllm-ascend
+ - Look for any new omni logic not documented in references
+
+3. **Update `_dummy_run`**:
+ - Copy from vllm-ascend base
+ - Compare with GPU `_dummy_run` if exists
+ - Re-insert all omni-specific logic
+
+### Step 6: Update Imports
+
+Check and update imports at the top of each file:
+
+```python
+# Common vllm-ascend imports
+from vllm_ascend.ascend_forward_context import get_forward_context, set_ascend_forward_context
+from vllm_ascend.attention.attention_v1 import AscendAttentionState
+from vllm_ascend.attention.utils import using_paged_attention
+from vllm_ascend.compilation.acl_graph import ACLGraphWrapper, update_full_graph_params
+from vllm_ascend.ops.rotary_embedding import update_cos_sin
+from vllm_ascend.utils import enable_sp, lmhead_tp_enable
+from vllm_ascend.worker.model_runner_v1 import SEQ_LEN_WITH_MAX_PA_WORKSPACE, NPUModelRunner
+
+# Omni-specific imports
+from vllm_omni.model_executor.models.output_templates import OmniOutput
+from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner
+from vllm_omni.outputs import OmniModelRunnerOutput
+from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager
+```
+
+### Step 7: Sync GPU-Side Omni Changes
+
+1. **Check recent GPU worker changes**:
+ ```bash
+ git diff .. -- vllm_omni/worker/gpu_model_runner.py
+ git diff .. -- vllm_omni/worker/gpu_ar_model_runner.py
+ ```
+
+2. **Identify new omni features** that need to be ported to NPU
+
+3. **Apply corresponding changes** to NPU runners
+
+### Step 8: Validation
+
+1. **Run type checking**:
+ ```bash
+ cd /root/vllm-workspace/vllm-omni
+ python -m py_compile vllm_omni/platforms/npu/worker/npu_model_runner.py
+ python -m py_compile vllm_omni/platforms/npu/worker/npu_ar_model_runner.py
+ python -m py_compile vllm_omni/platforms/npu/worker/npu_generation_model_runner.py
+ ```
+
+2. **Run import test**:
+ ```bash
+ python -c "from vllm_omni.platforms.npu.worker import *"
+ ```
+
+3. **Run model serving test** (if hardware available):
+ ```bash
+ vllm serve --trust-remote-code
+ ```
+
+## Common Pitfalls
+
+### 1. Forward Context Differences
+- GPU uses `set_forward_context`
+- NPU uses `set_ascend_forward_context`
+- Parameters may differ slightly
+
+### 2. Graph Wrapper Differences
+- GPU: `CUDAGraphWrapper`
+- NPU: `ACLGraphWrapper`
+- Constructor parameters may differ
+
+### 3. Buffer Creation
+- GPU: `_make_buffer` returns different structure
+- NPU: May need numpy=True/False parameter
+
+### 4. Attention Metadata
+- GPU: Uses vllm attention metadata builders
+- NPU: Uses `AscendCommonAttentionMetadata`
+
+### 5. Sampling
+- GPU: Uses vllm sampler
+- NPU: Uses `AscendSampler`
+
+## Checklist Before Commit
+
+- [ ] All omni-specific comment markers preserved
+- [ ] New omni logic from GPU side synced
+- [ ] Imports updated to latest vllm-ascend
+- [ ] No `CUDAGraphWrapper` references in NPU code
+- [ ] `set_ascend_forward_context` used instead of `set_forward_context`
+- [ ] `ACLGraphWrapper` used for talker_mtp wrapping
+- [ ] Type hints match vllm-ascend signatures
+- [ ] No duplicate code blocks
+- [ ] Python syntax valid (py_compile passes)
+
+## Reference Files for Comparison
+
+When upgrading, keep these files open for reference:
+
+1. **vllm-ascend NPUModelRunner**: `/root/vllm-workspace/vllm-ascend/vllm_ascend/worker/model_runner_v1.py`
+2. **vllm GPUModelRunner**: `/root/vllm-workspace/vllm/vllm/v1/worker/gpu_model_runner.py`
+3. **vllm-omni OmniGPUModelRunner**: `/root/vllm-workspace/vllm-omni/vllm_omni/worker/gpu_model_runner.py`
diff --git a/.claude/skills/vllm-omni-npu-upgrade/references/gpu-to-npu-translation.md b/.claude/skills/vllm-omni-npu-upgrade/references/gpu-to-npu-translation.md
new file mode 100644
index 0000000000..89067d37b2
--- /dev/null
+++ b/.claude/skills/vllm-omni-npu-upgrade/references/gpu-to-npu-translation.md
@@ -0,0 +1,335 @@
+# GPU to NPU Translation Patterns
+
+This document provides a quick reference for translating GPU code patterns to NPU equivalents when porting omni-specific logic.
+
+## Import Translations
+
+### Forward Context
+```python
+# GPU
+from vllm.forward_context import set_forward_context
+
+# NPU
+from vllm_ascend.ascend_forward_context import set_ascend_forward_context
+```
+
+### Graph Wrapper
+```python
+# GPU
+from vllm.compilation.cuda_graph import CUDAGraphWrapper
+
+# NPU
+from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
+```
+
+### Attention State
+```python
+# GPU (no equivalent - uses FlashAttention states directly)
+
+# NPU
+from vllm_ascend.attention.attention_v1 import AscendAttentionState
+```
+
+### Utilities
+```python
+# GPU
+# (directly use torch.cuda functions)
+
+# NPU
+from vllm_ascend.utils import enable_sp, lmhead_tp_enable
+from vllm_ascend.ops.rotary_embedding import update_cos_sin
+```
+
+## Context Manager Translations
+
+### Forward Context Setup
+```python
+# GPU
+with set_forward_context(
+ attn_metadata,
+ self.vllm_config,
+ num_tokens=num_tokens_padded,
+ num_tokens_across_dp=num_tokens_across_dp,
+ cudagraph_runtime_mode=cudagraph_mode,
+ batch_descriptor=batch_desc,
+):
+ # forward pass
+
+# NPU
+with set_ascend_forward_context(
+ attn_metadata,
+ self.vllm_config,
+ num_tokens=num_tokens_padded,
+ num_tokens_across_dp=num_tokens_across_dp,
+ aclgraph_runtime_mode=cudagraph_mode, # Note: 'aclgraph' not 'cudagraph'
+ batch_descriptor=batch_desc,
+ num_actual_tokens=scheduler_output.total_num_scheduled_tokens,
+ model_instance=self.model,
+):
+ # forward pass
+```
+
+### Graph Capture Context
+```python
+# GPU
+from vllm.compilation.cuda_graph import graph_capture as cuda_graph_capture
+with cuda_graph_capture(self.device):
+ # capture
+
+# NPU
+from vllm_ascend.worker.model_runner_v1 import graph_capture
+with graph_capture(self.device):
+ # capture
+```
+
+## Graph Wrapper Usage
+
+### Creating Graph Wrapper
+```python
+# GPU
+if cudagraph_mode.has_full_cudagraphs() and has_separate_talker:
+ self.talker_mtp = CUDAGraphWrapper(
+ talker_mtp,
+ self.vllm_config,
+ runtime_mode=CUDAGraphMode.FULL
+ )
+
+# NPU
+if cudagraph_mode.has_full_cudagraphs() and has_separate_talker:
+ self.talker_mtp = ACLGraphWrapper(
+ talker_mtp,
+ self.vllm_config,
+ runtime_mode=CUDAGraphMode.FULL
+ )
+```
+
+### Checking Graph Wrapper Type
+```python
+# GPU
+if not isinstance(self.talker_mtp, CUDAGraphWrapper):
+ _cudagraph_mode = CUDAGraphMode.NONE
+
+# NPU
+if not isinstance(self.talker_mtp, ACLGraphWrapper):
+ _cudagraph_mode = CUDAGraphMode.NONE
+```
+
+## Device Operations
+
+### Synchronization
+```python
+# GPU
+torch.cuda.synchronize()
+
+# NPU
+torch.npu.synchronize()
+```
+
+### Stream Operations
+```python
+# GPU
+stream = torch.cuda.Stream(device=device)
+torch.cuda.current_stream()
+
+# NPU
+stream = torch.npu.Stream(device=device)
+torch.npu.current_stream()
+```
+
+## Attention Metadata
+
+### State Setting (NPU-specific)
+```python
+# GPU - handled internally by attention backends
+
+# NPU - explicit state setting required
+self.attn_state = AscendAttentionState.DecodeOnly
+if self.speculative_config and self.speculative_config.method == "mtp":
+ if self.vllm_config.model_config.use_mla:
+ self.attn_state = AscendAttentionState.SpecDecoding
+ else:
+ self.attn_state = AscendAttentionState.ChunkedPrefill
+```
+
+### Building Attention Metadata
+```python
+# GPU - uses vllm attention builders
+
+# NPU - may need additional parameters
+(attn_metadata, spec_decode_common_attn_metadata) = self._build_attention_metadata(
+ num_tokens=num_tokens_unpadded,
+ num_tokens_padded=num_tokens_padded,
+ num_reqs=num_reqs,
+ num_reqs_padded=num_reqs_padded,
+ max_query_len=max_num_scheduled_tokens,
+ ubatch_slices=ubatch_slices_attn,
+ logits_indices=logits_indices,
+ use_spec_decode=use_spec_decode,
+ num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
+ num_scheduled_tokens_np=num_scheduled_tokens_np,
+ cascade_attn_prefix_lens=cascade_attn_prefix_lens,
+)
+```
+
+## Rotary Embedding
+
+### Update Cos/Sin Cache
+```python
+# GPU - typically handled inside attention
+
+# NPU - explicit update required before forward
+from vllm_ascend.ops.rotary_embedding import update_cos_sin
+update_cos_sin(positions)
+```
+
+## Sequence Parallelism
+
+### Enable SP Check
+```python
+# GPU - use vllm distributed utilities
+
+# NPU - use vllm-ascend wrapper
+from vllm_ascend.utils import enable_sp
+
+if enable_sp():
+ # sequence parallelism enabled
+```
+
+## Sampler
+
+### Sampler Type
+```python
+# GPU - uses vllm sampler
+self.sampler = Sampler()
+
+# NPU - uses AscendSampler
+from vllm_ascend.sample.sampler import AscendSampler
+self.sampler = AscendSampler()
+```
+
+## Input Batch
+
+### Batch Class
+```python
+# GPU
+from vllm.v1.worker.gpu_input_batch import InputBatch
+
+# NPU
+from vllm_ascend.worker.npu_input_batch import NPUInputBatch
+```
+
+## Graph Parameter Updates
+
+### Full Graph Params Update (NPU-specific)
+```python
+# GPU - not needed
+
+# NPU - required for FULL graph mode
+from vllm_ascend.compilation.acl_graph import update_full_graph_params
+
+forward_context = get_forward_context()
+if (
+ forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL
+ and not forward_context.capturing
+ and not self.use_sparse
+):
+ update_full_graph_params(
+ self.attn_backend,
+ self.update_stream,
+ forward_context,
+ num_tokens_padded,
+ self.vllm_config,
+ self.speculative_config,
+ positions.shape[0],
+ )
+```
+
+## Paged Attention Check
+
+```python
+# GPU - not typically needed
+
+# NPU
+from vllm_ascend.attention.utils import using_paged_attention
+
+if is_graph_capturing and using_paged_attention(num_tokens, self.vllm_config):
+ seq_lens = SEQ_LEN_WITH_MAX_PA_WORKSPACE
+```
+
+## Common Method Signature Differences
+
+### _dummy_run Parameters
+```python
+# GPU (v0.17.0)
+def _dummy_run(
+ self,
+ num_tokens: int,
+ cudagraph_runtime_mode: CUDAGraphMode | None = None,
+ force_attention: bool = False,
+ uniform_decode: bool = False,
+ allow_microbatching: bool = True,
+ skip_eplb: bool = False,
+ is_profile: bool = False,
+ create_mixed_batch: bool = False,
+ remove_lora: bool = True,
+ is_graph_capturing: bool = False,
+ num_active_loras: int = 0,
+) -> tuple[torch.Tensor, torch.Tensor]:
+
+# NPU (v0.17.0) - adds with_prefill, activate_lora
+def _dummy_run(
+ self,
+ num_tokens: int,
+ with_prefill: bool = False,
+ cudagraph_runtime_mode: CUDAGraphMode | None = None,
+ force_attention: bool = False,
+ uniform_decode: bool = False,
+ is_profile: bool = False,
+ create_mixed_batch: bool = False,
+ allow_microbatching: bool = True,
+ skip_eplb: bool = False,
+ remove_lora: bool = True,
+ activate_lora: bool = False,
+ is_graph_capturing: bool = False,
+ num_active_loras: int = 0,
+) -> tuple[torch.Tensor, torch.Tensor]:
+```
+
+### _model_forward Parameters
+```python
+# GPU - no num_tokens_padded
+def _model_forward(
+ self,
+ input_ids: torch.Tensor | None = None,
+ positions: torch.Tensor | None = None,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ **model_kwargs: dict[str, Any],
+):
+
+# NPU - has num_tokens_padded as first parameter
+def _model_forward(
+ self,
+ num_tokens_padded: int,
+ input_ids: torch.Tensor | None = None,
+ positions: torch.Tensor | None = None,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ **model_kwargs: dict[str, Any],
+):
+```
+
+## Quick Reference Table
+
+| Feature | GPU | NPU |
+|---------|-----|-----|
+| Graph wrapper | `CUDAGraphWrapper` | `ACLGraphWrapper` |
+| Forward context | `set_forward_context` | `set_ascend_forward_context` |
+| Runtime mode param | `cudagraph_runtime_mode` | `aclgraph_runtime_mode` |
+| Device sync | `torch.cuda.synchronize()` | `torch.npu.synchronize()` |
+| Stream | `torch.cuda.Stream` | `torch.npu.Stream` |
+| Current stream | `torch.cuda.current_stream()` | `torch.npu.current_stream()` |
+| Input batch | `InputBatch` | `NPUInputBatch` |
+| Sampler | `Sampler` | `AscendSampler` |
+| Attention state | N/A | `AscendAttentionState` |
+| RoPE update | N/A | `update_cos_sin()` |
diff --git a/.claude/skills/vllm-omni-npu-upgrade/references/omni-specific-blocks.md b/.claude/skills/vllm-omni-npu-upgrade/references/omni-specific-blocks.md
new file mode 100644
index 0000000000..8c5d32ab4c
--- /dev/null
+++ b/.claude/skills/vllm-omni-npu-upgrade/references/omni-specific-blocks.md
@@ -0,0 +1,374 @@
+# Omni-Specific Code Blocks Reference
+
+This document catalogs omni-specific code blocks in the NPU model runners, making it easier to identify what needs to be preserved during upgrades.
+
+> **IMPORTANT**: This document may not be complete or up-to-date!
+>
+> - Always grep for `Omni-new` in the GPU implementations (`vllm_omni/worker/`) to find the authoritative list
+> - New omni features may be added that are not yet documented here
+> - When you discover new omni-specific blocks during an upgrade, please update this document
+> - Last verified: Check git history for this file
+
+## OmniNPUModelRunner (npu_model_runner.py)
+
+### load_model - Talker MTP Initialization
+
+```python
+def load_model(self, *args, **kwargs) -> None:
+ NPUModelRunner.load_model(self, *args, **kwargs)
+ # Initialize enable_sp cache to avoid get_current_vllm_config() error
+ # in _pad_for_sequence_parallelism during execute_model.
+ # This is a workaround for vllm-ascend not passing vllm_config to enable_sp().
+ enable_sp(self.vllm_config)
+ # TODO move this model specific logic to a separate class
+ # TTS model IS the talker (no .talker sub-attr); use getattr to support both Omni and TTS.
+ talker_mtp = getattr(self.model, "talker_mtp", None)
+ if talker_mtp is not None:
+ self.talker_mtp = talker_mtp # type: ignore[assignment]
+ cudagraph_mode = self.compilation_config.cudagraph_mode
+ assert cudagraph_mode is not None
+ # Only wrap talker_mtp in CUDAGraphWrapper for Omni models that
+ # have a separate .talker sub-module. TTS models' code predictor
+ # has internal AR loops / torch.multinomial — not graph-safe.
+ has_separate_talker = getattr(self.model, "talker", None) is not None
+ if cudagraph_mode.has_full_cudagraphs() and has_separate_talker:
+ # NOTE: Use ACLGraphWrapper on NPU, not CUDAGraphWrapper
+ self.talker_mtp = ACLGraphWrapper(talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL)
+ # TTS exposes mtp_hidden_size; Omni uses hf_text_config.hidden_size.
+ hidden_size = int(
+ getattr(self.model, "mtp_hidden_size", 0) or getattr(self.model_config.hf_text_config, "hidden_size")
+ )
+ max_batch_size = max(self.max_num_reqs, self.compilation_config.max_cudagraph_capture_size)
+ self.talker_mtp_input_ids = self._make_buffer(max_batch_size, dtype=torch.int32)
+ self.talker_mtp_inputs_embeds = self._make_buffer(
+ max_batch_size, hidden_size, dtype=self.dtype, numpy=False
+ )
+ self.last_talker_hidden = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False)
+ self.text_step = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False)
+```
+
+### _dummy_run - Talker MTP Dummy Forward
+
+Location: Inside `set_ascend_forward_context` block, before main model forward
+
+```python
+# ---------------------------------------Omni-new----------------------------------------------
+if getattr(self.model, "talker", None) is not None and hasattr(self.model, "talker_mtp"):
+ num_tokens_padded_talker_mtp = num_tokens_padded
+ if num_tokens_padded_talker_mtp == self.max_num_tokens:
+ num_tokens_padded_talker_mtp = self.talker_mtp_input_ids.gpu.shape[0]
+ outputs = self.talker_mtp(
+ self.talker_mtp_input_ids.gpu[:num_tokens_padded_talker_mtp],
+ self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded_talker_mtp],
+ self.last_talker_hidden.gpu[:num_tokens_padded_talker_mtp],
+ self.text_step.gpu[:num_tokens_padded_talker_mtp],
+ )
+ self.compilation_config.cache_dir = None
+# ---------------------------------------Omni-new----------------------------------------------
+```
+
+### _dummy_run - Extract Multimodal Outputs
+
+Location: After model forward, before dummy_compute_logits
+
+```python
+# ---------------------------------------Omni-new----------------------------------------------
+hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states)
+# ---------------------------------------Omni-new----------------------------------------------
+```
+
+### _model_forward - Omni Output Wrapping
+
+```python
+def _model_forward(
+ self,
+ num_tokens_padded: int,
+ input_ids: torch.Tensor | None = None,
+ positions: torch.Tensor | None = None,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ **model_kwargs: dict[str, Any],
+):
+ """Override to combine NPUModelRunner's signature with OmniGPUModelRunner's logic."""
+ # Omni-specific: build and inject extra model kwargs
+ model_kwargs_extra = self._build_model_kwargs_extra()
+
+ # Call the model forward (same as NPUModelRunner)
+ assert self.model is not None
+ model_output = self.model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ **model_kwargs,
+ **model_kwargs_extra,
+ )
+
+ # Omni-specific: wrap output if needed
+ if not isinstance(model_output, OmniOutput) and hasattr(self.model, "make_omni_output"):
+ model_output = self.model.make_omni_output(model_output, **model_kwargs_extra)
+
+ # Omni-specific: cache model output for later sample_tokens
+ self._omni_last_model_output = model_output
+
+ # NPU-specific: update full graph params (keep from vllm-ascend)
+ forward_context = get_forward_context()
+ # ... NPU graph update logic ...
+
+ # NPU-specific: all-gather for sequence parallelism (keep from vllm-ascend)
+ if get_forward_context().sp_enabled and not isinstance(model_output, IntermediateTensors):
+ model_output = self._all_gather_hidden_states_and_aux(model_output)
+
+ return model_output
+```
+
+---
+
+## NPUARModelRunner (npu_ar_model_runner.py)
+
+### __init__ - KV Transfer Manager
+
+```python
+def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32)
+ # each model stage has their own hidden size
+ self.hidden_size = self.model_config.hf_text_config.hidden_size
+ self.inputs_embeds = self._make_buffer(self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False)
+ # Initialize KV cache manager (preserve vllm_config fallback behavior)
+ self.kv_transfer_manager = OmniKVTransferManager.from_vllm_config(self.vllm_config, self.model_config)
+```
+
+### execute_model - KV Transfer Before Update States
+
+Location: At the very beginning of execute_model
+
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+# [Omni] Handle KV transfer BEFORE updating states (which removes finished requests)
+self.kv_extracted_req_ids = self.kv_transfer_manager.handle_finished_requests_kv_transfer(
+ finished_reqs=getattr(scheduler_output, "finished_requests_needing_kv_transfer", {}),
+ kv_caches=self.kv_caches,
+ block_size=self.cache_config.block_size,
+ cache_dtype=str(self.cache_config.cache_dtype),
+ request_id_resolver=self._resolve_global_request_id,
+)
+# -------------------------------------- Omni-new -------------------------------------------------
+```
+
+### execute_model - Custom _update_states Call
+
+Location: Inside synchronize_input_prep context
+
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+self._update_states(scheduler_output)
+# ------------------------------------------------------------------------------------------------
+```
+
+### execute_model - Extract Multimodal Outputs
+
+Location: In post process section, after hidden_states assignment
+
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states)
+
+if multimodal_outputs is not None:
+ keys_or_type = (
+ list(multimodal_outputs.keys())
+ if isinstance(multimodal_outputs, dict)
+ else type(multimodal_outputs)
+ )
+ logger.debug(f"[AR] execute_model: multimodal_outputs keys = {keys_or_type}")
+else:
+ logger.debug("[AR] execute_model: multimodal_outputs is None")
+# -------------------------------------- Omni-new -------------------------------------------------
+```
+
+### execute_model - Compute Logits with sampling_metadata
+
+Location: In both broadcast_pp_output True and False branches
+
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+# Try with sampling_metadata first; fall back to without for models that don't support it
+try:
+ logits = self.model.compute_logits(
+ sample_hidden_states, sampling_metadata=self.input_batch.sampling_metadata
+ )
+except TypeError:
+ logits = self.model.compute_logits(sample_hidden_states)
+# -------------------------------------- Omni-new -------------------------------------------------
+```
+
+### sample_tokens - KV Extracted Req IDs
+
+Location: At the beginning of sample_tokens
+
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+kv_extracted_req_ids = getattr(self, "kv_extracted_req_ids", None)
+self.kv_extracted_req_ids = None
+# -------------------------------------- Omni-new -------------------------------------------------
+```
+
+### sample_tokens - Process Additional Information and Build Output
+
+Location: After bookkeeping sync, replacing the original output construction
+
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+hidden_states_cpu = hidden_states.detach().to("cpu").contiguous()
+num_scheduled_tokens_np = getattr(self, "_omni_num_scheduled_tokens_np", None)
+if num_scheduled_tokens_np is None:
+ req_ids = self.input_batch.req_ids
+ num_scheduled_tokens_np = np.array(
+ [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids],
+ dtype=np.int32,
+ )
+
+self._process_additional_information_updates(
+ hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output
+)
+
+pooler_output: list[dict[str, object]] = []
+for rid in req_ids_output_copy:
+ idx = req_id_to_index_output_copy[rid]
+ start = int(self.query_start_loc.cpu[idx])
+ sched = int(num_scheduled_tokens_np[idx])
+ end = start + sched
+ hidden_slice = hidden_states_cpu[start:end]
+ payload: dict[str, object] = {"hidden": hidden_slice}
+ if isinstance(multimodal_outputs, dict) and multimodal_outputs:
+ # ... multimodal output slicing logic ...
+ pooler_output.append(payload)
+
+model_runner_output = OmniModelRunnerOutput(
+ req_ids=req_ids_output_copy,
+ req_id_to_index=req_id_to_index_output_copy,
+ sampled_token_ids=valid_sampled_token_ids,
+ logprobs=logprobs_lists,
+ prompt_logprobs_dict=prompt_logprobs_dict,
+ pooler_output=(pooler_output if self.vllm_config.model_config.engine_output_type != "text" else None),
+ kv_connector_output=kv_connector_output,
+)
+model_runner_output.kv_extracted_req_ids = kv_extracted_req_ids
+# -------------------------------------- Omni-new -------------------------------------------------
+```
+
+---
+
+## NPUGenerationModelRunner (npu_generation_model_runner.py)
+
+### execute_model - Async Chunk Update
+
+Location: Inside prepare input section, before synchronize_input_prep
+
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+if self.model_config.async_chunk and num_scheduled_tokens:
+ self._update_request_states(scheduler_output)
+# -------------------------------------- Omni-new -------------------------------------------------
+```
+
+### execute_model - Seq Token Counts
+
+Location: After _preprocess call
+
+```python
+# [Omni] Pass token counts per request for code2wav output slicing
+model_kwargs["seq_token_counts"] = tokens
+```
+
+### execute_model - Run Generation Model
+
+Location: Inside forward context
+
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+outputs = self._run_generation_model(
+ num_tokens_padded=num_tokens_padded,
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ model_kwargs=model_kwargs,
+ logits_indices=logits_indices,
+)
+_, multimodal_outputs = self.extract_multimodal_outputs(outputs)
+# -------------------------------------- Omni-new -------------------------------------------------
+```
+
+### sample_tokens - Multimodal Output Processing
+
+The entire sample_tokens method body is omni-specific for generation models:
+
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+pooler_output: list[object] = []
+if isinstance(multimodal_outputs, torch.Tensor):
+ # ... tensor handling ...
+elif isinstance(multimodal_outputs, list):
+ # ... list handling ...
+elif isinstance(multimodal_outputs, dict):
+ # ... dict handling per request ...
+else:
+ raise RuntimeError("Unsupported diffusion output type")
+# [Omni] Copy req_id mappings to avoid async scheduling mutation.
+req_ids_output_copy = self.input_batch.req_ids.copy()
+req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy()
+output = OmniModelRunnerOutput(
+ req_ids=req_ids_output_copy,
+ req_id_to_index=req_id_to_index_output_copy,
+ sampled_token_ids=[],
+ logprobs=None,
+ prompt_logprobs_dict={},
+ pooler_output=pooler_output,
+ kv_connector_output=kv_connector_output,
+ num_nans_in_logits={},
+ ec_connector_output=ec_connector_output if self.supports_mm_inputs else None,
+)
+# -------------------------------------- Omni-new -------------------------------------------------
+```
+
+### _dummy_run - Model Kwargs Init and Multimodal Extract
+
+Location: Before model forward and after
+
+```python
+model_kwargs = self._init_model_kwargs() # Before forward
+
+# ... forward ...
+
+# -------------------------------------- Omni-new -------------------------------------------------
+hidden_states, _ = self.extract_multimodal_outputs(hidden_states)
+# -------------------------------------------------------------------------------------------------
+```
+
+---
+
+## ExecuteModelState Extension
+
+The `ExecuteModelState` NamedTuple is extended for omni:
+
+```python
+class ExecuteModelState(NamedTuple):
+ """Ephemeral cached state transferred between execute_model() and
+ sample_tokens(), after execute_model() returns None."""
+
+ scheduler_output: SchedulerOutput
+ logits: torch.Tensor
+ spec_decode_metadata: SpecDecodeMetadata | None
+ spec_decode_common_attn_metadata: AscendCommonAttentionMetadata | None
+ hidden_states: torch.Tensor
+ sample_hidden_states: torch.Tensor
+ aux_hidden_states: list[torch.Tensor] | None
+ attn_metadata: PerLayerAttnMetadata
+ positions: torch.Tensor
+ ec_connector_output: ECConnectorOutput | None
+ cudagraph_stats: CUDAGraphStat | None
+ multimodal_outputs: Any # <-- Omni extension
+```
+
+This extended state must be imported from `npu_ar_model_runner` in `npu_generation_model_runner`.
diff --git a/.claude/skills/vllm-omni-npu-upgrade/references/workflow-checklist.md b/.claude/skills/vllm-omni-npu-upgrade/references/workflow-checklist.md
new file mode 100644
index 0000000000..4f184df0ec
--- /dev/null
+++ b/.claude/skills/vllm-omni-npu-upgrade/references/workflow-checklist.md
@@ -0,0 +1,222 @@
+# NPU Model Runner Upgrade Workflow Checklist
+
+> **Note**: Reference documents (`omni-specific-blocks.md`) may not be complete. Always grep for `Omni-new` in GPU implementations to find all omni-specific code blocks. Update the reference docs when discovering new blocks.
+
+## Pre-Upgrade Preparation
+
+### 1. Version Information
+- [ ] Identify current vllm-omni version: `_________`
+- [ ] Identify target vllm-ascend version: `_________`
+- [ ] Identify target vllm version: `_________`
+- [ ] Last release date for GPU worker changes: `_________`
+
+### 2. Gather Git History
+```bash
+# GPU-side omni changes since last release
+cd /root/vllm-workspace/vllm-omni
+git log --oneline --since="YYYY-MM-DD" -- vllm_omni/worker/
+
+# vllm-ascend NPUModelRunner changes
+cd /root/vllm-workspace/vllm-ascend
+git log --oneline .. -- vllm_ascend/worker/model_runner_v1.py
+```
+
+### 3. Backup Current Files
+- [ ] Create backup of current NPU runners:
+ ```bash
+ cp -r vllm_omni/platforms/npu/worker vllm_omni/platforms/npu/worker.backup
+ ```
+
+---
+
+## OmniNPUModelRunner (npu_model_runner.py)
+
+### Read and Understand
+- [ ] Read current `npu_model_runner.py`
+- [ ] Read latest `vllm_ascend/worker/model_runner_v1.py`
+- [ ] Read latest `vllm_omni/worker/gpu_model_runner.py`
+
+### Method: load_model
+- [ ] Document existing omni-specific logic
+- [ ] Copy latest NPUModelRunner.load_model structure
+- [ ] Re-insert: `enable_sp(self.vllm_config)` call
+- [ ] Re-insert: talker_mtp detection and setup
+- [ ] Replace: `CUDAGraphWrapper` → `ACLGraphWrapper`
+- [ ] Re-insert: Buffer allocations (talker_mtp_input_ids, etc.)
+
+### Method: _dummy_run
+- [ ] Document existing omni-specific logic locations
+- [ ] Copy latest NPUModelRunner._dummy_run
+- [ ] Re-insert: talker_mtp dummy forward block (inside context)
+- [ ] Re-insert: `extract_multimodal_outputs` call
+- [ ] Verify: Comment markers are present
+
+### Method: _model_forward
+- [ ] Copy latest NPUModelRunner._model_forward structure
+- [ ] Re-insert: `_build_model_kwargs_extra()` call
+- [ ] Re-insert: OmniOutput wrapping logic
+- [ ] Re-insert: `_omni_last_model_output` caching
+- [ ] Keep: NPU graph params update
+- [ ] Keep: SP all-gather logic
+
+### Method: _talker_mtp_forward
+- [ ] Verify: Uses `set_ascend_forward_context`
+- [ ] Verify: Uses `ACLGraphWrapper` check
+- [ ] Sync any changes from GPU `_talker_mtp_forward`
+
+### Imports
+- [ ] Update vllm-ascend imports to latest paths
+- [ ] Verify all omni imports are present
+- [ ] Remove any deprecated imports
+
+---
+
+## NPUARModelRunner (npu_ar_model_runner.py)
+
+### Read and Understand
+- [ ] Read current `npu_ar_model_runner.py`
+- [ ] Read latest `vllm_ascend/worker/model_runner_v1.py` execute_model
+- [ ] Read latest `vllm_omni/worker/gpu_ar_model_runner.py`
+
+### Method: __init__
+- [ ] Sync any new initialization from GPU side
+- [ ] Keep: `OmniKVTransferManager` setup
+- [ ] Keep: Custom buffer allocations
+
+### Method: execute_model
+- [ ] Document all omni blocks with line numbers
+- [ ] Copy latest NPUModelRunner.execute_model structure
+- [ ] Re-insert: KV transfer handling (beginning)
+- [ ] Re-insert: Custom `_update_states` call
+- [ ] Re-insert: `extract_multimodal_outputs`
+- [ ] Re-insert: `compute_logits` with sampling_metadata try/except
+- [ ] Update: ExecuteModelState to include multimodal_outputs
+
+### Method: sample_tokens
+- [ ] Document all omni blocks
+- [ ] Copy latest NPUModelRunner.sample_tokens structure
+- [ ] Re-insert: `kv_extracted_req_ids` handling
+- [ ] Re-insert: Hidden states CPU copy
+- [ ] Re-insert: `_process_additional_information_updates`
+- [ ] Re-insert: `OmniModelRunnerOutput` construction
+
+### ExecuteModelState
+- [ ] Verify: `multimodal_outputs` field is present
+- [ ] Verify: Imported/used correctly in execute_model
+
+### Imports
+- [ ] Update all vllm-ascend imports
+- [ ] Keep omni-specific imports
+
+---
+
+## NPUGenerationModelRunner (npu_generation_model_runner.py)
+
+### Read and Understand
+- [ ] Read current `npu_generation_model_runner.py`
+- [ ] Read latest GPU `gpu_generation_model_runner.py`
+
+### Method: _update_request_states
+- [ ] Verify: async_chunk handling is correct
+- [ ] Sync any changes from GPU side
+
+### Method: execute_model
+- [ ] Document all omni blocks
+- [ ] Copy latest NPUModelRunner.execute_model base structure
+- [ ] Re-insert: async_chunk update logic
+- [ ] Re-insert: `seq_token_counts` injection
+- [ ] Re-insert: `_run_generation_model` call
+- [ ] Re-insert: `extract_multimodal_outputs`
+- [ ] Use: ExecuteModelState from npu_ar_model_runner
+
+### Method: sample_tokens
+- [ ] Keep: Entire omni multimodal output processing
+- [ ] Update: Any new output fields needed
+- [ ] Keep: `OmniModelRunnerOutput` construction
+
+### Method: _run_generation_model
+- [ ] Sync any changes from GPU side
+- [ ] Keep: `_model_forward` call with sampler
+
+### Method: _dummy_run
+- [ ] Copy latest NPUModelRunner._dummy_run
+- [ ] Re-insert: `model_kwargs = self._init_model_kwargs()`
+- [ ] Re-insert: `extract_multimodal_outputs` at end
+
+### Imports
+- [ ] Import ExecuteModelState from npu_ar_model_runner
+- [ ] Update vllm-ascend imports
+
+---
+
+## Post-Upgrade Validation
+
+### Syntax Validation
+- [ ] `python -m py_compile vllm_omni/platforms/npu/worker/npu_model_runner.py`
+- [ ] `python -m py_compile vllm_omni/platforms/npu/worker/npu_ar_model_runner.py`
+- [ ] `python -m py_compile vllm_omni/platforms/npu/worker/npu_generation_model_runner.py`
+
+### Import Validation
+- [ ] `python -c "from vllm_omni.platforms.npu.worker.npu_model_runner import OmniNPUModelRunner"`
+- [ ] `python -c "from vllm_omni.platforms.npu.worker.npu_ar_model_runner import NPUARModelRunner"`
+- [ ] `python -c "from vllm_omni.platforms.npu.worker.npu_generation_model_runner import NPUGenerationModelRunner"`
+
+### Comment Markers
+- [ ] Grep for "Omni-new" in all three files
+- [ ] Verify all omni blocks have closing markers
+
+### Code Review
+- [ ] No `CUDAGraphWrapper` references
+- [ ] All `set_forward_context` replaced with `set_ascend_forward_context`
+- [ ] Parameter names correct (`aclgraph_runtime_mode` not `cudagraph_runtime_mode`)
+- [ ] No duplicate code blocks
+- [ ] No missing imports
+
+---
+
+## Git Commit
+
+### Commit Message Template
+```
+[NPU] Upgrade model runners to align with vllm-ascend vX.Y.Z
+
+- Update OmniNPUModelRunner with latest NPUModelRunner base
+- Update NPUARModelRunner execute_model and sample_tokens
+- Update NPUGenerationModelRunner for async_chunk changes
+- Sync GPU-side omni changes from vX.Y.Z release
+- Preserve all omni-specific logic (marked with Omni-new comments)
+
+Changes from vllm-ascend:
+-
+
+Changes synced from GPU:
+-
+```
+
+### Files to Stage
+- [ ] `vllm_omni/platforms/npu/worker/npu_model_runner.py`
+- [ ] `vllm_omni/platforms/npu/worker/npu_ar_model_runner.py`
+- [ ] `vllm_omni/platforms/npu/worker/npu_generation_model_runner.py`
+- [ ] Any other modified files
+
+---
+
+## Troubleshooting
+
+### Import Errors
+- Check if vllm-ascend module paths have changed
+- Verify PYTHONPATH includes both vllm-ascend and vllm-omni
+
+### Type Errors
+- Check method signatures match between GPU and NPU
+- Verify NamedTuple fields match expected structure
+
+### Runtime Errors
+- Enable debug logging: `export VLLM_LOGGING_LEVEL=DEBUG`
+- Check graph capture issues: try `--enforce-eager`
+- Check attention issues: verify AscendAttentionState usage
+
+### Performance Regression
+- Compare with previous version on same model
+- Check if graph capture is working: look for ACLGraph logs
+- Verify SP/EP configurations are correct
diff --git a/.gitignore b/.gitignore
index 7f101a784c..35dc7571ee 100644
--- a/.gitignore
+++ b/.gitignore
@@ -158,7 +158,19 @@ cython_debug/
# Claude
CLAUDE.md
-.claude/
+/.claude/*
+!.claude/skills/
+!.claude/skills/readme.md
+!.claude/skills/add-diffusion-model/
+!.claude/skills/add-diffusion-model/SKILL.md
+!.claude/skills/add-diffusion-model/references/
+!.claude/skills/add-diffusion-model/references/*.md
+!.claude/skills/add-tts-model/
+!.claude/skills/add-tts-model/SKILL.md
+!.claude/skills/review-pr/
+!.claude/skills/review-pr/SKILL.md
+!.claude/skills/review-pr/references/
+!.claude/skills/review-pr/references/*.md
# Codex
AGENTS.md
@@ -191,6 +203,7 @@ checkpoints/
# Cache directories
cache/
!vllm_omni/diffusion/cache/
+!tests/diffusion/cache/
.cache/
diffusion_cache/
kv_cache/
@@ -250,3 +263,5 @@ tmp_test
vllm_omni/_version.py
# output files
*.wav
+# CI overlay yamls materialized from tests/utils.py:_CI_OVERLAYS at test time
+tests/.ci_generated/
diff --git a/benchmarks/build_dataset/download_process_data_seedtts.md b/benchmarks/build_dataset/download_process_data_seedtts.md
index ec16f64424..faf072303b 100644
--- a/benchmarks/build_dataset/download_process_data_seedtts.md
+++ b/benchmarks/build_dataset/download_process_data_seedtts.md
@@ -27,7 +27,7 @@ pip install gdown
Download the dataset from Google Drive:
```bash
-gdown --id 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP
+gdown 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP
```
### 4. Extract the Dataset
@@ -74,7 +74,7 @@ rm meta.lst
# Full setup and benchmark
cd benchmarks/build_dataset
pip install gdown
-gdown --id 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP
+gdown 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP
tar -xf seedtts_testset.tar
cp seedtts_testset/en/meta.lst meta.lst
python extract_tts_prompts.py -i meta.lst -o top100.txt -n 100
diff --git a/benchmarks/diffusion/backends.py b/benchmarks/diffusion/backends.py
index fa53f87aed..13ce7c8309 100644
--- a/benchmarks/diffusion/backends.py
+++ b/benchmarks/diffusion/backends.py
@@ -306,6 +306,8 @@ async def async_request_v1_videos(
video_bytes = await content_response.read()
output.response_body = video_bytes
output.success = True
+ if "stage_durations" in poll_json:
+ output.stage_durations = poll_json["stage_durations"] or {}
if "peak_memory_mb" in poll_json:
output.peak_memory_mb = poll_json["peak_memory_mb"]
elif "peak_memory_mb" in resp_json:
diff --git a/benchmarks/diffusion/diffusion_benchmark_serving.py b/benchmarks/diffusion/diffusion_benchmark_serving.py
index aad955b0d1..32ec48a698 100644
--- a/benchmarks/diffusion/diffusion_benchmark_serving.py
+++ b/benchmarks/diffusion/diffusion_benchmark_serving.py
@@ -558,6 +558,7 @@ def __init__(self, args, api_url: str, model: str, enable_negative_prompt: bool
super().__init__(args, api_url, model)
self.num_prompts = args.num_prompts
self.enable_negative_prompt = enable_negative_prompt
+ self.num_input_images = max(1, args.num_input_images)
self.random_request_config = getattr(args, "random_request_config", None)
if self.random_request_config:
self.random_request_config = json.loads(self.random_request_config)
@@ -580,11 +581,7 @@ def __init__(self, args, api_url: str, model: str, enable_negative_prompt: bool
# Random image generate
if self.args.task in ["i2v", "ti2v", "ti2i", "i2i"]:
- img = Image.new("RGB", (512, 512), (255, 255, 255))
-
- image_path = os.path.join(tempfile.gettempdir(), "diffusion_benchmark_random_image.png")
- self._random_image_path = [image_path]
- img.save(image_path)
+ self._random_image_path = self._generate_random_image_paths()
else:
self._random_image_path = None
@@ -619,6 +616,18 @@ def __getitem__(self, idx: int) -> RequestFuncInput:
def get_requests(self) -> list[RequestFuncInput]:
return [self[i] for i in range(len(self))]
+ def _generate_random_image_paths(self) -> list[str]:
+ image_paths: list[str] = []
+ for image_idx in range(self.num_input_images):
+ img = Image.new("RGB", (512, 512), (255, 255, 255))
+ image_path = os.path.join(
+ tempfile.gettempdir(),
+ f"diffusion_benchmark_random_image_{image_idx}.png",
+ )
+ img.save(image_path)
+ image_paths.append(image_path)
+ return image_paths
+
def _compute_expected_latency_ms_from_base(req: RequestFuncInput, args, base_time_ms: float | None) -> float | None:
"""Compute expected execution time (ms) based on a base per-step-per-frame unit time.
@@ -1115,6 +1124,15 @@ async def limited_request_func(req, session, pbar):
'{"width":768,"height":768,"num_inference_steps":20,"weight":0.85}]'
),
)
+ parser.add_argument(
+ "--num-input-images",
+ type=int,
+ default=1,
+ help=(
+ "Number of synthetic input images to attach for image-conditioned tasks "
+ "(i2v, ti2v, ti2i, i2i) when using random dataset."
+ ),
+ )
args = parser.parse_args()
diff --git a/benchmarks/qwen3-omni/README.md b/benchmarks/qwen3-omni/README.md
index de27c05c2c..dc282d0525 100644
--- a/benchmarks/qwen3-omni/README.md
+++ b/benchmarks/qwen3-omni/README.md
@@ -9,7 +9,7 @@ cd benchmarks/build_dataset
pip install gdown
# Download SeedTTS test set from Google Drive
-gdown --id 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP
+gdown 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP
# Extract
tar -xf seedtts_testset.tar
diff --git a/benchmarks/qwen3-tts/README.md b/benchmarks/qwen3-tts/README.md
index 9c01f29aa9..a1c2ebe12f 100644
--- a/benchmarks/qwen3-tts/README.md
+++ b/benchmarks/qwen3-tts/README.md
@@ -35,8 +35,8 @@ MODEL=Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice bash run_benchmark.sh --async-only
# Use a Voice Clone model
MODEL=Qwen/Qwen3-TTS-12Hz-1.7B-Base TASK_TYPE=Base bash run_benchmark.sh --async-only
-# Use bs16 config for higher throughput
-STAGE_CONFIG=vllm_omni/configs/qwen3_tts_bs16.yaml bash run_benchmark.sh --async-only
+# Use batch size 16 for higher throughput
+BATCH_SIZE=16 bash run_benchmark.sh --async-only
# Custom GPU, prompt count, concurrency levels
GPU_DEVICE=1 NUM_PROMPTS=20 CONCURRENCY="1 4" bash run_benchmark.sh
@@ -50,7 +50,8 @@ GPU_DEVICE=1 NUM_PROMPTS=20 CONCURRENCY="1 4" bash run_benchmark.sh
CUDA_VISIBLE_DEVICES=0 python -m vllm_omni.entrypoints.cli.main serve \
"Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice" \
--omni --host 127.0.0.1 --port 8000 \
- --stage-configs-path benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs1.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
+ --stage-overrides '{"0":{"max_num_seqs":1,"gpu_memory_utilization":0.3,"max_num_batched_tokens":512},"1":{"max_num_seqs":1,"gpu_memory_utilization":0.3,"max_num_batched_tokens":8192}}' \
--trust-remote-code
```
@@ -84,16 +85,19 @@ python benchmarks/qwen3-tts/plot_results.py \
--output results/comparison.png
```
-## Stage Configs
+## Batch-size presets
-| Config | max_num_seqs | Description |
-|--------|:------------:|-------------|
-| `vllm_omni/configs/qwen3_tts_bs1.yaml` | 1 | Single-request processing (lowest latency) |
-| `vllm_omni/configs/qwen3_tts_bs16.yaml` | 16 | High-throughput concurrent processing |
+The bench script loads the bundled production deploy (`vllm_omni/deploy/qwen3_tts.yaml`) and layers per-stage budgets on top via `--stage-overrides`, driven by the `BATCH_SIZE` env var. Each batch size picks compatible per-stage `max_num_seqs`, `max_num_batched_tokens`, and `gpu_memory_utilization` defaults:
-All configs use a 2-stage pipeline (Talker -> Code2Wav) with `async_chunk` streaming enabled. The `SharedMemoryConnector` streams codec frames (25-frame chunks with 25-frame context overlap) between stages.
+| `BATCH_SIZE` | Description |
+|:--:|-------------|
+| `1` (default) | Single-request processing (lowest latency) |
+| `4` | Moderate-throughput concurrent processing |
+| `16` | High-throughput concurrent processing |
-The model is specified via the CLI `--model` flag (or `MODEL` env var), so the same configs work for both the 0.6B and 1.7B model variants.
+The 2-stage pipeline (Talker -> Code2Wav) runs with `async_chunk` streaming enabled via the prod deploy; the `SharedMemoryConnector` streams codec frames (25-frame chunks with 25-frame context overlap) between stages.
+
+The model is specified via the CLI `--model` flag (or `MODEL` env var), so the same bench script works for both the 0.6B and 1.7B model variants.
## Metrics
diff --git a/benchmarks/qwen3-tts/run_benchmark.sh b/benchmarks/qwen3-tts/run_benchmark.sh
index 283b6b844c..8c3e46903c 100755
--- a/benchmarks/qwen3-tts/run_benchmark.sh
+++ b/benchmarks/qwen3-tts/run_benchmark.sh
@@ -26,8 +26,8 @@
# # Use Voice Clone model
# MODEL=Qwen/Qwen3-TTS-12Hz-1.7B-Base TASK_TYPE=Base bash run_benchmark.sh --async-only
#
-# # Use batch_size=4 config:
-# STAGE_CONFIG=vllm_omni/configs/qwen3_tts_bs4.yaml bash run_benchmark.sh --async-only
+# # Use batch_size=4:
+# BATCH_SIZE=4 bash run_benchmark.sh --async-only
#
# Environment variables:
# GPU_DEVICE - GPU index to use (default: 0)
@@ -35,9 +35,9 @@
# CONCURRENCY - Space-separated concurrency levels (default: "1 4 10")
# MODEL - Model name (default: Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice)
# PORT - Server port (default: 8000)
-# GPU_MEM_TALKER - gpu_memory_utilization for talker stage (default: 0.3)
-# GPU_MEM_CODE2WAV - gpu_memory_utilization for code2wav stage (default: 0.2)
-# STAGE_CONFIG - Path to stage config YAML (default: configs/qwen3_tts_bs1.yaml)
+# BATCH_SIZE - Per-stage ``max_num_seqs`` for both talker and code2wav (default: 1)
+# GPU_MEM_TALKER - gpu_memory_utilization for talker stage (default: 0.3 at bs=1, else 0.2)
+# GPU_MEM_CODE2WAV - gpu_memory_utilization for code2wav stage (default: 0.3 at bs=1, else 0.2)
# TASK_TYPE - Task type: CustomVoice, VoiceDesign, Base (default: CustomVoice)
set -euo pipefail
@@ -51,14 +51,36 @@ NUM_PROMPTS="${NUM_PROMPTS:-50}"
CONCURRENCY="${CONCURRENCY:-1 4 10}"
MODEL="${MODEL:-Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice}"
PORT="${PORT:-8000}"
-GPU_MEM_TALKER="${GPU_MEM_TALKER:-0.3}"
-GPU_MEM_CODE2WAV="${GPU_MEM_CODE2WAV:-0.2}"
+BATCH_SIZE="${BATCH_SIZE:-1}"
+DEFAULT_MEM=$([ "${BATCH_SIZE}" = "1" ] && echo "0.3" || echo "0.2")
+GPU_MEM_TALKER="${GPU_MEM_TALKER:-${DEFAULT_MEM}}"
+GPU_MEM_CODE2WAV="${GPU_MEM_CODE2WAV:-${DEFAULT_MEM}}"
NUM_WARMUPS="${NUM_WARMUPS:-3}"
-STAGE_CONFIG="${STAGE_CONFIG:-vllm_omni/configs/qwen3_tts_bs1.yaml}"
+DEPLOY_CONFIG="vllm_omni/deploy/qwen3_tts.yaml"
RESULT_DIR="${SCRIPT_DIR}/results"
TIMESTAMP="$(date +%Y%m%d_%H%M%S)"
TASK_TYPE="${TASK_TYPE:-CustomVoice}"
+# Build --stage-overrides JSON from BATCH_SIZE + GPU_MEM_*.
+STAGE_OVERRIDES=$(
+ BATCH_SIZE="${BATCH_SIZE}" \
+ GPU_MEM_TALKER="${GPU_MEM_TALKER}" \
+ GPU_MEM_CODE2WAV="${GPU_MEM_CODE2WAV}" \
+ python - <<'PYEOF'
+import json, os
+bs = int(os.environ["BATCH_SIZE"])
+mem_t = float(os.environ["GPU_MEM_TALKER"])
+mem_c = float(os.environ["GPU_MEM_CODE2WAV"])
+# Prefill budget grows with batch size on both stages.
+talker_batched = 512 if bs <= 4 else 4096
+code2wav_batched = 8192 if bs <= 4 else 32768
+print(json.dumps({
+ "0": {"max_num_seqs": bs, "gpu_memory_utilization": mem_t, "max_num_batched_tokens": talker_batched},
+ "1": {"max_num_seqs": bs, "gpu_memory_utilization": mem_c, "max_num_batched_tokens": code2wav_batched},
+}))
+PYEOF
+)
+
# Parse args
RUN_ASYNC=true
RUN_HF=true
@@ -75,41 +97,27 @@ mkdir -p "${RESULT_DIR}"
echo "============================================================"
echo " Qwen3-TTS Benchmark"
echo "============================================================"
-echo " GPU: ${GPU_DEVICE}"
-echo " Model: ${MODEL}"
-echo " Prompts: ${NUM_PROMPTS}"
-echo " Concurrency: ${CONCURRENCY}"
-echo " Port: ${PORT}"
-echo " Stage config: ${STAGE_CONFIG}"
-echo " Results: ${RESULT_DIR}"
-echo " Task type: ${TASK_TYPE}"
+echo " GPU: ${GPU_DEVICE}"
+echo " Model: ${MODEL}"
+echo " Prompts: ${NUM_PROMPTS}"
+echo " Concurrency: ${CONCURRENCY}"
+echo " Port: ${PORT}"
+echo " Deploy config: ${DEPLOY_CONFIG}"
+echo " Batch size: ${BATCH_SIZE}"
+echo " GPU mem T/C: ${GPU_MEM_TALKER} / ${GPU_MEM_CODE2WAV}"
+echo " Results: ${RESULT_DIR}"
+echo " Task type: ${TASK_TYPE}"
echo "============================================================"
-# Prepare stage config with correct GPU device and memory settings
-prepare_config() {
- local config_template="$1"
- local config_name="$2"
- local output_path="${RESULT_DIR}/${config_name}_stage_config.yaml"
-
- # Use sed to patch GPU device and memory utilization
- sed \
- -e "s/devices: \"0\"/devices: \"${GPU_DEVICE}\"/g" \
- -e "s/gpu_memory_utilization: 0.3/gpu_memory_utilization: ${GPU_MEM_TALKER}/g" \
- -e "s/gpu_memory_utilization: 0.2/gpu_memory_utilization: ${GPU_MEM_CODE2WAV}/g" \
- "${config_template}" > "${output_path}"
-
- echo "${output_path}"
-}
-
# Start server and wait for it to be ready
start_server() {
- local stage_config="$1"
- local config_name="$2"
+ local config_name="$1"
local log_file="${RESULT_DIR}/server_${config_name}_${TIMESTAMP}.log"
echo ""
echo "Starting server with config: ${config_name}"
- echo " Stage config: ${stage_config}"
+ echo " Deploy config: ${DEPLOY_CONFIG}"
+ echo " Stage overrides: ${STAGE_OVERRIDES}"
echo " Log file: ${log_file}"
VLLM_WORKER_MULTIPROC_METHOD=spawn \
@@ -118,7 +126,8 @@ start_server() {
--omni \
--host 127.0.0.1 \
--port "${PORT}" \
- --stage-configs-path "${stage_config}" \
+ --deploy-config "${DEPLOY_CONFIG}" \
+ --stage-overrides "${STAGE_OVERRIDES}" \
--stage-init-timeout 120 \
--trust-remote-code \
--disable-log-stats \
@@ -175,17 +184,13 @@ trap 'stop_server' EXIT
# Run benchmark for a given config
run_bench() {
local config_name="$1"
- local config_template="$2"
echo ""
echo "============================================================"
echo " Benchmarking: ${config_name}"
echo "============================================================"
- local stage_config
- stage_config=$(prepare_config "${config_template}" "${config_name}")
-
- start_server "${stage_config}" "${config_name}"
+ start_server "${config_name}"
# Convert concurrency string to args
local conc_args=""
@@ -212,7 +217,7 @@ run_bench() {
# Run vllm-omni benchmark
if [ "${RUN_ASYNC}" = true ]; then
- run_bench "async_chunk" "${SCRIPT_DIR}/${STAGE_CONFIG}"
+ run_bench "async_chunk"
fi
# Run HuggingFace baseline benchmark
diff --git a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs16.yaml b/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs16.yaml
deleted file mode 100644
index 2cc5cf5353..0000000000
--- a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs16.yaml
+++ /dev/null
@@ -1,94 +0,0 @@
-# Qwen3-TTS max_num_seqs=16 config (streaming with async_chunk)
-# High-throughput concurrent request processing
-# 2-stage pipeline: Talker -> Code2Wav
-async_chunk: true
-stage_args:
- - stage_id: 0
- stage_type: llm
- is_comprehension: true
- runtime:
- devices: "0"
- engine_args:
- max_num_seqs: 16
- model_stage: qwen3_tts
- model_arch: Qwen3TTSTalkerForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- enforce_eager: false
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: latent
- gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 4096
- max_model_len: 4096
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: false
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 1
- stage_type: llm
- runtime:
- devices: "0"
- engine_args:
- max_num_seqs: 16
- model_stage: code2wav
- model_arch: Qwen3TTSCode2Wav
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: audio
- gpu_memory_utilization: 0.2
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 16384
- max_model_len: 32768
- engine_input_source: [0]
- final_output: true
- final_output_type: audio
- input_connectors:
- from_stage_0: connector_of_shared_memory
- tts_args:
- max_instructions_length: 500
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: true
- repetition_penalty: 1.0
-
-runtime:
- enabled: true
- defaults:
- window_size: -1
- max_inflight: 16
-
- connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536
- codec_streaming: true
- connector_get_sleep_s: 0.01
- connector_get_max_wait_first_chunk: 3000
- connector_get_max_wait: 300
- codec_chunk_frames: 25
- codec_left_context_frames: 25
-
- edges:
- - from: 0
- to: 1
- window_size: -1
diff --git a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs4.yaml b/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs4.yaml
deleted file mode 100644
index 5de107d497..0000000000
--- a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs4.yaml
+++ /dev/null
@@ -1,94 +0,0 @@
-# Qwen3-TTS batch_size=4 config (streaming with async_chunk)
-# Enables concurrent request processing
-# 2-stage pipeline: Talker -> Code2Wav
-async_chunk: true
-stage_args:
- - stage_id: 0
- stage_type: llm
- is_comprehension: true
- runtime:
- devices: "0"
- engine_args:
- max_num_seqs: 4
- model_stage: qwen3_tts
- model_arch: Qwen3TTSTalkerForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- enforce_eager: false
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: latent
- gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 512
- max_model_len: 4096
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: false
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 1
- stage_type: llm
- runtime:
- devices: "0"
- engine_args:
- max_num_seqs: 4
- model_stage: code2wav
- model_arch: Qwen3TTSCode2Wav
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: audio
- gpu_memory_utilization: 0.2
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 8192
- max_model_len: 32768
- engine_input_source: [0]
- final_output: true
- final_output_type: audio
- input_connectors:
- from_stage_0: connector_of_shared_memory
- tts_args:
- max_instructions_length: 500
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: true
- repetition_penalty: 1.0
-
-runtime:
- enabled: true
- defaults:
- window_size: -1
- max_inflight: 4
-
- connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536
- codec_streaming: true
- connector_get_sleep_s: 0.01
- connector_get_max_wait_first_chunk: 3000
- connector_get_max_wait: 300
- codec_chunk_frames: 25
- codec_left_context_frames: 25
-
- edges:
- - from: 0
- to: 1
- window_size: -1
diff --git a/benchmarks/qwen3-tts/vllm_omni/run_async_chunk_benchmark.sh b/benchmarks/qwen3-tts/vllm_omni/run_async_chunk_benchmark.sh
index 61cf7757a9..0ede359ea3 100755
--- a/benchmarks/qwen3-tts/vllm_omni/run_async_chunk_benchmark.sh
+++ b/benchmarks/qwen3-tts/vllm_omni/run_async_chunk_benchmark.sh
@@ -31,8 +31,11 @@ PORT_OFF="${PORT_OFF:-8001}"
RESULT_DIR="${SCRIPT_DIR}/results"
TIMESTAMP="$(date +%Y%m%d_%H%M%S)"
-STAGE_CONFIG_ON="vllm_omni/model_executor/stage_configs/qwen3_tts.yaml"
-STAGE_CONFIG_OFF="vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml"
+# The bundled ``vllm_omni/deploy/qwen3_tts.yaml`` is auto-loaded by the model
+# registry; no ``--deploy-config`` flag needed on the default (ON) path.
+# async_chunk OFF is selected by the ``--no-async-chunk`` CLI flag —
+# the single ``qwen3_tts`` pipeline dispatches to the end-to-end codec
+# processor when ``deploy.async_chunk`` is false.
mkdir -p "${RESULT_DIR}"
@@ -77,7 +80,6 @@ wait_for_server() {
echo ""
echo "[Phase 1] Starting async_chunk ON server on port ${PORT_ON}..."
CUDA_VISIBLE_DEVICES=${GPU_DEVICE} vllm-omni serve "${MODEL}" \
- --stage-configs-path "${STAGE_CONFIG_ON}" \
--host 0.0.0.0 --port "${PORT_ON}" \
--trust-remote-code --enforce-eager --omni \
> "${RESULT_DIR}/server_on_${TIMESTAMP}.log" 2>&1 &
@@ -104,7 +106,7 @@ sleep 5
echo ""
echo "[Phase 2] Starting async_chunk OFF server on port ${PORT_OFF}..."
CUDA_VISIBLE_DEVICES=${GPU_DEVICE} vllm-omni serve "${MODEL}" \
- --stage-configs-path "${STAGE_CONFIG_OFF}" \
+ --no-async-chunk \
--host 0.0.0.0 --port "${PORT_OFF}" \
--trust-remote-code --enforce-eager --omni \
> "${RESULT_DIR}/server_off_${TIMESTAMP}.log" 2>&1 &
diff --git a/benchmarks/voxcpm/README.md b/benchmarks/voxcpm/README.md
new file mode 100644
index 0000000000..17f904101b
--- /dev/null
+++ b/benchmarks/voxcpm/README.md
@@ -0,0 +1,119 @@
+# VoxCPM Benchmark
+
+This directory contains both:
+
+- online serving benchmark through the OpenAI-compatible `/v1/audio/speech` API
+- offline benchmark for `Omni` / `AsyncOmni`
+- full offline smoke-matrix orchestration
+
+Both benchmark paths report:
+
+- TTFP: time to first PCM packet
+- E2E latency
+- RTF: real-time factor (`e2e / audio_duration`)
+
+## Offline Benchmark
+
+Single offline benchmark run:
+
+```bash
+python benchmarks/voxcpm/vllm_omni/bench_tts_offline.py \
+ --model /path/to/voxcpm-model \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm.yaml \
+ --text "This is a split-stage VoxCPM synthesis example running on vLLM Omni." \
+ --warmup-runs 1 \
+ --output-dir benchmarks/voxcpm/results/offline_single
+```
+
+Streaming offline benchmark:
+
+```bash
+python benchmarks/voxcpm/vllm_omni/bench_tts_offline.py \
+ --model /path/to/voxcpm-model \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml \
+ --text "This is a split-stage VoxCPM streaming example running on vLLM Omni." \
+ --warmup-runs 1 \
+ --output-dir benchmarks/voxcpm/results/offline_streaming
+```
+
+Full fixed offline matrix, equivalent to the old `examples/offline_inference/voxcpm/test.py`:
+
+```bash
+python benchmarks/voxcpm/vllm_omni/run_offline_matrix.py \
+ --model /path/to/voxcpm-model \
+ --ref-audio /path/to/reference.wav \
+ --ref-text "The exact transcript spoken in reference.wav." \
+ --output-root benchmarks/voxcpm/results/offline_matrix
+```
+
+The full matrix covers both routes:
+
+- streaming: `voxcpm_async_chunk.yaml`
+- sync: `voxcpm.yaml`
+
+And these six scenarios under each route:
+
+- warmup + single TTS
+- warmup + single voice cloning
+- warmup + batch TTS
+- warmup + batch voice cloning
+- cold single TTS
+- cold single voice cloning
+
+`bench_tts_offline.py` itself no longer writes `summary.json` / `results.json`; it prints TTFP / RTF inline and saves generated WAV files only. The matrix runner keeps only per-case `run.log`.
+
+## Start the Server
+
+Async-chunk:
+
+```bash
+vllm serve /path/to/voxcpm-model \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml \
+ --trust-remote-code \
+ --enforce-eager \
+ --omni \
+ --port 8091
+```
+
+Non-streaming:
+
+```bash
+vllm serve /path/to/voxcpm-model \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm.yaml \
+ --trust-remote-code \
+ --enforce-eager \
+ --omni \
+ --port 8091
+```
+
+## Run the Benchmark
+
+```bash
+python benchmarks/voxcpm/vllm_omni/bench_tts_serve.py \
+ --host 127.0.0.1 \
+ --port 8091 \
+ --num-prompts 20 \
+ --max-concurrency 1 \
+ --result-dir /tmp/voxcpm_bench
+```
+
+Voice cloning benchmark:
+
+```bash
+python benchmarks/voxcpm/vllm_omni/bench_tts_serve.py \
+ --host 127.0.0.1 \
+ --port 8091 \
+ --num-prompts 10 \
+ --max-concurrency 1 \
+ --ref-audio https://example.com/reference.wav \
+ --ref-text "The exact transcript spoken in the reference audio." \
+ --result-dir /tmp/voxcpm_clone_bench
+```
+
+## Notes
+
+- The benchmark uses `stream=true` and `response_format=pcm` so TTFP is measured from the first audio packet.
+- `RTF < 1.0` means the server generates audio faster than real time.
+- For `voxcpm_async_chunk.yaml`, keep concurrency at `1`. This matches native VoxCPM streaming more closely.
+- Do not benchmark concurrent online streaming on `voxcpm_async_chunk.yaml`; use `voxcpm.yaml` for multi-request throughput runs.
+- For the offline matrix mode, `--ref-audio` and `--ref-text` are required because clone cases are part of the fixed coverage set.
diff --git a/benchmarks/voxcpm/vllm_omni/bench_tts_offline.py b/benchmarks/voxcpm/vllm_omni/bench_tts_offline.py
new file mode 100644
index 0000000000..a3bad3e692
--- /dev/null
+++ b/benchmarks/voxcpm/vllm_omni/bench_tts_offline.py
@@ -0,0 +1,890 @@
+"""Offline VoxCPM benchmark for vLLM Omni.
+
+Supports both:
+- sync one-shot (Omni.generate)
+- streaming (AsyncOmni.generate with async_chunk config)
+- text-only synthesis
+- voice cloning
+- text/clone batch inputs from txt or jsonl
+"""
+
+from __future__ import annotations
+
+import asyncio
+import json
+import logging
+import os
+import tempfile
+import time
+import uuid
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+
+import torch
+from vllm.utils.argparse_utils import FlexibleArgumentParser
+
+from vllm_omni import AsyncOmni, Omni
+
+REPO_ROOT = Path(__file__).resolve().parents[3]
+DEFAULT_STAGE_ASYNC = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm_async_chunk.yaml"
+DEFAULT_STAGE_SYNC = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm.yaml"
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass(frozen=True, slots=True)
+class PromptSpec:
+ text: str
+ label: str
+ ref_audio: str | None = None
+ ref_text: str | None = None
+
+
+def _require_soundfile():
+ try:
+ import soundfile as sf # type: ignore
+ except ModuleNotFoundError as exc:
+ raise RuntimeError(
+ "soundfile is required to write VoxCPM benchmark WAV outputs. Install it with: pip install soundfile"
+ ) from exc
+ return sf
+
+
+def _build_prompt(
+ args,
+ *,
+ text: str,
+ ref_audio: str | None = None,
+ ref_text: str | None = None,
+ global_request_id: str | None = None,
+) -> dict[str, Any]:
+ additional_information: dict[str, list[Any]] = {
+ "text": [text],
+ "cfg_value": [args.cfg_value],
+ "inference_timesteps": [args.inference_timesteps],
+ "min_len": [args.min_len],
+ "max_new_tokens": [args.max_new_tokens],
+ }
+ if args.streaming_prefix_len is not None:
+ additional_information["streaming_prefix_len"] = [args.streaming_prefix_len]
+
+ if ref_audio:
+ additional_information["ref_audio"] = [ref_audio]
+ if ref_text:
+ additional_information["ref_text"] = [ref_text]
+ if global_request_id is not None:
+ additional_information["global_request_id"] = [global_request_id]
+
+ return {
+ "prompt_token_ids": [1],
+ "additional_information": additional_information,
+ }
+
+
+def _extract_audio_tensor(mm: dict[str, Any]) -> torch.Tensor:
+ audio = mm.get("audio", mm.get("model_outputs"))
+ if audio is None:
+ raise ValueError("No audio output found in multimodal output.")
+ if isinstance(audio, list):
+ parts = [torch.as_tensor(a).float().cpu().reshape(-1) for a in audio]
+ audio = torch.cat(parts, dim=-1) if parts else torch.zeros(0)
+ if not isinstance(audio, torch.Tensor):
+ audio = torch.as_tensor(audio)
+ return audio.float().cpu().reshape(-1)
+
+
+def _extract_sample_rate(mm: dict[str, Any]) -> int:
+ sr_raw = mm.get("sr", 24000)
+ if isinstance(sr_raw, list) and sr_raw:
+ sr_raw = sr_raw[-1]
+ if hasattr(sr_raw, "item"):
+ return int(sr_raw.item())
+ return int(sr_raw)
+
+
+def _emit_offline_metrics(
+ *,
+ request_id: str,
+ elapsed_s: float,
+ first_audio_elapsed: float | None,
+ audio_duration_s: float,
+) -> None:
+ metrics = {
+ "request_id": request_id,
+ "ttfp_ms": round(first_audio_elapsed * 1000.0, 3) if first_audio_elapsed is not None else None,
+ "audio_duration_s": round(audio_duration_s, 6),
+ "rtf": round(elapsed_s / audio_duration_s, 6) if audio_duration_s > 0 else None,
+ }
+ print(f"[OfflineMetrics] {metrics}")
+
+
+def _write_audio_tensor(output_path: Path, audio_tensor: Any, sample_rate: int) -> None:
+ sf = _require_soundfile()
+ if isinstance(audio_tensor, torch.Tensor):
+ audio_np = audio_tensor.float().cpu().clamp(-1.0, 1.0).numpy()
+ else:
+ audio_np = torch.as_tensor(audio_tensor).float().cpu().clamp(-1.0, 1.0).numpy()
+ sf.write(
+ output_path,
+ audio_np,
+ sample_rate,
+ format="WAV",
+ subtype="PCM_16",
+ )
+
+
+def _save_wav(mm: dict[str, Any], output_dir: Path, request_id: str) -> Path:
+ output_dir.mkdir(parents=True, exist_ok=True)
+ output_path = output_dir / f"output_{request_id}.wav"
+ _write_audio_tensor(output_path, _extract_audio_tensor(mm), _extract_sample_rate(mm))
+ return output_path
+
+
+def _iter_request_multimodal_outputs(request_output: Any):
+ outputs = getattr(request_output, "outputs", None)
+ if outputs:
+ for output in outputs:
+ mm = getattr(output, "multimodal_output", None)
+ if isinstance(mm, dict):
+ yield mm
+
+ mm = getattr(request_output, "multimodal_output", None)
+ if isinstance(mm, dict):
+ yield mm
+
+
+def _read_non_empty_lines(path: str) -> list[str]:
+ with open(path, encoding="utf-8") as f:
+ return [line.strip() for line in f if line.strip()]
+
+
+def _load_prompt_specs(args) -> list[PromptSpec]:
+ specs: list[PromptSpec] = []
+
+ if args.txt_prompts is not None:
+ texts = _read_non_empty_lines(args.txt_prompts)
+ if not texts:
+ raise ValueError(f"No prompts found in {args.txt_prompts}")
+ for idx, text in enumerate(texts, start=1):
+ specs.append(
+ PromptSpec(
+ text=text,
+ label=f"item{idx:03d}",
+ ref_audio=args.ref_audio,
+ ref_text=args.ref_text,
+ )
+ )
+ return specs
+
+ if args.jsonl_prompts is not None:
+ with open(args.jsonl_prompts, encoding="utf-8") as f:
+ for line_no, raw_line in enumerate(f, start=1):
+ line = raw_line.strip()
+ if not line:
+ continue
+ try:
+ item = json.loads(line)
+ except json.JSONDecodeError as exc:
+ raise ValueError(f"{args.jsonl_prompts}:{line_no} is not valid JSON: {exc}") from exc
+ if not isinstance(item, dict):
+ raise ValueError(f"{args.jsonl_prompts}:{line_no} must be a JSON object")
+
+ text = item.get("text")
+ if not isinstance(text, str) or not text.strip():
+ raise ValueError(f"{args.jsonl_prompts}:{line_no} requires non-empty string field 'text'")
+
+ ref_audio = item.get("ref_audio", args.ref_audio)
+ ref_text = item.get("ref_text", args.ref_text)
+ if (ref_audio is None) != (ref_text is None):
+ raise ValueError(
+ f"{args.jsonl_prompts}:{line_no} must provide both 'ref_audio' and 'ref_text' together"
+ )
+
+ specs.append(
+ PromptSpec(
+ text=text.strip(),
+ label=f"item{len(specs) + 1:03d}",
+ ref_audio=ref_audio,
+ ref_text=ref_text,
+ )
+ )
+
+ if not specs:
+ raise ValueError(f"No prompts found in {args.jsonl_prompts}")
+ return specs
+
+ specs.append(
+ PromptSpec(
+ text=args.text,
+ label="item001",
+ ref_audio=args.ref_audio,
+ ref_text=args.ref_text,
+ )
+ )
+ return specs
+
+
+def _build_prompt_for_spec(args, spec: PromptSpec, *, global_request_id: str | None = None) -> dict[str, Any]:
+ return _build_prompt(
+ args,
+ text=spec.text,
+ ref_audio=spec.ref_audio,
+ ref_text=spec.ref_text,
+ global_request_id=global_request_id,
+ )
+
+
+def _count_voice_clone_prompts(prompt_specs: list[PromptSpec]) -> int:
+ return sum(1 for spec in prompt_specs if spec.ref_audio is not None)
+
+
+def _get_warmup_specs(prompt_specs: list[PromptSpec]) -> list[PromptSpec]:
+ return prompt_specs[:1]
+
+
+def _extract_stream_finished(stage_output: Any) -> bool:
+ request_output = getattr(stage_output, "request_output", None)
+ request_finished = getattr(request_output, "finished", None)
+ if request_finished is not None:
+ return bool(request_finished)
+ return bool(getattr(stage_output, "finished", False))
+
+
+def _build_profiled_stage_config(
+ stage_configs_path: str,
+ profiler_dir: str,
+) -> str:
+ stage_config_path = Path(stage_configs_path)
+ yaml_text = stage_config_path.read_text(encoding="utf-8")
+ injected_lines: list[str] = []
+ injected_count = 0
+
+ for line in yaml_text.splitlines():
+ injected_lines.append(line)
+ if line.strip() != "engine_args:":
+ continue
+ indent = line[: len(line) - len(line.lstrip())]
+ child_indent = indent + " "
+ grandchild_indent = child_indent + " "
+ injected_lines.extend(
+ [
+ f"{child_indent}profiler_config:",
+ f'{grandchild_indent}profiler: "torch"',
+ f'{grandchild_indent}torch_profiler_dir: "{profiler_dir}"',
+ f"{grandchild_indent}torch_profiler_with_stack: true",
+ ]
+ )
+ injected_count += 1
+
+ if injected_count == 0:
+ raise ValueError(f"No engine_args block found in stage config: {stage_configs_path}")
+
+ tmp = tempfile.NamedTemporaryFile(
+ mode="w",
+ encoding="utf-8",
+ delete=False,
+ suffix=".yaml",
+ prefix=f"{stage_config_path.stem}_profile_",
+ )
+ tmp.write("\n".join(injected_lines) + "\n")
+ tmp.close()
+ return tmp.name
+
+
+def parse_args():
+ parser = FlexibleArgumentParser(
+ description="Offline split-stage VoxCPM inference with vLLM Omni (auto sync/streaming by stage config)"
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default=os.environ.get("VOXCPM_MODEL"),
+ help="Local VoxCPM model directory. Defaults to $VOXCPM_MODEL.",
+ )
+ parser.add_argument(
+ "--text",
+ type=str,
+ default="This is a split-stage VoxCPM synthesis example running on vLLM Omni.",
+ help="Text to synthesize. Ignored when --txt-prompts or --jsonl-prompts is used.",
+ )
+ parser.add_argument(
+ "--txt-prompts",
+ type=str,
+ default=None,
+ help="Path to a .txt file with one synthesis text per line.",
+ )
+ parser.add_argument(
+ "--jsonl-prompts",
+ type=str,
+ default=None,
+ help=(
+ "Path to a .jsonl file. Each line must contain at least {'text': ...}; "
+ "clone rows can also set ref_audio/ref_text, and ref_text must be the "
+ "real transcript of ref_audio."
+ ),
+ )
+ parser.add_argument(
+ "--ref-audio",
+ type=str,
+ default=None,
+ help=(
+ "Optional reference audio path for voice cloning. With --txt-prompts, "
+ "the same reference is applied to every line."
+ ),
+ )
+ parser.add_argument(
+ "--ref-text",
+ type=str,
+ default=None,
+ help=(
+ "Real transcript of the reference audio. Placeholder text or mismatched "
+ "text will usually produce noisy/electronic clone audio."
+ ),
+ )
+ parser.add_argument(
+ "--stage-configs-path",
+ type=str,
+ default=str(DEFAULT_STAGE_SYNC),
+ help="Stage config YAML path. Routing is selected only from this path.",
+ )
+ parser.add_argument(
+ "--cfg-value",
+ type=float,
+ default=2.0,
+ help="Classifier-free guidance value for VoxCPM.",
+ )
+ parser.add_argument(
+ "--inference-timesteps",
+ type=int,
+ default=10,
+ help="Number of inference timesteps.",
+ )
+ parser.add_argument(
+ "--min-len",
+ type=int,
+ default=2,
+ help="Minimum generated token length.",
+ )
+ parser.add_argument(
+ "--max-new-tokens",
+ type=int,
+ default=4096,
+ help="Maximum generated token length.",
+ )
+ parser.add_argument(
+ "--streaming-prefix-len",
+ type=int,
+ default=None,
+ help="VoxCPM streaming window (optional, streaming mode only).",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default=None,
+ help="Directory for output WAV files.",
+ )
+ parser.add_argument(
+ "--stage-init-timeout",
+ type=int,
+ default=600,
+ help="Stage initialization timeout in seconds.",
+ )
+ parser.add_argument(
+ "--log-stats",
+ dest="log_stats",
+ action="store_true",
+ help="Enable vLLM Omni stats logging.",
+ )
+ parser.add_argument(
+ "--no-log-stats",
+ dest="log_stats",
+ action="store_false",
+ help="Disable vLLM Omni stats logging.",
+ )
+ parser.set_defaults(log_stats=True)
+ parser.add_argument(
+ "--num-runs",
+ type=int,
+ default=1,
+ help="Number of full inference runs (same prompt each time). Default 1.",
+ )
+ parser.add_argument(
+ "--warmup-runs",
+ type=int,
+ default=0,
+ help=(
+ "Optional number of warmup passes before measured runs. Warmup uses only "
+ "the first prompt and does not save outputs."
+ ),
+ )
+ parser.add_argument(
+ "--enable-profiler",
+ action="store_true",
+ help=(
+ "Enable torch profiler for the configured stages. A temporary profiled "
+ "stage config is generated automatically."
+ ),
+ )
+ parser.add_argument(
+ "--profiler-dir",
+ type=str,
+ default=None,
+ help="Directory for profiler traces. Defaults to /profiler when profiling is enabled.",
+ )
+ parser.add_argument(
+ "--profiler-stages",
+ type=int,
+ nargs="*",
+ default=None,
+ help="Optional stage ids to profile. Defaults to all stages that have profiler_config.",
+ )
+ parser.add_argument(
+ "--profiler-wait-seconds",
+ type=float,
+ default=30.0,
+ help="Seconds to wait after stop_profile for trace files to flush.",
+ )
+ args = parser.parse_args()
+
+ if not args.model:
+ parser.error("--model is required unless $VOXCPM_MODEL is set")
+ if args.txt_prompts is not None and args.jsonl_prompts is not None:
+ parser.error("--txt-prompts and --jsonl-prompts are mutually exclusive")
+ if (args.ref_audio is None) != (args.ref_text is None):
+ parser.error("--ref-audio and --ref-text must be provided together")
+ if args.num_runs < 1:
+ parser.error("--num-runs must be >= 1")
+ if args.warmup_runs < 0:
+ parser.error("--warmup-runs must be >= 0")
+ if args.output_dir is None:
+ args.output_dir = (
+ "output_audio_streaming" if _is_streaming_stage_config(args.stage_configs_path) else "output_audio"
+ )
+ if args.enable_profiler and args.profiler_dir is None:
+ args.profiler_dir = str(Path(args.output_dir) / "profiler")
+ try:
+ args.prompt_specs = _load_prompt_specs(args)
+ except ValueError as exc:
+ parser.error(str(exc))
+
+ return args
+
+
+def _is_streaming_stage_config(stage_configs_path: str) -> bool:
+ cfg_name = Path(stage_configs_path).name.lower()
+ # Keep routing purely config-path based:
+ # - voxcpm.yaml => sync
+ # - voxcpm_async_chunk.yaml => streaming
+ return "async_chunk" in cfg_name
+
+
+async def _collect_streaming_audio(
+ omni: AsyncOmni,
+ args: Any,
+ spec: PromptSpec,
+ request_id: str,
+ *,
+ phase_label: str,
+ prompt_index: int,
+ prompt_count: int,
+ print_prompt: bool = False,
+) -> tuple[torch.Tensor, int, float, float | None]:
+ prompt = _build_prompt_for_spec(args, spec, global_request_id=request_id)
+ delta_chunks: list[torch.Tensor] = []
+ sample_rate = 24000
+ chunk_i = 0
+ prev_total_samples = 0
+ t_start = time.perf_counter()
+ first_audio_elapsed: float | None = None
+
+ if print_prompt:
+ print(f"---prompt---:{prompt}")
+
+ async for stage_output in omni.generate(prompt, request_id=request_id):
+ mm = getattr(stage_output, "multimodal_output", None)
+ if not isinstance(mm, dict):
+ ro = getattr(stage_output, "request_output", None)
+ if ro is None:
+ continue
+ mm = getattr(ro, "multimodal_output", None)
+ if not isinstance(mm, dict) and getattr(ro, "outputs", None):
+ seq = ro.outputs[0]
+ mm = getattr(seq, "multimodal_output", None)
+ if not isinstance(mm, dict):
+ continue
+ sample_rate = _extract_sample_rate(mm)
+ try:
+ w = _extract_audio_tensor(mm)
+ n = int(w.numel())
+ if n == 0:
+ continue
+ finished = _extract_stream_finished(stage_output)
+ if n > prev_total_samples:
+ delta = w.reshape(-1)[prev_total_samples:]
+ prev_total_samples = n
+ elif finished and n == prev_total_samples:
+ delta = w.reshape(-1)[:0]
+ else:
+ delta = w.reshape(-1)
+ prev_total_samples += int(delta.numel())
+ if int(delta.numel()) > 0:
+ delta_chunks.append(delta)
+ if first_audio_elapsed is None and int(delta.numel()) > 0:
+ first_audio_elapsed = time.perf_counter() - t_start
+ logger.info(
+ "%s prompt=%d/%d chunk=%d delta_samples=%d buf_len=%d finished=%s",
+ phase_label,
+ prompt_index + 1,
+ prompt_count,
+ chunk_i,
+ int(delta.numel()),
+ n,
+ finished,
+ )
+ chunk_i += 1
+ except ValueError:
+ if not _extract_stream_finished(stage_output):
+ logger.debug("skip non-audio partial output chunk=%d", chunk_i)
+
+ if not delta_chunks:
+ raise RuntimeError("No audio chunks received; check stage config and logs.")
+
+ audio_cat = torch.cat([c.reshape(-1) for c in delta_chunks], dim=0)
+ elapsed = time.perf_counter() - t_start
+ return audio_cat, sample_rate, elapsed, first_audio_elapsed
+
+
+async def _abort_streaming_residual_work(
+ omni: AsyncOmni,
+ request_id: str,
+ *,
+ settle_seconds: float = 0.1,
+) -> None:
+ """Stop any late stage-0 work once the final audio has been collected."""
+ await omni.engine.abort_async([request_id])
+ if settle_seconds > 0:
+ await asyncio.sleep(settle_seconds)
+
+
+async def _run_streaming_single(
+ omni: AsyncOmni,
+ args: Any,
+ spec: PromptSpec,
+ output_dir: Path,
+ request_id: str,
+ *,
+ run_index: int,
+ num_runs: int,
+ prompt_index: int,
+ prompt_count: int,
+) -> Path:
+ audio_cat, sample_rate, elapsed, first_audio_elapsed = await _collect_streaming_audio(
+ omni,
+ args,
+ spec,
+ request_id,
+ phase_label=f"run={run_index + 1}/{num_runs}",
+ prompt_index=prompt_index,
+ prompt_count=prompt_count,
+ print_prompt=(run_index == 0 and prompt_index == 0),
+ )
+ await _abort_streaming_residual_work(omni, request_id)
+ output_path = output_dir / f"output_run{run_index + 1}_{spec.label}.wav"
+ _write_audio_tensor(output_path, audio_cat, sample_rate)
+ audio_duration_s = float(audio_cat.numel()) / float(sample_rate) if sample_rate > 0 else 0.0
+ ttfp_text = f", ttfp={first_audio_elapsed:.2f}s" if first_audio_elapsed is not None else ""
+ rtf_text = f", rtf={elapsed / audio_duration_s:.3f}" if audio_duration_s > 0 else ""
+ print(
+ f"Saved (streaming) run {run_index + 1}/{num_runs}, "
+ f"prompt {prompt_index + 1}/{prompt_count}: {output_path} ({elapsed:.2f}s{ttfp_text}{rtf_text})"
+ )
+ _emit_offline_metrics(
+ request_id=request_id,
+ elapsed_s=elapsed,
+ first_audio_elapsed=first_audio_elapsed,
+ audio_duration_s=audio_duration_s,
+ )
+ return output_path
+
+
+async def _run_streaming_warmup(args, omni: AsyncOmni) -> None:
+ if args.warmup_runs == 0:
+ return
+
+ warmup_specs = _get_warmup_specs(args.prompt_specs)
+ print(
+ f"Warmup: {args.warmup_runs} run(s) using the first prompt "
+ f"({len(warmup_specs)} prompt(s)); outputs will be discarded."
+ )
+ for warmup_index in range(args.warmup_runs):
+ t_warmup = time.perf_counter()
+ tasks = []
+ request_ids: list[str] = []
+ for prompt_index, spec in enumerate(warmup_specs):
+ request_id = f"warmup_stream_{warmup_index + 1}_{spec.label}_{uuid.uuid4().hex[:8]}"
+ request_ids.append(request_id)
+ tasks.append(
+ _collect_streaming_audio(
+ omni,
+ args,
+ spec,
+ request_id,
+ phase_label=f"warmup={warmup_index + 1}/{args.warmup_runs}",
+ prompt_index=prompt_index,
+ prompt_count=len(warmup_specs),
+ )
+ )
+ results = await asyncio.gather(*tasks)
+ for request_id in request_ids:
+ await _abort_streaming_residual_work(omni, request_id)
+ total_samples = sum(int(audio.numel()) for audio, _, _, _ in results)
+ warmup_ttfps = [ttfp for _, _, _, ttfp in results if ttfp is not None]
+ ttfp_text = f", ttfp={min(warmup_ttfps):.2f}s" if warmup_ttfps else ""
+ print(
+ f"Warmup (streaming) {warmup_index + 1}/{args.warmup_runs} finished: "
+ f"{len(results)} prompt(s), {total_samples} sample(s) "
+ f"({time.perf_counter() - t_warmup:.2f}s{ttfp_text})"
+ )
+
+
+async def _run_streaming(args) -> list[Path]:
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ omni = AsyncOmni(
+ model=args.model,
+ stage_configs_path=args.stage_configs_path,
+ log_stats=args.log_stats,
+ stage_init_timeout=args.stage_init_timeout,
+ )
+
+ await _run_streaming_warmup(args, omni)
+ profiler_started = False
+ if args.enable_profiler:
+ profile_prefix = f"voxcpm_streaming_{int(time.time())}"
+ stages_text = args.profiler_stages if args.profiler_stages is not None else "all-configured"
+ print(f"Starting profiler (streaming): stages={stages_text}, dir={args.profiler_dir}")
+ await omni.start_profile(profile_prefix=profile_prefix, stages=args.profiler_stages)
+ profiler_started = True
+ t_total = time.perf_counter()
+ total_elapsed = 0.0
+ paths: list[Path] = []
+ prompt_specs: list[PromptSpec] = args.prompt_specs
+ try:
+ for run in range(args.num_runs):
+ for prompt_index, spec in enumerate(prompt_specs):
+ request_id = f"stream_{run + 1}_{spec.label}_{uuid.uuid4().hex[:8]}"
+ paths.append(
+ await _run_streaming_single(
+ omni,
+ args,
+ spec,
+ output_dir,
+ request_id,
+ run_index=run,
+ num_runs=args.num_runs,
+ prompt_index=prompt_index,
+ prompt_count=len(prompt_specs),
+ )
+ )
+ total_elapsed = time.perf_counter() - t_total
+ finally:
+ if profiler_started:
+ print("Stopping profiler (streaming)...")
+ await omni.stop_profile(stages=args.profiler_stages)
+ if args.profiler_wait_seconds > 0:
+ print(f"Waiting {args.profiler_wait_seconds:.1f}s for profiler traces to flush...")
+ await asyncio.sleep(args.profiler_wait_seconds)
+
+ print(
+ f"All streaming runs finished: {args.num_runs} run(s), "
+ f"{len(prompt_specs)} prompt(s), {len(paths)} file(s) in {total_elapsed:.2f}s total"
+ )
+ return paths
+
+
+def _run_sync(args) -> list[Path]:
+ output_dir = Path(args.output_dir)
+
+ omni = Omni(
+ model=args.model,
+ stage_configs_path=args.stage_configs_path,
+ log_stats=args.log_stats,
+ stage_init_timeout=args.stage_init_timeout,
+ )
+
+ def _run_sync_single(
+ spec: PromptSpec,
+ *,
+ request_prefix: str,
+ save_outputs: bool,
+ run_index: int | None = None,
+ ) -> tuple[list[Path], int, float | None, float, float, str]:
+ global_request_id = f"{request_prefix}_{spec.label}"
+ prompt = _build_prompt_for_spec(args, spec, global_request_id=global_request_id)
+ if save_outputs and run_index == 0 and spec.label == "item001":
+ print(f"---prompt---:{prompt}")
+
+ saved_paths: list[Path] = []
+ output_count = 0
+ first_audio_elapsed: float | None = None
+ total_audio_duration_s = 0.0
+ metrics_request_id = global_request_id
+ t_start = time.perf_counter()
+ for stage_outputs in omni.generate(prompt):
+ request_output = stage_outputs.request_output
+ if request_output is None:
+ continue
+ request_output_id = getattr(request_output, "request_id", None)
+ if isinstance(request_output_id, str) and request_output_id:
+ metrics_request_id = request_output_id
+ for j, mm in enumerate(_iter_request_multimodal_outputs(request_output)):
+ output_count += 1
+ if first_audio_elapsed is None:
+ try:
+ audio_tensor = _extract_audio_tensor(mm)
+ if int(audio_tensor.numel()) > 0:
+ first_audio_elapsed = time.perf_counter() - t_start
+ total_audio_duration_s += float(audio_tensor.numel()) / float(_extract_sample_rate(mm))
+ except ValueError:
+ pass
+ else:
+ try:
+ audio_tensor = _extract_audio_tensor(mm)
+ total_audio_duration_s += float(audio_tensor.numel()) / float(_extract_sample_rate(mm))
+ except ValueError:
+ pass
+ if not save_outputs:
+ continue
+ save_stem = f"run{run_index + 1}_{spec.label}" if j == 0 else f"run{run_index + 1}_{spec.label}_{j}"
+ saved_paths.append(_save_wav(mm, output_dir, save_stem))
+
+ if output_count == 0:
+ raise RuntimeError("No output from Omni.generate")
+ elapsed_s = time.perf_counter() - t_start
+ return saved_paths, output_count, first_audio_elapsed, elapsed_s, total_audio_duration_s, metrics_request_id
+
+ if args.warmup_runs:
+ warmup_specs = _get_warmup_specs(args.prompt_specs)
+ print(
+ f"Warmup: {args.warmup_runs} run(s) using the first prompt "
+ f"({len(warmup_specs)} prompt(s)); outputs will be discarded."
+ )
+ for warmup_index in range(args.warmup_runs):
+ t_warmup = time.perf_counter()
+ _, output_count, first_audio_elapsed, elapsed_s, audio_duration_s, _ = _run_sync_single(
+ warmup_specs[0],
+ request_prefix=f"warmup_sync{warmup_index + 1}",
+ save_outputs=False,
+ )
+ ttfp_text = f", ttfp={first_audio_elapsed:.2f}s" if first_audio_elapsed is not None else ""
+ rtf_text = f", rtf={elapsed_s / audio_duration_s:.3f}" if audio_duration_s > 0 else ""
+ print(
+ f"Warmup (sync) {warmup_index + 1}/{args.warmup_runs} finished: "
+ f"{output_count} output(s) ({time.perf_counter() - t_warmup:.2f}s{ttfp_text}{rtf_text})"
+ )
+
+ profiler_started = False
+ if args.enable_profiler:
+ profile_prefix = f"voxcpm_sync_{int(time.time())}"
+ stages_text = args.profiler_stages if args.profiler_stages is not None else "all-configured"
+ print(f"Starting profiler (sync): stages={stages_text}, dir={args.profiler_dir}")
+ omni.start_profile(profile_prefix=profile_prefix, stages=args.profiler_stages)
+ profiler_started = True
+
+ t_total = time.perf_counter()
+ total_elapsed = 0.0
+ saved_paths: list[Path] = []
+ prompt_specs: list[PromptSpec] = args.prompt_specs
+ try:
+ for run in range(args.num_runs):
+ t_run = time.perf_counter()
+ run_paths: list[Path] = []
+ for prompt_index, spec in enumerate(prompt_specs):
+ prompt_paths, _, first_audio_elapsed, elapsed_s, audio_duration_s, metrics_request_id = (
+ _run_sync_single(
+ spec,
+ request_prefix=f"sync_run{run + 1}_{prompt_index + 1:03d}",
+ save_outputs=True,
+ run_index=run,
+ )
+ )
+ run_paths.extend(prompt_paths)
+ ttfp_text = f", ttfp={first_audio_elapsed:.2f}s" if first_audio_elapsed is not None else ""
+ rtf_text = f", rtf={elapsed_s / audio_duration_s:.3f}" if audio_duration_s > 0 else ""
+ print(
+ f"Saved (sync) run {run + 1}/{args.num_runs}, "
+ f"prompt {prompt_index + 1}/{len(prompt_specs)}: {len(prompt_paths)} file(s){ttfp_text}{rtf_text}"
+ )
+ _emit_offline_metrics(
+ request_id=metrics_request_id,
+ elapsed_s=elapsed_s,
+ first_audio_elapsed=first_audio_elapsed,
+ audio_duration_s=audio_duration_s,
+ )
+
+ saved_paths.extend(run_paths)
+ print(
+ f"Run {run + 1}/{args.num_runs} finished: {len(run_paths)} file(s) ({time.perf_counter() - t_run:.2f}s)"
+ )
+ for path in run_paths:
+ print(f" {path}")
+
+ total_elapsed = time.perf_counter() - t_total
+ finally:
+ if profiler_started:
+ print("Stopping profiler (sync)...")
+ omni.stop_profile(stages=args.profiler_stages)
+ if args.profiler_wait_seconds > 0:
+ print(f"Waiting {args.profiler_wait_seconds:.1f}s for profiler traces to flush...")
+ time.sleep(args.profiler_wait_seconds)
+
+ print(
+ f"All sync runs finished: {args.num_runs} run(s), "
+ f"{len(prompt_specs)} prompt(s), {len(saved_paths)} file(s) in {total_elapsed:.2f}s total"
+ )
+ return saved_paths
+
+
+def main(args) -> int:
+ logging.basicConfig(level=logging.INFO)
+ profiled_stage_config_path: str | None = None
+ original_stage_config_path = args.stage_configs_path
+ if args.enable_profiler:
+ Path(args.profiler_dir).mkdir(parents=True, exist_ok=True)
+ profiled_stage_config_path = _build_profiled_stage_config(
+ args.stage_configs_path,
+ str(Path(args.profiler_dir).resolve()),
+ )
+ args.stage_configs_path = profiled_stage_config_path
+
+ is_streaming = _is_streaming_stage_config(args.stage_configs_path)
+ voice_clone_count = _count_voice_clone_prompts(args.prompt_specs)
+ print(f"Model: {args.model}")
+ print(f"Stage config: {original_stage_config_path}")
+ print(f"Route: {'streaming' if is_streaming else 'sync'} (from stage-configs-path)")
+ print(f"Prompt count: {len(args.prompt_specs)}")
+ print("Batch mode: sequential (aligned with native VoxCPM)")
+ print(f"Warmup runs: {args.warmup_runs}")
+ print(f"Voice cloning prompts: {voice_clone_count}/{len(args.prompt_specs)}")
+ if args.enable_profiler:
+ print(f"Profiler: enabled (dir={args.profiler_dir}, stages={args.profiler_stages or 'all-configured'})")
+ print(f"Profiled stage config: {args.stage_configs_path}")
+ if voice_clone_count:
+ print("Voice cloning note: --ref-text/ref_text must match the spoken content of the reference audio.")
+ print(f"Num runs: {args.num_runs}")
+ try:
+ if is_streaming:
+ asyncio.run(_run_streaming(args))
+ else:
+ _run_sync(args)
+ finally:
+ if profiled_stage_config_path is not None and os.path.exists(profiled_stage_config_path):
+ os.unlink(profiled_stage_config_path)
+ return 0
+
+
+if __name__ == "__main__":
+ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+ raise SystemExit(main(parse_args()))
diff --git a/benchmarks/voxcpm/vllm_omni/bench_tts_serve.py b/benchmarks/voxcpm/vllm_omni/bench_tts_serve.py
new file mode 100644
index 0000000000..816df32796
--- /dev/null
+++ b/benchmarks/voxcpm/vllm_omni/bench_tts_serve.py
@@ -0,0 +1,283 @@
+"""Benchmark VoxCPM via /v1/audio/speech.
+
+Reports TTFP (time to first packet), E2E latency, and RTF (real-time factor).
+"""
+
+from __future__ import annotations
+
+import argparse
+import asyncio
+import json
+import time
+from dataclasses import asdict, dataclass, field
+from datetime import datetime
+from pathlib import Path
+
+import aiohttp
+import numpy as np
+from tqdm.asyncio import tqdm
+
+DEFAULT_MODEL = "OpenBMB/VoxCPM1.5"
+DEFAULT_SAMPLE_RATE = 24000
+PROMPTS = [
+ "Hello, welcome to the VoxCPM speech benchmark.",
+ "This is a short benchmark prompt for online text-to-speech generation.",
+ "The quick brown fox jumps over the lazy dog near the riverbank.",
+ "Please remember to bring your identification documents tomorrow morning.",
+ "Learning a new language takes patience, practice, and curiosity.",
+ "This benchmark reports TTFP and RTF for the VoxCPM online serving path.",
+]
+
+
+@dataclass
+class RequestResult:
+ success: bool = False
+ ttfp: float = 0.0
+ e2e: float = 0.0
+ audio_bytes: int = 0
+ audio_duration: float = 0.0
+ rtf: float = 0.0
+ prompt: str = ""
+ error: str = ""
+
+
+@dataclass
+class BenchmarkResult:
+ concurrency: int = 0
+ num_prompts: int = 0
+ completed: int = 0
+ failed: int = 0
+ duration_s: float = 0.0
+ mean_ttfp_ms: float = 0.0
+ median_ttfp_ms: float = 0.0
+ p95_ttfp_ms: float = 0.0
+ mean_e2e_ms: float = 0.0
+ median_e2e_ms: float = 0.0
+ p95_e2e_ms: float = 0.0
+ mean_rtf: float = 0.0
+ median_rtf: float = 0.0
+ p95_rtf: float = 0.0
+ total_audio_duration_s: float = 0.0
+ request_throughput: float = 0.0
+ per_request: list[dict[str, float | str]] = field(default_factory=list)
+
+
+def pcm_bytes_to_duration(num_bytes: int, sample_rate: int = DEFAULT_SAMPLE_RATE, sample_width: int = 2) -> float:
+ num_samples = num_bytes / sample_width
+ return num_samples / sample_rate
+
+
+async def send_tts_request(
+ session: aiohttp.ClientSession,
+ api_url: str,
+ *,
+ model: str,
+ prompt: str,
+ ref_audio: str | None,
+ ref_text: str | None,
+ pbar: tqdm | None = None,
+) -> RequestResult:
+ payload: dict[str, object] = {
+ "model": model,
+ "input": prompt,
+ "stream": True,
+ "response_format": "pcm",
+ }
+ if ref_audio is not None:
+ payload["ref_audio"] = ref_audio
+ if ref_text is not None:
+ payload["ref_text"] = ref_text
+
+ result = RequestResult(prompt=prompt)
+ started_at = time.perf_counter()
+
+ try:
+ async with session.post(api_url, json=payload) as response:
+ if response.status != 200:
+ result.error = f"HTTP {response.status}: {await response.text()}"
+ return result
+
+ first_chunk = True
+ total_bytes = 0
+ async for chunk in response.content.iter_any():
+ if not chunk:
+ continue
+ if first_chunk:
+ result.ttfp = time.perf_counter() - started_at
+ first_chunk = False
+ total_bytes += len(chunk)
+
+ result.e2e = time.perf_counter() - started_at
+ result.audio_bytes = total_bytes
+ result.audio_duration = pcm_bytes_to_duration(total_bytes)
+ if result.audio_duration > 0:
+ result.rtf = result.e2e / result.audio_duration
+ result.success = True
+ except Exception as e:
+ result.error = str(e)
+ result.e2e = time.perf_counter() - started_at
+
+ if pbar is not None:
+ pbar.update(1)
+ return result
+
+
+async def run_benchmark(
+ *,
+ host: str,
+ port: int,
+ model: str,
+ num_prompts: int,
+ max_concurrency: int,
+ num_warmups: int,
+ ref_audio: str | None,
+ ref_text: str | None,
+) -> BenchmarkResult:
+ api_url = f"http://{host}:{port}/v1/audio/speech"
+ connector = aiohttp.TCPConnector(limit=max_concurrency, limit_per_host=max_concurrency, keepalive_timeout=60)
+ timeout = aiohttp.ClientTimeout(total=600)
+
+ async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
+ if num_warmups > 0:
+ print(f" Warming up with {num_warmups} requests...")
+ warmup_tasks = [
+ send_tts_request(
+ session,
+ api_url,
+ model=model,
+ prompt=PROMPTS[i % len(PROMPTS)],
+ ref_audio=ref_audio,
+ ref_text=ref_text,
+ )
+ for i in range(num_warmups)
+ ]
+ await asyncio.gather(*warmup_tasks)
+ print(" Warmup done.")
+
+ request_prompts = [PROMPTS[i % len(PROMPTS)] for i in range(num_prompts)]
+ semaphore = asyncio.Semaphore(max_concurrency)
+ pbar = tqdm(total=num_prompts, desc=f" concurrency={max_concurrency}")
+
+ async def limited_request(prompt: str) -> RequestResult:
+ async with semaphore:
+ return await send_tts_request(
+ session,
+ api_url,
+ model=model,
+ prompt=prompt,
+ ref_audio=ref_audio,
+ ref_text=ref_text,
+ pbar=pbar,
+ )
+
+ started_at = time.perf_counter()
+ results = await asyncio.gather(*[asyncio.create_task(limited_request(prompt)) for prompt in request_prompts])
+ duration = time.perf_counter() - started_at
+ pbar.close()
+
+ succeeded = [result for result in results if result.success]
+ bench = BenchmarkResult(
+ concurrency=max_concurrency,
+ num_prompts=num_prompts,
+ completed=len(succeeded),
+ failed=len(results) - len(succeeded),
+ duration_s=duration,
+ )
+
+ if not succeeded:
+ return bench
+
+ ttfps = np.array([result.ttfp * 1000 for result in succeeded], dtype=np.float64)
+ e2es = np.array([result.e2e * 1000 for result in succeeded], dtype=np.float64)
+ rtfs = np.array([result.rtf for result in succeeded], dtype=np.float64)
+ audio_durations = np.array([result.audio_duration for result in succeeded], dtype=np.float64)
+
+ bench.mean_ttfp_ms = float(np.mean(ttfps))
+ bench.median_ttfp_ms = float(np.median(ttfps))
+ bench.p95_ttfp_ms = float(np.percentile(ttfps, 95))
+ bench.mean_e2e_ms = float(np.mean(e2es))
+ bench.median_e2e_ms = float(np.median(e2es))
+ bench.p95_e2e_ms = float(np.percentile(e2es, 95))
+ bench.mean_rtf = float(np.mean(rtfs))
+ bench.median_rtf = float(np.median(rtfs))
+ bench.p95_rtf = float(np.percentile(rtfs, 95))
+ bench.total_audio_duration_s = float(np.sum(audio_durations))
+ bench.request_throughput = len(succeeded) / duration if duration > 0 else 0.0
+ bench.per_request = [
+ {
+ "prompt": result.prompt,
+ "ttfp_ms": result.ttfp * 1000,
+ "e2e_ms": result.e2e * 1000,
+ "rtf": result.rtf,
+ "audio_duration_s": result.audio_duration,
+ }
+ for result in succeeded
+ ]
+
+ return bench
+
+
+def print_summary(result: BenchmarkResult) -> None:
+ width = 54
+ print("")
+ print("=" * width)
+ print(f"{'VoxCPM Serving Benchmark':^{width}}")
+ print("=" * width)
+ print(f"concurrency : {result.concurrency}")
+ print(f"requests : {result.completed}/{result.num_prompts} succeeded")
+ print(f"wall time (s) : {result.duration_s:.3f}")
+ print(f"mean TTFP (ms) : {result.mean_ttfp_ms:.2f}")
+ print(f"p95 TTFP (ms) : {result.p95_ttfp_ms:.2f}")
+ print(f"mean E2E (ms) : {result.mean_e2e_ms:.2f}")
+ print(f"p95 E2E (ms) : {result.p95_e2e_ms:.2f}")
+ print(f"mean RTF : {result.mean_rtf:.3f}")
+ print(f"p95 RTF : {result.p95_rtf:.3f}")
+ print(f"request throughput : {result.request_throughput:.2f} req/s")
+ print("=" * width)
+
+
+async def main_async(args) -> None:
+ result_dir = Path(args.result_dir)
+ result_dir.mkdir(parents=True, exist_ok=True)
+
+ all_results: list[BenchmarkResult] = []
+ for concurrency in args.max_concurrency:
+ result = await run_benchmark(
+ host=args.host,
+ port=args.port,
+ model=args.model,
+ num_prompts=args.num_prompts,
+ max_concurrency=concurrency,
+ num_warmups=args.num_warmups,
+ ref_audio=args.ref_audio,
+ ref_text=args.ref_text,
+ )
+ print_summary(result)
+ all_results.append(result)
+
+ payload = {
+ "model": args.model,
+ "created_at": datetime.utcnow().isoformat() + "Z",
+ "results": [asdict(result) for result in all_results],
+ }
+ result_path = result_dir / "bench_tts_serve.json"
+ result_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
+ print(f"Saved results to: {result_path}")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Benchmark VoxCPM via /v1/audio/speech")
+ parser.add_argument("--host", default="127.0.0.1", help="Server host")
+ parser.add_argument("--port", type=int, default=8091, help="Server port")
+ parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name or path")
+ parser.add_argument("--num-prompts", type=int, default=20, help="Number of prompts to send")
+ parser.add_argument("--max-concurrency", type=int, nargs="+", default=[1], help="Concurrency levels to benchmark")
+ parser.add_argument("--num-warmups", type=int, default=3, help="Warmup request count")
+ parser.add_argument("--ref-audio", default=None, help="Reference audio URL or data URL for voice cloning")
+ parser.add_argument("--ref-text", default=None, help="Reference audio transcript for voice cloning")
+ parser.add_argument("--result-dir", default="results", help="Directory to save benchmark JSON")
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ asyncio.run(main_async(parse_args()))
diff --git a/benchmarks/voxcpm/vllm_omni/run_offline_matrix.py b/benchmarks/voxcpm/vllm_omni/run_offline_matrix.py
new file mode 100644
index 0000000000..cee46c0f86
--- /dev/null
+++ b/benchmarks/voxcpm/vllm_omni/run_offline_matrix.py
@@ -0,0 +1,303 @@
+"""Run the full offline VoxCPM smoke matrix.
+
+This script keeps the old `test.py` coverage, but delegates each case to
+`bench_tts_offline.py` so the benchmark runner itself stays focused on a
+single execution path.
+"""
+
+from __future__ import annotations
+
+import shlex
+import subprocess
+import sys
+import time
+from dataclasses import dataclass
+from pathlib import Path
+
+from vllm.utils.argparse_utils import FlexibleArgumentParser
+
+REPO_ROOT = Path(__file__).resolve().parents[3]
+BENCH_SCRIPT = Path(__file__).with_name("bench_tts_offline.py")
+DEFAULT_STAGE_ASYNC = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm_async_chunk.yaml"
+DEFAULT_STAGE_SYNC = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm.yaml"
+DEFAULT_OUTPUT_ROOT = BENCH_SCRIPT.parents[1] / "results" / "offline_matrix"
+
+SINGLE_TTS_TEXT = "This is a single text-to-speech smoke test for VoxCPM on vLLM Omni."
+SINGLE_CLONE_TEXT = "This sentence is synthesized with the cloned voice for validation."
+BATCH_TTS_TEXTS = [
+ "The first batch text-to-speech sample validates sequential batch execution.",
+ "The second batch text-to-speech sample checks another prompt in the same file.",
+ "The third batch text-to-speech sample completes the sequential batch path.",
+]
+BATCH_CLONE_TEXTS = [
+ "The first cloned sample validates sequential batch voice cloning.",
+ "The second cloned sample checks the same reference voice on another prompt.",
+ "The third cloned sample finishes the shared-reference clone batch path.",
+]
+
+
+@dataclass(frozen=True, slots=True)
+class ModeSpec:
+ name: str
+ stage_config: Path
+
+
+@dataclass(frozen=True, slots=True)
+class CaseSpec:
+ name: str
+ warmup_runs: int
+ prompt_kind: str
+ voice_clone: bool
+
+
+@dataclass(frozen=True, slots=True)
+class CaseResult:
+ mode: str
+ case: str
+ returncode: int
+ elapsed_s: float
+ output_dir: Path
+ log_path: Path
+
+ @property
+ def ok(self) -> bool:
+ return self.returncode == 0
+
+
+MODE_SPECS = [
+ ModeSpec(name="streaming", stage_config=DEFAULT_STAGE_ASYNC),
+ ModeSpec(name="sync", stage_config=DEFAULT_STAGE_SYNC),
+]
+
+CASE_SPECS = [
+ CaseSpec(name="warmup_single_tts", warmup_runs=1, prompt_kind="single", voice_clone=False),
+ CaseSpec(name="warmup_single_clone", warmup_runs=1, prompt_kind="single", voice_clone=True),
+ CaseSpec(name="warmup_batch_tts", warmup_runs=1, prompt_kind="batch", voice_clone=False),
+ CaseSpec(name="warmup_batch_clone", warmup_runs=1, prompt_kind="batch", voice_clone=True),
+ CaseSpec(name="cold_single_tts", warmup_runs=0, prompt_kind="single", voice_clone=False),
+ CaseSpec(name="cold_single_clone", warmup_runs=0, prompt_kind="single", voice_clone=True),
+]
+
+
+def _write_lines(path: Path, lines: list[str]) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ path.write_text("\n".join(lines) + "\n", encoding="utf-8")
+
+
+def _prepare_batch_inputs(output_root: Path) -> tuple[Path, Path]:
+ input_dir = output_root / "inputs"
+ batch_tts_path = input_dir / "batch_tts_prompts.txt"
+ batch_clone_path = input_dir / "batch_clone_prompts.txt"
+ _write_lines(batch_tts_path, BATCH_TTS_TEXTS)
+ _write_lines(batch_clone_path, BATCH_CLONE_TEXTS)
+ return batch_tts_path, batch_clone_path
+
+
+def _base_command(args, mode: ModeSpec, output_dir: Path) -> list[str]:
+ cmd = [
+ args.python,
+ str(BENCH_SCRIPT),
+ "--model",
+ args.model,
+ "--stage-configs-path",
+ str(mode.stage_config),
+ "--output-dir",
+ str(output_dir),
+ "--num-runs",
+ str(args.num_runs),
+ "--stage-init-timeout",
+ str(args.stage_init_timeout),
+ ]
+ cmd.append("--log-stats" if args.log_stats else "--no-log-stats")
+ cmd.extend(["--cfg-value", str(args.cfg_value)])
+ cmd.extend(["--inference-timesteps", str(args.inference_timesteps)])
+ cmd.extend(["--min-len", str(args.min_len)])
+ cmd.extend(["--max-new-tokens", str(args.max_new_tokens)])
+ if args.streaming_prefix_len is not None:
+ cmd.extend(["--streaming-prefix-len", str(args.streaming_prefix_len)])
+ if args.enable_profiler:
+ profiler_dir = Path(args.profiler_dir) if args.profiler_dir is not None else (output_dir / "profiler")
+ cmd.append("--enable-profiler")
+ cmd.extend(["--profiler-dir", str(profiler_dir)])
+ cmd.extend(["--profiler-wait-seconds", str(args.profiler_wait_seconds)])
+ if args.profiler_stages is not None:
+ cmd.append("--profiler-stages")
+ cmd.extend(str(stage_id) for stage_id in args.profiler_stages)
+ return cmd
+
+
+def _build_case_command(
+ args,
+ mode: ModeSpec,
+ case: CaseSpec,
+ *,
+ batch_tts_path: Path,
+ batch_clone_path: Path,
+ output_dir: Path,
+) -> list[str]:
+ cmd = _base_command(args, mode, output_dir)
+ cmd.extend(["--warmup-runs", str(case.warmup_runs)])
+ if case.prompt_kind == "single":
+ cmd.extend(["--text", SINGLE_CLONE_TEXT if case.voice_clone else SINGLE_TTS_TEXT])
+ else:
+ cmd.extend(["--txt-prompts", str(batch_clone_path if case.voice_clone else batch_tts_path)])
+ if case.voice_clone:
+ cmd.extend(["--ref-audio", args.ref_audio, "--ref-text", args.ref_text])
+ return cmd
+
+
+def _run_case(
+ args,
+ mode: ModeSpec,
+ case: CaseSpec,
+ *,
+ batch_tts_path: Path,
+ batch_clone_path: Path,
+ output_root: Path,
+) -> CaseResult:
+ case_output_dir = output_root / mode.name / case.name
+ case_output_dir.mkdir(parents=True, exist_ok=True)
+ case_log_path = case_output_dir / "run.log"
+ cmd = _build_case_command(
+ args,
+ mode,
+ case,
+ batch_tts_path=batch_tts_path,
+ batch_clone_path=batch_clone_path,
+ output_dir=case_output_dir,
+ )
+
+ print()
+ print("=" * 80)
+ print(f"[{mode.name}] {case.name}")
+ print(f"Output directory: {case_output_dir}")
+ print(shlex.join(cmd))
+
+ start = time.perf_counter()
+ with case_log_path.open("w", encoding="utf-8") as log_fp:
+ process = subprocess.Popen(
+ cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ text=True,
+ bufsize=1,
+ )
+ assert process.stdout is not None
+ for line in process.stdout:
+ print(line, end="")
+ log_fp.write(line)
+ process.wait()
+
+ elapsed_s = time.perf_counter() - start
+ status = "PASS" if (process.returncode or 0) == 0 else f"FAIL({process.returncode})"
+ print(f"[{mode.name}] {case.name} -> {status} ({elapsed_s:.2f}s)")
+ return CaseResult(
+ mode=mode.name,
+ case=case.name,
+ returncode=int(process.returncode or 0),
+ elapsed_s=elapsed_s,
+ output_dir=case_output_dir,
+ log_path=case_log_path,
+ )
+
+
+def parse_args():
+ parser = FlexibleArgumentParser(description="Run the full offline VoxCPM smoke matrix.")
+ parser.add_argument("--model", type=str, required=True, help="Local VoxCPM model directory.")
+ parser.add_argument("--ref-audio", type=str, required=True, help="Reference audio path for clone cases.")
+ parser.add_argument("--ref-text", type=str, required=True, help="Exact transcript spoken in --ref-audio.")
+ parser.add_argument("--output-root", type=str, default=str(DEFAULT_OUTPUT_ROOT), help="Root directory for outputs.")
+ parser.add_argument("--python", type=str, default=sys.executable, help="Python executable used to launch cases.")
+ parser.add_argument("--stage-init-timeout", type=int, default=600, help="Stage initialization timeout in seconds.")
+ parser.add_argument("--log-stats", dest="log_stats", action="store_true", help="Enable vLLM Omni stats logging.")
+ parser.add_argument(
+ "--no-log-stats",
+ dest="log_stats",
+ action="store_false",
+ help="Disable vLLM Omni stats logging.",
+ )
+ parser.set_defaults(log_stats=True)
+ parser.add_argument("--num-runs", type=int, default=1, help="Number of measured runs per case.")
+ parser.add_argument("--cfg-value", type=float, default=2.0, help="Classifier-free guidance value for VoxCPM.")
+ parser.add_argument("--inference-timesteps", type=int, default=10, help="Number of inference timesteps.")
+ parser.add_argument("--min-len", type=int, default=2, help="Minimum generated token length.")
+ parser.add_argument("--max-new-tokens", type=int, default=4096, help="Maximum generated token length.")
+ parser.add_argument(
+ "--streaming-prefix-len",
+ type=int,
+ default=None,
+ help="Optional VoxCPM streaming window passed to streaming cases.",
+ )
+ parser.add_argument("--enable-profiler", action="store_true", help="Enable torch profiler for each case.")
+ parser.add_argument(
+ "--profiler-dir",
+ type=str,
+ default=None,
+ help="Profiler output root. Defaults to /profiler.",
+ )
+ parser.add_argument(
+ "--profiler-stages",
+ type=int,
+ nargs="*",
+ default=None,
+ help="Optional stage ids to profile. Defaults to all configured stages.",
+ )
+ parser.add_argument(
+ "--profiler-wait-seconds",
+ type=float,
+ default=30.0,
+ help="Seconds to wait after stopping profiler for traces to flush.",
+ )
+ args = parser.parse_args()
+ if args.num_runs < 1:
+ parser.error("--num-runs must be >= 1")
+ return args
+
+
+def main(args) -> int:
+ output_root = Path(args.output_root)
+ output_root.mkdir(parents=True, exist_ok=True)
+ batch_tts_path, batch_clone_path = _prepare_batch_inputs(output_root)
+
+ print(f"Model: {args.model}")
+ print(f"Reference audio: {args.ref_audio}")
+ print(f"Reference text: {args.ref_text}")
+ print(f"Python: {args.python}")
+ print(f"Output root: {output_root}")
+ print(f"Cases: {len(MODE_SPECS) * len(CASE_SPECS)}")
+
+ results: list[CaseResult] = []
+ for mode in MODE_SPECS:
+ for case in CASE_SPECS:
+ results.append(
+ _run_case(
+ args,
+ mode,
+ case,
+ batch_tts_path=batch_tts_path,
+ batch_clone_path=batch_clone_path,
+ output_root=output_root,
+ )
+ )
+
+ failed = [result for result in results if not result.ok]
+ print()
+ print("=" * 80)
+ print("Summary:")
+ for result in results:
+ status = "PASS" if result.ok else f"FAIL({result.returncode})"
+ print(f"- [{result.mode}] {result.case}: {status} ({result.elapsed_s:.2f}s)")
+ print(f" output_dir={result.output_dir}")
+ print(f" log={result.log_path}")
+
+ print(f"Passed: {len(results) - len(failed)}/{len(results)}")
+ if failed:
+ print("Failed cases:")
+ for result in failed:
+ print(f"- [{result.mode}] {result.case}: see {result.log_path}")
+ return 1
+ return 0
+
+
+if __name__ == "__main__":
+ raise SystemExit(main(parse_args()))
diff --git a/docker/Dockerfile.ci b/docker/Dockerfile.ci
index 24ce39bafd..9cbf89d0b7 100644
--- a/docker/Dockerfile.ci
+++ b/docker/Dockerfile.ci
@@ -7,7 +7,7 @@ COPY . .
# Install system dependencies
RUN apt-get update && \
- apt-get install -y espeak-ng ffmpeg git sox libsox-fmt-all jq && \
+ apt-get install -y espeak-ng git jq && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
diff --git a/docker/Dockerfile.cuda b/docker/Dockerfile.cuda
index 754d491d86..28e10f4fb8 100644
--- a/docker/Dockerfile.cuda
+++ b/docker/Dockerfile.cuda
@@ -7,7 +7,7 @@ WORKDIR ${COMMON_WORKDIR}
# Step 1: Setup - Install system dependencies
RUN apt-get update && \
- apt-get install -y ffmpeg git sox libsox-fmt-all jq && \
+ apt-get install -y git jq && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm
index bfbb060bcb..a54aa3b793 100644
--- a/docker/Dockerfile.rocm
+++ b/docker/Dockerfile.rocm
@@ -18,8 +18,10 @@ ARG COMMON_WORKDIR=/app
WORKDIR ${COMMON_WORKDIR}
# Step 1: Setup - Install system dependencies
+# Need to include ffmpeg because vllm rocm upstream docker image
+# does not include it.
RUN apt-get update && \
- apt-get install -y espeak-ng ffmpeg git sox libsox-fmt-all jq && \
+ apt-get install -y espeak-ng ffmpeg git jq && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
@@ -39,6 +41,24 @@ RUN if [ "${USE_NIGHTLY_BUILD}" = "1" ]; then \
# Step 3: Copy vllm-omni code and install without uv
RUN mkdir -p ${COMMON_WORKDIR}/vllm-omni
COPY . ${COMMON_WORKDIR}/vllm-omni
+
+# This is a workaround to ensure pytest exits with the correct status code in CI tests.
+RUN printf '%s\n' \
+ 'import os' \
+ '' \
+ '_exit_code = 1' \
+ '' \
+ 'def pytest_sessionfinish(session, exitstatus):' \
+ ' global _exit_code' \
+ ' _exit_code = int(exitstatus)' \
+ '' \
+ 'def pytest_unconfigure(config):' \
+ ' import sys' \
+ ' sys.stdout.flush()' \
+ ' sys.stderr.flush()' \
+ ' os._exit(_exit_code)' \
+ > ${COMMON_WORKDIR}/vllm-omni/conftest.py
+
RUN cd ${COMMON_WORKDIR}/vllm-omni && uv pip install --python "$(python3 -c 'import sys; print(sys.executable)')" --no-cache-dir ".[dev]" --no-build-isolation
RUN ln -sf /usr/bin/python3 /usr/bin/python
diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu
index 17f1aebf0d..25d5d0c800 100644
--- a/docker/Dockerfile.xpu
+++ b/docker/Dockerfile.xpu
@@ -15,9 +15,7 @@ RUN apt clean && apt-get update -y && \
apt-get install -y --no-install-recommends --fix-missing \
curl \
espeak-ng \
- ffmpeg \
git \
- libsndfile1 \
libsm6 \
libxext6 \
libgl1 \
diff --git a/docs/.nav.yml b/docs/.nav.yml
index 86ce4a3b0c..455a052505 100644
--- a/docs/.nav.yml
+++ b/docs/.nav.yml
@@ -64,6 +64,7 @@ nav:
- FP8: user_guide/diffusion/quantization/fp8.md
- Int8: user_guide/diffusion/quantization/int8.md
- GGUF: user_guide/diffusion/quantization/gguf.md
+ - Frame Interpolation: user_guide/diffusion/frame_interpolation.md
- Parallelism:
- Overview: user_guide/diffusion/parallelism/overview.md
- CFG Parallel: user_guide/diffusion/parallelism/cfg_parallel.md
@@ -97,6 +98,7 @@ nav:
- design/feature/disaggregated_inference.md
- design/feature/ray_based_execution.md
- design/feature/omni_connectors/
+ - design/feature/prefix_caching.md
- design/feature/cfg_parallel.md
- design/feature/expert_parallel.md
- design/feature/sequence_parallel.md
@@ -105,7 +107,7 @@ nav:
- design/feature/hsdp.md
- design/feature/cache_dit.md
- design/feature/teacache.md
- - design/feature/async_chunk_design.md
+ - design/feature/async_chunk.md
- design/feature/vae_parallel.md
- design/feature/diffusion_step_execution.md
- Module Design:
diff --git a/docs/api/README.md b/docs/api/README.md
index f65cbb525d..0147f19e12 100644
--- a/docs/api/README.md
+++ b/docs/api/README.md
@@ -5,7 +5,7 @@
Main entry points for vLLM-Omni inference and serving.
- [vllm_omni.entrypoints.async_omni.AsyncOmni][]
-- [vllm_omni.entrypoints.cfg_companion_tracker.CfgCompanionTracker][]
+- [vllm_omni.engine.cfg_companion_tracker.CfgCompanionTracker][]
- [vllm_omni.entrypoints.cli.benchmark.base.OmniBenchmarkSubcommandBase][]
- [vllm_omni.entrypoints.cli.benchmark.main.OmniBenchmarkSubcommand][]
- [vllm_omni.entrypoints.cli.benchmark.serve.OmniBenchmarkServingSubcommand][]
diff --git a/docs/assets/WeChat.jpg b/docs/assets/WeChat.jpg
index c32ece6c10..83252b7569 100644
Binary files a/docs/assets/WeChat.jpg and b/docs/assets/WeChat.jpg differ
diff --git a/docs/configuration/README.md b/docs/configuration/README.md
index b5761a7f1b..390176e9ce 100644
--- a/docs/configuration/README.md
+++ b/docs/configuration/README.md
@@ -6,7 +6,7 @@ For options within a vLLM Engine. Please refer to [vLLM Configuration](https://d
Currently, the main options are maintained by stage configs for each model.
-For specific example, please refer to [Qwen2.5-omni stage config](stage_configs/qwen2_5_omni.yaml)
+For a specific example, see the [Qwen2.5-Omni deploy config](gh-file:vllm_omni/deploy/qwen2_5_omni.yaml). The matching frozen pipeline topology lives at [vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py](gh-file:vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py).
For introduction, please check [Introduction for stage config](./stage_configs.md)
diff --git a/docs/configuration/pd_disaggregation.md b/docs/configuration/pd_disaggregation.md
index 1cf6189e60..9196bdb024 100644
--- a/docs/configuration/pd_disaggregation.md
+++ b/docs/configuration/pd_disaggregation.md
@@ -11,7 +11,7 @@ deployment-specific values usually change per environment:
- connector backend and connector ports
- connector IPs or bootstrap addresses
-Start from the [default Qwen3-Omni stage config](gh-file:vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml)
+Start from the [default Qwen3-Omni stage config](gh-file:vllm_omni/deploy/qwen3_omni_moe.yaml)
and copy it to your own file, for example `qwen3_omni_pd.yaml`. Then apply the
changes below.
@@ -145,19 +145,13 @@ Compared with the default Qwen3-Omni config:
```yaml
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
edges:
- from: 0
to: 1
- window_size: -1
- from: 1
to: 2
- window_size: -1
- from: 2
to: 3
- window_size: -1
```
## 4. Launch with your custom config
diff --git a/docs/configuration/stage_configs.md b/docs/configuration/stage_configs.md
index 95c42afcc7..55b4053cc7 100644
--- a/docs/configuration/stage_configs.md
+++ b/docs/configuration/stage_configs.md
@@ -3,7 +3,147 @@
In vLLM-Omni, the target model is separated into multiple stages, which are processed by different LLMEngines, DiffusionEngines or other types of engines. Depending on different types of stages, such as Autoregressive (AR) stage or Diffusion transformer (DiT) stage, each can choose corresponding schedulers, model workers to load with the Engines in a plug-in fashion.
!!! note
- Default stage config YAMLs (for example, `vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml` and `vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml`) are bundled and loaded automatically when `stage_configs_path` is not provided. They have been verified to work on 1xH100 for Qwen2.5-Omni and 2xH100 for Qwen3-Omni.
+ Default deploy config YAMLs (for example, `vllm_omni/deploy/qwen2_5_omni.yaml`, `vllm_omni/deploy/qwen3_omni_moe.yaml`, and `vllm_omni/deploy/qwen3_tts.yaml`) are bundled and loaded automatically when neither `--stage-configs-path` nor `--deploy-config` is provided — the model registry resolves the right pipeline + deploy YAML by `model_type`. The bundled defaults have been verified on 1xH100 for Qwen2.5-Omni and 2xH100 for Qwen3-Omni. Models that have not yet migrated to the new schema continue to use the legacy `vllm_omni/model_executor/stage_configs/.yaml` files via `--stage-configs-path`.
+
+## New deploy schema reference
+
+The new deploy schema lives under `vllm_omni/deploy/` and is paired with a frozen `PipelineConfig` registered by the model's `pipeline.py`. Each deploy YAML has these top-level fields:
+
+| Field | Type | Required | Default | Description |
+|-------|------|----------|---------|-------------|
+| `base_config` | str (path) | optional | — | Overlay parent (relative or absolute). `stages:` / `platforms:` deep-merged by stage_id; other scalars overlay-wins. Intended for user-authored overlays; prod yamls stay flat. |
+| `async_chunk` | bool | optional | `true` | Enable chunked streaming between stages. Pin to `false` if the pipeline runs end-to-end. |
+| `connectors` | dict | optional | `null` | Named connector specs (`{name, extra}`). Referenced by each stage's `input_connectors` / `output_connectors`. See [Connector schema](#connector-schema). |
+| `edges` | list | optional | `null` | Explicit edge list for the KV transfer graph. Auto-derived from stage inputs if omitted. |
+| `stages` | list | required | — | Per-stage engine args + wiring (see [Stage fields](#stage-fields)). |
+| `platforms` | dict | optional | `null` | Keyed by `npu` / `rocm` / `xpu`, each contains a `stages:` list with per-platform overrides applied on top of the CUDA defaults. |
+| `pipeline` | str | optional | `null` | Override the auto-detected pipeline registry key (used for structural variants like `qwen2_5_omni_thinker_only`). |
+| `trust_remote_code` | bool | optional | `true` | **Pipeline-wide.** Trust HF remote code on model load; applies to every stage. |
+| `distributed_executor_backend` | str | optional | `"mp"` | **Pipeline-wide.** Executor backend (`"mp"` or `"ray"`). |
+| `dtype` | str \| null | optional | `null` | **Pipeline-wide.** Model dtype for every stage. |
+| `quantization` | str \| null | optional | `null` | **Pipeline-wide.** Quantization method for every stage. |
+| `enable_prefix_caching` | bool | optional | `false` | **Pipeline-wide.** Prefix cache toggle applied to every stage. |
+| `enable_chunked_prefill` | bool \| null | optional | `null` | **Pipeline-wide.** Chunked prefill toggle applied to every stage. |
+| `data_parallel_size` | int | optional | `1` | **Pipeline-wide.** DP degree for every stage. |
+| `pipeline_parallel_size` | int | optional | `1` | **Pipeline-wide.** PP degree for every stage. |
+
+### Stage fields
+
+Each entry under `stages:` accepts any `StageDeployConfig` field directly (no nested `engine_args:`). Only fields whose value legitimately varies across stages live here; pipeline-wide settings (trust_remote_code, distributed_executor_backend, dtype, quantization, prefix/chunked prefill, DP/PP sizes) are declared at the top level and applied to every stage. Unknown keys fall through to `engine_extras:` and are forwarded to the engine.
+
+| Field | Type | Required | Default | Description |
+|-------|------|----------|---------|-------------|
+| `stage_id` | int | required | — | Stage identity; matched against `PipelineConfig.stages[*].stage_id`. |
+| `max_num_seqs` | int | optional | `64` | Max concurrent sequences per stage. |
+| `gpu_memory_utilization` | float | optional | `0.9` | Per-stage memory budget. |
+| `tensor_parallel_size` | int | optional | `1` | TP degree for this stage. |
+| `enforce_eager` | bool | optional | `false` | Disable CUDA graphs. |
+| `max_num_batched_tokens` | int | optional | `32768` | Prefill budget. |
+| `max_model_len` | int \| null | optional | `null` | Per-stage context length (auto-sets `VLLM_ALLOW_LONG_MAX_MODEL_LEN=1` when larger than HF default). |
+| `async_scheduling` | bool \| null | optional | `null` | Per-stage async scheduling toggle. |
+| `devices` | str | optional | `"0"` | `CUDA_VISIBLE_DEVICES`-style device list. |
+| `output_connectors` | dict \| null | optional | `null` | Keyed by `to_stage_`; values are names registered under top-level `connectors:`. |
+| `input_connectors` | dict \| null | optional | `null` | Keyed by `from_stage_`; values are names registered under top-level `connectors:`. |
+| `default_sampling_params` | dict \| null | optional | `null` | Baseline sampling params. Deep-merged with pipeline `sampling_constraints` (pipeline wins). |
+| `engine_extras` | dict | optional | `{}` | Catch-all for keys not listed above; deep-merged across overlays. Also carries per-stage overrides of pipeline-wide settings (e.g. stage-specific `dtype`). |
+
+### Connector schema
+
+Each entry under top-level `connectors:` follows this shape:
+
+```yaml
+connectors:
+ :
+ name: # required — class registered in vllm_omni.distributed
+ extra: # optional — forwarded to the connector's __init__
+ :
+ ...
+```
+
+| Connector class | Use case | `extra` keys |
+|-----------------|----------|--------------|
+| `SharedMemoryConnector` | Same-host KV transfer between stages (default for bundled YAMLs). | `shm_threshold_bytes` (int, default `65536`). |
+| `MooncakeStoreConnector` | Cross-host KV transfer over TCP. Required for multi-node deployments. | `host`, `metadata_server`, `master`, `segment` (int bytes), `localbuf` (int bytes), `proto` (`"tcp"` / `"rdma"`). |
+
+A stage references a connector by name in its `input_connectors` / `output_connectors`:
+
+```yaml
+connectors:
+ shm:
+ name: SharedMemoryConnector
+
+stages:
+ - stage_id: 0
+ output_connectors: {to_stage_1: shm}
+ - stage_id: 1
+ input_connectors: {from_stage_0: shm}
+```
+
+### CLI flags introduced in this refactor
+
+| Flag | Description |
+|------|-------------|
+| `--deploy-config PATH` | Load a new-schema deploy YAML. Takes precedence over `--stage-configs-path`. **Optional** — when omitted, the bundled `vllm_omni/deploy/.yaml` is auto-loaded by the model registry. |
+| `--stage-overrides JSON` | Per-stage JSON overrides, e.g. `'{"0":{"gpu_memory_utilization":0.5}}'`. Per-stage values always win over global flags. |
+| `--async-chunk` / `--no-async-chunk` | Flip the deploy YAML's `async_chunk:` bool. Unset (default) leaves the YAML value in force. |
+| `--stage-configs-path` | **Deprecated.** Accepts legacy `stage_args` yamls and (auto-detected) new deploy yamls; emits a deprecation warning. Migrate to `--deploy-config`. To be removed in a follow-up PR. |
+
+### Precedence
+
+From highest to lowest:
+
+1. Per-stage flags (`--stage-overrides` JSON, `--stage--` if registered)
+2. Explicit global CLI flags (`--gpu-memory-utilization 0.85`, etc.)
+3. Platform section (`platforms.npu.stages`, etc.) on top of the base `stages:`
+4. Overlay YAML (via `base_config:`) on top of the base YAML
+5. Parser defaults
+
+### Worked override example
+
+Starting from the bundled `vllm_omni/deploy/qwen3_omni_moe.yaml`:
+
+```yaml
+# vllm_omni/deploy/qwen3_omni_moe.yaml (excerpt)
+async_chunk: true
+stages:
+ - stage_id: 0
+ gpu_memory_utilization: 0.9
+ max_num_seqs: 32
+ - stage_id: 1
+ gpu_memory_utilization: 0.7
+ max_num_seqs: 16
+```
+
+A user-authored overlay that inherits the base and overrides only stage 1:
+
+```yaml
+# my_overrides.yaml
+base_config: /path/to/vllm_omni/deploy/qwen3_omni_moe.yaml
+stages:
+ - stage_id: 1
+ gpu_memory_utilization: 0.5 # smaller GPU
+```
+
+Launched with both an explicit global flag and a per-stage override:
+
+```bash
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
+ --deploy-config my_overrides.yaml \
+ --max-model-len 16384 \
+ --stage-overrides '{"0": {"max_num_seqs": 8}}'
+```
+
+Effective config per stage after the merge:
+
+| Stage | Field | Final value | Source |
+|-------|-------|-------------|--------|
+| 0 | `gpu_memory_utilization` | `0.9` | base YAML (overlay didn't touch stage 0) |
+| 0 | `max_num_seqs` | `8` | per-stage CLI (`--stage-overrides`) — wins over base `32` |
+| 0 | `max_model_len` | `16384` | global CLI |
+| 1 | `gpu_memory_utilization` | `0.5` | overlay YAML — wins over base `0.7` |
+| 1 | `max_num_seqs` | `16` | base YAML (overlay didn't touch this field) |
+| 1 | `max_model_len` | `16384` | global CLI |
+| 2 | (all defaults) | — | base YAML (no overrides apply) |
Therefore, as a core part of vLLM-Omni, the stage configs for a model have several main functions:
@@ -35,7 +175,7 @@ stage_args:
- stage_id: 0 # mark the unique id for each stage
runtime: # The disaggregated configuration
process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
+ devices: "0" # Logical device index for this stage (mapped through CUDA_VISIBLE_DEVICES / ASCEND_RT_VISIBLE_DEVICES if set)
engine_args: # Engine arguments for a certain engine
model_stage: thinker
max_num_seqs: 1
@@ -114,16 +254,12 @@ stage_args:
# Top-level runtime config (concise): default windows and stage edges
runtime:
enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
+
edges:
- from: 0 # thinker → talker: trigger only after receiving full input (-1)
to: 1
- window_size: -1
- from: 1 # talker → code2wav: trigger only after receiving full input (-1)
to: 2
- window_size: -1
```
@@ -155,7 +291,9 @@ Default: `true`
#### `runtime.devices`
-Visible devices for this stage, specified as a string. This controls which GPU devices are available to the stage process, similar to setting `CUDA_VISIBLE_DEVICES` or using `torch.cuda.set_device()`. For example, `"0"` uses GPU 0, `"1"` uses GPU 1, and `"0,1"` makes both GPUs 0 and 1 visible.
+Logical device indices for this stage, specified as a string. Values are **logical indices** (`0`, `1`, `2`, ...) — not physical GPU IDs — and are mapped through the platform's visibility env var (`CUDA_VISIBLE_DEVICES` on CUDA, `ASCEND_RT_VISIBLE_DEVICES` on NPU) before being applied via `torch.cuda.set_device()` (or the equivalent).
+
+Example: if `CUDA_VISIBLE_DEVICES=0,2,4` is set in the environment, then `devices: "0"` selects physical GPU 0 (the first visible), `devices: "1"` selects physical GPU 2, and `devices: "0,1"` makes physical GPUs 0 and 2 available to the stage. If no visibility env var is set, logical and physical IDs coincide.
Default: `"0"`
diff --git a/docs/configuration/stage_configs/qwen2_5_omni.yaml b/docs/configuration/stage_configs/qwen2_5_omni.yaml
deleted file mode 100644
index 690577b84a..0000000000
--- a/docs/configuration/stage_configs/qwen2_5_omni.yaml
+++ /dev/null
@@ -1,94 +0,0 @@
-# stage config for running qwen2.5-omni with AsyncOmniEngine + Orchestrator runtime.
-stage_args:
- - stage_id: 0
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: true # Now we only support eager mode
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- - stage_id: 1
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: latent
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
- - stage_id: 2
- runtime:
- process: true
- devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.15
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/docs/contributing/ci/CI_5levels.md b/docs/contributing/ci/CI_5levels.md
index 74ae1a38eb..2452ef5d4a 100644
--- a/docs/contributing/ci/CI_5levels.md
+++ b/docs/contributing/ci/CI_5levels.md
@@ -86,7 +86,8 @@ Through five levels (L1-L5) and common (Common) specifications, the system clari
/tests/e2e/online_serving/test_{model_name}_expansion.py
/tests/e2e/offline_inference/test_{model_name}_expansion.py
Performance:
- /tests/dfx/perf/tests/test.json
+ /tests/dfx/perf/tests/test_qwen_omni.json (Omni), test_tts.json (TTS),
+ and /tests/dfx/perf/tests/test_{diffusion_model}_vllm_omni.json (Diffusion)
Doc Test:
tests/example/online_serving/test_{model_name}.py
tests/example/offline_inference/test_{model_name}.py
@@ -230,8 +231,7 @@ vllm_omni/ tests/
│ ├── test_qwen3_omni_expansion.py
│ ├── test_mimo_audio.py
│ ├── test_image_gen_edit.py
- │ ├── test_images_generations_lora.py
- │ └── stage_configs/
+ │ └── test_images_generations_lora.py
└── offline_inference/ ✅
├── test_qwen2_5_omni.py
├── test_qwen3_omni.py
@@ -242,16 +242,17 @@ vllm_omni/ tests/
├── test_zimage_tensor_parallel.py
├── test_cache_dit.py
├── test_teacache.py
- ├── test_stable_audio_model.py
+ ├── test_stable_audio_expansion.py
├── test_diffusion_cpu_offload.py
├── test_diffusion_layerwise_offload.py
├── test_diffusion_lora.py
├── test_sequence_parallel.py
- └── stage_configs/
- ├── qwen2_5_omni_ci.yaml
- ├── qwen3_omni_ci.yaml
- ├── bagel_*.yaml
- └── npu/, rocm/, etc.
+ └── stage_configs/ (legacy schema, still
+ ├── bagel_*.yaml present for unmigrated
+ └── npu/, rocm/, etc. models)
+
+# Migrated models (qwen3_omni_moe, qwen2_5_omni, qwen3_tts) live under
+# vllm_omni/deploy/ instead — see docs/configuration/stage_configs.md.
```
@@ -530,13 +531,13 @@ L4 level testing is a comprehensive quality audit before a version release. It e
### 3.2 Testing Content and Scope
- ***Full Functionality Testing***: Executes all test cases defined in `test_{model_name}_expansion.py`, covering all implemented features, positive flows, boundary conditions, and exception handling.
-- ***Performance Testing***: Uses the `tests/dfx/perf/tests/test.json` configuration file to drive performance testing tools for stress, load, and endurance tests, collecting metrics like throughput, response time, and resource utilization.
+- ***Performance Testing***: Uses `tests/dfx/perf/tests/test_qwen_omni.json`, `tests/dfx/perf/tests/test_tts.json`, and diffusion configs in the form `tests/dfx/perf/tests/test_*_vllm_omni.json` (passed to `run_benchmark.py` via `--test-config-file`) to drive performance testing tools for stress, load, and endurance tests, collecting metrics like throughput, response time, and resource utilization.
- ***Documentation Testing***: Verifies whether the example code provided to users is runnable and its results match the description.
### 3.3 Test Directory and Execution Files
- ***Functional Testing***: Same directories as L3.
-- ***Performance Test Configuration***: `tests/dfx/perf/tests/test.json`
+- ***Performance Test Configuration***: `tests/dfx/perf/tests/test_qwen_omni.json`, `tests/dfx/perf/tests/test_tts.json`, and diffusion configs `tests/dfx/perf/tests/test_*_vllm_omni.json` (e.g. `test_qwen_image_vllm_omni.json`)
- ***Documentation Example Tests***:
- - `tests/example/online_serving/test_{model_name}.py`
- `tests/example/offline_inference/test_{model_name}.py`
diff --git a/docs/contributing/ci/test_examples/l4_performance_tests.inc.md b/docs/contributing/ci/test_examples/l4_performance_tests.inc.md
index 8093e1459f..f1f3073dc5 100644
--- a/docs/contributing/ci/test_examples/l4_performance_tests.inc.md
+++ b/docs/contributing/ci/test_examples/l4_performance_tests.inc.md
@@ -1,4 +1,4 @@
-When you want to add L4-level ***performance test*** cases, you can refer to the following format for case addition in tests/dfx/perf/tests/test.json:
+When you want to add L4-level ***performance test*** cases, you can refer to the following format for case addition in `tests/dfx/perf/tests/test_qwen_omni.json`, `tests/dfx/perf/tests/test_tts.json`, or diffusion configs such as `tests/dfx/perf/tests/test_*_vllm_omni.json` (selected via `pytest ... run_benchmark.py --test-config-file `):
```JSON
{
diff --git a/docs/contributing/ci/test_guide.md b/docs/contributing/ci/test_guide.md
index 425f24332c..08b2e3b4ea 100644
--- a/docs/contributing/ci/test_guide.md
+++ b/docs/contributing/ci/test_guide.md
@@ -45,7 +45,6 @@ Our test scripts use the pytest framework. First, please use `git clone https://
=== "L3 level & L4 level"
```bash
- cd tests
pytest -s -v -m "advanced_model" --run-level=advanced_model
```
If you only want to run L3 test case, you can use:
@@ -60,9 +59,9 @@ Our test scripts use the pytest framework. First, please use `git clone https://
```bash
pytest -s -v -m "core_model and distributed_cuda and L4" --run-level=core_model
```
- Note: To run performance tests, use:
+ Note: To run performance tests (defaults to ``test_qwen_omni.json``; use ``--test-config-file tests/dfx/perf/tests/test_tts.json`` for TTS):
```bash
- pytest -s -v perf/scripts/run_benchmark.py
+ pytest -s -v tests/dfx/perf/scripts/run_benchmark.py
```
The latest L3 test commands for various test suites can be found in the [pipeline](https://github.com/vllm-project/vllm-omni/blob/main/.buildkite/test-merge.yml).
diff --git a/docs/contributing/ci/tests_style.md b/docs/contributing/ci/tests_style.md
index 8b10cf4cc1..392f004721 100644
--- a/docs/contributing/ci/tests_style.md
+++ b/docs/contributing/ci/tests_style.md
@@ -135,8 +135,7 @@ vllm_omni/ tests/
│ ├── test_qwen3_omni_expansion.py
│ ├── test_mimo_audio.py
│ ├── test_image_gen_edit.py
- │ ├── test_images_generations_lora.py
- │ └── stage_configs/
+ │ └── test_images_generations_lora.py
└── offline_inference/ ✅
├── test_qwen2_5_omni.py
├── test_qwen3_omni.py
@@ -147,17 +146,18 @@ vllm_omni/ tests/
├── test_zimage_tensor_parallel.py
├── test_cache_dit.py
├── test_teacache.py
- ├── test_stable_audio_model.py
+ ├── test_stable_audio_expansion.py
├── test_diffusion_cpu_offload.py
├── test_diffusion_layerwise_offload.py
├── test_diffusion_lora.py
├── test_sequence_parallel.py
├── test_qwen_image_edit_expansion.py
- └── stage_configs/
- ├── qwen2_5_omni_ci.yaml
- ├── qwen3_omni_ci.yaml
- ├── bagel_*.yaml
+ └── stage_configs/ (legacy schema, still present
+ ├── bagel_*.yaml for unmigrated models)
└── npu/, rocm/, etc.
+
+# Migrated models (qwen3_omni_moe, qwen2_5_omni, qwen3_tts) live under
+# vllm_omni/deploy/ instead — see docs/configuration/stage_configs.md.
examples/ tests
│ └── examples
├── online_serving/ → ├── online_serving/
@@ -229,6 +229,7 @@ from tests.conftest import (
generate_synthetic_video,
merge_base64_and_convert_to_text,
)
+from tests.utils import get_deploy_config_path
from vllm_omni.platforms import current_omni_platform
# Edit: model name and stage config path
@@ -236,7 +237,7 @@ models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]
#If you use the default configuration file, you can directly use the following address.
def get_default_config():
- return str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml")
+ return get_deploy_config_path("ci/qwen3_omni_moe.yaml")
#If you need to modify the configuration file, you can use modify_stage_config.
def get_chunk_config():
diff --git a/docs/contributing/model/adding_omni_model.md b/docs/contributing/model/adding_omni_model.md
index a0619e3381..1eaff10596 100644
--- a/docs/contributing/model/adding_omni_model.md
+++ b/docs/contributing/model/adding_omni_model.md
@@ -313,7 +313,7 @@ The registry uses lazy loading, so the model class is imported only when needed.
## Stage Configuration
-Create a YAML configuration file in `vllm_omni/model_executor/stage_configs/`. For a complete example, see the [Qwen3-Omni configuration file](gh-file:vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml).
+Create a YAML configuration file in `vllm_omni/deploy/`. For a complete example, see the [Qwen3-Omni configuration file](gh-file:vllm_omni/deploy/qwen3_omni_moe.yaml).
### Key Configuration Fields
@@ -408,18 +408,17 @@ Understanding the data structures is crucial for implementing stage transitions:
**Input to your function:**
- `stage_list[source_stage_id].engine_outputs`: List of `EngineCoreOutput` objects
- - Each contains `outputs`: List of `RequestOutput` objects
- - Each `RequestOutput` has:
- - `token_ids`: Generated token IDs
- - `multimodal_output`: Dict with keys like `"code_predictor_codes"`, etc.
- - These are the hidden states or intermediate outputs from the model's forward pass
- - `prompt_token_ids`: Original prompt token IDs
+- - Each contains `outputs`: List of `RequestOutput` objects
+ - Each `RequestOutput` has:
+- - - `token_ids`: Generated token IDs
+ - `multimodal_output`: Dict with keys like `"code_predictor_codes"`, etc.These are the hidden states or intermediate outputs from the model's forward pass
+ - `prompt_token_ids`: Original prompt token IDs
**Output from your function:**
- Must return `list[OmniTokensPrompt]` where each `OmniTokensPrompt` contains:
- - `prompt_token_ids`: List[int] - Token IDs for the next stage
- - `additional_information`: Dict[str, Any] - Optional metadata (e.g., embeddings, hidden states)
- - `multi_modal_data`: Optional multimodal data if needed
+- - `prompt_token_ids`: List[int] - Token IDs for the next stage
+ - `additional_information`: Dict[str, Any] - Optional metadata (e.g., embeddings, hidden states)
+ - `multi_modal_data`: Optional multimodal data if needed
### How Model Outputs Are Stored
@@ -614,7 +613,7 @@ For a complete reference implementation, see:
- **Thinker**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py`
- **Talker**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py`
- **Code2Wav**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py`
-- **Stage config**: `vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml`
+- **Stage config**: `vllm_omni/deploy/qwen3_omni_moe.yaml`
- **Input processors**: `vllm_omni/model_executor/stage_input_processors/qwen3_omni.py`
- **Registry**: `vllm_omni/model_executor/models/registry.py`
- **Testing**: `vllm_omni/tests/e2e/offline_inference/test_qwen3_omni.py`
diff --git a/docs/contributing/model/adding_tts_model.md b/docs/contributing/model/adding_tts_model.md
index e48ae5049f..622064173c 100644
--- a/docs/contributing/model/adding_tts_model.md
+++ b/docs/contributing/model/adding_tts_model.md
@@ -28,7 +28,7 @@ and can be placed on different devices. Qwen3-TTS has two stages:
Each stage is a separate model class configured independently via YAML. The two stages
are connected by the `async_chunk` framework, which enables inter-stage streaming for
-low first-packet latency (see [Async Chunk Design](../../design/feature/async_chunk_design.md)).
+low first-packet latency (see [Async Chunk Design](../../design/feature/async_chunk.md)).
### Without async_chunk (batch mode)
@@ -120,8 +120,18 @@ vllm_omni/model_executor/stage_configs/
| `models/qwen3_tts/qwen3_tts.py` | Unified model class |
| `models/qwen3_tts/qwen3_tts_code_predictor_vllm.py` | Stage 0 - optimized AR |
| `models/qwen3_tts/qwen3_tts_code2wav.py` | Stage 1 - decoder |
-| `stage_configs/qwen3_tts.yaml` | Stage config (async_chunk enabled) |
-| `stage_configs/qwen3_tts_batch.yaml` | Batch mode config |
+| `deploy/qwen3_tts.yaml` (new schema) | Deploy config (async_chunk enabled) — paired with `models/qwen3_tts/pipeline.py` for the frozen topology |
+
+> **Chunked vs end-to-end modes**: `qwen3_tts` registers a single
+> pipeline whose stage 1 declares alternate processor functions — an
+> `async_chunk_process_next_stage_input_func` (per-chunk streaming, used
+> when `deploy.async_chunk=True`) and a `sync_process_input_func`
+> (batch-end, used when `deploy.async_chunk=False`). The loader selects
+> one at merge time based on the bool, so `--no-async-chunk` alone
+> switches modes — no variant yaml or variant pipeline registration is
+> needed. Pipelines that only make sense in one mode (e.g.
+> `qwen3_omni_moe` is always chunked) can keep using the unconditional
+> `custom_process_*` fields.
| `stage_input_processors/qwen3_tts.py` | Stage transition processors |
## Step-by-Step Implementation
@@ -574,11 +584,12 @@ Adding a TTS model to vLLM-Omni involves:
| `models/qwen3_tts/qwen3_tts.py` | Unified model class |
| `models/qwen3_tts/qwen3_tts_code_predictor_vllm.py` | AR stage with vLLM fused ops |
| `models/qwen3_tts/qwen3_tts_code2wav.py` | Decoder stage with `chunked_decode_streaming()` |
-| `stage_configs/qwen3_tts.yaml` | Stage configuration |
+| `models/qwen3_tts/pipeline.py` | Frozen pipeline topology (registered at import time) |
+| `deploy/qwen3_tts.yaml` | Deploy config (user-editable, async_chunk + SharedMemoryConnector) |
| `stage_input_processors/qwen3_tts.py` | Stage transition processors |
For more information, see:
- [Architecture Overview](../../design/architecture_overview.md)
-- [Async Chunk Design](../../design/feature/async_chunk_design.md)
+- [Async Chunk Design](../../design/feature/async_chunk.md)
- [Stage Configuration Guide](../../configuration/stage_configs.md)
diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md
index 7a2e64f131..6c209e5659 100644
--- a/docs/contributing/profiling.md
+++ b/docs/contributing/profiling.md
@@ -1,216 +1,193 @@
# Profiling vLLM-Omni
-> **Warning:** Profiling incurs significant overhead. Use only for development and debugging, never in production.
+> **Warning:** Profiling is for development and debugging only. It adds significant overhead and should not be enabled in production.
-vLLM-Omni uses the PyTorch Profiler to analyze performance across both **multi-stage omni-modality models** and **diffusion models**.
+vLLM-Omni supports two profiler backends through `profiler_config`:
-### 1. Configure Profiling in the Stage YAML
+- `torch`: detailed CPU/CUDA traces written to `torch_profiler_dir`
+- `cuda`: low-overhead CUDA range control for NVIDIA Nsight Systems (`nsys`)
-Enable profiling by adding `profiler_config` under `engine_args` for the stage(s) you want to profile in your stage config YAML:
+## 1. Configure Profiling
+
+Use the same `profiler_config` shape everywhere:
+
+```yaml
+profiler_config:
+ profiler: torch
+ torch_profiler_dir: ./perf
+```
+
+Supported fields:
+
+| Field | Description |
+|---|---|
+| `profiler` | Profiler backend. Supported values: `torch`, `cuda`. |
+| `torch_profiler_dir` | Output directory for torch traces. Required when `profiler: torch`. |
+| `delay_iterations` | Number of worker iterations to skip before profiling starts. |
+| `max_iterations` | Maximum number of worker iterations to capture before auto-stop. |
+| `warmup_iterations` | Torch-profiler warmup iterations. |
+| `active_iterations` | Torch-profiler active iterations. |
+| `wait_iterations` | Torch-profiler wait iterations before warmup. |
+
+For multi-stage omni pipelines, put `profiler_config` under the target stage's `engine_args`.
```yaml
stage_args:
- stage_id: 0
stage_type: llm
engine_args:
- # ... other engine args ...
profiler_config:
profiler: torch
torch_profiler_dir: ./perf
```
-| Field | Description |
-|---|---|
-| `profiler` | Profiler backend to use. Currently supports `torch`. |
-| `torch_profiler_dir` | Directory where trace files are saved. Created automatically if it doesn't exist. |
-
-> **Tip:** Only enable `profiler_config` on stages you actually need to profile. Stages without it will not start a profiler, keeping overhead minimal.
-
-### 2. Profiling Omni-Modality Models
+For single-stage diffusion usage, pass `profiler_config` directly to `Omni(...)` or `vllm serve`.
-**Selective Stage Profiling**
+## 2. Profiling Omni Pipelines
-It is highly recommended to profile specific stages to prevent producing overly large trace files:
+It is usually best to profile only the stages you need.
```python
-# Profile all stages
-omni_llm.start_profile()
+# Profile all stages.
+omni.start_profile()
-# Only profile Stage 1
-omni_llm.start_profile(stages=[1])
-
-# Stage 0 (Thinker) and Stage 2 (Audio Decoder) for qwen omni
-omni_llm.start_profile(stages=[0, 2])
+# Profile selected stages only.
+omni.start_profile(stages=[0, 2])
+...
+omni.stop_profile(stages=[0, 2])
```
-> **Important:** Always pass the same `stages` list to both `start_profile()` and `stop_profile()`. If you omit `stages` from `stop_profile()`, it defaults to stopping all stages — including ones that were never started — which will produce errors.
+Always stop the same stage set that you started. If only some stages have `profiler_config`, pass an explicit `stages=[...]` list instead of relying on the default "all stages" behavior.
-**Python Usage**: Wrap your generation logic with `start_profile()` and `stop_profile()`.
+Examples:
-```python
-profiler_stages = [0] # Only profile the stages you need
+1. [Qwen2.5-Omni end2end](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen2_5_omni/end2end.py)
+2. [Qwen3-Omni end2end](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen3_omni/end2end.py)
-# 1. Start profiling
-omni.start_profile(stages=profiler_stages)
+## 3. Profiling Single-Stage Diffusion
-# Initialize generator
-omni_generator = omni.generate(prompts, sampling_params_list, py_generator=args.py_generator)
+Single-stage diffusion models use the same `start_profile()` / `stop_profile()` controls, but you must provide `profiler_config` explicitly.
-total_requests = len(prompts)
-processed_count = 0
+### PyTorch profiler
-# Main Processing Loop
-for stage_outputs in omni_generator:
-
- # ... [Output processing logic for text/audio would go here] ...
+```python
+from vllm_omni import Omni
+
+omni = Omni(
+ model="Wan-AI/Wan2.2-I2V-A14B-Diffusers",
+ profiler_config={
+ "profiler": "torch",
+ "torch_profiler_dir": "./perf",
+ },
+)
+
+omni.start_profile()
+...
+omni.stop_profile()
+```
- # Update count to track when to stop profiling
- processed_count += len(stage_outputs.request_output)
+### Nsight Systems (`nsys`)
- # 2. Check if all requests are done to stop the profiler safely
- if profiler_enabled and processed_count >= total_requests:
- print(f"[Info] Processed {processed_count}/{total_requests}. Stopping profiler inside active loop...")
+For Nsight Systems, use `profiler: cuda` and wrap the process with `nsys profile`.
- # Stop the profiler while workers are still active
- # Pass the same stages list used in start_profile()
- omni_llm.stop_profile(stages=profiler_stages)
+```bash
+nsys profile \
+ --trace-fork-before-exec=true \
+ --cuda-graph-trace=node \
+ --capture-range=cudaProfilerApi \
+ --capture-range-end=repeat \
+ -o diffusion_trace \
+ python image_to_video.py ...
+```
- # Wait for traces to flush to disk
- print("[Info] Waiting 30s for workers to write trace files to disk...")
- time.sleep(30)
- print("[Info] Trace export wait time finished.")
+The Python process being profiled must create the diffusion engine with:
-omni_llm.close()
+```python
+profiler_config={"profiler": "cuda"}
```
+Then call `start_profile()` before the requests you want to capture and `stop_profile()` after them. The diffusion worker processes open and close the CUDA capture range themselves, so `nsys` sees the actual GPU work instead of only the parent process.
-**CLI Usage** (using `end2end.py`):
-```bash
-# Profile only Stage 0 (Thinker)
-python end2end.py --output-wav output_audio \
- --query-type text --enable-profiler --profiler-stages 0
+Examples:
-# Profile Stage 0 and Stage 2
-python end2end.py --output-wav output_audio \
- --query-type text --enable-profiler --profiler-stages 0 2
-
-# Profile all stages (omit --profiler-stages)
-python end2end.py --output-wav output_audio \
- --query-type text --enable-profiler
-```
+1. [Image edit example](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/image_to_image/image_edit.py)
+2. [Image to video example](https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video)
-**Examples**:
+## 4. Profiling Online Serving
-1. **Qwen2.5-Omni**: [https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen2_5_omni/end2end.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen2_5_omni/end2end.py)
+When any stage has `profiler_config.profiler` set, the server exposes:
-2. **Qwen3-Omni**: [https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen3_omni/end2end.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen3_omni/end2end.py)
+- `POST /start_profile`
+- `POST /stop_profile`
-### 3. Profiling diffusion models
+### Start the server
-Diffusion profiling is End-to-End, capturing encoding, denoising loops, and decoding. Standalone diffusion scripts use `--profiler-dir` to enable profiling.
+Multi-stage omni serving:
-**CLI Usage:**
```bash
-python image_to_video.py \
- --model Wan-AI/Wan2.2-I2V-A14B-Diffusers \
- --image qwen-bear.png \
- --prompt "A cat playing with yarn, smooth motion" \
- --profiler-dir \
- \
- # Minimize Spatial Dimensions (Optional but helpful):
- # Drastically reduces memory usage so the profiler doesn't
- # crash due to overhead, though for accurate performance
- # tuning you often want target resolutions.
- --height 48 \
- --width 64 \
- \
- # Minimize Temporal Dimension (Frames):
- # Video models process 3D tensors (Time, Height, Width).
- # Reducing frames to the absolute minimum (2) keeps the
- # tensor size small, ensuring the trace file doesn't become
- # multi-gigabytes in size.
- --num-frames 2 \
- \
- # Minimize Iteration Loop (Steps):
- # This is the most critical setting for profiling.
- # Diffusion models run the same loop X times.
- # Profiling 2 steps gives you the exact same performance
- # data as 50 steps, but saves minutes of runtime and
- # prevents the trace viewer from freezing.
- --num-inference-steps 2 \
- \
- --guidance-scale 5.0 \
- --guidance-scale-high 6.0 \
- --boundary-ratio 0.875 \
- --flow-shift 12.0 \
- --fps 16 \
- --output i2v_output.mp4
+vllm serve Qwen/Qwen2.5-Omni-7B \
+ --omni \
+ --port 8091
```
-> **Note:** For diffusion stages within a multi-stage omni pipeline, use `profiler_config` in the stage YAML instead (see Section 1).
-
-**Examples**:
-
-1. **Qwen image edit**: [https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/image_to_image/image_edit.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/image_to_image/image_edit.py)
-
-2. **Wan-AI/Wan2.2-I2V-A14B-Diffusers**: [https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video](https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video)
-
-### 4. Profiling Online Serving
+(The default deploy config at `vllm_omni/deploy/qwen2_5_omni.yaml` is loaded automatically. Pass `--deploy-config /path/to/custom.yaml` to override.)
-When `profiler_config` is set in the stage YAML, the server automatically exposes `/start_profile` and `/stop_profile` HTTP endpoints.
+Single-stage diffusion serving with torch profiler:
-**1. Start the server** with a stage YAML that has `profiler_config` enabled:
```bash
-vllm serve Qwen/Qwen2.5-Omni-7B \
- --omni \
- --stage-configs-path qwen2_5_omni.yaml \
- --port 8091
+vllm serve Wan-AI/Wan2.2-I2V-A14B-Diffusers \
+ --omni \
+ --port 8091 \
+ --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}'
```
-Or for one stage diffusion models:
+Single-stage diffusion serving with Nsight Systems:
```bash
-vllm serve Wan-AI/Wan2.2-I2V-A14B-Diffusers --omni --port 8091 --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}'
+nsys profile \
+ --trace-fork-before-exec=true \
+ --cuda-graph-trace=node \
+ --capture-range=cudaProfilerApi \
+ --capture-range-end=repeat \
+ -o serving_trace \
+ vllm serve Wan-AI/Wan2.2-I2V-A14B-Diffusers \
+ --omni \
+ --port 8091 \
+ --profiler-config '{"profiler": "cuda"}'
```
-**2. Start profiling** by sending a POST request:
+### Control capture
+
```bash
-# Profile all stages that have profiler_config set
+# Start profiling on all profiled stages.
curl -X POST http://localhost:8091/start_profile
-# Profile specific stages only
+# Start profiling on selected stages.
curl -X POST http://localhost:8091/start_profile \
- -H "Content-Type: application/json" \
- -d '{"stages": [0]}'
-```
+ -H "Content-Type: application/json" \
+ -d '{"stages": [0]}'
-**3. Send your inference requests** as normal while the profiler is running.
-
-**4. Stop profiling** and collect traces:
-```bash
-# Stop all stages
+# Stop profiling.
curl -X POST http://localhost:8091/stop_profile
-
-# Stop specific stages (must match the stages you started)
-curl -X POST http://localhost:8091/stop_profile \
- -H "Content-Type: application/json" \
- -d '{"stages": [0]}'
```
-Trace files are written to the `torch_profiler_dir` specified in your stage YAML.
+For mixed-stage pipelines, use explicit `stages` and pass the same stage list to both endpoints.
+
+## 5. Analyze Results
-> **Important:** Always stop the same stages you started. Stopping a stage that was never started will produce errors.
+Torch profiler output:
-### 5. Analyzing Traces
+- Chrome/Perfetto traces under `torch_profiler_dir`
+- Optional aggregated CUDA-time tables under the same directory
-Output files are saved to the `torch_profiler_dir` specified in your stage YAML config.
+CUDA profiler / Nsight Systems output:
-**Output**
-**Chrome Trace** (`.json.gz`): Visual timeline of kernels and stages. Open in Perfetto UI.
+- `.nsys-rep` report files written by `nsys -o ...`
-**Viewing Tools:**
+Recommended viewers:
-- [Perfetto](https://ui.perfetto.dev/) (recommended)
-- `chrome://tracing` (Chrome only)
+- [Perfetto](https://ui.perfetto.dev/) for torch traces
+- `nsys stats .nsys-rep` for CLI summaries
+- Nsight Systems GUI for CUDA kernel timelines
-**Note**: vLLM-Omni reuses the PyTorch Profiler infrastructure from vLLM. See the official vLLM profiler documentation: [vLLM Profiling Guide](https://docs.vllm.ai/en/stable/contributing/profiling/)
+vLLM-Omni reuses the vLLM profiling infrastructure where possible. For the upstream reference, see the [vLLM profiling guide](https://docs.vllm.ai/en/stable/contributing/profiling/).
diff --git a/docs/design/feature/async_chunk_design.md b/docs/design/feature/async_chunk.md
similarity index 80%
rename from docs/design/feature/async_chunk_design.md
rename to docs/design/feature/async_chunk.md
index 202ef0e18e..57b4209b8d 100644
--- a/docs/design/feature/async_chunk_design.md
+++ b/docs/design/feature/async_chunk.md
@@ -1,4 +1,4 @@
-# Async Chunk Design
+# Async Chunk
## Table of Contents
@@ -19,7 +19,7 @@ The `async_chunk` feature enables asynchronous, chunked processing of data acros
For qwen3-omni:
- **Thinker → Talker**: Per decode step (typically chunk_size=1)
-- **Talker → Code2Wav**: Accumulated to `codec_chunk_frames` (default=25) before sending. During the initial phase, a dynamic initial chunk size (IC) is automatically selected based on server load to reduce TTFA. Use the per-request `initial_codec_chunk_frames` API field to override.
+- **Talker → Code2Wav**: Accumulated to `codec_chunk_frames` (default=25) before sending. During the initial phase, a dynamic initial chunk size (IC) is automatically selected based on server load to reduce TTFP. Use the per-request `initial_codec_chunk_frames` API field to override.
- **Code2Wav**: Streaming decode with code2wav chunk_size
With `async_chunk`:
@@ -75,26 +75,85 @@ Enabling **async_chunk** (False→True) sharply reduces time-to-first-audio (TTF
## Architecture
-### Data Flow
-#### Sequential Flow
+### Async Chunk Pipeline Overview
+
+The following diagram illustrates the **Async Chunk Architecture** for multi-stage models (e.g., Qwen3-Omni with Thinker → Talker → Code2Wav), showing how data flows through the 4-stage pipeline with parallel processing and dual-stream output:
+
-
-
+
+
-#### Async Chunk Flow
+**Diagram Legend:**
+
+| Step | Stage Type | Description |
+|------|-----------|------------|
+| `prefill` | Initialization | Context processing, KV cache initialization |
+| `decode` | Autoregressive | Token-by-token generation in AR stages |
+| `codes` | Audio Encoding | RVQ codec codes from Talker stage |
+| `output` | Final Output | Text chunks or audio waveforms |
+
+### Data Flow
+
+#### Stage 0: Thinker (Multimodal Understanding + Text Generation)
+- **Prefill**: Processes multimodal input (text/image/audio/video), initializes KV cache
+- **Decode Loop**: Generates text tokens autoregressively
+- **Chunk Triggers**: Each decode step (typically `chunk_size=1`) can trigger downstream processing
+- **Dual Output**:
+ - **Text Stream**: `text_0`, `text_1`, `text_2`... `text_n` streamed to output
+ - **Hidden States**: Passed to Talker stage for audio synthesis
+
+#### Stage 1: Talker (Text → RVQ Audio Codes)
+- **Prefill**: Receives hidden states from Thinker as semantic condition
+- **Decode Loop**: Generates RVQ codec codes autoregressively
+- **Accumulation**: Codes accumulate to `codec_chunk_frames` (default=25) before forwarding
+- **Dynamic IC**: Initial chunk size auto-selected based on server load to optimize TTFP
+- **Output**: `codes` blocks (chunk 0, 1, ... n) sent to Code2Wav
+
+#### Stage 2: Code2Wav (Vocoder Decoder)
+- **Non-Autoregressive**: Processes RVQ codes in parallel batches
+- **Streaming Decode**: Converts codes to audio waveforms chunk-by-chunk
+- **Batching**: Supports batched inference for multiple concurrent requests
+- **Output**: Audio segments `audio_0`, `audio_1`, ... `audio_n`
+
+#### Stage 3: Output (Dual Stream)
+- **Text Streaming**: `text_0` → `text_1` → `text_2` → ... (user sees response in real-time)
+- **Audio Streaming**: `audio_0` → `audio_1` → ... (user hears audio progressively)
+
+### Execution Timeline
+```
+Timeline: Parallel vs Sequential
+
+Sequential (async_chunk=false):
+[Thinker: ████████████████████] (2.0s)
+ [Talker: ████████████████████] (3.0s)
+ [Code2Wav: ████] (1.0s)
+Total: 6.0s, TTFP: 6.0s
+
+Async Chunk (async_chunk=true):
+[Thinker: ████░░░░████░░░░████] (2.0s, streaming)
+ [Talker: ░░████░░░░████░░] (3.0s, parallel)
+ [Code2Wav: ░░░░████░░] (1.0s, batched)
+Total: ~3.5s, TTFP: ~0.5s
+
+█ = Active computation ░ = Waiting/idle
+```
+
+#### Sequential Flow (for comparison)
-
-
+
+
-### Async Chunk architecture
+In sequential mode, each stage must wait for the previous stage to complete entirely before starting.
+
+### Async Chunk System Architecture
diff --git a/docs/design/feature/cfg_parallel.md b/docs/design/feature/cfg_parallel.md
index 64decbe956..c73a87749f 100644
--- a/docs/design/feature/cfg_parallel.md
+++ b/docs/design/feature/cfg_parallel.md
@@ -25,7 +25,9 @@ In standard Classifier-Free Guidance, each diffusion step requires two forward p
1. **Positive/Conditional**: Guided by the text prompt
2. **Negative/Unconditional**: Typically using empty or negative prompt
-CFG-Parallel eliminates this bottleneck by distributing the two forward passes across different GPU ranks, allowing them to execute simultaneously rather than sequentially.
+Some models require 3 or more CFG branches (see [N-Branch CFG](#n-branch-cfg-3-branches)).
+
+CFG-Parallel eliminates this bottleneck by distributing the forward passes across different GPU ranks, allowing them to execute simultaneously rather than sequentially.
### Architecture
@@ -33,9 +35,11 @@ vLLM-omni provides `CFGParallelMixin` that encapsulates all CFG parallel logic.
| Method | Purpose | Automatic Behavior |
|--------|---------|-------------------|
-| [`predict_noise_maybe_with_cfg()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Predict noise with CFG | Detects parallel mode, distributes computation, gathers results |
+| [`predict_noise_maybe_with_cfg()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Predict noise with 2-branch CFG | Detects parallel mode, distributes computation, gathers results |
+| [`predict_noise_with_multi_branch_cfg()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Predict noise with N-branch CFG | Round-robin dispatches N branches across M GPUs |
| [`scheduler_step_maybe_with_cfg()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Step scheduler | All ranks step locally (no broadcast needed) |
-| [`combine_cfg_noise()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Combine positive/negative | Applies CFG formula with optional normalization |
+| [`combine_cfg_noise()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Combine 2-branch predictions | Applies CFG formula with optional normalization |
+| [`combine_multi_branch_cfg_noise()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Combine N-branch predictions | Override for custom multi-branch combine logic |
| [`predict_noise()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Forward pass wrapper | Override for custom transformer calls |
| [`cfg_normalize_function()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Normalize CFG output | Override for custom normalization |
@@ -57,6 +61,22 @@ vLLM-omni provides `CFGParallelMixin` that encapsulates all CFG parallel logic.
- All ranks compute the scheduler step locally — no broadcast needed because `predict_noise_maybe_with_cfg` already ensures all ranks have identical noise predictions after `all_gather` + local combine.
+### N-Branch CFG (3+ branches)
+
+Some models require more than 2 CFG branches. For example, Bagel and OmniGen2 use 3 branches, DreamID Omni uses 4 branches.
+
+`predict_noise_with_multi_branch_cfg()` handles these by automatically dispatching N branches across M GPUs using round-robin (rule: branch `i` → rank `i % M`):
+
+| Branches (N) | GPUs (M) | Dispatch |
+|:---:|:---:|:---|
+| 3 | 2 | `[[0, 2], [1]]` |
+| 3 | 3 | `[[0], [1], [2]]` |
+| 4 | 2 | `[[0, 2], [1, 3]]` |
+| 4 | 3 | `[[0, 3], [1], [2]]` |
+| 4 | 4 | `[[0], [1], [2], [3]]` |
+
+When a rank handles multiple branches, it runs them sequentially. After `all_gather`, all ranks execute `combine_multi_branch_cfg_noise()` locally, producing identical results.
+
---
## Step-by-Step Implementation
@@ -98,6 +118,7 @@ class YourModelPipeline(nn.Module, CFGParallelMixin):
- `positive_kwargs`: transformer arguments for conditional (text-guided) prediction
- `negative_kwargs`: transformer arguments for unconditional prediction (set to `None` if CFG disabled)
- For image editing pipelines, add `output_slice=image_seq_len` to extract the generative image portion
+- For models with 3+ CFG branches, see [Multi-Branch CFG](#multi-branch-cfg-3-branches) in the Customization section
### Step 2: Call `diffuse`
@@ -171,20 +192,42 @@ class LongCatImagePipeline(nn.Module, CFGParallelMixin):
```
-### Override `combine_cfg_noise()` for Multi-Output Models
+### Multi-Branch CFG (3+ branches)
+
+For models with 3 or more CFG branches, use `predict_noise_with_multi_branch_cfg()` instead of `predict_noise_maybe_with_cfg()`, and override `combine_multi_branch_cfg_noise()` for custom combine logic. This interface also works for standard 2-branch CFG — just pass 2 branches in `branches_kwargs`.
-When `predict_noise()` returns a tuple (e.g., video + audio), the default `combine_cfg_noise()` applies CFG to every element. Override it to apply different logic per element — for example, CFG on video but positive-only on audio:
+**Example (3-branch with dual guidance scale):**
```python
-class MyVideoAudioPipeline(nn.Module, CFGParallelMixin):
- def combine_cfg_noise(self, positive_noise_pred, negative_noise_pred, scale, normalize):
- (video_pos, audio_pos) = positive_noise_pred
- (video_neg, audio_neg) = negative_noise_pred
- video_combined = super().combine_cfg_noise(video_pos, video_neg, scale, normalize)
- return (video_combined, audio_pos) # audio: positive only, no CFG
+class YourMultiBranchPipeline(nn.Module, CFGParallelMixin):
+ def combine_multi_branch_cfg_noise(self, predictions, true_cfg_scale, cfg_normalize=False):
+ text_scale = true_cfg_scale["text"]
+ image_scale = true_cfg_scale["image"]
+ pos, ref, uncond = predictions
+ return uncond + image_scale * (ref - uncond) + text_scale * (pos - ref)
+
+ def diffuse(self, ...):
+ for i, t in enumerate(timesteps):
+ positive_kwargs = {...} # conditional prompt
+ ref_neg_kwargs = {...} # negative prompt + reference
+ uncond_kwargs = {...} # unconditional
+
+ noise_pred = self.predict_noise_with_multi_branch_cfg(
+ do_true_cfg=do_true_cfg,
+ true_cfg_scale={"text": text_guidance_scale, "image": image_guidance_scale},
+ branches_kwargs=[positive_kwargs, ref_neg_kwargs, uncond_kwargs],
+ )
+ latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
+
+ return latents
```
-This also requires `predict_noise()` to return a tuple (see [Override predict_noise](#override-predict_noise-for-custom-transformer-calls) above).
+### Override Combine Functions
+
+There are two combine functions for different scenarios:
+
+- **`combine_cfg_noise()`** — Used by `predict_noise_maybe_with_cfg()`. Override when `predict_noise()` returns a tuple (e.g., video + audio) and you need per-element CFG logic.
+- **`combine_multi_branch_cfg_noise()`** — Used by `predict_noise_with_multi_branch_cfg()`. Override to implement custom multi-branch combine formulas (see [Multi-Branch CFG](#multi-branch-cfg-3-branches) above).
### Implement a Composite Scheduler for Multi-Output Models
@@ -303,4 +346,5 @@ Adding CFG-Parallel support:
1. ✅ **Create mixin** - Inherit from `CFGParallelMixin` and implement `diffuse()` method
2. ✅ **(Optional) Customize** - Override `predict_noise()` or `cfg_normalize_function()` for custom behavior
-3. ✅ **Test** - Verify with `--cfg-parallel-size 2` and compare performance
+3. ✅ **(Optional) Multi-branch** - For 3+ branch models, use `predict_noise_with_multi_branch_cfg()` and override `combine_multi_branch_cfg_noise()`
+4. ✅ **Test** - Verify with `--cfg-parallel-size 2` (or 3/4 for multi-branch) and compare performance
diff --git a/docs/design/feature/expert_parallel.md b/docs/design/feature/expert_parallel.md
index 9a7c4cdbac..e05eec3361 100644
--- a/docs/design/feature/expert_parallel.md
+++ b/docs/design/feature/expert_parallel.md
@@ -207,9 +207,9 @@ Complete examples in the codebase:
| Model | Path | Pattern | Notes |
|-------|------|---------|-------|
-| **HunyuanImage3.0** | `vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py` | Standard EP | Full implementation with validation |
+| **HunyuanImage3.0** | `vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py` | Standard EP | Full implementation with validation |
| **EP Tests** | `vllm-omni/tests/e2e/offline_inference/test_expert_parallel.py` | E2E testing | EP correctness and performance |
-| **Constraint Tests** | `vllm-omni/tests/diffusion/models/hunyuan_image_3/test_hunyuan_fused_moe.py` | Unit testing | Validation logic |
+| **Constraint Tests** | `vllm-omni/tests/diffusion/models/hunyuan_image3/test_hunyuan_fused_moe.py` | Unit testing | Validation logic |
---
## Summary
diff --git a/docs/design/feature/prefix_caching.md b/docs/design/feature/prefix_caching.md
new file mode 100644
index 0000000000..ebad8b6910
--- /dev/null
+++ b/docs/design/feature/prefix_caching.md
@@ -0,0 +1,164 @@
+# Automatic Prefix Caching in Omni Models
+
+
+---
+
+## Table of Contents
+
+- [Overview](#overview)
+- [High-Level Approach](#high-level-approach)
+- [Example](#example)
+- [What About Multimodal Inputs?](#what-about-multimodal-inputs)
+
+---
+
+### Overview
+
+Prefix caching in the context of kv-cache management is a useful optimization for avoiding redundant computations. The main idea is that we store portions of the kv-cache from processed requests, so that we can reuse them if incoming requests have the same prefix as previous requests.
+
+vLLM manages the kv-cache as blocks, which represent a span of tokens of a fixed length. Blocks are hashable by the content that they contain, which typically means the tokens within the span, but also could be influenced by other factors, e.g., LoRA and multimodal data.
+
+vLLM implements automatic prefix caching for managing its kv-cache, which is best understood by reading the design document [here](https://docs.vllm.ai/en/latest/design/prefix_caching/). vLLM-Omni builds on top of the prefix caching mechanism in a noninvasive way to allow caching between stages in Omni pipelines. This typically means for a given stage we aim to support caching for the following:
+
+- The last hidden states produced by the stage
+- Model / stage specific multimodal data
+
+!!! note "Note 1"
+ This document describes vLLM-Omni's mechanism for caching tensor outputs that are meant to be passed between stages, when requests have common prefixes, similar to the way in which vLLM has prefix caching for the kv-cache. This works in conjunction with vLLM's multimodal encoder caching, but is distinct. See the final section for a concrete example for how they tie together in practice.
+
+### High-Level Approach
+!!! note "Note 2"
+ Prior to reading this section, it's recommended to take a look at the design documents in vLLM for [Automatic Prefix Caching](https://docs.vllm.ai/en/latest/features/automatic_prefix_caching/), which will make some of the concepts more clear.
+
+The main focus of vLLM-Omni's approach to prefix caching stage outputs is to build on vLLM's prefix caching in the least invasive way possible while minimizing impact for cache misses, and consuming a minimal amount of GPU memory. To understand the implementation, there are a few important things to note:
+
+- Between stages, device tensors are generally moved to CPU; this is important since we're just caching the outputs of stages, so it is okay to keep the entire cache on the CPU.
+
+- For a tensor to be considered cacheable, the first dimension (currently) needs to be the same as the token count, as it allows us to reuse block/slot mappings for our externally maintained tensor caches. This allows us to dynamically discover the tensors to be marked as cacheable outputs in each Omni model without having to explicitly specify cacheable output field names in every model.
+
+With this in mind, consider the set of blocks in a 2D layout, where the row represents the index of blocks being considered, and the columns represent the slots corresponding to tokens within each block. Since we know the `num_blocks` and `block_size` from our kv cache config, if we want to cache a tensor with feature size `D`, we can preallocate a CPU tensor of size `(num_blocks, block_size, D)`, and use the same block index and slot mapping to retrieve the corresponding feature vector.
+
+
+### Example
+!!! note "Note 3"
+ Prefix caching in vLLM-Omni currently is only supported on AutoRegressive stages with one kv-cache group. It can be enabled/disabled per-stage via the `enable_prefix_caching` parameter in the model's stage config.
+
+The way in which vLLM-Omni ties into vLLM's prefix caching is best understood by example. Say that we have the following:
+
+- `num_blocks=8`
+- `block_size=4`
+- `hidden_size=2`
+- A stage specific multimodal output tensor named `mm_feature` with feature dimension `16`
+
+The prefix cache flow is then outlined below.
+
+1. When the model is initialized, we can determine the `hidden_size` from the `ModelConfig`, and allocate a cache of size `(num_blocks, block_size, hidden_size)`.
+
+2. Say we process the request `The quick brown fox was tired and slept beneath the shady tree`, which is 12 tokens and evenly divides into 3 blocks as shown below.
+
+```
+ [ The quick brown fox ] [ was tired and slept ] [beneath the shady tree ]
+Block 1: |<--- block tokens ---->|
+Block 2: |<------- prefix ------>| |<--- block tokens --->|
+Block 3: |<------------------ prefix -------------------->| |<--- block tokens ---->|
+```
+
+When the request processes, we inspect the multimodal outputs and identify the `mm_feature` tensor, which will be of shape `(seq_len, feature_dim)`, i.e., `(12, 16)` in this example. We note that the first axis is dependent on the `seq_len` and add a new cache_tensor of shape `(num_blocks, block_size, feature_dim)` to our multimodal cache for tensors.
+
+
+3. If we lay out the cache as a 2D tensor of shape (`num_blocks`, `block_size`), we'll have something like the following:
+
+```
+0: [ The quick brown fox ]
+1: [ was tired and slept ]
+2: [beneath the shady tree ]
+3: [EMPTY]
+...
+7: [EMPTY]
+```
+
+Or, if we flatten it down to 1D,
+```
+0: The
+1: quick
+2: brown
+3: fox
+...
+11: tree
+12: [EMPTY]
+...
+```
+
+which we can think of as row indices into the hidden states tensor if we view it as the 2D shape `(num_blocks x block_size, feature_dim)`. That is, the analogous flattened (from 3D -> 2D) mapping of the cache for hidden states becomes the following.
+```
+0:
+1:
+2:
+3:
+...
+11:
+12: [EMPTY]
+...
+```
+
+Similarly, for the multimodal outputs cache, the flattened coordinates are the same, but the `mm_feature` maps to vectors of length `16` instead of the hidden size of `2`. Note that in practice, we may have multiple multimodal output tensors per forward pass, which may have different names and different feature dimensions.
+
+
+4. Now, say that we receive a new request `The quick brown fox jumped over the dog`.
+
+```
+ [ The quick brown fox ] [ jumped over the dog ]
+Block 1: |<--- block tokens ---->|
+Block 2: |<------- prefix ------>| |<--- block tokens --->|
+```
+
+Here, we will have a cache hit for `Block 1` which will be detected by vLLM based on the hash of the first block when it's handling the prefix caching on the kv-cache. As a result, when we get the output from the scheduler, we will see that `num_computed_tokens=4` (corresponding to the cached first block), and we only need to process the remaining 4 new tokens in the new prefill.
+
+Since we have the block indices / slot mappings from the kv cache manager, we can simply mirror the mappings and leverage the same indices for the cached hidden states and multimodal outputs. This allows us to look up the correct tensors from our externally maintained 3D caches.
+
+```
+0: [ The quick brown fox ] < already in the cache
+1: [ was tired and slept ]
+2: [beneath the shady tree ]
+3: [ jumped over the dog ] < added on the second request
+4: [EMPTY]
+...
+7: [EMPTY]
+...
+```
+
+Finally, to pass the full hidden states and multimodal outputs to the next stage, we simply concatenate the cached contents with the corresponding new tensors computed from the current forward call.
+
+
+### What About Multimodal Inputs?
+It's also useful to consider the case about how Omni prefix caching is handled when we have multimodal inputs that don't cleanly end on block boundaries, as well as how this works with multimodal encoder caching in vLLM. For example:
+
+```
+ [ Im0 Im1 Im2 Im3 ] [ Im4 Im5 foo ]
+Block 1: |<--- block tokens ---->|
+Block 2: |<------- prefix ------>| |<--- block tokens --->|
+```
+
+In this case, only `Block 1` will have outputs stored in the prefix tensor cache, because vLLM does not store partial blocks. This may appear to be a problem at first glance, because the multimodal input is fragmented across a new block that wasn't cached.
+
+In reality, this isn't a big problem for correctness, because vLLM also maintains an encoder cache for multimodal inputs. In other words, after the first pass, we'll have the following:
+
+- The Block 1 hash, which is used for prefix caching
+- The hash describing the image data starting at position 0 and with length 6
+- In vLLM's encoder cache, a mapping from the image hash above to the encoder output
+
+
+To understand what happens, say we get the following input as a second request:
+```
+ [ Im0 Im1 Im2 Im3 ] [ Im4 Im5 bar baz ]
+Block 1: |<--- block tokens ---->|
+Block 2: |<------- prefix ------>| |<--- block tokens --->|
+```
+
+First, the scheduler will check for a prefix cache hit, which we will see on `Block 1`. As a result, we will have 4 tokens marked as precomputed, and only see the remaining 4 tokens in the following prefill.
+
+Because we have multimodal data in a scheduled span that isn't fully precomputed, we still need to call the visual encoder. However, since we have the image hash and encoder cache, we will retrieve the encoder outputs for `Im4` and `Im5` as we create the multimodal embeddings.
+
+When we pass our multimodal tensors to the language model component in the same stage, we'll then expect the same outputs, because the prefix caching behaviors in vLLM-Omni / vLLM match, so the LLM will use vLLM's KV cache manager's prefix caching to correctly handle the attention information for `Block 1` while calculating the outputs for `Block 2`, giving us the correct results for processing `Block 2` with the context of `Block 1`.
+
+Finally, we look up the output hidden states/multimodal tensors corresponding to the prefix cache hit `Block 1` and concatenate it with the forward pass result to get the final result, which is expected to be identical to the full hidden states when prefix caching is disabled.
diff --git a/docs/design/feature/teacache.md b/docs/design/feature/teacache.md
index 9fa315cee7..8577cff1f0 100644
--- a/docs/design/feature/teacache.md
+++ b/docs/design/feature/teacache.md
@@ -326,9 +326,41 @@ for prompt in tqdm(prompts, desc="Collecting data"):
# Estimate coefficients
coeffs = estimator.estimate(poly_order=4)
-print(f"Estimated coefficients: {coeffs.tolist()}")
+print(f"Estimated coefficients: {coeffs}")
```
+Note: some models may require the vLLM context and config to be initialized to initialize vLLM modules. To this end, you may need a workaround like the following to be able to run coefficient estimation.
+```python
+from vllm_omni.diffusion.forward_context import set_forward_context
+from vllm_omni.diffusion.distributed.parallel_state import (
+ init_distributed_environment,
+ initialize_model_parallel,
+)
+from vllm.config import VllmConfig
+...
+
+if __name__ == "__main__":
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "8192"
+ os.environ["LOCAL_RANK"] = "0"
+ os.environ["RANK"] = "0"
+ os.environ["WORLD_SIZE"] = "1"
+
+ vllm_config = VllmConfig()
+ init_distributed_environment()
+ initialize_model_parallel()
+
+ # NOTE: you may have to pass an initialized OmniDiffusionConfig as a kwarg
+ # here to make current sp checks happy; if this is the case, just create one
+ # .from_kwargs() with the model name to get around this check for now,
+ # since your estimator subclass should handle the actual model configuration.
+ #
+ # This will be cleaned up in the future
+ with set_forward_context(vllm_config):
+
+```
+
+
**Data Statistics Guide:**
| Metric | Good Range | Warning Signs |
diff --git a/docs/design/figures/omni/E2EL_s_vllm_omni_vs_transformers.png b/docs/design/figures/omni/E2EL_s_vllm_omni_vs_transformers.png
new file mode 100644
index 0000000000..15112d5862
Binary files /dev/null and b/docs/design/figures/omni/E2EL_s_vllm_omni_vs_transformers.png differ
diff --git a/docs/design/figures/omni/Mean_AUDIO_RTF_Baseline_vs_Batch.png b/docs/design/figures/omni/Mean_AUDIO_RTF_Baseline_vs_Batch.png
new file mode 100644
index 0000000000..2f0615f77b
Binary files /dev/null and b/docs/design/figures/omni/Mean_AUDIO_RTF_Baseline_vs_Batch.png differ
diff --git a/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_CUDA_Graph_vs_Async_Chunk.png b/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_CUDA_Graph_vs_Async_Chunk.png
new file mode 100644
index 0000000000..62d8bc79b6
Binary files /dev/null and b/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_CUDA_Graph_vs_Async_Chunk.png differ
diff --git a/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_vs_Batch_CUDA_Graph.png b/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_vs_Batch_CUDA_Graph.png
new file mode 100644
index 0000000000..5838b45319
Binary files /dev/null and b/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_vs_Batch_CUDA_Graph.png differ
diff --git a/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Baseline_vs_Batch.png b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Baseline_vs_Batch.png
new file mode 100644
index 0000000000..24be814b7e
Binary files /dev/null and b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Baseline_vs_Batch.png differ
diff --git a/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_CUDA_Graph_vs_Async_Chunk.png b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_CUDA_Graph_vs_Async_Chunk.png
new file mode 100644
index 0000000000..c8df58ebcd
Binary files /dev/null and b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_CUDA_Graph_vs_Async_Chunk.png differ
diff --git a/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_vs_Batch_CUDA_Graph.png b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_vs_Batch_CUDA_Graph.png
new file mode 100644
index 0000000000..2d1a04e9c2
Binary files /dev/null and b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_vs_Batch_CUDA_Graph.png differ
diff --git a/docs/design/figures/omni/Mean_E2EL_ms_Baseline_vs_Batch.png b/docs/design/figures/omni/Mean_E2EL_ms_Baseline_vs_Batch.png
new file mode 100644
index 0000000000..e598b54343
Binary files /dev/null and b/docs/design/figures/omni/Mean_E2EL_ms_Baseline_vs_Batch.png differ
diff --git a/docs/design/figures/omni/Mean_E2EL_ms_Batch_CUDA_Graph_vs_Async_Chunk.png b/docs/design/figures/omni/Mean_E2EL_ms_Batch_CUDA_Graph_vs_Async_Chunk.png
new file mode 100644
index 0000000000..54452013eb
Binary files /dev/null and b/docs/design/figures/omni/Mean_E2EL_ms_Batch_CUDA_Graph_vs_Async_Chunk.png differ
diff --git a/docs/design/figures/omni/Mean_E2EL_ms_Batch_vs_Batch_CUDA_Graph.png b/docs/design/figures/omni/Mean_E2EL_ms_Batch_vs_Batch_CUDA_Graph.png
new file mode 100644
index 0000000000..04c5ad7396
Binary files /dev/null and b/docs/design/figures/omni/Mean_E2EL_ms_Batch_vs_Batch_CUDA_Graph.png differ
diff --git a/docs/design/figures/omni/RTF_vllm_omni_vs_transformers.png b/docs/design/figures/omni/RTF_vllm_omni_vs_transformers.png
new file mode 100644
index 0000000000..d93ba0b2af
Binary files /dev/null and b/docs/design/figures/omni/RTF_vllm_omni_vs_transformers.png differ
diff --git a/docs/design/figures/omni/Summary_E2EL_ms_vs_features.png b/docs/design/figures/omni/Summary_E2EL_ms_vs_features.png
new file mode 100644
index 0000000000..04087b5910
Binary files /dev/null and b/docs/design/figures/omni/Summary_E2EL_ms_vs_features.png differ
diff --git a/docs/design/figures/omni/Summary_RTF_vs_features.png b/docs/design/figures/omni/Summary_RTF_vs_features.png
new file mode 100644
index 0000000000..c2c8ad4083
Binary files /dev/null and b/docs/design/figures/omni/Summary_RTF_vs_features.png differ
diff --git a/docs/design/figures/omni/Summary_TTFP_ms_vs_features.png b/docs/design/figures/omni/Summary_TTFP_ms_vs_features.png
new file mode 100644
index 0000000000..3dcc1c5537
Binary files /dev/null and b/docs/design/figures/omni/Summary_TTFP_ms_vs_features.png differ
diff --git a/docs/design/figures/omni/TTFP_s_vllm_omni_vs_transformers.png b/docs/design/figures/omni/TTFP_s_vllm_omni_vs_transformers.png
new file mode 100644
index 0000000000..9a5b6c9bda
Binary files /dev/null and b/docs/design/figures/omni/TTFP_s_vllm_omni_vs_transformers.png differ
diff --git a/docs/design/figures/tts/Mean_AUDIO_RTF_vllm_omni_vs_transformers.png b/docs/design/figures/tts/Mean_AUDIO_RTF_vllm_omni_vs_transformers.png
new file mode 100644
index 0000000000..68f0ef17e8
Binary files /dev/null and b/docs/design/figures/tts/Mean_AUDIO_RTF_vllm_omni_vs_transformers.png differ
diff --git a/docs/design/figures/tts/Mean_AUDIO_TTFP_(ms)_vllm_omni_vs_transformers.png b/docs/design/figures/tts/Mean_AUDIO_TTFP_(ms)_vllm_omni_vs_transformers.png
new file mode 100644
index 0000000000..44be96e96d
Binary files /dev/null and b/docs/design/figures/tts/Mean_AUDIO_TTFP_(ms)_vllm_omni_vs_transformers.png differ
diff --git a/docs/design/figures/tts/Mean_E2EL_(ms)_vllm_omni_vs_transformers.png b/docs/design/figures/tts/Mean_E2EL_(ms)_vllm_omni_vs_transformers.png
new file mode 100644
index 0000000000..2e5d1482bd
Binary files /dev/null and b/docs/design/figures/tts/Mean_E2EL_(ms)_vllm_omni_vs_transformers.png differ
diff --git a/docs/design/figures/tts/Mean_mean_e2e_ms_baseline_vs_batch.png b/docs/design/figures/tts/Mean_mean_e2e_ms_baseline_vs_batch.png
new file mode 100644
index 0000000000..04d8f0bac5
Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_e2e_ms_baseline_vs_batch.png differ
diff --git a/docs/design/figures/tts/Mean_mean_e2e_ms_batch_vs_cuda_graph.png b/docs/design/figures/tts/Mean_mean_e2e_ms_batch_vs_cuda_graph.png
new file mode 100644
index 0000000000..eb85ec0dd4
Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_e2e_ms_batch_vs_cuda_graph.png differ
diff --git a/docs/design/figures/tts/Mean_mean_e2e_ms_cuda_graph_vs_async_chunk.png b/docs/design/figures/tts/Mean_mean_e2e_ms_cuda_graph_vs_async_chunk.png
new file mode 100644
index 0000000000..6f0e0e2529
Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_e2e_ms_cuda_graph_vs_async_chunk.png differ
diff --git a/docs/design/figures/tts/Mean_mean_rtf_baseline_vs_batch.png b/docs/design/figures/tts/Mean_mean_rtf_baseline_vs_batch.png
new file mode 100644
index 0000000000..89ea30a864
Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_rtf_baseline_vs_batch.png differ
diff --git a/docs/design/figures/tts/Mean_mean_rtf_batch_vs_cuda_graph.png b/docs/design/figures/tts/Mean_mean_rtf_batch_vs_cuda_graph.png
new file mode 100644
index 0000000000..2b207b8898
Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_rtf_batch_vs_cuda_graph.png differ
diff --git a/docs/design/figures/tts/Mean_mean_rtf_cuda_graph_vs_async_chunk.png b/docs/design/figures/tts/Mean_mean_rtf_cuda_graph_vs_async_chunk.png
new file mode 100644
index 0000000000..f5f7ad72c8
Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_rtf_cuda_graph_vs_async_chunk.png differ
diff --git a/docs/design/figures/tts/Mean_mean_ttfp_ms_baseline_vs_batch.png b/docs/design/figures/tts/Mean_mean_ttfp_ms_baseline_vs_batch.png
new file mode 100644
index 0000000000..6f8c1da4a5
Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_ttfp_ms_baseline_vs_batch.png differ
diff --git a/docs/design/figures/tts/Mean_mean_ttfp_ms_batch_vs_cuda_graph.png b/docs/design/figures/tts/Mean_mean_ttfp_ms_batch_vs_cuda_graph.png
new file mode 100644
index 0000000000..b0fe1d02a9
Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_ttfp_ms_batch_vs_cuda_graph.png differ
diff --git a/docs/design/figures/tts/Mean_mean_ttfp_ms_cuda_graph_vs_async_chunk.png b/docs/design/figures/tts/Mean_mean_ttfp_ms_cuda_graph_vs_async_chunk.png
new file mode 100644
index 0000000000..008ba9bf78
Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_ttfp_ms_cuda_graph_vs_async_chunk.png differ
diff --git a/docs/design/figures/tts/Summary_mean_e2e_ms_vs_features.png b/docs/design/figures/tts/Summary_mean_e2e_ms_vs_features.png
new file mode 100644
index 0000000000..7c65aa1177
Binary files /dev/null and b/docs/design/figures/tts/Summary_mean_e2e_ms_vs_features.png differ
diff --git a/docs/design/figures/tts/Summary_mean_rtf_vs_features.png b/docs/design/figures/tts/Summary_mean_rtf_vs_features.png
new file mode 100644
index 0000000000..71bb2c5468
Binary files /dev/null and b/docs/design/figures/tts/Summary_mean_rtf_vs_features.png differ
diff --git a/docs/design/figures/tts/Summary_mean_ttfp_ms_vs_features.png b/docs/design/figures/tts/Summary_mean_ttfp_ms_vs_features.png
new file mode 100644
index 0000000000..cef2546d6f
Binary files /dev/null and b/docs/design/figures/tts/Summary_mean_ttfp_ms_vs_features.png differ
diff --git a/docs/design/qwen3_omni_tts_performance_optimization.md b/docs/design/qwen3_omni_tts_performance_optimization.md
new file mode 100644
index 0000000000..2f18a1b1bc
--- /dev/null
+++ b/docs/design/qwen3_omni_tts_performance_optimization.md
@@ -0,0 +1,539 @@
+# Speech Generation on vLLM-Omni: Performance Optimizations for Qwen3-Omni and Qwen3-TTS
+
+## Summary
+
+vLLM-Omni supports end-to-end serving for speech-generating models, including both **Qwen3-Omni** (multimodal understanding + speech) and **Qwen3-TTS** (text-to-speech). Despite their different architectures, both models share the same multi-stage pipeline design and benefit from the same set of stacked optimizations:
+
+1. **Batching** improves GPU utilization stage by stage and increases overall throughput.
+2. **CUDA Graph** reduces CPU launch overhead and decode-time jitter on stable shapes.
+3. **Async Chunk and Streaming Output** overlap compute and communication across stages and emit audio incrementally, improving both TTFP and E2E.
+
+### Model architectures
+
+**Qwen3-Omni** is a native multimodal model that understands text, audio, image, and video inputs, and generates both text and speech outputs. Its pipeline has three stages:
+
+- **Thinker**: multimodal understanding and text generation
+- **Talker (+ Talker-MTP / code predictor path)**: converts semantic/text representations into codec tokens
+- **Code2Wav**: decodes codec tokens into waveform audio
+
+**Qwen3-TTS** is a lightweight, high-quality text-to-speech model. Its pipeline has two stages:
+
+- **Talker (AR decoder)**: auto-regressively generates codec tokens from text input
+- **Code2Wav (vocoder)**: decodes codec tokens into waveform audio
+
+The optimizations described in this post apply to both models. We present results for each side by side.
+
+### vLLM-Omni vs HF Transformers
+
+Compared with **HF Transformers** (offline, single request), vLLM-Omni with the full optimization stack delivers dramatically lower latency and higher efficiency for both models.
+
+**Qwen3-Omni** (A100):
+
+
+
+| Metric | vLLM-Omni | HF Transformers | Improvement |
+| --- | --- | --- | --- |
+| E2E latency (s) | 23.78 | 336.10 | ~93% reduction |
+| TTFP (s) | 0.934 | 336.10 | ~99.7% reduction |
+| RTF | 0.32 | 3.776 | ~91% reduction (~12× faster) |
+
+- **E2E latency**: 23.78 s vs 336.10 s - **~93%** reduction
+- **TTFP**: 0.934 s vs 336.10 s - **~99.7%** reduction
+- **RTF**: 0.32 vs 3.776 - **~91%** reduction (~12x faster)
+
+**Qwen3-TTS** (H200, concurrency 1):
+
+
+
+| Metric | vLLM-Omni | HF Transformers | Improvement |
+| --- | --- | --- | --- |
+| E2E latency (ms) | 941 | 15,513 | ~94% reduction |
+| TTFP (ms) | 64 | 15,513 | ~99.6% reduction (242× faster) |
+| RTF | 0.16 | 2.64 | ~94% reduction (~16.5× faster) |
+
+- **E2E latency**: 941 ms vs 15,513 ms - **~94%** reduction
+- **TTFP**: 64 ms vs 15,513 ms - **~99.6%** reduction (242x faster)
+- **RTF**: 0.16 vs 2.64 - **~94%** reduction (~16.5x faster)
+
+### Stacked optimization summary
+
+Each optimization stacks on the previous one. The summary plots below show the cumulative effect at each step, with one line per concurrency level (1, 4, 10).
+
+**Qwen3-Omni** (A100):
+
+
+
+- **E2EL reduction**: ~74% at concurrency 10 (410,054 ms -> 104,901 ms); ~90% at concurrency 1 (426,529 ms -> 41,216 ms)
+- **TTFP reduction**: ~96% at concurrency 10 (409,705 ms -> 16,482 ms); ~99.7% at concurrency 1 (426,078 ms -> 1,164 ms)
+- **RTF reduction**: ~74% at concurrency 10 (2.83 -> 0.74); ~90% at concurrency 1 (2.08 -> 0.21)
+
+**Qwen3-TTS** (H200):
+
+
+
+- **E2EL reduction**: ~85% at concurrency 10 (12,141 ms -> 1,767 ms); ~29% at concurrency 1 (1,323 ms -> 941 ms)
+- **TTFP reduction**: ~96.5% at concurrency 10 (12,141 ms -> 425 ms); ~95% at concurrency 1 (1,323 ms -> 64 ms)
+- **RTF reduction**: ~86% at concurrency 10 (2.19 -> 0.31); ~30% at concurrency 1 (0.23 -> 0.16)
+
+**Benchmark environment:**
+
+| | Qwen3-Omni | Qwen3-TTS |
+| --- |-----------------------------| --- |
+| **GPU** | A100 | H200 |
+| **Model** | Qwen3-Omni-30B-A3B-Instruct | Qwen3-TTS-12Hz-1.7B-CustomVoice |
+| **vLLM** | v0.17.0 | v0.18.0 |
+| **vllm-omni** | commit 199f7832 | v0.18.0rc2 |
+| **CUDA** | 12.9 | 12.8 |
+
+This post walks through each optimization in the same order they are typically enabled in practice, then ends with deployment playbooks for both models.
+
+---
+
+## Pipeline Batching
+
+### How stage-wise batching works
+
+For both Qwen3-Omni and Qwen3-TTS, batching is a pipeline-level optimization:
+
+- Requests are grouped per stage using `runtime.max_batch_size`
+- Each stage executes batch inference with its own scheduler/worker
+- Stage outputs are routed to downstream stages with per-request mapping preserved
+
+**Batching strategy by stage:** The understanding and decode stages (Thinker for Omni, Talker for both) use **continuous batching**: requests can join and leave the batch over time. Code2Wav uses **static batching**: once a batch is formed, the stage runs the whole batch before starting the next. This matches the decode pattern of Code2Wav and keeps implementation simple while still improving throughput.
+
+### Batching results (Baseline vs. Batch)
+
+Batching alone greatly reduces E2EL and RTF across all concurrencies. The biggest gains appear at high concurrency where requests share GPU resources.
+
+**Qwen3-Omni** (A100):
+
+
+
+| Metric | Concurrency | Baseline | + Batch | Improvement |
+| --- | --- | --- | --- | --- |
+| E2EL (ms) | 1 | 426,529 | 307,719 | 1.4× |
+| E2EL (ms) | 4 | 407,213 | 376,934 | 1.1× |
+| E2EL (ms) | 10 | 410,054 | 234,844 | 1.7× |
+| TTFP (ms) | 1 | 426,078 | 307,262 | 1.4× |
+| TTFP (ms) | 4 | 406,843 | 376,466 | 1.1× |
+| TTFP (ms) | 10 | 409,705 | 234,557 | 1.7× |
+| RTF | 1 | 2.08 | 1.51 | 1.4× |
+| RTF | 4 | 2.55 | 1.83 | 1.4× |
+| RTF | 10 | 2.83 | 2.28 | 1.2× |
+
+At concurrency 10, E2EL drops from ~410 s to ~235 s; at concurrency 1, from ~427 s to ~308 s.
+
+**Qwen3-TTS** (H200):
+
+
+
+| Metric | Concurrency | Baseline | + Batch | Improvement |
+| --- | --- | --- | --- | --- |
+| E2EL (ms) | 1 | 1,323 | 1,339 | 1.0× |
+| E2EL (ms) | 4 | 5,171 | 1,471 | 3.5× |
+| E2EL (ms) | 10 | 12,141 | 1,705 | 7.1× |
+| RTF | 1 | 0.230 | 0.234 | 1.0× |
+| RTF | 4 | 0.908 | 0.255 | 3.6× |
+| RTF | 10 | 2.186 | 0.292 | 7.5× |
+| Throughput (audio-s/wall-s) | 10 | 3.99 | 33.53 | 8.4× |
+
+At concurrency 10, batching alone brings Qwen3-TTS RTF from 2.19 (slower than realtime) down to 0.29 (faster than realtime), and throughput from 4.0 to 33.5 audio-sec/wall-sec.
+
+---
+
+## CUDA Graph on the Critical Decode Path
+
+### Why CUDA Graph helps here
+
+In decode-heavy serving, repeatedly launching many small kernels from CPU can become a visible overhead. CUDA Graph reduces this overhead by capturing and replaying stable execution graphs.
+
+In stage configs, this is represented by `enforce_eager: false` for stages where graph capture is desired (Thinker/Talker), while Code2Wav keeps eager mode depending on stage behavior.
+
+### CUDA Graph results on top of batching
+
+**Qwen3-Omni** (A100):
+
+
+
+| Metric | Concurrency | Batch | + CUDA Graph | Improvement |
+| --- | --- | --- | --- | --- |
+| E2EL (ms) | 1 | 307,719 | 61,613 | 5.0× |
+| E2EL (ms) | 4 | 376,934 | 79,019 | 4.8× |
+| E2EL (ms) | 10 | 234,844 | 126,867 | 1.9× |
+| TTFP (ms) | 1 | 307,262 | 61,257 | 5.0× |
+| TTFP (ms) | 4 | 376,466 | 78,634 | 4.8× |
+| TTFP (ms) | 10 | 234,557 | 126,534 | 1.9× |
+| RTF | 1 | 1.51 | 0.32 | 4.7× |
+| RTF | 4 | 1.83 | 0.43 | 4.3× |
+| RTF | 10 | 2.28 | 0.90 | 2.5× |
+
+For the larger Qwen3-Omni model (30B-A3B), CUDA Graph provides a significant improvement. At concurrency 1, E2EL drops from ~308 s to ~62 s; at concurrency 10, from ~235 s to ~127 s.
+
+**Qwen3-TTS** (H200):
+
+
+
+| Metric | Concurrency | Batch | + CUDA Graph | Improvement |
+| --- | --- | --- | --- | --- |
+| E2EL (ms) | 1 | 1,339 | 733 | 1.8× |
+| E2EL (ms) | 4 | 1,471 | 987 | 1.5× |
+| E2EL (ms) | 10 | 1,705 | 1,197 | 1.4× |
+| RTF | 1 | 0.234 | 0.124 | 1.9× |
+| RTF | 10 | 0.292 | 0.203 | 1.4× |
+| Throughput (audio-s/wall-s) | 10 | 33.53 | 47.15 | 1.4× |
+
+At concurrency 1, CUDA Graph reduces E2EL from 1,339 ms to 733 ms and RTF from 0.234 to 0.124 - nearly a 2x improvement. The benefit is consistent across all concurrency levels.
+
+---
+
+## Async Chunk and Streaming Output: Earlier Audio and Cross-Stage Overlap
+
+### Why this step matters for first-packet latency
+
+Two mechanisms work together to improve user-visible latency:
+
+- **Streaming output**: audio streaming emits audio chunks as soon as they are decoded (lower **TTFP**). Without streaming, the client waits for larger buffers or end-of-sequence.
+- **Async chunk** is the main enabler for *earlier* audio: instead of handing off whole-request results between stages, each stage forwards **chunks** so the next stage can start as soon as the first chunk is ready. For Omni: Thinker -> Talker forwards hidden-state chunks; for both: Talker -> Code2Wav forwards codec chunks; Code2Wav decodes and emits packets incrementally. This **overlaps compute and communication** across stages and directly reduces time-to-first-audio-packet (TTFP) and end-to-end latency (E2EL).
+
+So in practice: streaming output defines *how* bytes are sent to the client; async chunk defines *when* the pipeline can produce the first bytes.
+
+**Dependency between the two:** Async chunk and audio streaming output are mutually dependent. Without async chunk, **audio streaming output cannot truly take effect**. Without audio streaming output, async chunk's **TTFP advantage is not fully realized**: the client would still wait for larger buffers or end-of-sequence instead of hearing the first packet as soon as it is ready. We therefore recommend enabling **both** on top of batching + CUDA Graph; the benchmarks in this post use both.
+
+### Results: Batch + CUDA Graph vs. Batch + CUDA Graph + Async Chunk + Streaming Output
+
+**Qwen3-Omni** (A100):
+
+
+
+| Metric | Concurrency | Batch + CG | + Async Chunk | Improvement |
+| --- | --- | --- | --- | --- |
+| E2EL (ms) | 1 | 61,613 | 41,216 | 1.5× |
+| E2EL (ms) | 4 | 79,019 | 67,584 | 1.2× |
+| E2EL (ms) | 10 | 126,867 | 104,901 | 1.2× |
+| TTFP (ms) | 1 | 61,257 | 1,164 | 53× |
+| TTFP (ms) | 4 | 78,634 | 3,152 | 24.9× |
+| TTFP (ms) | 10 | 126,534 | 16,482 | 7.7× |
+| RTF | 1 | 0.32 | 0.21 | 1.5× |
+| RTF | 4 | 0.43 | 0.34 | 1.3× |
+| RTF | 10 | 0.90 | 0.74 | 1.2× |
+
+Enabling both brings TTFP down sharply (concurrency 1: 61,257 ms -> 1,164 ms, **~98% reduction**; concurrency 4: 78,634 ms -> 3,152 ms, **~96% reduction**). E2EL and RTF also improve at every concurrency.
+
+**Qwen3-TTS** (H200):
+
+
+
+| Metric | Concurrency | Batch + CG | + Async Chunk | Improvement |
+| --- | --- | --- | --- | --- |
+| TTFP (ms) | 1 | 733 | **64** | **11.5×** |
+| TTFP (ms) | 4 | 987 | **119** | **8.3×** |
+| TTFP (ms) | 10 | 1,197 | **425** | **2.8×** |
+| E2EL (ms) | 1 | 733 | 941 | 0.8× |
+| E2EL (ms) | 10 | 1,197 | 1,767 | 0.7× |
+| RTF | 1 | 0.124 | 0.160 | 0.8× |
+| RTF | 10 | 0.203 | 0.314 | 0.6× |
+
+The TTFP improvement is the headline result for both models. For Qwen3-TTS at concurrency 1, users hear the first audio in **64 ms** instead of 733 ms - an **11.5x reduction**. For Qwen3-Omni at concurrency 1, TTFP drops from 61 s to 1.2 s - a **53x reduction**.
+
+### Why E2EL and RTF are higher with async chunk (TTS)
+
+The table above shows that enabling async chunk + streaming *increases* E2EL and RTF for TTS compared to CUDA Graph alone. This is expected - the two configurations optimize for fundamentally different metrics:
+
+- **CUDA Graph (no async chunk)** generates the entire audio end-to-end before returning. No chunking overhead, so total compute is minimized.
+- **Async Chunk + Streaming** splits the pipeline into incremental chunks, adding overhead from chunked transport, context overlap in Code2Wav (`codec_left_context_frames=25`), and smaller effective batch sizes per chunk.
+
+**The tradeoff is intentional.** Async chunk trades ~30% higher total compute for **11x faster time-to-first-audio**. For interactive applications (voice assistants, chatbots), TTFP determines perceived responsiveness. For offline batch processing, CUDA Graph without async chunk is the better choice.
+
+---
+
+## TTS-Specific: Code Predictor Re-prefill + `torch.compile`
+
+Qwen3-TTS has a **code predictor** - a small 5-layer transformer that generates residual codebook tokens (groups 1 through Q-1) autoregressively. Each AR step operates on very short sequences (2 to ~16 tokens).
+
+The naive approach uses a KV cache for this small transformer, similar to the main Talker. But the KV cache machinery (block tables, slot mappings, paged attention) introduces significant overhead relative to the tiny model. Two optimizations replace that:
+
+### Re-prefill (stateless forward, no KV cache)
+
+Instead of maintaining a KV cache across steps, the code predictor **re-feeds the full growing sequence** at each AR step using `F.scaled_dot_product_attention`. With sequences of at most ~16 tokens through 5 layers, the O(T^2) attention cost is negligible - and removing the KV cache machinery (block table management, `set_forward_context`, slot mapping) saves far more time than it costs.
+
+### `torch.compile` on the code predictor forward
+
+The 5-layer transformer forward pass launches ~60 small CUDA kernels per step. `torch.compile(mode="default", dynamic=True)` fuses these into fewer kernels via Inductor:
+
+```python
+self._compiled_model_fwd = torch.compile(
+ self.model.forward,
+ mode="default", # no Inductor CUDA graphs, avoids conflict with vLLM's CUDAGraphWrapper
+ dynamic=True, # sequence length grows each step (2, 3, ..., num_groups+1)
+)
+```
+
+`mode="default"` is used instead of `mode="reduce-overhead"` to avoid conflicts with vLLM's own CUDA graph capture on the main Talker model. `dynamic=True` handles the growing sequence length without recompilation.
+
+These optimizations are always-on in the current codebase - all Qwen3-TTS benchmark results in this post include them.
+
+---
+
+## TTS-Specific: Dynamic Initial Chunk for Faster First Audio
+
+In the async chunk pipeline, the standard `codec_chunk_frames` is 25 (each chunk = ~2 seconds of audio at 12 Hz). Waiting for 25 frames before forwarding the first chunk to Code2Wav adds unnecessary TTFP. The **initial codec chunk** optimization sends a smaller first chunk so Code2Wav can start decoding earlier.
+
+**Dynamic initial chunk sizing (default behavior):**
+
+Rather than using a fixed initial chunk size, vLLM-Omni dynamically selects it based on current server load. The initial chunk size is chosen from power-of-2 steps [2, 4, 8, 16] based on load factor (`active_requests / max_batch_size`):
+
+| Server load | Initial chunk frames | Rationale |
+| --- | --- | --- |
+| Low (e.g. 1/10 active) | **2** (~167 ms of audio) | Minimize TTFP when there's headroom |
+| Medium (e.g. 5/10 active) | **4-8** | Balance TTFP vs decode efficiency |
+| High (e.g. 10/10 active) | **16** | Larger first chunk to amortize decode cost |
+
+After the initial chunk, all subsequent chunks use the standard `codec_chunk_frames` (25) size.
+
+**How it works in the pipeline:**
+
+1. Talker generates codec tokens auto-regressively
+2. The stage input processor checks current load and picks an initial chunk size (e.g. **2 frames** at low load)
+3. After that many frames, the first chunk is forwarded to Code2Wav
+4. Code2Wav decodes this small chunk and emits the first audio packet
+5. Subsequent chunks use the standard 25-frame size for efficient batch decoding
+
+**Per-request override:** Clients can also set a fixed initial chunk size via the API:
+
+```json
+{"initial_codec_chunk_frames": 2}
+```
+
+This overrides the dynamic calculation for that request.
+
+**Config (server-side):**
+
+```yaml
+runtime:
+ connectors:
+ connector_of_shared_memory:
+ name: SharedMemoryConnector
+ extra:
+ codec_streaming: true
+ codec_chunk_frames: 25 # standard chunk size (~2s of audio)
+ codec_left_context_frames: 25
+ # initial chunk is computed dynamically by default
+ # set initial_codec_chunk_frames: 2 to force a fixed value
+```
+
+The 64 ms TTFP result reported above for Qwen3-TTS at concurrency 1 uses the dynamic initial chunk, which picks `initial_codec_chunk_frames=2` at low load. At higher concurrency the dynamic sizing increases the initial chunk to maintain decode efficiency.
+
+---
+
+## Live Demo: Streaming TTS over WebSocket
+
+vLLM-Omni supports real-time streaming audio output for Qwen3-TTS over WebSocket ([PR #1719](https://github.com/vllm-project/vllm-omni/pull/1719)). With `stream_audio: true`, the server sends chunked PCM audio frames as they are generated, so clients can start playback before full sentence synthesis completes.
+
+The WebSocket protocol uses `audio.start` / binary PCM chunks / `audio.done` framing per sentence:
+
+```json
+// Client sends:
+{"type":"session.config","voice":"Vivian","response_format":"pcm","stream_audio":true}
+{"type":"input.text","text":"Hello world. This is a streaming demo."}
+{"type":"input.done"}
+
+// Server streams back per sentence:
+{"type":"audio.start","sentence_index":0,"sentence_text":"Hello world.","format":"pcm","sample_rate":24000}
+
+
+...
+{"type":"audio.done","sentence_index":0,"total_bytes":96000,"error":false}
+{"type":"audio.start","sentence_index":1,"sentence_text":"This is a streaming demo.","format":"pcm","sample_rate":24000}
+
+...
+{"type":"audio.done","sentence_index":1,"total_bytes":72000,"error":false}
+{"type":"session.done","total_sentences":2}
+```
+
+VIDEO
+
+---
+
+## Deployment Playbook
+
+### Qwen3-Omni
+
+#### 1) Serve with the default 3-stage config
+
+```bash
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct \
+ --omni \
+ --port 8091
+```
+
+Notes:
+
+- `runtime.max_batch_size` controls stage-level batching.
+- Thinker/Talker commonly use `enforce_eager: false` for CUDA Graph paths.
+- Code2Wav often remains eager (`enforce_eager: true`) depending on runtime behavior.
+
+#### 2) Enable async chunk
+
+```bash
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct \
+ --omni \
+ --port 8091 \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml
+```
+
+#### 3) Key config knobs
+
+```yaml
+async_chunk: true
+stage_args:
+ - stage_id: 0 # thinker
+ runtime:
+ max_batch_size: 64
+ engine_args:
+ enforce_eager: false
+ max_num_batched_tokens: 32768
+ custom_process_next_stage_input_func: >-
+ vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk
+
+ - stage_id: 1 # talker
+ runtime:
+ max_batch_size: 64
+ engine_args:
+ enforce_eager: false
+ max_num_batched_tokens: 32768
+ custom_process_next_stage_input_func: >-
+ vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk
+
+ - stage_id: 2 # code2wav
+ runtime:
+ max_batch_size: 64
+ engine_args:
+ enforce_eager: true
+ max_num_batched_tokens: 51200
+```
+
+#### Reproduce Qwen3-Omni benchmarks
+
+```bash
+vllm bench serve \
+ --dataset-name random \
+ --port ${PORT} \
+ --model ${MODEL_PATH} \
+ --endpoint /v1/chat/completions \
+ --backend openai-chat-omni \
+ --max-concurrency ${MAX_CONCURRENCY} \
+ --num-prompts ${NUM_PROMPTS} \
+ --random-input-len 2500 \
+ --ignore-eos \
+ --percentile-metrics ttft,tpot,itl,e2el,audio_ttfp,audio_rtf \
+ --random-output-len 900 \
+ --extra_body '{"modalities": ["text","audio"]}'
+```
+
+### Qwen3-TTS
+
+#### 1) Serve with async chunk (recommended)
+
+```bash
+vllm-omni serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
+ --omni \
+ --port 8000
+```
+
+The default config (`qwen3_tts.yaml`) enables the full optimization stack:
+
+- Batching with `max_batch_size: 10` on the Talker stage
+- CUDA Graph on the Talker (`enforce_eager: false`)
+- Async chunk with streaming transport
+
+#### 2) Serve without async chunk (for comparison)
+
+```bash
+vllm-omni serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
+ --omni \
+ --port 8000 \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml
+```
+
+#### 3) Key config knobs
+
+```yaml
+async_chunk: true
+stage_args:
+ - stage_id: 0 # Talker (AR decoder)
+ runtime:
+ max_batch_size: 10
+ engine_args:
+ enforce_eager: false
+ max_num_batched_tokens: 512
+ custom_process_next_stage_input_func: >-
+ vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
+
+ - stage_id: 1 # Code2Wav (vocoder)
+ runtime:
+ max_batch_size: 1
+ engine_args:
+ enforce_eager: true
+ max_num_batched_tokens: 8192
+
+runtime:
+ connectors:
+ connector_of_shared_memory:
+ name: SharedMemoryConnector
+ extra:
+ codec_streaming: true
+ codec_chunk_frames: 25
+ codec_left_context_frames: 25
+```
+
+#### Reproduce Qwen3-TTS benchmarks
+
+```bash
+GPU_DEVICE=0 \
+MODEL=Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
+NUM_PROMPTS=50 \
+CONCURRENCY="1 4 10" \
+bash benchmarks/qwen3-tts/vllm_omni/run_stacked_benchmark.sh
+```
+
+This cycles through four configs (Baseline -> + Batch -> + CUDA Graph -> + Async Chunk + Streaming), benchmarks each at the specified concurrency levels, and generates all comparison figures automatically.
diff --git a/docs/getting_started/installation/gpu/rocm.inc.md b/docs/getting_started/installation/gpu/rocm.inc.md
index 1a683d174f..5dfea8d2ff 100644
--- a/docs/getting_started/installation/gpu/rocm.inc.md
+++ b/docs/getting_started/installation/gpu/rocm.inc.md
@@ -26,7 +26,7 @@ uv pip install vllm-omni
# Optional if want to run Qwen3 TTS
uv pip uninstall onnxruntime # should be removed before we can install onnxruntime-rocm
-uv pip install onnxruntime-rocm sox
+uv pip install onnxruntime-rocm
```
# --8<-- [end:pre-built-wheels]
diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 0f9c8fff60..b4976b09c4 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -47,6 +47,7 @@ th {
| `Flux2KleinPipeline` | FLUX.2-klein | `black-forest-labs/FLUX.2-klein-4B`, `black-forest-labs/FLUX.2-klein-9B` | ✅︎ | ✅︎ | ✅︎ | ✅︎ |
| `FluxKontextPipeline` | FLUX.1-Kontext-dev | `black-forest-labs/FLUX.1-Kontext-dev` | ✅︎ | ✅︎ | | |
| `FluxPipeline` | FLUX.1-dev | `black-forest-labs/FLUX.1-dev` | ✅︎ | ✅︎ | | ✅︎ |
+| `FluxPipeline` | FLUX.1-schnell | `black-forest-labs/FLUX.1-schnell` | ✅︎ | ✅︎ | | ✅︎ |
| `OmniGen2Pipeline` | OmniGen2 | `OmniGen2/OmniGen2` | ✅︎ | ✅︎ | | ✅︎ |
| `StableAudioPipeline` | Stable-Audio-Open | `stabilityai/stable-audio-open-1.0` | ✅︎ | ✅︎ | | ✅︎ |
| `Qwen3TTSForConditionalGeneration` | Qwen3-TTS-12Hz-1.7B-CustomVoice | `Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice` | ✅︎ | ✅︎ | ✅︎ | ✅︎ |
diff --git a/docs/serving/image_edit_api.md b/docs/serving/image_edit_api.md
index d254ac06ad..79303e1a69 100644
--- a/docs/serving/image_edit_api.md
+++ b/docs/serving/image_edit_api.md
@@ -104,6 +104,8 @@ Content-Type: multipart/form-data
| `guidance_scale` | float | model defaults | Classifier-free guidance scale (typically 0.0-20.0) |
| `true_cfg_scale` | float | model defaults | True CFG scale (model-specific parameter, may be ignored if not supported) |
| `seed` | integer | null | Random seed for reproducibility |
+| `reference_image` | string or array | null | Reference image for inpainting |
+| `mask_image` | string or array | null | Mask for inpainting (white areas will be inpainted) |
### Response Format
diff --git a/docs/serving/speech_api.md b/docs/serving/speech_api.md
index ecbe8d9ac9..733811081a 100644
--- a/docs/serving/speech_api.md
+++ b/docs/serving/speech_api.md
@@ -15,7 +15,7 @@ Each server instance runs a single model (specified at startup via `vllm serve <
```bash
# Qwen3-TTS: CustomVoice model (predefined speakers)
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
@@ -300,7 +300,7 @@ curl -X POST http://localhost:8091/v1/audio/speech \
```bash
# Start server with VoiceDesign model first
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
@@ -322,7 +322,7 @@ curl -X POST http://localhost:8091/v1/audio/speech \
```bash
# Start server with Base model first
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-Base \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
@@ -517,15 +517,16 @@ for result in response.json()["results"]:
All items are fanned out to `generate()` concurrently. The engine's stage worker automatically batches them up to the configured `max_batch_size` and queues the rest — no client-side throttling needed.
-For best throughput, use a batch-optimized stage config with `max_batch_size > 1`:
+For best throughput, set both stages' `max_num_seqs` to ≥4 via `--stage-overrides`:
```bash
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml \
- --omni --port 8091 --trust-remote-code --enforce-eager
+ --omni --port 8091 --trust-remote-code --enforce-eager \
+ --stage-overrides '{"0":{"max_num_seqs":4,"gpu_memory_utilization":0.2},
+ "1":{"max_num_seqs":4,"gpu_memory_utilization":0.2}}'
```
-The default `qwen3_tts.yaml` uses `max_batch_size: 1` (single request). The `qwen3_tts_batch.yaml` config sets `max_batch_size: 4` for ~4x throughput.
+The bundled `qwen3_tts.yaml` uses `max_num_seqs: 1` (single request) on both stages. Bumping to 4 yields roughly 4× throughput on the talker and lets stage 1 batch chunks across in-flight requests.
## Supported Models
@@ -617,7 +618,7 @@ Enable debug logging:
```bash
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
diff --git a/docs/source/architecture/async-chunk-architecture.png b/docs/source/architecture/async-chunk-architecture.png
index 249de53bfe..7b3e95e4df 100644
Binary files a/docs/source/architecture/async-chunk-architecture.png and b/docs/source/architecture/async-chunk-architecture.png differ
diff --git a/docs/source/architecture/qwen3-omni-async-chunk.png b/docs/source/architecture/qwen3-omni-async-chunk.png
index b2d98b80f3..e73ca84b28 100644
Binary files a/docs/source/architecture/qwen3-omni-async-chunk.png and b/docs/source/architecture/qwen3-omni-async-chunk.png differ
diff --git a/docs/source/architecture/qwen3-omni-non-async-chunk.png b/docs/source/architecture/qwen3-omni-non-async-chunk.png
index da5610a11b..47a9ba66a5 100644
Binary files a/docs/source/architecture/qwen3-omni-non-async-chunk.png and b/docs/source/architecture/qwen3-omni-non-async-chunk.png differ
diff --git a/docs/source/architecture/vllm-omni-dataflow-between-stages.png b/docs/source/architecture/vllm-omni-dataflow-between-stages.png
index cdbc9a8b7b..74abc81ff0 100644
Binary files a/docs/source/architecture/vllm-omni-dataflow-between-stages.png and b/docs/source/architecture/vllm-omni-dataflow-between-stages.png differ
diff --git a/docs/usage/faq.md b/docs/usage/faq.md
index c080eae402..0539e158b0 100644
--- a/docs/usage/faq.md
+++ b/docs/usage/faq.md
@@ -4,14 +4,6 @@
A: Now, we support natively disaggregated deployment for different model stages within a model. There is a restriction that one chip can only have one AutoRegressive model stage. This is because the unified KV cache management of vLLM. Stages of other types can coexist within a chip. The restriction will be resolved in later version.
-> Q: When trying to run examples, I encounter error about backend of librosa or soundfile. How to solve it?
-
-A: If you encounter error about backend of librosa, try to install ffmpeg with command below.
-```
-sudo apt update
-sudo apt install ffmpeg
-```
-
> Q: I see GPU OOM or "free memory is less than desired GPU memory utilization" errors. How can I fix it?
A: Refer to [GPU memory calculation and configuration](../configuration/gpu_memory_utilization.md) for guidance on tuning `gpu_memory_utilization` and related settings.
diff --git a/docs/user_guide/diffusion/frame_interpolation.md b/docs/user_guide/diffusion/frame_interpolation.md
new file mode 100644
index 0000000000..349af50c51
--- /dev/null
+++ b/docs/user_guide/diffusion/frame_interpolation.md
@@ -0,0 +1,92 @@
+# Frame Interpolation
+
+## Overview
+
+vLLM-Omni supports post-generation frame interpolation for supported video
+diffusion pipelines. This feature inserts synthesized intermediate frames
+between adjacent generated frames to improve temporal smoothness without
+rerunning the diffusion denoising loop.
+
+Frame interpolation runs in the diffusion worker post-processing path instead
+of the API server encoding path. This allows the interpolation step to reuse
+the worker's current accelerator device and keeps the FastAPI event loop free
+from heavy synchronous PyTorch work.
+
+For an input video with `N` generated frames and interpolation exponent `exp`,
+the output frame count is:
+
+```text
+(N - 1) * 2**exp + 1
+```
+
+The output FPS is multiplied by `2**exp` so the clip duration remains close to
+the original generated video.
+
+## Supported Pipelines
+
+Frame interpolation is currently supported for:
+
+- `WanPipeline` (Wan2.2 text-to-video)
+- `WanImageToVideoPipeline`
+- `Wan22TI2VPipeline`
+
+## Request Parameters
+
+The video APIs `/v1/videos` and `/v1/videos/sync` accept:
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `enable_frame_interpolation` | bool | `false` | Enable post-generation frame interpolation |
+| `frame_interpolation_exp` | int | `1` | Interpolation exponent. `1=2x`, `2=4x`, etc. |
+| `frame_interpolation_scale` | float | `1.0` | RIFE inference scale |
+| `frame_interpolation_model_path` | str | `None` | Local directory or Hugging Face repo ID containing `flownet.pkl` |
+
+## Execution Flow
+
+For supported Wan2.2 pipelines, the execution order is:
+
+1. Diffusion worker finishes denoising and decodes the raw video tensor.
+2. Worker-side model-specific post-processing runs.
+3. If frame interpolation is enabled, RIFE interpolates the decoded video
+ tensor on the worker side and records a FPS multiplier in `custom_output`.
+4. The API server receives the already-interpolated video and only performs
+ MP4 export.
+
+This design keeps interpolation close to the generated tensor and avoids
+introducing another heavyweight GPU context in the API server process.
+
+## Example
+
+Start the server:
+
+```bash
+vllm serve Wan-AI/Wan2.2-T2V-A14B-Diffusers --omni --port 8091
+```
+
+Run a sync request with interpolation enabled:
+
+```bash
+curl -X POST http://localhost:8091/v1/videos/sync \
+ -F "prompt=A dog running through a park" \
+ -F "num_frames=81" \
+ -F "width=832" \
+ -F "height=480" \
+ -F "fps=16" \
+ -F "num_inference_steps=40" \
+ -F "guidance_scale=1.0" \
+ -F "guidance_scale_2=1.0" \
+ -F "enable_frame_interpolation=true" \
+ -F "frame_interpolation_exp=1" \
+ -F "frame_interpolation_scale=1.0" \
+ -F "seed=42" \
+ -o sync_t2v_interpolated.mp4
+```
+
+## Notes
+
+- This is a post-processing feature. It does not modify the diffusion denoising
+ schedule.
+- Higher interpolation exponents increase post-processing time and memory usage.
+- If the interpolation model weights are not available locally,
+ `frame_interpolation_model_path` may point to a Hugging Face repo containing
+ `flownet.pkl`.
diff --git a/docs/user_guide/diffusion/lora.md b/docs/user_guide/diffusion/lora.md
index e45c033b84..256698752a 100644
--- a/docs/user_guide/diffusion/lora.md
+++ b/docs/user_guide/diffusion/lora.md
@@ -56,6 +56,92 @@ outputs = omni.generate(
!!! note "Server-side Path Requirement"
The LoRA adapter path (`local_path`) must be readable on the **server** machine. If your client and server are on different machines, ensure the LoRA adapter is accessible via a shared mount or copied to the server.
+## Wan2.2 LightX2V Offline Assembly
+
+This workflow is LoRA-adjacent: it uses external LightX2V conversion plus
+`Wan2.2-Distill-Loras` to bake converted Wan2.2 I2V checkpoints into a local
+Diffusers directory, instead of loading LoRA adapters at runtime.
+
+### Required assets
+
+- Base model: `Wan-AI/Wan2.2-I2V-A14B`
+- Diffusers skeleton: `Wan-AI/Wan2.2-I2V-A14B-Diffusers`
+- Optional external converter from the LightX2V project (not shipped in this repository)
+- Optional LoRA weights: `lightx2v/Wan2.2-Distill-Loras`
+
+### Step 1: Optional - convert high/low-noise DiT weights with LightX2V
+
+Install or clone LightX2V from the upstream repository
+(`https://github.com/ModelTC/LightX2V`). After cloning, the converter used
+below is available at `/tools/convert/converter.py`.
+
+```bash
+python /path/to/lightx2v/tools/convert/converter.py \
+ --source /path/to/Wan2.2-I2V-A14B/high_noise_model \
+ --output /tmp/wan22_lightx2v/high_noise_out \
+ --output_ext .safetensors \
+ --output_name diffusion_pytorch_model \
+ --model_type wan_dit \
+ --direction forward \
+ --lora_path /path/to/wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors \
+ --lora_key_convert auto \
+ --single_file
+
+python /path/to/lightx2v/tools/convert/converter.py \
+ --source /path/to/Wan2.2-I2V-A14B/low_noise_model \
+ --output /tmp/wan22_lightx2v/low_noise_out \
+ --output_ext .safetensors \
+ --output_name diffusion_pytorch_model \
+ --model_type wan_dit \
+ --direction forward \
+ --lora_path /path/to/wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors \
+ --lora_key_convert auto \
+ --single_file
+```
+
+If you are not using LightX2V, skip this step and either keep the original
+Diffusers weights from the skeleton or point Step 2 at any other converted
+`transformer/` and `transformer_2/` checkpoints.
+
+### Step 2: Assemble a final Diffusers-style directory
+
+```bash
+python tools/wan22/assemble_wan22_i2v_diffusers.py \
+ --diffusers-skeleton /path/to/Wan2.2-I2V-A14B-Diffusers \
+ --transformer-weight /tmp/wan22_lightx2v/high_noise_out \
+ --transformer-2-weight /tmp/wan22_lightx2v/low_noise_out \
+ --output-dir /path/to/Wan2.2-I2V-A14B-Custom-Diffusers \
+ --asset-mode symlink \
+ --overwrite
+```
+
+`--transformer-weight` and `--transformer-2-weight` are optional. If you omit
+them, the tool keeps the original weights from the Diffusers skeleton.
+
+### Step 3: Run offline inference
+
+```bash
+python examples/offline_inference/image_to_video/image_to_video.py \
+ --model /path/to/Wan2.2-I2V-A14B-Custom-Diffusers \
+ --image /path/to/input.jpg \
+ --prompt "A cat playing with yarn" \
+ --num-frames 81 \
+ --num-inference-steps 4 \
+ --tensor-parallel-size 4 \
+ --height 480 \
+ --width 832 \
+ --flow-shift 12 \
+ --sample-solver euler \
+ --guidance-scale 1.0 \
+ --guidance-scale-high 1.0 \
+ --boundary-ratio 0.875
+```
+
+Notes:
+
+- This route avoids runtime LoRA loading changes in vLLM-Omni when you choose to bake converted weights into a local Diffusers directory.
+- Output quality and speed depend on the replacement checkpoints and sampling params you choose.
+
## See Also
diff --git a/docs/user_guide/diffusion_features.md b/docs/user_guide/diffusion_features.md
index 7e08851812..7bdeede446 100644
--- a/docs/user_guide/diffusion_features.md
+++ b/docs/user_guide/diffusion_features.md
@@ -14,7 +14,7 @@ vLLM-Omni supports various advanced features for diffusion models:
- Acceleration: **cache methods**, **parallelism methods**, **startup optimizations**
- Memory optimization: **cpu offloading**, **quantization**
-- Extensions: **LoRA inference**
+- Extensions: **LoRA inference**, **frame interpolation**
- Execution modes: **step execution**
## Supported Features
@@ -69,6 +69,7 @@ Extension methods add specialized capabilities to diffusion models beyond standa
| Method | Description | Best For |
|--------|-------------|----------|
| **[LoRA Inference](diffusion/lora.md)** | Enables inference with Low-Rank Adaptation (LoRA) adapters weights | Reinforcement learning extensions |
+| **[Frame Interpolation](diffusion/frame_interpolation.md)** | Inserts intermediate video frames after generation for smoother motion | Video generation pipelines that need higher temporal smoothness |
### Execution Modes
@@ -108,17 +109,18 @@ The following tables show which models support each feature:
|-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:|
| **Bagel** | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **FLUX.1-dev** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
+| **FLUX.1-schnell** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
| **FLUX.2-klein** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
| **FLUX.1-Kontext-dev** | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
-| **FLUX.2-dev** | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
+| **FLUX.2-dev** | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **GLM-Image** | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **HunyuanImage3** | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
-| **LongCat-Image** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
-| **LongCat-Image-Edit** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
+| **LongCat-Image** | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
+| **LongCat-Image-Edit** | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
| **MagiHuman** | ❌ | ❌ | ❌ | ❓ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
| **MammothModa2(T2I)** | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **Nextstep_1(T2I)** | ❓ | ❓ | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
-| **OmniGen2** | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
+| **OmniGen2** | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **Ovis-Image** | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| **Qwen-Image** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ✅ |
| **Qwen-Image-2512** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ✅ |
@@ -138,16 +140,21 @@ The following tables show which models support each feature:
|-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:|
| **Wan2.2** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (encode/decode) | ❌ | ❌ |
| **Wan2.1-VACE** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ❌ | ❌ |
-| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
+| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
| **Helios** | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ |
-| **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
+| **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
+
+**Frame Interpolation Support**
+
+- **Supported**: Wan2.2 text-to-video, image-to-video, and TI2V pipelines
+- **Not supported**: Wan2.1-VACE, LTX-2, Helios, HunyuanVideo-1.5, DreamID-Omni
### AudioGen
| Model | ⚡TeaCache | ⚡Cache-DiT | 🔀SP (Ulysses & Ring) | 🔀CFG-Parallel | 🔀Tensor-Parallel | 🔀HSDP | 💾CPU Offload (Layerwise) | 💾VAE-Patch-Parallel | 💾Quantization | 🔄Step Execution |
|-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:|
-| **Stable-Audio-Open** | ❌ | ❌ | ❓ | ❓ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
+| **Stable-Audio-Open** | ✅ | ❌ | ❓ | ❓ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
## Feature Compatibility
@@ -258,6 +265,7 @@ Measured on NVIDIA H800:
**Extensions:**
- **[LoRA Inference Guide](diffusion/lora.md)** - Low-Rank Adaptation for style customization and fine-tuning
+- **[Frame Interpolation Guide](diffusion/frame_interpolation.md)** - Worker-side post-generation video frame interpolation for smoother motion
**Execution Modes:**
diff --git a/docs/user_guide/examples/offline_inference/bagel.md b/docs/user_guide/examples/offline_inference/bagel.md
index 5f458750b4..1fb4d40457 100644
--- a/docs/user_guide/examples/offline_inference/bagel.md
+++ b/docs/user_guide/examples/offline_inference/bagel.md
@@ -176,8 +176,6 @@ Example configuration for TP=2 on GPUs 0 and 1:
| Parameter | Value | Description |
| :-------------------- | :------ | :------------------------------- |
-| `window_size` | `-1` | Window size (-1 means unlimited) |
-| `max_inflight` | `1` | Maximum inflight requests |
| `shm_threshold_bytes` | `65536` | Shared memory threshold (64KB) |
## Using Mooncake Connector
@@ -250,13 +248,6 @@ For more details on the Mooncake connector and multi-node setup, see the [Moonca
## FAQ
-- If you encounter an error about the backend of librosa, try to install ffmpeg with the command below.
-
-```bash
-sudo apt update
-sudo apt install ffmpeg
-```
-
- If you don’t know how much VRAM is needed for the model or encounter the OOM error, you can try to decrease the max_model_len.
| Stage | VRAM |
diff --git a/docs/user_guide/examples/offline_inference/cosyvoice3.md b/docs/user_guide/examples/offline_inference/cosyvoice3.md
index d912f1c62e..ebb7c02efc 100644
--- a/docs/user_guide/examples/offline_inference/cosyvoice3.md
+++ b/docs/user_guide/examples/offline_inference/cosyvoice3.md
@@ -10,7 +10,7 @@ Install dependencies:
uv pip install -e .
```
-> **Note:** This includes required libraries such as `librosa`, `soundfile`,
+> **Note:** This includes required libraries such as `soundfile`,
> `onnxruntime`, `x-transformers`, and `einops` via
> `requirements/common.txt` and platform-specific requirements files.
diff --git a/docs/user_guide/examples/offline_inference/image_to_video.md b/docs/user_guide/examples/offline_inference/image_to_video.md
index 7a750aeff3..6e105741a7 100644
--- a/docs/user_guide/examples/offline_inference/image_to_video.md
+++ b/docs/user_guide/examples/offline_inference/image_to_video.md
@@ -62,12 +62,13 @@ Key arguments:
- `--negative-prompt`: Optional list of artifacts to suppress.
- `--boundary-ratio`: Boundary split ratio for two-stage MoE models.
- `--flow-shift`: Scheduler flow shift (5.0 for 720p, 12.0 for 480p).
+- `--sample-solver`: Wan2.2 sampling solver. Use `unipc` for the default multistep solver, or `euler` for Lightning/Distill checkpoints.
- `--num-inference-steps`: Number of denoising steps (default 50).
- `--fps`: Frames per second for the saved MP4 (requires `diffusers` export_to_video).
- `--output`: Path to save the generated video.
- `--vae-use-slicing`: Enable VAE slicing for memory optimization.
- `--vae-use-tiling`: Enable VAE tiling for memory optimization.
-- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](https://github.com/vllm-project/vllm-omni/tree/main/docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel).
+- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](https://github.com/vllm-project/vllm-omni/tree/main/docs/user_guide/diffusion/parallelism/cfg_parallel.md).
- `--tensor-parallel-size`: tensor parallel size (effective for models that support TP, e.g. LTX2).
- `--enable-cpu-offload`: enable CPU offloading for diffusion models.
- `--use-hsdp`: Enable Hybrid Sharded Data Parallel to shard model weights across GPUs.
@@ -78,6 +79,9 @@ Key arguments:
> ℹ️ If you encounter OOM errors, try using `--vae-use-slicing` and `--vae-use-tiling` to reduce memory usage.
+For Wan2.2 LightX2V-converted local Diffusers directories and related LoRA
+assets, see the [LoRA guide](../../diffusion/lora.md#wan22-lightx2v-offline-assembly).
+
## Example materials
??? abstract "image_to_video.py"
diff --git a/docs/user_guide/examples/offline_inference/mimo_audio.md b/docs/user_guide/examples/offline_inference/mimo_audio.md
index 1a3be15d69..4e80526971 100644
--- a/docs/user_guide/examples/offline_inference/mimo_audio.md
+++ b/docs/user_guide/examples/offline_inference/mimo_audio.md
@@ -189,29 +189,6 @@ Note: This task uses hardcoded message lists in the script.
## Troubleshooting
-### Audio dependencies (soundfile, librosa)
-
-This example depends on **soundfile** (read/write WAV) and **librosa** (load audio including MP3). Install the project requirements first:
-
-```bash
-pip install -r requirements/common.txt
-# or at least: pip install soundfile>=0.13.1 librosa>=0.11.0
-```
-
-- **`soundfile` / libsndfile not found**
- `soundfile` uses the C library **libsndfile**. On Linux, install the system package before pip:
- - Debian/Ubuntu: `sudo apt-get install libsndfile1`
- - For development builds: `sudo apt-get install libsndfile1-dev`
- - Then: `pip install soundfile`
-
-- **`librosa` fails to load MP3 or reports "No backend available"**
- Loading MP3 (e.g. in `spoken_dialogue_sft_multiturn` with `.mp3` files) uses **ffmpeg** as the backend. Install ffmpeg:
- - Debian/Ubuntu: `sudo apt-get install ffmpeg`
- - macOS: `brew install ffmpeg`
-
-- **`ImportError: No module named 'soundfile'` or `ModuleNotFoundError: ... librosa`**
- Ensure you are in the same Python environment where vLLM Omni and the example dependencies are installed, and that `requirements/common.txt` (or the packages above) are installed.
-
### Tokenizer path
- **`MIMO_AUDIO_TOKENIZER_PATH` not set or model fails to find tokenizer**
diff --git a/docs/user_guide/examples/offline_inference/qwen2_5_omni.md b/docs/user_guide/examples/offline_inference/qwen2_5_omni.md
index 07a56cf9a0..c54976b540 100644
--- a/docs/user_guide/examples/offline_inference/qwen2_5_omni.md
+++ b/docs/user_guide/examples/offline_inference/qwen2_5_omni.md
@@ -64,14 +64,6 @@ If media file paths are not provided, the script will use default assets. Suppor
- `use_audio_in_video`: Extract audio from video
- `text`: Text-only query
-### FAQ
-
-If you encounter error about backend of librosa, try to install ffmpeg with command below.
-```
-sudo apt update
-sudo apt install ffmpeg
-```
-
## Example materials
??? abstract "end2end.py"
diff --git a/docs/user_guide/examples/offline_inference/qwen3_omni.md b/docs/user_guide/examples/offline_inference/qwen3_omni.md
index 6577092bbf..2d856f7380 100644
--- a/docs/user_guide/examples/offline_inference/qwen3_omni.md
+++ b/docs/user_guide/examples/offline_inference/qwen3_omni.md
@@ -112,14 +112,6 @@ python end2end_async_chunk.py \
> async_chunk example when you need the stage-level concurrency semantics
> described in PR #962 / #1151.
-### FAQ
-
-If you encounter error about backend of librosa, try to install ffmpeg with command below.
-```
-sudo apt update
-sudo apt install ffmpeg
-```
-
## Example materials
??? abstract "end2end.py"
diff --git a/docs/user_guide/examples/offline_inference/qwen3_tts.md b/docs/user_guide/examples/offline_inference/qwen3_tts.md
index 19fea4132c..7226ac1fe4 100644
--- a/docs/user_guide/examples/offline_inference/qwen3_tts.md
+++ b/docs/user_guide/examples/offline_inference/qwen3_tts.md
@@ -18,11 +18,11 @@ Please refer to the [stage configuration documentation](https://docs.vllm.ai/pro
### ROCm Dependencies
-You will need to install these two dependencies `onnxruntime-rocm` and `sox`.
+You will need to install the dependency `onnxruntime-rocm`.
```
pip uninstall onnxruntime # should be removed before we can install onnxruntime-rocm
-pip install onnxruntime-rocm sox
+pip install onnxruntime-rocm
```
## Quick Start
@@ -144,13 +144,13 @@ completes. This demonstrates that audio data is available progressively rather t
## Batched Decoding
-The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, provide a stage config with `max_num_seqs > 1` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`.
+The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, set `max_num_seqs > 1` on both stages via `--stage-overrides` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`.
```
python end2end.py --query-type CustomVoice \
--txt-prompts benchmark_prompts.txt \
--batch-size 4 \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml
+ --stage-overrides '{"0":{"max_num_seqs":4,"gpu_memory_utilization":0.2},"1":{"max_num_seqs":4,"gpu_memory_utilization":0.2}}'
```
**Important:** `--batch-size` must match a CUDA graph capture size (1, 2, 4, 8, 16...) because the Talker's code predictor KV cache is sized to `max_num_seqs`, and CUDA graphs pad the batch to the next capture size. Both stages need `max_num_seqs >= batch_size` in the stage config for batching to take effect. If only stage 1 has a higher `max_num_seqs`, it won't help — stage 1 can only batch chunks from requests that are in-flight simultaneously, which requires stage 0 to also process multiple requests concurrently.
diff --git a/docs/user_guide/examples/offline_inference/x_to_video_audio.md b/docs/user_guide/examples/offline_inference/x_to_video_audio.md
index 8ea39d8115..cec8d47c59 100644
--- a/docs/user_guide/examples/offline_inference/x_to_video_audio.md
+++ b/docs/user_guide/examples/offline_inference/x_to_video_audio.md
@@ -31,9 +31,9 @@ dreamid_omni/
```
### Run the Inference
-```
+```python
python x_to_video_audio.py \
- --model /xx/dreamid_omni \
+ --model /path/to/dreamid_omni \
--prompt "Two people walking together and singing happily" \
--image-path ./example0.png ./example1.png \
--audio-path ./example0.wav ./example1.wav \
@@ -43,11 +43,33 @@ python x_to_video_audio.py \
--num-inference-steps 45 \
--height 704 \
--width 1280 \
- --output dreamid_omni.mp4
+ --output out_dreamid_omni_twoip.mp4
```
In the current test scenario (2 images + 2 audio inputs), the VRAM requirement is 72GB, regardless of whether cfg-parallel is enabled or disabled.
The VRAM usage can be reduced by enabling CPU offload via --enable-cpu-offload.
+
+You could take reference images/audios from the test cases in the official repo: https://github.com/Guoxu1233/DreamID-Omni
+
+For example, single IP ref resources can be found under https://github.com/Guoxu1233/DreamID-Omni/tree/main/test_case/oneip, you could download them correspondingly to your local and use them for testing.
+
+```python
+# Example usage for oneip, ref media from the official repo DreamID-Omni
+python x_to_video_audio.py \
+ --model /path/to/dreamid_omni \
+ --prompt ": In the frame, a woman with black long hair is identified as .\n**Overall Environment/Scene**: A lively open-kitchen café at night; stove flames flare, steam rises, and warm pendant lights swing slightly as staff move behind her. The shot is an upper-body close-up.\n**Main Characters/Subjects Appearance**: is a young woman with thick dark wavy hair and a side part. She wears a fitted black top under a light apron, a thin gold chain necklace, and small stud earrings.\n**Main Characters/Subjects Actions**: tastes the sauce with a spoon, then turns her face toward the camera while still holding the spoon, her expression shifting from focused to conflicted.\n maintains eye contact, swallows as if choosing her words, and says, I keep telling myself I’m fine,but some nights it feels like I’m just performing calm." \
+ --image-path 9.png \
+ --audio-path 9.wav \
+ --video-negative-prompt "jitter, bad hands, blur, distortion" \
+ --audio-negative-prompt "robotic, muffled, echo, distorted" \
+ --cfg-parallel-size 2 \
+ --num-inference-steps 45 \
+ --height 704 \
+ --width 1280 \
+ --output out_dreamid_omni_oneip.mp4
+```
+
+
Key arguments:
- `--prompt`: text description (string).
- `--model`: path to the model local directory.
diff --git a/docs/user_guide/examples/online_serving/bagel.md b/docs/user_guide/examples/online_serving/bagel.md
index 4a6094c089..9de31926aa 100644
--- a/docs/user_guide/examples/online_serving/bagel.md
+++ b/docs/user_guide/examples/online_serving/bagel.md
@@ -357,13 +357,6 @@ curl http://localhost:8091/v1/chat/completions \
## FAQ
-- If you encounter an error about the backend of librosa, try to install ffmpeg with the command below.
-
-```bash
-sudo apt update
-sudo apt install ffmpeg
-```
-
- If you don’t know how much VRAM is needed for the model or encounter the OOM error, you can try to decrease the max_model_len.
| Stage | VRAM |
diff --git a/docs/user_guide/examples/online_serving/image_to_video.md b/docs/user_guide/examples/online_serving/image_to_video.md
index 00b67d74e2..781f0c2a5e 100644
--- a/docs/user_guide/examples/online_serving/image_to_video.md
+++ b/docs/user_guide/examples/online_serving/image_to_video.md
@@ -72,6 +72,9 @@ curl -X POST http://localhost:8091/v1/videos/sync \
-F "guidance_scale_2=1.0" \
-F "boundary_ratio=0.875" \
-F "flow_shift=12.0" \
+ -F "enable_frame_interpolation=true" \
+ -F "frame_interpolation_exp=1" \
+ -F "frame_interpolation_scale=1.0" \
-F "seed=42" \
-o sync_i2v_output.mp4
```
@@ -114,6 +117,9 @@ create_response=$(curl -s http://localhost:8091/v1/videos \
-F "guidance_scale_2=1.0" \
-F "boundary_ratio=0.875" \
-F "flow_shift=12.0" \
+ -F "enable_frame_interpolation=true" \
+ -F "frame_interpolation_exp=1" \
+ -F "frame_interpolation_scale=1.0" \
-F "seed=42")
video_id=$(echo "$create_response" | jq -r '.id')
@@ -172,9 +178,35 @@ curl -X POST http://localhost:8091/v1/videos \
-F "guidance_scale_2=1.0" \
-F "boundary_ratio=0.875" \
-F "flow_shift=12.0" \
+ -F "enable_frame_interpolation=true" \
+ -F "frame_interpolation_exp=1" \
+ -F "frame_interpolation_scale=1.0" \
-F "seed=42"
```
+Frame interpolation is also available for supported Wan2.2 I2V requests. See
+[Frame Interpolation](../../diffusion/frame_interpolation.md) for worker-side
+execution details and feature constraints.
+
+### Frame Interpolation Example
+
+```bash
+curl -X POST http://localhost:8091/v1/videos/sync \
+ -F "prompt=A bear playing with yarn, smooth motion" \
+ -F "input_reference=@/path/to/qwen-bear.png" \
+ -F "width=832" \
+ -F "height=480" \
+ -F "num_frames=33" \
+ -F "fps=16" \
+ -F "num_inference_steps=40" \
+ -F "guidance_scale=1.0" \
+ -F "guidance_scale_2=1.0" \
+ -F "enable_frame_interpolation=true" \
+ -F "frame_interpolation_exp=1" \
+ -F "frame_interpolation_scale=1.0" \
+ -o sync_i2v_interpolated.mp4
+```
+
## Create Response Format
`POST /v1/videos` returns a job record, not inline base64 video data.
diff --git a/docs/user_guide/examples/online_serving/qwen2_5_omni.md b/docs/user_guide/examples/online_serving/qwen2_5_omni.md
index 4357646924..b3a2c9f2ac 100644
--- a/docs/user_guide/examples/online_serving/qwen2_5_omni.md
+++ b/docs/user_guide/examples/online_serving/qwen2_5_omni.md
@@ -218,14 +218,6 @@ The gradio script supports the following arguments:
- `--port`: Port for Gradio server (default: 7861)
- `--share`: Share the Gradio demo publicly (creates a public link)
-### FAQ
-
-If you encounter error about backend of librosa, try to install ffmpeg with command below.
-```
-sudo apt update
-sudo apt install ffmpeg
-```
-
## Example materials
??? abstract "gradio_demo.py"
diff --git a/docs/user_guide/examples/online_serving/qwen3_omni.md b/docs/user_guide/examples/online_serving/qwen3_omni.md
index 69de24852f..611eb6fd3f 100644
--- a/docs/user_guide/examples/online_serving/qwen3_omni.md
+++ b/docs/user_guide/examples/online_serving/qwen3_omni.md
@@ -18,12 +18,12 @@ vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091
If you want to open async chunking for qwen3-omni, launch the server with command below
```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --deploy-config /vllm_omni/deploy/qwen3_omni_moe.yaml
```
If you have custom stage configs file, launch the server with command below
```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /path/to/stage_configs_file
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --deploy-config /path/to/deploy_config_file
```
### Send Multi-modal Request
@@ -64,15 +64,6 @@ python openai_chat_completion_client_for_multimodal_generation.py \
bash run_curl_multimodal_generation.sh use_image
```
-
-### FAQ
-
-If you encounter error about backend of librosa, try to install ffmpeg with command below.
-```
-sudo apt update
-sudo apt install ffmpeg
-```
-
## Modality control
You can control output modalities to specify which types of output the model should generate. This is useful when you only need text output and want to skip audio generation stages for better performance.
@@ -196,7 +187,7 @@ The script supports the following arguments:
- `--model`: Model name/path (default: Qwen/Qwen3-Omni-30B-A3B-Instruct)
- `--server-port`: Port for vLLM server (default: 8091)
- `--gradio-port`: Port for Gradio demo (default: 7861)
-- `--stage-configs-path`: Path to custom stage configs YAML file (optional)
+- `--deploy-config`: Path to custom deploy config YAML file (optional)
- `--server-host`: Host for vLLM server (default: 0.0.0.0)
- `--gradio-ip`: IP for Gradio demo (default: 127.0.0.1)
- `--share`: Share Gradio demo publicly (creates a public link)
@@ -211,7 +202,7 @@ vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091
If you have custom stage configs file:
```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /path/to/stage_configs_file
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --deploy-config /path/to/deploy_config_file
```
**Step 2: Run the Gradio demo**
diff --git a/docs/user_guide/examples/online_serving/qwen3_tts.md b/docs/user_guide/examples/online_serving/qwen3_tts.md
index 156c4942cd..95f234f02d 100644
--- a/docs/user_guide/examples/online_serving/qwen3_tts.md
+++ b/docs/user_guide/examples/online_serving/qwen3_tts.md
@@ -58,7 +58,7 @@ Then open http://localhost:7860 in your browser.
```bash
# CustomVoice model (predefined speakers)
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
@@ -66,7 +66,7 @@ vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
# VoiceDesign model
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
@@ -74,7 +74,7 @@ vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign \
# Base model (voice cloning)
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-Base \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
@@ -211,14 +211,6 @@ with open("output.wav", "wb") as f:
f.write(response.content)
```
-### FAQ
-
-If you encounter error about backend of librosa, try to install ffmpeg with command below.
-```
-sudo apt update
-sudo apt install ffmpeg
-```
-
## API Reference
### Voices Endpoint
diff --git a/docs/user_guide/examples/online_serving/text_to_video.md b/docs/user_guide/examples/online_serving/text_to_video.md
index d58296fcc7..00a9c16723 100644
--- a/docs/user_guide/examples/online_serving/text_to_video.md
+++ b/docs/user_guide/examples/online_serving/text_to_video.md
@@ -3,17 +3,28 @@
Source .
-This example demonstrates how to deploy the Wan2.2 text-to-video model for online video generation using vLLM-Omni.
+This example demonstrates how to deploy text-to-video models for online video generation using vLLM-Omni.
-## Start Server
+## Supported Models
-### Basic Start
+| Model | Model ID |
+|-------|----------|
+| Wan2.1 T2V (1.3B) | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` |
+| Wan2.1 T2V (14B) | `Wan-AI/Wan2.1-T2V-14B-Diffusers` |
+| Wan2.2 T2V | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` |
+| LTX-2 | `Lightricks/LTX-2` |
+
+## Wan2.2 T2V
+
+### Start Server
+
+#### Basic Start
```bash
vllm serve Wan-AI/Wan2.2-T2V-A14B-Diffusers --omni --port 8091
```
-### Start with Parameters
+#### Start with Parameters
Or use the startup script:
@@ -154,6 +165,9 @@ curl -X POST http://localhost:8091/v1/videos \
-F "guidance_scale_2=4.0" \
-F "boundary_ratio=0.875" \
-F "flow_shift=5.0" \
+ -F "enable_frame_interpolation=true" \
+ -F "frame_interpolation_exp=1" \
+ -F "frame_interpolation_scale=1.0" \
-F "seed=42"
```
@@ -176,6 +190,35 @@ curl -X POST http://localhost:8091/v1/videos \
| `flow_shift` | float | None | Scheduler flow shift (Wan2.2) |
| `seed` | int | None | Random seed (reproducible) |
| `lora` | object | None | LoRA configuration |
+| `enable_frame_interpolation` | bool | false | Enable RIFE frame interpolation before MP4 encoding |
+| `frame_interpolation_exp` | int | 1 | Interpolation exponent; 1=2x temporal resolution, 2=4x |
+| `frame_interpolation_scale` | float | 1.0 | RIFE inference scale; use 0.5 for high-resolution inputs |
+| `frame_interpolation_model_path` | str | None | Local directory or Hugging Face repo ID with `flownet.pkl`; defaults to `elfgum/RIFE-4.22.lite` |
+
+## Frame Interpolation
+
+Frame interpolation is an optional post-processing step for `/v1/videos` and
+`/v1/videos/sync`. It synthesizes intermediate frames between generated frames
+without rerunning the diffusion model. If the generated video has `N` frames,
+the interpolated output frame count is `(N - 1) * 2**exp + 1`. The encoder FPS
+is multiplied by `2**exp` so the output duration remains close to the original.
+
+Frame interpolation runs in the diffusion worker post-processing path instead of
+the API server encoding path, so it can reuse the worker's current accelerator
+device without blocking the FastAPI event loop.
+
+Example: generate 5 frames and interpolate to 9 frames:
+
+```bash
+curl -X POST http://localhost:8091/v1/videos/sync \
+ -F "prompt=A dog running through a park" \
+ -F "num_frames=5" \
+ -F "fps=8" \
+ -F "enable_frame_interpolation=true" \
+ -F "frame_interpolation_exp=1" \
+ -F "frame_interpolation_scale=1.0" \
+ -o sync_t2v_interpolated.mp4
+```
## Create Response Format
@@ -234,8 +277,94 @@ while true; do
done
```
+## LTX-2
+
+### Start Server
+
+#### Basic Start
+
+```bash
+vllm serve Lightricks/LTX-2 --omni --port 8098 \
+ --enforce-eager --flow-shift 1.0 --boundary-ratio 1.0
+```
+
+#### Start with Optimization Presets
+
+Use the LTX-2 startup script with built-in optimization presets:
+
+```bash
+# Baseline (1 GPU, eager)
+bash run_server_ltx2.sh baseline
+
+# 4-GPU Ulysses sequence parallelism (lossless)
+bash run_server_ltx2.sh ulysses4
+
+# Cache-DiT lossy acceleration (1 GPU, ~1.4× speedup)
+bash run_server_ltx2.sh cache-dit
+
+# Best combo: 4-GPU Ulysses SP + Cache-DiT (~2.2× speedup)
+bash run_server_ltx2.sh best-combo
+```
+
+#### Optimization Benchmarks
+
+Benchmarked on H800, online serving (480×768, 41 frames, 20 steps, `seed=42`).
+"Inference" is the server-reported inference time; excludes HTTP/poll overhead.
+
+| Preset | Server Command | Inference (s) | Speedup | Type |
+|--------|---------------|---------------|---------|------|
+| `baseline` | `--enforce-eager` | 10.3 | 1.00× | — |
+| `compile` | *(default, no --enforce-eager)* | ~10.3 (warm) | ~1.00× | Lossless |
+| `ulysses4` | `--enforce-eager --usp 4` | ~10.3 | ~1.00× | Lossless |
+| `cache-dit` | `--enforce-eager --cache-backend cache_dit` | 7.4 avg | ~1.4× | Lossy |
+| `best-combo` | `--enforce-eager --usp 4 --cache-backend cache_dit` | 4.7 avg | **~2.2×** | Lossless + Lossy |
+
+**Observations**:
+- **torch.compile**: On H800, warm-request inference time matches the eager baseline (~10.3s).
+ The first request pays ~6s compilation overhead. Benefit depends on model architecture and GPU.
+- **Ulysses SP (4 GPU)**: No measurable speedup alone for 41-frame generation at this resolution.
+ Communication overhead outweighs gains at this sequence length.
+- **Cache-DiT**: Inference varies per request (6–10s) due to dynamic caching decisions.
+ Average is ~7.4s (~1.4× speedup) with slight quality tradeoff.
+- **Best combo**: 4-GPU Ulysses SP + Cache-DiT synergize well — Cache-DiT reduces per-step
+ computation, making the communication overhead of Ulysses SP worthwhile. Average ~4.7s
+ (~2.2× speedup).
+- **FP8 quantization**: Reduces VRAM but does not speed up LTX-2 on H800 (compute-bound).
+
+**Deployment Recommendations**:
+- For **production with quality priority**: use `baseline` with `--enforce-eager`
+- For **maximum throughput** (4 GPUs, quality tradeoff): use `best-combo` (~2.2× speedup)
+- For **single-GPU throughput**: use `cache-dit` (~1.4× speedup)
+- `--enforce-eager` is recommended to avoid torch.compile warmup latency on first request
+
+### Send Requests (curl)
+
+```bash
+# Using the provided script
+bash run_curl_ltx2.sh
+
+# Or directly
+curl -sS -X POST http://localhost:8098/v1/videos \
+ -H "Accept: application/json" \
+ -F "prompt=A serene lakeside sunrise with mist over the water." \
+ -F "width=768" \
+ -F "height=480" \
+ -F "num_frames=41" \
+ -F "fps=24" \
+ -F "num_inference_steps=20" \
+ -F "guidance_scale=3.0" \
+ -F "seed=42"
+```
+
## Example materials
+??? abstract "response.json"
+ ``````json
+ --8<-- "examples/online_serving/text_to_video/response.json"
+ ``````
+??? abstract "run_curl_ltx2.sh"
+ ``````sh
+ --8<-- "examples/online_serving/text_to_video/run_curl_ltx2.sh"
??? abstract "run_curl_hunyuan_video_15.sh"
``````sh
--8<-- "examples/online_serving/text_to_video/run_curl_hunyuan_video_15.sh"
@@ -248,6 +377,9 @@ done
``````sh
--8<-- "examples/online_serving/text_to_video/run_server.sh"
``````
+??? abstract "run_server_ltx2.sh"
+ ``````sh
+ --8<-- "examples/online_serving/text_to_video/run_server_ltx2.sh"
??? abstract "run_server_hunyuan_video_15.sh"
``````sh
--8<-- "examples/online_serving/text_to_video/run_server_hunyuan_video_15.sh"
diff --git a/examples/offline_inference/bagel/README.md b/examples/offline_inference/bagel/README.md
index 226c009f79..3e653d0e3a 100644
--- a/examples/offline_inference/bagel/README.md
+++ b/examples/offline_inference/bagel/README.md
@@ -173,8 +173,6 @@ Example configuration for TP=2 on GPUs 0 and 1:
| Parameter | Value | Description |
| :-------------------- | :------ | :------------------------------- |
-| `window_size` | `-1` | Window size (-1 means unlimited) |
-| `max_inflight` | `1` | Maximum inflight requests |
| `shm_threshold_bytes` | `65536` | Shared memory threshold (64KB) |
## Using Mooncake Connector
@@ -247,13 +245,6 @@ For more details on the Mooncake connector and multi-node setup, see the [Moonca
## FAQ
-- If you encounter an error about the backend of librosa, try to install ffmpeg with the command below.
-
-```bash
-sudo apt update
-sudo apt install ffmpeg
-```
-
- If you don’t know how much VRAM is needed for the model or encounter the OOM error, you can try to decrease the max_model_len.
| Stage | VRAM |
diff --git a/examples/offline_inference/bagel/end2end.py b/examples/offline_inference/bagel/end2end.py
index 472d748d1e..ed5fa57e8d 100644
--- a/examples/offline_inference/bagel/end2end.py
+++ b/examples/offline_inference/bagel/end2end.py
@@ -97,6 +97,24 @@ def parse_args():
default=False,
help="Enable thinking mode: AR stage decodes ... planning tokens before image generation.",
)
+ parser.add_argument(
+ "--max-think-tokens",
+ type=int,
+ default=1000,
+ help="Maximum number of tokens for thinking text generation (default: 1000).",
+ )
+ parser.add_argument(
+ "--do-sample",
+ action="store_true",
+ default=False,
+ help="Enable sampling for text generation (default: greedy).",
+ )
+ parser.add_argument(
+ "--text-temperature",
+ type=float,
+ default=0.3,
+ help="Temperature for text generation sampling (default: 0.3).",
+ )
args = parser.parse_args()
return args
@@ -108,7 +126,6 @@ def main():
model_name = args.model
prompts: list[OmniPromptType] = []
try:
- # Preferred: load from txt file (one prompt per line)
if getattr(args, "txt_prompts", None) and args.prompt_type == "text":
with open(args.txt_prompts, encoding="utf-8") as f:
lines = [ln.strip() for ln in f.readlines()]
@@ -121,10 +138,8 @@ def main():
raise
if not prompts:
- # Default prompt for text2img test if none provided
prompts = ["A cute cat"]
print(f"[Info] No prompts provided, using default: {prompts}")
- omni_outputs = []
from PIL import Image
@@ -132,11 +147,13 @@ def main():
omni_kwargs = {}
stage_configs_path = args.stage_configs_path
+ is_single_stage = stage_configs_path and "single_stage" in stage_configs_path
if args.think and stage_configs_path is None:
stage_configs_path = "vllm_omni/model_executor/stage_configs/bagel_think.yaml"
print(f"[Info] Think mode enabled, using stage config: {stage_configs_path}")
if stage_configs_path:
omni_kwargs["stage_configs_path"] = stage_configs_path
+ is_single_stage = "single_stage" in stage_configs_path
omni_kwargs.update(
{
@@ -198,40 +215,61 @@ def main():
formatted_prompts.append(prompt_dict)
params_list = omni.default_sampling_params_list
+
+ # For single-stage DiT, think/text params go into the diffusion sampling params extra_args.
+ # For 2-stage, diffusion params are at index 1.
+ diffusion_params_idx = 0 if is_single_stage else (1 if len(params_list) > 1 else 0)
+ diffusion_params = params_list[diffusion_params_idx]
+
if args.modality in ("text2img", "img2img"):
- if len(params_list) > 1:
- diffusion_params = params_list[1]
- diffusion_params.num_inference_steps = args.steps # type: ignore
- diffusion_params.cfg_parallel_size = args.cfg_parallel_size # type: ignore
- if args.seed is not None:
- diffusion_params.seed = args.seed # type: ignore
- extra = {
- "cfg_text_scale": args.cfg_text_scale,
- "cfg_img_scale": args.cfg_img_scale,
- }
- if args.cfg_interval is not None:
- extra["cfg_interval"] = tuple(args.cfg_interval)
- if args.cfg_renorm_type is not None:
- extra["cfg_renorm_type"] = args.cfg_renorm_type
- if args.cfg_renorm_min is not None:
- extra["cfg_renorm_min"] = args.cfg_renorm_min
- if args.negative_prompt is not None:
- extra["negative_prompt"] = args.negative_prompt
- diffusion_params.extra_args = extra # type: ignore
+ diffusion_params.num_inference_steps = args.steps # type: ignore
+ diffusion_params.cfg_parallel_size = args.cfg_parallel_size # type: ignore
+ if args.seed is not None:
+ diffusion_params.seed = args.seed # type: ignore
+
+ extra = getattr(diffusion_params, "extra_args", {}) or {}
+ extra["cfg_text_scale"] = args.cfg_text_scale
+ extra["cfg_img_scale"] = args.cfg_img_scale
+ if args.cfg_interval is not None:
+ extra["cfg_interval"] = tuple(args.cfg_interval)
+ if args.cfg_renorm_type is not None:
+ extra["cfg_renorm_type"] = args.cfg_renorm_type
+ if args.cfg_renorm_min is not None:
+ extra["cfg_renorm_min"] = args.cfg_renorm_min
+ if args.negative_prompt is not None:
+ extra["negative_prompt"] = args.negative_prompt
+
+ needs_text_gen = is_single_stage and (args.think or args.modality in ("text2text", "img2text"))
+ if needs_text_gen:
+ if args.think:
+ extra["think"] = True
+ extra["max_think_tokens"] = args.max_think_tokens
+ extra["do_sample"] = args.do_sample
+ extra["text_temperature"] = args.text_temperature
+ diffusion_params.extra_args = extra # type: ignore
omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list))
img_idx = 0
for req_output in omni_outputs:
- if args.think:
- ro = getattr(req_output, "request_output", None)
- if ro and getattr(ro, "outputs", None):
- txt = "".join(getattr(o, "text", "") or "" for o in ro.outputs)
- if txt:
- print(txt)
+ # 2-stage think mode: text output from thinker stage
+ ro = getattr(req_output, "request_output", None)
+ if ro and getattr(ro, "outputs", None):
+ txt = "".join(getattr(o, "text", "") or "" for o in ro.outputs)
+ if txt:
+ if args.think:
+ print(f"[Think]\n{txt}")
+ else:
+ print(f"[Output] Text:\n{txt}")
- images = getattr(req_output, "images", None)
+ # Single-stage DiT: text from custom_output
+ custom = getattr(req_output, "_custom_output", {}) or {}
+ if custom.get("think_text"):
+ print(f"[Think]\n{custom['think_text']}")
+ if custom.get("text_output"):
+ print(f"[Output] Text:\n{custom['text_output']}")
+ images = getattr(req_output, "images", None)
if not images:
continue
@@ -241,8 +279,6 @@ def main():
print(f"[Output] Saved image to {save_path}")
img_idx += 1
- print(omni_outputs)
-
if __name__ == "__main__":
main()
diff --git a/examples/offline_inference/cosyvoice3/README.md b/examples/offline_inference/cosyvoice3/README.md
index 895d3f660f..e16134e6ef 100644
--- a/examples/offline_inference/cosyvoice3/README.md
+++ b/examples/offline_inference/cosyvoice3/README.md
@@ -7,7 +7,7 @@ Install dependencies:
uv pip install -e .
```
-> **Note:** This includes required libraries such as `librosa`, `soundfile`,
+> **Note:** This includes required libraries such as `soundfile`,
> `onnxruntime`, `x-transformers`, and `einops` via
> `requirements/common.txt` and platform-specific requirements files.
diff --git a/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py b/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py
index 68ab72b387..6311bbc901 100644
--- a/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py
+++ b/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py
@@ -2,13 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import os
-from pathlib import Path
-import librosa
import numpy as np
import soundfile as sf
from vllm import SamplingParams
from vllm.assets.audio import AudioAsset
+from vllm.multimodal.media.audio import load_audio
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.model_executor.models.cosyvoice3.config import CosyVoice3Config
@@ -16,22 +15,6 @@
from vllm_omni.model_executor.models.cosyvoice3.utils import extract_text_token
-def _ensure_mel_filters_asset() -> None:
- repo_root = Path(__file__).resolve().parents[3]
- filters_path = repo_root / "vllm_omni" / "model_executor" / "models" / "cosyvoice3" / "assets" / "mel_filters.npz"
- if filters_path.exists():
- return
-
- source_url = "https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/mel_filters.npz"
- raise FileNotFoundError(
- "Missing CosyVoice3 mel filter asset:\n"
- f" {filters_path}\n"
- "Download it with:\n"
- f" mkdir -p {filters_path.parent} && "
- f"curl -L {source_url} -o {filters_path}"
- )
-
-
def run_e2e():
parser = argparse.ArgumentParser()
# ""FunAudioLLM/Fun-CosyVoice3-0.5B-2512
@@ -56,7 +39,6 @@ def run_e2e():
help="Path to tokenizer directory (e.g., /CosyVoice-BlankEN).",
)
args = parser.parse_args()
- _ensure_mel_filters_asset()
# Ensure tokenizer directory exists
if not os.path.exists(args.tokenizer):
raise FileNotFoundError(f"{args.tokenizer} does not exist!")
@@ -85,7 +67,7 @@ def run_e2e():
if not os.path.exists(args.audio_path):
raise FileNotFoundError(f"Audio file not found: {args.audio_path}")
# Load at native sample rate
- audio_signal, sr = librosa.load(args.audio_path, sr=None)
+ audio_signal, sr = load_audio(args.audio_path, sr=None)
# Validate sample rate before processing (similar to original CosyVoice)
min_sr = 16000
diff --git a/examples/offline_inference/hunyuan_image3/prompt_utils.py b/examples/offline_inference/hunyuan_image3/prompt_utils.py
new file mode 100644
index 0000000000..a5ef8e1536
--- /dev/null
+++ b/examples/offline_inference/hunyuan_image3/prompt_utils.py
@@ -0,0 +1,88 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Prompt construction utilities for HunyuanImage-3.0-Instruct examples.
+
+Wraps system_prompt.get_system_prompt() with task-aware presets so that
+examples and tests don't need to manually concatenate system prompts,
+ , , and tags.
+
+Usage:
+ from prompt_utils import build_prompt
+
+ # IT2I (image editing, think+recaption mode)
+ prompt = build_prompt("Make the petals neon pink", task="it2i_think")
+
+ # I2T (image understanding)
+ prompt = build_prompt("Describe the content of the picture.", task="i2t")
+"""
+
+from __future__ import annotations
+
+from vllm_omni.diffusion.models.hunyuan_image3.system_prompt import (
+ get_system_prompt,
+)
+
+# task → (sys_type, bot_task, trigger_tag)
+# trigger_tag: "", "", or None
+_TASK_PRESETS: dict[str, tuple[str, str | None, str | None]] = {
+ # Pure text generation (text → text, no image)
+ "t2t": ("en_unified", None, None),
+ # Image understanding (image → text)
+ "i2t": ("en_unified", None, None),
+ # Image editing (image+text → image), think+recaption mode
+ "it2i_think": ("en_unified", "think", ""),
+ # Image editing, recaption-only mode
+ "it2i_recaption": ("en_unified", "recaption", ""),
+ # Text-to-image, think mode
+ "t2i_think": ("en_unified", "think", ""),
+ # Text-to-image, recaption mode
+ "t2i_recaption": ("en_unified", "recaption", ""),
+ # Text-to-image, vanilla (no CoT)
+ "t2i_vanilla": ("en_vanilla", "image", None),
+}
+
+
+def build_prompt(
+ user_prompt: str,
+ task: str = "it2i_think",
+ sys_type: str | None = None,
+ custom_system_prompt: str | None = None,
+) -> str:
+ """Build a complete HunyuanImage-3.0 prompt with auto-selected system
+ prompt and mode trigger tags.
+
+ Args:
+ user_prompt: The user's raw instruction or question.
+ task: One of the preset task keys (see _TASK_PRESETS).
+ sys_type: Override the preset's sys_type for get_system_prompt().
+ custom_system_prompt: Custom system prompt text (used when
+ sys_type="custom").
+
+ Returns:
+ Fully formatted prompt string ready for Omni.generate().
+ """
+ if task not in _TASK_PRESETS:
+ raise ValueError(f"Unknown task {task!r}. Choose from: {sorted(_TASK_PRESETS)}")
+
+ preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task]
+ effective_sys_type = sys_type or preset_sys_type
+
+ system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt)
+ sys_text = system_prompt.strip() if system_prompt else ""
+
+ has_image_input = task.startswith("i2t") or task.startswith("it2i")
+
+ parts = ["<|startoftext|>"]
+ if sys_text:
+ parts.append(sys_text)
+ # Instruct conversation template: \n\nUser: ... \n\nAssistant:
+ parts.append("\n\nUser: ")
+ if has_image_input:
+ parts.append(" ")
+ parts.append(user_prompt)
+ parts.append("\n\nAssistant: ")
+ if trigger_tag:
+ parts.append(trigger_tag)
+
+ return "".join(parts)
diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py
index a8035a3fdc..1a7e86f13c 100644
--- a/examples/offline_inference/image_to_image/image_edit.py
+++ b/examples/offline_inference/image_to_image/image_edit.py
@@ -297,8 +297,8 @@ def parse_args() -> argparse.Namespace:
"--cfg-parallel-size",
type=int,
default=1,
- choices=[1, 2],
- help="Number of GPUs used for classifier free guidance parallel size.",
+ choices=[1, 2, 3],
+ help="Number of GPUs used for classifier free guidance parallel size (max 3 branches).",
)
parser.add_argument(
"--enforce-eager",
diff --git a/examples/offline_inference/image_to_video/README.md b/examples/offline_inference/image_to_video/README.md
index 2692c76df2..a458850a02 100644
--- a/examples/offline_inference/image_to_video/README.md
+++ b/examples/offline_inference/image_to_video/README.md
@@ -59,12 +59,13 @@ Key arguments:
- `--negative-prompt`: Optional list of artifacts to suppress.
- `--boundary-ratio`: Boundary split ratio for two-stage MoE models.
- `--flow-shift`: Scheduler flow shift (5.0 for 720p, 12.0 for 480p).
+- `--sample-solver`: Wan2.2 sampling solver. Use `unipc` for the default multistep solver, or `euler` for Lightning/Distill checkpoints.
- `--num-inference-steps`: Number of denoising steps (default 50).
- `--fps`: Frames per second for the saved MP4 (requires `diffusers` export_to_video).
- `--output`: Path to save the generated video.
- `--vae-use-slicing`: Enable VAE slicing for memory optimization.
- `--vae-use-tiling`: Enable VAE tiling for memory optimization.
-- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel).
+- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](https://github.com/vllm-project/vllm-omni/tree/main/docs/user_guide/diffusion/parallelism/cfg_parallel.md).
- `--tensor-parallel-size`: tensor parallel size (effective for models that support TP, e.g. LTX2).
- `--enable-cpu-offload`: enable CPU offloading for diffusion models.
- `--use-hsdp`: Enable Hybrid Sharded Data Parallel to shard model weights across GPUs.
@@ -74,3 +75,6 @@ Key arguments:
> ℹ️ If you encounter OOM errors, try using `--vae-use-slicing` and `--vae-use-tiling` to reduce memory usage.
+
+For Wan2.2 LightX2V-converted local Diffusers directories and related LoRA
+assets, see the [LoRA guide](../../../docs/user_guide/diffusion/lora.md#wan22-lightx2v-offline-assembly).
diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py
index 7e7cfbf84e..53319c8221 100644
--- a/examples/offline_inference/image_to_video/image_to_video.py
+++ b/examples/offline_inference/image_to_video/image_to_video.py
@@ -84,6 +84,13 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--flow-shift", type=float, default=5.0, help="Scheduler flow_shift (5.0 for 720p, 12.0 for 480p)."
)
+ parser.add_argument(
+ "--sample-solver",
+ type=str,
+ default="unipc",
+ choices=["unipc", "euler"],
+ help="Sampling solver for Wan2.2 pipelines. Use 'euler' for Lightning/Distill setups.",
+ )
parser.add_argument("--output", type=str, default="i2v_output.mp4", help="Path to save the video (mp4).")
parser.add_argument("--fps", type=int, default=None, help="Frames per second for the output video.")
parser.add_argument(
@@ -305,6 +312,7 @@ def main():
print(f" Model: {args.model}")
print(f" Inference steps: {args.num_inference_steps}")
print(f" Frames: {args.num_frames}")
+ print(f" Solver: {args.sample_solver}")
print(
f" Parallel configuration: cfg_parallel_size={args.cfg_parallel_size},"
f" tensor_parallel_size={args.tensor_parallel_size}, vae_patch_parallel_size={args.vae_patch_parallel_size}"
@@ -326,9 +334,14 @@ def main():
generator=generator,
guidance_scale=guidance_scale,
guidance_scale_2=args.guidance_scale_high,
+ boundary_ratio=args.boundary_ratio,
num_inference_steps=num_inference_steps,
num_frames=num_frames,
frame_rate=frame_rate,
+ extra_args={
+ "sample_solver": args.sample_solver,
+ "flow_shift": args.flow_shift,
+ },
),
)
generation_end = time.perf_counter()
diff --git a/examples/offline_inference/mimo_audio/README.md b/examples/offline_inference/mimo_audio/README.md
index 747e734cc2..596afabeef 100644
--- a/examples/offline_inference/mimo_audio/README.md
+++ b/examples/offline_inference/mimo_audio/README.md
@@ -190,29 +190,6 @@ Note: This task uses hardcoded message lists in the script.
## Troubleshooting
-### Audio dependencies (soundfile, librosa)
-
-This example depends on **soundfile** (read/write WAV) and **librosa** (load audio including MP3). Install the project requirements first:
-
-```bash
-pip install -r requirements/common.txt
-# or at least: pip install soundfile>=0.13.1 librosa>=0.11.0
-```
-
-- **`soundfile` / libsndfile not found**
- `soundfile` uses the C library **libsndfile**. On Linux, install the system package before pip:
- - Debian/Ubuntu: `sudo apt-get install libsndfile1`
- - For development builds: `sudo apt-get install libsndfile1-dev`
- - Then: `pip install soundfile`
-
-- **`librosa` fails to load MP3 or reports "No backend available"**
- Loading MP3 (e.g. in `spoken_dialogue_sft_multiturn` with `.mp3` files) uses **ffmpeg** as the backend. Install ffmpeg:
- - Debian/Ubuntu: `sudo apt-get install ffmpeg`
- - macOS: `brew install ffmpeg`
-
-- **`ImportError: No module named 'soundfile'` or `ModuleNotFoundError: ... librosa`**
- Ensure you are in the same Python environment where vLLM Omni and the example dependencies are installed, and that `requirements/common.txt` (or the packages above) are installed.
-
### Tokenizer path
- **`MIMO_AUDIO_TOKENIZER_PATH` not set or model fails to find tokenizer**
diff --git a/examples/offline_inference/mimo_audio/message_convert.py b/examples/offline_inference/mimo_audio/message_convert.py
index ebcc59c6b4..416f21ccfa 100644
--- a/examples/offline_inference/mimo_audio/message_convert.py
+++ b/examples/offline_inference/mimo_audio/message_convert.py
@@ -5,12 +5,12 @@
import re
from collections.abc import Callable
-import librosa
import numpy as np
import torch
import torchaudio
from process_speechdata import InputSegment, StreamingInputSegment
from torchaudio.transforms import MelSpectrogram
+from vllm.multimodal.media.audio import load_audio
speech_zeroemb_idx = 151667
empty_token = "<|empty|>"
@@ -685,7 +685,7 @@ def get_audio_data(audio_url):
# File path
audio_file = audio_url
- audio_signal, sr = librosa.load(audio_file, sr=24000)
+ audio_signal, sr = load_audio(audio_file, sr=24000)
audio_data = (audio_signal.astype(np.float32), sr)
return audio_data
diff --git a/examples/offline_inference/ming_flash_omni/README.md b/examples/offline_inference/ming_flash_omni/README.md
new file mode 100644
index 0000000000..7414163fc0
--- /dev/null
+++ b/examples/offline_inference/ming_flash_omni/README.md
@@ -0,0 +1,76 @@
+# Ming-flash-omni 2.0
+
+[Ming-flash-omni-2.0](https://github.com/inclusionAI/Ming) is an omni-modal model supporting text, image, video, and audio understanding, with outputs in text, image, and audio. For now, Ming-flash-omni-2.0 in vLLM-Omni is supported with thinker stage (multi-modal understanding).
+
+## Setup
+
+Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup.
+
+## Run examples
+
+### Text-only
+```bash
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type text
+```
+
+#### Reasoning (Thinking Mode)
+
+Reasoning (Thinking) mode is enabled via applying "detailed thinking on" when building the system prompt template (in `apply_chat_template`).
+
+In the end2end example, a default problem for thinking mode is provided, as referred to the example usage of Ming's cookbook;
+To utilize it, you have to download the example figure from https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/figures/cases/3_0.png
+
+```bash
+python examples/offline_inference/ming_flash_omni/end2end.py -q reasoning --image-path ./3_0.png
+```
+
+### Image understanding
+```bash
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_image
+
+# With a local image
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_image --image-path /path/to/image.jpg
+```
+
+### Audio understanding
+```bash
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_audio
+
+# With a local audio file
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_audio --audio-path /path/to/audio.wav
+```
+
+### Video understanding
+```bash
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_video
+
+# With a local video and custom frame count
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_video --video-path /path/to/video.mp4 --num-frames 16
+```
+
+### Mixed modalities (image + audio)
+```bash
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_mixed_modalities \
+ --image-path /path/to/image.jpg \
+ --audio-path /path/to/audio.wav
+```
+
+If media file paths are not provided, the script uses built-in default assets.
+
+### Modality control
+To control output modalities (e.g. text-only output):
+```bash
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_audio --modalities text
+```
+
+*For now, only text output is supported*
+
+### Custom stage config
+```bash
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_image \
+ --stage-configs-path /path/to/your_config.yaml
+```
+
+## Online serving
+
+For online serving via the OpenAI-compatible API, see [examples/online_serving/ming_flash_omni/README.md](../../online_serving/ming_flash_omni/README.md).
diff --git a/examples/offline_inference/ming_flash_omni/end2end.py b/examples/offline_inference/ming_flash_omni/end2end.py
new file mode 100644
index 0000000000..49cdbcc018
--- /dev/null
+++ b/examples/offline_inference/ming_flash_omni/end2end.py
@@ -0,0 +1,485 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# Partial example cases are referred from
+# https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/cookbook.ipynb
+import os
+import time
+from typing import NamedTuple
+
+import librosa
+import numpy as np
+import vllm
+from PIL import Image
+from transformers import AutoProcessor
+from vllm import SamplingParams
+from vllm.assets.audio import AudioAsset
+from vllm.assets.image import ImageAsset
+from vllm.assets.video import VideoAsset, video_to_ndarrays
+from vllm.multimodal.image import convert_image_mode
+from vllm.utils.argparse_utils import FlexibleArgumentParser
+
+import vllm_omni
+from vllm_omni.entrypoints.omni import Omni
+
+# Imports the processor also registers itself
+from vllm_omni.transformers_utils.processors.ming import MingFlashOmniProcessor # noqa: F401
+
+SEED = 42
+MODEL_NAME = "Jonathan1909/Ming-flash-omni-2.0"
+
+
+class QueryResult(NamedTuple):
+ inputs: dict
+ limit_mm_per_prompt: dict[str, int]
+
+
+def get_text_query(processor: MingFlashOmniProcessor, question: str | None = None) -> QueryResult:
+ if question is None:
+ question = "请详细介绍鹦鹉的生活习性。"
+ conversation = [{"role": "HUMAN", "content": question}]
+ prompt = processor.apply_chat_template(conversation, tokenize=False)
+ return QueryResult(
+ inputs={"prompt": prompt},
+ limit_mm_per_prompt={},
+ )
+
+
+def get_image_query(
+ processor: MingFlashOmniProcessor,
+ question: str | None = None,
+ image_path: str | None = None,
+) -> QueryResult:
+ if question is None:
+ question = "Describe this image in detail."
+
+ if image_path:
+ if not os.path.exists(image_path):
+ raise FileNotFoundError(f"Image file not found: {image_path}")
+ image_data = convert_image_mode(Image.open(image_path), "RGB")
+ else:
+ image_data = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
+
+ conversation = [
+ {
+ "role": "HUMAN",
+ "content": [
+ {"type": "image", "image": image_data},
+ {"type": "text", "text": question},
+ ],
+ }
+ ]
+ prompt = processor.apply_chat_template(conversation, tokenize=False)
+
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {"image": image_data},
+ },
+ limit_mm_per_prompt={"image": 1},
+ )
+
+
+def get_audio_query(
+ processor: MingFlashOmniProcessor,
+ question: str | None = None,
+ audio_path: str | None = None,
+ sampling_rate: int = 16000,
+) -> QueryResult:
+ if question is None:
+ question = "Please recognize the language of this speech and transcribe it. Format: oral."
+
+ if audio_path:
+ if not os.path.exists(audio_path):
+ raise FileNotFoundError(f"Audio file not found: {audio_path}")
+ audio_signal, sr = librosa.load(audio_path, sr=sampling_rate)
+ audio_data = (audio_signal.astype(np.float32), sr)
+ else:
+ audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
+
+ # Use a string for "audio" so the processor counts it as 1 audio input
+ conversation = [
+ {
+ "role": "HUMAN",
+ "content": [
+ {"type": "audio", "audio": "input"},
+ {"type": "text", "text": question},
+ ],
+ }
+ ]
+ prompt = processor.apply_chat_template(conversation, tokenize=False)
+
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {"audio": audio_data},
+ },
+ limit_mm_per_prompt={"audio": 1},
+ )
+
+
+def get_video_query(
+ processor: MingFlashOmniProcessor,
+ question: str | None = None,
+ video_path: str | None = None,
+ num_frames: int = 16,
+) -> QueryResult:
+ if question is None:
+ question = "Describe what is happening in this video."
+
+ if video_path:
+ if not os.path.exists(video_path):
+ raise FileNotFoundError(f"Video file not found: {video_path}")
+ video_frames = video_to_ndarrays(video_path, num_frames=num_frames)
+ else:
+ video_frames = VideoAsset(name="baby_reading", num_frames=num_frames).np_ndarrays
+
+ conversation = [
+ {
+ "role": "HUMAN",
+ "content": [
+ {"type": "video"},
+ {"type": "text", "text": question},
+ ],
+ }
+ ]
+ prompt = processor.apply_chat_template(conversation, tokenize=False)
+
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {"video": video_frames},
+ },
+ limit_mm_per_prompt={"video": 1},
+ )
+
+
+def get_mixed_modalities_query(
+ processor: MingFlashOmniProcessor,
+ image_path: str | None = None,
+ audio_path: str | None = None,
+ sampling_rate: int = 16000,
+) -> QueryResult:
+ """Mixed image + audio understanding."""
+ question = "Describe the image, and recognize the language of this speech and transcribe it. Format: oral"
+
+ if image_path:
+ if not os.path.exists(image_path):
+ raise FileNotFoundError(f"Image file not found: {image_path}")
+ image_data = convert_image_mode(Image.open(image_path), "RGB")
+ else:
+ image_data = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
+
+ if audio_path:
+ if not os.path.exists(audio_path):
+ raise FileNotFoundError(f"Audio file not found: {audio_path}")
+ sig, sr = librosa.load(audio_path, sr=sampling_rate)
+ audio_data = (sig.astype(np.float32), sr)
+ else:
+ audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
+
+ conversation = [
+ {
+ "role": "HUMAN",
+ "content": [
+ {"type": "image", "image": image_data},
+ {"type": "audio", "audio": "input"},
+ {"type": "text", "text": question},
+ ],
+ }
+ ]
+ prompt = processor.apply_chat_template(conversation, tokenize=False)
+
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {"image": image_data, "audio": audio_data},
+ },
+ limit_mm_per_prompt={"image": 1, "audio": 1},
+ )
+
+
+def get_reasoning_query(
+ processor: MingFlashOmniProcessor,
+ question: str | None = None,
+ image_path: str | None = None,
+) -> QueryResult:
+ if question is None:
+ # NOTE: To use the following default question, input with example figure provided by Ming
+ # https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/figures/cases/3_0.png
+ # E.g.,
+ # python examples/offline_inference/ming_flash_omni/end2end.py -q reasoning --image-path ./3_0.png
+ # Otherwise, the problem solving might be false.
+ question = (
+ "Based on the following rules:\n•\tYou control the smiley face character\n"
+ "•\tYou can move up, down, left, and right, and only a single square at a time\n"
+ "•\tWalls are dark grey and cannot be moved into\n•\tThe brown square is a box\n•"
+ "\tThe box can be pushed by moving into it (i.e., if you are in the square "
+ "adjacent to the box to the left, and move onto the square with the box, "
+ "the box will move one square to the right).\n"
+ "•\tThe box cannot be pushed into walls\n"
+ "•\tThe blue door at the bottom is locked and cannot be passed through, "
+ "unless the box is placed on the blue square\n"
+ "•\tThe square beneath the blue door is the exit\n"
+ "•\tMoving from one square to another\n\n"
+ "Let's assume a coordinate system where the smiley face is "
+ "on the top left at (1,1) and the square below it is (1,2). "
+ "The smiley face performs the following moves: {down, right, right, right}, "
+ "such that the smiley face is at square (4,2) and the box is in square (5,2). "
+ "What are the next sequence of moves that must be done to move the box down to (5,3)? "
+ "Give your answer as a comma separated list."
+ )
+
+ if image_path:
+ if not os.path.exists(image_path):
+ raise FileNotFoundError(f"Image file not found: {image_path}")
+ image_data = convert_image_mode(Image.open(image_path), "RGB")
+ conversation = [
+ {
+ "role": "HUMAN",
+ "content": [
+ {"type": "image", "image": image_data},
+ {"type": "text", "text": question},
+ ],
+ }
+ ]
+ prompt = processor.apply_chat_template(conversation, tokenize=False, use_cot_system_prompt=True)
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {"image": image_data},
+ },
+ limit_mm_per_prompt={"image": 1},
+ )
+
+ conversation = [{"role": "HUMAN", "content": question}]
+ prompt = processor.apply_chat_template(conversation, tokenize=False, use_cot_system_prompt=True)
+ return QueryResult(
+ inputs={"prompt": prompt},
+ limit_mm_per_prompt={},
+ )
+
+
+query_map = {
+ "text": get_text_query,
+ "use_audio": get_audio_query,
+ "use_image": get_image_query,
+ "use_video": get_video_query,
+ "use_mixed_modalities": get_mixed_modalities_query,
+ "reasoning": get_reasoning_query,
+}
+
+
+def main(args):
+ print(
+ "=" * 20,
+ "\n",
+ f"vllm version: {vllm.__version__}\n",
+ f"vllm-omni version: {vllm_omni.__version__}\n",
+ "=" * 20,
+ sep="",
+ )
+
+ processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
+ assert isinstance(processor, MingFlashOmniProcessor), f"Wrong processor type being used: {type(processor)}"
+
+ query_func = query_map[args.query_type]
+ if args.query_type == "use_image":
+ query_result = query_func(processor, image_path=args.image_path)
+ elif args.query_type == "use_audio":
+ query_result = query_func(processor, audio_path=args.audio_path, sampling_rate=args.sampling_rate)
+ elif args.query_type == "use_video":
+ query_result = query_func(processor, video_path=args.video_path, num_frames=args.num_frames)
+ elif args.query_type == "use_mixed_modalities":
+ query_result = query_func(
+ processor,
+ image_path=args.image_path,
+ audio_path=args.audio_path,
+ sampling_rate=args.sampling_rate,
+ )
+ elif args.query_type == "reasoning":
+ query_result = query_func(processor, image_path=args.image_path)
+ else:
+ query_result = query_func(processor)
+
+ # Initialize Omni (with thinker-only stage config)
+ omni = Omni(
+ model=MODEL_NAME,
+ stage_configs_path=args.stage_configs_path,
+ log_stats=args.log_stats,
+ init_timeout=args.init_timeout,
+ stage_init_timeout=args.stage_init_timeout,
+ )
+
+ # Thinker sampling params
+ thinker_sampling_params = SamplingParams(
+ temperature=0.4,
+ top_p=0.9,
+ max_tokens=args.max_tokens,
+ repetition_penalty=1.05,
+ seed=SEED,
+ detokenize=True,
+ )
+ sampling_params_list = [thinker_sampling_params]
+
+ prompts = [query_result.inputs for _ in range(args.num_prompts)]
+
+ if args.modalities is not None:
+ output_modalities = args.modalities.split(",")
+ for prompt in prompts:
+ prompt["modalities"] = output_modalities
+
+ total_requests = len(prompts)
+ processed_count = 0
+ print(f"Query type: {args.query_type}")
+ print(f"Number of prompts: {total_requests}")
+
+ output_dir = args.output_dir
+ os.makedirs(output_dir, exist_ok=True)
+
+ profiler_enabled = args.enable_profiler
+ if profiler_enabled:
+ omni.start_profile(stages=args.profiler_stages)
+
+ for stage_outputs in omni.generate(prompts, sampling_params_list):
+ output = stage_outputs.request_output
+ if stage_outputs.final_output_type == "text":
+ request_id = output.request_id
+ text_output = output.outputs[0].text
+ lines = []
+ lines.append("Prompt:\n")
+ lines.append(str(output.prompt) + "\n")
+ lines.append("Text Output:\n")
+ lines.append(str(text_output).strip() + "\n")
+ print(*lines, sep="")
+
+ # Save to file
+ out_txt = os.path.join(output_dir, f"{request_id}.txt")
+ try:
+ with open(out_txt, "w", encoding="utf-8") as f:
+ f.writelines(lines)
+ print(f"Request ID: {request_id}, text saved to {out_txt}")
+ except Exception as e:
+ print(f"Failed to write output file {out_txt}: {e}")
+
+ elif stage_outputs.final_output_type == "audio":
+ raise NotImplementedError("Add audio example after talker supported.")
+
+ processed_count += 1
+ if profiler_enabled and processed_count >= total_requests:
+ print(f"[Info] Processed {processed_count}/{total_requests}. Stopping profiler inside active loop...")
+ # Stop the profiler while workers are still alive
+ omni.stop_profile(stages=args.profiler_stages)
+
+ print("[Info] Waiting 30s for workers to write trace files to disk...")
+ time.sleep(30)
+ print("[Info] Trace export wait time finished.")
+
+ omni.close()
+
+
+def parse_args():
+ parser = FlexibleArgumentParser(description="Ming-flash-omni 2.0 offline inference example")
+ parser.add_argument(
+ "--query-type",
+ "-q",
+ type=str,
+ default="text",
+ choices=query_map.keys(),
+ help="Query type.",
+ )
+ parser.add_argument(
+ "--stage-configs-path",
+ type=str,
+ default=None,
+ help="Path to a stage configs YAML file.",
+ )
+ parser.add_argument(
+ "--log-stats",
+ action="store_true",
+ default=False,
+ help="Enable detailed statistics logging.",
+ )
+ parser.add_argument("--init-timeout", type=int, default=2000, help="Timeout for initializing in seconds.")
+ parser.add_argument(
+ "--stage-init-timeout",
+ type=int,
+ default=2000,
+ help="Timeout for initializing a single stage in seconds.",
+ )
+ parser.add_argument(
+ "--enable-profiler",
+ action="store_true",
+ default=False,
+ help="Enables profiling when set.",
+ )
+ parser.add_argument(
+ "--profiler-stages",
+ type=int,
+ nargs="*",
+ default=[0],
+ help="List of stage IDs to profile. If not set, profiles all stages.",
+ )
+ parser.add_argument(
+ "--image-path",
+ "-i",
+ type=str,
+ default=None,
+ help="Path to local image file. Uses default asset if not provided.",
+ )
+ parser.add_argument(
+ "--audio-path",
+ "-a",
+ type=str,
+ default=None,
+ help="Path to local audio file. Uses default asset if not provided.",
+ )
+ parser.add_argument(
+ "--video-path",
+ "-v",
+ type=str,
+ default=None,
+ help="Path to local video file. Uses default asset if not provided.",
+ )
+ parser.add_argument(
+ "--num-frames",
+ type=int,
+ default=16,
+ help="Number of frames to extract from video.",
+ )
+ parser.add_argument(
+ "--sampling-rate",
+ type=int,
+ default=16000,
+ help="Sampling rate for audio loading.",
+ )
+ parser.add_argument(
+ "--max-tokens",
+ type=int,
+ default=16384,
+ help="Maximum tokens to generate.",
+ )
+ parser.add_argument(
+ "--num-prompts",
+ type=int,
+ default=1,
+ help="Number of prompts to generate.",
+ )
+ parser.add_argument(
+ "--modalities",
+ type=str,
+ default=None,
+ help="Output modalities (comma-separated).",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="output_ming",
+ help="Output directory for results.",
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/offline_inference/omnivoice/end2end.py b/examples/offline_inference/omnivoice/end2end.py
index b41379b011..9371c95142 100644
--- a/examples/offline_inference/omnivoice/end2end.py
+++ b/examples/offline_inference/omnivoice/end2end.py
@@ -103,9 +103,9 @@ def run_e2e():
if not os.path.exists(args.ref_audio):
raise FileNotFoundError(f"Reference audio not found: {args.ref_audio}")
- import librosa
+ from vllm.multimodal.media.audio import load_audio
- audio_signal, sr = librosa.load(args.ref_audio, sr=None)
+ audio_signal, sr = load_audio(args.ref_audio, sr=None)
multi_modal_data["audio"] = (audio_signal.astype(np.float32), sr)
mm_processor_kwargs["ref_text"] = args.ref_text or ""
mm_processor_kwargs["sample_rate"] = sr
diff --git a/examples/offline_inference/qwen2_5_omni/README.md b/examples/offline_inference/qwen2_5_omni/README.md
index 20740a0da0..e2eae8a96b 100644
--- a/examples/offline_inference/qwen2_5_omni/README.md
+++ b/examples/offline_inference/qwen2_5_omni/README.md
@@ -60,11 +60,3 @@ If media file paths are not provided, the script will use default assets. Suppor
- `mixed_modalities`: Audio + image + video
- `use_audio_in_video`: Extract audio from video
- `text`: Text-only query
-
-### FAQ
-
-If you encounter error about backend of librosa, try to install ffmpeg with command below.
-```
-sudo apt update
-sudo apt install ffmpeg
-```
diff --git a/examples/offline_inference/qwen2_5_omni/end2end.py b/examples/offline_inference/qwen2_5_omni/end2end.py
index 7bba599830..dfe124700d 100644
--- a/examples/offline_inference/qwen2_5_omni/end2end.py
+++ b/examples/offline_inference/qwen2_5_omni/end2end.py
@@ -9,7 +9,6 @@
import time
from typing import NamedTuple
-import librosa
import numpy as np
import soundfile as sf
from PIL import Image
@@ -17,6 +16,7 @@
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset, video_to_ndarrays
from vllm.multimodal.image import convert_image_mode
+from vllm.multimodal.media.audio import load_audio
from vllm.sampling_params import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser
@@ -96,7 +96,7 @@ def get_mixed_modalities_query(
if audio_path:
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
- audio_signal, sr = librosa.load(audio_path, sr=sampling_rate)
+ audio_signal, sr = load_audio(audio_path, sr=sampling_rate)
audio_data = (audio_signal.astype(np.float32), sr)
else:
audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
@@ -130,7 +130,7 @@ def get_use_audio_in_video_query(
raise FileNotFoundError(f"Video file not found: {video_path}")
video_frames = video_to_ndarrays(video_path, num_frames=num_frames)
# Extract audio from video file
- audio_signal, sr = librosa.load(video_path, sr=sampling_rate)
+ audio_signal, sr = load_audio(video_path, sr=sampling_rate)
audio = (audio_signal.astype(np.float32), sr)
else:
asset = VideoAsset(name="baby_reading", num_frames=num_frames)
@@ -165,7 +165,7 @@ def get_multi_audios_query(audio_path: str | None = None, sampling_rate: int = 1
if audio_path:
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
- audio_signal, sr = librosa.load(audio_path, sr=sampling_rate)
+ audio_signal, sr = load_audio(audio_path, sr=sampling_rate)
audio_data = (audio_signal.astype(np.float32), sr)
# Use the provided audio as the first audio, default as second
audio_list = [
@@ -261,7 +261,7 @@ def get_audio_query(question: str = None, audio_path: str | None = None, samplin
if audio_path:
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
- audio_signal, sr = librosa.load(audio_path, sr=sampling_rate)
+ audio_signal, sr = load_audio(audio_path, sr=sampling_rate)
audio_data = (audio_signal.astype(np.float32), sr)
else:
audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
@@ -320,14 +320,7 @@ def main(args):
query_result = query_func(audio_path=audio_path, sampling_rate=sampling_rate)
else:
query_result = query_func()
- omni = Omni(
- model=model_name,
- log_stats=args.log_stats,
- stage_init_timeout=args.stage_init_timeout,
- batch_timeout=args.batch_timeout,
- init_timeout=args.init_timeout,
- shm_threshold_bytes=args.shm_threshold_bytes,
- )
+ omni = Omni.from_cli_args(args, model=model_name)
thinker_sampling_params = SamplingParams(
temperature=0.0, # Deterministic - no randomness
top_p=1.0, # Disable nucleus sampling
diff --git a/examples/offline_inference/qwen3_omni/README.md b/examples/offline_inference/qwen3_omni/README.md
index b3e8592532..0710faa133 100644
--- a/examples/offline_inference/qwen3_omni/README.md
+++ b/examples/offline_inference/qwen3_omni/README.md
@@ -70,8 +70,8 @@ For true stage-level concurrency -- where downstream stages (Talker, Code2Wav)
start **before** the upstream stage (Thinker) finishes -- use the async_chunk
example. This requires:
-1. A stage config YAML with ``async_chunk: true`` (e.g.
- ``qwen3_omni_moe_async_chunk.yaml``).
+1. A deploy config YAML with ``async_chunk: true`` (e.g.
+ ``qwen3_omni_moe.yaml``).
2. Hardware that matches the config (e.g. 2x H100 for the default 3-stage
config).
@@ -101,18 +101,10 @@ python end2end_async_chunk.py --query-type text --modalities text
```bash
python end2end_async_chunk.py \
--query-type use_audio \
- --stage-configs-path /path/to/your_async_chunk.yaml
+ --deploy-config /path/to/your_deploy_config.yaml
```
> **Note**: The synchronous ``end2end.py`` (using ``Omni``) is still the
> recommended entry point for non-async-chunk workflows. Only use the
> async_chunk example when you need the stage-level concurrency semantics
> described in PR #962 / #1151.
-
-### FAQ
-
-If you encounter error about backend of librosa, try to install ffmpeg with command below.
-```
-sudo apt update
-sudo apt install ffmpeg
-```
diff --git a/examples/offline_inference/qwen3_omni/end2end.py b/examples/offline_inference/qwen3_omni/end2end.py
index 155eca4ed9..f028c32aa1 100644
--- a/examples/offline_inference/qwen3_omni/end2end.py
+++ b/examples/offline_inference/qwen3_omni/end2end.py
@@ -9,7 +9,6 @@
import time
from typing import NamedTuple
-import librosa
import numpy as np
import soundfile as sf
import vllm
@@ -19,6 +18,7 @@
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset, video_to_ndarrays
from vllm.multimodal.image import convert_image_mode
+from vllm.multimodal.media.audio import load_audio
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm_omni.entrypoints.omni import Omni
@@ -129,7 +129,7 @@ def get_audio_query(question: str = None, audio_path: str | None = None, samplin
if audio_path:
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
- audio_signal, sr = librosa.load(audio_path, sr=sampling_rate)
+ audio_signal, sr = load_audio(audio_path, sr=sampling_rate)
audio_data = (audio_signal.astype(np.float32), sr)
else:
audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
@@ -183,7 +183,7 @@ def get_mixed_modalities_query(
if audio_path:
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
- audio_signal, sr = librosa.load(audio_path, sr=sampling_rate)
+ audio_signal, sr = load_audio(audio_path, sr=sampling_rate)
audio_data = (audio_signal.astype(np.float32), sr)
else:
audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
@@ -294,14 +294,7 @@ def main(args):
else:
query_result = query_func()
- omni = Omni(
- model=model_name,
- dtype=args.dtype,
- stage_configs_path=args.stage_configs_path,
- log_stats=args.log_stats,
- stage_init_timeout=args.stage_init_timeout,
- init_timeout=args.init_timeout,
- )
+ omni = Omni.from_cli_args(args, model=model_name)
thinker_sampling_params = SamplingParams(
temperature=0.9,
diff --git a/examples/offline_inference/qwen3_omni/end2end_async_chunk.py b/examples/offline_inference/qwen3_omni/end2end_async_chunk.py
index 8adbae9eb6..f38922e943 100644
--- a/examples/offline_inference/qwen3_omni/end2end_async_chunk.py
+++ b/examples/offline_inference/qwen3_omni/end2end_async_chunk.py
@@ -14,7 +14,7 @@
Usage
-----
python end2end_async_chunk.py --query-type use_audio \
- --stage-configs-path
+ --deploy-config
See ``--help`` for all options.
"""
@@ -32,13 +32,13 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-import librosa
from PIL import Image
from vllm import SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset, video_to_ndarrays
from vllm.multimodal.image import convert_image_mode
+from vllm.multimodal.media.audio import load_audio
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm_omni.entrypoints.async_omni import AsyncOmni
@@ -89,7 +89,7 @@ def get_audio_query(
if audio_path:
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
- audio_signal, sr = librosa.load(audio_path, sr=sampling_rate)
+ audio_signal, sr = load_audio(audio_path, sr=sampling_rate)
audio_data = (audio_signal.astype(np.float32), sr)
else:
audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
@@ -179,20 +179,26 @@ def clone_prompt_for_request(template: dict) -> dict:
return cloned
-def _default_async_chunk_stage_configs_path() -> str | None:
- """Best-effort default stage config for running Qwen3-Omni with async_chunk.
+def _default_deploy_config_path() -> str | None:
+ """Best-effort default deploy config for running Qwen3-Omni with async_chunk.
- When this example is executed from within the repository, we resolve the
- default YAML path relative to this file. When installed elsewhere, the
- file may not exist and callers should pass --stage-configs-path explicitly.
+ The default ``vllm_omni/deploy/qwen3_omni_moe.yaml`` ships with
+ ``async_chunk: true`` at the top level, so loading it is enough to
+ enable async-chunk semantics. To disable it, copy the YAML and set
+ ``async_chunk: false`` (or pass ``--deploy-config`` to a YAML that
+ overrides the flag).
+
+ When this example is executed from within the repository, we resolve
+ the default YAML path relative to this file. When installed elsewhere,
+ the file may not exist and callers should pass ``--deploy-config``
+ explicitly.
"""
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
candidate = os.path.join(
repo_root,
"vllm_omni",
- "model_executor",
- "stage_configs",
- "qwen3_omni_moe_async_chunk.yaml",
+ "deploy",
+ "qwen3_omni_moe.yaml",
)
return candidate if os.path.exists(candidate) else None
@@ -374,15 +380,16 @@ async def run_all(args):
prompt["modalities"] = output_modalities
# Create AsyncOmni
- print(f"[Info] Creating AsyncOmni with stage_configs_path={args.stage_configs_path}")
+ print(f"[Info] Creating AsyncOmni with deploy_config={args.deploy_config}")
async_omni = None
try:
- async_omni = AsyncOmni(
- model=args.model,
- stage_configs_path=args.stage_configs_path,
- log_stats=args.log_stats,
- stage_init_timeout=args.stage_init_timeout,
- )
+ # ``from_cli_args`` expands vars(args) into kwargs and auto-captures
+ # ``_cli_explicit_keys`` from ``sys.argv[1:]`` so argparse defaults
+ # do not silently override deploy YAML values. Mirrors the
+ # ``EngineArgs.from_cli_args`` pattern used throughout vllm /
+ # vllm-omni. ``deploy_config=None`` (the default) falls through to
+ # the bundled ``vllm_omni/deploy/qwen3_omni_moe.yaml``.
+ async_omni = AsyncOmni.from_cli_args(args)
# Use default sampling params from stage config (they are pre-configured
# in the YAML for each stage).
@@ -470,11 +477,11 @@ def parse_args():
help="Query type.",
)
parser.add_argument(
- "--stage-configs-path",
+ "--deploy-config",
type=str,
- default=_default_async_chunk_stage_configs_path(),
+ default=_default_deploy_config_path(),
help=(
- "Path to an async_chunk stage config YAML. "
+ "Path to a deploy config YAML. "
"If not set, uses the model's default config "
"(make sure it has async_chunk: true)."
),
diff --git a/examples/offline_inference/qwen3_omni/run_multiple_prompts_async_chunk.sh b/examples/offline_inference/qwen3_omni/run_multiple_prompts_async_chunk.sh
index 809054867c..2f2be20915 100755
--- a/examples/offline_inference/qwen3_omni/run_multiple_prompts_async_chunk.sh
+++ b/examples/offline_inference/qwen3_omni/run_multiple_prompts_async_chunk.sh
@@ -17,7 +17,7 @@ REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
python "${SCRIPT_DIR}/end2end_async_chunk.py" \
--query-type text \
--txt-prompts "${SCRIPT_DIR}/text_prompts_10.txt" \
- --stage-configs-path "${REPO_ROOT}/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml" \
+ --deploy-config "${REPO_ROOT}/vllm_omni/deploy/qwen3_omni_moe.yaml" \
--output-dir output_audio_async_chunk \
--max-in-flight 2 \
"$@"
diff --git a/examples/offline_inference/qwen3_omni/run_single_prompt_async_chunk.sh b/examples/offline_inference/qwen3_omni/run_single_prompt_async_chunk.sh
index 918c7ee4fd..9ef69293cb 100755
--- a/examples/offline_inference/qwen3_omni/run_single_prompt_async_chunk.sh
+++ b/examples/offline_inference/qwen3_omni/run_single_prompt_async_chunk.sh
@@ -6,13 +6,13 @@
# achieving true stage-level concurrency via chunk-level streaming.
#
# Prerequisites:
-# - An async_chunk stage config YAML (e.g. qwen3_omni_moe_async_chunk.yaml)
+# - A deploy config YAML (e.g. qwen3_omni_moe.yaml)
# - Hardware matching the config (e.g. 2x H100 for the default 3-stage config)
#
# Usage:
# bash run_single_prompt_async_chunk.sh
# bash run_single_prompt_async_chunk.sh --query-type text --modalities text
-# bash run_single_prompt_async_chunk.sh --stage-configs-path /path/to/custom.yaml
+# bash run_single_prompt_async_chunk.sh --deploy-config /path/to/custom.yaml
set -euo pipefail
@@ -21,6 +21,6 @@ REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
python "${SCRIPT_DIR}/end2end_async_chunk.py" \
--query-type use_audio \
- --stage-configs-path "${REPO_ROOT}/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml" \
+ --deploy-config "${REPO_ROOT}/vllm_omni/deploy/qwen3_omni_moe.yaml" \
--output-dir output_audio_async_chunk \
"$@"
diff --git a/examples/offline_inference/qwen3_tts/README.md b/examples/offline_inference/qwen3_tts/README.md
index bf59dc9ba4..2971ad716a 100644
--- a/examples/offline_inference/qwen3_tts/README.md
+++ b/examples/offline_inference/qwen3_tts/README.md
@@ -15,11 +15,11 @@ Please refer to the [stage configuration documentation](https://docs.vllm.ai/pro
### ROCm Dependencies
-You will need to install these two dependencies `onnxruntime-rocm` and `sox`.
+You will need to install the dependency `onnxruntime-rocm`.
```
pip uninstall onnxruntime # should be removed before we can install onnxruntime-rocm
-pip install onnxruntime-rocm sox
+pip install onnxruntime-rocm
```
## Quick Start
@@ -104,13 +104,13 @@ completes. This demonstrates that audio data is available progressively rather t
## Batched Decoding
-The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, provide a stage config with `max_num_seqs > 1` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`.
+The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, set `max_num_seqs > 1` on both stages via `--stage-overrides` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`.
```
python end2end.py --query-type CustomVoice \
--txt-prompts benchmark_prompts.txt \
--batch-size 4 \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml
+ --stage-overrides '{"0":{"max_num_seqs":4,"gpu_memory_utilization":0.2},"1":{"max_num_seqs":4,"gpu_memory_utilization":0.2}}'
```
**Important:** `--batch-size` must match a CUDA graph capture size (1, 2, 4, 8, 16...) because the Talker's code predictor KV cache is sized to `max_num_seqs`, and CUDA graphs pad the batch to the next capture size. Both stages need `max_num_seqs >= batch_size` in the stage config for batching to take effect. If only stage 1 has a higher `max_num_seqs`, it won't help — stage 1 can only batch chunks from requests that are in-flight simultaneously, which requires stage 0 to also process multiple requests concurrently.
diff --git a/examples/offline_inference/qwen3_tts/end2end.py b/examples/offline_inference/qwen3_tts/end2end.py
index 901418c39b..77da356b4f 100644
--- a/examples/offline_inference/qwen3_tts/end2end.py
+++ b/examples/offline_inference/qwen3_tts/end2end.py
@@ -366,12 +366,7 @@ def main(args):
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
- omni = Omni(
- model=model_name,
- stage_configs_path=args.stage_configs_path,
- log_stats=args.log_stats,
- stage_init_timeout=args.stage_init_timeout,
- )
+ omni = Omni.from_cli_args(args, model=model_name)
batch_size = args.batch_size
for batch_start in range(0, len(inputs), batch_size):
@@ -387,12 +382,7 @@ async def main_streaming(args):
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
- omni = AsyncOmni(
- model=model_name,
- stage_configs_path=args.stage_configs_path,
- log_stats=args.log_stats,
- stage_init_timeout=args.stage_init_timeout,
- )
+ omni = AsyncOmni.from_cli_args(args, model=model_name)
for i, prompt in enumerate(inputs):
request_id = str(i)
diff --git a/examples/offline_inference/text_to_audio/README.md b/examples/offline_inference/text_to_audio/README.md
index 7edc38092a..50bab3e2f2 100644
--- a/examples/offline_inference/text_to_audio/README.md
+++ b/examples/offline_inference/text_to_audio/README.md
@@ -23,6 +23,7 @@ python text_to_audio.py \
--guidance-scale 7.0 \
--audio-length 10.0 \
--num-inference-steps 100 \
+ --cache-backend tea_cache \
--output stable_audio_output.wav
```
@@ -34,4 +35,5 @@ Key arguments:
- `--guidance-scale`: classifier-free guidance scale.
- `--audio-length`: audio duration in seconds.
- `--num-inference-steps`: diffusion sampling steps.(more steps = higher quality, slower).
+- `--cache-backend`: cache acceleration backend. Stable Audio currently supports `tea_cache`.
- `--output`: path to save the generated WAV file.
diff --git a/examples/offline_inference/text_to_audio/text_to_audio.py b/examples/offline_inference/text_to_audio/text_to_audio.py
index a6968c419f..3adb3ad53a 100644
--- a/examples/offline_inference/text_to_audio/text_to_audio.py
+++ b/examples/offline_inference/text_to_audio/text_to_audio.py
@@ -11,6 +11,7 @@
python text_to_audio.py --prompt "The sound of a dog barking"
python text_to_audio.py --prompt "A piano playing a gentle melody" --audio-length 10.0
python text_to_audio.py --prompt "Thunder and rain sounds" --negative-prompt "Low quality"
+ python text_to_audio.py --prompt "A soft synth pad" --cache-backend tea_cache
"""
import argparse
@@ -90,6 +91,23 @@ def parse_args() -> argparse.Namespace:
default=44100,
help="Sample rate for output audio (Stable Audio uses 44100 Hz).",
)
+ parser.add_argument(
+ "--cache-backend",
+ type=str,
+ default=None,
+ choices=["tea_cache"],
+ help=(
+ "Cache backend to use for acceleration. "
+ "Stable Audio currently supports 'tea_cache'. "
+ "Default: None (no cache acceleration)."
+ ),
+ )
+ parser.add_argument(
+ "--tea-cache-rel-l1-thresh",
+ type=float,
+ default=0.2,
+ help="[tea_cache] Threshold for accumulated relative L1 distance.",
+ )
parser.add_argument(
"--enable-diffusion-pipeline-profiler",
action="store_true",
@@ -124,6 +142,11 @@ def save_audio(audio_data: np.ndarray, output_path: str, sample_rate: int = 4410
def main():
args = parse_args()
generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed)
+ cache_config = None
+ if args.cache_backend == "tea_cache":
+ cache_config = {
+ "rel_l1_thresh": args.tea_cache_rel_l1_thresh,
+ }
print(f"\n{'=' * 60}")
print("Stable Audio Open - Text-to-Audio Generation")
@@ -134,12 +157,15 @@ def main():
print(f" Audio length: {args.audio_length}s")
print(f" Inference steps: {args.num_inference_steps}")
print(f" Guidance scale: {args.guidance_scale}")
+ print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}")
print(f" Seed: {args.seed}")
print(f"{'=' * 60}\n")
# Initialize Omni with Stable Audio model
omni = Omni(
model=args.model,
+ cache_backend=args.cache_backend,
+ cache_config=cache_config,
enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler,
)
diff --git a/examples/offline_inference/text_to_image/README.md b/examples/offline_inference/text_to_image/README.md
index cc295e8279..c71773972b 100644
--- a/examples/offline_inference/text_to_image/README.md
+++ b/examples/offline_inference/text_to_image/README.md
@@ -29,7 +29,8 @@ This folder provides several entrypoints for experimenting with text-to-image di
| `AIDC-AI/Ovis-Image-7B` | 1024 x 1024 | 71.8 | 17.1 |
| `OmniGen2/OmniGen2` | 1024 x 1024 | 20.1 | 14.7 |
| `stabilityai/stable-diffusion-3.5-medium` | 1024 x 1024 | 20.1 | 15.6 |
-| `black-forest-labs/FLUX.1-dev` | 1024 x 1024 | 77.6 | 31.4 |
+| `black-forest-labs/FLUX.1-dev` | 1024 x 1024 | 33.9 | 31.4 |
+| `black-forest-labs/FLUX.1-schnell` | 1024 x 1024 | 33.9 | 31.4 |
| `black-forest-labs/FLUX.2-klein-4B` | 1024 x 1024 | 72.7 | 14.9 |
| `black-forest-labs/FLUX.2-klein-9B` | 1024 x 1024 | 37.1 | 32.3 |
| `black-forest-labs/FLUX.2-dev` | 1024 x 1024 | 65.7 | >80 (CPU offload required) |
diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py
index 615e4067ed..3b3f8e77cf 100644
--- a/examples/offline_inference/text_to_image/text_to_image.py
+++ b/examples/offline_inference/text_to_image/text_to_image.py
@@ -242,6 +242,18 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable logging of diffusion pipeline stats.",
)
+ parser.add_argument(
+ "--init-timeout",
+ type=int,
+ default=600,
+ help="Timeout for initializing a single stage in seconds (default: 600s)",
+ )
+ parser.add_argument(
+ "--stage-init-timeout",
+ type=int,
+ default=600,
+ help="Timeout for initializing a single stage in seconds (default: 600s)",
+ )
parser.add_argument(
"--use-system-prompt",
type=str,
@@ -346,6 +358,8 @@ def main():
"mode": "text-to-image",
"log_stats": args.log_stats,
"enable_diffusion_pipeline_profiler": args.enable_diffusion_pipeline_profiler,
+ "init_timeout": args.init_timeout,
+ "stage_init_timeout": args.stage_init_timeout,
**lora_args,
**quant_kwargs,
}
diff --git a/examples/offline_inference/voxcpm/README.md b/examples/offline_inference/voxcpm/README.md
new file mode 100644
index 0000000000..1eaea9b0db
--- /dev/null
+++ b/examples/offline_inference/voxcpm/README.md
@@ -0,0 +1,123 @@
+# VoxCPM Offline Example
+
+This directory contains the minimal offline VoxCPM example for vLLM Omni.
+
+`end2end.py` is intentionally small and only covers:
+
+- single text-to-speech
+- single voice cloning with `ref_audio` + `ref_text`
+- non-streaming with `vllm_omni/model_executor/stage_configs/voxcpm.yaml`
+- streaming with `vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml`
+
+Advanced workflows were moved out of the getting-started example:
+
+- `benchmarks/voxcpm/vllm_omni/bench_tts_offline.py`: warmup, batch prompts, profiler, offline TTFP / RTF
+- `benchmarks/voxcpm/vllm_omni/run_offline_matrix.py`: fixed offline smoke matrix
+- `benchmarks/voxcpm/`: benchmark scripts and benchmark docs
+
+## Prerequisites
+
+Install VoxCPM in one of these ways:
+
+```bash
+pip install voxcpm
+```
+
+or point vLLM Omni to the local VoxCPM source tree:
+
+```bash
+export VLLM_OMNI_VOXCPM_CODE_PATH=/path/to/VoxCPM/src
+```
+
+The example writes WAV files with `soundfile`:
+
+```bash
+pip install soundfile
+```
+
+## Model Path
+
+Pass the native VoxCPM model directory directly:
+
+```bash
+export VOXCPM_MODEL=/path/to/voxcpm-model
+```
+
+If the native VoxCPM `config.json` does not contain HuggingFace metadata such as
+`model_type`, prepare a persistent HF-compatible config directory and point the
+stage configs to it with `VLLM_OMNI_VOXCPM_HF_CONFIG_PATH`:
+
+```bash
+export VLLM_OMNI_VOXCPM_HF_CONFIG_PATH=/tmp/voxcpm_hf_config
+mkdir -p "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH"
+cp "$VOXCPM_MODEL/config.json" "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH/config.json"
+cp "$VOXCPM_MODEL/generation_config.json" "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH/generation_config.json" 2>/dev/null || true
+python3 -c 'import json, os; p=os.path.join(os.environ["VLLM_OMNI_VOXCPM_HF_CONFIG_PATH"], "config.json"); cfg=json.load(open(p, "r", encoding="utf-8")); cfg["model_type"]="voxcpm"; cfg.setdefault("architectures", ["VoxCPMForConditionalGeneration"]); json.dump(cfg, open(p, "w", encoding="utf-8"), indent=2, ensure_ascii=False)'
+```
+
+If the model directory itself already has `model_type`, this extra directory is
+not required.
+
+## Quick Start
+
+Single text-to-speech, non-streaming:
+
+```bash
+python examples/offline_inference/voxcpm/end2end.py \
+ --model "$VOXCPM_MODEL" \
+ --text "This is a split-stage VoxCPM synthesis example running on vLLM Omni."
+```
+
+Single voice cloning, non-streaming:
+
+```bash
+python examples/offline_inference/voxcpm/end2end.py \
+ --model "$VOXCPM_MODEL" \
+ --text "This sentence is synthesized with a cloned voice." \
+ --ref-audio /path/to/reference.wav \
+ --ref-text "The exact transcript spoken in reference.wav."
+```
+
+Streaming:
+
+```bash
+python examples/offline_inference/voxcpm/end2end.py \
+ --model "$VOXCPM_MODEL" \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml \
+ --text "This is a split-stage VoxCPM streaming example running on vLLM Omni."
+```
+
+By default, `end2end.py` writes to `output_audio/` for non-streaming and
+`output_audio_streaming/` for streaming.
+
+## Advanced Workflows
+
+Use `benchmarks/voxcpm/vllm_omni/bench_tts_offline.py` when you need:
+
+- warmup runs
+- prompt files
+- batch JSONL inputs
+- profiler injection
+- offline TTFP / RTF emission
+
+Use `benchmarks/voxcpm/vllm_omni/run_offline_matrix.py` when you need the fixed offline smoke matrix that previously lived in `test.py`.
+
+Full matrix benchmark example:
+
+```bash
+python benchmarks/voxcpm/vllm_omni/run_offline_matrix.py \
+ --model "$VOXCPM_MODEL" \
+ --ref-audio /path/to/reference.wav \
+ --ref-text "The exact transcript spoken in reference.wav."
+```
+
+For online serving examples, see [examples/online_serving/voxcpm](../../online_serving/voxcpm/README.md).
+
+For benchmark reporting, see [benchmarks/voxcpm](../../../benchmarks/voxcpm/README.md).
+
+## Notes
+
+- `voxcpm.yaml` is the default non-streaming stage config.
+- `voxcpm_async_chunk.yaml` is the streaming stage config.
+- Streaming is currently single-request oriented; the fixed smoke matrix now lives in `benchmarks/voxcpm/vllm_omni/run_offline_matrix.py`.
+- `ref_text` must be the real transcript of the reference audio. Mismatched text usually causes obvious quality degradation.
diff --git a/examples/offline_inference/voxcpm/end2end.py b/examples/offline_inference/voxcpm/end2end.py
new file mode 100644
index 0000000000..980410feae
--- /dev/null
+++ b/examples/offline_inference/voxcpm/end2end.py
@@ -0,0 +1,206 @@
+"""Minimal offline VoxCPM example for vLLM Omni."""
+
+from __future__ import annotations
+
+import asyncio
+import time
+from pathlib import Path
+from typing import Any
+
+import soundfile as sf
+import torch
+from vllm.utils.argparse_utils import FlexibleArgumentParser
+
+from vllm_omni import AsyncOmni, Omni
+
+REPO_ROOT = Path(__file__).resolve().parents[3]
+DEFAULT_SYNC_STAGE_CONFIG = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm.yaml"
+
+
+def _build_prompt(args) -> dict[str, Any]:
+ additional_information: dict[str, list[Any]] = {
+ "text": [args.text],
+ "cfg_value": [args.cfg_value],
+ "inference_timesteps": [args.inference_timesteps],
+ "min_len": [args.min_len],
+ "max_new_tokens": [args.max_new_tokens],
+ }
+ if args.streaming_prefix_len is not None:
+ additional_information["streaming_prefix_len"] = [args.streaming_prefix_len]
+ if args.ref_audio is not None:
+ additional_information["ref_audio"] = [args.ref_audio]
+ if args.ref_text is not None:
+ additional_information["ref_text"] = [args.ref_text]
+ return {
+ "prompt_token_ids": [1],
+ "additional_information": additional_information,
+ }
+
+
+def _extract_audio_tensor(mm: dict[str, Any]) -> torch.Tensor:
+ audio = mm.get("audio", mm.get("model_outputs"))
+ if audio is None:
+ raise ValueError("No audio output found in multimodal output.")
+ if isinstance(audio, list):
+ parts = [torch.as_tensor(item).float().cpu().reshape(-1) for item in audio]
+ audio = torch.cat(parts, dim=-1) if parts else torch.zeros(0)
+ if not isinstance(audio, torch.Tensor):
+ audio = torch.as_tensor(audio)
+ return audio.float().cpu().reshape(-1)
+
+
+def _extract_sample_rate(mm: dict[str, Any]) -> int:
+ sr_raw = mm.get("sr", 24000)
+ if isinstance(sr_raw, list) and sr_raw:
+ sr_raw = sr_raw[-1]
+ if hasattr(sr_raw, "item"):
+ return int(sr_raw.item())
+ return int(sr_raw)
+
+
+def _is_streaming_stage_config(stage_config_path: str) -> bool:
+ return "async_chunk" in Path(stage_config_path).stem
+
+
+def _save_audio(audio: torch.Tensor, sample_rate: int, output_dir: Path, request_id: str) -> Path:
+ output_dir.mkdir(parents=True, exist_ok=True)
+ output_path = output_dir / f"output_{request_id}.wav"
+ sf.write(
+ output_path,
+ audio.float().cpu().clamp(-1.0, 1.0).numpy(),
+ sample_rate,
+ format="WAV",
+ subtype="PCM_16",
+ )
+ return output_path
+
+
+async def _run_streaming(args) -> Path:
+ prompt = _build_prompt(args)
+ output_dir = Path(args.output_dir) if args.output_dir is not None else Path("output_audio_streaming")
+ request_id = "streaming_example"
+ sample_rate = 24000
+ buffered_samples = 0
+ chunks: list[torch.Tensor] = []
+ started = time.perf_counter()
+ omni = AsyncOmni(
+ model=args.model,
+ stage_configs_path=args.stage_configs_path,
+ log_stats=args.log_stats,
+ stage_init_timeout=args.stage_init_timeout,
+ )
+ try:
+ async for stage_output in omni.generate(prompt, request_id=request_id):
+ mm = getattr(stage_output, "multimodal_output", None)
+ if not isinstance(mm, dict):
+ request_output = getattr(stage_output, "request_output", None)
+ if request_output is None:
+ continue
+ mm = getattr(request_output, "multimodal_output", None)
+ if not isinstance(mm, dict) and getattr(request_output, "outputs", None):
+ mm = getattr(request_output.outputs[0], "multimodal_output", None)
+ if not isinstance(mm, dict):
+ continue
+ audio = _extract_audio_tensor(mm)
+ if audio.numel() == 0:
+ continue
+ sample_rate = _extract_sample_rate(mm)
+ if audio.numel() > buffered_samples:
+ delta = audio[buffered_samples:]
+ buffered_samples = int(audio.numel())
+ else:
+ delta = audio
+ buffered_samples += int(delta.numel())
+ if delta.numel() > 0:
+ chunks.append(delta)
+ if not chunks:
+ raise RuntimeError("No streaming audio chunks received from VoxCPM.")
+ output_audio = torch.cat(chunks, dim=0)
+ output_path = _save_audio(output_audio, sample_rate, output_dir, request_id)
+ print(f"Saved streaming audio to: {output_path} ({time.perf_counter() - started:.2f}s)")
+ return output_path
+ finally:
+ omni.shutdown()
+
+
+def _run_sync(args) -> Path:
+ prompt = _build_prompt(args)
+ output_dir = Path(args.output_dir) if args.output_dir is not None else Path("output_audio")
+ request_id = "sync_example"
+ started = time.perf_counter()
+ last_mm: dict[str, Any] | None = None
+ omni = Omni(
+ model=args.model,
+ stage_configs_path=args.stage_configs_path,
+ log_stats=args.log_stats,
+ stage_init_timeout=args.stage_init_timeout,
+ )
+ for stage_outputs in omni.generate(prompt):
+ request_output = getattr(stage_outputs, "request_output", None)
+ if request_output is None:
+ continue
+ outputs = getattr(request_output, "outputs", None)
+ if outputs:
+ for output in outputs:
+ mm = getattr(output, "multimodal_output", None)
+ if isinstance(mm, dict):
+ last_mm = mm
+ mm = getattr(request_output, "multimodal_output", None)
+ if isinstance(mm, dict):
+ last_mm = mm
+ if last_mm is None:
+ raise RuntimeError("No audio output received from VoxCPM.")
+ output_path = _save_audio(
+ _extract_audio_tensor(last_mm),
+ _extract_sample_rate(last_mm),
+ output_dir,
+ request_id,
+ )
+ print(f"Saved audio to: {output_path} ({time.perf_counter() - started:.2f}s)")
+ return output_path
+
+
+def parse_args():
+ parser = FlexibleArgumentParser(description="Minimal offline VoxCPM example for vLLM Omni.")
+ parser.add_argument("--model", type=str, required=True, help="Local VoxCPM model directory.")
+ parser.add_argument(
+ "--stage-configs-path",
+ type=str,
+ default=str(DEFAULT_SYNC_STAGE_CONFIG),
+ help=("Stage config path. Use voxcpm.yaml for non-streaming or voxcpm_async_chunk.yaml for streaming."),
+ )
+ parser.add_argument("--text", type=str, required=True, help="Input text for synthesis.")
+ parser.add_argument("--ref-audio", type=str, default=None, help="Reference audio path for voice cloning.")
+ parser.add_argument("--ref-text", type=str, default=None, help="Transcript of the reference audio.")
+ parser.add_argument("--output-dir", type=str, default=None, help="Output directory for generated wav files.")
+ parser.add_argument("--cfg-value", type=float, default=2.0, help="Guidance value passed to VoxCPM.")
+ parser.add_argument("--inference-timesteps", type=int, default=10, help="Number of diffusion timesteps.")
+ parser.add_argument("--min-len", type=int, default=2, help="Minimum latent length.")
+ parser.add_argument("--max-new-tokens", type=int, default=4096, help="Maximum latent length.")
+ parser.add_argument(
+ "--streaming-prefix-len",
+ type=int,
+ default=3,
+ help="Streaming prefix length used by voxcpm_async_chunk.yaml.",
+ )
+ parser.add_argument("--stage-init-timeout", type=int, default=600, help="Stage initialization timeout in seconds.")
+ parser.add_argument("--log-stats", action="store_true", help="Enable vLLM Omni stats logging.")
+ args = parser.parse_args()
+ if (args.ref_audio is None) != (args.ref_text is None):
+ raise ValueError("Voice cloning requires --ref-audio and --ref-text together.")
+ return args
+
+
+def main(args) -> None:
+ route = "streaming" if _is_streaming_stage_config(args.stage_configs_path) else "sync"
+ print(f"Model: {args.model}")
+ print(f"Stage config: {args.stage_configs_path}")
+ print(f"Route: {route}")
+ if route == "streaming":
+ asyncio.run(_run_streaming(args))
+ else:
+ _run_sync(args)
+
+
+if __name__ == "__main__":
+ main(parse_args())
diff --git a/examples/offline_inference/voxcpm2/README.md b/examples/offline_inference/voxcpm2/README.md
new file mode 100644
index 0000000000..e982730799
--- /dev/null
+++ b/examples/offline_inference/voxcpm2/README.md
@@ -0,0 +1,83 @@
+# VoxCPM2 Offline Inference (Native AR)
+
+VoxCPM2 is a 2B-parameter tokenizer-free diffusion AR TTS model. It produces 48kHz audio and supports 30+ languages with a single-stage native AR pipeline backed by MiniCPM4.
+
+## Prerequisites
+
+Install the `voxcpm` package, or set the environment variable pointing to the source tree:
+
+```bash
+# Option A: install package
+pip install voxcpm
+
+# Option B: use source checkout
+export VLLM_OMNI_VOXCPM_CODE_PATH=/path/to/voxcpm
+```
+
+## Quick Start
+
+Zero-shot synthesis:
+
+```bash
+python examples/offline_inference/voxcpm2/end2end.py \
+ --model openbmb/VoxCPM2 \
+ --text "Hello, this is a VoxCPM2 demo." \
+ --output-dir output_audio
+```
+
+Voice cloning with a reference audio:
+
+```bash
+python examples/offline_inference/voxcpm2/end2end.py \
+ --text "Hello, this is a voice clone demo." \
+ --reference-audio /path/to/reference.wav \
+ --output-dir output_clone
+```
+
+Prompt continuation (matched audio + text prefix):
+
+```bash
+python examples/offline_inference/voxcpm2/end2end.py \
+ --text "Continuation target sentence." \
+ --prompt-audio /path/to/prompt.wav \
+ --prompt-text "Transcript of the prompt audio." \
+ --output-dir output_cont
+```
+
+The script accepts the following arguments:
+
+| Argument | Default | Description |
+|---|---|---|
+| `--model` | `openbmb/VoxCPM2` | HuggingFace repo ID or local path |
+| `--text` | (example sentence) | Text to synthesize |
+| `--output-dir` | `output_audio` | Directory for output WAV files |
+| `--stage-configs-path` | `voxcpm2.yaml` | Stage config YAML path |
+| `--reference-audio` | `None` | Reference audio for voice cloning (isolated) |
+| `--prompt-audio` | `None` | Prompt audio for continuation mode |
+| `--prompt-text` | `None` | Transcript matching `--prompt-audio` |
+
+## Performance
+
+Measured on a single H20 GPU (80 GB):
+
+| Input length | RTF | Sample rate |
+|---|---|---|
+| Short (~10 tokens) | ~0.28 | 48 kHz |
+| Long (~100 tokens) | ~0.34 | 48 kHz |
+
+RTF < 1.0 means faster than real time.
+
+## Architecture
+
+VoxCPM2 uses a single-stage native AR pipeline:
+
+```
+feat_encoder
+└─► MiniCPM4 (base LM)
+ └─► FSQ (finite scalar quantization)
+ └─► residual_lm (residual AR)
+ └─► LocDiT (local diffusion transformer)
+ └─► AudioVAE → 48 kHz waveform
+```
+
+All stages are fused into one vllm-native execution graph via `voxcpm2.yaml`, eliminating inter-stage coordination overhead and enabling true end-to-end batching.
diff --git a/examples/offline_inference/voxcpm2/end2end.py b/examples/offline_inference/voxcpm2/end2end.py
new file mode 100644
index 0000000000..6b6bf78ddf
--- /dev/null
+++ b/examples/offline_inference/voxcpm2/end2end.py
@@ -0,0 +1,171 @@
+"""Offline VoxCPM2 inference example (native AR pipeline).
+
+Uses the single-stage native AR config (voxcpm2.yaml).
+Requires the `voxcpm` package or VLLM_OMNI_VOXCPM_CODE_PATH env var.
+"""
+
+from __future__ import annotations
+
+import os
+import time
+from pathlib import Path
+
+import soundfile as sf
+import torch
+from vllm.utils.argparse_utils import FlexibleArgumentParser
+
+from vllm_omni import Omni
+
+REPO_ROOT = Path(__file__).resolve().parents[3]
+DEFAULT_STAGE_CONFIGS_PATH = str(REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm2.yaml")
+SAMPLE_RATE = 48_000
+
+
+def parse_args():
+ parser = FlexibleArgumentParser(description="Offline VoxCPM2 native AR inference")
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="openbmb/VoxCPM2",
+ help="VoxCPM2 model path or HuggingFace repo ID.",
+ )
+ parser.add_argument(
+ "--text",
+ type=str,
+ default="This is a VoxCPM2 native AR synthesis example running on vLLM Omni.",
+ help="Text to synthesize.",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="output_audio",
+ help="Directory for output WAV files.",
+ )
+ parser.add_argument(
+ "--stage-configs-path",
+ type=str,
+ default=DEFAULT_STAGE_CONFIGS_PATH,
+ help="Path to the stage config YAML file.",
+ )
+ parser.add_argument(
+ "--reference-audio",
+ type=str,
+ default=None,
+ help="Path to reference audio for voice cloning (isolated ref mode).",
+ )
+ parser.add_argument(
+ "--prompt-audio",
+ type=str,
+ default=None,
+ help="Path to prompt audio for continuation mode (requires --prompt-text).",
+ )
+ parser.add_argument(
+ "--prompt-text",
+ type=str,
+ default=None,
+ help="Text matching --prompt-audio for continuation mode.",
+ )
+ parser.add_argument(
+ "--ref-text",
+ type=str,
+ default=None,
+ help="Optional transcript of --reference-audio (enables ref_continuation mode).",
+ )
+ return parser.parse_args()
+
+
+def extract_audio(multimodal_output: dict) -> torch.Tensor:
+ """Extract the final complete audio tensor from multimodal output.
+
+ The output processor concatenates per-step delta tensors under
+ ``model_outputs``. Falls back to ``audio`` for backwards compat.
+ """
+ audio = multimodal_output.get("model_outputs")
+ if audio is None:
+ audio = multimodal_output.get("audio")
+ if audio is None:
+ raise ValueError(f"No audio key in multimodal_output: {list(multimodal_output.keys())}")
+
+ if isinstance(audio, list):
+ # Defensive: usually the output processor consolidates into a single
+ # tensor at request completion, but concatenate here too in case the
+ # caller consumes intermediate (pre-consolidation) outputs.
+ valid = [torch.as_tensor(a).float().cpu().reshape(-1) for a in audio if a is not None]
+ if not valid:
+ raise ValueError("Audio list is empty or all elements are None.")
+ return torch.cat(valid, dim=0) if len(valid) > 1 else valid[0]
+
+ return torch.as_tensor(audio).float().cpu().reshape(-1)
+
+
+def main():
+ args = parse_args()
+
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ engine = Omni(
+ model=args.model,
+ stage_configs_path=args.stage_configs_path,
+ )
+
+ from transformers import AutoTokenizer
+
+ from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import (
+ build_cjk_split_map,
+ build_voxcpm2_prompt,
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
+ split_map = build_cjk_split_map(tokenizer)
+ hf_config = engine.engine.stage_vllm_configs[0].model_config.hf_config
+
+ ref_audio_arg = args.reference_audio or args.prompt_audio
+ ref_text_arg = args.ref_text or args.prompt_text
+ ref_wav, ref_sr = (None, None)
+ if ref_audio_arg:
+ ref_wav_arr, ref_sr = sf.read(ref_audio_arg)
+ ref_wav = ref_wav_arr.mean(axis=-1).tolist() if ref_wav_arr.ndim > 1 else ref_wav_arr.tolist()
+
+ prompt = build_voxcpm2_prompt(
+ hf_config=hf_config,
+ tokenizer=tokenizer,
+ split_map=split_map,
+ text=args.text,
+ ref_audio=ref_wav,
+ ref_sr=ref_sr,
+ ref_text=ref_text_arg,
+ )
+
+ print(f"Model : {args.model}")
+ print(f"Text : {args.text}")
+ if ref_audio_arg:
+ print(f"Ref audio : {ref_audio_arg}")
+ if ref_text_arg:
+ print(f"Ref text : {ref_text_arg}")
+ print(f"Output dir : {output_dir}")
+
+ t_start = time.perf_counter()
+ outputs = engine.generate([prompt])
+ elapsed = time.perf_counter() - t_start
+
+ # outputs[0].outputs[0].multimodal_output["audio"] is a list of tensors
+ request_output = outputs[0]
+ mm = request_output.outputs[0].multimodal_output
+ audio = extract_audio(mm)
+
+ duration = audio.numel() / SAMPLE_RATE
+ rtf = elapsed / duration if duration > 0 else float("inf")
+
+ output_path = output_dir / "output.wav"
+ sf.write(str(output_path), audio.numpy(), SAMPLE_RATE, format="WAV")
+
+ print(f"Saved : {output_path}")
+ print(f"Duration : {duration:.2f}s")
+ print(f"Inference : {elapsed:.2f}s")
+ print(f"RTF : {rtf:.3f}")
+
+
+if __name__ == "__main__":
+ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+ main()
diff --git a/examples/offline_inference/x_to_video_audio/x_to_video_audio.md b/examples/offline_inference/x_to_video_audio/x_to_video_audio.md
index 4b5188f41b..13f2cfe7c0 100644
--- a/examples/offline_inference/x_to_video_audio/x_to_video_audio.md
+++ b/examples/offline_inference/x_to_video_audio/x_to_video_audio.md
@@ -30,9 +30,9 @@ dreamid_omni/
```
### Run the Inference
-```
+```python
python x_to_video_audio.py \
- --model /xx/dreamid_omni \
+ --model /path/to/dreamid_omni \
--prompt "Two people walking together and singing happily" \
--image-path ./example0.png ./example1.png \
--audio-path ./example0.wav ./example1.wav \
@@ -42,11 +42,33 @@ python x_to_video_audio.py \
--num-inference-steps 45 \
--height 704 \
--width 1280 \
- --output dreamid_omni.mp4
+ --output out_dreamid_omni_twoip.mp4
```
In the current test scenario (2 images + 2 audio inputs), the VRAM requirement is 72GB, regardless of whether cfg-parallel is enabled or disabled.
The VRAM usage can be reduced by enabling CPU offload via --enable-cpu-offload.
+
+You could take reference images/audios from the test cases in the official repo: https://github.com/Guoxu1233/DreamID-Omni
+
+For example, single IP ref resources can be found under https://github.com/Guoxu1233/DreamID-Omni/tree/main/test_case/oneip, you could download them correspondingly to your local and use them for testing.
+
+```python
+# Example usage for oneip, ref media from the official repo DreamID-Omni
+python x_to_video_audio.py \
+ --model /path/to/dreamid_omni \
+ --prompt ": In the frame, a woman with black long hair is identified as .\n**Overall Environment/Scene**: A lively open-kitchen café at night; stove flames flare, steam rises, and warm pendant lights swing slightly as staff move behind her. The shot is an upper-body close-up.\n**Main Characters/Subjects Appearance**: is a young woman with thick dark wavy hair and a side part. She wears a fitted black top under a light apron, a thin gold chain necklace, and small stud earrings.\n**Main Characters/Subjects Actions**: tastes the sauce with a spoon, then turns her face toward the camera while still holding the spoon, her expression shifting from focused to conflicted.\n maintains eye contact, swallows as if choosing her words, and says, I keep telling myself I’m fine,but some nights it feels like I’m just performing calm." \
+ --image-path 9.png \
+ --audio-path 9.wav \
+ --video-negative-prompt "jitter, bad hands, blur, distortion" \
+ --audio-negative-prompt "robotic, muffled, echo, distorted" \
+ --cfg-parallel-size 2 \
+ --num-inference-steps 45 \
+ --height 704 \
+ --width 1280 \
+ --output out_dreamid_omni_oneip.mp4
+```
+
+
Key arguments:
- `--prompt`: text description (string).
- `--model`: path to the model local directory.
diff --git a/examples/offline_inference/x_to_video_audio/x_to_video_audio.py b/examples/offline_inference/x_to_video_audio/x_to_video_audio.py
index e0424add69..497284ceb9 100644
--- a/examples/offline_inference/x_to_video_audio/x_to_video_audio.py
+++ b/examples/offline_inference/x_to_video_audio/x_to_video_audio.py
@@ -5,10 +5,12 @@
import re
import time
-import librosa
+import numpy as np
from PIL import Image
+from vllm.multimodal.media.audio import load_audio
from vllm_omni.diffusion.data import DiffusionParallelConfig
+from vllm_omni.diffusion.utils.media_utils import mux_video_audio_bytes
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
@@ -36,8 +38,8 @@ def parse_args() -> argparse.Namespace:
"--cfg-parallel-size",
type=int,
default=1,
- choices=[1, 2],
- help="Number of GPUs used for classifier free guidance parallel size.",
+ choices=[1, 2, 3, 4],
+ help="Number of GPUs used for classifier free guidance parallel size (max 4 branches).",
)
parser.add_argument(
"--video-negative-prompt",
@@ -56,6 +58,11 @@ def parse_args() -> argparse.Namespace:
default=False,
help="Enable CPU offloading for diffusion models.",
)
+ parser.add_argument(
+ "--enable-layerwise-offload",
+ action="store_true",
+ help="Enable layerwise (blockwise) offloading on DiT modules.",
+ )
return parser.parse_args()
@@ -69,7 +76,7 @@ def load_image_and_audio(image_paths, audio_paths):
image.append(img)
for path in audio_paths:
- audio_array, sr = librosa.load(path, sr=16000)
+ audio_array, sr = load_audio(path, sr=16000)
audio_array = audio_array[int(sr * 1) : int(sr * 3)]
audio.append(audio_array)
return image, audio
@@ -124,6 +131,7 @@ def main() -> None:
parallel_config=parallel_config,
model_type=args.model_type,
enable_cpu_offload=args.enable_cpu_offload,
+ enable_layerwise_offload=args.enable_layerwise_offload,
)
start = time.perf_counter()
outputs = omni.generate(prompt, sampling_params)
@@ -131,15 +139,35 @@ def main() -> None:
if not outputs:
raise RuntimeError("No output returned from DreamID-Omni.")
- output = outputs[0].request_output
- generated_video = output.images[0][0]
- generated_audio = output.images[0][1]
- try:
- from dreamid_omni.utils.io_utils import save_video
- except Exception as e:
- raise RuntimeError(f"Failed to extract video and audio from DreamID-Omni output. Error: {e}")
+ result = outputs[0]
+ if not result.images:
+ raise RuntimeError("No video frames found in DreamID-Omni output.")
+ generated_video = result.images[0]
+ mm = result.multimodal_output or {}
+ generated_audio = mm.get("audio")
+ fps = int(mm.get("fps", 24))
+ sample_rate = int(mm.get("audio_sample_rate", 16000))
+
+ # DreamID-Omni returns video as (C, F, H, W) float32 in [-1, 1].
+ # mux_video_audio_bytes expects (F, H, W, C) uint8.
+ if not isinstance(generated_video, np.ndarray) or generated_video.ndim != 4:
+ raise RuntimeError(f"Unexpected video shape: {getattr(generated_video, 'shape', None)}")
+ frames = generated_video.transpose(1, 2, 3, 0)
+ frames = (np.clip((frames + 1.0) / 2.0, 0.0, 1.0) * 255.0).round().astype(np.uint8)
+
+ audio_np = None
+ if generated_audio is not None:
+ audio_np = np.squeeze(np.asarray(generated_audio)).astype(np.float32)
+
output_path = args.output
- save_video(output_path, generated_video, generated_audio, fps=24, sample_rate=16000)
+ video_bytes = mux_video_audio_bytes(
+ frames,
+ audio_np,
+ fps=float(fps),
+ audio_sample_rate=sample_rate,
+ )
+ with open(output_path, "wb") as f:
+ f.write(video_bytes)
print(f"Saved generated video to {output_path}")
print(f"Total time: {elapsed:.2f}s")
diff --git a/examples/online_serving/bagel/README.md b/examples/online_serving/bagel/README.md
index 9b74acae10..0939bc5f38 100644
--- a/examples/online_serving/bagel/README.md
+++ b/examples/online_serving/bagel/README.md
@@ -354,13 +354,6 @@ curl http://localhost:8091/v1/chat/completions \
## FAQ
-- If you encounter an error about the backend of librosa, try to install ffmpeg with the command below.
-
-```bash
-sudo apt update
-sudo apt install ffmpeg
-```
-
- If you don’t know how much VRAM is needed for the model or encounter the OOM error, you can try to decrease the max_model_len.
| Stage | VRAM |
diff --git a/examples/online_serving/image_to_video/README.md b/examples/online_serving/image_to_video/README.md
index 49283bd9a0..285eeb2798 100644
--- a/examples/online_serving/image_to_video/README.md
+++ b/examples/online_serving/image_to_video/README.md
@@ -26,6 +26,23 @@ The script allows overriding:
- `CACHE_BACKEND` (default: `none`)
- `ENABLE_CACHE_DIT_SUMMARY` (default: `0`)
+### Ascend / Local LightX2V Example
+
+For a local Wan2.2-LightX2V Diffusers directory on Ascend/NPU, you can start the server like this:
+
+```bash
+vllm serve /path/to/Wan2.2-I2V-A14B-LightX2V-Diffusers-Lightning \
+ --omni \
+ --port 8091 \
+ --flow-shift 12 \
+ --cfg-parallel-size 1 \
+ --ulysses-degree 4 \
+ --use-hsdp \
+ --trust-remote-code \
+ --allowed-local-media-path / \
+ --seed 42
+```
+
## Async Job Behavior
`POST /v1/videos` is asynchronous. It creates a video job and immediately
@@ -69,10 +86,35 @@ curl -X POST http://localhost:8091/v1/videos/sync \
-F "guidance_scale_2=1.0" \
-F "boundary_ratio=0.875" \
-F "flow_shift=12.0" \
+ -F 'extra_params={"sample_solver":"euler"}' \
-F "seed=42" \
-o sync_i2v_output.mp4
```
+For Wan Lightning/Distill checkpoints, pass `{"sample_solver":"euler"}` via `extra_params`. The default solver is `unipc`.
+
+Example matching the local LightX2V deployment above:
+
+```bash
+curl -sS -X POST http://localhost:8091/v1/videos/sync \
+ -H "Accept: video/mp4" \
+ -F "prompt=A cat playing with yarn" \
+ -F "input_reference=@/path/to/input.jpg" \
+ -F "width=832" \
+ -F "height=480" \
+ -F "num_frames=81" \
+ -F "fps=16" \
+ -F "num_inference_steps=4" \
+ -F "guidance_scale=1.0" \
+ -F "guidance_scale_2=1.0" \
+ -F "boundary_ratio=0.875" \
+ -F "seed=42" \
+ -F 'extra_params={"sample_solver":"euler"}' \
+ -o ./output.mp4
+```
+
+Use `/v1/videos/sync` if you want to write the MP4 directly to a file. `POST /v1/videos` is async and returns job metadata, not inline `b64_json`.
+
## Storage
Generated video files are stored on local disk by the async video API.
@@ -96,6 +138,9 @@ export VLLM_OMNI_STORAGE_MAX_CONCURRENCY=8
# Basic image-to-video generation
bash run_curl_image_to_video.sh
+# Wan Lightning/Distill checkpoints
+SAMPLE_SOLVER=euler bash run_curl_image_to_video.sh
+
# Or execute directly (OpenAI-style multipart)
create_response=$(curl -s http://localhost:8091/v1/videos \
-H "Accept: application/json" \
@@ -111,6 +156,7 @@ create_response=$(curl -s http://localhost:8091/v1/videos \
-F "guidance_scale_2=1.0" \
-F "boundary_ratio=0.875" \
-F "flow_shift=12.0" \
+ -F 'extra_params={"sample_solver":"euler"}' \
-F "seed=42")
video_id=$(echo "$create_response" | jq -r '.id')
@@ -169,9 +215,12 @@ curl -X POST http://localhost:8091/v1/videos \
-F "guidance_scale_2=1.0" \
-F "boundary_ratio=0.875" \
-F "flow_shift=12.0" \
+ -F 'extra_params={"sample_solver":"euler"}' \
-F "seed=42"
```
+`sample_solver` is supported by Wan2.2 online serving through the existing `extra_params` field, which is merged into the pipeline `extra_args`. Use `unipc` for the default multistep solver, or `euler` for Lightning/Distill checkpoints.
+
## Create Response Format
`POST /v1/videos` returns a job record, not inline base64 video data.
diff --git a/examples/online_serving/image_to_video/run_curl_image_to_video.sh b/examples/online_serving/image_to_video/run_curl_image_to_video.sh
index f4c1496a69..6f6a6f96d5 100644
--- a/examples/online_serving/image_to_video/run_curl_image_to_video.sh
+++ b/examples/online_serving/image_to_video/run_curl_image_to_video.sh
@@ -7,6 +7,7 @@ INPUT_IMAGE="${INPUT_IMAGE:-../../offline_inference/image_to_video/qwen-bear.png
BASE_URL="${BASE_URL:-http://localhost:8099}"
OUTPUT_PATH="${OUTPUT_PATH:-wan22_i2v_output.mp4}"
NEGATIVE_PROMPT="${NEGATIVE_PROMPT:-}"
+SAMPLE_SOLVER="${SAMPLE_SOLVER:-}"
POLL_INTERVAL="${POLL_INTERVAL:-2}"
if [ ! -f "$INPUT_IMAGE" ]; then
@@ -34,6 +35,10 @@ if [ -n "${NEGATIVE_PROMPT}" ]; then
create_cmd+=(-F "negative_prompt=${NEGATIVE_PROMPT}")
fi
+if [ -n "${SAMPLE_SOLVER}" ]; then
+ create_cmd+=(-F "extra_params={\"sample_solver\":\"${SAMPLE_SOLVER}\"}")
+fi
+
create_response="$("${create_cmd[@]}")"
video_id="$(echo "${create_response}" | jq -r '.id')"
if [ -z "${video_id}" ] || [ "${video_id}" = "null" ]; then
diff --git a/examples/online_serving/ming_flash_omni/README.md b/examples/online_serving/ming_flash_omni/README.md
new file mode 100644
index 0000000000..502232725c
--- /dev/null
+++ b/examples/online_serving/ming_flash_omni/README.md
@@ -0,0 +1,204 @@
+# Ming-flash-omni 2.0
+
+## Installation
+
+Please refer to [README.md](../../../README.md)
+
+## Run examples (Ming-flash-omni 2.0)
+
+### Launch the Server
+
+```bash
+vllm serve Jonathan1909/Ming-flash-omni-2.0 --omni --port 8091
+```
+
+If you have custom stage configs file, launch the server with command below
+```bash
+vllm serve Jonathan1909/Ming-flash-omni-2.0 --omni --port 8091 --stage-configs-path /path/to/stage_configs_file
+```
+
+### Send Multi-modal Request
+
+#### Send request via python
+
+```bash
+python examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py --model Jonathan1909/Ming-flash-omni-2.0 --query-type use_mixed_modalities --port 8091 --host "localhost" --modalities text
+```
+
+The Python client supports the following command-line arguments:
+
+- `--query-type` (or `-q`): Query type. Options: `text`, `use_audio`, `use_image`, `use_video`, `use_mixed_modalities`
+- `--video-path` (or `-v`): Path to local video file or URL. If not provided and query-type uses video, uses default video URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs. Example: `--video-path /path/to/video.mp4` or `--video-path https://example.com/video.mp4`
+- `--image-path` (or `-i`): Path to local image file or URL. If not provided and query-type uses image, uses default image URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common image formats: JPEG, PNG, GIF, WebP. Example: `--image-path /path/to/image.jpg` or `--image-path https://example.com/image.png`
+- `--audio-path` (or `-a`): Path to local audio file or URL. If not provided and query-type uses audio, uses default audio URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common audio formats: MP3, WAV, OGG, FLAC, M4A. Example: `--audio-path /path/to/audio.wav` or `--audio-path https://example.com/audio.mp3`
+- `--prompt` (or `-p`): Custom text prompt/question. If not provided, uses default prompt for the selected query type. Example: `--prompt "What are the main activities shown in this video?"`
+- `--modalities`: Output modalities. For now, only `text` is supported. Example: `--modalities text`
+
+
+#### Send request via curl
+
+```bash
+bash run_curl_multimodal_generation.sh text
+bash run_curl_multimodal_generation.sh use_image
+bash run_curl_multimodal_generation.sh use_audio
+bash run_curl_multimodal_generation.sh use_video
+bash run_curl_multimodal_generation.sh use_mixed_modalities
+```
+
+## Modality control
+
+Ming-flash-omni 2.0 currently supports text output only (thinker stage).
+
+| Modalities | Output |
+|------------|--------|
+| `["text"]` | Text only |
+| Not specified | Text only (default) |
+
+### Using curl
+
+```bash
+curl http://localhost:8091/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "Jonathan1909/Ming-flash-omni-2.0",
+ "messages": [
+ {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]},
+ {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"}
+ ],
+ "modalities": ["text"]
+ }'
+```
+
+### Using OpenAI Python SDK
+
+```python
+from openai import OpenAI
+
+client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY")
+
+response = client.chat.completions.create(
+ model="Jonathan1909/Ming-flash-omni-2.0",
+ messages=[
+ {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]},
+ {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"},
+ ],
+ modalities=["text"],
+)
+print(response.choices[0].message.content)
+```
+
+### Multi-modal input with OpenAI Python SDK
+
+```python
+from openai import OpenAI
+
+client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY")
+
+response = client.chat.completions.create(
+ model="Jonathan1909/Ming-flash-omni-2.0",
+ messages=[
+ {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]},
+ {
+ "role": "user",
+ "content": [
+ {"type": "image_url", "image_url": {"url": "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg"}},
+ {"type": "text", "text": "Describe this image in detail."},
+ ],
+ },
+ ],
+ modalities=["text"],
+)
+print(response.choices[0].message.content)
+```
+
+## Streaming Output
+
+To enable streaming output:
+
+```bash
+python examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py \
+ --query-type use_image \
+ --model Jonathan1909/Ming-flash-omni-2.0 \
+ --modalities text \
+ --stream
+```
+
+Or with the OpenAI Python SDK:
+
+```python
+from openai import OpenAI
+
+client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY")
+
+response = client.chat.completions.create(
+ model="Jonathan1909/Ming-flash-omni-2.0",
+ messages=[
+ {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]},
+ {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"},
+ ],
+ modalities=["text"],
+ stream=True,
+)
+for chunk in response:
+ for choice in chunk.choices:
+ if hasattr(choice, "delta") and choice.delta.content:
+ print(choice.delta.content, end="", flush=True)
+print()
+```
+
+Or using curl:
+
+```bash
+curl http://localhost:8091/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "Jonathan1909/Ming-flash-omni-2.0",
+ "messages": [
+ {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]},
+ {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"}
+ ],
+ "modalities": ["text"],
+ "stream": true,
+ }'
+```
+
+
+## Reasoning (Thinking Mode)
+
+To enable reasoning/thinking mode, change `detailed thinking off` to `detailed thinking on` in the system prompt:
+
+### Using curl
+
+```bash
+curl http://localhost:8091/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "Jonathan1909/Ming-flash-omni-2.0",
+ "messages": [
+ {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking on"}]},
+ {"role": "user", "content": [
+ {"type": "image_url", "image_url": {"url": "https://example.com/math_problem.png"}},
+ {"type": "text", "text": "Solve this math problem step by step."}
+ ]}
+ ],
+ "modalities": ["text"]
+ }'
+```
+
+### Using OpenAI Python SDK
+
+```python
+from openai import OpenAI
+
+client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY")
+
+response = client.chat.completions.create(
+ model="Jonathan1909/Ming-flash-omni-2.0",
+ messages=[
+ {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking on"}]},
+ {"role": "user", "content": "If a train travels 120 km in 2 hours, what is its average speed?"},
+ ],
+ modalities=["text"],
+)
+print(response.choices[0].message.content)
+```
diff --git a/examples/online_serving/ming_flash_omni/run_curl_multimodal_generation.sh b/examples/online_serving/ming_flash_omni/run_curl_multimodal_generation.sh
new file mode 100755
index 0000000000..768a424e45
--- /dev/null
+++ b/examples/online_serving/ming_flash_omni/run_curl_multimodal_generation.sh
@@ -0,0 +1,145 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+# Server port
+PORT="${PORT:-8091}"
+# Default query type
+QUERY_TYPE="${1:-text}"
+
+# Validate query type
+if [[ ! "$QUERY_TYPE" =~ ^(text|use_audio|use_image|use_video|use_mixed_modalities)$ ]]; then
+ echo "Error: Invalid query type '$QUERY_TYPE'"
+ echo "Usage: $0 [text|use_audio|use_image|use_video|use_mixed_modalities]"
+ echo " text: Text-only query"
+ echo " use_audio: Audio + Text query"
+ echo " use_image: Image + Text query"
+ echo " use_video: Video + Text query"
+ echo " use_mixed_modalities: Audio + Image + Video + Text query"
+ exit 1
+fi
+
+thinker_sampling_params='{
+ "temperature": 0.4,
+ "top_p": 0.9,
+ "top_k": -1,
+ "max_tokens": 16384,
+ "seed": 42,
+ "detokenize": true,
+ "repetition_penalty": 1.05
+}'
+# Above is optional, it has a default setting in stage_configs of the corresponding model.
+
+# Define URLs for assets
+MARY_HAD_LAMB_AUDIO_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/mary_had_lamb.ogg"
+CHERRY_BLOSSOM_IMAGE_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg"
+SAMPLE_VIDEO_URL="https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4"
+
+# Build user content based on query type
+case "$QUERY_TYPE" in
+ text)
+ user_content='[
+ {
+ "type": "text",
+ "text": "请详细介绍鹦鹉的生活习性。"
+ }
+ ]'
+ ;;
+ use_image)
+ user_content='[
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": "'"$CHERRY_BLOSSOM_IMAGE_URL"'"
+ }
+ },
+ {
+ "type": "text",
+ "text": "Describe this image in detail."
+ }
+ ]'
+ ;;
+ use_audio)
+ user_content='[
+ {
+ "type": "audio_url",
+ "audio_url": {
+ "url": "'"$MARY_HAD_LAMB_AUDIO_URL"'"
+ }
+ },
+ {
+ "type": "text",
+ "text": "Please recognize the language of this speech and transcribe it. Format: oral."
+ }
+ ]'
+ ;;
+ use_video)
+ user_content='[
+ {
+ "type": "video_url",
+ "video_url": {
+ "url": "'"$SAMPLE_VIDEO_URL"'"
+ }
+ },
+ {
+ "type": "text",
+ "text": "Describe what is happening in this video."
+ }
+ ]'
+ ;;
+ use_mixed_modalities)
+ user_content='[
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": "'"$CHERRY_BLOSSOM_IMAGE_URL"'"
+ }
+ },
+ {
+ "type": "audio_url",
+ "audio_url": {
+ "url": "'"$MARY_HAD_LAMB_AUDIO_URL"'"
+ }
+ },
+ {
+ "type": "text",
+ "text": "Describe the image, and recognize the language of this speech and transcribe it. Format: oral"
+ }
+ ]'
+ ;;
+esac
+
+echo "Running query type: $QUERY_TYPE"
+echo ""
+
+request_body=$(cat < **Note on `--no-async-chunk`**: Flips the deploy yaml's `async_chunk:`
+> bool. Pipelines that implement alternate processor functions for
+> chunked vs end-to-end modes (e.g. qwen3_tts code2wav) dispatch
+> automatically based on that bool — no extra flag or variant yaml is
+> needed.
+
+> ⚠️ **For multi-stage models that share GPUs (qwen3_omni_moe by default
+> shares cuda:1 between stages 1 and 2), avoid using global memory flags.**
+> A global `--gpu-memory-utilization 0.85` would apply to every stage and
+> oversubscribe the shared device. Use per-stage overrides instead — see
+> below.
+
+#### 2. Per-stage overrides via `--stage-overrides` (recommended for memory)
+
+```bash
+# Lower stage 1's memory budget; leave others at the YAML default
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
+ --stage-overrides '{
+ "1": {"gpu_memory_utilization": 0.5},
+ "2": {"max_num_batched_tokens": 65536}
+ }'
+```
+
+Per-stage values are always treated as explicit and beat YAML defaults for
+the named stage. Other stages keep their YAML values.
+
+#### 3. Custom deploy YAML
+
+When per-stage overrides get long, write a small overlay YAML that inherits
+from the bundled default:
+
+```yaml
+# my_qwen3_omni_overrides.yaml
+base_config: /path/to/vllm_omni/deploy/qwen3_omni_moe.yaml
+
+stages:
+ - stage_id: 0
+ max_num_batched_tokens: 65536
+ enforce_eager: true
+ - stage_id: 1
+ gpu_memory_utilization: 0.5
+ - stage_id: 2
+ max_model_len: 8192
```
+Then start the server with `--deploy-config my_qwen3_omni_overrides.yaml`.
+The `base_config:` line tells the loader to inherit everything else (stages,
+connectors, edges, platforms section) from the bundled production YAML, so
+you only need to spell out the deltas.
+
+#### 4. Multi-node deployment (cross-host transfer connector)
+
+The bundled `qwen3_omni_moe.yaml` uses `SharedMemoryConnector` between stages,
+which only works when all stages run on the same physical host. For
+**cross-node** deployments, write a small overlay YAML that swaps in a
+network-capable connector (e.g. `MooncakeStoreConnector`) and re-points each
+stage's connector wiring at it. The connector spec carries your own server
+addresses — there is no checked-in default because every cluster is
+different.
+
+```yaml
+# my_qwen3_omni_multinode.yaml
+base_config: /path/to/vllm_omni/deploy/qwen3_omni_moe.yaml
+
+connectors:
+ mooncake_connector:
+ name: MooncakeStoreConnector
+ extra:
+ host: "127.0.0.1"
+ metadata_server: "http://YOUR_METADATA_HOST:8080/metadata"
+ master: "YOUR_MASTER_HOST:50051"
+ segment: 512000000 # 512 MB transfer segment
+ localbuf: 64000000 # 64 MB local buffer
+ proto: "tcp"
+
+stages:
+ - stage_id: 0
+ output_connectors:
+ to_stage_1: mooncake_connector
+ - stage_id: 1
+ input_connectors:
+ from_stage_0: mooncake_connector
+ output_connectors:
+ to_stage_2: mooncake_connector
+ - stage_id: 2
+ input_connectors:
+ from_stage_1: mooncake_connector
+```
+
+Then launch with `--deploy-config my_qwen3_omni_multinode.yaml`. Same
+pattern works for Qwen2.5-Omni — replace `base_config:` with the path to
+`vllm_omni/deploy/qwen2_5_omni.yaml`.
+
+> ⚠️ Replace `YOUR_METADATA_HOST` / `YOUR_MASTER_HOST` with the actual
+> mooncake server addresses for your cluster. The `base_config:` overlay
+> inherits all stage budgets, devices, and edges from the bundled prod
+> YAML — you only need to spell out the connector swap.
+
### Send Multi-modal Request
Get into the example folder
@@ -38,38 +180,43 @@ python examples/online_serving/openai_chat_completion_client_for_multimodal_gene
#### Realtime WebSocket client (`openai_realtime_client.py`)
-[`openai_realtime_client.py`](./openai_realtime_client.py) connects to **`ws://:/v1/realtime`**, uploads a local audio file as **PCM16 mono @ 16 kHz** chunks (OpenAI-style `input_audio_buffer.append` / `commit`), and prints **streaming transcription** (`transcription.delta` / `transcription.done`).
+[`openai_realtime_client.py`](./openai_realtime_client.py) connects to **`ws://:/v1/realtime`**, streams a local WAV as **PCM16 mono @ 16 kHz** in fixed-size chunks (OpenAI-style `input_audio_buffer.append` / `commit`), and receives **`response.audio.delta`** (incremental PCM for the reply) plus **`transcription.*`** events. By default it concatenates audio deltas and writes **`--output-wav`** (model output is typically **24 kHz**). Optional **`--delta-dump-dir`** saves each delta as `delta_000001.wav`, … for debugging.
+
+Streaming input works well for translation-style use cases; if the Thinker runs while input is still incomplete, consider limiting **`max_tokens`** in your session / server defaults to avoid over-generation.
**Dependencies:**
```bash
-pip install websockets librosa numpy
+pip install websockets
```
-(ffmpeg may be required by `librosa` for some formats; see the FAQ below.)
-
**From this directory** (`examples/online_serving/qwen3_omni`):
```bash
python openai_realtime_client.py \
- --host localhost \
- --port 8091 \
+ --url ws://localhost:8091/v1/realtime \
--model Qwen/Qwen3-Omni-30B-A3B-Instruct \
- --audio_path /path/to/your.wav
+ --input-wav /path/to/input_16k_mono.wav \
+ --output-wav realtime_output.wav \
+ --delta-dump-dir ./rt_delta_wavs
```
-If `--audio_path` is omitted, the script uses a bundled default clip (`mary_had_lamb` via vLLM assets).
-
**Arguments:**
| Flag | Default | Description |
|------|---------|-------------|
-| `--host` | `localhost` | API server host |
-| `--port` | `8000` | API server port (match your `vllm serve` port, e.g. `8091`) |
-| `--model` | `Qwen/Qwen3-Omni-30B-A3B-Instruct` | Must match the served model (also sent in `session.update`) |
-| `--audio_path` | *(optional)* | Path to input audio; resampled to 16 kHz mono inside the client |
-
-Ensure the vLLM-Omni server is running with realtime support for this endpoint, for example:
+| `--url` | `ws://localhost:8091/v1/realtime` | Full WebSocket URL including path |
+| `--model` | `Qwen/Qwen3-Omni-30B-A3B-Instruct` | Must match the served model (sent in `session.update`) |
+| `--input-wav` | *(required)* | Input WAV: mono, 16-bit PCM, **16 kHz** |
+| `--output-wav` | `realtime_output.wav` | Output path for concatenated reply audio |
+| `--output-text` | *(optional)* | If set, write final transcription text to this path |
+| `--chunk-ms` | `200` | Size of each uploaded audio chunk (milliseconds of audio) |
+| `--send-delay-ms` | `0` | Delay between chunk sends (simulate realtime upload) |
+| `--delta-dump-dir` | *(optional)* | Directory to write per-`response.audio.delta` WAV files |
+| `--num-requests` | `1` | Number of sequential sessions (see `--concurrency`) |
+| `--concurrency` | `1` | Max concurrent WebSocket sessions when `--num-requests` > 1 |
+
+Ensure the server is running **without** `async_chunk` if you use `/v1/realtime`, for example:
```bash
vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091
@@ -105,12 +252,6 @@ bash run_curl_multimodal_generation.sh use_image
### FAQ
-If you encounter error about backend of librosa, try to install ffmpeg with command below.
-```
-sudo apt update
-sudo apt install ffmpeg
-```
-
## Modality control
You can control output modalities to specify which types of output the model should generate. This is useful when you only need text output and want to skip audio generation stages for better performance.
@@ -284,7 +425,7 @@ The script supports the following arguments:
- `--model`: Model name/path (default: Qwen/Qwen3-Omni-30B-A3B-Instruct)
- `--server-port`: Port for vLLM server (default: 8091)
- `--gradio-port`: Port for Gradio demo (default: 7861)
-- `--stage-configs-path`: Path to custom stage configs YAML file (optional)
+- `--deploy-config`: Path to custom deploy config YAML file (optional)
- `--server-host`: Host for vLLM server (default: 0.0.0.0)
- `--gradio-ip`: IP for Gradio demo (default: 127.0.0.1)
- `--share`: Share Gradio demo publicly (creates a public link)
@@ -299,7 +440,7 @@ vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091
If you have custom stage configs file:
```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /path/to/stage_configs_file
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --deploy-config /path/to/deploy_config_file
```
**Step 2: Run the Gradio demo**
diff --git a/examples/online_serving/qwen3_omni/openai_realtime_client.py b/examples/online_serving/qwen3_omni/openai_realtime_client.py
index 4fa043c481..79e30a3f50 100644
--- a/examples/online_serving/qwen3_omni/openai_realtime_client.py
+++ b/examples/online_serving/qwen3_omni/openai_realtime_client.py
@@ -1,81 +1,118 @@
-"""
-This script demonstrates how to use the vLLM-Omni Realtime WebSocket API to perform
-audio transcription by uploading an audio file.
+"""Realtime client for vLLM-Omni /v1/realtime (audio + text events).
+
+This client:
+1) Reads a local WAV file (must be mono, 16-bit PCM, 16kHz),
+2) Streams PCM16 chunks to /v1/realtime with OpenAI-style events,
+3) Receives response.audio.* and transcription.* events,
+4) Saves synthesized audio to an output WAV file and optional text file.
-Before running this script, you must start the vLLM-Omni server with a realtime-capable
-model, for example:
+By default each ``response.audio.delta`` is treated as an **incremental PCM**
+chunk and all chunks are concatenated into the final ``--output-wav``.
- vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni
+Optional debugging: pass ``--delta-dump-dir DIR`` to write every
+``response.audio.delta`` payload as ``delta_000001.wav``, ``delta_000002.wav``, …
-Requirements:
-- vllm with audio support
-- websockets
-- librosa
-- numpy
+Usage:
+ python openai_realtime_client.py \
+ --url ws://localhost:8091/v1/realtime \
+ --model Qwen/Qwen3-Omni-30B-A3B-Instruct \
+ --input-wav input_16k_mono.wav \
+ --output-wav realtime_output.wav \
+ --delta-dump-dir ./rt_delta_wavs
-The script:
-1. Connects to the Realtime WebSocket endpoint
-2. Converts an audio file to PCM16 @ 16kHz
-3. Sends audio chunks to the server
-4. Receives and prints transcription as it streams
+Dependencies:
+ pip install websockets
"""
+from __future__ import annotations
+
import argparse
import asyncio
import base64
import json
+import wave
+from pathlib import Path
+
+try:
+ import websockets
+except ImportError:
+ print("Please install websockets: pip install websockets")
+ raise SystemExit(1)
+
+
+def _read_wav_pcm16(path: Path) -> bytes:
+ with wave.open(str(path), "rb") as wf:
+ nchannels = wf.getnchannels()
+ sampwidth = wf.getsampwidth()
+ framerate = wf.getframerate()
+ comptype = wf.getcomptype()
+ nframes = wf.getnframes()
+
+ if nchannels != 1:
+ raise ValueError(f"Input WAV must be mono (got {nchannels} channels).")
+ if sampwidth != 2:
+ raise ValueError(f"Input WAV must be 16-bit PCM (got sample width={sampwidth}).")
+ if framerate != 16000:
+ raise ValueError(f"Input WAV must be 16kHz (got {framerate} Hz).")
+ if comptype != "NONE":
+ raise ValueError(f"Input WAV must be uncompressed PCM (got comptype={comptype}).")
+ if nframes <= 0:
+ raise ValueError("Input WAV has no audio frames.")
+
+ return wf.readframes(nframes)
+
+
+def _write_wav_pcm16(path: Path, pcm16_bytes: bytes, sample_rate_hz: int) -> None:
+ with wave.open(str(path), "wb") as wf:
+ wf.setnchannels(1)
+ wf.setsampwidth(2)
+ wf.setframerate(sample_rate_hz)
+ wf.writeframes(pcm16_bytes)
+
+
+async def run_client(
+ url: str,
+ model: str,
+ input_wav: Path,
+ output_wav: Path,
+ output_text: Path | None,
+ chunk_ms: int,
+ send_delay_ms: int,
+ delta_dump_dir: Path | None,
+ request_idx: int = 1,
+ total_requests: int = 1,
+) -> None:
+ log_prefix = f"[req {request_idx:02d}/{total_requests:02d}] " if total_requests > 1 else ""
+ pcm16 = _read_wav_pcm16(input_wav)
+ bytes_per_ms = 16000 * 2 // 1000 # mono PCM16 at 16kHz
+ chunk_bytes = max(bytes_per_ms * chunk_ms, 2)
-import librosa
-import numpy as np
-import websockets
-from vllm.assets.audio import AudioAsset
-
-
-def audio_to_pcm16_base64(audio_path: str) -> str:
- """
- Load an audio file and convert it to base64-encoded PCM16 @ 16kHz.
- """
- # Load audio and resample to 16kHz mono
- audio, _ = librosa.load(audio_path, sr=16000, mono=True)
- # Convert to PCM16
- pcm16 = (audio * 32767).astype(np.int16)
- # Encode as base64
- return base64.b64encode(pcm16.tobytes()).decode("utf-8")
-
-
-async def realtime_transcribe(audio_path: str, host: str, port: int, model: str):
- """
- Connect to the Realtime API and transcribe an audio file.
- """
- uri = f"ws://{host}:{port}/v1/realtime"
-
- async with websockets.connect(uri) as ws:
- # Wait for session.created
- response = json.loads(await ws.recv())
- if response["type"] == "session.created":
- print(f"Session created: {response['id']}")
- else:
- print(f"Unexpected response: {response}")
- return
-
- # Validate model
- await ws.send(json.dumps({"type": "session.update", "model": model}))
-
- # Signal ready to start
- await ws.send(json.dumps({"type": "input_audio_buffer.commit"}))
-
- # Convert audio file to base64 PCM16
- print(f"Loading audio from: {audio_path}")
- audio_base64 = audio_to_pcm16_base64(audio_path)
-
- # Send audio in chunks (4KB of raw audio = ~8KB base64)
- chunk_size = 4096
- audio_bytes = base64.b64decode(audio_base64)
- total_chunks = (len(audio_bytes) + chunk_size - 1) // chunk_size
-
- print(f"Sending {total_chunks} audio chunks...")
- for i in range(0, len(audio_bytes), chunk_size):
- chunk = audio_bytes[i : i + chunk_size]
+ incremental_pcm_parts: list[bytes] = []
+ output_sample_rate = 24000
+ delta_index = 0
+ text_chunks: list[str] = []
+ final_text: str = ""
+
+ if delta_dump_dir is not None:
+ delta_dump_dir.mkdir(parents=True, exist_ok=True)
+
+ async with websockets.connect(url, max_size=64 * 1024 * 1024) as ws:
+ # 1) Validate model.
+ await ws.send(
+ json.dumps(
+ {
+ "type": "session.update",
+ "model": model,
+ }
+ )
+ )
+
+ # 2) Start generation once (non-final commit).
+ await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": False}))
+
+ # 3) Stream audio chunks.
+ for i in range(0, len(pcm16), chunk_bytes):
+ chunk = pcm16[i : i + chunk_bytes]
await ws.send(
json.dumps(
{
@@ -84,63 +121,212 @@ async def realtime_transcribe(audio_path: str, host: str, port: int, model: str)
}
)
)
+ if send_delay_ms > 0:
+ await asyncio.sleep(send_delay_ms / 1000.0)
- # Signal all audio is sent
+ # 4) Final commit closes input stream.
await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": True}))
- print("Audio sent. Waiting for transcription...\n")
- # Receive transcription
- print("Transcription: ", end="", flush=True)
+ # 5) Receive server events until audio done.
while True:
- response = json.loads(await ws.recv())
- if response["type"] == "transcription.delta":
- print(response["delta"], end="", flush=True)
- elif response["type"] == "transcription.done":
- print(f"\n\nFinal transcription: {response['text']}")
- if response.get("usage"):
- print(f"Usage: {response['usage']}")
- break
- elif response["type"] == "error":
- print(f"\nError: {response['error']}")
+ message = await ws.recv()
+ if isinstance(message, bytes):
+ # We only expect JSON text frames.
+ continue
+
+ event = json.loads(message)
+ event_type = event.get("type")
+
+ if event_type == "session.created":
+ continue
+
+ if event_type == "response.audio.delta":
+ sr = event.get("sample_rate_hz")
+ if isinstance(sr, int) and sr > 0:
+ output_sample_rate = sr
+ audio_b64 = event.get("audio", "")
+ if audio_b64:
+ pcm_delta = base64.b64decode(audio_b64)
+ incremental_pcm_parts.append(pcm_delta)
+ if delta_dump_dir is not None and pcm_delta:
+ delta_index += 1
+ dump_path = delta_dump_dir / f"delta_{delta_index:06d}.wav"
+ _write_wav_pcm16(dump_path, pcm_delta, output_sample_rate)
+ print(
+ f"{log_prefix}delta dump #{delta_index}: {dump_path} "
+ f"(pcm bytes={len(pcm_delta)}, sr={output_sample_rate})"
+ )
+ continue
+
+ if event_type == "transcription.delta":
+ delta = event.get("delta", "")
+ if delta:
+ text_chunks.append(delta)
+ print(delta, end="", flush=True)
+ continue
+
+ if event_type == "transcription.done":
+ final_text = event.get("text", "") or "".join(text_chunks)
+ usage = event.get("usage")
+ final_text_with_tag = f"Final transcription: {final_text}"
+ if text_chunks:
+ print()
+ print(f"{log_prefix}{final_text_with_tag}")
+ if usage:
+ print(f"{log_prefix}text usage: {usage}")
+ continue
+
+ if event_type == "response.audio.done":
break
+ if event_type == "error":
+ raise RuntimeError(f"Server error: {event}")
-def main(args):
- if args.audio_path:
- audio_path = args.audio_path
- else:
- # Use default audio asset
- audio_path = str(AudioAsset("mary_had_lamb").get_local_path())
- print(f"No audio path provided, using default: {audio_path}")
+ all_pcm16 = b"".join(incremental_pcm_parts)
+ if not all_pcm16:
+ raise RuntimeError("No audio received from server.")
- asyncio.run(realtime_transcribe(audio_path, args.host, args.port, args.model))
+ output_wav.parent.mkdir(parents=True, exist_ok=True)
+ _write_wav_pcm16(output_wav, all_pcm16, output_sample_rate)
+ print(f"{log_prefix}Saved realtime audio to: {output_wav} (incremental chunks joined)")
+ if output_text is not None:
+ text_to_save = final_text if final_text else "".join(text_chunks)
+ output_text.parent.mkdir(parents=True, exist_ok=True)
+ output_text.write_text(text_to_save, encoding="utf-8")
+ print(f"{log_prefix}Saved realtime text to: {output_text}")
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Realtime WebSocket Transcription Client")
+
+def _indexed_output_path(path: Path | None, index: int, total: int) -> Path | None:
+ if path is None or total <= 1:
+ return path
+ return path.with_name(f"{path.stem}_{index:02d}{path.suffix}")
+
+
+async def run_clients_concurrent(
+ *,
+ url: str,
+ model: str,
+ input_wav: Path,
+ output_wav: Path,
+ output_text: Path | None,
+ chunk_ms: int,
+ send_delay_ms: int,
+ delta_dump_dir: Path | None,
+ num_requests: int,
+ concurrency: int,
+) -> None:
+ sem = asyncio.Semaphore(concurrency)
+
+ async def _run_one(index: int) -> tuple[int, bool, str | None]:
+ per_output_wav = _indexed_output_path(output_wav, index, num_requests)
+ per_output_text = _indexed_output_path(output_text, index, num_requests)
+ per_delta_dir = None
+ if delta_dump_dir is not None:
+ per_delta_dir = delta_dump_dir / f"req_{index:02d}"
+ async with sem:
+ try:
+ await run_client(
+ url=url,
+ model=model,
+ input_wav=input_wav,
+ output_wav=per_output_wav,
+ output_text=per_output_text,
+ chunk_ms=chunk_ms,
+ send_delay_ms=send_delay_ms,
+ delta_dump_dir=per_delta_dir,
+ request_idx=index,
+ total_requests=num_requests,
+ )
+ return index, True, None
+ except Exception as exc:
+ return index, False, str(exc)
+
+ tasks = [asyncio.create_task(_run_one(i), name=f"rt-client-{i}") for i in range(1, num_requests + 1)]
+ results = await asyncio.gather(*tasks)
+
+ failed = [(idx, err) for idx, ok, err in results if not ok]
+ succeeded = num_requests - len(failed)
+ print(f"[summary] succeeded={succeeded}, failed={len(failed)}, total={num_requests}")
+ if failed:
+ for idx, err in failed:
+ print(f"[summary] req {idx:02d} failed: {err}")
+ raise RuntimeError(f"{len(failed)} concurrent request(s) failed")
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Realtime audio/text client for vLLM-Omni")
+ parser.add_argument("--url", default="ws://localhost:8091/v1/realtime", help="WebSocket URL")
parser.add_argument(
"--model",
- type=str,
default="Qwen/Qwen3-Omni-30B-A3B-Instruct",
- help="Model that is served and should be pinged.",
+ help="Model name for session.update",
)
+ parser.add_argument("--input-wav", required=True, type=Path, help="Input WAV (mono, PCM16, 16kHz)")
+ parser.add_argument("--output-wav", default=Path("realtime_output.wav"), type=Path, help="Output WAV path")
parser.add_argument(
- "--audio_path",
- type=str,
+ "--output-text",
default=None,
- help="Path to the audio file to transcribe.",
+ type=Path,
+ help="Optional output text path for final transcription",
)
+ parser.add_argument("--chunk-ms", type=int, default=200, help="Input chunk size in milliseconds")
parser.add_argument(
- "--host",
- type=str,
- default="localhost",
- help="vLLM-Omni server host (default: localhost)",
+ "--send-delay-ms",
+ type=int,
+ default=0,
+ help="Delay between chunk sends; set >0 to simulate realtime upload",
)
parser.add_argument(
- "--port",
+ "--delta-dump-dir",
+ type=Path,
+ default=None,
+ help="If set, each response.audio.delta is saved as delta_NNNNNN.wav under this directory",
+ )
+ parser.add_argument("--num-requests", type=int, default=1, help="Total number of requests to send")
+ parser.add_argument(
+ "--concurrency",
type=int,
- default=8000,
- help="vLLM-Omni server port (default: 8000)",
+ default=1,
+ help="Maximum number of concurrent websocket requests",
)
args = parser.parse_args()
- main(args)
+
+ if args.num_requests <= 0:
+ raise ValueError("--num-requests must be >= 1")
+ if args.concurrency <= 0:
+ raise ValueError("--concurrency must be >= 1")
+ concurrency = min(args.concurrency, args.num_requests)
+
+ if args.num_requests == 1:
+ asyncio.run(
+ run_client(
+ url=args.url,
+ model=args.model,
+ input_wav=args.input_wav,
+ output_wav=args.output_wav,
+ output_text=args.output_text,
+ chunk_ms=args.chunk_ms,
+ send_delay_ms=args.send_delay_ms,
+ delta_dump_dir=args.delta_dump_dir,
+ )
+ )
+ else:
+ asyncio.run(
+ run_clients_concurrent(
+ url=args.url,
+ model=args.model,
+ input_wav=args.input_wav,
+ output_wav=args.output_wav,
+ output_text=args.output_text,
+ chunk_ms=args.chunk_ms,
+ send_delay_ms=args.send_delay_ms,
+ delta_dump_dir=args.delta_dump_dir,
+ num_requests=args.num_requests,
+ concurrency=concurrency,
+ )
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/online_serving/qwen3_tts/README.md b/examples/online_serving/qwen3_tts/README.md
index 5504b5737a..350fcb71ca 100644
--- a/examples/online_serving/qwen3_tts/README.md
+++ b/examples/online_serving/qwen3_tts/README.md
@@ -43,7 +43,7 @@ Then open http://localhost:7860 in your browser.
### Launch the Server
-The default stage config is located at `vllm_omni/model_executor/stage_configs/qwen3_tts.yaml`. For other platforms (e.g., NPU), refer to `vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml`.
+The default deploy config is located at `vllm_omni/deploy/qwen3_tts.yaml` and is loaded automatically by the model registry — no `--deploy-config` flag needed for default use. Platform-specific deltas (NPU, ROCm, XPU) are merged in automatically from the `platforms:` block of the same YAML based on the detected runtime.
```bash
# CustomVoice model (predefined speakers)
@@ -70,6 +70,22 @@ vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
--port 8091
```
+#### Sync vs async-chunk mode
+
+Qwen3-TTS supports both **chunked streaming** (default, lower latency) and
+**synchronous end-to-end** modes from the same deploy YAML. The bundled
+`qwen3_tts.yaml` ships with `async_chunk: true`; flip with `--no-async-chunk`
+and the pipeline automatically dispatches to the end-to-end codec processor:
+
+```bash
+vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice --omni --port 8091 \
+ --no-async-chunk
+```
+
+No variant YAML or extra flag is needed — the `StagePipelineConfig` on each
+stage declares both processor functions and the runtime picks based on the
+`async_chunk:` bool.
+
Alternatively, use the convenience script:
```bash
./run_server.sh # Default: CustomVoice model
@@ -192,14 +208,6 @@ with open("output.wav", "wb") as f:
f.write(response.content)
```
-### FAQ
-
-If you encounter error about backend of librosa, try to install ffmpeg with command below.
-```
-sudo apt update
-sudo apt install ffmpeg
-```
-
## API Reference
### Voices Endpoint
@@ -386,6 +394,54 @@ Server -> Client:
{"type": "session.done", "total_sentences": 1}
```
+## Choosing an Execution Backend: Uniproc vs Multiprocessing
+
+Qwen3-TTS stage configs support two execution backends controlled by the
+`distributed_executor_backend` engine arg. The performance tradeoff between
+them is **both hardware- and task-dependent**, so there is no single best
+default (see [#2603](https://github.com/vllm-project/vllm-omni/issues/2603),
+[#2604](https://github.com/vllm-project/vllm-omni/pull/2604) for the full
+investigation).
+
+| Backend | Stage config setting | Behaviour |
+| ------- | -------------------- | --------- |
+| **Uniproc** (default, world_size=1) | `distributed_executor_backend` omitted | Both stages run inside the orchestrator process. Avoids IPC serialisation, D2H copies, and msgpack overhead between stages. |
+| **Multiprocessing** | `distributed_executor_backend: "mp"` | Each stage runs in its own subprocess. The Talker can continue decoding while Code2Wav runs the vocoder in parallel, improving pipeline utilisation under concurrency. |
+
+> **Note:** When `distributed_executor_backend` is omitted and `world_size=1`,
+> vLLM [automatically uses the uniproc executor](https://github.com/vllm-project/vllm/blob/main/vllm/config/parallel.py#L825).
+> When `world_size > 1`, it defaults to `mp`.
+
+### When uniproc wins
+
+The uniproc path eliminates inter-process data transfer (D2H copies,
+msgpack serialisation/deserialisation, tensor detaching). This matters most
+when per-request processing is heavy relative to autoregressive decode.
+
+The Base cloning task involves reference-audio encoding on every request, making IPC
+overhead a larger fraction of total cost. Qwen3-Omni shows a similar pattern.
+
+### When multiprocessing (`mp`) wins
+
+For lighter per-request workloads, process-level parallelism between the
+Talker and Code2Wav stages dominates.
+
+CustomVoice is lighter per-request (no reference audio encoding), so the
+process-level parallelism of `mp` outweighs its serialisation cost at
+concurrency ≥ 4.
+
+### How to switch
+
+To use the uniproc executor on a single-GPU setup, pass the
+`qwen3_tts_uniproc.yaml` stage config:
+
+```bash
+vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-Base \
+ --omni \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml \
+ --port 8091
+```
+
## Limitations
- **Single request**: Batch processing is not yet optimized for online serving.
diff --git a/examples/online_serving/qwen3_tts/batch_speech_client.py b/examples/online_serving/qwen3_tts/batch_speech_client.py
index 7d48e650f8..47fdc3691c 100644
--- a/examples/online_serving/qwen3_tts/batch_speech_client.py
+++ b/examples/online_serving/qwen3_tts/batch_speech_client.py
@@ -5,11 +5,13 @@
batch level and generate many utterances in the cloned voice without repeating
the reference for each item.
-Start the server (with batch-optimized config for best throughput):
+Start the server (with batch-optimized stage settings for best throughput):
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml \
- --trust-remote-code
+ --omni \
+ --trust-remote-code \
+ --stage-overrides '{"0":{"max_num_seqs":4,"gpu_memory_utilization":0.2},
+ "1":{"max_num_seqs":4,"gpu_memory_utilization":0.2}}'
Examples:
# Batch with a predefined voice
diff --git a/examples/online_serving/qwen3_tts/run_gradio_demo.sh b/examples/online_serving/qwen3_tts/run_gradio_demo.sh
index bcc0ddb7cf..d79be3c2ab 100644
--- a/examples/online_serving/qwen3_tts/run_gradio_demo.sh
+++ b/examples/online_serving/qwen3_tts/run_gradio_demo.sh
@@ -127,7 +127,7 @@ echo "Starting vLLM server..."
LOG_FILE="/tmp/vllm_tts_server_${SERVER_PORT}.log"
vllm-omni serve "$MODEL" \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
--host "$SERVER_HOST" \
--port "$SERVER_PORT" \
--gpu-memory-utilization 0.9 \
diff --git a/examples/online_serving/qwen3_tts/run_server.sh b/examples/online_serving/qwen3_tts/run_server.sh
index 6f4aa83a0b..78dd2c305d 100755
--- a/examples/online_serving/qwen3_tts/run_server.sh
+++ b/examples/online_serving/qwen3_tts/run_server.sh
@@ -31,7 +31,7 @@ esac
echo "Starting Qwen3-TTS server with model: $MODEL"
vllm-omni serve "$MODEL" \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
--host 0.0.0.0 \
--port 8091 \
--gpu-memory-utilization 0.9 \
diff --git a/examples/online_serving/qwen3_tts/speaker_embedding_interpolation.py b/examples/online_serving/qwen3_tts/speaker_embedding_interpolation.py
index e6786f8869..7790fa5127 100644
--- a/examples/online_serving/qwen3_tts/speaker_embedding_interpolation.py
+++ b/examples/online_serving/qwen3_tts/speaker_embedding_interpolation.py
@@ -5,7 +5,7 @@
using SLERP and sends the result to the /v1/audio/speech API.
Requirements:
- pip install torch librosa soundfile numpy httpx
+ pip install torch soundfile numpy httpx
Examples:
# Extract and save an embedding
@@ -143,17 +143,18 @@ def _load_speaker_encoder_weights(encoder: torch.nn.Module, model_path: str) ->
def compute_mel_spectrogram(audio: np.ndarray, sr: int = 24000) -> torch.Tensor:
"""Compute 128-bin mel spectrogram matching Qwen3-TTS's extraction pipeline."""
- import librosa
+ from vllm.multimodal.audio import AudioResampler
# Resample to 24kHz if needed
if sr != 24000:
- audio = librosa.resample(audio.astype(np.float32), orig_sr=sr, target_sr=24000)
+ resampler = AudioResampler(target_sr=24000)
+ audio = resampler.resample(audio.astype(np.float32), orig_sr=sr)
y = torch.from_numpy(audio).unsqueeze(0).float()
- from librosa.filters import mel as librosa_mel_fn
+ from vllm_omni.utils.audio import mel_filter_bank
- mel_basis = torch.from_numpy(librosa_mel_fn(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=12000)).float()
+ mel_basis = mel_filter_bank(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=12000)
n_fft = 1024
hop_size = 256
@@ -180,9 +181,9 @@ def compute_mel_spectrogram(audio: np.ndarray, sr: int = 24000) -> torch.Tensor:
@torch.inference_mode()
def extract_embedding(encoder: torch.nn.Module, audio_path: str, device: str = "cpu") -> np.ndarray:
"""Extract a 1024-dim speaker embedding from an audio file."""
- import librosa
+ from vllm.multimodal.media.audio import load_audio
- audio, sr = librosa.load(audio_path, sr=None, mono=True)
+ audio, sr = load_audio(audio_path, sr=None, mono=True)
mel = compute_mel_spectrogram(audio, sr).to(device)
embedding = encoder(mel.to(next(encoder.parameters()).dtype))[0]
return embedding.float().cpu().numpy()
diff --git a/examples/online_serving/text_to_video/README.md b/examples/online_serving/text_to_video/README.md
index 44e676671f..c01e0602ff 100644
--- a/examples/online_serving/text_to_video/README.md
+++ b/examples/online_serving/text_to_video/README.md
@@ -1,16 +1,27 @@
# Text-To-Video
-This example demonstrates how to deploy the Wan2.2 text-to-video model for online video generation using vLLM-Omni.
+This example demonstrates how to deploy text-to-video models for online video generation using vLLM-Omni.
-## Start Server
+## Supported Models
-### Basic Start
+| Model | Model ID |
+|-------|----------|
+| Wan2.1 T2V (1.3B) | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` |
+| Wan2.1 T2V (14B) | `Wan-AI/Wan2.1-T2V-14B-Diffusers` |
+| Wan2.2 T2V | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` |
+| LTX-2 | `Lightricks/LTX-2` |
+
+## Wan2.2 T2V
+
+### Start Server
+
+#### Basic Start
```bash
vllm serve Wan-AI/Wan2.2-T2V-A14B-Diffusers --omni --port 8091
```
-### Start with Parameters
+#### Start with Parameters
Or use the startup script:
@@ -230,3 +241,82 @@ while true; do
sleep 2
done
```
+
+## LTX-2
+
+### Start Server
+
+#### Basic Start
+
+```bash
+vllm serve Lightricks/LTX-2 --omni --port 8098 \
+ --enforce-eager --flow-shift 1.0 --boundary-ratio 1.0
+```
+
+#### Start with Optimization Presets
+
+Use the LTX-2 startup script with built-in optimization presets:
+
+```bash
+# Baseline (1 GPU, eager)
+bash run_server_ltx2.sh baseline
+
+# 4-GPU Ulysses sequence parallelism (lossless)
+bash run_server_ltx2.sh ulysses4
+
+# Cache-DiT lossy acceleration (1 GPU, ~1.4× speedup)
+bash run_server_ltx2.sh cache-dit
+
+# Best combo: 4-GPU Ulysses SP + Cache-DiT (~2.2× speedup)
+bash run_server_ltx2.sh best-combo
+```
+
+#### Optimization Benchmarks
+
+Benchmarked on H800, online serving (480×768, 41 frames, 20 steps, `seed=42`).
+"Inference" is the server-reported inference time; excludes HTTP/poll overhead.
+
+| Preset | Server Command | Inference (s) | Speedup | Type |
+|--------|---------------|---------------|---------|------|
+| `baseline` | `--enforce-eager` | 10.3 | 1.00× | — |
+| `compile` | *(default, no --enforce-eager)* | ~10.3 (warm) | ~1.00× | Lossless |
+| `ulysses4` | `--enforce-eager --usp 4` | ~10.3 | ~1.00× | Lossless |
+| `cache-dit` | `--enforce-eager --cache-backend cache_dit` | 7.4 avg | ~1.4× | Lossy |
+| `best-combo` | `--enforce-eager --usp 4 --cache-backend cache_dit` | 4.7 avg | **~2.2×** | Lossless + Lossy |
+
+**Observations**:
+- **torch.compile**: On H800, warm-request inference time matches the eager baseline (~10.3s).
+ The first request pays ~6s compilation overhead. Benefit depends on model architecture and GPU.
+- **Ulysses SP (4 GPU)**: No measurable speedup alone for 41-frame generation at this resolution.
+ Communication overhead outweighs gains at this sequence length.
+- **Cache-DiT**: Inference varies per request (6–10s) due to dynamic caching decisions.
+ Average is ~7.4s (~1.4× speedup) with slight quality tradeoff.
+- **Best combo**: 4-GPU Ulysses SP + Cache-DiT synergize well — Cache-DiT reduces per-step
+ computation, making the communication overhead of Ulysses SP worthwhile. Average ~4.7s
+ (~2.2× speedup).
+- **FP8 quantization**: Reduces VRAM but does not speed up LTX-2 on H800 (compute-bound).
+
+**Deployment Recommendations**:
+- For **production with quality priority**: use `baseline` with `--enforce-eager`
+- For **maximum throughput** (4 GPUs, quality tradeoff): use `best-combo` (~2.2× speedup)
+- For **single-GPU throughput**: use `cache-dit` (~1.4× speedup)
+- `--enforce-eager` is recommended to avoid torch.compile warmup latency on first request
+
+### Send Requests (curl)
+
+```bash
+# Using the provided script
+bash run_curl_ltx2.sh
+
+# Or directly
+curl -sS -X POST http://localhost:8098/v1/videos \
+ -H "Accept: application/json" \
+ -F "prompt=A serene lakeside sunrise with mist over the water." \
+ -F "width=768" \
+ -F "height=480" \
+ -F "num_frames=41" \
+ -F "fps=24" \
+ -F "num_inference_steps=20" \
+ -F "guidance_scale=3.0" \
+ -F "seed=42"
+```
diff --git a/examples/online_serving/text_to_video/run_curl_ltx2.sh b/examples/online_serving/text_to_video/run_curl_ltx2.sh
new file mode 100644
index 0000000000..b82f672eaa
--- /dev/null
+++ b/examples/online_serving/text_to_video/run_curl_ltx2.sh
@@ -0,0 +1,66 @@
+#!/bin/bash
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+#
+# LTX-2 text-to-video curl example using the async video job API.
+# Start the server first: bash run_server_ltx2.sh best-combo
+
+set -euo pipefail
+
+BASE_URL="${BASE_URL:-http://localhost:8098}"
+OUTPUT_PATH="${OUTPUT_PATH:-ltx2_output.mp4}"
+POLL_INTERVAL="${POLL_INTERVAL:-2}"
+
+PROMPT="${PROMPT:-A serene lakeside sunrise with mist over the water.}"
+
+create_response=$(
+ curl -sS -X POST "${BASE_URL}/v1/videos" \
+ -H "Accept: application/json" \
+ -F "prompt=${PROMPT}" \
+ -F "width=768" \
+ -F "height=480" \
+ -F "num_frames=41" \
+ -F "fps=24" \
+ -F "num_inference_steps=20" \
+ -F "guidance_scale=3.0" \
+ -F "seed=42"
+)
+
+video_id="$(echo "${create_response}" | jq -r '.id')"
+if [ -z "${video_id}" ] || [ "${video_id}" = "null" ]; then
+ echo "Failed to create video job:"
+ echo "${create_response}" | jq .
+ exit 1
+fi
+
+echo "Created video job ${video_id}"
+echo "${create_response}" | jq .
+
+while true; do
+ status_response="$(curl -sS "${BASE_URL}/v1/videos/${video_id}")"
+ status="$(echo "${status_response}" | jq -r '.status')"
+
+ case "${status}" in
+ queued|in_progress)
+ echo "Video job ${video_id} status: ${status}"
+ sleep "${POLL_INTERVAL}"
+ ;;
+ completed)
+ echo "${status_response}" | jq .
+ break
+ ;;
+ failed)
+ echo "Video generation failed:"
+ echo "${status_response}" | jq .
+ exit 1
+ ;;
+ *)
+ echo "Unexpected status response:"
+ echo "${status_response}" | jq .
+ exit 1
+ ;;
+ esac
+done
+
+curl -sS -L "${BASE_URL}/v1/videos/${video_id}/content" -o "${OUTPUT_PATH}"
+echo "Saved video to ${OUTPUT_PATH}"
diff --git a/examples/online_serving/text_to_video/run_server_ltx2.sh b/examples/online_serving/text_to_video/run_server_ltx2.sh
new file mode 100644
index 0000000000..f4597d3cd2
--- /dev/null
+++ b/examples/online_serving/text_to_video/run_server_ltx2.sh
@@ -0,0 +1,84 @@
+#!/bin/bash
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+#
+# LTX-2 online serving startup script with optimization presets.
+#
+# Usage:
+# bash run_server_ltx2.sh # baseline (1 GPU, eager)
+# bash run_server_ltx2.sh ulysses4 # 4-GPU Ulysses SP
+# bash run_server_ltx2.sh cache-dit # 1 GPU + Cache-DiT
+# bash run_server_ltx2.sh best-combo # 4-GPU Ulysses SP + Cache-DiT
+#
+# Online serving benchmarks on H800 (480×768, 41 frames, 20 steps):
+# baseline : 10.3s inference (1.00×)
+# compile : ~10.3s warm (~1.00×) first request +6s warmup
+# ulysses4 : ~10.3s (~1.00×) no gain at 41 frames
+# cache-dit : 7.4s avg (~1.4×) lossy, variable per request
+# best-combo : 4.7s avg (~2.2×) 4-GPU ulysses + cache-dit
+
+set -euo pipefail
+
+MODEL="${MODEL:-Lightricks/LTX-2}"
+PORT="${PORT:-8098}"
+FLOW_SHIFT="${FLOW_SHIFT:-1.0}"
+BOUNDARY_RATIO="${BOUNDARY_RATIO:-1.0}"
+
+PRESET="${1:-baseline}"
+
+EXTRA_ARGS=()
+case "$PRESET" in
+ baseline)
+ echo "=== LTX-2 Preset: baseline (1 GPU, enforce-eager) ==="
+ EXTRA_ARGS+=(--enforce-eager)
+ ;;
+ ulysses2)
+ echo "=== LTX-2 Preset: 2-GPU Ulysses SP (lossless) ==="
+ EXTRA_ARGS+=(--enforce-eager --usp 2)
+ ;;
+ ulysses4)
+ echo "=== LTX-2 Preset: 4-GPU Ulysses SP (lossless) ==="
+ EXTRA_ARGS+=(--enforce-eager --usp 4)
+ ;;
+ cache-dit)
+ echo "=== LTX-2 Preset: Cache-DiT (1 GPU, lossy) ==="
+ EXTRA_ARGS+=(--enforce-eager --cache-backend cache_dit)
+ ;;
+ best-combo)
+ echo "=== LTX-2 Preset: 4-GPU Ulysses SP + Cache-DiT (best combo) ==="
+ EXTRA_ARGS+=(--enforce-eager --usp 4 --cache-backend cache_dit)
+ ;;
+ compile)
+ echo "=== LTX-2 Preset: torch.compile (1 GPU, lossless) ==="
+ # torch.compile is the default (no --enforce-eager)
+ ;;
+ *)
+ echo "Usage: $0 {baseline|ulysses2|ulysses4|cache-dit|best-combo|compile}"
+ echo ""
+ echo "Presets:"
+ echo " baseline - 1 GPU, eager execution (reference)"
+ echo " ulysses2 - 2-GPU Ulysses SP (lossless)"
+ echo " ulysses4 - 4-GPU Ulysses SP (lossless)"
+ echo " cache-dit - 1 GPU + Cache-DiT (lossy, ~1.4× speedup)"
+ echo " best-combo - 4-GPU Ulysses SP + Cache-DiT (~2.2× speedup)"
+ echo " compile - 1 GPU + torch.compile (slower first request)"
+ echo ""
+ echo "Environment variables:"
+ echo " MODEL - Model path (default: Lightricks/LTX-2)"
+ echo " PORT - Server port (default: 8098)"
+ echo " FLOW_SHIFT - Scheduler flow shift (default: 1.0)"
+ echo " BOUNDARY_RATIO - Boundary ratio (default: 1.0)"
+ exit 1
+ ;;
+esac
+
+echo "Model: $MODEL"
+echo "Port: $PORT"
+echo "Flow shift: $FLOW_SHIFT"
+echo "Boundary ratio: $BOUNDARY_RATIO"
+
+vllm serve "$MODEL" --omni \
+ --port "$PORT" \
+ --flow-shift "$FLOW_SHIFT" \
+ --boundary-ratio "$BOUNDARY_RATIO" \
+ "${EXTRA_ARGS[@]}"
diff --git a/examples/online_serving/voxcpm/README.md b/examples/online_serving/voxcpm/README.md
new file mode 100644
index 0000000000..78e1bf4aaa
--- /dev/null
+++ b/examples/online_serving/voxcpm/README.md
@@ -0,0 +1,166 @@
+# VoxCPM
+
+## Prerequisites
+
+Install VoxCPM in one of these ways:
+
+```bash
+pip install voxcpm
+```
+
+or point vLLM-Omni to a local VoxCPM source tree:
+
+```bash
+export VLLM_OMNI_VOXCPM_CODE_PATH=/path/to/VoxCPM/src
+```
+
+If the native VoxCPM `config.json` lacks HF metadata such as `model_type`,
+prepare a persistent HF-compatible config directory and export:
+
+```bash
+export VLLM_OMNI_VOXCPM_HF_CONFIG_PATH=/tmp/voxcpm_hf_config
+mkdir -p "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH"
+cp "$VOXCPM_MODEL/config.json" "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH/config.json"
+cp "$VOXCPM_MODEL/generation_config.json" "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH/generation_config.json" 2>/dev/null || true
+python3 -c 'import json, os; p=os.path.join(os.environ["VLLM_OMNI_VOXCPM_HF_CONFIG_PATH"], "config.json"); cfg=json.load(open(p, "r", encoding="utf-8")); cfg["model_type"]="voxcpm"; cfg.setdefault("architectures", ["VoxCPMForConditionalGeneration"]); json.dump(cfg, open(p, "w", encoding="utf-8"), indent=2, ensure_ascii=False)'
+```
+
+The VoxCPM stage configs read `VLLM_OMNI_VOXCPM_HF_CONFIG_PATH` directly. The `python3 -c` form above avoids heredoc/indentation issues in interactive shells.
+
+## Launch the Server
+
+Use the async-chunk stage config by default:
+
+```bash
+export VOXCPM_MODEL=/path/to/voxcpm-model
+cd examples/online_serving/voxcpm
+./run_server.sh
+```
+
+Use the non-streaming stage config:
+
+```bash
+./run_server.sh sync
+```
+
+You can also launch the server directly:
+
+```bash
+vllm serve "$VOXCPM_MODEL" \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml \
+ --trust-remote-code \
+ --enforce-eager \
+ --omni \
+ --port 8091
+```
+
+## Send Requests
+
+### Basic text-to-speech
+
+```bash
+python openai_speech_client.py \
+ --model "$VOXCPM_MODEL" \
+ --text "This is a VoxCPM online text-to-speech example."
+```
+
+### Voice cloning
+
+```bash
+python openai_speech_client.py \
+ --model "$VOXCPM_MODEL" \
+ --text "This sentence is synthesized with a cloned voice." \
+ --ref-audio /path/to/reference.wav \
+ --ref-text "The exact transcript spoken in reference.wav."
+```
+
+`ref_text` must be the real transcript of the reference audio. Placeholder text or mismatched text will usually degrade quality badly.
+
+### Streaming PCM output
+
+```bash
+python openai_speech_client.py \
+ --model "$VOXCPM_MODEL" \
+ --text "This is a streaming VoxCPM request." \
+ --stream \
+ --output voxcpm_stream.pcm
+```
+
+### Using curl
+
+```bash
+curl -X POST http://localhost:8091/v1/audio/speech \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "OpenBMB/VoxCPM1.5",
+ "input": "Hello from VoxCPM online serving.",
+ "response_format": "wav"
+ }' --output output.wav
+```
+
+Voice cloning:
+
+```bash
+curl -X POST http://localhost:8091/v1/audio/speech \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "OpenBMB/VoxCPM1.5",
+ "input": "This sentence uses a cloned voice.",
+ "ref_audio": "https://example.com/reference.wav",
+ "ref_text": "The exact transcript spoken in the reference audio.",
+ "response_format": "wav"
+ }' --output cloned.wav
+```
+
+Streaming PCM:
+
+```bash
+curl -X POST http://localhost:8091/v1/audio/speech \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "OpenBMB/VoxCPM1.5",
+ "input": "This is a streaming VoxCPM request.",
+ "stream": true,
+ "response_format": "pcm"
+ }' --output output.pcm
+```
+
+## Supported Request Shape
+
+VoxCPM online serving currently supports:
+
+- plain text-to-speech
+- voice cloning with `ref_audio` + `ref_text`
+- `stream=true` with `response_format=pcm` or `wav`
+
+VoxCPM online serving does not use these generic TTS fields:
+
+- `voice`
+- `instructions`
+- `language`
+- `speaker_embedding`
+- `x_vector_only_mode`
+
+## Streaming vs Non-Streaming
+
+- `voxcpm_async_chunk.yaml` enables async-chunk streaming and is best for single-request streaming latency.
+- `voxcpm.yaml` performs one-shot latent generation then VAE decode.
+
+Like native VoxCPM, the async streaming path should be treated as single-request. If you need stable throughput benchmarking, prefer `voxcpm.yaml`.
+
+Do not use `voxcpm_async_chunk.yaml` for concurrent online streaming or `/v1/audio/speech/batch`. For multiple requests, prefer `voxcpm.yaml`.
+
+## Benchmark
+
+The serving benchmark reports TTFP and RTF:
+
+```bash
+python benchmarks/voxcpm/vllm_omni/bench_tts_serve.py \
+ --host 127.0.0.1 \
+ --port 8091 \
+ --num-prompts 10 \
+ --max-concurrency 1 \
+ --result-dir /tmp/voxcpm_bench
+```
+
+For the async-chunk server, keep `--max-concurrency 1`.
diff --git a/examples/online_serving/voxcpm/openai_speech_client.py b/examples/online_serving/voxcpm/openai_speech_client.py
new file mode 100644
index 0000000000..c400114e8b
--- /dev/null
+++ b/examples/online_serving/voxcpm/openai_speech_client.py
@@ -0,0 +1,155 @@
+"""OpenAI-compatible client for VoxCPM via /v1/audio/speech.
+
+Examples:
+ # Basic text-to-speech
+ python openai_speech_client.py --text "Hello from VoxCPM"
+
+ # Voice cloning
+ python openai_speech_client.py \
+ --text "This sentence uses the cloned voice." \
+ --ref-audio /path/to/reference.wav \
+ --ref-text "The exact transcript spoken in the reference audio."
+
+ # Streaming PCM output
+ python openai_speech_client.py \
+ --text "This is a streaming VoxCPM request." \
+ --stream \
+ --output output.pcm
+"""
+
+import argparse
+import base64
+import os
+
+import httpx
+
+DEFAULT_API_BASE = "http://localhost:8091"
+DEFAULT_API_KEY = "EMPTY"
+DEFAULT_MODEL = "OpenBMB/VoxCPM1.5"
+
+
+def encode_audio_to_base64(audio_path: str) -> str:
+ """Encode a local audio file to base64 data URL."""
+ if not os.path.exists(audio_path):
+ raise FileNotFoundError(f"Audio file not found: {audio_path}")
+
+ ext = audio_path.lower().rsplit(".", 1)[-1]
+ mime_map = {
+ "wav": "audio/wav",
+ "mp3": "audio/mpeg",
+ "flac": "audio/flac",
+ "ogg": "audio/ogg",
+ }
+ mime_type = mime_map.get(ext, "audio/wav")
+
+ with open(audio_path, "rb") as f:
+ audio_b64 = base64.b64encode(f.read()).decode("utf-8")
+ return f"data:{mime_type};base64,{audio_b64}"
+
+
+def build_payload(args) -> dict[str, object]:
+ payload: dict[str, object] = {
+ "model": args.model,
+ "input": args.text,
+ "response_format": "pcm" if args.stream else args.response_format,
+ }
+
+ if args.ref_audio:
+ if args.ref_audio.startswith(("http://", "https://", "data:")):
+ payload["ref_audio"] = args.ref_audio
+ else:
+ payload["ref_audio"] = encode_audio_to_base64(args.ref_audio)
+ if args.ref_text:
+ payload["ref_text"] = args.ref_text
+ if args.max_new_tokens is not None:
+ payload["max_new_tokens"] = args.max_new_tokens
+ if args.stream:
+ payload["stream"] = True
+
+ return payload
+
+
+def run_tts(args) -> None:
+ payload = build_payload(args)
+ api_url = f"{args.api_base}/v1/audio/speech"
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {args.api_key}",
+ }
+
+ print(f"Model: {args.model}")
+ print(f"Text: {args.text}")
+ if args.ref_audio:
+ print("Mode: voice cloning")
+ print(f"Reference audio: {args.ref_audio}")
+ else:
+ print("Mode: text-to-speech")
+
+ if args.stream:
+ output_path = args.output or "voxcpm_output.pcm"
+ with httpx.Client(timeout=300.0) as client:
+ with client.stream("POST", api_url, json=payload, headers=headers) as response:
+ if response.status_code != 200:
+ print(f"Error: {response.status_code}")
+ print(response.read().decode("utf-8", errors="ignore"))
+ return
+
+ total_bytes = 0
+ with open(output_path, "wb") as f:
+ for chunk in response.iter_bytes():
+ if not chunk:
+ continue
+ f.write(chunk)
+ total_bytes += len(chunk)
+ print(f"Streamed {total_bytes} bytes to: {output_path}")
+ return
+
+ with httpx.Client(timeout=300.0) as client:
+ response = client.post(api_url, json=payload, headers=headers)
+
+ if response.status_code != 200:
+ print(f"Error: {response.status_code}")
+ print(response.text)
+ return
+
+ try:
+ text = response.content.decode("utf-8")
+ if text.startswith('{"error"'):
+ print(f"Error: {text}")
+ return
+ except UnicodeDecodeError:
+ pass
+
+ output_path = args.output or "voxcpm_output.wav"
+ with open(output_path, "wb") as f:
+ f.write(response.content)
+ print(f"Audio saved to: {output_path}")
+
+
+def main():
+ parser = argparse.ArgumentParser(description="VoxCPM OpenAI-compatible speech client")
+ parser.add_argument("--api-base", default=DEFAULT_API_BASE, help="API base URL")
+ parser.add_argument("--api-key", default=DEFAULT_API_KEY, help="API key")
+ parser.add_argument("--model", "-m", default=DEFAULT_MODEL, help="Model name or path")
+ parser.add_argument("--text", required=True, help="Text to synthesize")
+ parser.add_argument("--ref-audio", default=None, help="Reference audio path, URL, or data URL")
+ parser.add_argument(
+ "--ref-text",
+ default=None,
+ help="The exact transcript spoken in the reference audio",
+ )
+ parser.add_argument("--stream", action="store_true", help="Enable streaming PCM output")
+ parser.add_argument(
+ "--response-format",
+ default="wav",
+ choices=["wav", "pcm", "flac", "mp3", "aac", "opus"],
+ help="Audio format for non-streaming mode (default: wav)",
+ )
+ parser.add_argument("--max-new-tokens", type=int, default=None, help="Maximum tokens to generate")
+ parser.add_argument("--output", "-o", default=None, help="Output file path")
+ args = parser.parse_args()
+ run_tts(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/online_serving/voxcpm/run_server.sh b/examples/online_serving/voxcpm/run_server.sh
new file mode 100755
index 0000000000..ab4b6fe854
--- /dev/null
+++ b/examples/online_serving/voxcpm/run_server.sh
@@ -0,0 +1,38 @@
+#!/bin/bash
+# Launch vLLM-Omni server for VoxCPM online speech serving.
+#
+# Usage:
+# ./run_server.sh # default: async_chunk stage config
+# ./run_server.sh async # async_chunk stage config
+# ./run_server.sh sync # no-async-chunk stage config
+# VOXCPM_MODEL=/path/to/model ./run_server.sh
+
+set -e
+
+MODE="${1:-async}"
+MODEL="${VOXCPM_MODEL:-OpenBMB/VoxCPM1.5}"
+
+case "$MODE" in
+ async)
+ STAGE_CONFIG="vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml"
+ ;;
+ sync)
+ STAGE_CONFIG="vllm_omni/model_executor/stage_configs/voxcpm.yaml"
+ ;;
+ *)
+ echo "Unknown mode: $MODE"
+ echo "Supported: async, sync"
+ exit 1
+ ;;
+esac
+
+echo "Starting VoxCPM server with model: $MODEL"
+echo "Stage config: $STAGE_CONFIG"
+
+vllm serve "$MODEL" \
+ --stage-configs-path "$STAGE_CONFIG" \
+ --host 0.0.0.0 \
+ --port 8091 \
+ --trust-remote-code \
+ --enforce-eager \
+ --omni
diff --git a/examples/online_serving/voxcpm2/README.md b/examples/online_serving/voxcpm2/README.md
new file mode 100644
index 0000000000..8735180f0a
--- /dev/null
+++ b/examples/online_serving/voxcpm2/README.md
@@ -0,0 +1,42 @@
+# VoxCPM2 Online Serving
+
+Serve VoxCPM2 TTS via the OpenAI-compatible `/v1/audio/speech` endpoint.
+
+## Start the Server
+
+```bash
+python -m vllm_omni.entrypoints.openai.api_server \
+ --model openbmb/VoxCPM2 \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm2.yaml \
+ --host 0.0.0.0 --port 8000
+```
+
+## Zero-shot Synthesis
+
+```bash
+python openai_speech_client.py --text "Hello, this is VoxCPM2."
+```
+
+Or with curl:
+
+```bash
+curl -X POST http://localhost:8000/v1/audio/speech \
+ -H "Content-Type: application/json" \
+ -d '{"model": "voxcpm2", "input": "Hello, this is VoxCPM2.", "voice": "default"}' \
+ --output output.wav
+```
+
+## Voice Cloning
+
+Clone a speaker's voice using a reference audio file:
+
+```bash
+python openai_speech_client.py \
+ --text "This should sound like the reference speaker." \
+ --ref-audio /path/to/reference.wav
+```
+
+The `--ref-audio` parameter accepts:
+- Local file path (auto-encoded to base64)
+- URL (`https://...`)
+- Base64 data URI (`data:audio/wav;base64,...`)
diff --git a/examples/online_serving/voxcpm2/gradio_demo.py b/examples/online_serving/voxcpm2/gradio_demo.py
new file mode 100644
index 0000000000..a33a2d9245
--- /dev/null
+++ b/examples/online_serving/voxcpm2/gradio_demo.py
@@ -0,0 +1,602 @@
+"""Gradio demo for VoxCPM2 TTS with gapless streaming audio playback.
+
+Uses a custom AudioWorklet-based player for gap-free streaming
+(adapted from the Qwen3-TTS demo). Audio is streamed from the vLLM
+server through a same-origin proxy and played via the Web Audio API's
+AudioWorklet, which maintains a FIFO buffer queue and plays samples at
+the audio clock rate.
+
+Usage:
+ # Start the vLLM server first:
+ python -m vllm_omni.entrypoints.openai.api_server \
+ --model openbmb/VoxCPM2 \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm2.yaml \
+ --host 0.0.0.0 --port 8000
+
+ # Then launch the demo:
+ python gradio_demo.py --api-base http://localhost:8000
+"""
+
+from __future__ import annotations
+
+import argparse
+import base64
+import io
+import json
+import logging
+
+import gradio as gr
+import httpx
+import numpy as np
+import soundfile as sf
+from fastapi import FastAPI, Request
+from fastapi.responses import Response, StreamingResponse
+
+logger = logging.getLogger(__name__)
+
+SAMPLE_RATE = 48000
+
+# ── AudioWorklet processor (loaded in browser via Blob URL) ──────────
+WORKLET_JS = r"""
+class TTSPlaybackProcessor extends AudioWorkletProcessor {
+ constructor() {
+ super();
+ this.queue = [];
+ this.buf = null;
+ this.pos = 0;
+ this.playing = false;
+ this.played = 0;
+ this.port.onmessage = (e) => {
+ if (e.data && e.data.type === 'clear') {
+ this.queue = []; this.buf = null; this.pos = 0; this.played = 0;
+ if (this.playing) { this.playing = false; this.port.postMessage({type:'stopped'}); }
+ return;
+ }
+ this.queue.push(e.data);
+ };
+ }
+ process(inputs, outputs) {
+ const out = outputs[0][0];
+ for (let i = 0; i < out.length; i++) {
+ if (!this.buf || this.pos >= this.buf.length) {
+ if (this.queue.length > 0) {
+ this.buf = this.queue.shift(); this.pos = 0;
+ } else {
+ for (let j = i; j < out.length; j++) out[j] = 0;
+ if (this.playing) { this.playing = false; this.port.postMessage({type:'stopped', played:this.played}); }
+ return true;
+ }
+ }
+ out[i] = this.buf[this.pos++] / 32768;
+ this.played++;
+ }
+ if (!this.playing) { this.playing = true; this.port.postMessage({type:'started'}); }
+ return true;
+ }
+}
+registerProcessor('tts-playback-processor', TTSPlaybackProcessor);
+"""
+
+PLAYER_HTML = """
+
+"""
+
+
+def _build_player_js() -> str:
+ return f"""
+
+"""
+
+
+def _encode_audio(audio_data: tuple) -> str:
+ sr, audio_np = audio_data
+ if audio_np.dtype in (np.float32, np.float64):
+ audio_np = np.clip(audio_np, -1.0, 1.0)
+ audio_np = (audio_np * 32767).astype(np.int16)
+ elif audio_np.dtype != np.int16:
+ audio_np = audio_np.astype(np.int16)
+ buf = io.BytesIO()
+ sf.write(buf, audio_np, sr, format="WAV")
+ return f"data:audio/wav;base64,{base64.b64encode(buf.getvalue()).decode()}"
+
+
+def create_app(api_base: str):
+ app = FastAPI()
+ _pending: dict[str, dict] = {}
+
+ @app.post("/proxy/v1/audio/speech")
+ async def proxy_speech(request: Request):
+ body = await request.json()
+ req_id = body.get("_req_id")
+ if req_id and req_id in _pending:
+ body = _pending.pop(req_id)
+ logger.info("Proxy: %s", {k: (f"<{len(str(v))} chars>" if k == "ref_audio" else v) for k, v in body.items()})
+ try:
+ client = httpx.AsyncClient(timeout=300)
+ resp = await client.send(
+ client.build_request(
+ "POST",
+ f"{api_base}/v1/audio/speech",
+ json=body,
+ headers={"Authorization": "Bearer EMPTY", "Content-Type": "application/json"},
+ ),
+ stream=True,
+ )
+ except Exception as exc:
+ logger.exception("Proxy connection error")
+ await client.aclose()
+ return Response(content=str(exc), status_code=502)
+ if resp.status_code != 200:
+ content = await resp.aread()
+ await resp.aclose()
+ await client.aclose()
+ return Response(content=content, status_code=resp.status_code)
+
+ async def relay():
+ try:
+ async for chunk in resp.aiter_bytes():
+ yield chunk
+ finally:
+ await resp.aclose()
+ await client.aclose()
+
+ return StreamingResponse(relay(), media_type="application/octet-stream")
+
+ css = """
+ #generate-btn button { width: 100%; }
+ #streaming-player { border: 1px solid var(--border-color-primary) !important; border-radius: var(--block-radius) !important; padding: var(--block-padding) !important; }
+ """
+ theme = gr.themes.Default(
+ primary_hue=gr.themes.Color(
+ c50="#f0f5ff",
+ c100="#dce6f9",
+ c200="#b8cef3",
+ c300="#8eb2eb",
+ c400="#6496e0",
+ c500="#4A90D9",
+ c600="#3a7bc8",
+ c700="#2d66b0",
+ c800="#1f4f8f",
+ c900="#163a6e",
+ c950="#0e2650",
+ ),
+ )
+
+ with gr.Blocks(title="VoxCPM2 TTS Demo") as demo:
+ gr.HTML(f"""
+
+
+
+
VoxCPM2 Streaming Demo
+
+ Served by vLLM-Omni
+ · {api_base}
+ · 48 kHz
+
+
+
+ """)
+
+ gr.Markdown(
+ "**Three modes:** "
+ "**Voice Design** (control instruction only) · "
+ "**Controllable Cloning** (ref audio + optional style control) · "
+ "**Ultimate Cloning** (ref audio + transcript for audio continuation)"
+ )
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ text_input = gr.Textbox(
+ label="Target Text",
+ placeholder="Enter text to synthesize...",
+ lines=4,
+ )
+ control_instruction = gr.Textbox(
+ label="Control Instruction (optional)",
+ placeholder="e.g. A warm young woman / Excited and fast-paced",
+ lines=2,
+ info="Describe voice style, emotion, pace. Works for both Voice Design and Controllable Cloning.",
+ )
+
+ with gr.Accordion("Voice Cloning", open=False):
+ ref_audio = gr.Audio(
+ label="Reference Audio (upload for cloning)",
+ type="numpy",
+ sources=["upload", "microphone"],
+ )
+ ref_audio_url = gr.Textbox(
+ label="or Reference Audio URL",
+ placeholder="https://example.com/reference.wav",
+ )
+ ultimate_clone = gr.Checkbox(
+ label="Ultimate Cloning Mode",
+ value=False,
+ info="Provide transcript of ref audio for audio continuation (disables control instruction)",
+ )
+ prompt_text = gr.Textbox(
+ label="Reference Audio Transcript",
+ placeholder="Transcript of your reference audio (for ultimate cloning)",
+ lines=2,
+ visible=False,
+ )
+
+ with gr.Row():
+ stream_checkbox = gr.Checkbox(
+ label="Stream (gapless)",
+ value=True,
+ info="AudioWorklet streaming",
+ )
+ with gr.Row():
+ generate_btn = gr.Button(
+ "Generate Speech",
+ variant="primary",
+ size="lg",
+ elem_id="generate-btn",
+ scale=3,
+ )
+ reset_btn = gr.Button("Reset", variant="secondary", size="lg", scale=1)
+
+ with gr.Column(scale=2):
+ player_html = gr.HTML(
+ value=PLAYER_HTML,
+ visible=True,
+ label="streaming player",
+ elem_id="streaming-player",
+ )
+ audio_output = gr.Audio(
+ label="generated audio",
+ interactive=False,
+ autoplay=True,
+ visible=False,
+ )
+ gr.Examples(
+ examples=[
+ ["Hello, this is a VoxCPM2 demo running on vLLM-Omni.", ""],
+ [
+ "I have a dream that my four little children will one day live in a nation "
+ "where they will not be judged by the color of their skin but by the content "
+ "of their character.",
+ "",
+ ],
+ [
+ "I never asked you to stay. It's not like I care or anything. "
+ "But why does it still hurt so much now that you're gone?",
+ "A young girl with a soft, sweet voice. Speaks slowly with a melancholic tone.",
+ ],
+ ],
+ inputs=[text_input, control_instruction],
+ label="examples",
+ )
+ gr.HTML("""
+
+ """)
+
+ hidden_payload = gr.Textbox(visible=False, elem_id="tts-payload")
+
+ def on_ultimate_toggle(checked):
+ return (
+ gr.update(visible=checked), # prompt_text
+ gr.update(interactive=not checked), # control_instruction
+ )
+
+ ultimate_clone.change(
+ fn=on_ultimate_toggle,
+ inputs=[ultimate_clone],
+ outputs=[prompt_text, control_instruction],
+ )
+
+ def on_stream_change(stream: bool):
+ if stream:
+ return gr.update(visible=True), gr.update(visible=False)
+ return gr.update(visible=False), gr.update(visible=True)
+
+ stream_checkbox.change(
+ fn=on_stream_change,
+ inputs=[stream_checkbox],
+ outputs=[player_html, audio_output],
+ )
+
+ def on_reset():
+ return "", "", None, "", False, "", PLAYER_HTML
+
+ reset_btn.click(
+ fn=on_reset,
+ outputs=[
+ text_input,
+ control_instruction,
+ audio_output,
+ hidden_payload,
+ ultimate_clone,
+ prompt_text,
+ player_html,
+ ],
+ js="() => { if (window.ttsStop) window.ttsStop(); }",
+ )
+
+ def on_generate(stream_enabled, text, ctrl_instr, ref_a, ref_url, ult_clone, p_text):
+ import time as _time
+
+ if not text or not text.strip():
+ raise gr.Error("Please enter text to synthesize.")
+
+ # VoxCPM2 uses "(instruction)text" format for control
+ ctrl = ctrl_instr.strip() if ctrl_instr and not ult_clone else ""
+ final_text = f"({ctrl}){text.strip()}" if ctrl else text.strip()
+
+ payload: dict = {
+ "input": final_text,
+ "voice": "default",
+ "response_format": "pcm" if stream_enabled else "wav",
+ "stream": stream_enabled,
+ }
+
+ # Reference audio for cloning
+ ref_url_s = ref_url.strip() if ref_url else ""
+ if ref_url_s:
+ payload["ref_audio"] = ref_url_s
+ elif ref_a is not None:
+ payload["ref_audio"] = _encode_audio(ref_a)
+
+ # Ultimate cloning: prompt_audio + prompt_text for continuation
+ if ult_clone and p_text and p_text.strip():
+ if ref_url_s:
+ payload["prompt_audio"] = ref_url_s
+ elif ref_a is not None:
+ payload["prompt_audio"] = payload.get("ref_audio", "")
+ payload["prompt_text"] = p_text.strip()
+
+ if stream_enabled:
+ if ref_a is not None and not ref_url_s:
+ req_id = f"req-{int(_time.time() * 1000)}"
+ _pending[req_id] = payload
+ browser_payload = {"_req_id": req_id, "_nonce": int(_time.time() * 1000)}
+ return json.dumps(browser_payload), gr.update()
+ payload["_nonce"] = int(_time.time() * 1000)
+ return json.dumps(payload), gr.update()
+ else:
+ try:
+ with httpx.Client(timeout=300.0) as client:
+ resp = client.post(
+ f"{api_base}/v1/audio/speech",
+ json=payload,
+ headers={"Content-Type": "application/json", "Authorization": "Bearer EMPTY"},
+ )
+ except httpx.ConnectError:
+ raise gr.Error(f"Cannot connect to server at {api_base}.")
+ if resp.status_code != 200:
+ raise gr.Error(f"Server error ({resp.status_code}): {resp.text[:200]}")
+ audio_np, sr = sf.read(io.BytesIO(resp.content))
+ if audio_np.ndim > 1:
+ audio_np = audio_np[:, 0]
+ return "", (sr, audio_np.astype(np.float32))
+
+ generate_btn.click(
+ fn=on_generate,
+ inputs=[
+ stream_checkbox,
+ text_input,
+ control_instruction,
+ ref_audio,
+ ref_audio_url,
+ ultimate_clone,
+ prompt_text,
+ ],
+ outputs=[hidden_payload, audio_output],
+ ).then(
+ fn=lambda p: p,
+ inputs=[hidden_payload],
+ outputs=[hidden_payload],
+ js="(p) => { if (p && p.trim()) { const d = JSON.parse(p); delete d._nonce; window.ttsGenerate(d); } return p; }",
+ )
+
+ demo.queue()
+
+ return gr.mount_gradio_app(app, demo, path="/", css=css, theme=theme, head=_build_player_js())
+
+
+def main():
+ parser = argparse.ArgumentParser(description="VoxCPM2 streaming Gradio demo")
+ parser.add_argument("--api-base", default="http://localhost:8000", help="vLLM API server URL")
+ parser.add_argument("--host", default="0.0.0.0", help="Gradio server host")
+ parser.add_argument("--port", type=int, default=7860, help="Gradio server port")
+ args = parser.parse_args()
+
+ logging.basicConfig(level=logging.INFO)
+ print(f"Connecting to vLLM server at: {args.api_base}")
+
+ import uvicorn
+
+ uvicorn.run(create_app(args.api_base), host=args.host, port=args.port)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/online_serving/voxcpm2/openai_speech_client.py b/examples/online_serving/voxcpm2/openai_speech_client.py
new file mode 100644
index 0000000000..a117d24fd1
--- /dev/null
+++ b/examples/online_serving/voxcpm2/openai_speech_client.py
@@ -0,0 +1,108 @@
+"""OpenAI-compatible client for VoxCPM2 TTS via /v1/audio/speech endpoint.
+
+Examples:
+ # Zero-shot synthesis
+ python openai_speech_client.py --text "Hello, this is VoxCPM2."
+
+ # Voice cloning with a local reference audio file
+ python openai_speech_client.py --text "Hello world" \
+ --ref-audio /path/to/reference.wav
+
+ # Voice cloning with a URL
+ python openai_speech_client.py --text "Hello world" \
+ --ref-audio "https://example.com/reference.wav"
+
+Server setup:
+ python -m vllm_omni.entrypoints.openai.api_server \
+ --model openbmb/VoxCPM2 \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm2.yaml \
+ --host 0.0.0.0 --port 8000
+"""
+
+from __future__ import annotations
+
+import argparse
+import base64
+import os
+
+import httpx
+
+DEFAULT_API_BASE = "http://localhost:8000"
+DEFAULT_API_KEY = "sk-empty"
+
+
+def encode_audio_to_base64(audio_path: str) -> str:
+ """Encode a local audio file to a base64 data URL."""
+ if not os.path.exists(audio_path):
+ raise FileNotFoundError(f"Audio file not found: {audio_path}")
+
+ ext = audio_path.lower().rsplit(".", 1)[-1]
+ mime = {
+ "wav": "audio/wav",
+ "mp3": "audio/mpeg",
+ "flac": "audio/flac",
+ "ogg": "audio/ogg",
+ }.get(ext, "audio/wav")
+
+ with open(audio_path, "rb") as f:
+ b64 = base64.b64encode(f.read()).decode("utf-8")
+ return f"data:{mime};base64,{b64}"
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="VoxCPM2 OpenAI speech client")
+ parser.add_argument("--text", type=str, required=True, help="Text to synthesize")
+ parser.add_argument(
+ "--ref-audio",
+ type=str,
+ default=None,
+ help="Reference audio for voice cloning (local path, URL, or data: URI)",
+ )
+ parser.add_argument("--model", type=str, default="voxcpm2")
+ parser.add_argument("--output", type=str, default="output.wav")
+ parser.add_argument("--api-base", type=str, default=DEFAULT_API_BASE)
+ parser.add_argument("--api-key", type=str, default=DEFAULT_API_KEY)
+ parser.add_argument("--response-format", type=str, default="wav")
+ args = parser.parse_args()
+
+ # VoxCPM2 has no predefined voices. The "voice" field is required by
+ # the OpenAI API schema but ignored by VoxCPM2 — use any placeholder.
+ # For voice cloning, pass --ref-audio instead.
+ payload: dict = {
+ "model": args.model,
+ "input": args.text,
+ "voice": "default",
+ "response_format": args.response_format,
+ }
+
+ if args.ref_audio:
+ ref = args.ref_audio
+ if ref.startswith(("http://", "https://", "data:")):
+ payload["ref_audio"] = ref
+ else:
+ payload["ref_audio"] = encode_audio_to_base64(ref)
+
+ url = f"{args.api_base}/v1/audio/speech"
+ print(f"POST {url}")
+ print(f" text: {args.text}")
+ if args.ref_audio:
+ print(f" ref_audio: {args.ref_audio[:80]}...")
+
+ with httpx.Client(timeout=300) as client:
+ resp = client.post(
+ url,
+ json=payload,
+ headers={"Authorization": f"Bearer {args.api_key}"},
+ )
+
+ if resp.status_code != 200:
+ print(f"Error {resp.status_code}: {resp.text[:500]}")
+ return
+
+ with open(args.output, "wb") as f:
+ f.write(resp.content)
+ print(f"Saved: {args.output} ({len(resp.content):,} bytes)")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pyproject.toml b/pyproject.toml
index e49aa6e325..9b034a7c8e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -55,12 +55,24 @@ dev = [
"pyttsx3>=2.99",
"opencc>=1.2.0",
"mistune>=3.2.0", # for example tests
+ "torchmetrics>=1.4.0", # for accuracy similarity metrics
]
demo = [
"gradio>=6.7.0",
]
+# Seed-TTS serve benchmark WER (BytedanceSpeech/seed-tts-eval run_wer.py protocol).
+seed-tts-eval = [
+ "jiwer>=3.0.0",
+ "zhon>=2.0.0",
+ "zhconv>=1.4.2",
+ "scipy>=1.10.0",
+ "soundfile>=0.12.0",
+ "transformers>=4.36.0",
+ "funasr>=1.0.0",
+]
+
docs = [
"mkdocs>=1.5.0",
"mkdocs-api-autonav",
@@ -182,6 +194,7 @@ markers = [
"H100: Tests that require H100 GPU",
"L4: Tests that require L4 GPU",
"MI325: Tests that require MI325 GPU (AMD/ROCm)",
+ "B60: Tests that require Intel Arc Pro B60 XPU",
"S5000: Tests that require S5000 GPU (Moore Threads/MUSA)",
"A2: Tests that require A2 NPU",
"A3: Tests that require A3 NPU",
diff --git a/recipes/Qwen/Qwen3-Omni.md b/recipes/Qwen/Qwen3-Omni.md
new file mode 100644
index 0000000000..081e1453d3
--- /dev/null
+++ b/recipes/Qwen/Qwen3-Omni.md
@@ -0,0 +1,90 @@
+# Qwen3-Omni for multimodal chat on 1x A100 80GB
+
+## Summary
+
+- Vendor: Qwen
+- Model: `Qwen/Qwen3-Omni-30B-A3B-Instruct`
+- Task: Multimodal chat with text, image, audio, or video input
+- Mode: Online serving with the OpenAI-compatible API
+- Maintainer: Community
+
+## When to use this recipe
+
+Use this recipe when you want a known-good starting point for serving
+`Qwen/Qwen3-Omni-30B-A3B-Instruct` with vLLM-Omni on a single 80 GB A100 and
+validate the deployment with the existing multimodal client examples in this
+repository.
+
+## References
+
+- Upstream or canonical docs:
+ [`docs/user_guide/examples/online_serving/qwen3_omni.md`](../../docs/user_guide/examples/online_serving/qwen3_omni.md)
+- Related example under `examples/`:
+ [`examples/online_serving/qwen3_omni/README.md`](../../examples/online_serving/qwen3_omni/README.md)
+- Related issue or discussion:
+ [RFC: add recipes folder](https://github.com/vllm-project/vllm-omni/issues/2645)
+
+## Hardware Support
+
+This recipe currently documents one tested-style reference configuration for
+CUDA GPU serving. Add more sections for other hardware as community validation
+lands.
+
+## GPU
+
+### 1x A100 80GB
+
+#### Environment
+
+- OS: Linux
+- Python: 3.10+
+- Driver / runtime: NVIDIA CUDA environment with an A100 80 GB GPU
+- vLLM version: Match the repository requirements for your checkout
+- vLLM-Omni version or commit: Use the commit you are deploying from
+
+#### Command
+
+Start the server from the repository root:
+
+```bash
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091
+```
+
+To enable async chunking, use the bundled stage config:
+
+```bash
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct \
+ --omni \
+ --port 8091 \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml
+```
+
+#### Verification
+
+Run one of the existing example clients after the server is ready:
+
+```bash
+python examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py \
+ --model Qwen/Qwen3-Omni-30B-A3B-Instruct \
+ --query-type use_image \
+ --port 8091 \
+ --host localhost
+```
+
+For a quick API smoke test, request text-only output:
+
+```bash
+curl http://localhost:8091/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
+ "messages": [{"role": "user", "content": "Describe vLLM in brief."}],
+ "modalities": ["text"]
+ }'
+```
+
+#### Notes
+
+- Memory usage: Size depends on runtime options and output modalities; leave headroom for multimodal workloads.
+- Key flags: `--omni` is required; `--stage-configs-path` is optional for custom or async-chunk stage configs.
+- Known limitations: This starter recipe is intentionally narrow and focuses on the single-GPU online-serving path already documented in the repo examples.
diff --git a/recipes/README.md b/recipes/README.md
new file mode 100644
index 0000000000..5b3dfb5430
--- /dev/null
+++ b/recipes/README.md
@@ -0,0 +1,35 @@
+# Community Recipes
+
+This directory contains community-maintained recipes for answering a
+practical user question:
+
+> How do I run model X on hardware Y for task Z?
+
+Add recipes for this repository under this in-repo `recipes/` directory. To
+keep naming and layout consistent, organize recipes by model vendor in a way
+that is aligned with
+[`vllm-project/recipes`](https://github.com/vllm-project/recipes), but treat
+that external repository as a reference for structure rather than the place to
+add files for this repo. Use one Markdown file per model family by default.
+
+Example layout:
+
+```text
+recipes/
+ Qwen/
+ Qwen3-Omni.md
+ Qwen3-TTS.md
+ Tencent-Hunyuan/
+ HunyuanVideo.md
+```
+
+## Available Recipes
+
+- [`Qwen/Qwen3-Omni.md`](./Qwen/Qwen3-Omni.md): online serving recipe for
+ multimodal chat on `1x A100 80GB`
+
+Within a single recipe file, include different hardware support sections such
+as `GPU`, `ROCm`, and `NPU`, and add concrete tested configurations like
+`1x A100 80GB` or `2x L40S` inside those sections when applicable.
+
+See [TEMPLATE.md](./TEMPLATE.md) for the recommended format.
diff --git a/recipes/TEMPLATE.md b/recipes/TEMPLATE.md
new file mode 100644
index 0000000000..9bf8cb9c75
--- /dev/null
+++ b/recipes/TEMPLATE.md
@@ -0,0 +1,82 @@
+# Recipe Title
+
+> Example: Qwen3-Omni for speech chat on 1x A100 80GB
+
+## Summary
+
+- Vendor:
+- Model:
+- Task:
+- Mode:
+- Maintainer:
+
+## When to use this recipe
+
+Briefly describe the concrete scenario this recipe covers.
+
+## References
+
+- Upstream or canonical docs:
+- Related example under `examples/`:
+- Related issue or discussion:
+
+## Hardware Support
+
+Add one section per platform, such as `GPU`, `ROCm`, or `NPU`. Under each
+platform section, document one or more tested hardware configurations.
+
+## GPU
+
+### 1x A100 80GB
+
+#### Environment
+
+- OS:
+- Python:
+- Driver / runtime:
+- vLLM version:
+- vLLM-Omni version or commit:
+
+#### Command
+
+```bash
+# Add the exact command(s) here
+```
+
+#### Verification
+
+```bash
+# Add a quick validation command or expected output here
+```
+
+#### Notes
+
+- Memory usage:
+- Key flags:
+- Known limitations:
+
+### 2x L40S
+
+Repeat the same structure for other hardware setups as needed.
+
+## ROCm
+
+### Example hardware configuration
+
+Repeat the same nested structure for ROCm setups as needed:
+
+- `#### Environment`
+- `#### Command`
+- `#### Verification`
+- `#### Notes`
+
+## NPU
+
+### Example hardware configuration
+
+Repeat the same nested structure for NPU setups as needed:
+
+- `#### Environment`
+- `#### Command`
+- `#### Verification`
+- `#### Notes`
diff --git a/requirements/common.txt b/requirements/common.txt
index 89eaac32bc..63e16d580f 100644
--- a/requirements/common.txt
+++ b/requirements/common.txt
@@ -1,8 +1,6 @@
# Common dependencies for all platforms
av>=14.0.0
omegaconf>=2.3.0
-librosa>=0.11.0
-resampy>=0.4.3
diffusers>=0.36.0
accelerate==1.12.0
soundfile>=0.13.1
@@ -11,7 +9,6 @@ tqdm>=4.66.0
torchsde>=0.2.6
openai-whisper>=20250625
imageio[ffmpeg]>=2.37.2
-sox>=1.5.0
x-transformers>=2.12.2
einops>=0.8.1
prettytable>=3.8.0
diff --git a/tests/comfyui/conftest.py b/tests/comfyui/conftest.py
index 0b4565e946..4280d3506f 100644
--- a/tests/comfyui/conftest.py
+++ b/tests/comfyui/conftest.py
@@ -9,8 +9,8 @@
import os
import sys
+from types import ModuleType, SimpleNamespace
from typing import BinaryIO, TypedDict
-from unittest.mock import MagicMock
def pytest_configure(config):
@@ -58,15 +58,15 @@ def save_to(self, file: str | BinaryIO):
else:
file.write(self._data)
- mock_comfy_api = MagicMock()
- mock_comfy_api_input = MagicMock()
+ mock_comfy_api = ModuleType("comfy_api")
+ mock_comfy_api_input = ModuleType("comfy_api.input")
mock_comfy_api_input.AudioInput = AudioInput
mock_comfy_api_input.VideoInput = VideoInput
mock_comfy_api.input = mock_comfy_api_input
- mock_comfy_api_latest = MagicMock()
- mock_comfy_api_latest.Types.VideoComponents = MagicMock(side_effect=lambda **kwargs: kwargs)
- mock_comfy_api_latest.InputImpl.VideoFromComponents = MagicMock(
- side_effect=lambda _: VideoInput(b"mock_video_from_components")
+ mock_comfy_api_latest = ModuleType("comfy_api.latest")
+ mock_comfy_api_latest.Types = SimpleNamespace(VideoComponents=lambda **kwargs: kwargs)
+ mock_comfy_api_latest.InputImpl = SimpleNamespace(
+ VideoFromComponents=lambda _: VideoInput(b"mock_video_from_components")
)
mock_comfy_api.latest = mock_comfy_api_latest
@@ -76,8 +76,8 @@ def mock_load(_: str | BinaryIO):
sample_rate = 24000
return waveform, sample_rate
- mock_comfy_extras = MagicMock()
- mock_nodes_audio = MagicMock()
+ mock_comfy_extras = ModuleType("comfy_extras")
+ mock_nodes_audio = ModuleType("comfy_extras.nodes_audio")
mock_nodes_audio.load = mock_load
mock_comfy_extras.nodes_audio = mock_nodes_audio
diff --git a/tests/comfyui/test_comfyui_integration.py b/tests/comfyui/test_comfyui_integration.py
index f6ce82f9b2..5164f3b9ac 100644
--- a/tests/comfyui/test_comfyui_integration.py
+++ b/tests/comfyui/test_comfyui_integration.py
@@ -13,7 +13,6 @@
from enum import StrEnum, auto
from types import SimpleNamespace
from typing import Any, NamedTuple
-from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import requests
@@ -28,6 +27,7 @@
)
from comfyui_vllm_omni.utils.types import AutoregressionSamplingParams, DiffusionSamplingParams, WanModelSpecificParams
from PIL import Image
+from pytest_mock import MockerFixture
from vllm import SamplingParams
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.utils.argparse_utils import FlexibleArgumentParser
@@ -217,9 +217,10 @@ def _build_diffusion_video_output() -> OmniRequestOutput:
def _build_diffusion_image_output_for_chat_endpoint() -> OmniRequestOutput:
- request_output = MagicMock()
- request_output.images = [_build_image_output(color="blue")]
- request_output.finished = True
+ request_output = SimpleNamespace(
+ images=[_build_image_output(color="blue")],
+ finished=True,
+ )
return OmniRequestOutput(
request_id="test_req_img_chat",
finished=True,
@@ -389,51 +390,55 @@ def sampling_case(request) -> SamplingCase:
@pytest.fixture
-def mock_async_omni(server_case: ServerCase, sampling_case: SamplingCase):
+def mock_async_omni(
+ server_case: ServerCase,
+ sampling_case: SamplingCase,
+ monkeypatch: pytest.MonkeyPatch,
+ mocker: MockerFixture,
+):
async def _mock_preprocess_chat(self, *args, **kwargs):
return ([{"role": "user", "content": "test"}], [{"prompt": "test prompt"}])
# Need to mock AsyncOmni itself (not only its generate method) because
# 1. The API layer uses its stage_list and stage_configs attributes
# 2. Its __init__ method has slow side effects (model & config loading).
- with (
- patch("vllm_omni.entrypoints.openai.api_server.AsyncOmni") as MockAsyncOmni,
- patch(
- "vllm_omni.entrypoints.openai.serving_chat.OmniOpenAIServingChat._preprocess_chat",
- new=_mock_preprocess_chat,
- ),
- ):
- mock_instance = AsyncMock(spec=RealAsyncOmni)
- mock_instance.generate = _build_mock_outputs(server_case.outputs, sampling_case, server_case)
-
- mock_instance.stage_list = server_case.stage_list
- mock_instance.stage_configs = server_case.stage_configs
- mock_instance.output_modalities = _build_output_modalities(server_case.stage_configs)
- mock_instance.default_sampling_params_list = [
- SamplingParams() if _stage_type(stage) != "diffusion" else MagicMock()
- for stage in server_case.stage_configs
- ]
- mock_instance.errored = False
- mock_instance.dead_error = RuntimeError("Mock engine error")
- mock_instance.model_config = MagicMock(
- max_model_len=4096,
- io_processor_plugin=None,
- allowed_local_media_path=None,
- allowed_media_domains=None,
- )
- # Mimic Qwen3-TTS talker speaker config so CustomVoice validation passes.
- mock_instance.model_config.hf_config = MagicMock()
- mock_instance.model_config.hf_config.talker_config = MagicMock()
- mock_instance.model_config.hf_config.talker_config.speaker_id = {"Vivian": 0}
- mock_instance.io_processor = MagicMock()
- mock_instance.input_processor = MagicMock()
- mock_instance.shutdown = MagicMock()
- mock_instance.get_vllm_config = AsyncMock(return_value=None)
- mock_instance.get_supported_tasks = AsyncMock(return_value=["generate"])
- mock_instance.get_tokenizer = AsyncMock(return_value=None)
+ mock_async_omni_cls = mocker.patch("vllm_omni.entrypoints.openai.api_server.AsyncOmni")
+ monkeypatch.setattr(
+ "vllm_omni.entrypoints.openai.serving_chat.OmniOpenAIServingChat._preprocess_chat",
+ _mock_preprocess_chat,
+ )
+
+ mock_instance = mocker.AsyncMock(spec=RealAsyncOmni)
+ mock_instance.generate = _build_mock_outputs(server_case.outputs, sampling_case, server_case)
+
+ mock_instance.stage_list = server_case.stage_list
+ mock_instance.stage_configs = server_case.stage_configs
+ mock_instance.output_modalities = _build_output_modalities(server_case.stage_configs)
+ mock_instance.default_sampling_params_list = [
+ SamplingParams() if _stage_type(stage) != "diffusion" else mocker.MagicMock()
+ for stage in server_case.stage_configs
+ ]
+ mock_instance.errored = False
+ mock_instance.dead_error = RuntimeError("Mock engine error")
+ mock_instance.model_config = mocker.MagicMock(
+ max_model_len=4096,
+ io_processor_plugin=None,
+ allowed_local_media_path=None,
+ allowed_media_domains=None,
+ )
+ # Mimic Qwen3-TTS talker speaker config so CustomVoice validation passes.
+ mock_instance.model_config.hf_config = mocker.MagicMock()
+ mock_instance.model_config.hf_config.talker_config = mocker.MagicMock()
+ mock_instance.model_config.hf_config.talker_config.speaker_id = {"Vivian": 0}
+ mock_instance.io_processor = mocker.MagicMock()
+ mock_instance.input_processor = mocker.MagicMock()
+ mock_instance.shutdown = mocker.MagicMock()
+ mock_instance.get_vllm_config = mocker.AsyncMock(return_value=None)
+ mock_instance.get_supported_tasks = mocker.AsyncMock(return_value=["generate"])
+ mock_instance.get_tokenizer = mocker.AsyncMock(return_value=None)
- MockAsyncOmni.return_value = mock_instance
- yield MockAsyncOmni
+ mock_async_omni_cls.return_value = mock_instance
+ yield mock_async_omni_cls
@pytest.fixture
@@ -518,6 +523,7 @@ def run_server():
"Qwen/Qwen-Image-Edit",
True,
id="image-to-image-dalle-endpoint",
+ marks=pytest.mark.skip(reason="Temporarily disabled due to failure."),
),
pytest.param(
ServerCase(
@@ -583,9 +589,9 @@ async def test_image_generation_node(api_server: str, model: str, image_input: b
ServerCase(
served_model="Qwen/Qwen2.5-Omni-7B",
stage_list=[
- MagicMock(is_comprehension=True, model_stage="llm"),
- MagicMock(is_comprehension=False, model_stage="llm"),
- MagicMock(is_comprehension=False, model_stage="llm"),
+ SimpleNamespace(is_comprehension=True, model_stage="llm"),
+ SimpleNamespace(is_comprehension=False, model_stage="llm"),
+ SimpleNamespace(is_comprehension=False, model_stage="llm"),
],
stage_configs=[
_make_stage_config("llm", is_comprehension=True, model_stage="thinker"),
diff --git a/tests/config/__init__.py b/tests/config/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/config/test_pipeline_registry.py b/tests/config/test_pipeline_registry.py
new file mode 100644
index 0000000000..3483d530c6
--- /dev/null
+++ b/tests/config/test_pipeline_registry.py
@@ -0,0 +1,111 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Tests for the central pipeline registry (2.5/N)."""
+
+from __future__ import annotations
+
+import pytest
+
+from vllm_omni.config.pipeline_registry import (
+ _DIFFUSION_PIPELINES,
+ _OMNI_PIPELINES,
+ _VLLM_OMNI_PIPELINES,
+)
+from vllm_omni.config.stage_config import (
+ _PIPELINE_REGISTRY,
+ PipelineConfig,
+ StageExecutionType,
+ StagePipelineConfig,
+ register_pipeline,
+)
+
+
+class TestCentralRegistryDeclarations:
+ """Every in-tree pipeline must be declared exactly once in the central registry."""
+
+ def test_union_contains_all_omni(self):
+ for key in _OMNI_PIPELINES:
+ assert key in _VLLM_OMNI_PIPELINES
+
+ def test_union_contains_all_diffusion(self):
+ for key in _DIFFUSION_PIPELINES:
+ assert key in _VLLM_OMNI_PIPELINES
+
+ def test_no_duplicate_model_type_between_omni_and_diffusion(self):
+ overlap = set(_OMNI_PIPELINES) & set(_DIFFUSION_PIPELINES)
+ assert not overlap, f"Duplicate model_types across omni/diffusion: {overlap}"
+
+ def test_expected_omni_pipelines_present(self):
+ # Guard against accidental removal during future refactors.
+ assert "qwen2_5_omni" in _OMNI_PIPELINES
+ assert "qwen2_5_omni_thinker_only" in _OMNI_PIPELINES
+ assert "qwen3_omni_moe" in _OMNI_PIPELINES
+ assert "qwen3_tts" in _OMNI_PIPELINES
+
+
+class TestLazyLoading:
+ """Pipelines are imported only on first access."""
+
+ def test_contains_without_import(self):
+ # ``in`` hits the lazy map, not the loaded cache.
+ assert "qwen3_omni_moe" in _PIPELINE_REGISTRY
+
+ def test_getitem_loads_correct_pipeline(self):
+ pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"]
+ assert pipeline.model_type == "qwen3_omni_moe"
+ assert pipeline.model_arch == "Qwen3OmniMoeForConditionalGeneration"
+
+ def test_unknown_model_type_returns_none_via_get(self):
+ assert _PIPELINE_REGISTRY.get("not_a_real_pipeline") is None
+
+ def test_unknown_model_type_raises_keyerror_via_getitem(self):
+ with pytest.raises(KeyError):
+ _PIPELINE_REGISTRY["not_a_real_pipeline"]
+
+ def test_iteration_yields_registered_pipelines(self):
+ keys = set(_PIPELINE_REGISTRY)
+ assert "qwen2_5_omni" in keys
+ assert "qwen3_omni_moe" in keys
+
+
+class TestDynamicRegistration:
+ """``register_pipeline()`` still works for plugins and tests."""
+
+ def test_register_adds_to_registry(self):
+ custom = PipelineConfig(
+ model_type="_test_dynamic_registration",
+ model_arch="DynamicTestModel",
+ stages=(
+ StagePipelineConfig(
+ stage_id=0,
+ model_stage="test",
+ execution_type=StageExecutionType.LLM_AR,
+ input_sources=(),
+ final_output=True,
+ ),
+ ),
+ )
+ register_pipeline(custom)
+ try:
+ assert "_test_dynamic_registration" in _PIPELINE_REGISTRY
+ assert _PIPELINE_REGISTRY["_test_dynamic_registration"] is custom
+ finally:
+ # Don't leak the test registration into other tests.
+ if "_test_dynamic_registration" in _PIPELINE_REGISTRY:
+ del _PIPELINE_REGISTRY["_test_dynamic_registration"]
+
+ def test_dynamic_registration_overrides_lazy_entry(self):
+ # Build a substitute for qwen3_omni_moe that we can distinguish.
+ original = _PIPELINE_REGISTRY["qwen3_omni_moe"]
+ override = PipelineConfig(
+ model_type="qwen3_omni_moe",
+ model_arch="OverriddenArch",
+ stages=original.stages,
+ )
+ register_pipeline(override)
+ try:
+ assert _PIPELINE_REGISTRY["qwen3_omni_moe"].model_arch == "OverriddenArch"
+ finally:
+ # Remove the dynamic override so later tests see the original.
+ if "qwen3_omni_moe" in _PIPELINE_REGISTRY._loaded:
+ del _PIPELINE_REGISTRY["qwen3_omni_moe"]
diff --git a/tests/conftest.py b/tests/conftest.py
index 27833fe282..83752521f2 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,3 +1,4 @@
+import atexit
import base64
import datetime
import io
@@ -46,6 +47,7 @@
from vllm.logger import init_logger
from vllm.utils.network_utils import get_open_port
+from vllm_omni.config.stage_config import resolve_deploy_yaml
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniSamplingParams
from vllm_omni.outputs import OmniRequestOutput
@@ -75,6 +77,7 @@ class OmniServerParams(NamedTuple):
use_omni: bool = True
use_stage_cli: bool = False
init_timeout: int | None = None
+ stage_init_timeout: int | None = None # None defers to the server's own default (300 s)
def assert_image_diffusion_response(
@@ -166,7 +169,6 @@ def assert_audio_diffusion_response(
Validate audio diffusion response.
"""
raise NotImplementedError("Audio validation is not implemented yet")
- # consider using assert_audio_valid defined above
def _maybe_int(value: Any) -> int | None:
@@ -276,15 +278,32 @@ def assert_video_valid(
pass
-def assert_audio_valid(path: Path, *, sample_rate: int, channels: int, duration_s: float) -> None:
- """Assert the WAV has the expected sample rate, channel count, and duration."""
+def assert_audio_valid(
+ audio_or_path: Path | np.ndarray,
+ *,
+ sample_rate: int,
+ channels: int,
+ duration_s: float,
+) -> None:
+ """Assert WAV file or (batch, channels, samples) ndarray matches expected audio format."""
+ expected_samples = int(duration_s * sample_rate)
+ if isinstance(audio_or_path, np.ndarray):
+ audio = audio_or_path
+ assert audio.ndim == 3, f"Expected audio ndim=3 (batch, channels, samples), got shape {audio.shape}"
+ assert audio.shape[0] == 1, f"Expected batch size 1, got {audio.shape[0]}"
+ assert audio.shape[1] == channels, f"Expected {channels} channels, got {audio.shape[1]}"
+ assert audio.shape[2] == expected_samples, (
+ f"Expected {expected_samples} samples ({duration_s}s @ {sample_rate} Hz), got {audio.shape[2]}"
+ )
+ return
+
+ path = audio_or_path
assert path.exists(), f"Audio not found: {path}"
info = sf.info(str(path))
assert info.samplerate == sample_rate, f"Expected sample_rate={sample_rate}, got {info.samplerate}"
assert info.channels == channels, f"Expected {channels} channel(s), got {info.channels}"
- expected_frames = int(duration_s * sample_rate)
- assert info.frames == expected_frames, (
- f"Expected {expected_frames} frames ({duration_s}s @ {sample_rate} Hz), got {info.frames}"
+ assert info.frames == expected_samples, (
+ f"Expected {expected_samples} frames ({duration_s}s @ {sample_rate} Hz), got {info.frames}"
)
@@ -1321,12 +1340,14 @@ def delete_by_path(config_dict: dict, path: str) -> None:
else:
print(f"Path {path} does not exist")
+ _stage_key = "stages" if "stages" in config else "stage_args"
+
# Apply deletions first
if deletes:
for key, value in deletes.items():
- if key == "stage_args":
+ if key in ("stage_args", "stages"):
if value and isinstance(value, dict):
- stage_args = config.get("stage_args", [])
+ stage_args = config.get(_stage_key, [])
if not stage_args:
raise ValueError("stage_args does not exist in config")
@@ -1345,9 +1366,10 @@ def delete_by_path(config_dict: dict, path: str) -> None:
continue
# Delete specified paths in this stage
- for path in delete_paths:
- if path: # Skip empty paths
- delete_by_path(target_stage, path)
+ # Avoid shadowing the original YAML Path used for the output filename below.
+ for delete_path in delete_paths:
+ if delete_path: # Skip empty paths
+ delete_by_path(target_stage, delete_path)
elif "." in key:
# Delete using dot-separated path
delete_by_path(config, key)
@@ -1358,9 +1380,9 @@ def delete_by_path(config_dict: dict, path: str) -> None:
# Apply updates
if updates:
for key, value in updates.items():
- if key == "stage_args":
+ if key in ("stage_args", "stages"):
if value and isinstance(value, dict):
- stage_args = config.get("stage_args", [])
+ stage_args = config.get(_stage_key, [])
if not stage_args:
raise ValueError("stage_args does not exist in config")
@@ -1377,15 +1399,15 @@ def delete_by_path(config_dict: dict, path: str) -> None:
raise KeyError(f"Stage ID {stage_id} not found, available: {available_ids}")
# Apply updates to this stage
- for path, val in stage_updates.items():
+ for update_path, val in stage_updates.items():
# Check if this is a simple key (not dot-separated)
# Example: 'engine_input_source' vs 'engine_args.max_model_len'
- if "." not in path:
+ if "." not in update_path:
# Direct key assignment (e.g., updating a list value)
- target_stage[path] = val
+ target_stage[update_path] = val
else:
# Dot-separated path (e.g., nested dict access)
- apply_update(target_stage, path, val)
+ apply_update(target_stage, update_path, val)
elif "." in key:
# Apply using dot-separated path
apply_update(config, key, value)
@@ -1397,13 +1419,14 @@ def delete_by_path(config_dict: dict, path: str) -> None:
# within the same second (e.g. test_qwen3_omni_expansion imports both
# get_chunk_config and get_batch_token_config). int(time.time()) would collide
# and the later write would overwrite the earlier YAML on disk.
- base_name = yaml_path.rsplit(".", 1)[0] if "." in yaml_path else yaml_path
- output_path = f"{base_name}_{time.time_ns()}.yaml"
+ # Keep generated configs outside the repo and delete them when pytest exits.
+ output_fd, output_path = tempfile.mkstemp(prefix=f"{path.stem}_", suffix=".yaml")
+ atexit.register(Path(output_path).unlink, missing_ok=True)
- with open(output_path, "w", encoding="utf-8") as f:
+ with os.fdopen(output_fd, "w", encoding="utf-8") as f:
yaml.dump(config, f, default_flow_style=None, sort_keys=False, allow_unicode=True, indent=2)
- return output_path
+ return str(output_path)
class OmniServer:
@@ -1565,32 +1588,46 @@ def __init__(
self.stage_config_path = stage_config_path
self.master_port = get_open_port()
self.visible_device_list = self._load_visible_device_list(env_dict)
- self.stage_runtime_devices = self._load_stage_runtime_devices(stage_config_path)
- self.stage_ids = stage_ids or self._load_stage_ids(stage_config_path)
+ resolved_cfg = resolve_deploy_yaml(stage_config_path)
+ # Dump the resolved deploy config so CI logs show each stage's
+ # gpu_memory_utilization / max_model_len / max_num_seqs after
+ # base_config inheritance and overlay merge — essential when
+ # diagnosing OOMs that depend on the merged values.
+ print(
+ f"[OmniServerStageCli] Resolved deploy config from {stage_config_path}:\n"
+ f"{yaml.safe_dump(resolved_cfg, sort_keys=False, default_flow_style=False)}",
+ flush=True,
+ )
+ self.stage_runtime_devices = self._load_stage_runtime_devices(resolved_cfg)
+ self.stage_ids = stage_ids or self._load_stage_ids(resolved_cfg)
if 0 not in self.stage_ids:
raise ValueError(f"Stage CLI test requires stage_id=0 in config: {stage_config_path}")
self.stage_procs: dict[int, subprocess.Popen] = {}
self.proc = None
@staticmethod
- def _load_stage_ids(stage_config_path: str) -> list[int]:
- with open(stage_config_path, encoding="utf-8") as f:
- cfg = yaml.safe_load(f) or {}
+ def _stage_entries(cfg: dict) -> list[dict]:
+ """Return the list of stage entries from either legacy (``stage_args``)
+ or new-schema (``stages``) deploy YAMLs."""
+ return cfg.get("stage_args") or cfg.get("stages") or []
- stage_ids = [stage["stage_id"] for stage in cfg.get("stage_args", []) if "stage_id" in stage]
+ @staticmethod
+ def _load_stage_ids(resolved_config: dict) -> list[int]:
+ stage_ids = [
+ stage["stage_id"] for stage in OmniServerStageCli._stage_entries(resolved_config) if "stage_id" in stage
+ ]
if not stage_ids:
- raise ValueError(f"No stage IDs found in config: {stage_config_path}")
+ raise ValueError("No stage IDs found in resolved config")
return stage_ids
@staticmethod
- def _load_stage_runtime_devices(stage_config_path: str) -> dict[int, str]:
- with open(stage_config_path, encoding="utf-8") as f:
- cfg = yaml.safe_load(f) or {}
-
+ def _load_stage_runtime_devices(resolved_config: dict) -> dict[int, str]:
runtime_devices: dict[int, str] = {}
- for stage in cfg.get("stage_args", []):
+ for stage in OmniServerStageCli._stage_entries(resolved_config):
stage_id = stage.get("stage_id")
- devices = stage.get("runtime", {}).get("devices")
+ # New schema: stage.devices is flat at stage level.
+ # Legacy schema: stage.runtime.devices is nested.
+ devices = stage.get("devices") or stage.get("runtime", {}).get("devices")
if stage_id is not None and devices:
runtime_devices[int(stage_id)] = str(devices)
return runtime_devices
@@ -1676,10 +1713,21 @@ def _launch_stage(self, stage_id: int, *, headless: bool) -> None:
cmd = self._build_stage_cmd(stage_id, headless=headless)
print(f"Launching OmniServerStageCli stage {stage_id}: {' '.join(cmd)}")
+ # Capture each subprocess's stdout+stderr to a per-stage log file so
+ # debugging "Stage N exited before API server ready" doesn't rely on
+ # guessing; the file is surfaced in the RuntimeError message.
+ log_path = Path(tempfile.gettempdir()) / f"omni_stage_{stage_id}_{self.master_port}.log"
+ self._stage_log_paths = getattr(self, "_stage_log_paths", {})
+ self._stage_log_paths[stage_id] = log_path
+ log_fh = open(log_path, "w", buffering=1) # noqa: SIM115 - closed in __exit__
+ self._stage_log_files = getattr(self, "_stage_log_files", {})
+ self._stage_log_files[stage_id] = log_fh
proc = subprocess.Popen(
cmd,
env=env,
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
+ stdout=log_fh,
+ stderr=subprocess.STDOUT,
)
self.stage_procs[stage_id] = proc
if stage_id == 0:
@@ -1689,7 +1737,18 @@ def _ensure_stage_processes_alive(self) -> None:
for stage_id, proc in self.stage_procs.items():
ret = proc.poll()
if ret is not None:
- raise RuntimeError(f"Stage {stage_id} exited with code {ret} before API server became ready.")
+ log_path = getattr(self, "_stage_log_paths", {}).get(stage_id)
+ tail = ""
+ if log_path and log_path.exists():
+ try:
+ with open(log_path, encoding="utf-8", errors="replace") as f:
+ lines = f.readlines()
+ tail = "\n=== Last 60 lines of stage {} log ({}) ===\n{}".format(
+ stage_id, log_path, "".join(lines[-60:]) or ""
+ )
+ except Exception as exc: # pragma: no cover - diagnostic only
+ tail = f"\n"
+ raise RuntimeError(f"Stage {stage_id} exited with code {ret} before API server became ready.{tail}")
def _start_server(self) -> None:
ordered_stage_ids = [0, *[stage_id for stage_id in self.stage_ids if stage_id != 0]]
@@ -1715,7 +1774,46 @@ def _start_server(self) -> None:
raise RuntimeError(f"OmniServerStageCli failed to start within {max_wait} seconds")
+ def _dump_stage_logs_for_debug(self, head_lines: int = 300, tail_lines: int = 500) -> None:
+ """Tail each stage's subprocess log back to stdout on teardown.
+
+ Stage subprocesses redirect stdout/stderr to ``/tmp/omni_stage_*.log``
+ so we don't spam the main CI stream while tests run; but that also
+ hides engine init (KV cache size, Available KV cache memory, vLLM
+ engine config) when things go wrong. Dump them here so buildkite
+ captures them post-run. Head covers engine init; tail covers
+ whatever state the stage was in when it was torn down.
+ """
+ log_paths = getattr(self, "_stage_log_paths", {}) or {}
+ for stage_id in sorted(log_paths):
+ log_path = log_paths[stage_id]
+ if not log_path or not log_path.exists():
+ continue
+ try:
+ with open(log_path, encoding="utf-8", errors="replace") as f:
+ lines = f.readlines()
+ except Exception as exc: # pragma: no cover - diagnostic only
+ print(f"[OmniServerStageCli] stage {stage_id} log read failed: {exc}", flush=True)
+ continue
+ total = len(lines)
+ if total <= head_lines + tail_lines:
+ head_chunk = lines
+ tail_chunk = []
+ elided = 0
+ else:
+ head_chunk = lines[:head_lines]
+ tail_chunk = lines[-tail_lines:]
+ elided = total - head_lines - tail_lines
+ print(f"\n=== stage {stage_id} log HEAD ({log_path}) ===", flush=True)
+ print("".join(head_chunk).rstrip("\n"), flush=True)
+ if tail_chunk:
+ print(f"\n... [{elided} lines elided] ...", flush=True)
+ print(f"\n=== stage {stage_id} log TAIL ({log_path}) ===", flush=True)
+ print("".join(tail_chunk).rstrip("\n"), flush=True)
+ print(f"=== end stage {stage_id} log ===\n", flush=True)
+
def __exit__(self, exc_type, exc_val, exc_tb):
+ self._dump_stage_logs_for_debug()
for stage_id in sorted(self.stage_procs, reverse=True):
proc = self.stage_procs[stage_id]
if proc.poll() is None:
@@ -1761,22 +1859,35 @@ def omni_server(request: pytest.FixtureRequest, run_level: str, model_prefix: st
if run_level == "advanced_model" and stage_config_path is not None:
with open(stage_config_path, encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
- stage_ids = [stage["stage_id"] for stage in cfg.get("stage_args", []) if "stage_id" in stage]
+ # Strip ``load_format: dummy`` (CI overlay default) so advanced_model
+ # tests use real weights. New schema (``stages:``) writes the field
+ # flat at stage level; legacy schema (``stage_args:``) nests it as
+ # ``engine_args.load_format``. Handle both.
+ new_schema_stages = cfg.get("stages")
+ stage_key = "stages" if new_schema_stages is not None else "stage_args"
+ delete_path = "load_format" if new_schema_stages is not None else "engine_args.load_format"
+ stage_entries = cfg.get(stage_key, [])
+ stage_ids = [stage["stage_id"] for stage in stage_entries if "stage_id" in stage]
stage_config_path = modify_stage_config(
stage_config_path,
- deletes={"stage_args": {stage_id: ["engine_args.load_format"] for stage_id in stage_ids}},
+ deletes={stage_key: {stage_id: [delete_path] for stage_id in stage_ids}},
)
server_args = params.server_args or []
- if params.use_omni:
- server_args = ["--stage-init-timeout", "120", *server_args]
+ if params.use_omni and params.stage_init_timeout is not None:
+ server_args = [*server_args, "--stage-init-timeout", str(params.stage_init_timeout)]
+ else:
+ server_args = [*server_args, "--stage-init-timeout", "600"]
if params.init_timeout is not None:
server_args = [*server_args, "--init-timeout", str(params.init_timeout)]
+ else:
+ server_args = [*server_args, "--init-timeout", "900"]
if params.use_stage_cli:
if not params.use_omni:
raise ValueError("omni_server with use_stage_cli=True requires use_omni=True")
if stage_config_path is None:
raise ValueError("omni_server with use_stage_cli=True requires a stage_config_path")
+ server_args += ["--stage-configs-path", stage_config_path]
with OmniServerStageCli(
model,
@@ -1826,6 +1937,7 @@ class OmniResponse:
e2e_latency: float | None = None
success: bool = False
error_message: str | None = None
+ cached_tokens: int | None = None
@dataclass
@@ -2321,6 +2433,11 @@ def _process_non_stream_omni_response(self, chat_completion) -> OmniResponse:
if hasattr(choice.message, "content") and choice.message.content is not None:
text_content = choice.message.content
+ # Extract cached_tokens for prefix caching tests
+ usage = getattr(chat_completion, "usage", None)
+ if usage and (details := getattr(usage, "prompt_tokens_details", None)):
+ result.cached_tokens = details.cached_tokens
+
# Calculate end-to-end latency
result.e2e_latency = time.perf_counter() - start_time
@@ -2373,7 +2490,7 @@ def _process_diffusion_response(self, chat_completion) -> DiffusionResponse:
image_url = item.get("image_url", {}).get("url")
else:
image_url_obj = getattr(item, "image_url", None)
- image_url = hasattr(image_url_obj, "url", None) if image_url_obj else None
+ image_url = getattr(image_url_obj, "url", None) if image_url_obj else None
if image_url and image_url.startswith("data:image"):
b64_data = image_url.split(",", 1)[1]
img = decode_b64_image(b64_data)
@@ -2679,7 +2796,7 @@ def _stream_task():
return responses
- def send_diffusion_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]:
+ def send_diffusion_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[DiffusionResponse]:
"""
Send OpenAI requests for diffusion models.
@@ -2687,9 +2804,9 @@ def send_diffusion_request(self, request_config: dict[str, Any], request_num: in
request_config: Request configuration dictionary containing parameters like model, messages
request_num: Number of requests to send concurrently, defaults to 1 (single request)
Returns:
- List[OmniResponse]: List of response objects
+ List[DiffusionResponse]: List of response objects
"""
- responses = []
+ responses: list[DiffusionResponse] = []
stream = request_config.get("stream", False)
modalities = request_config.get("modalities", omit) # Most diffusion models don't require modalities param
extra_body = request_config.get("extra_body", None)
@@ -2869,9 +2986,9 @@ def __init__(
self,
model_name: str,
seed: int = 42,
- stage_init_timeout: int = 300,
+ stage_init_timeout: int = 600,
batch_timeout: int = 10,
- init_timeout: int = 300,
+ init_timeout: int = 900,
shm_threshold_bytes: int = 65536,
log_stats: bool = False,
stage_configs_path: str | None = None,
@@ -2996,6 +3113,10 @@ def get_omni_inputs(
video_padding_token = "<|video_pad|>"
image_padding_token = "<|image_pad|>"
audio_padding_token = "<|audio_pad|>"
+ elif "Ming-flash-omni" in self.model_name:
+ video_padding_token = ""
+ image_padding_token = ""
+ audio_padding_token = ""
if isinstance(prompts, str):
prompts = [prompts]
diff --git a/tests/core/sched/test_chunk_scheduling_coordinator.py b/tests/core/sched/test_chunk_scheduling_coordinator.py
new file mode 100644
index 0000000000..5e19465e22
--- /dev/null
+++ b/tests/core/sched/test_chunk_scheduling_coordinator.py
@@ -0,0 +1,690 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for OmniSchedulingCoordinator (formerly ChunkSchedulingCoordinator).
+
+These tests use mock request objects and mock queues. They do not require
+GPU, vLLM runtime, or any connector.
+"""
+
+from __future__ import annotations
+
+import unittest
+from types import SimpleNamespace
+
+import vllm_omni.core.sched.omni_scheduling_coordinator as coord_mod
+from vllm_omni.core.sched.omni_scheduling_coordinator import (
+ ChunkSchedulingCoordinator,
+ OmniSchedulingCoordinator,
+)
+
+# ------------------------------------------------------------------ #
+# Mock helpers
+# ------------------------------------------------------------------ #
+
+
+class _RequestStatus:
+ WAITING = "waiting"
+ RUNNING = "running"
+ WAITING_FOR_CHUNK = "waiting_for_chunk"
+ WAITING_FOR_INPUT = "waiting_for_input"
+ FINISHED_STOPPED = "finished_stopped"
+
+
+# Patch RequestStatus for tests that don't import vllm
+try:
+ from vllm.v1.request import RequestStatus
+except ImportError:
+ RequestStatus = _RequestStatus # type: ignore[misc,assignment]
+
+if not hasattr(RequestStatus, "WAITING_FOR_INPUT"):
+ coord_mod.RequestStatus = _RequestStatus # type: ignore[assignment]
+ RequestStatus = _RequestStatus # type: ignore[misc,assignment]
+
+
+def _make_request(req_id: str, status: str = "waiting") -> SimpleNamespace:
+ return SimpleNamespace(
+ request_id=req_id,
+ external_req_id=req_id,
+ status=status,
+ additional_information=None,
+ prompt_token_ids=[],
+ num_prompt_tokens=0,
+ num_computed_tokens=0,
+ _all_token_ids=[],
+ _output_token_ids=[],
+ )
+
+
+class MockQueue:
+ """Simplified queue that mimics the Scheduler waiting queue interface."""
+
+ def __init__(self, items: list | None = None):
+ self._items: list = list(items or [])
+
+ def __iter__(self):
+ return iter(self._items)
+
+ def __len__(self):
+ return len(self._items)
+
+ def __contains__(self, item):
+ return item in self._items
+
+ def add_request(self, request):
+ self._items.append(request)
+
+ def prepend_requests(self, requests):
+ self._items = list(requests) + self._items
+
+ def remove(self, request):
+ self._items.remove(request)
+
+ def remove_requests(self, requests):
+ remove_set = set(id(r) for r in requests)
+ self._items = [r for r in self._items if id(r) not in remove_set]
+
+
+# ------------------------------------------------------------------ #
+# Tests
+# ------------------------------------------------------------------ #
+
+
+class TestChunkCoordinatorStateTransition(unittest.TestCase):
+ """Test 5: process_pending_chunks transitions WAITING_FOR_CHUNK → target."""
+
+ def test_ready_request_transitions_to_waiting(self):
+ coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1, async_chunk=True)
+
+ req = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_chunks(
+ waiting,
+ running,
+ chunk_ready_req_ids={"r1"},
+ chunk_finished_req_ids=set(),
+ )
+
+ self.assertEqual(req.status, RequestStatus.WAITING)
+ self.assertIn("r1", coord.requests_with_ready_chunks)
+
+ def test_non_ready_stays_waiting_for_chunk(self):
+ coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1, async_chunk=True)
+
+ req = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_chunks(
+ waiting,
+ running,
+ chunk_ready_req_ids=set(),
+ chunk_finished_req_ids=set(),
+ )
+
+ self.assertEqual(req.status, RequestStatus.WAITING_FOR_CHUNK)
+
+ def test_stage_0_is_noop(self):
+ coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=0)
+ req = _make_request("r1")
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_chunks(
+ waiting,
+ running,
+ chunk_ready_req_ids={"r1"},
+ chunk_finished_req_ids=set(),
+ )
+ self.assertNotEqual(req.status, RequestStatus.WAITING_FOR_CHUNK)
+
+
+class TestChunkCoordinatorRestoreQueues(unittest.TestCase):
+ """Test 6: restore_queues returns waiting-for-chunk requests."""
+
+ def test_restore(self):
+ coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
+
+ r1 = _make_request("r1")
+ r2 = _make_request("r2")
+ coord._waiting_for_chunk_waiting.append(r1)
+ coord._waiting_for_chunk_running.append(r2)
+
+ waiting = MockQueue()
+ running: list = []
+
+ coord.restore_queues(waiting, running)
+
+ self.assertIn(r1, waiting)
+ self.assertIn(r2, running)
+ self.assertEqual(len(coord._waiting_for_chunk_waiting), 0)
+ self.assertEqual(len(coord._waiting_for_chunk_running), 0)
+
+
+class TestChunkCoordinatorFinishedSignal(unittest.TestCase):
+ """Test 8: chunk_finished_req_ids → finished_requests."""
+
+ def test_finished_signal(self):
+ coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1, async_chunk=True)
+
+ req = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_chunks(
+ waiting,
+ running,
+ chunk_ready_req_ids={"r1"},
+ chunk_finished_req_ids={"r1"},
+ )
+
+ self.assertIn("r1", coord.finished_requests)
+
+
+class TestChunkCoordinatorUpdateRequestMetadata(unittest.TestCase):
+ """Test update_request_metadata applies scheduling metadata to requests."""
+
+ def test_ar_mode_no_longer_sets_additional_information(self):
+ """AR mode only processes scheduling metadata, not full payloads."""
+ coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
+
+ req = _make_request("r1")
+ requests = {"r1": req}
+
+ # Only scheduling metadata is passed now (full payload stays in model runner)
+ request_metadata = {"r1": {"next_stage_prompt_len": 50}}
+
+ coord.update_request_metadata(requests, request_metadata, model_mode="ar")
+
+ # next_stage_prompt_len should update prompt_token_ids
+ self.assertEqual(len(req.prompt_token_ids), 50)
+ self.assertEqual(req.num_prompt_tokens, 50)
+ # additional_information should NOT be set
+ self.assertIsNone(getattr(req, "additional_information", None))
+
+ def test_generation_mode(self):
+ coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
+
+ req = _make_request("r1")
+ req.prompt_token_ids = [0, 0, 0]
+ requests = {"r1": req}
+
+ request_metadata = {
+ "r1": {
+ "code_predictor_codes": [10, 20, 30],
+ "left_context_size": 25,
+ }
+ }
+
+ coord.update_request_metadata(requests, request_metadata, model_mode="generation")
+
+ self.assertEqual(req.prompt_token_ids, [10, 20, 30])
+ self.assertEqual(req.num_computed_tokens, 0)
+ self.assertIsNone(req.additional_information)
+ self.assertEqual(req._omni_initial_model_buffer, {"left_context_size": 25})
+
+
+class TestChunkCoordinatorPostprocess(unittest.TestCase):
+ """Test postprocess_scheduler_output clears ready chunks."""
+
+ def test_clear_ready(self):
+ coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
+ coord.requests_with_ready_chunks = {"r1", "r2"}
+
+ new_req = SimpleNamespace(req_id="r1")
+ cached_reqs = SimpleNamespace(req_ids=["r2"])
+ scheduler_output = SimpleNamespace(
+ scheduled_new_reqs=[new_req],
+ scheduled_cached_reqs=cached_reqs,
+ )
+
+ coord.postprocess_scheduler_output(scheduler_output)
+
+ self.assertEqual(coord.requests_with_ready_chunks, set())
+
+
+class TestWaitingForInputTransition(unittest.TestCase):
+ """Test B8: process_pending_full_payload_inputs transitions WAITING_FOR_INPUT."""
+
+ def test_transition_on_recv(self):
+ coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
+
+ req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_full_payload_inputs(
+ waiting,
+ running,
+ stage_recv_req_ids={"r1"},
+ )
+
+ self.assertEqual(req.status, RequestStatus.WAITING)
+
+ def test_stays_waiting_for_input_if_not_received(self):
+ coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
+
+ req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_full_payload_inputs(
+ waiting,
+ running,
+ stage_recv_req_ids=set(),
+ )
+
+ self.assertEqual(req.status, RequestStatus.WAITING_FOR_INPUT)
+ self.assertEqual(len(coord._waiting_for_input), 1)
+
+ def test_stage_0_is_noop(self):
+ coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=0)
+
+ req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_full_payload_inputs(
+ waiting,
+ running,
+ stage_recv_req_ids={"r1"},
+ )
+ self.assertEqual(req.status, RequestStatus.WAITING_FOR_INPUT)
+
+ def test_restore_queues_includes_waiting_for_input(self):
+ coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
+
+ r1 = _make_request("r1")
+ coord._waiting_for_input.append(r1)
+
+ waiting = MockQueue()
+ running: list = []
+
+ coord.restore_queues(waiting, running)
+
+ self.assertIn(r1, waiting)
+ self.assertEqual(len(coord._waiting_for_input), 0)
+
+ def test_full_payload_mode_auto_transitions_waiting_to_waiting_for_input(self):
+ """In full_payload_mode (async_chunk=False), fresh WAITING requests on
+ non-Stage-0 should be transitioned to WAITING_FOR_INPUT."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ async_chunk=False,
+ )
+
+ req = _make_request("r1", status=RequestStatus.WAITING)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_full_payload_inputs(
+ waiting,
+ running,
+ stage_recv_req_ids=set(),
+ )
+
+ self.assertEqual(req.status, RequestStatus.WAITING_FOR_INPUT)
+ self.assertEqual(len(coord._waiting_for_input), 1)
+ self.assertEqual(len(coord.pending_input_registrations), 1)
+
+ def test_async_chunk_mode_does_not_auto_transition(self):
+ """In async_chunk mode, fresh WAITING requests should NOT be
+ transitioned to WAITING_FOR_INPUT."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ async_chunk=True,
+ )
+
+ req = _make_request("r1", status=RequestStatus.WAITING)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_full_payload_inputs(
+ waiting,
+ running,
+ stage_recv_req_ids=set(),
+ )
+
+ self.assertEqual(req.status, RequestStatus.WAITING)
+
+ def test_pending_input_registrations(self):
+ coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
+
+ req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_full_payload_inputs(
+ waiting,
+ running,
+ stage_recv_req_ids=set(),
+ )
+
+ self.assertEqual(len(coord.pending_input_registrations), 1)
+ self.assertEqual(coord.pending_input_registrations[0].request_id, "r1")
+
+
+class TestTimeoutDetection(unittest.TestCase):
+ """Regression tests for orphaned pending-recv timeout detection.
+
+ Covers the full lifecycle:
+ 1. Request enters WAITING_FOR_CHUNK from either waiting or running queue
+ 2. restore_queues() moves it back to the scheduler queue
+ 3. Timeout fires via collect_timed_out_request_ids()
+ 4. Scheduler removes from both queues and calls _free_request()
+ """
+
+ def test_waiting_since_recorded_on_chunk_wait(self):
+ """_waiting_since is set when a request enters WAITING_FOR_CHUNK."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ async_chunk=True,
+ )
+ req = _make_request("r1", status=RequestStatus.WAITING)
+ waiting = MockQueue([req])
+
+ coord.process_pending_chunks(
+ waiting,
+ [],
+ chunk_ready_req_ids=set(),
+ chunk_finished_req_ids=set(),
+ )
+
+ self.assertIn("r1", coord._waiting_since)
+ self.assertEqual(req.status, RequestStatus.WAITING_FOR_CHUNK)
+
+ def test_waiting_since_cleared_on_chunk_arrival(self):
+ """_waiting_since is cleared when a chunk arrives."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ async_chunk=True,
+ )
+ req = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK)
+ waiting = MockQueue([req])
+
+ coord.process_pending_chunks(
+ waiting,
+ [],
+ chunk_ready_req_ids={"r1"},
+ chunk_finished_req_ids=set(),
+ )
+
+ self.assertNotIn("r1", coord._waiting_since)
+
+ def test_waiting_since_recorded_on_input_wait(self):
+ """_waiting_since is set when a request enters WAITING_FOR_INPUT."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ async_chunk=False,
+ )
+ req = _make_request("r1", status=RequestStatus.WAITING)
+ waiting = MockQueue([req])
+
+ coord.process_pending_full_payload_inputs(
+ waiting,
+ [],
+ stage_recv_req_ids=set(),
+ )
+
+ self.assertIn("r1", coord._waiting_since)
+
+ def test_waiting_since_cleared_on_input_arrival(self):
+ """_waiting_since is cleared when input data arrives."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ async_chunk=False,
+ )
+ req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT)
+ coord._waiting_for_input.append(req)
+ coord._waiting_since["r1"] = 0.0
+
+ waiting = MockQueue()
+ coord.process_pending_full_payload_inputs(
+ waiting,
+ [],
+ stage_recv_req_ids={"r1"},
+ )
+
+ self.assertNotIn("r1", coord._waiting_since)
+ self.assertEqual(req.status, RequestStatus.WAITING)
+
+ def test_collect_timed_out_request_ids_no_timeout(self):
+ """No IDs returned when nothing has timed out."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ )
+ import time
+
+ coord._waiting_since["r1"] = time.monotonic()
+
+ result = coord.collect_timed_out_request_ids(timeout_s=300.0)
+ self.assertEqual(result, set())
+
+ def test_collect_timed_out_request_ids_expired(self):
+ """Timed-out IDs are returned and _waiting_since is cleared."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ )
+ coord._waiting_since["r1"] = 0.0 # epoch → definitely expired
+ coord._waiting_since["r2"] = 0.0
+
+ import time
+
+ coord._waiting_since["r3"] = time.monotonic() + 9999 # far future
+
+ result = coord.collect_timed_out_request_ids(timeout_s=1.0)
+
+ self.assertEqual(result, {"r1", "r2"})
+ self.assertNotIn("r1", coord._waiting_since)
+ self.assertNotIn("r2", coord._waiting_since)
+ self.assertIn("r3", coord._waiting_since)
+
+ def test_collect_removes_from_coordinator_queues(self):
+ """Timed-out requests are defensively removed from internal queues."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ )
+ r1 = _make_request("r1")
+ r2 = _make_request("r2")
+ coord._waiting_for_chunk_waiting.append(r1)
+ coord._waiting_for_input.append(r2)
+ coord._waiting_since["r1"] = 0.0
+ coord._waiting_since["r2"] = 0.0
+
+ result = coord.collect_timed_out_request_ids(timeout_s=1.0)
+
+ self.assertEqual(result, {"r1", "r2"})
+ self.assertEqual(len(coord._waiting_for_chunk_waiting), 0)
+ self.assertEqual(len(coord._waiting_for_input), 0)
+
+ def test_free_finished_request_clears_waiting_since(self):
+ """free_finished_request clears _waiting_since."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ )
+ coord._waiting_since["r1"] = 0.0
+ coord.free_finished_request("r1")
+ self.assertNotIn("r1", coord._waiting_since)
+
+ def test_timeout_from_running_queue_full_lifecycle(self):
+ """End-to-end: request from running → WAITING_FOR_CHUNK → restore →
+ timeout → removed from running list.
+
+ This is the critical regression case: WAITING_FOR_CHUNK requests
+ that originated from self.running are placed back into self.running
+ by restore_queues(), but their status remains WAITING_FOR_CHUNK.
+ The scheduler must remove from BOTH queues unconditionally.
+ """
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ async_chunk=True,
+ )
+
+ # 1) Request starts in running queue with WAITING status
+ req = _make_request("r1", status=RequestStatus.WAITING)
+ running = [req]
+ waiting = MockQueue()
+
+ # 2) process_pending_chunks: moves to WAITING_FOR_CHUNK
+ coord.process_pending_chunks(
+ waiting,
+ running,
+ chunk_ready_req_ids=set(),
+ chunk_finished_req_ids=set(),
+ )
+ self.assertEqual(req.status, RequestStatus.WAITING_FOR_CHUNK)
+ self.assertIn("r1", coord._waiting_since)
+ self.assertEqual(len(coord._waiting_for_chunk_running), 1)
+
+ # 3) restore_queues: back to running (status stays WAITING_FOR_CHUNK)
+ coord.restore_queues(waiting, running)
+ self.assertIn(req, running)
+ self.assertEqual(len(coord._waiting_for_chunk_running), 0)
+ self.assertEqual(req.status, RequestStatus.WAITING_FOR_CHUNK)
+
+ # 4) Force timeout by setting _waiting_since to epoch
+ coord._waiting_since["r1"] = 0.0
+
+ timed_out_ids = coord.collect_timed_out_request_ids(timeout_s=1.0)
+ self.assertEqual(timed_out_ids, {"r1"})
+
+ # 5) Scheduler removes from both queues (simulating the scheduler path)
+ timed_out_id_set = {id(req)}
+ running = [r for r in running if id(r) not in timed_out_id_set]
+ waiting.remove_requests([req])
+
+ self.assertNotIn(req, running)
+ self.assertEqual(len(waiting), 0)
+
+ def test_timeout_from_waiting_queue_full_lifecycle(self):
+ """End-to-end: request from waiting → WAITING_FOR_CHUNK → restore →
+ timeout → removed from waiting queue."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=10,
+ stage_id=1,
+ async_chunk=True,
+ )
+
+ req = _make_request("r1", status=RequestStatus.WAITING)
+ waiting = MockQueue([req])
+ running: list = []
+
+ coord.process_pending_chunks(
+ waiting,
+ running,
+ chunk_ready_req_ids=set(),
+ chunk_finished_req_ids=set(),
+ )
+ self.assertEqual(len(coord._waiting_for_chunk_waiting), 1)
+
+ coord.restore_queues(waiting, running)
+ self.assertIn(req, waiting)
+
+ coord._waiting_since["r1"] = 0.0
+ timed_out_ids = coord.collect_timed_out_request_ids(timeout_s=1.0)
+ self.assertEqual(timed_out_ids, {"r1"})
+
+ waiting.remove_requests([req])
+ self.assertEqual(len(waiting), 0)
+
+
+class TestOverflowPreemption(unittest.TestCase):
+ """Tests for P1-1: overflow requests must get WAITING status.
+
+ Overflow happens when multiple WAITING_FOR_CHUNK requests in
+ ``_waiting_for_chunk_running`` receive their chunk in the same cycle.
+ ``_process_chunk_queue`` restores them to RUNNING (``continue``
+ path) while RUNNING requests without chunks are moved out. If the
+ net result exceeds ``scheduler_max_num_seqs``, the tail is pushed
+ to ``waiting_queue`` and must have status == WAITING.
+ """
+
+ def test_overflow_sets_waiting_status(self):
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=1,
+ stage_id=1,
+ async_chunk=True,
+ )
+
+ # r1 is currently RUNNING in the queue.
+ # r2, r3 were previously moved to _waiting_for_chunk_running.
+ r1 = _make_request("r1", status=RequestStatus.RUNNING)
+ r2 = _make_request("r2", status=RequestStatus.WAITING_FOR_CHUNK)
+ r3 = _make_request("r3", status=RequestStatus.WAITING_FOR_CHUNK)
+
+ running = [r1]
+ waiting = MockQueue([])
+ coord._waiting_for_chunk_running.extend([r2, r3])
+
+ # restore_queues puts r2, r3 back into running
+ coord.restore_queues(waiting, running)
+ self.assertEqual(len(running), 3)
+
+ # Now process_pending_chunks with r2, r3 chunks ready:
+ # _process_chunk_queue will:
+ # r1 (RUNNING) → no chunk → move to _waiting_for_chunk_running
+ # r2 (WAITING_FOR_CHUNK, chunk ready) → set RUNNING, stay in running
+ # r3 (WAITING_FOR_CHUNK, chunk ready) → set RUNNING, stay in running
+ # running = [r2, r3], len=2 > max=1 → overflow
+ coord.process_pending_chunks(
+ waiting,
+ running,
+ chunk_ready_req_ids={"r2", "r3"},
+ chunk_finished_req_ids=set(),
+ )
+
+ self.assertEqual(len(running), 1)
+ self.assertEqual(len(waiting), 1)
+ overflow_req = list(waiting)[0]
+ self.assertEqual(
+ overflow_req.status,
+ RequestStatus.WAITING,
+ f"Overflowed request should have WAITING status, got {overflow_req.status}",
+ )
+
+ def test_overflow_does_not_strand_request(self):
+ """Without the fix, the overflowed request would keep its
+ RUNNING status in the waiting queue and never be re-scheduled."""
+ coord = OmniSchedulingCoordinator(
+ scheduler_max_num_seqs=1,
+ stage_id=1,
+ async_chunk=True,
+ )
+
+ r1 = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK)
+ r2 = _make_request("r2", status=RequestStatus.WAITING_FOR_CHUNK)
+ coord._waiting_for_chunk_running.extend([r1, r2])
+
+ running: list = []
+ waiting = MockQueue([])
+
+ coord.restore_queues(waiting, running)
+ self.assertEqual(len(running), 2)
+
+ coord.process_pending_chunks(
+ waiting,
+ running,
+ chunk_ready_req_ids={"r1", "r2"},
+ chunk_finished_req_ids=set(),
+ )
+
+ self.assertEqual(len(running), 1)
+ self.assertEqual(len(waiting), 1)
+ for req in waiting:
+ self.assertNotEqual(req.status, RequestStatus.RUNNING, "Overflowed request must not keep RUNNING status")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/sched/test_generation_scheduler_restore.py b/tests/core/sched/test_generation_scheduler_restore.py
index 154f40b399..5cc1cab702 100644
--- a/tests/core/sched/test_generation_scheduler_restore.py
+++ b/tests/core/sched/test_generation_scheduler_restore.py
@@ -6,7 +6,6 @@
those requests are permanently orphaned.
"""
-import unittest
from collections import deque
import pytest
@@ -39,7 +38,7 @@ def postprocess_scheduler_output(self, output):
pass
-class TestRestoreQueuesOnError(unittest.TestCase):
+class TestRestoreQueuesOnError:
"""Verify that restore_queues is called even when rewrapping raises."""
def test_requests_not_lost_on_exception(self):
@@ -52,8 +51,8 @@ def test_requests_not_lost_on_exception(self):
# Step 1: process_pending_chunks moves req-B out
adapter.process_pending_chunks(waiting=[], running=running)
- self.assertEqual(running, ["req-A"])
- self.assertEqual(len(adapter.waiting_for_chunk_running_requests), 1)
+ assert running == ["req-A"]
+ assert len(adapter.waiting_for_chunk_running_requests) == 1
# Step 2: simulate the try/except/finally pattern
try:
@@ -65,9 +64,9 @@ def test_requests_not_lost_on_exception(self):
adapter.restore_queues(waiting=[], running=running)
# Step 3: verify request is restored
- self.assertTrue(adapter.restore_called)
- self.assertIn("req-B", running)
- self.assertEqual(len(adapter.waiting_for_chunk_running_requests), 0)
+ assert adapter.restore_called is True
+ assert "req-B" in running
+ assert len(adapter.waiting_for_chunk_running_requests) == 0
def test_requests_lost_without_fix(self):
"""Demonstrate the bug: without restore in except, request is lost."""
@@ -76,7 +75,7 @@ def test_requests_lost_without_fix(self):
running = ["req-A", "req-B"]
adapter.process_pending_chunks(waiting=[], running=running)
- self.assertEqual(running, ["req-A"])
+ assert running == ["req-A"]
# Simulate the BUGGY code: except without restore
try:
@@ -85,8 +84,8 @@ def test_requests_lost_without_fix(self):
pass # Bug: no restore_queues call
# Request is lost!
- self.assertNotIn("req-B", running)
- self.assertEqual(len(adapter.waiting_for_chunk_running_requests), 1)
+ assert "req-B" not in running
+ assert len(adapter.waiting_for_chunk_running_requests) == 1
def test_happy_path_restores_via_finally(self):
"""When no exception, restore_queues is still called via finally."""
@@ -102,9 +101,5 @@ def test_happy_path_restores_via_finally(self):
finally:
adapter.restore_queues(waiting=[], running=running)
- self.assertTrue(adapter.restore_called)
- self.assertIn("req-B", running)
-
-
-if __name__ == "__main__":
- unittest.main()
+ assert adapter.restore_called is True
+ assert "req-B" in running
diff --git a/tests/core/sched/test_omni_scheduler_mixin.py b/tests/core/sched/test_omni_scheduler_mixin.py
new file mode 100644
index 0000000000..e04a9c39fb
--- /dev/null
+++ b/tests/core/sched/test_omni_scheduler_mixin.py
@@ -0,0 +1,129 @@
+"""Unit tests for OmniSchedulerMixin streaming session replacement.
+
+These tests pin the behavior of `_replace_session_with_streaming_update` against
+current vLLM `Request` / `StreamingUpdate` (and Omni patches). When upgrading
+vLLM, failures here should highlight incompatible changes to request state or
+update payloads early.
+"""
+
+from __future__ import annotations
+
+from dataclasses import replace
+
+import pytest
+
+# Imports must run in this order: vllm_omni applies patches to vllm.v1.request before
+# Request / StreamingUpdate are bound in this module. Ruff isort would reorder them.
+# isort: off
+import vllm_omni # noqa: F401 - import for side effects (patch vLLM)
+from vllm.sampling_params import SamplingParams
+from vllm.v1.engine import EngineCoreEventType
+from vllm.v1.request import Request, RequestStatus, StreamingUpdate
+from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin
+
+# isort: on
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+class _SchedulerStub(OmniSchedulerMixin):
+ """Minimal scheduler surface required by OmniSchedulerMixin."""
+
+ def __init__(self, *, log_stats: bool = False) -> None:
+ self.num_waiting_for_streaming_input = 0
+ self.log_stats = log_stats
+
+
+def _make_request(**kwargs) -> Request:
+ sp = SamplingParams(max_tokens=8)
+ defaults = dict(
+ request_id="req-mixin-test",
+ prompt_token_ids=[1, 2, 3],
+ sampling_params=sp,
+ pooling_params=None,
+ arrival_time=100.0,
+ block_hasher=None,
+ )
+ defaults.update(kwargs)
+ return Request(**defaults)
+
+
+def _make_update(**kwargs) -> StreamingUpdate:
+ sp_new = SamplingParams(max_tokens=16)
+ defaults = dict(
+ mm_features=None,
+ prompt_token_ids=[10, 20],
+ max_tokens=32,
+ arrival_time=200.0,
+ sampling_params=sp_new,
+ )
+ defaults.update(kwargs)
+ return StreamingUpdate(**defaults)
+
+
+class TestReplaceSessionWithStreamingUpdate:
+ def test_resets_tokens_and_prompt_from_update(self) -> None:
+ sched = _SchedulerStub()
+ session = _make_request()
+ session.append_output_token_ids([7, 8])
+ session.num_computed_tokens = 99
+ session.status = RequestStatus.WAITING_FOR_STREAMING_REQ
+
+ update = _make_update(prompt_token_ids=[40, 41, 42])
+ sched.num_waiting_for_streaming_input = 3
+ sched._replace_session_with_streaming_update(session, update)
+
+ assert session._output_token_ids == []
+ assert list(session._all_token_ids) == [40, 41, 42]
+ assert session.prompt_token_ids == [40, 41, 42]
+ assert session.num_computed_tokens == 0
+ assert session.num_prompt_tokens == 3
+ assert session.arrival_time == 200.0
+ assert session.sampling_params is update.sampling_params
+ assert session.status == RequestStatus.WAITING
+ assert sched.num_waiting_for_streaming_input == 2
+
+ def test_none_prompt_token_ids_becomes_empty(self) -> None:
+ sched = _SchedulerStub()
+ session = _make_request()
+ session.status = RequestStatus.RUNNING
+ update = _make_update(prompt_token_ids=None)
+ sched._replace_session_with_streaming_update(session, update)
+
+ assert session.prompt_token_ids == ()
+ assert list(session._all_token_ids) == []
+ assert session.num_prompt_tokens == 0
+ assert sched.num_waiting_for_streaming_input == 0
+
+ def test_additional_information_cleared_when_update_omits_it(self) -> None:
+ sched = _SchedulerStub()
+ session = _make_request()
+ if not hasattr(session, "additional_information"):
+ pytest.skip("Request has no additional_information (Omni patch inactive?)")
+ session.additional_information = {"keep": True}
+ session.status = RequestStatus.RUNNING
+
+ base = _make_update()
+ if not hasattr(base, "additional_information"):
+ pytest.skip("StreamingUpdate has no additional_information (Omni patch inactive?)")
+ update = replace(base, additional_information=None)
+
+ sched._replace_session_with_streaming_update(session, update)
+ assert session.additional_information is None
+
+ def test_does_not_decrement_waiting_when_not_streaming_status(self) -> None:
+ sched = _SchedulerStub()
+ session = _make_request()
+ session.status = RequestStatus.RUNNING
+ sched.num_waiting_for_streaming_input = 5
+ sched._replace_session_with_streaming_update(session, _make_update())
+ assert sched.num_waiting_for_streaming_input == 5
+
+ def test_records_queued_event_when_log_stats_enabled(self) -> None:
+ sched = _SchedulerStub(log_stats=True)
+ session = _make_request()
+ session.status = RequestStatus.WAITING_FOR_STREAMING_REQ
+ sched._replace_session_with_streaming_update(session, _make_update())
+
+ assert session.events
+ assert session.events[-1].type == EngineCoreEventType.QUEUED
diff --git a/tests/core/test_prefix_cache.py b/tests/core/test_prefix_cache.py
new file mode 100644
index 0000000000..c3d8c1ff92
--- /dev/null
+++ b/tests/core/test_prefix_cache.py
@@ -0,0 +1,347 @@
+from unittest.mock import Mock, patch
+
+import pytest
+import torch
+
+from vllm_omni.core.prefix_cache import OmniTensorPrefixCache
+
+DEFAULT_SEQ_LEN = 15
+NUM_BLOCKS = 10
+BLOCK_SIZE = 4
+HIDDEN_SIZE = 2
+DTYPE = torch.float32
+OTHER_DTYPE = torch.float16
+DEFAULT_SHAPE = torch.Size([NUM_BLOCKS, BLOCK_SIZE, HIDDEN_SIZE])
+
+
+class MockInputBatch:
+ def __init__(self, num_computed_tokens_cpu):
+ self.req_ids = ["req1", "req2"]
+ self.req_id_to_index = {req_id: i for i, req_id in enumerate(self.req_ids)}
+ self.num_computed_tokens_cpu = num_computed_tokens_cpu
+ # Block table is only mocked for validation of length;
+ # we don't actually need to add valid values here since
+ # we patch the table when testing.
+ self.block_table = Mock()
+ self.block_table.block_tables = [None]
+
+
+def get_omni_pcache_with_mm_tensors(feat_dims, seq_len) -> OmniTensorPrefixCache:
+ """Build an OmniTensorPrefixCache and init mm tensors."""
+ cache = get_omni_pcache()
+ mm_outputs = get_multimodal_outputs(feat_dims, seq_len)
+ cache.maybe_init_missing_mm_cache_keys(mm_outputs, seq_len)
+ return cache
+
+
+def get_omni_pcache() -> OmniTensorPrefixCache:
+ """Build an OmniTensorPrefixCache, but don't init mm tensors."""
+ cache = OmniTensorPrefixCache(
+ num_blocks=NUM_BLOCKS,
+ block_size=BLOCK_SIZE,
+ hidden_size=HIDDEN_SIZE,
+ hs_dtype=DTYPE,
+ )
+ return cache
+
+
+def get_multimodal_outputs(feat_dims: dict[str, int], seq_len: int) -> dict[str, torch.Tensor]:
+ fake_mm_inputs = {}
+ for mm_key, feat_dim in feat_dims.items():
+ fake_mm_inputs[mm_key] = torch.rand((seq_len, feat_dim), dtype=DTYPE)
+ return fake_mm_inputs
+
+
+### Tests for initialization
+def test_initialization_simple():
+ """Check default initialization only creates the hidden states."""
+ cache = get_omni_pcache()
+ assert isinstance(cache.hidden_states_cache, torch.Tensor)
+ assert cache.hidden_states_cache.shape == DEFAULT_SHAPE
+ assert len(cache.mm_outputs_cache) == 0
+ assert len(cache.mm_cache_keys) == 0
+
+
+def test_initialization_with_multimodal():
+ """Check initialization + registration of multimodal outputs."""
+ cache = get_omni_pcache()
+ feat_dims = {"foo": 100, "bar": 50, "baz": 10}
+ mm_outputs = get_multimodal_outputs(
+ feat_dims,
+ seq_len=DEFAULT_SEQ_LEN,
+ )
+ # Cast one of the keys to a different dtype; the dtype of the tensor
+ # that is used to initialize the cache dictates the cache dtype.
+ mm_outputs["foo"] = mm_outputs["foo"].to(OTHER_DTYPE)
+
+ cache.maybe_init_missing_mm_cache_keys(mm_outputs, DEFAULT_SEQ_LEN)
+ assert len(cache.mm_cache_keys) == 3
+ assert set(cache.mm_cache_keys) == set(feat_dims.keys())
+ for mm_key in cache.mm_cache_keys:
+ cache_tensor = cache.mm_outputs_cache[mm_key]
+ assert isinstance(cache_tensor, torch.Tensor)
+ assert cache_tensor.shape[-1] == feat_dims[mm_key]
+ assert mm_outputs[mm_key].dtype == cache_tensor.dtype
+
+
+def test_init_missing_mm_cache_keys_is_idempotent():
+ """Ensure that the cache doesn't reinitialize old keys."""
+ cache = get_omni_pcache()
+ mm_key = "foo"
+ feat_dims = {mm_key: 100}
+ mm_outputs = get_multimodal_outputs(
+ feat_dims,
+ seq_len=DEFAULT_SEQ_LEN,
+ )
+ cache.maybe_init_missing_mm_cache_keys(mm_outputs, DEFAULT_SEQ_LEN)
+ assert len(cache.mm_cache_keys) == 1
+ assert mm_key in cache.mm_cache_keys
+
+ # Cache is initialized to 0 - fill it with 1s
+ cache.mm_outputs_cache[mm_key].fill_(1)
+
+ # Ensure that running another initialization
+ # doesn't zero out our cache values
+ cache.maybe_init_missing_mm_cache_keys(mm_outputs, DEFAULT_SEQ_LEN)
+ assert len(cache.mm_cache_keys) == 1
+ assert mm_key in cache.mm_cache_keys
+ assert torch.all(cache.mm_outputs_cache[mm_key] == 1)
+
+
+### Tests for Update
+def test_update_no_multimodal():
+ """Test that slot mappings act as row indices hidden states."""
+ cache = get_omni_pcache()
+
+ num_tokens_unpadded = 8
+ slot_offset = 8
+ slot_mapping = torch.arange(slot_offset, slot_offset + num_tokens_unpadded)
+ new_hidden_states = torch.rand((num_tokens_unpadded, HIDDEN_SIZE), dtype=DTYPE)
+
+ cache.update_omni_tensor_prefix_cache(
+ hidden_states=new_hidden_states,
+ multimodal_outputs=None,
+ num_tokens_unpadded=num_tokens_unpadded,
+ slot_mapping=slot_mapping,
+ )
+
+ # Ensure that if we reshape our 3D cache back to 2D, we can use the
+ # indices in our slot mappings to access the hidden states as expected
+ hs_rows = cache.hidden_states_cache.view(NUM_BLOCKS * BLOCK_SIZE, HIDDEN_SIZE)
+ for slot_idx, new_states in zip(slot_mapping, new_hidden_states):
+ slot_states = hs_rows[slot_idx]
+ assert torch.all(slot_states == new_states)
+
+
+@pytest.mark.parametrize(
+ "feat_dims",
+ [
+ {"foo": 100, "bar": 100},
+ {"foo": 100, "bar": 50, "baz": 10},
+ ],
+)
+def test_update_with_multimodal_outputs(feat_dims):
+ """Test that slot mappings are correct for multimodal tensors."""
+ cache = get_omni_pcache_with_mm_tensors(feat_dims, seq_len=DEFAULT_SEQ_LEN)
+
+ num_tokens_unpadded = 8
+ slot_offset = 8
+ slot_mapping = torch.arange(slot_offset, slot_offset + num_tokens_unpadded)
+ feature_dims = {key: val.shape[-1] for key, val in cache.mm_outputs_cache.items()}
+ mm_outputs = {key: torch.rand((num_tokens_unpadded, feature_dims[key]), dtype=DTYPE) for key in cache.mm_cache_keys}
+ cache.update_omni_tensor_prefix_cache(
+ hidden_states=None,
+ multimodal_outputs=mm_outputs,
+ num_tokens_unpadded=num_tokens_unpadded,
+ slot_mapping=slot_mapping,
+ )
+
+ for mm_key in feat_dims.keys():
+ assert mm_key in cache.mm_outputs_cache
+ key_feat_dim = feature_dims[mm_key]
+ mm_state_rows = cache.mm_outputs_cache[mm_key].view(NUM_BLOCKS * BLOCK_SIZE, key_feat_dim)
+
+ # Similar to hidden states, but for each key in the dict;
+ # Different tensors may have different feature dims
+ new_mm_outputs = mm_outputs[mm_key]
+ for slot_idx, new_output in zip(slot_mapping, new_mm_outputs):
+ slot_states = mm_state_rows[slot_idx]
+ assert torch.all(slot_states == new_output)
+
+
+### Tests for Merging
+def fake_get_cached_block_ids(self, req_idx, *args, **kwargs):
+ """Fake block table lookup.
+
+ Assumption:
+ req_idx 0 is a cache hit with slots 8, 9, ..., 15
+ req_idx 1 is a cache miss
+ """
+ assert req_idx < 2
+ if req_idx == 0:
+ # With the slot offset we provided (8), the corresponding
+ # blocks IDs are 2 & 3 because the block size is 4.
+ return torch.tensor([2, 3], dtype=torch.long)
+ return torch.tensor([], dtype=torch.long)
+
+
+@pytest.mark.parametrize("num_tokens_padded", [None, 16])
+def test_get_merged_hidden_states(num_tokens_padded):
+ """Ensure that hidden states are merged correctly."""
+ cache = get_omni_pcache()
+
+ orig_num_tokens_unpadded = 8
+ slot_offset = 8 # We'll put our states in slots 8, 9, 10, ..., 15
+ orig_slot_mapping = torch.arange(slot_offset, slot_offset + orig_num_tokens_unpadded)
+ orig_hidden_states = torch.rand((orig_num_tokens_unpadded, HIDDEN_SIZE), dtype=DTYPE)
+
+ cache.update_omni_tensor_prefix_cache(
+ hidden_states=orig_hidden_states,
+ multimodal_outputs=None,
+ num_tokens_unpadded=orig_num_tokens_unpadded,
+ slot_mapping=orig_slot_mapping,
+ num_tokens_padded=num_tokens_padded,
+ )
+
+ # Say that we have two requests, but only one of them is a cache hit
+ num_new_toks_req1 = 3
+ num_new_toks_req2 = 2
+ cache.add_prefix_cached_new_req_id("req1")
+
+ num_scheduled_tokens = {
+ "req1": num_new_toks_req1,
+ "req2": num_new_toks_req2,
+ }
+ new_hidden_states = torch.rand(
+ (num_new_toks_req1 + num_new_toks_req2, HIDDEN_SIZE),
+ dtype=DTYPE,
+ )
+ req1_new_states = new_hidden_states[:num_new_toks_req1]
+ req2_new_states = new_hidden_states[-num_new_toks_req2:]
+
+ input_batch = MockInputBatch(num_computed_tokens_cpu=torch.Tensor([orig_num_tokens_unpadded, 0]))
+
+ with patch(
+ "vllm_omni.core.prefix_cache.OmniTensorPrefixCache._get_cached_block_ids",
+ new=fake_get_cached_block_ids,
+ ):
+ merged_states = cache.get_merged_hidden_states(
+ query_start_loc=[0, num_new_toks_req1],
+ input_batch=input_batch,
+ hidden_states=new_hidden_states,
+ num_scheduled_tokens=num_scheduled_tokens,
+ )
+
+ assert "req1" in merged_states and "req2" in merged_states
+ req1_merged_states = merged_states["req1"]
+ req2_merged_states = merged_states["req2"]
+
+ # First, check the cache hit case
+ assert req1_merged_states.shape == torch.Size([orig_num_tokens_unpadded + num_new_toks_req1, HIDDEN_SIZE])
+ # Ensure that the req1 merged states are the cached states + the new req1 states
+ assert torch.all(req1_merged_states[:orig_num_tokens_unpadded] == orig_hidden_states)
+ assert torch.all(req1_merged_states[-num_new_toks_req1:] == req1_new_states)
+
+ # Next, ensure that the cache miss case only has the new states
+ assert req2_merged_states.shape == torch.Size([num_new_toks_req2, HIDDEN_SIZE])
+ assert torch.all(req2_merged_states == req2_new_states)
+
+
+@pytest.mark.parametrize("num_tokens_padded", [None, 16])
+@pytest.mark.parametrize(
+ "feat_dims",
+ [
+ {"foo": 100, "bar": 100},
+ {"foo": 100, "bar": 50, "baz": 10},
+ ],
+)
+def test_get_merged_multimodal_outputs(feat_dims, num_tokens_padded):
+ cache = get_omni_pcache_with_mm_tensors(feat_dims, seq_len=DEFAULT_SEQ_LEN)
+
+ orig_num_tokens_unpadded = 8
+ slot_offset = 8 # We'll put our states in slots 8, 9, 10, ..., 15
+ orig_slot_mapping = torch.arange(slot_offset, slot_offset + orig_num_tokens_unpadded)
+ feature_dims = {key: val.shape[-1] for key, val in cache.mm_outputs_cache.items()}
+ orig_mm_outputs = {
+ key: torch.rand((orig_num_tokens_unpadded, feature_dims[key]), dtype=DTYPE) for key in cache.mm_cache_keys
+ }
+
+ cache.update_omni_tensor_prefix_cache(
+ hidden_states=None,
+ multimodal_outputs=orig_mm_outputs,
+ num_tokens_unpadded=orig_num_tokens_unpadded,
+ slot_mapping=orig_slot_mapping,
+ num_tokens_padded=num_tokens_padded,
+ )
+
+ # Similar to hs test- say that we have two requests, but only one of them is a cache hit
+ num_new_toks_req1 = 3
+ num_new_toks_req2 = 2
+ cache.add_prefix_cached_new_req_id("req1")
+
+ num_scheduled_tokens = {
+ "req1": num_new_toks_req1,
+ "req2": num_new_toks_req2,
+ }
+
+ new_mm_outputs = {}
+ for mm_key in cache.mm_cache_keys:
+ new_mm_outputs[mm_key] = torch.rand(
+ (num_new_toks_req1 + num_new_toks_req2, feature_dims[mm_key]),
+ dtype=DTYPE,
+ )
+ # We also want to make sure passthrough data (outside of our keys) isn't dropped
+ new_mm_outputs["passthrough_data"] = "Something else"
+ # Lists are a special case because we can't split them yet if we want to match
+ # the nonprefix cache behavior, because this runs before post process.
+ new_mm_outputs["passthrough_list"] = ["should", "not", "split"]
+
+ input_batch = MockInputBatch(num_computed_tokens_cpu=torch.Tensor([orig_num_tokens_unpadded, 0]))
+
+ with patch(
+ "vllm_omni.core.prefix_cache.OmniTensorPrefixCache._get_cached_block_ids",
+ new=fake_get_cached_block_ids,
+ ):
+ merged_mm_outputs = cache.get_merged_multimodal_states(
+ query_start_loc=[0, num_new_toks_req1],
+ input_batch=input_batch,
+ multimodal_outputs=new_mm_outputs,
+ num_scheduled_tokens=num_scheduled_tokens,
+ )
+
+ # Ensure the passthrough data wasn't dropped
+ assert "passthrough_data" in merged_mm_outputs
+ assert "passthrough_list" in merged_mm_outputs
+
+ for mm_key, mm_output in merged_mm_outputs.items():
+ # Ensure passthrough data is just forwarded normally and not duplicated
+ assert isinstance(mm_output, dict)
+ assert "req1" in mm_output and "req2" in mm_output
+ if mm_key == "passthrough_data":
+ assert mm_key not in cache.mm_cache_keys
+ assert new_mm_outputs[mm_key] == mm_output["req1"]
+ assert new_mm_outputs[mm_key] == mm_output["req2"]
+ elif mm_key == "passthrough_list":
+ assert mm_key not in cache.mm_cache_keys
+ assert new_mm_outputs[mm_key] == mm_output["req1"]
+ assert new_mm_outputs[mm_key] == mm_output["req2"]
+ else:
+ assert mm_key in cache.mm_cache_keys
+ curr_feat_dim = feature_dims[mm_key]
+ # Ensure that req1 (cache hit) merged the mm data
+ req1_merged_mm_outputs = mm_output["req1"]
+ req1_new_mm_outputs = new_mm_outputs[mm_key][:num_new_toks_req1]
+
+ assert req1_merged_mm_outputs.shape == torch.Size(
+ [orig_num_tokens_unpadded + num_new_toks_req1, curr_feat_dim]
+ )
+ # Ensure that the req1 merged mm data are the cached data + the new data
+ assert torch.all(req1_merged_mm_outputs[:orig_num_tokens_unpadded] == orig_mm_outputs[mm_key])
+ assert torch.all(req1_merged_mm_outputs[-num_new_toks_req1:] == req1_new_mm_outputs)
+
+ # Ensure that req2 (cache miss) only has the new mm data
+ req2_merged_mm_outputs = mm_output["req2"]
+ req2_new_mm_outputs = new_mm_outputs[mm_key][-num_new_toks_req2:]
+
+ assert req2_merged_mm_outputs.shape == torch.Size([num_new_toks_req2, curr_feat_dim])
+ assert torch.all(req2_merged_mm_outputs == req2_new_mm_outputs)
diff --git a/tests/dfx/conftest.py b/tests/dfx/conftest.py
index e54141b344..b8edeba9d5 100644
--- a/tests/dfx/conftest.py
+++ b/tests/dfx/conftest.py
@@ -2,6 +2,8 @@
from pathlib import Path
from typing import Any
+import pytest
+
from tests.conftest import modify_stage_config
@@ -38,22 +40,32 @@ def modify_stage(default_path, updates, deletes):
def create_unique_server_params(
configs: list[dict[str, Any]],
stage_configs_dir: Path,
-) -> list[tuple[str, str, str]]:
+) -> list[tuple[str, str, str | None, str | None, tuple[str, ...]]]:
unique_params = []
seen = set()
for config in configs:
test_name = config["test_name"]
- model = config["server_params"]["model"]
- stage_config_name = config["server_params"].get("stage_config_name")
+ server_params = config["server_params"]
+ model = server_params["model"]
+ stage_config_name = server_params.get("stage_config_name")
if stage_config_name:
stage_config_path = str(stage_configs_dir / stage_config_name)
- delete = config["server_params"].get("delete", None)
- update = config["server_params"].get("update", None)
+ delete = server_params.get("delete", None)
+ update = server_params.get("update", None)
stage_config_path = modify_stage(stage_config_path, update, delete)
else:
stage_config_path = None
- server_param = (test_name, model, stage_config_path)
+ stage_overrides = server_params.get("stage_overrides")
+ stage_overrides_json = json.dumps(stage_overrides) if stage_overrides else None
+
+ # ``extra_cli_args`` passes raw CLI flags straight through to
+ # ``vllm_omni.entrypoints.cli.main serve`` — used for flags that
+ # don't map to stage-level overrides, e.g. ``--async-chunk`` /
+ # ``--no-async-chunk`` toggling the deploy-level async_chunk bool.
+ extra_cli_args = tuple(server_params.get("extra_cli_args") or ())
+
+ server_param = (test_name, model, stage_config_path, stage_overrides_json, extra_cli_args)
if server_param not in seen:
seen.add(server_param)
unique_params.append(server_param)
@@ -95,3 +107,13 @@ def create_benchmark_indices(
indices.append((test_name, idx))
return indices
+
+
+def pytest_addoption(parser: pytest.Parser) -> None:
+ """Register shared CLI options for DFX benchmark suites."""
+ parser.addoption(
+ "--test-config-file",
+ action="store",
+ default=None,
+ help=("Path to benchmark config JSON. Example: --test-config-file tests/dfx/perf/tests/test_tts.json"),
+ )
diff --git a/tests/dfx/perf/scripts/diffusion_result_template.json b/tests/dfx/perf/scripts/diffusion_result_template.json
new file mode 100644
index 0000000000..86bdf1bc7a
--- /dev/null
+++ b/tests/dfx/perf/scripts/diffusion_result_template.json
@@ -0,0 +1,86 @@
+[
+ {
+ "test_name": null,
+ "backend": null,
+ "timestamp": null,
+ "server_params": {
+ "model": null,
+ "serve_args": {
+ "enable-diffusion-pipeline-profiler": false
+ }
+ },
+ "benchmark_params": {
+ "name": null,
+ "dataset": null,
+ "task": null,
+ "width": 0,
+ "height": 0,
+ "num-inference-steps": 0,
+ "num-prompts": 0,
+ "max-concurrency": 0,
+ "num-input-images": 0,
+ "enable-negative-prompt": false,
+ "baseline": {
+ "throughput_qps": 0,
+ "latency_mean": 0,
+ "peak_memory_mb_max": 0,
+ "peak_memory_mb_mean": 0
+ }
+ },
+ "result": {
+ "duration": 0,
+ "completed_requests": 0,
+ "failed_requests": 0,
+ "throughput_qps": 0,
+ "latency_mean": 0,
+ "latency_median": 0,
+ "latency_p99": 0,
+ "latency_p95": 0,
+ "latency_p50": 0,
+ "peak_memory_mb_max": 0,
+ "peak_memory_mb_mean": 0,
+ "peak_memory_mb_median": 0,
+ "stage_durations_mean": {},
+ "stage_durations_p50": {},
+ "stage_durations_p99": {},
+ "backend": null,
+ "model": null,
+ "dataset": null,
+ "task": null
+ },
+ "log_file": null,
+ "Model": null,
+ "Framework": null,
+ "Hardware": null,
+ "Deployment": null,
+ "Task": null,
+ "Dataset": null,
+ "resolution": null,
+ "Parallelism": null,
+ "max_concurrency": 0,
+ "Cache": null,
+ "Quantization": null,
+ "offload": null,
+ "compile": null,
+ "Attn_backend": null,
+ "num_inference_steps": 0,
+ "completed": 0,
+ "failed": 0,
+ "throughput_qps": 0,
+ "latency_mean": 0,
+ "latency_median": 0,
+ "latency_p99": 0,
+ "latency_p95": 0,
+ "latency_p50": 0,
+ "peak_memory_mb_max": 0,
+ "peak_memory_mb_mean": 0,
+ "peak_memory_mb_median": 0,
+ "stage_durations_mean": {},
+ "stage_durations_p50": {},
+ "stage_durations_p99": {},
+ "commit_sha": null,
+ "build_id": null,
+ "build_url": null,
+ "source_file": null
+ }
+]
diff --git a/tests/dfx/perf/scripts/result_omni_template.json b/tests/dfx/perf/scripts/result_omni_template.json
new file mode 100644
index 0000000000..1d61321407
--- /dev/null
+++ b/tests/dfx/perf/scripts/result_omni_template.json
@@ -0,0 +1,55 @@
+{
+ "date": null,
+ "endpoint_type": null,
+ "backend": null,
+ "label": null,
+ "model_id": null,
+ "tokenizer_id": null,
+ "num_prompts": 0,
+ "request_rate": null,
+ "burstiness": 0,
+ "max_concurrency": 0,
+ "duration": 0,
+ "completed": 0,
+ "failed": 0,
+ "total_input_tokens": 0,
+ "total_output_tokens": 0,
+ "request_throughput": 0,
+ "request_goodput": null,
+ "output_throughput": 0,
+ "total_token_throughput": 0,
+ "total_audio_duration_s": 0,
+ "total_audio_frames": 0,
+ "audio_throughput": 0,
+ "max_output_tokens_per_s": 0,
+ "max_concurrent_requests": 0,
+ "rtfx": 0,
+ "mean_ttft_ms": 0,
+ "median_ttft_ms": 0,
+ "p99_ttft_ms": 0,
+ "mean_tpot_ms": 0,
+ "median_tpot_ms": 0,
+ "p99_tpot_ms": 0,
+ "mean_itl_ms": 0,
+ "median_itl_ms": 0,
+ "p99_itl_ms": 0,
+ "mean_e2el_ms": 0,
+ "median_e2el_ms": 0,
+ "p99_e2el_ms": 0,
+ "mean_audio_rtf": 0,
+ "median_audio_rtf": 0,
+ "p99_audio_rtf": 0,
+ "mean_audio_ttfp_ms": 0,
+ "median_audio_ttfp_ms": 0,
+ "p99_audio_ttfp_ms": 0,
+ "mean_audio_duration_s": 0,
+ "median_audio_duration_s": 0,
+ "p99_audio_duration_s": 0,
+ "baseline": {
+ "mean_ttft_ms": 0,
+ "mean_audio_ttfp_ms": 0,
+ "mean_audio_rtf": 0
+ },
+ "random_input_len": 0,
+ "random_output_len": 0
+}
diff --git a/tests/dfx/perf/scripts/run_benchmark.py b/tests/dfx/perf/scripts/run_benchmark.py
index c625239e5c..0de60c6a54 100644
--- a/tests/dfx/perf/scripts/run_benchmark.py
+++ b/tests/dfx/perf/scripts/run_benchmark.py
@@ -21,12 +21,35 @@
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-CONFIG_FILE_PATH = str(Path(__file__).parent.parent / "tests" / "test.json")
+def _get_config_file_from_argv() -> str | None:
+ """Read ``--test-config-file`` from ``sys.argv`` at import time so parametrization can use it."""
+ import sys
+
+ for i, arg in enumerate(sys.argv):
+ if arg == "--test-config-file" and i + 1 < len(sys.argv):
+ return sys.argv[i + 1]
+ if arg.startswith("--test-config-file="):
+ return arg.split("=", 1)[1]
+ return None
+
+
+_PERF_TESTS_DIR = Path(__file__).resolve().parent.parent / "tests"
+_DEFAULT_CONFIG_FILE = str(_PERF_TESTS_DIR / "test_qwen_omni.json")
+
+CONFIG_FILE_PATH = _get_config_file_from_argv()
+if CONFIG_FILE_PATH is None:
+ print(
+ "No --test-config-file in argv, using default: tests/dfx/perf/tests/test_qwen_omni.json "
+ "(override with e.g. --test-config-file tests/dfx/perf/tests/test_tts.json)"
+ )
+ CONFIG_FILE_PATH = _DEFAULT_CONFIG_FILE
+
BENCHMARK_CONFIGS = load_configs(CONFIG_FILE_PATH)
+OMNI_RESULT_TEMPLATE_PATH = Path(__file__).parent / "result_omni_template.json"
-STAGE_CONFIGS_DIR = Path(__file__).parent.parent / "stage_configs"
-test_params = create_unique_server_params(BENCHMARK_CONFIGS, STAGE_CONFIGS_DIR)
+DEPLOY_CONFIGS_DIR = Path(__file__).parent.parent / "deploy"
+test_params = create_unique_server_params(BENCHMARK_CONFIGS, DEPLOY_CONFIGS_DIR)
server_to_benchmark_mapping = create_test_parameter_mapping(BENCHMARK_CONFIGS)
_omni_server_lock = threading.Lock()
@@ -39,13 +62,19 @@ def omni_server(request):
Multi-stage initialization can take 10-20+ minutes.
"""
with _omni_server_lock:
- test_name, model, stage_config_path = request.param
+ test_name, model, stage_config_path, stage_overrides, extra_cli_args = request.param
print(f"Starting OmniServer with test: {test_name}, model: {model}")
- server_args = ["--stage-init-timeout", "120", "--init-timeout", "900"]
+ server_args = ["--stage-init-timeout", "600", "--init-timeout", "900"]
+ # --deploy-config and --stage-overrides compose at the CLI (see vllm_omni/entrypoints/utils.py):
+ # deploy-config sets the base; stage-overrides are applied on top. Both can be set.
if stage_config_path:
- server_args = ["--stage-configs-path", stage_config_path] + server_args
+ server_args = ["--deploy-config", stage_config_path] + server_args
+ if stage_overrides:
+ server_args = ["--stage-overrides", stage_overrides] + server_args
+ if extra_cli_args:
+ server_args = list(extra_cli_args) + server_args
with OmniServer(model, server_args) as server:
server.test_name = test_name
print("OmniServer started successfully")
@@ -55,16 +84,41 @@ def omni_server(request):
print("OmniServer stopped")
+def _safe_filename_token(value: Any | None, *, default: str = "na") -> str:
+ """Make a single path segment safe for result filenames on common filesystems."""
+ if value is None:
+ return default
+ s = str(value).strip()
+ for bad in ("/", "\\", ":", "*", "?", '"', "<", ">", "|"):
+ s = s.replace(bad, "_")
+ return s if s else default
+
+
def run_benchmark(
args: list,
test_name: str,
flow,
dataset_name: str,
num_prompt,
+ *,
+ baseline_config: dict[str, Any] | None = None,
+ sweep_index: int | None = None,
+ request_rate: Any | None = None,
+ max_concurrency: Any | None = None,
+ random_input_len: Any | None = None,
+ random_output_len: Any | None = None,
) -> Any:
- """Run a single benchmark iteration and return the parsed result JSON."""
+ """Run a single benchmark iteration and return the parsed result JSON.
+
+ After ``vllm bench`` writes the JSON, ``result["baseline"]`` holds the same
+ per-metric resolved thresholds as ``assert_result`` (via ``_baseline_thresholds_for_step``).
+ When ``random_input_len`` / ``random_output_len`` are set, they are also written into the result JSON;
+ omitted keys when not configured.
+ """
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
- result_filename = f"result_{test_name}_{dataset_name}_{flow}_{num_prompt}_{current_dt}.json"
+ ri = _safe_filename_token(random_input_len)
+ ro = _safe_filename_token(random_output_len)
+ result_filename = f"result_{test_name}_{dataset_name}_{flow}_{num_prompt}_in{ri}_out{ro}_{current_dt}.json"
if "--result-filename" in args:
print(f"The result file will be overwritten by {result_filename}")
command = (
@@ -94,8 +148,34 @@ def run_benchmark(
else:
result_dir = "./"
- with open(os.path.join(result_dir, result_filename), encoding="utf-8") as f:
- result = json.load(f)
+ result_path = os.path.join(result_dir, result_filename)
+ if not os.path.exists(result_path):
+ with open(OMNI_RESULT_TEMPLATE_PATH, encoding="utf-8") as f:
+ template_result: dict[str, Any] = json.load(f)
+ Path(result_path).parent.mkdir(parents=True, exist_ok=True)
+ with open(result_path, "w", encoding="utf-8") as f:
+ json.dump(template_result, f, ensure_ascii=False, indent=2)
+ print(f"Benchmark result file not generated, fallback to template: {result_path}")
+ result = template_result
+ else:
+ with open(result_path, encoding="utf-8") as f:
+ result = json.load(f)
+
+ if baseline_config:
+ result["baseline"] = _baseline_thresholds_for_step(
+ baseline_config,
+ sweep_index=sweep_index,
+ request_rate=request_rate,
+ max_concurrency=max_concurrency,
+ )
+ else:
+ result["baseline"] = {}
+ if random_input_len is not None:
+ result["random_input_len"] = random_input_len
+ if random_output_len is not None:
+ result["random_output_len"] = random_output_len
+ with open(result_path, "w", encoding="utf-8") as f:
+ json.dump(result, f, ensure_ascii=False, indent=2)
return result
@@ -165,6 +245,25 @@ def _resolve_baseline_value(
return baseline_raw
+def _baseline_thresholds_for_step(
+ baseline_data: dict[str, Any],
+ *,
+ sweep_index: int | None = None,
+ max_concurrency: Any = None,
+ request_rate: Any = None,
+) -> dict[str, Any]:
+ """Resolve ``test.json`` ``baseline`` block to one threshold per metric (same as ``assert_result``)."""
+ return {
+ metric_name: _resolve_baseline_value(
+ baseline_raw,
+ sweep_index=sweep_index,
+ max_concurrency=max_concurrency,
+ request_rate=request_rate,
+ )
+ for metric_name, baseline_raw in baseline_data.items()
+ }
+
+
def assert_result(
result,
params,
@@ -255,6 +354,12 @@ def to_list(value, default=None):
flow=qps,
dataset_name=dataset_name,
num_prompt=num_prompt,
+ baseline_config=params.get("baseline"),
+ sweep_index=i,
+ request_rate=qps,
+ max_concurrency=None,
+ random_input_len=params.get("random_input_len"),
+ random_output_len=params.get("random_output_len"),
)
assert_result(
result,
@@ -273,6 +378,12 @@ def to_list(value, default=None):
flow=concurrency,
dataset_name=dataset_name,
num_prompt=num_prompt,
+ baseline_config=params.get("baseline"),
+ sweep_index=i,
+ request_rate=None,
+ max_concurrency=concurrency,
+ random_input_len=params.get("random_input_len"),
+ random_output_len=params.get("random_output_len"),
)
assert_result(
result,
diff --git a/tests/dfx/perf/scripts/run_diffusion_benchmark.py b/tests/dfx/perf/scripts/run_diffusion_benchmark.py
index 1bd9bf1a14..23efa8bb0f 100644
--- a/tests/dfx/perf/scripts/run_diffusion_benchmark.py
+++ b/tests/dfx/perf/scripts/run_diffusion_benchmark.py
@@ -5,8 +5,8 @@
- vllm-omni (default): starts DiffusionServer via vllm_omni.entrypoints.cli.main,
benchmarks with diffusion_benchmark_serving.py --backend vllm-omni
-A config JSON file is REQUIRED via --config-file:
- pytest run_diffusion_benchmark.py --config-file tests/dfx/perf/tests/test_qwen_image_vllm_omni.json
+A config JSON file is REQUIRED via --test-config-file:
+ pytest run_diffusion_benchmark.py --test-config-file tests/dfx/perf/tests/test_qwen_image_vllm_omni.json
JSON config entries use a "server_type" field, and this runner executes
the vllm-omni path.
@@ -27,13 +27,14 @@
import time
from datetime import datetime
from pathlib import Path
-from typing import Any
+from typing import Any, cast
import psutil
import pytest
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
+os.environ.setdefault("DIFFUSION_ATTENTION_BACKEND", "FLASH_ATTN")
# ---------------------------------------------------------------------------
# Paths
@@ -50,19 +51,21 @@
# Populated lazily after CONFIG_FILE_PATH is resolved.
_SESSION_TIMESTAMP = datetime.now().strftime("%Y%m%d-%H%M%S")
_RESULT_LOCK = threading.Lock()
+_BRANCHPOINT_COMMIT_SHA: str | None = None
+DIFFUSION_RESULT_TEMPLATE_PATH = Path(__file__).parent / "diffusion_result_template.json"
def _get_config_file_from_argv() -> str | None:
- """Read --config-file from sys.argv at import time so pytest parametrize can use it.
+ """Read --test-config-file from sys.argv at import time so pytest parametrize can use it.
pytest_addoption (below) registers the same flag so pytest does not reject it.
- Supports both ``--config-file path`` and ``--config-file=path`` forms.
+ Supports both ``--test-config-file path`` and ``--test-config-file=path`` forms.
Returns None if the flag is not present; callers must handle the missing case.
"""
for i, arg in enumerate(sys.argv):
- if arg == "--config-file" and i + 1 < len(sys.argv):
+ if arg == "--test-config-file" and i + 1 < len(sys.argv):
return sys.argv[i + 1]
- if arg.startswith("--config-file="):
+ if arg.startswith("--test-config-file="):
return arg.split("=", 1)[1]
return None
@@ -110,7 +113,7 @@ def load_configs(config_path: str) -> list[dict[str, Any]]:
BENCHMARK_CONFIGS = load_configs(CONFIG_FILE_PATH)
_config_stem = Path(CONFIG_FILE_PATH).stem # e.g. "test_qwen_image_vllm_omni"
-AGGREGATED_RESULT_FILE = BENCHMARK_RESULT_DIR / f"benchmark_results_{_config_stem}_{_SESSION_TIMESTAMP}.json"
+AGGREGATED_RESULT_FILE = BENCHMARK_RESULT_DIR / f"diffusion_result_{_config_stem}_{_SESSION_TIMESTAMP}.json"
def _append_to_aggregated_file(record: dict[str, Any]) -> None:
@@ -131,19 +134,6 @@ def _append_to_aggregated_file(record: dict[str, Any]) -> None:
json.dump(records, f, indent=2, ensure_ascii=False)
-# Register --config-file with pytest so it does not reject the argument.
-def pytest_addoption(parser: pytest.Parser) -> None:
- parser.addoption(
- "--config-file",
- action="store",
- default=None,
- help=(
- "Path to the benchmark config JSON file (required). "
- "Example: --config-file tests/dfx/perf/tests/test_qwen_image_vllm_omni.json"
- ),
- )
-
-
_server_lock = threading.Lock()
# ---------------------------------------------------------------------------
@@ -232,13 +222,13 @@ class DiffusionServer:
def __init__(
self,
- model: str,
- serve_args: list[str],
+ server_cfg: dict[str, Any],
*,
port: int | None = None,
) -> None:
- self.model = model
- self.serve_args = serve_args
+ self.server_cfg: dict[str, Any] = server_cfg
+ self.model = server_cfg["model"]
+ self.serve_args = server_cfg["serve_args"]
self.host = "127.0.0.1"
self.port = port if port is not None else _get_open_port()
self.proc: subprocess.Popen | None = None
@@ -299,6 +289,95 @@ def _build_serve_args(serve_args_dict: dict[str, Any]) -> list[str]:
return args
+def _get_branchpoint_commit_sha() -> str:
+ """Return the branch-point commit SHA against main.
+
+ Uses git command: ``git merge-base HEAD origin/main``.
+ """
+ global _BRANCHPOINT_COMMIT_SHA
+ if _BRANCHPOINT_COMMIT_SHA is not None:
+ return _BRANCHPOINT_COMMIT_SHA
+
+ repo_root = Path(__file__).parent.parent.parent.parent
+ try:
+ sha = (
+ subprocess.check_output(
+ ["git", "merge-base", "HEAD", "origin/main"],
+ cwd=str(repo_root),
+ stderr=subprocess.STDOUT,
+ text=True,
+ )
+ .strip()
+ .splitlines()[0]
+ )
+ _BRANCHPOINT_COMMIT_SHA = sha
+ except Exception as e:
+ print(f"Warning: failed to get branch-point commit SHA: {e}")
+ _BRANCHPOINT_COMMIT_SHA = ""
+ return _BRANCHPOINT_COMMIT_SHA
+
+
+def _to_resolution_string(params: dict[str, Any]) -> str:
+ width = params.get("width", "unknown width")
+ height = params.get("height", "unknown height")
+ return f"{width}x{height}"
+
+
+def _to_parallelism_string(framework: str, serve_args_dict: dict[str, Any]) -> str:
+ parts: list[str] = []
+ if framework == "vllm-omni":
+ keys = [
+ "num-gpus",
+ "usp",
+ "ulysses-degree",
+ "ring",
+ "ring-degree",
+ "cfg-parallel-size",
+ "vae-patch-parallel-size",
+ "vae-use-tiling",
+ "tensor-parallel-size",
+ ]
+ for key in keys:
+ if key in serve_args_dict:
+ parts.append(f"{key}={serve_args_dict[key]}")
+ return ",".join(parts) if parts else "none"
+
+
+def _to_cache_string(framework: str, serve_args_dict: dict[str, Any]) -> str:
+ if framework == "vllm-omni":
+ if "cache-backend" in serve_args_dict:
+ return str(serve_args_dict["cache-backend"])
+ return "disabled"
+
+
+def _to_offload_string(framework: str, serve_args_dict: dict[str, Any]) -> str:
+ selected: list[str] = []
+ if framework == "vllm-omni":
+ offload_keys = [
+ "enable-cpu-offload",
+ "enable-layerwise-offload",
+ ]
+ for key in offload_keys:
+ if key in serve_args_dict:
+ selected.append(key)
+ return f"enabled({';'.join(selected)})" if selected else "disabled"
+
+
+def _to_compile_value(framework: str, serve_args_dict: dict[str, Any]) -> str:
+ if framework == "vllm-omni":
+ if "enforce-eager" in serve_args_dict:
+ return "disabled"
+ return "enabled"
+ return "disabled"
+
+
+def _to_quantization_value(framework: str, serve_args_dict: dict[str, Any]) -> str:
+ if framework == "vllm-omni":
+ quant = serve_args_dict.get("quantization")
+ return str(quant) if quant else "disabled"
+ return "disabled"
+
+
def _unique_server_params(configs: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Return one server-config dict per unique test_name."""
seen: set[str] = set()
@@ -310,12 +389,14 @@ def _unique_server_params(configs: list[dict[str, Any]]) -> list[dict[str, Any]]
seen.add(test_name)
if cfg.get("server_type", "vllm-omni") != "vllm-omni":
raise ValueError(f"Unsupported server_type in config: {cfg.get('server_type')}")
+ serve_args_dict = cfg["server_params"].get("serve_args", {})
result.append(
{
"test_name": test_name,
"server_type": "vllm-omni",
"model": cfg["server_params"]["model"],
- "serve_args": _build_serve_args(cfg["server_params"].get("serve_args", {})),
+ "serve_args_dict": serve_args_dict,
+ "serve_args": _build_serve_args(serve_args_dict),
"benchmark_backend": "vllm-omni",
"server_params": cfg["server_params"],
}
@@ -334,9 +415,7 @@ def _test_param_mapping(configs: list[dict[str, Any]]) -> dict[str, list[dict]]:
def _make_server(server_cfg: dict[str, Any]) -> DiffusionServer:
"""Factory: return a vLLM-Omni diffusion server instance for the config."""
- model = server_cfg["model"]
- serve_args = server_cfg["serve_args"]
- return DiffusionServer(model=model, serve_args=serve_args)
+ return DiffusionServer(server_cfg=server_cfg)
# ---------------------------------------------------------------------------
@@ -364,7 +443,6 @@ def diffusion_server(request):
print(f"\nStarting {server_type} server for test: {test_name}")
with _make_server(server_cfg) as server:
server.test_name = test_name
- server.server_params = server_cfg["server_params"]
print(f"{server_type} server started successfully")
yield server
print(f"{server_type} server stopping…")
@@ -402,16 +480,18 @@ def run_benchmark(
params: dict[str, Any],
test_name: str,
backend: str = "vllm-omni",
- server_params: dict[str, Any] | None = None,
+ server_cfg: dict[str, Any] | None = None,
+ source_file: str = "",
) -> dict[str, Any]:
"""Run diffusion_benchmark_serving.py as a subprocess and return parsed metrics.
The raw metrics are written to a temporary file by the subprocess. After
the run completes the metrics are merged with full metadata (test_name,
- backend, benchmark_params, timestamp) and appended to the session-wide
- aggregated JSON file (AGGREGATED_RESULT_FILE). The temporary file is
- removed afterwards. Subprocess stdout/stderr are tee'd to a .log file
- under BENCHMARK_RESULT_DIR/logs/; its path is stored in the record.
+ backend, benchmark_params, timestamp, flat reporting fields) and appended
+ to the session-wide aggregated JSON file (AGGREGATED_RESULT_FILE). The
+ temporary file is removed afterwards. Subprocess stdout/stderr are tee'd
+ to a .log file under BENCHMARK_RESULT_DIR/logs/; its path is stored in
+ the record.
"""
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
@@ -484,10 +564,17 @@ def run_benchmark(
if process.returncode != 0:
tmp_result_file.unlink(missing_ok=True)
- raise RuntimeError(f"Benchmark script exited with code {process.returncode}")
+ print(f"ERROR:Benchmark script exited with code {process.returncode}")
if not tmp_result_file.exists():
- raise FileNotFoundError(f"Benchmark result file not found: {tmp_result_file}")
+ with open(DIFFUSION_RESULT_TEMPLATE_PATH, encoding="utf-8") as f:
+ template_payload = json.load(f)
+ # Template schema is fixed and owned by this repo:
+ # ``diffusion_result_template.json`` is a one-item list and metrics live at [0]["result"].
+ template_metrics: dict[str, Any] = template_payload[0]["result"]
+ with open(tmp_result_file, "w", encoding="utf-8") as f:
+ json.dump(template_metrics, f, ensure_ascii=False, indent=2)
+ print(f"Benchmark result file not generated, fallback to template: {tmp_result_file}")
try:
with open(tmp_result_file, encoding="utf-8") as f:
@@ -495,14 +582,55 @@ def run_benchmark(
finally:
tmp_result_file.unlink(missing_ok=True)
+ server_cfg = server_cfg or {}
+ serve_args_dict = server_cfg.get("serve_args_dict", {})
+ if not isinstance(serve_args_dict, dict):
+ serve_args_dict = {}
+
+ completed = metrics.get("completed_requests", metrics.get("completed", 0))
+ failed = metrics.get("failed_requests", metrics.get("failed", 0))
+
record: dict[str, Any] = {
"test_name": test_name,
"backend": backend,
"timestamp": timestamp,
- "server_params": server_params,
+ "server_params": server_cfg.get("server_params"),
"benchmark_params": params,
"result": metrics,
"log_file": str(log_file),
+ "Model": model,
+ "Framework": backend,
+ "Hardware": "",
+ "Deployment": "",
+ "Task": params.get("task", "t2i"),
+ "Dataset": params.get("dataset", "random"),
+ "resolution": _to_resolution_string(params),
+ "Parallelism": _to_parallelism_string(backend, serve_args_dict),
+ "max_concurrency": params.get("max-concurrency", ""),
+ "Cache": _to_cache_string(backend, serve_args_dict),
+ "Quantization": _to_quantization_value(backend, serve_args_dict),
+ "offload": _to_offload_string(backend, serve_args_dict),
+ "compile": _to_compile_value(backend, serve_args_dict),
+ "Attn_backend": os.environ.get("DIFFUSION_ATTENTION_BACKEND", ""),
+ "num_inference_steps": params.get("num-inference-steps", ""),
+ "completed": completed,
+ "failed": failed,
+ "throughput_qps": metrics.get("throughput_qps"),
+ "latency_mean": metrics.get("latency_mean"),
+ "latency_median": metrics.get("latency_median"),
+ "latency_p99": metrics.get("latency_p99"),
+ "latency_p95": metrics.get("latency_p95"),
+ "latency_p50": metrics.get("latency_p50"),
+ "peak_memory_mb_max": metrics.get("peak_memory_mb_max"),
+ "peak_memory_mb_mean": metrics.get("peak_memory_mb_mean"),
+ "peak_memory_mb_median": metrics.get("peak_memory_mb_median"),
+ "stage_durations_mean": metrics.get("stage_durations_mean"),
+ "stage_durations_p50": metrics.get("stage_durations_p50"),
+ "stage_durations_p99": metrics.get("stage_durations_p99"),
+ "commit_sha": _get_branchpoint_commit_sha(),
+ "build_id": os.environ.get("BUILDKITE_BUILD_ID", ""),
+ "build_url": os.environ.get("BUILDKITE_BUILD_URL", ""),
+ "source_file": source_file,
}
_append_to_aggregated_file(record)
print(f"\n Result appended to: {AGGREGATED_RESULT_FILE}")
@@ -565,7 +693,8 @@ def test_diffusion_performance_benchmark(diffusion_server, benchmark_params):
params=params,
test_name=test_name,
backend=backend,
- server_params=diffusion_server.server_params,
+ server_cfg=getattr(diffusion_server, "server_cfg", {}),
+ source_file=cast(str, CONFIG_FILE_PATH),
)
print(f"\n{'=' * 60}")
diff --git a/tests/dfx/perf/stage_configs/qwen3_omni.yaml b/tests/dfx/perf/stage_configs/qwen3_omni.yaml
deleted file mode 100644
index 2add22b873..0000000000
--- a/tests/dfx/perf/stage_configs/qwen3_omni.yaml
+++ /dev/null
@@ -1,101 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
-
-# The following config has been verified on 2x H100-80G GPUs.
-async_chunk: false
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "0"
- engine_args:
- model_stage: thinker
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- hf_config_name: thinker_config
- tensor_parallel_size: 1
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
-
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "1"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.1
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 100000
- hf_config_name: thinker_config
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/tests/dfx/perf/stage_configs/qwen3_tts.yaml b/tests/dfx/perf/stage_configs/qwen3_tts.yaml
deleted file mode 100644
index 97b3090560..0000000000
--- a/tests/dfx/perf/stage_configs/qwen3_tts.yaml
+++ /dev/null
@@ -1,96 +0,0 @@
-# Stage config for running Qwen3-TTS with 2-stage architecture
-# Stage 0: Talker (text -> 8-layer RVQ codec codes)
-# Stage 1: Code2Wav (codec codes -> audio waveform)
-#
-# The following config has been verified on 1x H100-80G GPU.
-async_chunk: true
-stage_args:
- - stage_id: 0
- stage_type: llm
- is_comprehension: true
- runtime:
- devices: "0"
- engine_args:
- max_num_seqs: 4
- model_stage: qwen3_tts
- model_arch: Qwen3TTSTalkerForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- enforce_eager: false
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: latent
- gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 512
- max_model_len: 4096
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: false
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 1
- stage_type: llm
- runtime:
- devices: "0"
- engine_args:
- max_num_seqs: 4
- model_stage: code2wav
- model_arch: Qwen3TTSCode2Wav
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: audio
- gpu_memory_utilization: 0.2
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 8192
- max_model_len: 32768
- engine_input_source: [0]
- final_output: true
- final_output_type: audio
- input_connectors:
- from_stage_0: connector_of_shared_memory
- tts_args:
- max_instructions_length: 500
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: true
- repetition_penalty: 1.0
-
-runtime:
- enabled: true
- defaults:
- window_size: -1
- max_inflight: 4
-
- connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536
- codec_streaming: true
- connector_get_sleep_s: 0.01
- connector_get_max_wait_first_chunk: 3000
- connector_get_max_wait: 300
- codec_chunk_frames: 25
- codec_left_context_frames: 72
-
- edges:
- - from: 0
- to: 1
- window_size: -1
diff --git a/tests/dfx/perf/tests/test.json b/tests/dfx/perf/tests/test.json
deleted file mode 100644
index fe7e380469..0000000000
--- a/tests/dfx/perf/tests/test.json
+++ /dev/null
@@ -1,236 +0,0 @@
-[
- {
- "test_name": "test_qwen3_omni",
- "server_params": {
- "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
- "stage_config_name": "qwen3_omni.yaml"
- },
- "benchmark_params": [
- {
- "dataset_name": "random",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "num_prompts": [
- 10,
- 40,
- 100
- ],
- "max_concurrency": [
- 1,
- 4,
- 10
- ],
- "random_input_len": 100,
- "random_output_len": 100,
- "ignore_eos": true,
- "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "mean_ttft_ms": [1000, 3000, 5000],
- "mean_audio_ttfp_ms": [8000, 10000, 13000],
- "mean_audio_rtf": [0.2, 0.25, 0.45]
- }
- },
- {
- "dataset_name": "random-mm",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "num_prompts": [
- 10,
- 40,
- 100
- ],
- "request_rate": [
- 0.1,
- 0.3,
- 0.5
- ],
- "random_input_len": 100,
- "random_output_len": 100,
- "random_range_ratio": 0.0,
- "ignore_eos": true,
- "random_mm_base_items_per_request": 3,
- "random_mm_num_mm_items_range_ratio": 0,
- "random_mm_limit_mm_per_prompt": {
- "image": 1,
- "video": 1,
- "audio": 1
- },
- "random_mm_bucket_config": {
- "(32, 32, 1)": 0.5,
- "(0, 1, 1)": 0.1,
- "(32, 32, 2)": 0.4
- },
- "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "mean_ttft_ms": [2000, 4000, 6000],
- "mean_audio_ttfp_ms": [10000, 13000, 15000],
- "mean_audio_rtf": [0.25, 0.35, 0.45]
- }
- },
- {
- "dataset_name": "random",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "num_prompts": [
- 4,
- 16
- ],
- "max_concurrency": [
- 1,
- 4
- ],
- "random_input_len": 2500,
- "random_output_len": 900,
- "ignore_eos": true,
- "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "mean_ttft_ms": [1000, 3000],
- "mean_audio_ttfp_ms": [30000, 60000],
- "mean_audio_rtf": [0.35, 0.45]
- }
- }
- ]
- },
- {
- "test_name": "test_qwen3_omni_chunk",
- "server_params": {
- "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
- "stage_config_name": "qwen3_omni.yaml",
- "update": {
- "async_chunk": true,
- "stage_args": {
- "0": {
- "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk"
- },
- "1": {
- "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk"
- }
- }
- },
- "delete": {
- "stage_args": {
- "2": [
- "custom_process_input_func"
- ]
- }
- }
- },
- "benchmark_params": [
- {
- "dataset_name": "random",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "num_prompts": [
- 10,
- 40,
- 100
- ],
- "max_concurrency": [
- 1,
- 4,
- 10
- ],
- "random_input_len": 100,
- "random_output_len": 100,
- "ignore_eos": true,
- "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "mean_ttft_ms": [1000, 3000, 5000],
- "mean_audio_ttfp_ms": [1000, 3000, 5000],
- "mean_audio_rtf": [0.2, 0.35, 0.6]
- }
- },
- {
- "dataset_name": "random-mm",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "num_prompts": [
- 10,
- 40,
- 100
- ],
- "request_rate": [
- 0.1,
- 0.3,
- 0.5
- ],
- "random_input_len": 100,
- "random_output_len": 100,
- "random_range_ratio": 0.0,
- "ignore_eos": true,
- "random_mm_base_items_per_request": 3,
- "random_mm_num_mm_items_range_ratio": 0,
- "random_mm_limit_mm_per_prompt": {
- "image": 1,
- "video": 1,
- "audio": 1
- },
- "random_mm_bucket_config": {
- "(32, 32, 1)": 0.5,
- "(0, 1, 1)": 0.1,
- "(32, 32, 2)": 0.4
- },
- "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "mean_ttft_ms": [2000, 4000, 6000],
- "mean_audio_ttfp_ms": [2000, 4000, 6000],
- "mean_audio_rtf": [0.25, 0.4, 0.7]
- }
- },
- {
- "dataset_name": "random",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "num_prompts": [
- 4,
- 16
- ],
- "max_concurrency": [
- 1,
- 4
- ],
- "random_input_len": 2500,
- "random_output_len": 900,
- "ignore_eos": true,
- "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "mean_ttft_ms": [1000, 3000],
- "mean_audio_ttfp_ms": [1000, 3000],
- "mean_audio_rtf": [0.35, 0.45]
- }
- }
- ]
- },
- {
- "test_name": "test_qwen3_tts",
- "server_params": {
- "model": "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
- },
- "benchmark_params": [
- {
- "dataset_name": "random",
- "backend": "openai-audio-speech",
- "endpoint": "/v1/audio/speech",
- "num_prompts": [
- 10,
- 40
- ],
- "max_concurrency": [
- 1,
- 4
- ],
- "random_input_len": 100,
- "random_output_len": 100,
- "extra_body": {
- "voice": "Vivian",
- "language": "English"
- },
- "percentile-metrics": "ttft,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "mean_audio_ttfp_ms": [6000, 6000],
- "mean_audio_rtf": [0.3, 0.3]
- }
- }
- ]
- }
-]
diff --git a/tests/dfx/perf/tests/test_qwen_image_edit_2509_vllm_omni.json b/tests/dfx/perf/tests/test_qwen_image_edit_2509_vllm_omni.json
new file mode 100644
index 0000000000..7d1fbbfa70
--- /dev/null
+++ b/tests/dfx/perf/tests/test_qwen_image_edit_2509_vllm_omni.json
@@ -0,0 +1,167 @@
+[
+ {
+ "test_name": "test_qwen_image_edit_2509_single_device",
+ "description": "Single-device baseline (two input images)",
+ "server_type": "vllm-omni",
+ "server_params": {
+ "model": "Qwen/Qwen-Image-Edit-2509",
+ "serve_args": {
+ "enable-diffusion-pipeline-profiler": true
+ }
+ },
+ "benchmark_params": [
+ {
+ "name": "512x512_steps20_i2i_2img",
+ "dataset": "random",
+ "task": "i2i",
+ "width": 512,
+ "height": 512,
+ "num-inference-steps": 20,
+ "num-prompts": 10,
+ "max-concurrency": 1,
+ "num-input-images": 2,
+ "enable-negative-prompt": true,
+ "baseline": {
+ "throughput_qps": 0.05,
+ "latency_mean": 18,
+ "peak_memory_mb_max": 78500,
+ "peak_memory_mb_mean": 78500
+ }
+ },
+ {
+ "name": "1536x1536_steps35_i2i_2img",
+ "dataset": "random",
+ "task": "i2i",
+ "width": 1536,
+ "height": 1536,
+ "num-inference-steps": 35,
+ "num-prompts": 10,
+ "max-concurrency": 1,
+ "num-input-images": 2,
+ "enable-negative-prompt": true,
+ "baseline": {
+ "throughput_qps": 0.01,
+ "latency_mean": 70,
+ "peak_memory_mb_max": 81000,
+ "peak_memory_mb_mean": 81000
+ }
+ }
+ ]
+ },
+ {
+ "test_name": "test_qwen_image_edit_2509_ulysses2_cfg2_vae_patch4",
+ "description": "Ulysses SP=2 + CFG=2 + VAE patch parallel=4",
+ "server_type": "vllm-omni",
+ "server_params": {
+ "model": "Qwen/Qwen-Image-Edit-2509",
+ "serve_args": {
+ "ulysses-degree": 2,
+ "cfg-parallel-size": 2,
+ "vae-patch-parallel-size": 4,
+ "vae-use-tiling": true,
+ "enable-diffusion-pipeline-profiler": true
+ }
+ },
+ "benchmark_params": [
+ {
+ "name": "512x512_steps20_i2i_2img",
+ "dataset": "random",
+ "task": "i2i",
+ "width": 512,
+ "height": 512,
+ "num-inference-steps": 20,
+ "num-prompts": 10,
+ "max-concurrency": 1,
+ "num-input-images": 2,
+ "enable-negative-prompt": true,
+ "baseline": {
+ "throughput_qps": 0.1,
+ "latency_mean": 12,
+ "peak_memory_mb_max": 69000,
+ "peak_memory_mb_mean": 69000
+ }
+ },
+ {
+ "name": "1536x1536_steps35_i2i_2img",
+ "dataset": "random",
+ "task": "i2i",
+ "width": 1536,
+ "height": 1536,
+ "num-inference-steps": 35,
+ "num-prompts": 10,
+ "max-concurrency": 1,
+ "num-input-images": 2,
+ "enable-negative-prompt": true,
+ "baseline": {
+ "throughput_qps": 0.03,
+ "latency_mean": 28,
+ "peak_memory_mb_max": 69000,
+ "peak_memory_mb_mean": 69000
+ }
+ }
+ ]
+ },
+ {
+ "test_name": "test_qwen_image_edit_2509_ulysses2_cfg2_cache_dit",
+ "description": "Ulysses SP=2 + CFG=2 + CacheDiT",
+ "server_type": "vllm-omni",
+ "server_params": {
+ "model": "Qwen/Qwen-Image-Edit-2509",
+ "serve_args": {
+ "ulysses-degree": 2,
+ "cfg-parallel-size": 2,
+ "cache-backend": "cache_dit",
+ "cache-config": {
+ "Fn_compute_blocks": 1,
+ "Bn_compute_blocks": 0,
+ "max_warmup_steps": 4,
+ "residual_diff_threshold": 0.24,
+ "max_continuous_cached_steps": 3,
+ "enable_taylorseer": false,
+ "taylorseer_order": 1,
+ "scm_steps_mask_policy": null,
+ "scm_steps_policy": "dynamic"
+ },
+ "enable-diffusion-pipeline-profiler": true
+ }
+ },
+ "benchmark_params": [
+ {
+ "name": "512x512_steps20_i2i_2img",
+ "dataset": "random",
+ "task": "i2i",
+ "width": 512,
+ "height": 512,
+ "num-inference-steps": 20,
+ "num-prompts": 10,
+ "max-concurrency": 1,
+ "num-input-images": 2,
+ "enable-negative-prompt": true,
+ "baseline": {
+ "throughput_qps": 0.10,
+ "latency_mean": 12,
+ "peak_memory_mb_max": 73000,
+ "peak_memory_mb_mean": 73000
+ }
+ },
+ {
+ "name": "1536x1536_steps35_i2i_2img",
+ "dataset": "random",
+ "task": "i2i",
+ "width": 1536,
+ "height": 1536,
+ "num-inference-steps": 35,
+ "num-prompts": 10,
+ "max-concurrency": 1,
+ "num-input-images": 2,
+ "enable-negative-prompt": true,
+ "baseline": {
+ "throughput_qps": 0.05,
+ "latency_mean": 20,
+ "peak_memory_mb_max": 81000,
+ "peak_memory_mb_mean": 81000
+ }
+ }
+ ]
+ }
+]
diff --git a/tests/dfx/perf/tests/test_qwen_image_edit_vllm_omni.json b/tests/dfx/perf/tests/test_qwen_image_edit_vllm_omni.json
new file mode 100644
index 0000000000..f68201db5f
--- /dev/null
+++ b/tests/dfx/perf/tests/test_qwen_image_edit_vllm_omni.json
@@ -0,0 +1,161 @@
+[
+ {
+ "test_name": "test_qwen_image_edit_single_device",
+ "description": "Single-device baseline",
+ "server_type": "vllm-omni",
+ "server_params": {
+ "model": "Qwen/Qwen-Image-Edit",
+ "serve_args": {
+ "enable-diffusion-pipeline-profiler": true
+ }
+ },
+ "benchmark_params": [
+ {
+ "name": "512x512_steps20_i2i",
+ "dataset": "random",
+ "task": "i2i",
+ "width": 512,
+ "height": 512,
+ "num-inference-steps": 20,
+ "num-prompts": 10,
+ "max-concurrency": 1,
+ "enable-negative-prompt": true,
+ "baseline": {
+ "throughput_qps": 0.05,
+ "latency_mean": 15.0,
+ "peak_memory_mb_max": 72500,
+ "peak_memory_mb_mean": 72500
+ }
+ },
+ {
+ "name": "1536x1536_steps35_i2i",
+ "dataset": "random",
+ "task": "i2i",
+ "width": 1536,
+ "height": 1536,
+ "num-inference-steps": 35,
+ "num-prompts": 10,
+ "max-concurrency": 1,
+ "enable-negative-prompt": true,
+ "baseline": {
+ "throughput_qps": 0.01,
+ "latency_mean": 65.6,
+ "peak_memory_mb_max": 80777,
+ "peak_memory_mb_mean": 80777
+ }
+ }
+ ]
+ },
+ {
+ "test_name": "test_qwen_image_edit_ulysses2_cfg2_vae_patch4",
+ "description": "Ulysses SP=2 + CFG=2 + VAE patch parallel=4",
+ "server_type": "vllm-omni",
+ "server_params": {
+ "model": "Qwen/Qwen-Image-Edit",
+ "serve_args": {
+ "ulysses-degree": 2,
+ "cfg-parallel-size": 2,
+ "vae-patch-parallel-size": 4,
+ "vae-use-tiling": true,
+ "enable-diffusion-pipeline-profiler": true
+ }
+ },
+ "benchmark_params": [
+ {
+ "name": "512x512_steps20_i2i",
+ "dataset": "random",
+ "task": "i2i",
+ "width": 512,
+ "height": 512,
+ "num-inference-steps": 20,
+ "num-prompts": 10,
+ "max-concurrency": 1,
+ "enable-negative-prompt": true,
+ "baseline": {
+ "throughput_qps": 0.10,
+ "latency_mean": 7.2,
+ "peak_memory_mb_max": 68100,
+ "peak_memory_mb_mean": 68100
+ }
+ },
+ {
+ "name": "1536x1536_steps35_i2i",
+ "dataset": "random",
+ "task": "i2i",
+ "width": 1536,
+ "height": 1536,
+ "num-inference-steps": 35,
+ "num-prompts": 10,
+ "max-concurrency": 1,
+ "enable-negative-prompt": true,
+ "baseline": {
+ "throughput_qps": 0.03,
+ "latency_mean": 24.0,
+ "peak_memory_mb_max": 68100,
+ "peak_memory_mb_mean": 68100
+ }
+ }
+ ]
+ },
+ {
+ "test_name": "test_qwen_image_edit_ulysses2_cfg2_cache_dit",
+ "description": "Ulysses SP=2 + CFG=2 + CacheDiT",
+ "server_type": "vllm-omni",
+ "server_params": {
+ "model": "Qwen/Qwen-Image-Edit",
+ "serve_args": {
+ "ulysses-degree": 2,
+ "cfg-parallel-size": 2,
+ "cache-backend": "cache_dit",
+ "cache-config": {
+ "Fn_compute_blocks": 1,
+ "Bn_compute_blocks": 0,
+ "max_warmup_steps": 4,
+ "residual_diff_threshold": 0.24,
+ "max_continuous_cached_steps": 3,
+ "enable_taylorseer": false,
+ "taylorseer_order": 1,
+ "scm_steps_mask_policy": null,
+ "scm_steps_policy": "dynamic"
+ },
+ "enable-diffusion-pipeline-profiler": true
+ }
+ },
+ "benchmark_params": [
+ {
+ "name": "512x512_steps20_i2i",
+ "dataset": "random",
+ "task": "i2i",
+ "width": 512,
+ "height": 512,
+ "num-inference-steps": 20,
+ "num-prompts": 10,
+ "max-concurrency": 1,
+ "enable-negative-prompt": true,
+ "baseline": {
+ "throughput_qps": 0.1,
+ "latency_mean": 6.5,
+ "peak_memory_mb_max": 72600,
+ "peak_memory_mb_mean": 72600
+ }
+ },
+ {
+ "name": "1536x1536_steps35_i2i",
+ "dataset": "random",
+ "task": "i2i",
+ "width": 1536,
+ "height": 1536,
+ "num-inference-steps": 35,
+ "num-prompts": 10,
+ "max-concurrency": 1,
+ "enable-negative-prompt": true,
+ "baseline": {
+ "throughput_qps": 0.05,
+ "latency_mean": 16.0,
+ "peak_memory_mb_max": 81000,
+ "peak_memory_mb_mean": 81000
+ }
+ }
+ ]
+ }
+]
diff --git a/tests/dfx/perf/tests/test_qwen_image_layered_vllm_omni.json b/tests/dfx/perf/tests/test_qwen_image_layered_vllm_omni.json
new file mode 100644
index 0000000000..3cf13509c8
--- /dev/null
+++ b/tests/dfx/perf/tests/test_qwen_image_layered_vllm_omni.json
@@ -0,0 +1,49 @@
+[
+ {
+ "test_name": "test_qwen_image_layered_single_device",
+ "description": "Single-device baseline",
+ "server_type": "vllm-omni",
+ "server_params": {
+ "model": "Qwen/Qwen-Image-Layered",
+ "serve_args": {
+ "enable-diffusion-pipeline-profiler": true
+ }
+ },
+ "benchmark_params": [
+ {
+ "name": "640x640_steps20_i2i",
+ "dataset": "random",
+ "task": "i2i",
+ "width": 640,
+ "height": 640,
+ "num-inference-steps": 20,
+ "num-prompts": 10,
+ "max-concurrency": 1,
+ "enable-negative-prompt": true,
+ "baseline": {
+ "throughput_qps": 0.02,
+ "latency_mean": 40.0,
+ "peak_memory_mb_max": 70000,
+ "peak_memory_mb_mean": 70000
+ }
+ },
+ {
+ "name": "1024x1024_steps35_i2i",
+ "dataset": "random",
+ "task": "i2i",
+ "width": 1024,
+ "height": 1024,
+ "num-inference-steps": 35,
+ "num-prompts": 10,
+ "max-concurrency": 1,
+ "enable-negative-prompt": true,
+ "baseline": {
+ "throughput_qps": 0.005,
+ "latency_mean": 80.0,
+ "peak_memory_mb_max": 70000,
+ "peak_memory_mb_mean": 70000
+ }
+ }
+ ]
+ }
+]
diff --git a/tests/dfx/perf/tests/test_qwen_image_vllm_omni.json b/tests/dfx/perf/tests/test_qwen_image_vllm_omni.json
index 387e874ad5..5ec7f1cc2b 100644
--- a/tests/dfx/perf/tests/test_qwen_image_vllm_omni.json
+++ b/tests/dfx/perf/tests/test_qwen_image_vllm_omni.json
@@ -44,7 +44,52 @@
}
]
},
-
+ {
+ "test_name": "test_qwen_image_single_device_step_execution",
+ "description": "Single-device baseline (no parallelism) with step execution",
+ "server_type": "vllm-omni",
+ "server_params": {
+ "model": "Qwen/Qwen-Image",
+ "serve_args": {
+ "enable-diffusion-pipeline-profiler": true,
+ "step-execution": true
+ }
+ },
+ "benchmark_params": [
+ {
+ "name": "512x512_steps20",
+ "dataset": "random",
+ "task": "t2i",
+ "width": 512,
+ "height": 512,
+ "num-inference-steps": 20,
+ "num-prompts": 10,
+ "max-concurrency": 1,
+ "enable-negative-prompt": true,
+ "baseline": {
+ "throughput_qps": 0.30,
+ "latency_mean": 3.50,
+ "peak_memory_mb_mean": 67000
+ }
+ },
+ {
+ "name": "1536x1536_steps35",
+ "dataset": "random",
+ "task": "t2i",
+ "width": 1536,
+ "height": 1536,
+ "num-inference-steps": 35,
+ "num-prompts": 10,
+ "max-concurrency": 1,
+ "enable-negative-prompt": true,
+ "baseline": {
+ "throughput_qps": 0.037,
+ "latency_mean": 27.0,
+ "peak_memory_mb_mean": 74000
+ }
+ }
+ ]
+ },
{
"test_name": "test_qwen_image_ulysses2_cfg2_vae_patch4",
"description": "Ulysses SP=2 + CFG-parallel=2 + VAE Patch Parallel=4",
@@ -72,7 +117,7 @@
"enable-negative-prompt": true,
"baseline": {
"throughput_qps": 0.1,
- "latency_mean": 2.34,
+ "latency_mean": 2.7,
"peak_memory_mb_mean": 61000
}
},
@@ -94,7 +139,6 @@
}
]
},
-
{
"test_name": "test_qwen_image_ulysses2_cfg2_cache_dit",
"description": "Ulysses SP=2 + CFG-parallel=2 + CacheDiT acceleration",
diff --git a/tests/dfx/perf/tests/test_qwen_omni.json b/tests/dfx/perf/tests/test_qwen_omni.json
new file mode 100644
index 0000000000..ca3eb55570
--- /dev/null
+++ b/tests/dfx/perf/tests/test_qwen_omni.json
@@ -0,0 +1,315 @@
+[
+ {
+ "test_name": "test_qwen3_omni",
+ "server_params": {
+ "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
+ "extra_cli_args": ["--no-async-chunk"]
+ },
+ "benchmark_params": [
+ {
+ "dataset_name": "random",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [4, 16, 40],
+ "max_concurrency": [1, 4, 10],
+ "random_input_len": 2500,
+ "random_output_len": 900,
+ "ignore_eos": true,
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "baseline": {
+ "mean_ttft_ms": [1000, 3000, 5000],
+ "mean_audio_ttfp_ms": [30000, 60000, 90000],
+ "mean_audio_rtf": [0.35, 0.45, 0.55]
+ }
+ },
+ {
+ "dataset_name": "random-mm",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [10],
+ "request_rate": [0.1],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "random_range_ratio": 0.0,
+ "ignore_eos": true,
+ "random_mm_base_items_per_request": 1,
+ "random_mm_num_mm_items_range_ratio": 0.5,
+ "random_mm_limit_mm_per_prompt": {
+ "audio": 1
+ },
+ "random_mm_bucket_config": {
+ "(0, 60, 3)": 1.0
+ },
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "baseline": {
+ "mean_ttft_ms": [2000],
+ "mean_audio_ttfp_ms": [10000],
+ "mean_audio_rtf": [0.25]
+ }
+ },
+ {
+ "dataset_name": "random-mm",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [40],
+ "request_rate": [0.3],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "random_range_ratio": 0.0,
+ "ignore_eos": true,
+ "random_mm_base_items_per_request": 2,
+ "random_mm_num_mm_items_range_ratio": 0.5,
+ "random_mm_limit_mm_per_prompt": {
+ "image": 1,
+ "video": 1
+ },
+ "random_mm_bucket_config": {
+ "(256, 256, 1)": 0.5,
+ "(720, 1280, 2)": 0.5
+ },
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "baseline": {
+ "mean_ttft_ms": [4000],
+ "mean_audio_ttfp_ms": [13000],
+ "mean_audio_rtf": [0.35]
+ }
+ },
+ {
+ "dataset_name": "random-mm",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [100],
+ "request_rate": [0.5],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "random_range_ratio": 0.0,
+ "ignore_eos": true,
+ "random_mm_base_items_per_request": 3,
+ "random_mm_num_mm_items_range_ratio": 0.5,
+ "random_mm_limit_mm_per_prompt": {
+ "image": 1,
+ "video": 1,
+ "audio": 1
+ },
+ "random_mm_bucket_config": {
+ "(256, 256, 1)": 0.34,
+ "(720, 1280, 2)": 0.33,
+ "(0, 60, 3)": 0.33
+ },
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "baseline": {
+ "mean_ttft_ms": [6000],
+ "mean_audio_ttfp_ms": [15000],
+ "mean_audio_rtf": [0.45]
+ }
+ }
+ ]
+ },
+ {
+ "test_name": "test_qwen3_omni_chunk",
+ "server_params": {
+ "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
+ "extra_cli_args": ["--async-chunk"]
+ },
+ "benchmark_params": [
+ {
+ "dataset_name": "random",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [4, 16, 40],
+ "max_concurrency": [1, 4, 10],
+ "random_input_len": 2500,
+ "random_output_len": 900,
+ "ignore_eos": true,
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "baseline": {
+ "mean_ttft_ms": [1000, 3000, 5000],
+ "mean_audio_ttfp_ms": [1000, 3000, 5000],
+ "mean_audio_rtf": [0.2, 0.35, 0.6]
+ }
+ },
+ {
+ "dataset_name": "random-mm",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [10],
+ "request_rate": [0.1],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "random_range_ratio": 0.0,
+ "ignore_eos": true,
+ "random_mm_base_items_per_request": 1,
+ "random_mm_num_mm_items_range_ratio": 0.5,
+ "random_mm_limit_mm_per_prompt": {
+ "audio": 1
+ },
+ "random_mm_bucket_config": {
+ "(0, 60, 3)": 1.0
+ },
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "baseline": {
+ "mean_ttft_ms": [2000],
+ "mean_audio_ttfp_ms": [2000],
+ "mean_audio_rtf": [0.25]
+ }
+ },
+ {
+ "dataset_name": "random-mm",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [40],
+ "request_rate": [0.3],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "random_range_ratio": 0.0,
+ "ignore_eos": true,
+ "random_mm_base_items_per_request": 2,
+ "random_mm_num_mm_items_range_ratio": 0.5,
+ "random_mm_limit_mm_per_prompt": {
+ "image": 1,
+ "video": 1
+ },
+ "random_mm_bucket_config": {
+ "(256, 256, 1)": 0.5,
+ "(720, 1280, 2)": 0.5
+ },
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "baseline": {
+ "mean_ttft_ms": [4000],
+ "mean_audio_ttfp_ms": [4000],
+ "mean_audio_rtf": [0.4]
+ }
+ },
+ {
+ "dataset_name": "random-mm",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [100],
+ "request_rate": [0.5],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "random_range_ratio": 0.0,
+ "ignore_eos": true,
+ "random_mm_base_items_per_request": 3,
+ "random_mm_num_mm_items_range_ratio": 0.5,
+ "random_mm_limit_mm_per_prompt": {
+ "image": 1,
+ "video": 1,
+ "audio": 1
+ },
+ "random_mm_bucket_config": {
+ "(256, 256, 1)": 0.34,
+ "(720, 1280, 2)": 0.33,
+ "(0, 60, 3)": 0.33
+ },
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "baseline": {
+ "mean_ttft_ms": [6000],
+ "mean_audio_ttfp_ms": [6000],
+ "mean_audio_rtf": [0.7]
+ }
+ },
+ {
+ "dataset_name": "random",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [4, 16, 40],
+ "max_concurrency": [1, 4, 10],
+ "random_input_len": 2500,
+ "random_output_len": 900,
+ "ignore_eos": true,
+ "extra_body": {
+ "modalities": ["text"]
+ },
+ "percentile-metrics": "ttft,tpot,itl,e2el",
+ "baseline": {
+ "mean_ttft_ms": [1000, 3000, 5000]
+ }
+ },
+ {
+ "dataset_name": "random-mm",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [10],
+ "request_rate": [0.1],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "random_range_ratio": 0.0,
+ "ignore_eos": true,
+ "extra_body": {
+ "modalities": ["text"]
+ },
+ "random_mm_base_items_per_request": 1,
+ "random_mm_num_mm_items_range_ratio": 0.5,
+ "random_mm_limit_mm_per_prompt": {
+ "audio": 1
+ },
+ "random_mm_bucket_config": {
+ "(0, 60, 3)": 1.0
+ },
+ "percentile-metrics": "ttft,tpot,itl,e2el",
+ "baseline": {
+ "mean_ttft_ms": [2000]
+ }
+ },
+ {
+ "dataset_name": "random-mm",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [40],
+ "request_rate": [0.3],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "random_range_ratio": 0.0,
+ "ignore_eos": true,
+ "extra_body": {
+ "modalities": ["text"]
+ },
+ "random_mm_base_items_per_request": 2,
+ "random_mm_num_mm_items_range_ratio": 0.5,
+ "random_mm_limit_mm_per_prompt": {
+ "image": 1,
+ "video": 1
+ },
+ "random_mm_bucket_config": {
+ "(256, 256, 1)": 0.5,
+ "(720, 1280, 2)": 0.5
+ },
+ "percentile-metrics": "ttft,tpot,itl,e2el",
+ "baseline": {
+ "mean_ttft_ms": [4000]
+ }
+ },
+ {
+ "dataset_name": "random-mm",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [100],
+ "request_rate": [0.5],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "random_range_ratio": 0.0,
+ "ignore_eos": true,
+ "extra_body": {
+ "modalities": ["text"]
+ },
+ "random_mm_base_items_per_request": 3,
+ "random_mm_num_mm_items_range_ratio": 0.5,
+ "random_mm_limit_mm_per_prompt": {
+ "image": 1,
+ "video": 1,
+ "audio": 1
+ },
+ "random_mm_bucket_config": {
+ "(256, 256, 1)": 0.34,
+ "(720, 1280, 2)": 0.33,
+ "(0, 60, 3)": 0.33
+ },
+ "percentile-metrics": "ttft,tpot,itl,e2el",
+ "baseline": {
+ "mean_ttft_ms": [6000]
+ }
+ }
+ ]
+ }
+]
diff --git a/tests/dfx/perf/tests/test_tts.json b/tests/dfx/perf/tests/test_tts.json
new file mode 100644
index 0000000000..3583b45b4f
--- /dev/null
+++ b/tests/dfx/perf/tests/test_tts.json
@@ -0,0 +1,34 @@
+[
+ {
+ "test_name": "test_qwen3_tts",
+ "server_params": {
+ "model": "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
+ },
+ "benchmark_params": [
+ {
+ "dataset_name": "random",
+ "backend": "openai-audio-speech",
+ "endpoint": "/v1/audio/speech",
+ "num_prompts": [
+ 10,
+ 40
+ ],
+ "max_concurrency": [
+ 1,
+ 4
+ ],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "extra_body": {
+ "voice": "Vivian",
+ "language": "English"
+ },
+ "percentile-metrics": "ttft,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "baseline": {
+ "mean_audio_ttfp_ms": [6000, 6000],
+ "mean_audio_rtf": [0.3, 0.3]
+ }
+ }
+ ]
+ }
+]
diff --git a/tests/dfx/stability/scripts/test_benchmark_stability.py b/tests/dfx/stability/scripts/test_benchmark_stability.py
index e8568652d1..3d6b41e762 100644
--- a/tests/dfx/stability/scripts/test_benchmark_stability.py
+++ b/tests/dfx/stability/scripts/test_benchmark_stability.py
@@ -35,7 +35,7 @@
from tests.dfx.perf.scripts.run_benchmark import run_benchmark
STABILITY_DIR = Path(__file__).resolve().parent.parent
-STAGE_CONFIGS_DIR = STABILITY_DIR / "stage_configs"
+DEPLOY_CONFIGS_DIR = STABILITY_DIR / "deploy"
CONFIG_FILE_PATH = str(STABILITY_DIR / "tests" / "test.json")
DEFAULT_NUM_PROMPTS_PER_BATCH = 20
@@ -45,7 +45,7 @@
except FileNotFoundError:
BENCHMARK_CONFIGS = []
-test_params = create_unique_server_params(BENCHMARK_CONFIGS, STAGE_CONFIGS_DIR) if BENCHMARK_CONFIGS else []
+test_params = create_unique_server_params(BENCHMARK_CONFIGS, DEPLOY_CONFIGS_DIR) if BENCHMARK_CONFIGS else []
server_to_benchmark_mapping = create_test_parameter_mapping(BENCHMARK_CONFIGS) if BENCHMARK_CONFIGS else {}
_omni_server_lock = threading.Lock()
@@ -112,6 +112,8 @@ def _run_one_benchmark_batch(
flow=flow,
dataset_name=dataset_name,
num_prompt=num_prompts,
+ random_input_len=params.get("random_input_len"),
+ random_output_len=params.get("random_output_len"),
)
return result
except (FileNotFoundError, OSError) as e:
@@ -217,11 +219,20 @@ def omni_server(request):
Multi-stage initialization can take 10-20+ minutes.
"""
with _omni_server_lock:
- test_name, model, stage_config_path = request.param
+ test_name, model, stage_config_path, stage_overrides, extra_cli_args = request.param
print(f"Starting OmniServer with test: {test_name}, model: {model}")
- with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "120"]) as server:
+ server_args = ["--stage-init-timeout", "120"]
+ # --deploy-config and --stage-overrides compose at the CLI (see vllm_omni/entrypoints/utils.py):
+ # deploy-config sets the base; stage-overrides are applied on top. Both can be set.
+ if stage_config_path:
+ server_args = ["--deploy-config", stage_config_path] + server_args
+ if stage_overrides:
+ server_args = ["--stage-overrides", stage_overrides] + server_args
+ if extra_cli_args:
+ server_args = list(extra_cli_args) + server_args
+ with OmniServer(model, server_args) as server:
server.test_name = test_name
print("OmniServer started successfully")
yield server
diff --git a/tests/dfx/stability/stage_configs/qwen3_omni.yaml b/tests/dfx/stability/stage_configs/qwen3_omni.yaml
deleted file mode 100644
index 802f8dd249..0000000000
--- a/tests/dfx/stability/stage_configs/qwen3_omni.yaml
+++ /dev/null
@@ -1,101 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
-
-# The following config has been verified on 2x H100-80G GPUs.
-async_chunk: false
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type to launch OmniLLM
- runtime:
- devices: "0"
- max_batch_size: 64
- engine_args:
- model_stage: thinker
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- hf_config_name: thinker_config
- tensor_parallel_size: 1
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
-
- - stage_id: 1
- stage_type: llm # Use llm stage type to launch OmniLLM
- runtime:
- devices: "1"
- max_batch_size: 64
- engine_args:
- model_stage: talker
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 2
- stage_type: llm # Use llm stage type to launch OmniLLM
- runtime:
- devices: "1"
- max_batch_size: 64
- engine_args:
- model_stage: code2wav
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.1
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 1000000
- hf_config_name: thinker_config
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/tests/dfx/stability/tests/test.json b/tests/dfx/stability/tests/test.json
index 95993c9c55..255cd5b109 100644
--- a/tests/dfx/stability/tests/test.json
+++ b/tests/dfx/stability/tests/test.json
@@ -3,7 +3,11 @@
"test_name": "test_qwen3_omni_stability",
"server_params": {
"model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
- "stage_config_name": "qwen3_omni.yaml"
+ "stage_overrides": {
+ "2": {
+ "max_num_batched_tokens": 1000000
+ }
+ }
},
"benchmark_params": [
{
@@ -36,25 +40,12 @@
"test_name": "test_qwen3_omni_stability_async_chunk",
"server_params": {
"model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
- "stage_config_name": "qwen3_omni.yaml",
- "update": {
- "async_chunk": true,
- "stage_args": {
- "0": {
- "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk"
- },
- "1": {
- "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk"
- }
+ "stage_overrides": {
+ "2": {
+ "max_num_batched_tokens": 1000000
}
},
- "delete": {
- "stage_args": {
- "2": [
- "custom_process_input_func"
- ]
- }
- }
+ "extra_cli_args": ["--async-chunk"]
},
"benchmark_params": [
{
diff --git a/tests/diffusion/cache/test_cache_dit.py b/tests/diffusion/cache/test_cache_dit.py
new file mode 100644
index 0000000000..0b7ef72358
--- /dev/null
+++ b/tests/diffusion/cache/test_cache_dit.py
@@ -0,0 +1,40 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""
+Model specific tests for CacheDiT enablement.
+"""
+
+from unittest.mock import Mock, patch
+
+import pytest
+
+import vllm_omni.diffusion.cache.cache_dit_backend as cd_backend
+from vllm_omni.diffusion.data import DiffusionCacheConfig
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+SEPARATE_CFG_ENABLERS = [
+ cd_backend.enable_cache_for_ltx2,
+ cd_backend.enable_cache_for_wan22,
+ cd_backend.enable_cache_for_longcat_image,
+]
+
+SAMPLE_CACHE_CONFIG = DiffusionCacheConfig()
+
+
+@pytest.mark.parametrize("enabler", SEPARATE_CFG_ENABLERS)
+@patch("vllm_omni.diffusion.cache.cache_dit_backend.BlockAdapter")
+@patch("vllm_omni.diffusion.cache.cache_dit_backend.cache_dit")
+def test_separate_cfg(mock_cache_dit, mock_block_adapter, enabler):
+ """Ensure that custom enablers for models with separate CFG pass
+ the param through to cache_dit correctly.
+
+ Regression test for: https://github.com/vllm-project/vllm-omni/pull/2860
+ """
+ mock_pipeline = Mock()
+ enabler(mock_pipeline, SAMPLE_CACHE_CONFIG)
+
+ mock_cache_dit.enable_cache.assert_called_once()
+ adapter_kwargs = mock_block_adapter.call_args.kwargs
+ assert adapter_kwargs["has_separate_cfg"] is True
diff --git a/tests/diffusion/cache/test_teacache_extractors.py b/tests/diffusion/cache/test_teacache_extractors.py
index a52e11b3d4..c22a60e227 100644
--- a/tests/diffusion/cache/test_teacache_extractors.py
+++ b/tests/diffusion/cache/test_teacache_extractors.py
@@ -22,7 +22,7 @@
import torch
from tests.utils import hardware_test
-from vllm_omni.diffusion.cache.teacache.extractors import extract_flux2_klein_context
+from vllm_omni.diffusion.cache.teacache.extractors import extract_flux2_context, extract_flux2_klein_context
from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import (
Flux2Transformer2DModel,
)
@@ -174,3 +174,106 @@ def test_invalid_module_raises_error(self):
img_ids=torch.randint(0, 64, (1, 1024, 4)),
txt_ids=torch.randint(0, 64, (1, 512, 4)),
)
+
+
+class TestFlux2Extractor(BaseExtractorTest):
+ """Test extract_flux2_context function."""
+
+ def get_extractor(self):
+ return extract_flux2_context
+
+ @pytest.fixture
+ def flux2_module(self):
+ """Create a minimal Flux2Transformer2DModel for testing."""
+ from vllm_omni.diffusion.models.flux2.flux2_transformer import Flux2Transformer2DModel
+
+ model = Flux2Transformer2DModel(
+ num_layers=2,
+ num_single_layers=2,
+ num_attention_heads=48,
+ attention_head_dim=128,
+ joint_attention_dim=15360,
+ )
+ return model
+
+ def get_module(self, flux2_module):
+ return flux2_module
+
+ @pytest.fixture
+ def sample_inputs(self):
+ """Create sample input tensors for Flux2.
+
+ Note: hidden_states uses in_channels=128 (default for Flux2),
+ not inner_dim=6144. The x_embedder projects from 128 -> 6144.
+ encoder_hidden_states uses joint_attention_dim=15360 (model default),
+ which then gets projected to inner_dim=6144 by context_embedder.
+ """
+ batch_size = 1
+ img_seq_len = 1024
+ txt_seq_len = 512
+ in_channels = 128 # Model default in_channels
+ txt_dim = 15360 # Model default joint_attention_dim
+
+ return {
+ "hidden_states": torch.randn(batch_size, img_seq_len, in_channels),
+ "encoder_hidden_states": torch.randn(batch_size, txt_seq_len, txt_dim),
+ "timestep": torch.tensor([500]),
+ "img_ids": torch.randint(0, 64, (batch_size, img_seq_len, 4)),
+ "txt_ids": torch.randint(0, 64, (batch_size, txt_seq_len, 4)),
+ "guidance": torch.tensor([3.5]),
+ }
+
+ def get_sample_inputs(self, sample_inputs):
+ return sample_inputs
+
+ @hardware_test(res={"cuda": "L4"}, num_cards=1)
+ def test_modulated_input_shape(self, flux2_module, sample_inputs):
+ """Test that modulated_input has correct shape matching the model's inner_dim.
+
+ Note: After x_embedder projection, hidden_states are projected from
+ in_channels (128) to inner_dim (6144), so modulated_input should match
+ the projected shape, not the input shape.
+ """
+ context = extract_flux2_klein_context(flux2_module, **sample_inputs)
+
+ batch_size, img_seq_len, _ = sample_inputs["hidden_states"].shape
+ inner_dim = flux2_module.inner_dim
+ assert context.modulated_input.shape == (batch_size, img_seq_len, inner_dim)
+
+ @hardware_test(res={"cuda": "L4"}, num_cards=1)
+ def test_run_transformer_blocks_callable(self, flux2_module, sample_inputs):
+ """Test that run_transformer_blocks is callable."""
+ context = extract_flux2_context(flux2_module, **sample_inputs)
+ assert callable(context.run_transformer_blocks)
+
+ @hardware_test(res={"cuda": "L4"}, num_cards=1)
+ def test_postprocess_callable(self, flux2_module, sample_inputs):
+ """Test that postprocess is callable."""
+ context = extract_flux2_context(flux2_module, **sample_inputs)
+ assert callable(context.postprocess)
+
+ def test_without_guidance(self, flux2_module, sample_inputs):
+ """Test context extraction works without guidance (no CFG)."""
+ inputs = sample_inputs.copy()
+ inputs["guidance"] = None
+
+ context = extract_flux2_context(flux2_module, **inputs)
+
+ assert context is not None
+ assert context.temb is not None
+
+ @pytest.mark.cpu
+ def test_invalid_module_raises_error(self):
+ """Test that invalid module without transformer_blocks raises ValueError."""
+ invalid_module = Mock()
+ invalid_module.transformer_blocks = []
+
+ with pytest.raises(ValueError, match="Module must have transformer_blocks"):
+ extract_flux2_context(
+ invalid_module,
+ hidden_states=torch.randn(1, 1024, 6144),
+ encoder_hidden_states=torch.randn(1, 512, 15360),
+ timestep=torch.tensor([500]),
+ img_ids=torch.randint(0, 64, (1, 1024, 4)),
+ txt_ids=torch.randint(0, 64, (1, 512, 4)),
+ )
diff --git a/tests/diffusion/distributed/test_cfg_parallel.py b/tests/diffusion/distributed/test_cfg_parallel.py
index 79dbe9e6dd..bf709618de 100644
--- a/tests/diffusion/distributed/test_cfg_parallel.py
+++ b/tests/diffusion/distributed/test_cfg_parallel.py
@@ -2,8 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for CFG (Classifier-Free Guidance) parallel functionality.
-This test verifies that predict_noise_maybe_with_cfg produces numerically
-equivalent results with and without CFG parallel using fixed random inputs.
+This test verifies that predict_noise_maybe_with_cfg and
+predict_noise_with_multi_branch_cfg produce numerically equivalent results
+with and without CFG parallel using fixed random inputs.
"""
import os
@@ -429,3 +430,340 @@ def test_predict_noise_without_cfg(dtype: torch.dtype):
assert noise_pred.shape == (1, 4, 16, 16)
print(f"✓ Test passed: predict_noise without CFG (dtype={dtype})")
+
+
+class MultiBranchTestPipeline(CFGParallelMixin):
+ """Test pipeline with custom 3-branch combine logic (like OmniGen2)."""
+
+ def __init__(self, in_channels: int = 4, hidden_dim: int = 128, seed: int = 42):
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+
+ self.transformer = SimpleTransformer(in_channels, hidden_dim)
+
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+ for param in self.transformer.parameters():
+ torch.nn.init.normal_(param, mean=0.0, std=0.02)
+
+ def combine_multi_branch_cfg_noise(self, predictions, true_cfg_scale, cfg_normalize=False):
+ """N-branch combine with weighted sum for testing.
+
+ - 2-branch: standard CFG formula (true_cfg_scale is float)
+ - 3-branch: OmniGen2-style dual guidance scale (true_cfg_scale is dict)
+ - 4-branch: DreamID-style weighted sum (true_cfg_scale is dict)
+ """
+ if len(predictions) == 4:
+ text_scale = true_cfg_scale["text"]
+ image_scale = true_cfg_scale["image"]
+ vid_ref_scale = true_cfg_scale["vid_ref"]
+ pos, neg, vid_neg, audio_neg = predictions
+ combined = (
+ audio_neg
+ + vid_ref_scale * (vid_neg - audio_neg)
+ + image_scale * (neg - vid_neg)
+ + text_scale * (pos - neg)
+ )
+ elif len(predictions) == 3:
+ text_scale = true_cfg_scale["text"]
+ image_scale = true_cfg_scale["image"]
+ pos, ref, uncond = predictions
+ combined = uncond + image_scale * (ref - uncond) + text_scale * (pos - ref)
+ else:
+ pos, neg = predictions[0], predictions[1]
+ combined = neg + true_cfg_scale * (pos - neg)
+
+ if cfg_normalize:
+ combined = self.cfg_normalize_function(pos, combined)
+ return combined
+
+
+def _test_multi_branch_parallel_worker(
+ local_rank: int,
+ world_size: int,
+ cfg_parallel_size: int,
+ dtype: torch.dtype,
+ test_config: dict,
+ result_queue: torch.multiprocessing.Queue,
+):
+ """Worker function for multi-branch CFG parallel test."""
+ device = torch.device(f"{current_omni_platform.device_type}:{local_rank}")
+ current_omni_platform.set_device(device)
+
+ update_environment_variables(
+ {
+ "RANK": str(local_rank),
+ "LOCAL_RANK": str(local_rank),
+ "WORLD_SIZE": str(world_size),
+ "MASTER_ADDR": "localhost",
+ "MASTER_PORT": "29504",
+ }
+ )
+
+ init_distributed_environment()
+ initialize_model_parallel(cfg_parallel_size=cfg_parallel_size)
+
+ cfg_rank = get_classifier_free_guidance_rank()
+ cfg_world_size = get_classifier_free_guidance_world_size()
+ assert cfg_world_size == cfg_parallel_size
+
+ pipeline = MultiBranchTestPipeline(
+ in_channels=test_config["channels"],
+ hidden_dim=test_config["hidden_dim"],
+ seed=test_config["model_seed"],
+ )
+ pipeline.transformer = pipeline.transformer.to(device=device, dtype=dtype)
+ pipeline.transformer.eval()
+
+ n_branches = test_config["n_branches"]
+ batch_size = test_config["batch_size"]
+ channels = test_config["channels"]
+ height = test_config["height"]
+ width = test_config["width"]
+
+ # Create N branch inputs with distinct seeds
+ branches_kwargs = []
+ for b in range(n_branches):
+ torch.manual_seed(test_config["input_seed"] + b)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(test_config["input_seed"] + b)
+ x = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device)
+ branches_kwargs.append({"x": x})
+
+ with torch.no_grad():
+ noise_pred = pipeline.predict_noise_with_multi_branch_cfg(
+ do_true_cfg=True,
+ true_cfg_scale=test_config["cfg_scale"],
+ branches_kwargs=branches_kwargs,
+ cfg_normalize=test_config["cfg_normalize"],
+ )
+
+ assert noise_pred is not None
+ result_queue.put((cfg_rank, noise_pred.cpu()))
+
+ destroy_distributed_env()
+
+
+def _test_multi_branch_sequential_worker(
+ local_rank: int,
+ world_size: int,
+ dtype: torch.dtype,
+ test_config: dict,
+ result_queue: torch.multiprocessing.Queue,
+):
+ """Worker function for sequential multi-branch CFG test (baseline)."""
+ device = torch.device(f"{current_omni_platform.device_type}:{local_rank}")
+ current_omni_platform.set_device(device)
+
+ update_environment_variables(
+ {
+ "RANK": str(local_rank),
+ "LOCAL_RANK": str(local_rank),
+ "WORLD_SIZE": str(world_size),
+ "MASTER_ADDR": "localhost",
+ "MASTER_PORT": "29505",
+ }
+ )
+
+ init_distributed_environment()
+ initialize_model_parallel(cfg_parallel_size=1)
+
+ cfg_world_size = get_classifier_free_guidance_world_size()
+ assert cfg_world_size == 1
+
+ pipeline = MultiBranchTestPipeline(
+ in_channels=test_config["channels"],
+ hidden_dim=test_config["hidden_dim"],
+ seed=test_config["model_seed"],
+ )
+ pipeline.transformer = pipeline.transformer.to(device=device, dtype=dtype)
+ pipeline.transformer.eval()
+
+ n_branches = test_config["n_branches"]
+ batch_size = test_config["batch_size"]
+ channels = test_config["channels"]
+ height = test_config["height"]
+ width = test_config["width"]
+
+ branches_kwargs = []
+ for b in range(n_branches):
+ torch.manual_seed(test_config["input_seed"] + b)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(test_config["input_seed"] + b)
+ x = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device)
+ branches_kwargs.append({"x": x})
+
+ with torch.no_grad():
+ noise_pred = pipeline.predict_noise_with_multi_branch_cfg(
+ do_true_cfg=True,
+ true_cfg_scale=test_config["cfg_scale"],
+ branches_kwargs=branches_kwargs,
+ cfg_normalize=test_config["cfg_normalize"],
+ )
+
+ assert noise_pred is not None
+ result_queue.put(noise_pred.cpu())
+
+ destroy_distributed_env()
+
+
+@pytest.mark.parametrize(
+ "cfg_parallel_size,n_branches",
+ [
+ (2, 2), # 2 branches on 2 GPUs: [[0],[1]]
+ (2, 3), # 3 branches on 2 GPUs: [[0,2],[1]]
+ (3, 3), # 3 branches on 3 GPUs: [[0],[1],[2]]
+ (2, 4), # 4 branches on 2 GPUs: [[0,2],[1,3]]
+ ],
+)
+@pytest.mark.parametrize("dtype", [torch.bfloat16])
+@pytest.mark.parametrize("batch_size", [2])
+@pytest.mark.parametrize("cfg_normalize", [False, True])
+def test_predict_noise_with_multi_branch_cfg(
+ cfg_parallel_size: int,
+ n_branches: int,
+ dtype: torch.dtype,
+ batch_size: int,
+ cfg_normalize: bool,
+):
+ """
+ Test that predict_noise_with_multi_branch_cfg produces identical results
+ with and without CFG parallel for N-branch models.
+
+ Args:
+ cfg_parallel_size: Number of GPUs for CFG parallel
+ n_branches: Number of CFG branches
+ dtype: Data type for computation
+ batch_size: Batch size for testing
+ cfg_normalize: Whether to normalize CFG output
+ """
+ available_gpus = current_omni_platform.get_device_count()
+ if available_gpus < cfg_parallel_size:
+ pytest.skip(f"Test requires {cfg_parallel_size} GPUs but only {available_gpus} available")
+
+ if n_branches == 2:
+ cfg_scale = 5.0
+ elif n_branches == 3:
+ cfg_scale = {"text": 5.0, "image": 2.0}
+ else:
+ cfg_scale = {"text": 5.0, "image": 2.0, "vid_ref": 1.5}
+
+ test_config = {
+ "batch_size": batch_size,
+ "channels": 4,
+ "height": 16,
+ "width": 16,
+ "hidden_dim": 128,
+ "cfg_scale": cfg_scale,
+ "cfg_normalize": cfg_normalize,
+ "model_seed": 42,
+ "input_seed": 123,
+ "n_branches": n_branches,
+ }
+
+ mp_context = torch.multiprocessing.get_context("spawn")
+ manager = mp_context.Manager()
+ baseline_queue = manager.Queue()
+ cfg_parallel_queue = manager.Queue()
+
+ # Run baseline (sequential, cfgp=1)
+ torch.multiprocessing.spawn(
+ _test_multi_branch_sequential_worker,
+ args=(1, dtype, test_config, baseline_queue),
+ nprocs=1,
+ )
+
+ # Run CFG parallel
+ torch.multiprocessing.spawn(
+ _test_multi_branch_parallel_worker,
+ args=(cfg_parallel_size, cfg_parallel_size, dtype, test_config, cfg_parallel_queue),
+ nprocs=cfg_parallel_size,
+ )
+
+ baseline_output = baseline_queue.get()
+ cfg_parallel_outputs = [cfg_parallel_queue.get() for _ in range(cfg_parallel_size)]
+ cfg_parallel_outputs.sort(key=lambda item: item[0])
+ cfg_parallel_output = cfg_parallel_outputs[0][1]
+
+ # All ranks should produce identical output
+ for cfg_rank, rank_output in cfg_parallel_outputs[1:]:
+ torch.testing.assert_close(
+ rank_output,
+ cfg_parallel_output,
+ rtol=0,
+ atol=0,
+ msg=f"Multi-branch CFG parallel ranks differ (rank 0 vs rank {cfg_rank})",
+ )
+
+ assert baseline_output.shape == cfg_parallel_output.shape, (
+ f"Shape mismatch: baseline {baseline_output.shape} vs CFG parallel {cfg_parallel_output.shape}"
+ )
+
+ if dtype == torch.float32:
+ rtol, atol = 1e-5, 1e-5
+ elif dtype == torch.bfloat16:
+ rtol, atol = 1e-2, 1e-2
+ else:
+ rtol, atol = 1e-3, 1e-3
+
+ torch.testing.assert_close(
+ cfg_parallel_output,
+ baseline_output,
+ rtol=rtol,
+ atol=atol,
+ msg=(
+ f"Multi-branch CFG parallel output differs from sequential\n"
+ f" n_branches={n_branches}, cfg_parallel_size={cfg_parallel_size}\n"
+ f" dtype={dtype}, cfg_normalize={cfg_normalize}\n"
+ f" Max diff: {(cfg_parallel_output - baseline_output).abs().max().item():.6e}"
+ ),
+ )
+
+ print(
+ f"✓ Test passed: multi_branch n_branches={n_branches}, "
+ f"cfg_size={cfg_parallel_size}, dtype={dtype}, cfg_normalize={cfg_normalize}"
+ )
+
+
+@pytest.mark.parametrize("dtype", [torch.bfloat16])
+def test_multi_branch_without_cfg(dtype: torch.dtype):
+ """
+ Test predict_noise_with_multi_branch_cfg when do_true_cfg=False.
+
+ When CFG is disabled, only the first branch (positive) should be computed.
+ This test runs on a single GPU without distributed environment.
+ """
+ available_gpus = current_omni_platform.get_device_count()
+ if available_gpus < 1:
+ pytest.skip("Test requires at least 1 GPU")
+
+ device = torch.device(f"{current_omni_platform.device_type}:0")
+ current_omni_platform.set_device(device)
+
+ pipeline = MultiBranchTestPipeline(in_channels=4, hidden_dim=128, seed=42)
+ pipeline.transformer = pipeline.transformer.to(device=device, dtype=dtype)
+ pipeline.transformer.eval()
+
+ # Create 3 branch inputs (only first should be used)
+ branches_kwargs = []
+ for b in range(3):
+ torch.manual_seed(123 + b)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(123 + b)
+ x = torch.randn(1, 4, 16, 16, dtype=dtype, device=device)
+ branches_kwargs.append({"x": x})
+
+ with torch.no_grad():
+ noise_pred = pipeline.predict_noise_with_multi_branch_cfg(
+ do_true_cfg=False, # No CFG
+ true_cfg_scale=5.0,
+ branches_kwargs=branches_kwargs,
+ cfg_normalize=False,
+ )
+
+ assert noise_pred is not None
+ assert noise_pred.shape == (1, 4, 16, 16)
+
+ print(f"✓ Test passed: multi_branch predict_noise without CFG (dtype={dtype})")
diff --git a/tests/diffusion/distributed/test_distributed_vae_executor.py b/tests/diffusion/distributed/test_distributed_vae_executor.py
index dc491dcdaf..b2ee7c10d3 100644
--- a/tests/diffusion/distributed/test_distributed_vae_executor.py
+++ b/tests/diffusion/distributed/test_distributed_vae_executor.py
@@ -1,4 +1,4 @@
-from unittest.mock import MagicMock, patch
+from types import SimpleNamespace
import pytest
import torch
@@ -61,40 +61,31 @@ def merge(self, coord_tensor_map, grid_spec):
class DummyMixin(DistributedVaeMixin):
def __init__(self):
self.use_tiling = True
- self.distributed_executor = MagicMock()
- self.distributed_executor.parallel_size = 2
- self.distributed_executor.group = None
+ self.distributed_executor = SimpleNamespace(parallel_size=2, group=None)
@pytest.fixture(autouse=True)
-def mock_dist():
- with (
- patch.object(dist, "get_world_size", return_value=2),
- patch.object(dist, "get_rank", return_value=0),
- patch.object(dist, "is_initialized", return_value=True),
- patch.object(dist, "all_reduce", return_value=None),
- patch.object(dist, "gather", return_value=None),
- patch.object(dist, "broadcast", return_value=None),
- ):
- yield
+def mock_dist(monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr(dist, "get_world_size", lambda *args, **kwargs: 2)
+ monkeypatch.setattr(dist, "get_rank", lambda *args, **kwargs: 0)
+ monkeypatch.setattr(dist, "is_initialized", lambda: True)
+ monkeypatch.setattr(dist, "all_reduce", lambda *args, **kwargs: None)
+ monkeypatch.setattr(dist, "gather", lambda *args, **kwargs: None)
+ monkeypatch.setattr(dist, "broadcast", lambda *args, **kwargs: None)
@pytest.fixture(autouse=True)
-def mock_dit_group():
- with patch(
+def mock_dit_group(monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr(
"vllm_omni.diffusion.distributed.autoencoders.distributed_vae_executor.get_dit_group",
- new=MagicMock(return_value=None),
- ):
- yield
+ lambda: None,
+ )
@pytest.fixture(autouse=True)
-def mock_dist_vae_executor():
- with (
- patch.object(DistributedVaeExecutor, "gather_tensors", side_effect=lambda x: [x]),
- patch.object(DistributedVaeExecutor, "broadcast_tensor", side_effect=lambda x: x),
- ):
- yield
+def mock_dist_vae_executor(monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr(DistributedVaeExecutor, "gather_tensors", lambda self, x: [x])
+ monkeypatch.setattr(DistributedVaeExecutor, "broadcast_tensor", lambda self, x: x)
# ============================
diff --git a/tests/diffusion/hooks/test_hook_registry.py b/tests/diffusion/hooks/test_hook_registry.py
new file mode 100644
index 0000000000..6c8535cfec
--- /dev/null
+++ b/tests/diffusion/hooks/test_hook_registry.py
@@ -0,0 +1,164 @@
+"""
+Tests for hook registry.
+
+NOTE: The hook registry is also tested indirectly through a lot of
+other tests, e.g., tests/diffusion/distributed/test_sp_plan_hooks.py
+"""
+
+from typing import Any
+
+import pytest
+from torch import nn
+
+from vllm_omni.diffusion.hooks.base import HookRegistry, ModelHook
+
+DEFAULT_OUT = "ECHO"
+OVERRIDE_OUT = "OVERRIDE"
+INPUT_KWARG = "inp"
+
+
+class EchoModule(nn.Module):
+ """Just echo the input."""
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ super().__init__(*args, **kwargs)
+
+ def forward(self, *args, **kwargs):
+ input_val = kwargs[INPUT_KWARG]
+ return input_val + DEFAULT_OUT
+
+
+class AppendHook(ModelHook):
+ """Append an echo value to the input string on pre / post forward."""
+
+ def __init__(self, echo_val: str):
+ self.echo_val = echo_val
+
+ def pre_forward(self, module: nn.Module, *args, **kwargs):
+ input_val = kwargs[INPUT_KWARG]
+ return (), {INPUT_KWARG: input_val + self.echo_val}
+
+ def post_forward(self, module: nn.Module, output):
+ return output + self.echo_val
+
+
+class OverrideAppendHook(AppendHook):
+ """Same as AppendHook, but replace the forward call with a different string."""
+
+ def new_forward(self, module: nn.Module, *args, **kwargs):
+ return kwargs[INPUT_KWARG] + OVERRIDE_OUT
+
+
+def test_register_no_fwd_override_hooks():
+ """Ensure registration is correct with no forward hooks."""
+ mod = EchoModule()
+ registry = HookRegistry.get_or_create(mod)
+ first_hook = AppendHook("1")
+ second_hook = AppendHook("2")
+ sorted_no_fwd_hooks = [first_hook, second_hook]
+
+ # Will add and sort the hook by key
+ registry.register_hook(name="b", hook=second_hook)
+ registry.register_hook(name="a", hook=first_hook)
+
+ assert len(registry._hooks) == 2
+ assert len(registry._sorted_hooks) == 2
+ assert registry._new_fwd_impl_hook is None
+ # Ensure registering a new hook sorting alphabetically
+ for actual_hook, expected_hook in zip(registry._sorted_hooks, sorted_no_fwd_hooks):
+ assert actual_hook is expected_hook
+
+
+def test_register_with_forward_hooks():
+ """Ensure registration is correct with a forward hooks."""
+ mod = EchoModule()
+ registry = HookRegistry.get_or_create(mod)
+ first_hook = AppendHook("1")
+ second_hook = AppendHook("2")
+ exec_hook = OverrideAppendHook("3")
+ sorted_no_fwd_hooks = [first_hook, second_hook]
+
+ # Will add and sort the hook by key
+ registry.register_hook(name="b", hook=second_hook)
+ registry.register_hook(name="a", hook=first_hook)
+ registry.register_hook(name="c", hook=exec_hook)
+
+ assert len(registry._hooks) == 3
+ assert len(registry._sorted_hooks) == 3
+ assert registry._new_fwd_impl_hook is exec_hook
+ # Ensure registering a new hook sorting alphabetically
+ for actual_hook, expected_hook in zip(registry._sorted_hooks, sorted_no_fwd_hooks):
+ assert actual_hook is expected_hook
+
+
+def test_register_fails_with_multiple_forward_hooks():
+ """Ensure registration only allows one hook overriding new_forward"""
+ mod = EchoModule()
+ registry = HookRegistry.get_or_create(mod)
+
+ registry.register_hook(name="foo", hook=OverrideAppendHook("1"))
+ with pytest.raises(RuntimeError):
+ registry.register_hook(name="bar", hook=OverrideAppendHook("2"))
+
+
+def test_remove_hooks():
+ """Ensure removal sorts hooks."""
+ mod = EchoModule()
+ registry = HookRegistry.get_or_create(mod)
+
+ first_hook = AppendHook("1")
+ second_hook = AppendHook("2")
+ exec_hook = OverrideAppendHook("3")
+
+ registry.register_hook(name="b", hook=second_hook)
+ registry.register_hook(name="a", hook=first_hook)
+ registry.register_hook(name="c", hook=exec_hook)
+ # Explicitly reorder our hooks to be in the wrong order, since register
+ # forces them to be sorted too. Ensure that remove the hook will also
+ # enforce the sorted order.
+ registry._sorted_hooks = [second_hook, first_hook]
+
+ assert registry._new_fwd_impl_hook is exec_hook
+ registry.remove_hook("c")
+ assert registry._new_fwd_impl_hook is None
+
+ sorted_no_fwd_hooks = [first_hook, second_hook]
+ for actual_hook, expected_hook in zip(registry._sorted_hooks, sorted_no_fwd_hooks):
+ assert actual_hook is expected_hook
+
+
+def test_dispatch_no_fwd_override_hooks():
+ """Ensure dispatch runs hooks in deterministic sorted order."""
+ mod = EchoModule()
+ registry = HookRegistry.get_or_create(mod)
+
+ first_hook = AppendHook("1")
+ second_hook = AppendHook("2")
+
+ # Register will sort the hooks, so hook 1 will run first
+ # on preprocess and last in post process
+ registry.register_hook(name="2", hook=second_hook)
+ registry.register_hook(name="1", hook=first_hook)
+ res = registry.dispatch(inp="")
+ assert isinstance(res, str)
+ assert res == f"12{DEFAULT_OUT}21"
+
+
+def test_dispatch_with_fwd_hooks():
+ """Ensure dispatch runs hooks in deterministic sorted order."""
+ mod = EchoModule()
+ registry = HookRegistry.get_or_create(mod)
+
+ first_hook = AppendHook("1")
+ second_hook = AppendHook("2")
+ exec_hook = OverrideAppendHook("3")
+
+ # Register will sort the hooks, so hook 1 will run first on preprocess and last in
+ # post process. Since the override hook mutates forward, it will run last even
+ # though the name of the exec_hook is alphabetically before the second hook.
+ registry.register_hook(name="c", hook=second_hook)
+ registry.register_hook(name="a", hook=first_hook)
+ registry.register_hook(name="b", hook=exec_hook)
+ res = registry.dispatch(inp="")
+ assert isinstance(res, str)
+ assert res == f"123{OVERRIDE_OUT}321"
diff --git a/tests/diffusion/layers/test_norm.py b/tests/diffusion/layers/test_norm.py
new file mode 100644
index 0000000000..e420415285
--- /dev/null
+++ b/tests/diffusion/layers/test_norm.py
@@ -0,0 +1,453 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for LayerNorm and RMSNorm custom ops in diffusion layers."""
+
+import pytest
+import torch
+
+pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]
+
+
+# ── Import tests ──
+
+
+def test_layernorm_import():
+ """Verify LayerNorm can be imported from the norm module."""
+ from vllm_omni.diffusion.layers.norm import LayerNorm # noqa: F401
+
+
+def test_rmsnorm_import():
+ """Verify RMSNorm can be imported from the norm module."""
+ from vllm_omni.diffusion.layers.norm import RMSNorm # noqa: F401
+
+
+# ── LayerNorm tests ──
+
+
+def test_layernorm_forward_shape():
+ """LayerNorm produces correct output shapes."""
+ from vllm_omni.diffusion.layers.norm import LayerNorm
+
+ dim = 64
+ batch = 2
+ seq_len = 4
+ norm = LayerNorm(dim)
+
+ x = torch.randn(batch, seq_len, dim)
+ out = norm(x)
+
+ assert out.shape == (batch, seq_len, dim)
+
+
+def test_layernorm_forward_shape_2d():
+ """LayerNorm works with 2D input tensors."""
+ from vllm_omni.diffusion.layers.norm import LayerNorm
+
+ dim = 64
+ batch = 2
+ norm = LayerNorm(dim)
+
+ x = torch.randn(batch, dim)
+ out = norm(x)
+
+ assert out.shape == (batch, dim)
+
+
+def test_layernorm_preserves_dtype_fp32():
+ """LayerNorm preserves float32 dtype."""
+ from vllm_omni.diffusion.layers.norm import LayerNorm
+
+ dim = 64
+ norm = LayerNorm(dim)
+
+ x = torch.randn(2, 4, dim, dtype=torch.float32)
+ out = norm(x)
+
+ assert out.dtype == torch.float32
+
+
+def test_layernorm_preserves_dtype_fp16():
+ """LayerNorm preserves float16 dtype."""
+ from vllm_omni.diffusion.layers.norm import LayerNorm
+
+ dim = 64
+ norm = LayerNorm(dim)
+
+ x = torch.randn(2, 4, dim, dtype=torch.float16)
+ out = norm(x)
+
+ assert out.dtype == torch.float16
+
+
+def test_layernorm_preserves_dtype_bf16():
+ """LayerNorm preserves bfloat16 dtype."""
+ from vllm_omni.diffusion.layers.norm import LayerNorm
+
+ dim = 64
+ norm = LayerNorm(dim)
+
+ x = torch.randn(2, 4, dim, dtype=torch.bfloat16)
+ out = norm(x)
+
+ assert out.dtype == torch.bfloat16
+
+
+def test_layernorm_without_elementwise_affine():
+ """LayerNorm works without elementwise_affine (no learned parameters)."""
+ from vllm_omni.diffusion.layers.norm import LayerNorm
+
+ dim = 64
+ norm = LayerNorm(dim, elementwise_affine=False)
+
+ assert norm.weight is None
+ assert norm.bias is None
+
+ x = torch.randn(2, 4, dim)
+ out = norm(x)
+
+ assert out.shape == (2, 4, dim)
+
+
+def test_layernorm_custom_eps():
+ """LayerNorm accepts custom epsilon value."""
+ from vllm_omni.diffusion.layers.norm import LayerNorm
+
+ dim = 64
+ eps = 1e-5
+ norm = LayerNorm(dim, eps=eps)
+
+ assert norm.eps == eps
+
+
+def test_layernorm_has_learnable_parameters():
+ """LayerNorm has learnable weight and bias by default."""
+ from vllm_omni.diffusion.layers.norm import LayerNorm
+
+ dim = 64
+ norm = LayerNorm(dim)
+
+ assert norm.weight is not None
+ assert norm.bias is not None
+ assert norm.weight.shape == (dim,)
+ assert norm.bias.shape == (dim,)
+
+
+def test_layernorm_matches_fp32_reference():
+ """Verify LayerNorm produces identical output to FP32 nn.LayerNorm."""
+ from vllm_omni.diffusion.layers.norm import LayerNorm
+
+ dim = 64
+ eps = 1e-6
+ torch.manual_seed(42)
+
+ ours = LayerNorm(dim, eps=eps)
+ ref = torch.nn.LayerNorm(dim, eps=eps)
+
+ # Copy weights
+ ref.weight.data.copy_(ours.weight.data)
+ ref.bias.data.copy_(ours.bias.data)
+
+ x = torch.randn(2, 4, dim)
+
+ out_ours = ours(x)
+ out_ref = ref(x.float()).to(x.dtype)
+
+ torch.testing.assert_close(out_ours, out_ref, atol=1e-5, rtol=1e-5)
+
+
+def test_layernorm_matches_diffusers_fp32layernorm():
+ """Verify LayerNorm produces identical output to diffusers FP32LayerNorm."""
+ from diffusers.models.normalization import FP32LayerNorm
+
+ from vllm_omni.diffusion.layers.norm import LayerNorm
+
+ dim = 64
+ eps = 1e-6
+ torch.manual_seed(42)
+
+ ours = LayerNorm(dim, eps=eps)
+ ref = FP32LayerNorm(dim, eps=eps)
+
+ # Copy weights
+ ref.weight.data.copy_(ours.weight.data)
+ ref.bias.data.copy_(ours.bias.data)
+
+ # Test with fp16 input to verify FP32 computation
+ x = torch.randn(2, 4, dim, dtype=torch.float16)
+
+ out_ours = ours(x)
+ out_ref = ref(x)
+
+ torch.testing.assert_close(out_ours, out_ref, atol=1e-3, rtol=1e-3)
+
+
+# ── RMSNorm tests ──
+
+
+def test_rmsnorm_forward_shape():
+ """RMSNorm produces correct output shapes."""
+ from vllm_omni.diffusion.layers.norm import RMSNorm
+
+ hidden_size = 64
+ batch = 2
+ seq_len = 4
+ norm = RMSNorm(hidden_size)
+
+ x = torch.randn(batch, seq_len, hidden_size)
+ out = norm(x)
+
+ assert out.shape == (batch, seq_len, hidden_size)
+
+
+def test_rmsnorm_forward_shape_2d():
+ """RMSNorm works with 2D input tensors."""
+ from vllm_omni.diffusion.layers.norm import RMSNorm
+
+ hidden_size = 64
+ batch = 2
+ norm = RMSNorm(hidden_size)
+
+ x = torch.randn(batch, hidden_size)
+ out = norm(x)
+
+ assert out.shape == (batch, hidden_size)
+
+
+def test_rmsnorm_preserves_dtype_fp32():
+ """RMSNorm preserves float32 dtype."""
+ from vllm_omni.diffusion.layers.norm import RMSNorm
+
+ hidden_size = 64
+ norm = RMSNorm(hidden_size)
+
+ x = torch.randn(2, 4, hidden_size, dtype=torch.float32)
+ out = norm(x)
+
+ assert out.dtype == torch.float32
+
+
+def test_rmsnorm_preserves_dtype_fp16():
+ """RMSNorm preserves float16 dtype."""
+ from vllm_omni.diffusion.layers.norm import RMSNorm
+
+ hidden_size = 64
+ norm = RMSNorm(hidden_size)
+
+ x = torch.randn(2, 4, hidden_size, dtype=torch.float16)
+ out = norm(x)
+
+ assert out.dtype == torch.float16
+
+
+def test_rmsnorm_preserves_dtype_bf16():
+ """RMSNorm preserves bfloat16 dtype."""
+ from vllm_omni.diffusion.layers.norm import RMSNorm
+
+ hidden_size = 64
+ norm = RMSNorm(hidden_size)
+
+ x = torch.randn(2, 4, hidden_size, dtype=torch.bfloat16)
+ out = norm(x)
+
+ assert out.dtype == torch.bfloat16
+
+
+def test_rmsnorm_custom_eps():
+ """RMSNorm accepts custom epsilon value."""
+ from vllm_omni.diffusion.layers.norm import RMSNorm
+
+ hidden_size = 64
+ eps = 1e-5
+ norm = RMSNorm(hidden_size, eps=eps)
+
+ assert norm.variance_epsilon == eps
+
+
+def test_rmsnorm_has_weight_parameter():
+ """RMSNorm has learnable weight parameter initialized to ones."""
+ from vllm_omni.diffusion.layers.norm import RMSNorm
+
+ hidden_size = 64
+ norm = RMSNorm(hidden_size)
+
+ assert norm.weight is not None
+ assert norm.weight.shape == (hidden_size,)
+ torch.testing.assert_close(norm.weight, torch.ones(hidden_size))
+
+
+def test_rmsnorm_numerical_correctness():
+ """Verify RMSNorm produces numerically correct output."""
+ from vllm_omni.diffusion.layers.norm import RMSNorm
+
+ hidden_size = 64
+ eps = 1e-6
+ torch.manual_seed(42)
+
+ norm = RMSNorm(hidden_size, eps=eps)
+ x = torch.randn(2, 4, hidden_size)
+
+ # Compute expected output manually
+ x_fp32 = x.to(torch.float32)
+ variance = x_fp32.pow(2).mean(-1, keepdim=True)
+ expected = x_fp32 * torch.rsqrt(variance + eps)
+ expected = norm.weight.to(torch.float32) * expected
+ expected = expected.to(x.dtype)
+
+ out = norm(x)
+
+ torch.testing.assert_close(out, expected, atol=1e-5, rtol=1e-5)
+
+
+def test_rmsnorm_matches_reference_implementation():
+ """Verify RMSNorm matches a reference implementation."""
+ from vllm_omni.diffusion.layers.norm import RMSNorm
+
+ def reference_rmsnorm(x, weight, eps):
+ """Reference RMSNorm implementation."""
+ input_dtype = x.dtype
+ x = x.to(torch.float32)
+ variance = x.pow(2).mean(-1, keepdim=True)
+ out = x * torch.rsqrt(variance + eps)
+ out = weight.to(torch.float32) * out
+ return out.to(input_dtype)
+
+ hidden_size = 128
+ eps = 1e-6
+ torch.manual_seed(123)
+
+ norm = RMSNorm(hidden_size, eps=eps)
+
+ # Test with various dtypes
+ for dtype in [torch.float32, torch.float16, torch.bfloat16]:
+ x = torch.randn(4, 8, hidden_size, dtype=dtype)
+ expected = reference_rmsnorm(x, norm.weight, eps)
+ out = norm(x)
+ torch.testing.assert_close(out, expected, atol=1e-3, rtol=1e-3)
+
+
+# ── CustomOp dispatch tests ──
+
+
+def test_layernorm_inherits_from_customop():
+ """LayerNorm inherits from CustomOp for platform dispatch."""
+ from vllm_omni.diffusion.layers.custom_op import CustomOp
+ from vllm_omni.diffusion.layers.norm import LayerNorm
+
+ norm = LayerNorm(64)
+ assert isinstance(norm, CustomOp)
+
+
+def test_rmsnorm_inherits_from_customop():
+ """RMSNorm inherits from CustomOp for platform dispatch."""
+ from vllm_omni.diffusion.layers.custom_op import CustomOp
+ from vllm_omni.diffusion.layers.norm import RMSNorm
+
+ norm = RMSNorm(64)
+ assert isinstance(norm, CustomOp)
+
+
+def test_layernorm_has_platform_methods():
+ """LayerNorm has forward methods for each platform."""
+ from vllm_omni.diffusion.layers.norm import LayerNorm
+
+ norm = LayerNorm(64)
+
+ assert hasattr(norm, "forward_cuda")
+ assert hasattr(norm, "forward_hip")
+ assert hasattr(norm, "forward_xpu")
+ assert hasattr(norm, "forward_npu")
+ assert hasattr(norm, "forward_native")
+
+
+def test_rmsnorm_has_platform_methods():
+ """RMSNorm has forward methods for each platform."""
+ from vllm_omni.diffusion.layers.norm import RMSNorm
+
+ norm = RMSNorm(64)
+
+ assert hasattr(norm, "forward_cuda")
+ assert hasattr(norm, "forward_hip")
+ assert hasattr(norm, "forward_xpu")
+ assert hasattr(norm, "forward_npu")
+ assert hasattr(norm, "forward_native")
+
+
+def test_layernorm_forward_native_directly():
+ """LayerNorm.forward_native can be called directly."""
+ from vllm_omni.diffusion.layers.norm import LayerNorm
+
+ dim = 64
+ norm = LayerNorm(dim)
+ x = torch.randn(2, 4, dim)
+
+ out = norm.forward_native(x)
+
+ assert out.shape == (2, 4, dim)
+
+
+def test_rmsnorm_forward_native_directly():
+ """RMSNorm.forward_native can be called directly."""
+ from vllm_omni.diffusion.layers.norm import RMSNorm
+
+ hidden_size = 64
+ norm = RMSNorm(hidden_size)
+ x = torch.randn(2, 4, hidden_size)
+
+ out = norm.forward_native(x)
+
+ assert out.shape == (2, 4, hidden_size)
+
+
+# ── Edge case tests ──
+
+
+def test_layernorm_with_large_dim():
+ """LayerNorm works with large hidden dimensions."""
+ from vllm_omni.diffusion.layers.norm import LayerNorm
+
+ dim = 4096
+ norm = LayerNorm(dim)
+ x = torch.randn(1, 16, dim)
+
+ out = norm(x)
+
+ assert out.shape == (1, 16, dim)
+
+
+def test_rmsnorm_with_large_dim():
+ """RMSNorm works with large hidden dimensions."""
+ from vllm_omni.diffusion.layers.norm import RMSNorm
+
+ hidden_size = 4096
+ norm = RMSNorm(hidden_size)
+ x = torch.randn(1, 16, hidden_size)
+
+ out = norm(x)
+
+ assert out.shape == (1, 16, hidden_size)
+
+
+def test_layernorm_with_single_element_batch():
+ """LayerNorm works with batch size of 1."""
+ from vllm_omni.diffusion.layers.norm import LayerNorm
+
+ dim = 64
+ norm = LayerNorm(dim)
+ x = torch.randn(1, 1, dim)
+
+ out = norm(x)
+
+ assert out.shape == (1, 1, dim)
+
+
+def test_rmsnorm_with_single_element_batch():
+ """RMSNorm works with batch size of 1."""
+ from vllm_omni.diffusion.layers.norm import RMSNorm
+
+ hidden_size = 64
+ norm = RMSNorm(hidden_size)
+ x = torch.randn(1, 1, hidden_size)
+
+ out = norm(x)
+
+ assert out.shape == (1, 1, hidden_size)
diff --git a/tests/diffusion/layers/test_rotary_emb_equivalence.py b/tests/diffusion/layers/test_rotary_emb_equivalence.py
new file mode 100644
index 0000000000..2fbb7a31f5
--- /dev/null
+++ b/tests/diffusion/layers/test_rotary_emb_equivalence.py
@@ -0,0 +1,112 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Numerical equivalence tests for rotary embedding implementations (#2436).
+
+Verifies that the optimized stack+flatten RoPE produces bit-identical results
+to the original strided-slice implementation across various tensor shapes and
+dtypes, ensuring the refactor is safe.
+"""
+
+from __future__ import annotations
+
+import pytest
+import torch
+
+
+def _apply_rotary_emb_helios_original(
+ hidden_states: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> torch.Tensor:
+ """Original Helios RoPE using strided slice assignment (pre-#2436)."""
+ x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
+ out = torch.empty_like(hidden_states)
+ out[..., 0::2] = x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2]
+ out[..., 1::2] = x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2]
+ return out.type_as(hidden_states)
+
+
+def _apply_rotary_emb_helios_optimized(
+ hidden_states: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> torch.Tensor:
+ """Optimized Helios RoPE using stack+flatten (post-#2436)."""
+ x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
+ rotated = torch.stack(
+ (
+ x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2],
+ x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2],
+ ),
+ dim=-1,
+ )
+ return rotated.flatten(-2, -1).type_as(hidden_states)
+
+
+def _make_inputs(
+ batch: int,
+ seq_len: int,
+ num_heads: int,
+ head_dim: int,
+ dtype: torch.dtype = torch.float32,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Generate random hidden_states and freqs_cis for testing."""
+ torch.manual_seed(42)
+ hidden_states = torch.randn(batch, seq_len, num_heads, head_dim, dtype=dtype)
+ # freqs_cis: [B, seq, head_dim*2] — cos and sin concatenated along last dim
+ freqs_cis = torch.randn(batch, seq_len, head_dim * 2, dtype=dtype)
+ return hidden_states, freqs_cis
+
+
+class TestHeliosRoPEEquivalence:
+ """Verify optimized Helios RoPE is numerically identical to original."""
+
+ @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
+ def test_equivalence_across_dtypes(self, dtype: torch.dtype) -> None:
+ """Optimized output must be bit-identical to original across dtypes."""
+ hidden, freqs = _make_inputs(2, 16, 8, 64, dtype=dtype)
+ original = _apply_rotary_emb_helios_original(hidden, freqs)
+ optimized = _apply_rotary_emb_helios_optimized(hidden, freqs)
+ torch.testing.assert_close(optimized, original, atol=0, rtol=0)
+
+ @pytest.mark.parametrize(
+ "batch,seq_len,num_heads,head_dim",
+ [
+ (1, 8, 1, 32), # minimal: single batch, single head
+ (2, 16, 8, 64), # typical transformer config
+ (1, 8192, 4, 64), # video-scale patch tokens (720p DiT)
+ (4, 32, 16, 128), # large head_dim
+ ],
+ )
+ def test_equivalence_across_shapes(self, batch: int, seq_len: int, num_heads: int, head_dim: int) -> None:
+ """Equivalence must hold across different tensor shapes."""
+ hidden, freqs = _make_inputs(batch, seq_len, num_heads, head_dim)
+ original = _apply_rotary_emb_helios_original(hidden, freqs)
+ optimized = _apply_rotary_emb_helios_optimized(hidden, freqs)
+ torch.testing.assert_close(optimized, original, atol=0, rtol=0)
+
+ def test_output_contiguous(self) -> None:
+ """Optimized output should be contiguous in memory."""
+ hidden, freqs = _make_inputs(2, 16, 8, 64)
+ optimized = _apply_rotary_emb_helios_optimized(hidden, freqs)
+ assert optimized.is_contiguous()
+
+ def test_output_shape_preserved(self) -> None:
+ """Output shape must match input shape."""
+ hidden, freqs = _make_inputs(2, 16, 8, 64)
+ optimized = _apply_rotary_emb_helios_optimized(hidden, freqs)
+ assert optimized.shape == hidden.shape
+
+ def test_output_dtype_preserved(self) -> None:
+ """Output dtype must match input dtype."""
+ hidden, freqs = _make_inputs(2, 16, 8, 64, dtype=torch.float16)
+ optimized = _apply_rotary_emb_helios_optimized(hidden, freqs)
+ assert optimized.dtype == hidden.dtype
+
+ def test_odd_head_dim_raises(self) -> None:
+ """Odd head_dim should fail at unflatten (not a valid RoPE config)."""
+ hidden = torch.randn(1, 4, 2, 63)
+ freqs = torch.randn(1, 4, 126)
+ with pytest.raises(RuntimeError):
+ _apply_rotary_emb_helios_optimized(hidden, freqs)
diff --git a/tests/diffusion/models/bagel/test_trajectory_recording.py b/tests/diffusion/models/bagel/test_trajectory_recording.py
index 80b3f9d9ba..345eac1078 100644
--- a/tests/diffusion/models/bagel/test_trajectory_recording.py
+++ b/tests/diffusion/models/bagel/test_trajectory_recording.py
@@ -4,10 +4,10 @@
import types
from dataclasses import dataclass
-from unittest.mock import MagicMock, patch
import pytest
import torch
+from pytest_mock import MockerFixture
from vllm_omni.diffusion.models.bagel.bagel_transformer import (
Bagel,
@@ -23,9 +23,9 @@
EXPECTED_STEPS = NUM_TIMESTEPS - 1
-def _make_mock_bagel():
+def _make_mock_bagel(mocker: MockerFixture):
"""Create a mock Bagel with forward returning constant velocity."""
- mock = MagicMock(spec=Bagel)
+ mock = mocker.MagicMock(spec=Bagel)
mock._sp_size = 1
# forward returns a small constant velocity so x_t changes each step
@@ -78,18 +78,22 @@ def _make_generate_args(num_tokens=NUM_TOKENS, hidden_dim=HIDDEN_DIM, cfg=False)
@pytest.fixture(params=[False, True], ids=["no_cfg", "batched_cfg"])
-def bagel_and_args(request):
+def bagel_and_args(
+ request,
+ monkeypatch: pytest.MonkeyPatch,
+ mocker: MockerFixture,
+):
"""Mock Bagel instance and generate_image arguments.
Parametrized over CFG mode so every test runs on both the no-CFG
and batched-CFG code paths.
"""
cfg = request.param
- with patch(
+ monkeypatch.setattr(
"vllm_omni.diffusion.models.bagel.bagel_transformer.get_classifier_free_guidance_world_size",
- return_value=1,
- ):
- yield _make_mock_bagel(), _make_generate_args(cfg=cfg)
+ lambda: 1,
+ )
+ yield _make_mock_bagel(mocker), _make_generate_args(cfg=cfg)
class TestTrajectoryRecording:
@@ -188,12 +192,16 @@ class TestTrajectoryLogProbs:
"""Tests for log-prob recording when a scheduler is provided."""
@pytest.fixture()
- def bagel_scheduler_args(self):
- with patch(
+ def bagel_scheduler_args(
+ self,
+ monkeypatch: pytest.MonkeyPatch,
+ mocker: MockerFixture,
+ ):
+ monkeypatch.setattr(
"vllm_omni.diffusion.models.bagel.bagel_transformer.get_classifier_free_guidance_world_size",
- return_value=1,
- ):
- yield _make_mock_bagel(), _make_generate_args(), _MockScheduler()
+ lambda: 1,
+ )
+ yield _make_mock_bagel(mocker), _make_generate_args(), _MockScheduler()
def test_log_probs_recorded_with_scheduler(self, bagel_scheduler_args):
bagel, args, scheduler = bagel_scheduler_args
diff --git a/tests/diffusion/models/flux2/test_flux2_transformer_tp.py b/tests/diffusion/models/flux2/test_flux2_transformer_tp.py
index faad08afd1..54dda1dd07 100644
--- a/tests/diffusion/models/flux2/test_flux2_transformer_tp.py
+++ b/tests/diffusion/models/flux2/test_flux2_transformer_tp.py
@@ -1,7 +1,6 @@
-from unittest.mock import MagicMock, patch
-
import pytest
import torch
+from pytest_mock import MockerFixture
from tests.utils import hardware_test
from vllm_omni.diffusion.models.flux2.flux2_transformer import (
@@ -12,14 +11,17 @@
# Initialize TP group before tests
@pytest.fixture(scope="function", autouse=True)
-def setup_tp_group():
+def setup_tp_group(mocker: MockerFixture):
"""Set up TP group for each test function"""
- with patch("vllm.model_executor.layers.linear.get_tensor_model_parallel_world_size", return_value=2):
- with patch("vllm.distributed.parallel_state.get_tp_group") as mock_get_tp_group:
- mock_tp_group = MagicMock()
- mock_tp_group.world_size = 2
- mock_get_tp_group.return_value = mock_tp_group
- yield
+ mocker.patch(
+ "vllm.model_executor.layers.linear.get_tensor_model_parallel_world_size",
+ return_value=2,
+ )
+ mock_get_tp_group = mocker.patch("vllm.distributed.parallel_state.get_tp_group")
+ mock_tp_group = mocker.MagicMock()
+ mock_tp_group.world_size = 2
+ mock_get_tp_group.return_value = mock_tp_group
+ yield
class TestFlux2TransformerWeightLoading:
diff --git a/tests/diffusion/models/glm_image/test_glm_image_sp.py b/tests/diffusion/models/glm_image/test_glm_image_sp.py
new file mode 100644
index 0000000000..1b1c8d7a75
--- /dev/null
+++ b/tests/diffusion/models/glm_image/test_glm_image_sp.py
@@ -0,0 +1,134 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Tests for GLM-Image Sequence Parallelism support."""
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from vllm_omni.diffusion.data import DiffusionParallelConfig
+
+
+@pytest.fixture(scope="function", autouse=True)
+def setup_sp_groups():
+ """Set up SP and TP groups for each test function."""
+ with patch("vllm_omni.diffusion.distributed.parallel_state.get_sp_group") as mock_get_sp_group:
+ with patch("vllm.model_executor.layers.linear.get_tensor_model_parallel_world_size", return_value=1):
+ with patch("vllm.distributed.parallel_state.get_tp_group") as mock_get_tp_group:
+ mock_sp_group = MagicMock()
+ mock_sp_group.world_size = 4
+ mock_get_sp_group.return_value = mock_sp_group
+
+ mock_tp_group = MagicMock()
+ mock_tp_group.world_size = 1
+ mock_get_tp_group.return_value = mock_tp_group
+ yield
+
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+def test_glm_image_sp_plan_defined():
+ """Test that _sp_plan is properly defined on GlmImageTransformer2DModel."""
+ from vllm_omni.diffusion.models.glm_image.glm_image_transformer import (
+ GlmImageTransformer2DModel,
+ )
+
+ assert hasattr(GlmImageTransformer2DModel, "_sp_plan")
+ plan = GlmImageTransformer2DModel._sp_plan
+ assert plan is not None
+
+ # Verify plan structure
+ assert "prepare" in plan
+ assert "proj_out" in plan
+
+
+def test_glm_image_sp_plan_valid():
+ """Validate _sp_plan structure."""
+ from vllm_omni.diffusion.distributed.sp_plan import validate_sp_plan
+ from vllm_omni.diffusion.models.glm_image.glm_image_transformer import (
+ GlmImageTransformer2DModel,
+ )
+
+ plan = GlmImageTransformer2DModel._sp_plan
+ validate_sp_plan(plan)
+
+
+def test_glm_image_prepare_module_exists():
+ """Test that GlmImagePrepare module exists."""
+ from vllm_omni.diffusion.models.glm_image.glm_image_transformer import (
+ GlmImagePrepare,
+ )
+
+ assert GlmImagePrepare is not None
+
+
+def test_glm_image_attention_accepts_parallel_config():
+ """Test that GlmImageAttention accepts parallel_config parameter."""
+ from vllm_omni.diffusion.models.glm_image.glm_image_transformer import (
+ GlmImageAttention,
+ )
+
+ parallel_config = DiffusionParallelConfig(
+ ulysses_degree=2,
+ ring_degree=2,
+ tensor_parallel_size=1,
+ sequence_parallel_size=4,
+ )
+
+ attn = GlmImageAttention(
+ dim=2560,
+ num_heads=64,
+ head_dim=40,
+ parallel_config=parallel_config,
+ )
+
+ assert attn.parallel_config is not None
+ assert attn.parallel_config.sequence_parallel_size == 4
+
+
+def test_glm_image_transformer_block_accepts_parallel_config():
+ """Test that GlmImageTransformerBlock accepts parallel_config parameter."""
+ from vllm_omni.diffusion.models.glm_image.glm_image_transformer import (
+ GlmImageTransformerBlock,
+ )
+
+ parallel_config = DiffusionParallelConfig(
+ ulysses_degree=2,
+ ring_degree=2,
+ tensor_parallel_size=1,
+ sequence_parallel_size=4,
+ )
+
+ block = GlmImageTransformerBlock(
+ dim=2560,
+ num_attention_heads=64,
+ attention_head_dim=40,
+ time_embed_dim=512,
+ parallel_config=parallel_config,
+ )
+
+ assert block.attn1.parallel_config is not None
+ assert block.attn1.parallel_config.sequence_parallel_size == 4
+
+
+def test_glm_image_has_sp_support():
+ """Test that GLM-Image has SP support implemented."""
+ from vllm_omni.diffusion.models.glm_image.glm_image_transformer import (
+ GlmImageTransformer2DModel,
+ )
+
+ # Check that the model has parallel_config support
+ assert hasattr(GlmImageTransformer2DModel, "__init__")
+
+ # Verify the model can be instantiated with SP config
+
+ # This test just verifies the structure exists
+ # Actual SP testing requires multi-GPU setup
+
+
+@pytest.mark.cuda
+@pytest.mark.sp
+def test_glm_image_sp_inference():
+ """Test SP inference (requires multi-GPU setup)."""
+ pytest.skip("Requires multi-GPU SP setup")
diff --git a/tests/diffusion/models/hunyuan_image_3/test_hunyuan_fused_moe.py b/tests/diffusion/models/hunyuan_image3/test_hunyuan_fused_moe.py
similarity index 85%
rename from tests/diffusion/models/hunyuan_image_3/test_hunyuan_fused_moe.py
rename to tests/diffusion/models/hunyuan_image3/test_hunyuan_fused_moe.py
index 2cda9116c7..626f78eed9 100644
--- a/tests/diffusion/models/hunyuan_image_3/test_hunyuan_fused_moe.py
+++ b/tests/diffusion/models/hunyuan_image3/test_hunyuan_fused_moe.py
@@ -12,7 +12,7 @@ class TestSetForwardContextNumTokens:
def test_sets_num_tokens_when_context_available(self, mocker):
"""num_tokens should be set on ForwardContext when available."""
- import vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe as hunyuan_moe
+ import vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe as hunyuan_moe
mock_ctx = mocker.MagicMock()
del mock_ctx.in_profile_run # simulate missing attr
@@ -26,7 +26,7 @@ def test_sets_num_tokens_when_context_available(self, mocker):
def test_sets_in_profile_run_only_if_missing(self, mocker):
"""in_profile_run should not be overwritten if already set."""
- import vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe as hunyuan_moe
+ import vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe as hunyuan_moe
mock_ctx = mocker.MagicMock()
mock_ctx.in_profile_run = True # already set
@@ -40,7 +40,7 @@ def test_sets_in_profile_run_only_if_missing(self, mocker):
def test_noop_when_context_unavailable(self, mocker):
"""Should do nothing when ForwardContext is not available."""
- import vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe as hunyuan_moe
+ import vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe as hunyuan_moe
mocker.patch.object(hunyuan_moe._vllm_fc, "is_forward_context_available", return_value=False)
mock_get = mocker.patch.object(hunyuan_moe._vllm_fc, "get_forward_context")
@@ -55,11 +55,11 @@ class TestHunyuanFusedMoEPlatformDispatch:
def test_default_platform_uses_default_impl_qualname(self, mocker):
"""HunyuanFusedMoE should resolve the impl class from the platform hook."""
- import vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe as hunyuan_moe
+ import vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe as hunyuan_moe
mock_platform = mocker.MagicMock()
mock_platform.get_diffusion_model_impl_qualname.return_value = (
- "vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe.HunyuanFusedMoEDefault"
+ "vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe.HunyuanFusedMoEDefault"
)
mocker.patch.object(
@@ -71,7 +71,7 @@ def test_default_platform_uses_default_impl_qualname(self, mocker):
mock_impl = mocker.MagicMock()
mock_resolve.return_value = mock_impl
- from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import (
+ from vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe import (
HunyuanFusedMoE,
)
@@ -80,7 +80,7 @@ def test_default_platform_uses_default_impl_qualname(self, mocker):
mock_platform.prepare_diffusion_op_runtime.assert_called_once_with("hunyuan_fused_moe")
mock_platform.get_diffusion_model_impl_qualname.assert_called_once_with("hunyuan_fused_moe")
mock_resolve.assert_called_once_with(
- "vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe.HunyuanFusedMoEDefault"
+ "vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe.HunyuanFusedMoEDefault"
)
mock_impl.assert_called_once_with(prefix="")
@@ -90,7 +90,7 @@ class TestHunyuanFusedMoEFactory:
def test_new_delegates_to_impl_class(self, mocker):
"""HunyuanFusedMoE(prefix=..., **kwargs) should instantiate and return impl instance."""
- import vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe as hunyuan_moe
+ import vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe as hunyuan_moe
class MockImpl:
def __init__(self, *, prefix: str = "", **kwargs):
@@ -104,7 +104,7 @@ def __init__(self, *, prefix: str = "", **kwargs):
mock_impl_class = mocker.MagicMock(return_value=MockImpl(prefix="test", a=1))
mocker.patch.object(hunyuan_moe, "resolve_obj_by_qualname", return_value=mock_impl_class)
- from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import (
+ from vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe import (
HunyuanFusedMoE,
)
@@ -119,7 +119,7 @@ def __init__(self, *, prefix: str = "", **kwargs):
def test_make_expert_params_mapping_delegates_to_impl(self, mocker):
"""make_expert_params_mapping should delegate to impl class method."""
- import vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe as hunyuan_moe
+ import vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe as hunyuan_moe
expected_mapping = [("a", "b", 0, "c")]
mock_platform = mocker.MagicMock()
@@ -130,7 +130,7 @@ def test_make_expert_params_mapping_delegates_to_impl(self, mocker):
mock_impl_class.make_expert_params_mapping = mocker.MagicMock(return_value=expected_mapping)
mocker.patch.object(hunyuan_moe, "resolve_obj_by_qualname", return_value=mock_impl_class)
- from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import (
+ from vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe import (
HunyuanFusedMoE,
)
diff --git a/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_sampler.py b/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_sampler.py
new file mode 100644
index 0000000000..51f6a85f58
--- /dev/null
+++ b/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_sampler.py
@@ -0,0 +1,190 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for HunyuanImage3 AR sampler logic (stage transitions,
+ratio restriction, comprehension blocking)."""
+
+import pytest
+import torch
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+# Fake token IDs for testing (avoid importing the real model).
+END_OF_THINK = 100
+RECAPTION = 101
+END_OF_RECAPTION = 102
+ANSWER = 103
+BOI = 104
+SIZE_TOKEN = 105
+EOS = 106
+RATIO_START = 200
+RATIO_END = 210
+RATIO_OTHER_START = 220
+RATIO_OTHER_END = 223
+
+
+class FakeSamplerModel:
+ """Minimal stub that replicates the sampler-relevant attributes of
+ HunyuanImage3ForConditionalGeneration without loading real weights."""
+
+ def __init__(self, *, is_comprehension: bool = False):
+ self._is_comprehension = is_comprehension
+ self._eos_token_id = EOS
+ self._end_of_think_id = END_OF_THINK
+ self._recaption_id = RECAPTION
+ self._end_of_recaption_id = END_OF_RECAPTION
+ self._answer_id = ANSWER
+ self._mrope_boi_token_id = BOI
+ self._size_token_id = SIZE_TOKEN
+ self._start_ratio_id = RATIO_START
+ self._end_ratio_id = RATIO_END
+ self._ratio_other_slices = [(RATIO_OTHER_START, RATIO_OTHER_END + 1)]
+ self._all_ratio_ids = set(range(RATIO_START, RATIO_END + 1))
+ self._all_ratio_ids.update(range(RATIO_OTHER_START, RATIO_OTHER_END + 1))
+
+ self._stage_transitions: dict[int, list[int]] = {}
+ if not is_comprehension:
+ self._stage_transitions[END_OF_THINK] = [RECAPTION]
+ self._stage_transitions[END_OF_RECAPTION] = [ANSWER, BOI, SIZE_TOKEN]
+
+ self._blocked_token_ids: set[int] = set()
+ if is_comprehension:
+ self._blocked_token_ids.update([BOI, SIZE_TOKEN])
+ self._blocked_token_ids.update(self._all_ratio_ids)
+
+ # Bind the real methods from the model class.
+ from vllm_omni.model_executor.models.hunyuan_image3.hunyuan_image3 import (
+ HunyuanImage3ForConditionalGeneration as _Real,
+ )
+
+ _get_forced_token = _Real._get_forced_token
+ _apply_ratio_restriction = _Real._apply_ratio_restriction
+
+
+class TestGetForcedToken:
+ """Tests for the stateless _get_forced_token method."""
+
+ def setup_method(self):
+ self.model = FakeSamplerModel(is_comprehension=False)
+
+ def test_no_trigger_returns_none(self):
+ assert self.model._get_forced_token([1, 2, 3]) is None
+
+ def test_empty_history_returns_none(self):
+ assert self.model._get_forced_token([]) is None
+
+ def test_end_of_think_forces_recaption(self):
+ assert self.model._get_forced_token([END_OF_THINK]) == RECAPTION
+
+ def test_end_of_think_completed(self):
+ assert self.model._get_forced_token([END_OF_THINK, RECAPTION]) is None
+
+ def test_end_of_recaption_forces_answer(self):
+ tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION]
+ assert self.model._get_forced_token(tokens) == ANSWER
+
+ def test_end_of_recaption_forces_boi_after_answer(self):
+ tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION, ANSWER]
+ assert self.model._get_forced_token(tokens) == BOI
+
+ def test_end_of_recaption_forces_size_after_boi(self):
+ tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION, ANSWER, BOI]
+ assert self.model._get_forced_token(tokens) == SIZE_TOKEN
+
+ def test_full_sequence_complete(self):
+ tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION, ANSWER, BOI, SIZE_TOKEN]
+ assert self.model._get_forced_token(tokens) is None
+
+ def test_diverged_history_returns_none(self):
+ tokens = [END_OF_RECAPTION, 999] # 999 != ANSWER
+ assert self.model._get_forced_token(tokens) is None
+
+ def test_later_trigger_takes_precedence(self):
+ tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION]
+ assert self.model._get_forced_token(tokens) == ANSWER
+
+ def test_trigger_with_extra_tokens_before(self):
+ tokens = [1, 2, 3, END_OF_THINK]
+ assert self.model._get_forced_token(tokens) == RECAPTION
+
+
+class TestComprehensionBlocking:
+ """Tests for comprehension mode token blocking."""
+
+ def test_blocked_tokens_masked(self):
+ model = FakeSamplerModel(is_comprehension=True)
+ vocab_size = 300
+ logits = torch.zeros(1, vocab_size)
+ logits[0, BOI] = 5.0
+ logits[0, SIZE_TOKEN] = 3.0
+ logits[0, RATIO_START] = 2.0
+ min_score = torch.finfo(logits.dtype).min
+
+ for tid in model._blocked_token_ids:
+ if tid < vocab_size:
+ logits[0, tid] = min_score
+
+ assert logits[0, BOI].item() == min_score
+ assert logits[0, SIZE_TOKEN].item() == min_score
+ assert logits[0, RATIO_START].item() == min_score
+
+ def test_non_blocked_tokens_preserved(self):
+ model = FakeSamplerModel(is_comprehension=True)
+ vocab_size = 300
+ logits = torch.zeros(1, vocab_size)
+ logits[0, 50] = 7.0
+ min_score = torch.finfo(logits.dtype).min
+
+ for tid in model._blocked_token_ids:
+ if tid < vocab_size:
+ logits[0, tid] = min_score
+
+ assert logits[0, 50].item() == 7.0
+
+
+class TestRatioRestriction:
+ """Tests for _apply_ratio_restriction (greedy: only argmax ratio survives)."""
+
+ def test_greedy_selects_single_ratio_token(self):
+ model = FakeSamplerModel(is_comprehension=False)
+ vocab_size = 300
+ logits = torch.zeros(1, vocab_size)
+ logits[0, RATIO_START + 3] = 10.0
+ logits[0, RATIO_START + 1] = 5.0
+ logits[0, 50] = 20.0 # non-ratio, should be masked
+ min_score = torch.finfo(logits.dtype).min
+
+ model._apply_ratio_restriction(logits, 0, min_score)
+
+ assert logits[0, RATIO_START + 3].item() == 0
+ assert logits[0, RATIO_START + 1].item() == min_score
+ assert logits[0, 50].item() == min_score
+
+ def test_extra_ratio_slices_considered(self):
+ model = FakeSamplerModel(is_comprehension=False)
+ vocab_size = 300
+ logits = torch.zeros(1, vocab_size)
+ logits[0, RATIO_OTHER_START] = 15.0
+ logits[0, RATIO_START] = 5.0
+ min_score = torch.finfo(logits.dtype).min
+
+ model._apply_ratio_restriction(logits, 0, min_score)
+
+ assert logits[0, RATIO_OTHER_START].item() == 0
+ assert logits[0, RATIO_START].item() == min_score
+
+
+class TestForceEosAfterRatio:
+ """Tests that a ratio token as last_token forces EOS."""
+
+ def test_ratio_token_forces_eos(self):
+ model = FakeSamplerModel(is_comprehension=False)
+ vocab_size = 300
+ logits = torch.randn(1, vocab_size)
+ min_score = torch.finfo(logits.dtype).min
+
+ logits[0].fill_(min_score)
+ logits[0, model._eos_token_id] = 0
+
+ assert logits[0, EOS].item() == 0
+ non_eos_max = logits[0, :EOS].max().item()
+ assert non_eos_max == min_score
diff --git a/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py b/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py
new file mode 100644
index 0000000000..873b52bf7a
--- /dev/null
+++ b/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py
@@ -0,0 +1,38 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import json
+from pathlib import Path
+from types import SimpleNamespace
+
+import numpy as np
+import pytest
+from PIL import Image
+
+from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit_plus import (
+ get_qwen_image_edit_plus_pre_process_func,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]
+
+
+def test_qwen_image_edit_plus_rejects_too_many_input_images(tmp_path: Path):
+ vae_dir = tmp_path / "vae"
+ vae_dir.mkdir()
+ # Keep the mock config intentionally minimal: this test only needs the
+ # fields touched during pre-process initialization.
+ (vae_dir / "config.json").write_text(json.dumps({"z_dim": 16}))
+
+ pre_process = get_qwen_image_edit_plus_pre_process_func(SimpleNamespace(model=str(tmp_path)))
+ image = Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8))
+ request = SimpleNamespace(
+ prompts=[
+ {
+ "prompt": "combine",
+ "multi_modal_data": {"image": [image, image, image, image, image]},
+ }
+ ],
+ sampling_params=SimpleNamespace(height=None, width=None),
+ )
+
+ with pytest.raises(ValueError, match=r"At most 4 images are supported by this model"):
+ pre_process(request)
diff --git a/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py b/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py
new file mode 100644
index 0000000000..f5676a0056
--- /dev/null
+++ b/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py
@@ -0,0 +1,260 @@
+import inspect
+from types import SimpleNamespace
+
+import pytest
+import torch
+from torch import nn
+
+from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import (
+ QwenImagePipeline,
+)
+from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit import (
+ QwenImageEditPipeline,
+)
+from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit_plus import (
+ QwenImageEditPlusPipeline,
+)
+from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_layered import (
+ QwenImageLayeredPipeline,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+class _RejectingTextEncoder:
+ dtype = torch.float32
+
+ def __call__(self, *args, **kwargs):
+ raise AssertionError("text encoder should not run for prompts that exceed max_sequence_length")
+
+
+class _FakeModelInputs:
+ def __init__(self, total_sequence_length: int):
+ attention_mask = torch.ones((1, total_sequence_length), dtype=torch.long)
+ self.input_ids = attention_mask.clone()
+ self.attention_mask = attention_mask
+ self.pixel_values = None
+ self.image_grid_thw = None
+
+ def to(self, device):
+ return self
+
+
+class _FakeTokenizer:
+ def __init__(self, total_sequence_length: int | list[int]):
+ if isinstance(total_sequence_length, list):
+ self.total_sequence_lengths = list(total_sequence_length)
+ else:
+ self.total_sequence_lengths = [total_sequence_length]
+
+ def __call__(self, *args, **kwargs):
+ if len(self.total_sequence_lengths) > 1:
+ total_sequence_length = self.total_sequence_lengths.pop(0)
+ else:
+ total_sequence_length = self.total_sequence_lengths[0]
+ return _FakeModelInputs(total_sequence_length)
+
+
+class _FakeProcessor(_FakeTokenizer):
+ pass
+
+
+class _FakeScheduler:
+ def __init__(self):
+ self.begin_index = None
+
+ def set_begin_index(self, begin_index: int):
+ self.begin_index = begin_index
+
+
+PIPELINE_CASES = [
+ pytest.param(QwenImagePipeline, 34, "tokenizer", id="qwen-image"),
+ pytest.param(QwenImageLayeredPipeline, 34, "tokenizer", id="qwen-image-layered"),
+ pytest.param(QwenImageEditPipeline, 64, "processor", id="qwen-image-edit"),
+ pytest.param(QwenImageEditPlusPipeline, 64, "processor", id="qwen-image-edit-plus"),
+]
+
+
+def _make_pipeline(
+ pipeline_class: type,
+ *,
+ total_sequence_length: int,
+ drop_idx: int,
+ input_kind: str,
+):
+ pipeline = object.__new__(pipeline_class)
+ nn.Module.__init__(pipeline)
+ pipeline.device = torch.device("cpu")
+ pipeline.text_encoder = _RejectingTextEncoder()
+ pipeline.tokenizer_max_length = 1024
+ pipeline.prompt_template_encode = "{}"
+ pipeline.prompt_template_encode_start_idx = drop_idx
+ pipeline.tokenizer = _FakeTokenizer([total_sequence_length, 0])
+ if input_kind == "processor":
+ pipeline.processor = _FakeProcessor(total_sequence_length)
+ return pipeline
+
+
+@pytest.mark.parametrize(("pipeline_class", "drop_idx", "input_kind"), PIPELINE_CASES)
+def test_encode_prompt_rejects_prompt_longer_than_default_max_sequence_length(
+ pipeline_class: type,
+ drop_idx: int,
+ input_kind: str,
+):
+ pipeline = _make_pipeline(
+ pipeline_class,
+ total_sequence_length=1025,
+ drop_idx=drop_idx,
+ input_kind=input_kind,
+ )
+
+ with pytest.raises(ValueError, match=r"got 1025 tokens, but `max_sequence_length` is 1024"):
+ pipeline.encode_prompt(prompt="prompt")
+
+
+@pytest.mark.parametrize(("pipeline_class", "drop_idx", "input_kind"), PIPELINE_CASES)
+def test_encode_prompt_rejects_prompt_longer_than_explicit_max_sequence_length(
+ pipeline_class: type,
+ drop_idx: int,
+ input_kind: str,
+):
+ pipeline = _make_pipeline(
+ pipeline_class,
+ total_sequence_length=17,
+ drop_idx=drop_idx,
+ input_kind=input_kind,
+ )
+
+ with pytest.raises(ValueError, match=r"got 17 tokens, but `max_sequence_length` is 16"):
+ pipeline.encode_prompt(prompt="prompt", max_sequence_length=16)
+
+
+def test_prepare_encode_defaults_to_tokenizer_max_length():
+ pipeline = object.__new__(QwenImagePipeline)
+ nn.Module.__init__(pipeline)
+ pipeline.tokenizer_max_length = 1024
+ pipeline.vae_scale_factor = 8
+ pipeline.default_sample_size = 128
+ pipeline.scheduler = _FakeScheduler()
+ pipeline._extract_prompts = lambda prompts: (["prompt"], None)
+
+ captured = {}
+
+ def _fake_prepare_generation_context(**kwargs):
+ captured["max_sequence_length"] = kwargs["max_sequence_length"]
+ embeds = torch.ones((1, 1, 1))
+ mask = torch.ones((1, 1), dtype=torch.long)
+ return {
+ "prompt_embeds": embeds,
+ "prompt_embeds_mask": mask,
+ "negative_prompt_embeds": None,
+ "negative_prompt_embeds_mask": None,
+ "latents": embeds,
+ "timesteps": torch.tensor([1]),
+ "do_true_cfg": False,
+ "guidance": None,
+ "img_shapes": [[(1, 1, 1)]],
+ "txt_seq_lens": [1],
+ "negative_txt_seq_lens": None,
+ }
+
+ pipeline._prepare_generation_context = _fake_prepare_generation_context
+ state = SimpleNamespace(
+ prompts=["prompt"],
+ sampling=SimpleNamespace(
+ height=None,
+ width=None,
+ num_inference_steps=None,
+ sigmas=None,
+ guidance_scale_provided=False,
+ num_outputs_per_prompt=0,
+ generator=None,
+ true_cfg_scale=None,
+ max_sequence_length=None,
+ ),
+ )
+
+ pipeline.prepare_encode(state)
+
+ assert captured["max_sequence_length"] == 1024
+
+
+@pytest.mark.parametrize(
+ ("pipeline_class", "drop_idx"),
+ [
+ pytest.param(QwenImageEditPipeline, 64, id="qwen-image-edit"),
+ pytest.param(QwenImageEditPlusPipeline, 64, id="qwen-image-edit-plus"),
+ ],
+)
+def test_edit_pipelines_validate_text_prompt_length_before_image_token_expansion(
+ pipeline_class: type,
+ drop_idx: int,
+):
+ pipeline = object.__new__(pipeline_class)
+ nn.Module.__init__(pipeline)
+ pipeline.device = torch.device("cpu")
+ pipeline.text_encoder = _RejectingTextEncoder()
+ pipeline.tokenizer_max_length = 1024
+ pipeline.prompt_template_encode = "{}"
+ pipeline.prompt_template_encode_start_idx = drop_idx
+ pipeline.tokenizer = _FakeTokenizer([8, 0])
+ pipeline.processor = _FakeProcessor(drop_idx + 1500)
+
+ with pytest.raises(AssertionError, match="text encoder should not run"):
+ pipeline.encode_prompt(prompt="short prompt")
+
+
+@pytest.mark.parametrize(
+ "pipeline_class",
+ [
+ pytest.param(QwenImagePipeline, id="qwen-image"),
+ pytest.param(QwenImageLayeredPipeline, id="qwen-image-layered"),
+ ],
+)
+def test_qwen_generation_validator_excludes_template_suffix_from_budget(pipeline_class: type):
+ pipeline = object.__new__(pipeline_class)
+ nn.Module.__init__(pipeline)
+ pipeline.device = torch.device("cpu")
+ pipeline.text_encoder = _RejectingTextEncoder()
+ pipeline.tokenizer_max_length = 1024
+ pipeline.prompt_template_encode = "{}"
+ pipeline.prompt_template_encode_start_idx = 34
+ pipeline.tokenizer = _FakeTokenizer([1029, 5])
+
+ with pytest.raises(AssertionError, match="text encoder should not run"):
+ pipeline.encode_prompt(prompt="boundary prompt")
+
+
+@pytest.mark.parametrize(
+ "pipeline_class",
+ [
+ pytest.param(QwenImageEditPipeline, id="qwen-image-edit"),
+ pytest.param(QwenImageEditPlusPipeline, id="qwen-image-edit-plus"),
+ ],
+)
+def test_qwen_edit_validator_excludes_image_placeholders_from_budget(pipeline_class: type):
+ pipeline = object.__new__(pipeline_class)
+ nn.Module.__init__(pipeline)
+ pipeline.device = torch.device("cpu")
+ pipeline.text_encoder = _RejectingTextEncoder()
+ pipeline.tokenizer_max_length = 1024
+ pipeline.prompt_template_encode = "{}"
+ pipeline.prompt_template_encode_start_idx = 64
+ pipeline.tokenizer = _FakeTokenizer([30, 20])
+ pipeline.processor = _FakeProcessor(1500)
+
+ with pytest.raises(AssertionError, match="text encoder should not run"):
+ pipeline.encode_prompt(prompt="short prompt")
+
+
+@pytest.mark.parametrize(
+ "pipeline_class",
+ [
+ QwenImagePipeline,
+ QwenImageLayeredPipeline,
+ QwenImageEditPipeline,
+ QwenImageEditPlusPipeline,
+ ],
+)
+def test_forward_max_sequence_length_default_is_1024(pipeline_class: type):
+ assert inspect.signature(pipeline_class.forward).parameters["max_sequence_length"].default == 1024
diff --git a/tests/diffusion/models/wan2_2/conftest.py b/tests/diffusion/models/wan2_2/conftest.py
new file mode 100644
index 0000000000..f836fa545f
--- /dev/null
+++ b/tests/diffusion/models/wan2_2/conftest.py
@@ -0,0 +1,80 @@
+from __future__ import annotations
+
+from contextlib import contextmanager
+from types import SimpleNamespace
+
+import torch
+from torch import nn
+
+
+class StubTransformer(nn.Module):
+ def __init__(self, *, name: str = "transformer", in_channels: int = 4, out_channels: int = 4) -> None:
+ super().__init__()
+ self.name = name
+ self.config = SimpleNamespace(
+ patch_size=(1, 2, 2),
+ in_channels=in_channels,
+ out_channels=out_channels,
+ image_dim=None,
+ )
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return torch.float32
+
+ def forward(self, **kwargs):
+ hidden_states = kwargs["hidden_states"]
+ return (torch.zeros_like(hidden_states[:, : self.config.out_channels]),)
+
+
+class StubScheduler:
+ def __init__(self, timesteps: list[int]) -> None:
+ self.timesteps = torch.tensor(timesteps, dtype=torch.int64)
+ self.config = SimpleNamespace(num_train_timesteps=1000)
+ self.set_timesteps_calls: list[tuple[int, torch.device]] = []
+
+ def set_timesteps(self, num_steps: int, device: torch.device) -> None:
+ self.set_timesteps_calls.append((num_steps, device))
+
+
+class StubVAE:
+ dtype = torch.float32
+
+ def __init__(self, z_dim: int = 4) -> None:
+ self.config = SimpleNamespace(
+ z_dim=z_dim,
+ scale_factor_temporal=4,
+ scale_factor_spatial=8,
+ latents_mean=[0.0] * z_dim,
+ latents_std=[1.0] * z_dim,
+ )
+
+ def encode(self, video: torch.Tensor):
+ latent_frames = (video.shape[2] + self.config.scale_factor_temporal - 1) // self.config.scale_factor_temporal
+ latent_height = video.shape[-2] // self.config.scale_factor_spatial
+ latent_width = video.shape[-1] // self.config.scale_factor_spatial
+ latents = torch.ones(
+ video.shape[0],
+ self.config.z_dim,
+ latent_frames,
+ latent_height,
+ latent_width,
+ dtype=video.dtype,
+ device=video.device,
+ )
+ return SimpleNamespace(latents=latents)
+
+ def decode(self, latents: torch.Tensor, return_dict: bool = False):
+ del return_dict
+ return (latents,)
+
+
+@contextmanager
+def noop_progress_bar(*args, **kwargs):
+ del args, kwargs
+
+ class Bar:
+ def update(self) -> None:
+ return None
+
+ yield Bar()
diff --git a/tests/diffusion/models/wan2_2/test_wan22_i2v_pipeline.py b/tests/diffusion/models/wan2_2/test_wan22_i2v_pipeline.py
new file mode 100644
index 0000000000..04e834ac47
--- /dev/null
+++ b/tests/diffusion/models/wan2_2/test_wan22_i2v_pipeline.py
@@ -0,0 +1,126 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from types import SimpleNamespace
+
+import pytest
+import torch
+from PIL import Image
+from torch import nn
+
+from tests.diffusion.models.wan2_2.conftest import StubTransformer, StubVAE, noop_progress_bar
+from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import (
+ Wan22I2VPipeline,
+ get_wan22_i2v_pre_process_func,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion]
+
+
+def _make_i2v_pipeline(*, expand_timesteps: bool) -> Wan22I2VPipeline:
+ pipeline = object.__new__(Wan22I2VPipeline)
+ nn.Module.__init__(pipeline)
+ pipeline.device = torch.device("cpu")
+ pipeline.transformer = StubTransformer(name="high", in_channels=8, out_channels=4)
+ pipeline.transformer_2 = StubTransformer(name="low", in_channels=8, out_channels=4)
+ pipeline.vae = StubVAE(z_dim=4)
+ pipeline.vae_scale_factor_temporal = 4
+ pipeline.vae_scale_factor_spatial = 8
+ pipeline.expand_timesteps = expand_timesteps
+ pipeline.progress_bar = noop_progress_bar
+ return pipeline
+
+
+def test_i2v_preprocess_requires_image_and_resizes_to_480p_aspect() -> None:
+ preprocess = get_wan22_i2v_pre_process_func(SimpleNamespace())
+ request = SimpleNamespace(
+ prompts=[{"prompt": "p", "multi_modal_data": {"image": Image.new("RGB", (320, 160), "red")}}],
+ sampling_params=SimpleNamespace(height=None, width=None),
+ )
+
+ result = preprocess(request)
+ prompt = result.prompts[0]
+
+ assert result.sampling_params.height == 432
+ assert result.sampling_params.width == 880
+ assert prompt["multi_modal_data"]["image"].size == (880, 432)
+ assert prompt["additional_information"]["preprocessed_image"].shape[-2:] == (432, 880)
+
+ missing_image = SimpleNamespace(
+ prompts=[{"prompt": "p", "multi_modal_data": {}}],
+ sampling_params=SimpleNamespace(height=None, width=None),
+ )
+ with pytest.raises(ValueError, match="No image is provided"):
+ preprocess(missing_image)
+
+
+def test_i2v_diffuse_selects_stage_guidance_and_expands_timesteps() -> None:
+ pipeline = _make_i2v_pipeline(expand_timesteps=True)
+ latents = torch.zeros(1, 4, 2, 4, 4)
+ condition = torch.ones_like(latents)
+ first_frame_mask = torch.ones(1, 1, 2, 4, 4)
+ first_frame_mask[:, :, 0] = 0
+ timesteps = torch.tensor([900, 100])
+
+ calls = []
+
+ def fake_predict_noise_maybe_with_cfg(**kwargs):
+ positive = kwargs["positive_kwargs"]
+ calls.append(
+ {
+ "model": positive["current_model"].name,
+ "scale": kwargs["true_cfg_scale"],
+ "timestep_shape": tuple(positive["timestep"].shape),
+ "timestep_values": positive["timestep"].clone(),
+ "hidden_states": positive["hidden_states"].clone(),
+ }
+ )
+ return torch.ones_like(latents)
+
+ pipeline.predict_noise_maybe_with_cfg = fake_predict_noise_maybe_with_cfg # type: ignore[method-assign]
+ pipeline.scheduler_step_maybe_with_cfg = lambda noise, t, current, cfg: current + noise # type: ignore[method-assign]
+
+ result = pipeline.diffuse(
+ latents=latents,
+ timesteps=timesteps,
+ prompt_embeds=torch.zeros(1, 2, 3),
+ negative_prompt_embeds=None,
+ image_embeds=None,
+ guidance_low=1.0,
+ guidance_high=2.0,
+ boundary_timestep=500.0,
+ dtype=torch.float32,
+ attention_kwargs={},
+ condition=condition,
+ first_frame_mask=first_frame_mask,
+ )
+
+ assert [call["model"] for call in calls] == ["high", "low"]
+ assert [call["scale"] for call in calls] == [1.0, 2.0]
+ assert calls[0]["timestep_shape"] == (1, 8)
+ timestep_dtype = calls[0]["timestep_values"].dtype
+ torch.testing.assert_close(calls[0]["timestep_values"][0, :4], torch.zeros(4, dtype=timestep_dtype))
+ torch.testing.assert_close(calls[0]["timestep_values"][0, 4:], torch.full((4,), 900, dtype=timestep_dtype))
+ torch.testing.assert_close(calls[0]["hidden_states"][:, :, 0], torch.ones(1, 4, 4, 4))
+ torch.testing.assert_close(result, torch.full_like(latents, 2.0))
+
+
+def test_i2v_prepare_latents_builds_expand_condition_and_first_frame_mask() -> None:
+ pipeline = _make_i2v_pipeline(expand_timesteps=True)
+ latents, condition, first_frame_mask = pipeline.prepare_latents(
+ image=torch.zeros(1, 3, 16, 16),
+ batch_size=1,
+ num_channels_latents=4,
+ height=16,
+ width=16,
+ num_frames=5,
+ dtype=torch.float32,
+ device=torch.device("cpu"),
+ generator=torch.Generator(device="cpu").manual_seed(0),
+ )
+
+ assert latents.shape == (1, 4, 2, 2, 2)
+ assert condition.shape == (1, 4, 1, 2, 2)
+ assert first_frame_mask.shape == (1, 1, 2, 2, 2)
+ assert first_frame_mask[:, :, 0].sum() == 0
+ assert first_frame_mask[:, :, 1].sum() == 4
diff --git a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py
new file mode 100644
index 0000000000..64c2b271c9
--- /dev/null
+++ b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py
@@ -0,0 +1,150 @@
+from types import SimpleNamespace
+
+import PIL.Image
+import pytest
+import torch
+from torch import nn
+
+from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import (
+ WAN22_MAX_SEQUENCE_LENGTH,
+ Wan22Pipeline,
+)
+from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import (
+ Wan22I2VPipeline,
+)
+from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_ti2v import (
+ Wan22TI2VPipeline,
+)
+from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_vace import (
+ Wan22VACEPipeline,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+class _RejectingTextEncoder:
+ dtype = torch.float32
+
+ def __call__(self, *args, **kwargs):
+ raise AssertionError("text encoder should not run for prompts that exceed max_sequence_length")
+
+
+class _FakeTokenBatch:
+ def __init__(self, total_sequence_length: int):
+ attention_mask = torch.ones((1, total_sequence_length), dtype=torch.long)
+ self.input_ids = attention_mask.clone()
+ self.attention_mask = attention_mask
+
+
+class _FakeTokenizer:
+ def __init__(self, total_sequence_length: int):
+ self.total_sequence_length = total_sequence_length
+
+ def __call__(self, *args, **kwargs):
+ return _FakeTokenBatch(self.total_sequence_length)
+
+
+PIPELINE_CASES = [
+ pytest.param(Wan22Pipeline, id="wan22-t2v"),
+ pytest.param(Wan22I2VPipeline, id="wan22-i2v"),
+ pytest.param(Wan22TI2VPipeline, id="wan22-ti2v"),
+ pytest.param(Wan22VACEPipeline, id="wan22-vace"),
+]
+
+
+def _make_pipeline(pipeline_class: type, *, total_sequence_length: int):
+ pipeline = object.__new__(pipeline_class)
+ nn.Module.__init__(pipeline)
+ pipeline.device = torch.device("cpu")
+ pipeline.text_encoder = _RejectingTextEncoder()
+ pipeline.tokenizer = _FakeTokenizer(total_sequence_length)
+ pipeline.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH
+ return pipeline
+
+
+@pytest.mark.parametrize("pipeline_class", PIPELINE_CASES)
+def test_encode_prompt_rejects_prompt_longer_than_default_max_sequence_length(pipeline_class: type):
+ pipeline = _make_pipeline(pipeline_class, total_sequence_length=WAN22_MAX_SEQUENCE_LENGTH + 1)
+
+ with pytest.raises(ValueError, match=r"got 513 tokens, but `max_sequence_length` is 512"):
+ pipeline.encode_prompt(prompt="prompt")
+
+
+@pytest.mark.parametrize("pipeline_class", PIPELINE_CASES)
+def test_encode_prompt_rejects_prompt_longer_than_explicit_max_sequence_length(pipeline_class: type):
+ pipeline = _make_pipeline(pipeline_class, total_sequence_length=17)
+
+ with pytest.raises(ValueError, match=r"got 17 tokens, but `max_sequence_length` is 16"):
+ pipeline.encode_prompt(prompt="prompt", max_sequence_length=16)
+
+
+def _sampling_params(**overrides):
+ defaults = dict(
+ height=None,
+ width=None,
+ num_frames=None,
+ num_inference_steps=None,
+ generator=None,
+ guidance_scale_provided=False,
+ guidance_scale_2=None,
+ boundary_ratio=None,
+ num_outputs_per_prompt=0,
+ max_sequence_length=None,
+ seed=None,
+ extra_args={},
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ )
+ defaults.update(overrides)
+ return SimpleNamespace(**defaults)
+
+
+@pytest.mark.parametrize(
+ ("pipeline_class", "prompt_value", "forward_kwargs"),
+ [
+ pytest.param(Wan22Pipeline, "prompt", {}, id="wan22-t2v"),
+ pytest.param(
+ Wan22I2VPipeline,
+ {"prompt": "prompt", "multi_modal_data": {"image": PIL.Image.new("RGB", (64, 64))}},
+ {"image": PIL.Image.new("RGB", (64, 64))},
+ id="wan22-i2v",
+ ),
+ pytest.param(
+ Wan22TI2VPipeline,
+ {"prompt": "prompt", "multi_modal_data": {"image": PIL.Image.new("RGB", (64, 64))}},
+ {"image": PIL.Image.new("RGB", (64, 64))},
+ id="wan22-ti2v",
+ ),
+ pytest.param(Wan22VACEPipeline, "prompt", {}, id="wan22-vace"),
+ ],
+)
+def test_forward_defaults_to_wan22_tokenizer_max_length(
+ pipeline_class: type,
+ prompt_value,
+ forward_kwargs,
+):
+ pipeline = object.__new__(pipeline_class)
+ nn.Module.__init__(pipeline)
+ pipeline.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH
+ pipeline.boundary_ratio = None
+ pipeline.vae_scale_factor_temporal = 4
+ pipeline.vae_scale_factor_spatial = 8
+ pipeline.transformer_config = SimpleNamespace(patch_size=(1, 2, 2))
+
+ captured = {}
+
+ def _fake_check_inputs(*args, **kwargs):
+ captured["max_sequence_length"] = kwargs["max_sequence_length"]
+ raise RuntimeError("stop after capture")
+
+ pipeline.check_inputs = _fake_check_inputs
+
+ req = SimpleNamespace(
+ prompts=[prompt_value],
+ sampling_params=_sampling_params(),
+ )
+
+ with pytest.raises(RuntimeError, match="stop after capture"):
+ pipeline.forward(req, **forward_kwargs)
+
+ assert captured["max_sequence_length"] == WAN22_MAX_SEQUENCE_LENGTH
diff --git a/tests/diffusion/models/wan2_2/test_wan22_pipeline_diffuse.py b/tests/diffusion/models/wan2_2/test_wan22_pipeline_diffuse.py
new file mode 100644
index 0000000000..54bb672ef8
--- /dev/null
+++ b/tests/diffusion/models/wan2_2/test_wan22_pipeline_diffuse.py
@@ -0,0 +1,155 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from contextlib import contextmanager
+from types import SimpleNamespace
+
+import pytest
+import torch
+from torch import nn
+
+from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion]
+
+
+class _StubTransformer(nn.Module):
+ @property
+ def dtype(self) -> torch.dtype:
+ return torch.float32
+
+
+class _StubScheduler:
+ def __init__(self, timesteps: list[int]) -> None:
+ self.timesteps = torch.tensor(timesteps, dtype=torch.int64)
+ self.config = SimpleNamespace(num_train_timesteps=1000)
+ self.set_timesteps_calls: list[tuple[int, torch.device]] = []
+
+ def set_timesteps(self, num_steps: int, device: torch.device) -> None:
+ self.set_timesteps_calls.append((num_steps, device))
+
+
+@contextmanager
+def _noop_progress_bar(*args, **kwargs):
+ del args, kwargs
+
+ class _Bar:
+ def update(self) -> None:
+ return None
+
+ yield _Bar()
+
+
+def _make_pipeline() -> Wan22Pipeline:
+ pipeline = object.__new__(Wan22Pipeline)
+ nn.Module.__init__(pipeline)
+ pipeline.device = torch.device("cpu")
+ pipeline.transformer = _StubTransformer()
+ pipeline.transformer_2 = None
+ pipeline.transformer_config = SimpleNamespace(patch_size=(1, 2, 2), in_channels=4, out_channels=4)
+ pipeline.scheduler = _StubScheduler([9, 5])
+ pipeline.od_config = SimpleNamespace(flow_shift=5.0)
+ pipeline._sample_solver = "unipc"
+ pipeline._flow_shift = 5.0
+ pipeline.vae_scale_factor_temporal = 4
+ pipeline.vae_scale_factor_spatial = 8
+ pipeline.boundary_ratio = 0.875
+ pipeline.expand_timesteps = False
+ pipeline._guidance_scale = None
+ pipeline._guidance_scale_2 = None
+ pipeline._num_timesteps = None
+ pipeline._current_timestep = None
+ pipeline.check_inputs = lambda **kwargs: None
+ pipeline.prepare_latents = lambda **kwargs: torch.zeros((1, 4, 1, 8, 8), dtype=torch.float32)
+ pipeline.progress_bar = _noop_progress_bar
+ return pipeline
+
+
+def test_forward_delegates_denoising_to_diffuse(monkeypatch) -> None:
+ pipeline = _make_pipeline()
+
+ prompt_embeds = torch.randn(1, 8)
+ captured: dict[str, object] = {}
+
+ def _fake_diffuse(**kwargs):
+ captured.update(kwargs)
+ return kwargs["latents"] + 1
+
+ pipeline.diffuse = _fake_diffuse # type: ignore[method-assign]
+
+ req = SimpleNamespace(
+ prompts=["prompt"],
+ sampling_params=SimpleNamespace(
+ height=None,
+ width=None,
+ num_frames=1,
+ num_inference_steps=2,
+ guidance_scale_provided=False,
+ guidance_scale=None,
+ guidance_scale_2=None,
+ boundary_ratio=None,
+ generator=None,
+ seed=None,
+ num_outputs_per_prompt=1,
+ max_sequence_length=32,
+ latents=None,
+ extra_args={},
+ ),
+ )
+
+ output = pipeline.forward(req, prompt_embeds=prompt_embeds, output_type="latent", guidance_scale=1.0)
+
+ assert torch.equal(output.output, torch.ones((1, 4, 1, 8, 8)))
+ assert torch.equal(captured["timesteps"], pipeline.scheduler.timesteps)
+ assert captured["guidance_low"] == 1.0
+ assert captured["guidance_high"] == 1.0
+ assert captured["boundary_timestep"] == pytest.approx(875.0)
+ assert captured["latent_condition"] is None
+ assert captured["first_frame_mask"] is None
+ assert pipeline.scheduler.set_timesteps_calls == [(2, torch.device("cpu"))]
+
+
+def test_diffuse_runs_prediction_and_scheduler_for_each_timestep() -> None:
+ pipeline = _make_pipeline()
+ latents = torch.zeros((1, 1, 1, 2, 2), dtype=torch.float32)
+ timesteps = torch.tensor([7, 3], dtype=torch.int64)
+ prompt_embeds = torch.randn(1, 8)
+
+ predict_calls: list[dict[str, object]] = []
+ scheduler_calls: list[tuple[float, int, float, bool]] = []
+
+ def _fake_predict_noise_maybe_with_cfg(**kwargs):
+ predict_calls.append(kwargs)
+ timestep = kwargs["positive_kwargs"]["timestep"]
+ assert isinstance(timestep, torch.Tensor)
+ return torch.full_like(latents, float(timestep[0].item()))
+
+ def _fake_scheduler_step_maybe_with_cfg(noise_pred, t, current_latents, do_true_cfg):
+ scheduler_calls.append(
+ (float(noise_pred[0, 0, 0, 0, 0]), int(t.item()), float(current_latents.sum()), do_true_cfg)
+ )
+ return current_latents + noise_pred
+
+ pipeline.predict_noise_maybe_with_cfg = _fake_predict_noise_maybe_with_cfg # type: ignore[method-assign]
+ pipeline.scheduler_step_maybe_with_cfg = _fake_scheduler_step_maybe_with_cfg # type: ignore[method-assign]
+
+ result = pipeline.diffuse(
+ latents=latents,
+ timesteps=timesteps,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=None,
+ guidance_low=1.0,
+ guidance_high=2.0,
+ boundary_timestep=5.0,
+ dtype=torch.float32,
+ attention_kwargs={},
+ )
+
+ assert len(predict_calls) == 2
+ assert predict_calls[0]["true_cfg_scale"] == 1.0
+ assert predict_calls[1]["true_cfg_scale"] == 2.0
+ assert scheduler_calls == [
+ (7.0, 7, 0.0, False),
+ (3.0, 3, 28.0, False),
+ ]
+ assert torch.equal(result, torch.full_like(latents, 10.0))
diff --git a/tests/diffusion/models/wan2_2/test_wan22_pipeline_helpers.py b/tests/diffusion/models/wan2_2/test_wan22_pipeline_helpers.py
new file mode 100644
index 0000000000..3147178697
--- /dev/null
+++ b/tests/diffusion/models/wan2_2/test_wan22_pipeline_helpers.py
@@ -0,0 +1,81 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import json
+from types import SimpleNamespace
+
+import pytest
+import torch
+
+import vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 as wan22_module
+from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import (
+ create_transformer_from_config,
+ load_transformer_config,
+ retrieve_latents,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion]
+
+
+class _LatentDist:
+ def sample(self, generator):
+ assert isinstance(generator, torch.Generator)
+ return torch.tensor([1.0])
+
+ def mode(self):
+ return torch.tensor([2.0])
+
+
+def test_retrieve_latents_supports_sample_mode_argmax_and_direct_latents() -> None:
+ generator = torch.Generator(device="cpu")
+
+ assert retrieve_latents(SimpleNamespace(latent_dist=_LatentDist()), generator).item() == 1.0
+ assert retrieve_latents(SimpleNamespace(latent_dist=_LatentDist()), sample_mode="argmax").item() == 2.0
+ torch.testing.assert_close(retrieve_latents(SimpleNamespace(latents=torch.tensor([3.0]))), torch.tensor([3.0]))
+
+
+def test_retrieve_latents_rejects_unknown_encoder_output() -> None:
+ with pytest.raises(AttributeError, match="Could not access latents"):
+ retrieve_latents(SimpleNamespace())
+
+
+def test_load_transformer_config_reads_local_subfolder_config(tmp_path) -> None:
+ config_dir = tmp_path / "transformer_2"
+ config_dir.mkdir(parents=True)
+ (config_dir / "config.json").write_text(json.dumps({"patch_size": [1, 2, 2], "num_layers": 2}))
+
+ assert load_transformer_config(str(tmp_path), "transformer_2") == {"patch_size": [1, 2, 2], "num_layers": 2}
+ assert load_transformer_config(str(tmp_path), "missing") == {}
+
+
+def test_create_transformer_from_config_maps_supported_keys(monkeypatch) -> None:
+ captured = {}
+
+ class FakeTransformer:
+ def __init__(self, **kwargs) -> None:
+ captured.update(kwargs)
+
+ monkeypatch.setattr(wan22_module, "WanTransformer3DModel", FakeTransformer)
+
+ transformer = create_transformer_from_config(
+ {
+ "patch_size": [1, 2, 2],
+ "num_attention_heads": 8,
+ "attention_head_dim": 128,
+ "in_channels": 16,
+ "out_channels": 16,
+ "text_dim": 4096,
+ "vace_layers": [0],
+ "ignored": "value",
+ }
+ )
+
+ assert isinstance(transformer, FakeTransformer)
+ assert captured == {
+ "patch_size": (1, 2, 2),
+ "num_attention_heads": 8,
+ "attention_head_dim": 128,
+ "in_channels": 16,
+ "out_channels": 16,
+ "text_dim": 4096,
+ }
diff --git a/tests/diffusion/models/wan2_2/test_wan22_ti2v_pipeline.py b/tests/diffusion/models/wan2_2/test_wan22_ti2v_pipeline.py
new file mode 100644
index 0000000000..983350c4cf
--- /dev/null
+++ b/tests/diffusion/models/wan2_2/test_wan22_ti2v_pipeline.py
@@ -0,0 +1,98 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from types import SimpleNamespace
+
+import pytest
+import torch
+from PIL import Image
+from torch import nn
+
+from tests.diffusion.models.wan2_2.conftest import StubTransformer, StubVAE, noop_progress_bar
+from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_ti2v import (
+ Wan22TI2VPipeline,
+ get_wan22_ti2v_pre_process_func,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion]
+
+
+def _make_ti2v_pipeline() -> Wan22TI2VPipeline:
+ pipeline = object.__new__(Wan22TI2VPipeline)
+ nn.Module.__init__(pipeline)
+ pipeline.device = torch.device("cpu")
+ pipeline.transformer = StubTransformer(in_channels=4, out_channels=4)
+ pipeline.vae = StubVAE(z_dim=4)
+ pipeline.vae_scale_factor_temporal = 4
+ pipeline.vae_scale_factor_spatial = 8
+ pipeline.progress_bar = noop_progress_bar
+ return pipeline
+
+
+def test_ti2v_preprocess_uses_720p_area_for_image_condition() -> None:
+ preprocess = get_wan22_ti2v_pre_process_func(SimpleNamespace())
+ request = SimpleNamespace(
+ prompts=[{"prompt": "p", "multi_modal_data": {"image": Image.new("RGB", (320, 160), "blue")}}],
+ sampling_params=SimpleNamespace(height=None, width=None),
+ )
+
+ result = preprocess(request)
+
+ assert result.sampling_params.height == 672
+ assert result.sampling_params.width == 1344
+ assert result.prompts[0]["multi_modal_data"]["image"].size == (1344, 672)
+ assert result.prompts[0]["additional_information"]["preprocessed_image"].shape[-2:] == (672, 1344)
+
+
+def test_ti2v_diffuse_without_image_condition_expands_patch_timesteps() -> None:
+ pipeline = _make_ti2v_pipeline()
+ latents = torch.zeros(1, 4, 2, 4, 4)
+ calls = []
+
+ def fake_predict_noise_maybe_with_cfg(**kwargs):
+ calls.append(kwargs)
+ return torch.ones_like(latents)
+
+ pipeline.predict_noise_maybe_with_cfg = fake_predict_noise_maybe_with_cfg # type: ignore[method-assign]
+ pipeline.scheduler_step_maybe_with_cfg = lambda noise, t, current, cfg: current + noise # type: ignore[method-assign]
+
+ result = pipeline.diffuse(
+ latents=latents,
+ timesteps=torch.tensor([7]),
+ prompt_embeds=torch.zeros(1, 2, 3),
+ negative_prompt_embeds=torch.zeros(1, 2, 3),
+ guidance_scale=3.0,
+ dtype=torch.float32,
+ attention_kwargs={"a": "b"},
+ num_latent_frames=2,
+ latent_height=4,
+ latent_width=4,
+ )
+
+ positive = calls[0]["positive_kwargs"]
+ assert calls[0]["do_true_cfg"] is True
+ assert positive["timestep"].shape == (1, 8)
+ torch.testing.assert_close(positive["timestep"], torch.full((1, 8), 7, dtype=positive["timestep"].dtype))
+ torch.testing.assert_close(positive["hidden_states"], latents)
+ torch.testing.assert_close(result, torch.ones_like(latents))
+
+
+def test_ti2v_prepare_i2v_latents_encodes_condition_and_masks_first_frame() -> None:
+ pipeline = _make_ti2v_pipeline()
+ latents, latent_condition, first_frame_mask = pipeline.prepare_i2v_latents(
+ image=torch.zeros(1, 3, 16, 16),
+ batch_size=1,
+ num_channels_latents=4,
+ height=16,
+ width=16,
+ num_frames=5,
+ dtype=torch.float32,
+ device=torch.device("cpu"),
+ generator=None,
+ latents=torch.zeros(1, 4, 2, 2, 2),
+ )
+
+ torch.testing.assert_close(latents, torch.zeros(1, 4, 2, 2, 2))
+ assert latent_condition.shape == (1, 4, 1, 2, 2)
+ assert first_frame_mask[:, :, 0].sum() == 0
+ assert first_frame_mask[:, :, 1].sum() == 4
diff --git a/tests/diffusion/models/wan2_2/test_wan22_vace_pipeline.py b/tests/diffusion/models/wan2_2/test_wan22_vace_pipeline.py
new file mode 100644
index 0000000000..9fa9b67c49
--- /dev/null
+++ b/tests/diffusion/models/wan2_2/test_wan22_vace_pipeline.py
@@ -0,0 +1,137 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from types import SimpleNamespace
+
+import pytest
+import torch
+from PIL import Image
+from torch import nn
+
+from tests.diffusion.models.wan2_2.conftest import StubTransformer, StubVAE, noop_progress_bar
+from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_vace import (
+ Wan22VACEPipeline,
+ create_vace_transformer_from_config,
+ get_wan22_vace_pre_process_func,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion]
+
+
+def _make_vace_pipeline() -> Wan22VACEPipeline:
+ pipeline = object.__new__(Wan22VACEPipeline)
+ nn.Module.__init__(pipeline)
+ pipeline.device = torch.device("cpu")
+ pipeline.transformer = StubTransformer(in_channels=4, out_channels=4)
+ pipeline.transformer_config = pipeline.transformer.config
+ pipeline.vae = StubVAE(z_dim=4)
+ pipeline.vae_scale_factor_temporal = 4
+ pipeline.vae_scale_factor_spatial = 8
+ pipeline.progress_bar = noop_progress_bar
+ return pipeline
+
+
+def test_vace_preprocess_collects_reference_video_and_mask_inputs() -> None:
+ preprocess = get_wan22_vace_pre_process_func(SimpleNamespace())
+ ref = Image.new("RGB", (320, 160), "green")
+ frame = Image.new("RGB", (64, 64), "black")
+ mask = Image.new("L", (64, 64), 255)
+ request = SimpleNamespace(
+ prompts=[
+ {
+ "prompt": "p",
+ "multi_modal_data": {
+ "image": ref,
+ "video": [frame],
+ "mask": mask,
+ },
+ }
+ ],
+ sampling_params=SimpleNamespace(height=None, width=None),
+ )
+
+ result = preprocess(request)
+ additional_info = result.prompts[0]["additional_information"]
+
+ assert result.sampling_params.height == 432
+ assert result.sampling_params.width == 880
+ assert additional_info["reference_images"] == [ref]
+ assert additional_info["source_video"] == [frame]
+ assert additional_info["mask"] == [mask]
+
+
+def test_create_vace_transformer_from_config_maps_vace_specific_keys(monkeypatch) -> None:
+ captured = {}
+
+ class FakeVACETransformer:
+ def __init__(self, **kwargs) -> None:
+ captured.update(kwargs)
+
+ monkeypatch.setattr(
+ "vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_vace.WanVACETransformer3DModel",
+ FakeVACETransformer,
+ )
+
+ transformer = create_vace_transformer_from_config(
+ {
+ "patch_size": [1, 2, 2],
+ "in_channels": 96,
+ "out_channels": 16,
+ "vace_layers": [0, 1, 2],
+ "vace_in_channels": 132,
+ "unknown": "ignored",
+ }
+ )
+
+ assert isinstance(transformer, FakeVACETransformer)
+ assert captured == {
+ "patch_size": (1, 2, 2),
+ "in_channels": 96,
+ "out_channels": 16,
+ "vace_layers": [0, 1, 2],
+ "vace_in_channels": 132,
+ }
+
+
+def test_vace_prepare_masks_encodes_spatial_stride_and_reference_padding() -> None:
+ pipeline = _make_vace_pipeline()
+ mask = torch.ones(1, 3, 5, 16, 16)
+ reference_images = [[torch.zeros(3, 16, 16), torch.zeros(3, 16, 16)]]
+
+ encoded = pipeline.prepare_masks(mask, reference_images)
+
+ assert encoded.shape == (1, 64, 4, 2, 2)
+ torch.testing.assert_close(encoded[:, :, :2], torch.zeros(1, 64, 2, 2, 2))
+ torch.testing.assert_close(encoded[:, :, 2:], torch.ones(1, 64, 2, 2, 2))
+
+
+def test_vace_diffuse_passes_context_and_scale_to_cfg_branches() -> None:
+ pipeline = _make_vace_pipeline()
+ latents = torch.zeros(1, 4, 1, 2, 2)
+ vace_context = torch.ones(1, 12, 1, 2, 2)
+ calls = []
+
+ def fake_predict_noise_maybe_with_cfg(**kwargs):
+ calls.append(kwargs)
+ return torch.ones_like(latents)
+
+ pipeline.predict_noise_maybe_with_cfg = fake_predict_noise_maybe_with_cfg # type: ignore[method-assign]
+ pipeline.scheduler_step_maybe_with_cfg = lambda noise, t, current, cfg: current + noise # type: ignore[method-assign]
+
+ result = pipeline.diffuse(
+ latents=latents,
+ timesteps=torch.tensor([5]),
+ prompt_embeds=torch.zeros(1, 2, 3),
+ negative_prompt_embeds=torch.zeros(1, 2, 3),
+ guidance_scale=4.0,
+ dtype=torch.float32,
+ attention_kwargs={},
+ vace_context=vace_context,
+ vace_context_scale=0.75,
+ )
+
+ assert calls[0]["do_true_cfg"] is True
+ assert calls[0]["true_cfg_scale"] == 4.0
+ assert calls[0]["positive_kwargs"]["vace_context"] is vace_context
+ assert calls[0]["negative_kwargs"]["vace_context_scale"] == 0.75
+ torch.testing.assert_close(result, torch.ones_like(latents))
diff --git a/tests/diffusion/offloader/test_sequential_backend.py b/tests/diffusion/offloader/test_sequential_backend.py
index d18637a780..2539cc0689 100644
--- a/tests/diffusion/offloader/test_sequential_backend.py
+++ b/tests/diffusion/offloader/test_sequential_backend.py
@@ -3,8 +3,6 @@
"""Unit tests for SequentialOffloadBackend."""
-from unittest.mock import patch
-
import pytest
import torch
from torch import nn
@@ -44,7 +42,7 @@ def mock(self):
class TestMoveParamsPinMemory:
- def test_dtensor_skips_pin_memory(self, accelerator_device):
+ def test_dtensor_skips_pin_memory(self, accelerator_device, monkeypatch: pytest.MonkeyPatch):
"""DTensor should skip pin_memory to avoid RuntimeError."""
module = _create_simple_module().to(accelerator_device)
tracker, mock_pin = _track_pin_memory_calls()
@@ -56,73 +54,73 @@ def fake_isinstance(obj, cls):
return True
return original_isinstance(obj, cls)
- with patch.object(torch.Tensor, "pin_memory", mock_pin):
- with patch("builtins.isinstance", fake_isinstance):
- hook = SequentialOffloadHook(
- offload_targets=[],
- device=accelerator_device,
- pin_memory=True,
- use_hsdp=False,
- )
- hook._move_params(
- module,
- torch.device("cpu"),
- non_blocking=False,
- pin_memory=True,
- )
- assert not tracker["called"], "pin_memory should not be called for DTensor"
-
- def test_regular_tensor_calls_pin_memory(self, accelerator_device):
+ monkeypatch.setattr(torch.Tensor, "pin_memory", mock_pin)
+ monkeypatch.setattr("builtins.isinstance", fake_isinstance)
+ hook = SequentialOffloadHook(
+ offload_targets=[],
+ device=accelerator_device,
+ pin_memory=True,
+ use_hsdp=False,
+ )
+ hook._move_params(
+ module,
+ torch.device("cpu"),
+ non_blocking=False,
+ pin_memory=True,
+ )
+ assert not tracker["called"], "pin_memory should not be called for DTensor"
+
+ def test_regular_tensor_calls_pin_memory(self, accelerator_device, monkeypatch: pytest.MonkeyPatch):
"""Regular tensor should call pin_memory when moving to CPU."""
module = _create_simple_module().to(accelerator_device)
tracker, mock_pin = _track_pin_memory_calls()
- with patch.object(torch.Tensor, "pin_memory", mock_pin):
- hook = SequentialOffloadHook(
- offload_targets=[],
- device=accelerator_device,
- pin_memory=True,
- use_hsdp=False,
- )
- hook._move_params(
- module,
- torch.device("cpu"),
- non_blocking=False,
- pin_memory=True,
- )
- assert tracker["called"], "pin_memory should be called for regular tensors"
-
- def test_pin_memory_skipped_when_disabled(self, accelerator_device):
+ monkeypatch.setattr(torch.Tensor, "pin_memory", mock_pin)
+ hook = SequentialOffloadHook(
+ offload_targets=[],
+ device=accelerator_device,
+ pin_memory=True,
+ use_hsdp=False,
+ )
+ hook._move_params(
+ module,
+ torch.device("cpu"),
+ non_blocking=False,
+ pin_memory=True,
+ )
+ assert tracker["called"], "pin_memory should be called for regular tensors"
+
+ def test_pin_memory_skipped_when_disabled(self, accelerator_device, monkeypatch: pytest.MonkeyPatch):
"""pin_memory should not be called when pin_memory=False."""
module = _create_simple_module().to(accelerator_device)
tracker, mock_pin = _track_pin_memory_calls()
- with patch.object(torch.Tensor, "pin_memory", mock_pin):
- hook = SequentialOffloadHook(
- offload_targets=[],
- device=accelerator_device,
- pin_memory=False,
- use_hsdp=False,
- )
- hook._move_params(
- module,
- torch.device("cpu"),
- non_blocking=False,
- pin_memory=False,
- )
- assert not tracker["called"], "pin_memory should not be called when disabled"
-
- def test_pin_memory_skipped_for_non_cpu_target(self, accelerator_device):
+ monkeypatch.setattr(torch.Tensor, "pin_memory", mock_pin)
+ hook = SequentialOffloadHook(
+ offload_targets=[],
+ device=accelerator_device,
+ pin_memory=False,
+ use_hsdp=False,
+ )
+ hook._move_params(
+ module,
+ torch.device("cpu"),
+ non_blocking=False,
+ pin_memory=False,
+ )
+ assert not tracker["called"], "pin_memory should not be called when disabled"
+
+ def test_pin_memory_skipped_for_non_cpu_target(self, accelerator_device, monkeypatch: pytest.MonkeyPatch):
"""pin_memory should not be called for non-CPU targets."""
module = _create_simple_module().to("cpu")
tracker, mock_pin = _track_pin_memory_calls()
- with patch.object(torch.Tensor, "pin_memory", mock_pin):
- hook = SequentialOffloadHook(
- offload_targets=[],
- device=torch.device("cpu"),
- pin_memory=True,
- use_hsdp=False,
- )
- hook._move_params(module, accelerator_device, non_blocking=False, pin_memory=True)
- assert not tracker["called"], "pin_memory should not be called for non-CPU target"
+ monkeypatch.setattr(torch.Tensor, "pin_memory", mock_pin)
+ hook = SequentialOffloadHook(
+ offload_targets=[],
+ device=torch.device("cpu"),
+ pin_memory=True,
+ use_hsdp=False,
+ )
+ hook._move_params(module, accelerator_device, non_blocking=False, pin_memory=True)
+ assert not tracker["called"], "pin_memory should not be called for non-CPU target"
diff --git a/tests/diffusion/quantization/test_int8_config.py b/tests/diffusion/quantization/test_int8_config.py
index d4d5aa5a7f..875277ece4 100644
--- a/tests/diffusion/quantization/test_int8_config.py
+++ b/tests/diffusion/quantization/test_int8_config.py
@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for Int8 quantization config."""
-from unittest.mock import MagicMock, patch
-
import pytest
import torch
from pytest_mock import MockerFixture
@@ -102,7 +100,7 @@ def test_quantization_config_string_and_dict_equivalent():
assert config_str.quantization_config.activation_scheme == config_dict.quantization_config.activation_scheme
-def test_get_quant_method(mocker: MockerFixture):
+def test_get_quant_method(mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch):
"""Test for get_quant_method method for GPU"""
from vllm_omni.quantization.int8_config import Int8OnlineLinearMethod
@@ -111,18 +109,16 @@ def test_get_quant_method(mocker: MockerFixture):
def _fake_init(self, quant_config):
pass
- layer = MagicMock(spec=LinearBase)
+ layer = mocker.Mock(spec=LinearBase)
mocker.patch.object(Int8OnlineLinearMethod, "__init__", _fake_init)
prefix = "test_layer"
# Mock the platform to be GPU
- with (
- patch("vllm_omni.platforms.current_omni_platform.is_cuda", return_value=True),
- patch("vllm_omni.platforms.current_omni_platform.is_npu", return_value=False),
- ):
- method = config.get_quant_method(layer, prefix)
- assert isinstance(method, Int8OnlineLinearMethod)
+ monkeypatch.setattr(current_omni_platform, "is_cuda", lambda: True)
+ monkeypatch.setattr(current_omni_platform, "is_npu", lambda: False)
+ method = config.get_quant_method(layer, prefix)
+ assert isinstance(method, Int8OnlineLinearMethod)
# Test skipping quantization for a layer
config.ignored_layers = [prefix]
@@ -130,22 +126,20 @@ def _fake_init(self, quant_config):
assert isinstance(method, UnquantizedLinearMethod)
-def test_get_npu_quant_method():
+def test_get_npu_quant_method(mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch):
"""Test for get_quant_method method for NPU"""
from vllm_omni.quantization.int8_config import NPUInt8OnlineLinearMethod
config = build_quant_config("int8")
- layer = MagicMock(spec=LinearBase)
+ layer = mocker.Mock(spec=LinearBase)
prefix = "test_layer"
# Mock the platform to be NPU
- with (
- patch("vllm_omni.platforms.current_omni_platform.is_cuda", return_value=False),
- patch("vllm_omni.platforms.current_omni_platform.is_npu", return_value=True),
- ):
- method = config.get_quant_method(layer, prefix)
- assert isinstance(method, NPUInt8OnlineLinearMethod)
+ monkeypatch.setattr(current_omni_platform, "is_cuda", lambda: False)
+ monkeypatch.setattr(current_omni_platform, "is_npu", lambda: True)
+ method = config.get_quant_method(layer, prefix)
+ assert isinstance(method, NPUInt8OnlineLinearMethod)
# Test skipping quantization for a layer
config.ignored_layers = [prefix]
@@ -245,7 +239,7 @@ class TestNPUInt8LinearMethod:
@pytest.fixture
def mock_torch_npu(self, mocker):
- torch_npu = MagicMock()
+ torch_npu = mocker.MagicMock()
mocker.patch("vllm_omni.quantization.int8_config.torch_npu", return_value=torch_npu)
mocker.patch(
diff --git a/tests/diffusion/test_diffusion_scheduler.py b/tests/diffusion/test_diffusion_scheduler.py
index 4324ba1e63..a64d9920e0 100644
--- a/tests/diffusion/test_diffusion_scheduler.py
+++ b/tests/diffusion/test_diffusion_scheduler.py
@@ -4,10 +4,10 @@
import queue
import threading
from types import SimpleNamespace
-from unittest.mock import Mock, patch
import pytest
import torch
+from pytest_mock import MockerFixture
from vllm_omni.diffusion.data import DiffusionOutput, DiffusionRequestAbortedError
from vllm_omni.diffusion.diffusion_engine import DiffusionEngine
@@ -97,19 +97,19 @@ def initialize(self, od_config) -> None:
def add_request(self, request: OmniDiffusionRequest) -> str:
assert request is self._request
- self._state = Mock(sched_req_id=self._sched_req_id, req=request)
+ self._state = SimpleNamespace(sched_req_id=self._sched_req_id, req=request)
return self._sched_req_id
def schedule(self):
if self._scheduled or self._state is None:
- return Mock(
+ return SimpleNamespace(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
scheduled_req_ids=[],
is_empty=True,
)
self._scheduled = True
- return Mock(
+ return SimpleNamespace(
scheduled_new_reqs=[NewRequestData.from_state(self._state)],
scheduled_cached_reqs=CachedRequestData.make_empty(),
scheduled_req_ids=[self._state.sched_req_id],
@@ -153,7 +153,7 @@ def close(self) -> None:
class TestRequestScheduler:
def setup_method(self) -> None:
self.scheduler: RequestScheduler = RequestScheduler()
- self.scheduler.initialize(Mock())
+ self.scheduler.initialize(SimpleNamespace())
def test_single_request_success_lifecycle(self) -> None:
req_id = self.scheduler.add_request(_make_request("a"))
@@ -276,23 +276,23 @@ def test_request_id_mapping_lifecycle(self) -> None:
class TestDiffusionEngine:
- def test_add_req_and_wait_for_response_single_path(self) -> None:
+ def test_add_req_and_wait_for_response_single_path(self, mocker: MockerFixture) -> None:
engine = DiffusionEngine.__new__(DiffusionEngine)
engine.scheduler = RequestScheduler()
- engine.scheduler.initialize(Mock())
+ engine.scheduler.initialize(SimpleNamespace())
engine._rpc_lock = threading.RLock()
engine.abort_queue = queue.Queue()
request = _make_request("engine")
runner_output = _make_request_output("engine")
- engine.execute_fn = Mock(return_value=runner_output)
+ engine.execute_fn = mocker.Mock(return_value=runner_output)
output = engine.add_req_and_wait_for_response(request)
assert output is runner_output.result
engine.execute_fn.assert_called_once()
- def test_supports_scheduler_interface_injection(self) -> None:
+ def test_supports_scheduler_interface_injection(self, mocker: MockerFixture) -> None:
request = _make_request("engine_iface")
runner_output = _make_request_output("engine_iface")
scheduler = _StubScheduler(request, runner_output)
@@ -301,33 +301,45 @@ def test_supports_scheduler_interface_injection(self) -> None:
engine.scheduler = scheduler
engine._rpc_lock = threading.RLock()
engine.abort_queue = queue.Queue()
- engine.execute_fn = Mock(return_value=runner_output)
+ engine.execute_fn = mocker.Mock(return_value=runner_output)
output = engine.add_req_and_wait_for_response(request)
assert output is runner_output.result
engine.execute_fn.assert_called_once()
- def test_initializes_injected_scheduler(self) -> None:
+ def test_initializes_injected_scheduler(
+ self,
+ monkeypatch: pytest.MonkeyPatch,
+ mocker: MockerFixture,
+ ) -> None:
request = _make_request("init")
scheduler = _StubScheduler(request, DiffusionOutput(output=None))
- od_config = Mock(model_class_name="mock_model")
- fake_executor_cls = Mock(return_value=Mock())
+ od_config = SimpleNamespace(model_class_name="mock_model")
+ fake_executor_cls = mocker.Mock(return_value=mocker.Mock())
- with (
- patch("vllm_omni.diffusion.diffusion_engine.get_diffusion_post_process_func", return_value=None),
- patch("vllm_omni.diffusion.diffusion_engine.get_diffusion_pre_process_func", return_value=None),
- patch("vllm_omni.diffusion.diffusion_engine.DiffusionExecutor.get_class", return_value=fake_executor_cls),
- patch.object(DiffusionEngine, "_dummy_run", return_value=None),
- ):
- DiffusionEngine(od_config, scheduler=scheduler)
+ monkeypatch.setattr(
+ "vllm_omni.diffusion.diffusion_engine.get_diffusion_post_process_func",
+ lambda *args, **kwargs: None,
+ )
+ monkeypatch.setattr(
+ "vllm_omni.diffusion.diffusion_engine.get_diffusion_pre_process_func",
+ lambda *args, **kwargs: None,
+ )
+ monkeypatch.setattr(
+ "vllm_omni.diffusion.diffusion_engine.DiffusionExecutor.get_class",
+ lambda *args, **kwargs: fake_executor_cls,
+ )
+ monkeypatch.setattr(DiffusionEngine, "_dummy_run", lambda self: None)
+
+ DiffusionEngine(od_config, scheduler=scheduler)
assert scheduler.initialized_with is od_config
fake_executor_cls.assert_called_once_with(od_config)
def test_scheduler_alias_keeps_default_request_scheduler(self) -> None:
scheduler = Scheduler()
- scheduler.initialize(Mock())
+ scheduler.initialize(SimpleNamespace())
req_id = scheduler.add_request(_make_request("alias"))
sched_output = scheduler.schedule()
@@ -336,10 +348,10 @@ def test_scheduler_alias_keeps_default_request_scheduler(self) -> None:
assert req_id in finished
assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED
- def test_step_raises_aborted_error(self) -> None:
+ def test_step_raises_aborted_error(self, mocker: MockerFixture) -> None:
engine = DiffusionEngine.__new__(DiffusionEngine)
engine.pre_process_func = None
- engine.add_req_and_wait_for_response = Mock(
+ engine.add_req_and_wait_for_response = mocker.Mock(
return_value=DiffusionOutput(aborted=True, abort_message="Request req-abort aborted.")
)
@@ -349,7 +361,7 @@ def test_step_raises_aborted_error(self) -> None:
def test_abort_queue_marks_request_finished_aborted(self) -> None:
engine = DiffusionEngine.__new__(DiffusionEngine)
engine.scheduler = RequestScheduler()
- engine.scheduler.initialize(Mock())
+ engine.scheduler.initialize(SimpleNamespace())
engine.abort_queue = queue.Queue()
req_id = engine.scheduler.add_request(_make_request("req-abort"))
@@ -361,7 +373,7 @@ def test_abort_queue_marks_request_finished_aborted(self) -> None:
def test_finalize_finished_request_returns_aborted_output(self) -> None:
engine = DiffusionEngine.__new__(DiffusionEngine)
engine.scheduler = RequestScheduler()
- engine.scheduler.initialize(Mock())
+ engine.scheduler.initialize(SimpleNamespace())
req_id = engine.scheduler.add_request(_make_request("req-finalize"))
engine.scheduler.finish_requests(req_id, DiffusionRequestStatus.FINISHED_ABORTED)
@@ -371,29 +383,40 @@ def test_finalize_finished_request_returns_aborted_output(self) -> None:
assert output.aborted is True
assert output.abort_message == "Request req-finalize aborted."
- def test_initializes_step_scheduler_when_step_execution_enabled(self) -> None:
- od_config = Mock(model_class_name="mock_model")
+ def test_initializes_step_scheduler_when_step_execution_enabled(
+ self,
+ monkeypatch: pytest.MonkeyPatch,
+ mocker: MockerFixture,
+ ) -> None:
+ od_config = SimpleNamespace(model_class_name="mock_model")
od_config.step_execution = True
- fake_executor = Mock()
- fake_executor_cls = Mock(return_value=fake_executor)
+ fake_executor = mocker.Mock()
+ fake_executor_cls = mocker.Mock(return_value=fake_executor)
- with (
- patch("vllm_omni.diffusion.diffusion_engine.get_diffusion_post_process_func", return_value=None),
- patch("vllm_omni.diffusion.diffusion_engine.get_diffusion_pre_process_func", return_value=None),
- patch("vllm_omni.diffusion.diffusion_engine.DiffusionExecutor.get_class", return_value=fake_executor_cls),
- patch.object(DiffusionEngine, "_dummy_run", return_value=None),
- ):
- engine = DiffusionEngine(od_config)
+ monkeypatch.setattr(
+ "vllm_omni.diffusion.diffusion_engine.get_diffusion_post_process_func",
+ lambda *args, **kwargs: None,
+ )
+ monkeypatch.setattr(
+ "vllm_omni.diffusion.diffusion_engine.get_diffusion_pre_process_func",
+ lambda *args, **kwargs: None,
+ )
+ monkeypatch.setattr(
+ "vllm_omni.diffusion.diffusion_engine.DiffusionExecutor.get_class",
+ lambda *args, **kwargs: fake_executor_cls,
+ )
+ monkeypatch.setattr(DiffusionEngine, "_dummy_run", lambda self: None)
+ engine = DiffusionEngine(od_config)
assert isinstance(engine.scheduler, StepScheduler)
assert engine.execute_fn is fake_executor.execute_step
fake_executor_cls.assert_called_once_with(od_config)
- def test_dummy_run_raises_on_output_error(self) -> None:
+ def test_dummy_run_raises_on_output_error(self, mocker: MockerFixture) -> None:
engine = DiffusionEngine.__new__(DiffusionEngine)
- engine.od_config = Mock(model_class_name="mock_model")
+ engine.od_config = SimpleNamespace(model_class_name="mock_model")
engine.pre_process_func = None
- engine.add_req_and_wait_for_response = Mock(return_value=DiffusionOutput(error="boom"))
+ engine.add_req_and_wait_for_response = mocker.Mock(return_value=DiffusionOutput(error="boom"))
with pytest.raises(RuntimeError, match="Dummy run failed: boom"):
engine._dummy_run()
@@ -402,7 +425,7 @@ def test_dummy_run_raises_on_output_error(self) -> None:
class TestStepScheduler:
def setup_method(self) -> None:
self.scheduler: StepScheduler = StepScheduler()
- self.scheduler.initialize(Mock())
+ self.scheduler.initialize(SimpleNamespace())
def test_single_request_step_lifecycle(self) -> None:
request = _make_step_request("step", num_inference_steps=3)
diff --git a/tests/diffusion/test_diffusion_step_pipeline.py b/tests/diffusion/test_diffusion_step_pipeline.py
index 68aba9ba3b..42687d4a1e 100644
--- a/tests/diffusion/test_diffusion_step_pipeline.py
+++ b/tests/diffusion/test_diffusion_step_pipeline.py
@@ -7,10 +7,10 @@
import threading
from contextlib import contextmanager
from types import SimpleNamespace
-from unittest.mock import Mock
import pytest
import torch
+from pytest_mock import MockerFixture
import vllm_omni.diffusion.worker.diffusion_model_runner as model_runner_module
from tests.utils import hardware_test
@@ -542,11 +542,11 @@ def test_rejects_lora_requests_in_step_mode(self):
class TestExecutor:
"""MultiprocDiffusionExecutor.execute_step"""
- def test_execute_step_passes_through_runner_output(self):
+ def test_execute_step_passes_through_runner_output(self, mocker: MockerFixture):
executor = object.__new__(MultiprocDiffusionExecutor)
executor._ensure_open = lambda: None
expected = RunnerOutput(req_id="req-step", step_index=1, finished=False, result=None)
- executor.collective_rpc = Mock(return_value=expected)
+ executor.collective_rpc = mocker.Mock(return_value=expected)
request = _make_engine_request("req-step", num_inference_steps=2)
scheduler_output = _make_scheduler_output(request, sched_req_id="req-step")
@@ -578,9 +578,9 @@ class TestEngine:
),
],
)
- def test_step_engine_returns_error(self, execute_fn, expected_error):
+ def test_step_engine_returns_error(self, execute_fn, expected_error, mocker: MockerFixture):
scheduler = StepScheduler()
- scheduler.initialize(Mock())
+ scheduler.initialize(mocker.Mock())
engine = _make_engine(scheduler, execute_fn=execute_fn)
output = engine.add_req_and_wait_for_response(_make_engine_request("req-error", num_inference_steps=2))
@@ -588,9 +588,9 @@ def test_step_engine_returns_error(self, execute_fn, expected_error):
assert output.output is None
assert expected_error in output.error
- def test_step_execution_completes(self):
+ def test_step_execution_completes(self, mocker: MockerFixture):
scheduler = StepScheduler()
- scheduler.initialize(Mock())
+ scheduler.initialize(mocker.Mock())
engine = _make_engine(scheduler)
request = _make_engine_request("req-step", num_inference_steps=2)
@@ -614,9 +614,9 @@ def execute_fn(_):
assert output.error is None
assert torch.equal(output.output, torch.tensor([2.0]))
- def test_step_abort_stops_rescheduling_after_first_step(self):
+ def test_step_abort_stops_rescheduling_after_first_step(self, mocker: MockerFixture):
scheduler = StepScheduler()
- scheduler.initialize(Mock())
+ scheduler.initialize(mocker.Mock())
engine = _make_engine(scheduler)
request = _make_engine_request("req-stop", num_inference_steps=4)
@@ -639,9 +639,9 @@ def execute_fn(_):
assert step["n"] == 1
_assert_aborted_output(output, "req-stop")
- def test_step_abort_after_reschedule_returns_aborted_output(self):
+ def test_step_abort_after_reschedule_returns_aborted_output(self, mocker: MockerFixture):
scheduler = StepScheduler()
- scheduler.initialize(Mock())
+ scheduler.initialize(mocker.Mock())
engine = _make_engine(scheduler)
request = _make_engine_request("req-mid", num_inference_steps=4)
@@ -666,9 +666,9 @@ def execute_fn(sched_output):
assert step["n"] == 2
_assert_aborted_output(output, "req-mid")
- def test_finished_step_without_result_returns_error(self):
+ def test_finished_step_without_result_returns_error(self, mocker: MockerFixture):
scheduler = StepScheduler()
- scheduler.initialize(Mock())
+ scheduler.initialize(mocker.Mock())
engine = _make_engine(
scheduler,
execute_fn=lambda _: RunnerOutput(
diff --git a/tests/diffusion/test_diffusion_worker_cuda_profiler.py b/tests/diffusion/test_diffusion_worker_cuda_profiler.py
new file mode 100644
index 0000000000..4a3b22c212
--- /dev/null
+++ b/tests/diffusion/test_diffusion_worker_cuda_profiler.py
@@ -0,0 +1,101 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import pytest
+from pytest_mock import MockerFixture
+
+from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker
+
+pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]
+
+
+@pytest.fixture
+def mock_od_config(mocker: MockerFixture):
+ """Create a mock OmniDiffusionConfig with a CUDA profiler backend."""
+ config = mocker.Mock()
+ config.profiler_config = mocker.Mock()
+ config.profiler_config.profiler = "cuda"
+ config.diffusion_load_format = "default"
+ return config
+
+
+@pytest.fixture
+def mock_diffusion_worker_dependencies(mocker: MockerFixture):
+ """Patch heavy worker dependencies for focused profiler tests."""
+ mocker.patch.object(DiffusionWorker, "init_device")
+ mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.DiffusionModelRunner")
+
+
+class TestDiffusionWorkerCudaProfiler:
+ def test_creates_cuda_profiler_wrapper(
+ self,
+ mocker: MockerFixture,
+ mock_od_config,
+ mock_diffusion_worker_dependencies,
+ ):
+ fake_profiler = mocker.Mock()
+ cuda_profiler = mocker.patch(
+ "vllm_omni.diffusion.worker.diffusion_worker.CudaProfilerWrapper",
+ return_value=fake_profiler,
+ )
+ create_omni_profiler = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.create_omni_profiler")
+
+ worker = DiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config, skip_load_model=True)
+
+ cuda_profiler.assert_called_once_with(mock_od_config.profiler_config)
+ create_omni_profiler.assert_not_called()
+ assert worker.profiler is fake_profiler
+
+ def test_profile_start_stop_delegates_to_cuda_profiler(
+ self,
+ mocker: MockerFixture,
+ mock_od_config,
+ mock_diffusion_worker_dependencies,
+ ):
+ fake_profiler = mocker.Mock()
+ fake_profiler.start = mocker.Mock()
+ fake_profiler.stop = mocker.Mock()
+ mocker.patch(
+ "vllm_omni.diffusion.worker.diffusion_worker.CudaProfilerWrapper",
+ return_value=fake_profiler,
+ )
+
+ worker = DiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config, skip_load_model=True)
+
+ assert worker.profile(is_start=True) is None
+ assert worker.profile(is_start=False) is None
+
+ fake_profiler.start.assert_called_once_with()
+ fake_profiler.stop.assert_called_once_with()
+
+ def test_returns_none_when_profiler_config_is_missing(
+ self,
+ mocker: MockerFixture,
+ mock_od_config,
+ mock_diffusion_worker_dependencies,
+ ):
+ mock_od_config.profiler_config = None
+ cuda_profiler = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.CudaProfilerWrapper")
+ create_omni_profiler = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.create_omni_profiler")
+
+ worker = DiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config, skip_load_model=True)
+
+ cuda_profiler.assert_not_called()
+ create_omni_profiler.assert_not_called()
+ assert worker.profiler is None
+
+ def test_cuda_backend_does_not_use_torch_profiler_factory(
+ self,
+ mocker: MockerFixture,
+ mock_od_config,
+ mock_diffusion_worker_dependencies,
+ ):
+ mocker.patch(
+ "vllm_omni.diffusion.worker.diffusion_worker.CudaProfilerWrapper",
+ return_value=mocker.Mock(),
+ )
+ create_omni_profiler = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.create_omni_profiler")
+
+ DiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config, skip_load_model=True)
+
+ create_omni_profiler.assert_not_called()
diff --git a/tests/diffusion/test_inline_stage_diffusion_client.py b/tests/diffusion/test_inline_stage_diffusion_client.py
new file mode 100644
index 0000000000..385f39b124
--- /dev/null
+++ b/tests/diffusion/test_inline_stage_diffusion_client.py
@@ -0,0 +1,96 @@
+from __future__ import annotations
+
+import asyncio
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from vllm_omni.diffusion.data import OmniDiffusionConfig
+from vllm_omni.diffusion.inline_stage_diffusion_client import InlineStageDiffusionClient
+from vllm_omni.engine.stage_init_utils import StageMetadata
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+from vllm_omni.outputs import OmniRequestOutput
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+@pytest.fixture
+def mock_engine():
+ with patch("vllm_omni.diffusion.inline_stage_diffusion_client.DiffusionEngine") as mock:
+ engine_instance = MagicMock()
+ mock.make_engine.return_value = engine_instance
+ yield engine_instance
+
+
+@pytest.fixture
+def client(mock_engine):
+ metadata = StageMetadata(
+ stage_id=0,
+ stage_type="diffusion",
+ engine_output_type="image",
+ is_comprehension=False,
+ requires_multimodal_data=False,
+ engine_input_source="prompt",
+ final_output=True,
+ final_output_type="image",
+ default_sampling_params={},
+ custom_process_input_func=None,
+ model_stage=None,
+ runtime_cfg=None,
+ )
+ with patch.object(InlineStageDiffusionClient, "_enrich_config"):
+ od_config = MagicMock(spec=OmniDiffusionConfig)
+ c = InlineStageDiffusionClient(model="test_model", od_config=od_config, metadata=metadata, batch_size=1)
+ yield c
+ c.shutdown()
+
+
+@pytest.mark.asyncio
+async def test_inline_dispatch_request_success(client, mock_engine):
+ # Setup mock engine step to return a successful result
+ mock_result = OmniRequestOutput.from_diffusion(request_id="req-1", images=[MagicMock()])
+ mock_engine.step.return_value = [mock_result]
+
+ sampling_params = OmniDiffusionSamplingParams()
+ await client.add_request_async("req-1", "A test prompt", sampling_params)
+
+ # Wait for the task to be processed
+ for _ in range(10):
+ output = client.get_diffusion_output_nowait()
+ if output is not None:
+ break
+ await asyncio.sleep(0.01)
+
+ assert output is not None
+ assert output.request_id == "req-1"
+ mock_engine.step.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_inline_dispatch_request_error(client, mock_engine):
+ # Setup mock engine step to raise an exception
+ mock_engine.step.side_effect = RuntimeError("Engine failure")
+
+ sampling_params = OmniDiffusionSamplingParams()
+ await client.add_request_async("req-err", "A test prompt", sampling_params)
+
+ for _ in range(10):
+ output = client.get_diffusion_output_nowait()
+ if output is not None:
+ break
+ await asyncio.sleep(0.01)
+
+ assert output is not None
+ assert output.request_id == "req-err"
+ assert output.error == "Engine failure"
+ assert not output.images
+
+
+def test_inline_shutdown(client, mock_engine):
+ assert not client._shutting_down
+
+ # Shutting down should cleanly cancel anything queued and close engine
+ client.shutdown()
+
+ assert client._shutting_down
+ mock_engine.close.assert_called_once()
diff --git a/tests/diffusion/test_multiproc_engine_concurrency.py b/tests/diffusion/test_multiproc_engine_concurrency.py
index 517f98ddaa..4bc3e05fe9 100644
--- a/tests/diffusion/test_multiproc_engine_concurrency.py
+++ b/tests/diffusion/test_multiproc_engine_concurrency.py
@@ -3,7 +3,7 @@
import queue
import threading
-from unittest.mock import Mock, patch
+from types import SimpleNamespace
import pytest
import torch
@@ -24,11 +24,9 @@ def _tagged_output(tag: str) -> DiffusionOutput:
return DiffusionOutput(output=torch.tensor([0]), error=tag)
-def _mock_request(tag: str) -> Mock:
- """Return a mock ``OmniDiffusionRequest`` identifiable by *tag*."""
- req = Mock()
- req.request_ids = [tag]
- return req
+def _mock_request(tag: str):
+ """Return a lightweight request object identifiable by *tag*."""
+ return SimpleNamespace(request_ids=[tag])
def _make_executor(num_gpus: int = 1):
@@ -36,20 +34,18 @@ def _make_executor(num_gpus: int = 1):
Returns ``(executor, request_queue, result_queue)``.
"""
- od_cfg = Mock()
- od_cfg.num_gpus = num_gpus
-
- with patch.object(MultiprocDiffusionExecutor, "_init_executor"):
- executor = MultiprocDiffusionExecutor(od_cfg)
+ od_cfg = SimpleNamespace(num_gpus=num_gpus)
+ monkeypatch = pytest.MonkeyPatch()
+ monkeypatch.setattr(MultiprocDiffusionExecutor, "_init_executor", lambda self: None)
+ executor = MultiprocDiffusionExecutor(od_cfg)
+ monkeypatch.undo()
req_q: queue.Queue = queue.Queue()
res_q: queue.Queue = queue.Queue()
- mock_broadcast_mq = Mock()
- mock_broadcast_mq.enqueue = req_q.put
+ mock_broadcast_mq = SimpleNamespace(enqueue=req_q.put)
- mock_rmq = Mock()
- mock_rmq.dequeue = lambda timeout=None: res_q.get(timeout=timeout if timeout is not None else 10)
+ mock_rmq = SimpleNamespace(dequeue=lambda timeout=None: res_q.get(timeout=timeout if timeout is not None else 10))
executor._broadcast_mq = mock_broadcast_mq
executor._result_mq = mock_rmq
@@ -63,7 +59,7 @@ def _make_engine(num_gpus: int = 1):
executor, req_q, res_q = _make_executor(num_gpus)
engine = DiffusionEngine.__new__(DiffusionEngine)
sched = RequestScheduler()
- sched.initialize(Mock())
+ sched.initialize(SimpleNamespace())
engine.scheduler = sched
engine.executor = executor
engine._rpc_lock = threading.RLock()
diff --git a/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py b/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py
index 7a3caba11e..256e3e0a3f 100644
--- a/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py
+++ b/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py
@@ -4,10 +4,12 @@
import threading
from collections import deque
from types import SimpleNamespace
+from unittest.mock import patch
import pytest
import torch
from pytest_mock import MockerFixture
+from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler
from vllm.v1.request import RequestStatus
from vllm_omni.distributed.omni_connectors.transfer_adapter.base import OmniTransferAdapterBase
@@ -335,6 +337,27 @@ def test_cleanup_after_poll_flow(build_adapter):
assert "ext-flow" not in adapter.request_payload
+def test_finish_requests_restores_status(build_adapter):
+ """Abort path must pop ``requests_origin_status`` and restore pre-wait status.
+
+ While ``process_pending_chunks`` holds a request off the scheduler queues, the
+ adapter records the prior status (WAITING or RUNNING). ``finish_requests`` must
+ put that status back on the live ``Request`` so base ``Scheduler.finish_requests``
+ can finish bookkeeping without inconsistent state / crashes.
+ """
+ adapter, _ = build_adapter(stage_id=1)
+ req_id = "req-abort-during-chunk"
+ prior = RequestStatus.RUNNING
+ request = _req(req_id, RequestStatus.WAITING_FOR_CHUNK)
+ adapter.requests_origin_status[req_id] = prior
+ requests_map = {req_id: request}
+
+ adapter.finish_requests([req_id], RequestStatus.FINISHED_ABORTED, requests_map)
+
+ assert request.status == prior
+ assert req_id not in adapter.requests_origin_status
+
+
# ---------------------------------------------------------------
# Scheduler trigger tests
# ---------------------------------------------------------------
@@ -508,3 +531,31 @@ def test_ar_scheduler_defers_cleanup_and_queues_save_on_finished(mocker: MockerF
assert len(cleanup_calls) == 0
assert len(save_calls) == 1
+
+
+def test_omni_ar_scheduler_finish_requests(mocker: MockerFixture):
+ """``OmniARScheduler.finish_requests`` must run chunk adapter hook before vLLM base."""
+ from vllm_omni.core.sched.omni_ar_scheduler import OmniARScheduler
+
+ order: list[str] = []
+
+ adapter = mocker.MagicMock()
+
+ def _adapter_finish(request_ids, finished_status, requests):
+ order.append("adapter")
+ return []
+
+ adapter.finish_requests.side_effect = _adapter_finish
+
+ def _super_finish(_self, request_ids, finished_status):
+ order.append("super")
+ return []
+
+ sched = OmniARScheduler.__new__(OmniARScheduler)
+ sched.chunk_transfer_adapter = adapter
+ sched.requests = {}
+
+ with patch.object(VLLMScheduler, "finish_requests", _super_finish):
+ OmniARScheduler.finish_requests(sched, ["r1"], RequestStatus.FINISHED_ABORTED)
+
+ assert order == ["adapter", "super"]
diff --git a/tests/distributed/omni_connectors/test_shm_connector.py b/tests/distributed/omni_connectors/test_shm_connector.py
new file mode 100644
index 0000000000..e702318e3f
--- /dev/null
+++ b/tests/distributed/omni_connectors/test_shm_connector.py
@@ -0,0 +1,184 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for SharedMemoryConnector focusing on TP / CFG / metadata fallback."""
+
+import pytest
+
+from vllm_omni.distributed.omni_connectors.connectors.shm_connector import (
+ SharedMemoryConnector,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+@pytest.fixture()
+def connector():
+ c = SharedMemoryConnector({"shm_threshold_bytes": 64})
+ yield c
+ c.close()
+
+
+# ── Key-based read (the fundamental SHM path) ────────────────────────
+
+
+class TestKeyBasedReadWrite:
+ def test_put_then_get_by_key(self, connector):
+ data = {"hello": "world", "n": 42}
+ ok, size, meta = connector.put("s0", "s1", "test_key_1", data)
+ assert ok
+ assert size > 0
+ assert "shm" in meta
+ assert "test_key_1" in connector._pending_keys
+
+ result = connector.get("s0", "s1", "test_key_1", metadata=None)
+ assert result is not None
+ obj, rsize = result
+ assert obj == data
+ assert rsize == size
+ assert "test_key_1" not in connector._pending_keys
+
+ def test_get_nonexistent_key_returns_none(self, connector):
+ result = connector.get("s0", "s1", "no_such_key_xyz", metadata=None)
+ assert result is None
+
+ def test_rank_aware_keys_independent(self, connector):
+ """Each TP rank writes/reads its own key — simulates homogeneous TP."""
+ payloads = {}
+ for rank in range(4):
+ key = f"req1_s0_0_{rank}_{rank}"
+ data = {"rank": rank, "values": list(range(rank, rank + 3))}
+ ok, _, _ = connector.put("s0", "s1", key, data)
+ assert ok
+ payloads[rank] = data
+
+ for rank in range(4):
+ key = f"req1_s0_0_{rank}_{rank}"
+ result = connector.get("s0", "s1", key, metadata=None)
+ assert result is not None
+ obj, _ = result
+ assert obj == payloads[rank]
+
+
+# ── Metadata fallback behaviour ──────────────────────────────────────
+
+
+class TestMetadataFallback:
+ def test_rdma_style_metadata_falls_back_to_key(self, connector):
+ """source_host/source_port metadata should be ignored; key read used."""
+ data = {"payload": True}
+ connector.put("s0", "s1", "fb_key_1", data)
+
+ rdma_meta = {"source_host": "10.0.0.1", "source_port": 12345}
+ result = connector.get("s0", "s1", "fb_key_1", metadata=rdma_meta)
+ assert result is not None
+ obj, _ = result
+ assert obj == data
+
+ def test_non_dict_metadata_falls_back_to_key(self, connector):
+ data = {"val": 99}
+ connector.put("s0", "s1", "fb_key_2", data)
+
+ result = connector.get("s0", "s1", "fb_key_2", metadata="not_a_dict")
+ assert result is not None
+ obj, _ = result
+ assert obj == data
+
+ def test_empty_dict_metadata_falls_back_to_key(self, connector):
+ data = {"x": 1}
+ connector.put("s0", "s1", "fb_key_3", data)
+
+ result = connector.get("s0", "s1", "fb_key_3", metadata={})
+ assert result is not None
+ obj, _ = result
+ assert obj == data
+
+ def test_shm_handle_metadata_still_works(self, connector):
+ """When metadata contains a proper 'shm' handle, use it directly."""
+ data = {"direct": True}
+ ok, size, meta = connector.put("s0", "s1", "shm_direct_1", data)
+ assert ok
+ result = connector.get("s0", "s1", "shm_direct_1", metadata=meta)
+ assert result is not None
+ obj, _ = result
+ assert obj == data
+
+ def test_metadata_keyed_by_request_id(self, connector):
+ """Metadata wrapped as {get_key: actual_meta} should be unwrapped."""
+ data = {"wrapped": True}
+ ok, size, meta = connector.put("s0", "s1", "wrap_key", data)
+ assert ok
+ wrapped = {"wrap_key": meta}
+ result = connector.get("s0", "s1", "wrap_key", metadata=wrapped)
+ assert result is not None
+ obj, _ = result
+ assert obj == data
+
+
+# ── Heterogeneous TP multi-key read ──────────────────────────────────
+
+
+class TestHeteroTPMultiKey:
+ def test_receiver_reads_multiple_sender_keys(self, connector):
+ """Simulates from_tp=2 -> to_tp=1: receiver reads 2 keys and merges."""
+ for sender_rank in range(2):
+ key = f"req1_s0_0_{sender_rank}_0"
+ data = {"sender": sender_rank, "shard": [sender_rank * 10]}
+ connector.put("s0", "s1", key, data)
+
+ shards = []
+ for sender_rank in range(2):
+ key = f"req1_s0_0_{sender_rank}_0"
+ result = connector.get("s0", "s1", key, metadata=None)
+ assert result is not None
+ obj, _ = result
+ shards.append(obj)
+
+ assert len(shards) == 2
+ assert shards[0]["sender"] == 0
+ assert shards[1]["sender"] == 1
+
+ def test_sender_writes_multiple_receiver_keys(self, connector):
+ """Simulates from_tp=1 -> to_tp=2: sender writes 2 sliced keys."""
+ for recv_rank in range(2):
+ key = f"req1_s0_0_0_{recv_rank}"
+ data = {"target": recv_rank, "slice": list(range(recv_rank, recv_rank + 2))}
+ connector.put("s0", "s1", key, data)
+
+ for recv_rank in range(2):
+ key = f"req1_s0_0_0_{recv_rank}"
+ result = connector.get("s0", "s1", key, metadata=None)
+ assert result is not None
+ obj, _ = result
+ assert obj["target"] == recv_rank
+
+
+# ── Cleanup ──────────────────────────────────────────────────────────
+
+
+class TestCleanup:
+ def test_cleanup_removes_unconsumed_segment(self, connector):
+ data = {"leak": True}
+ connector.put("s0", "s1", "cleanup_req_42", data)
+ assert "cleanup_req_42" in connector._pending_keys
+
+ connector.cleanup("req_42")
+ assert "cleanup_req_42" not in connector._pending_keys
+
+ result = connector.get("s0", "s1", "cleanup_req_42", metadata=None)
+ assert result is None
+
+ def test_cleanup_noop_for_consumed_segment(self, connector):
+ data = {"consumed": True}
+ connector.put("s0", "s1", "consumed_req_99", data)
+ connector.get("s0", "s1", "consumed_req_99", metadata=None)
+
+ connector.cleanup("req_99")
+ assert "consumed_req_99" not in connector._pending_keys
+
+ def test_close_cleans_all_pending(self, connector):
+ for i in range(3):
+ connector.put("s0", "s1", f"close_test_{i}", {"i": i})
+
+ assert len(connector._pending_keys) == 3
+ connector.close()
+ assert len(connector._pending_keys) == 0
diff --git a/tests/distributed/omni_connectors/test_tp_rank_aware.py b/tests/distributed/omni_connectors/test_tp_rank_aware.py
new file mode 100644
index 0000000000..d4793479aa
--- /dev/null
+++ b/tests/distributed/omni_connectors/test_tp_rank_aware.py
@@ -0,0 +1,716 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for rank-aware KV transfer (TP > 1) and heterogeneous TP support.
+
+Covers:
+- _build_rank_aware_send_keys / _build_rank_aware_recv_keys
+- _get_kv_source_ranks / _get_kv_target_ranks / get_kv_connector_key
+- update_sender_info storing base host/port
+- receive path constructing per-rank metadata for connector.get()
+- Mooncake connector _query_metadata_at and partial-metadata get() path
+"""
+
+from types import SimpleNamespace
+from unittest.mock import MagicMock, patch
+
+import pytest
+import torch
+
+from vllm_omni.distributed.omni_connectors.kv_transfer_manager import (
+ KVCacheTransferData,
+ OmniKVCacheConfig,
+ OmniKVTransferManager,
+)
+from vllm_omni.distributed.omni_connectors.utils.initialization import (
+ KV_RANK_PORT_STRIDE,
+)
+from vllm_omni.distributed.omni_connectors.utils.kv_utils import (
+ KVTPTopology,
+ build_rank_aware_recv_keys,
+ build_rank_aware_send_keys,
+ get_kv_connector_key,
+ get_kv_source_ranks,
+ get_kv_target_ranks,
+ merge_received_rank_shards,
+ slice_received_rank_shard,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+def _make_manager(
+ from_tp: int = 1,
+ to_tp: int = 1,
+ local_rank: int = 0,
+ from_stage: str = "stage0",
+ to_stage: str = "stage1",
+ stage_id: str = "stage1",
+ need_recv: bool = True,
+ need_send: bool = False,
+ recv_timeout: float = 0.3,
+) -> OmniKVTransferManager:
+ """Build a manager with TP params injected, bypassing torch.distributed."""
+ config = OmniKVCacheConfig(
+ connector_config={"type": "mock"},
+ from_stage=from_stage,
+ to_stage=to_stage,
+ stage_id=stage_id,
+ need_recv_cache=need_recv,
+ need_send_cache=need_send,
+ recv_timeout=recv_timeout,
+ from_tp=from_tp,
+ to_tp=to_tp,
+ )
+ with (
+ patch("vllm_omni.distributed.omni_connectors.kv_transfer_manager.get_local_tp_rank", return_value=local_rank),
+ patch(
+ "vllm_omni.distributed.omni_connectors.kv_transfer_manager.get_tp_world_size",
+ return_value=max(from_tp, to_tp),
+ ),
+ ):
+ mgr = OmniKVTransferManager(config)
+ return mgr
+
+
+def _make_payload(head_values: list[float], request_id: str = "req-1") -> dict:
+ head_tensor = torch.tensor(head_values, dtype=torch.float32).view(1, len(head_values), 1).repeat(2, 1, 1)
+ return {
+ "request_id": request_id,
+ "layer_blocks": {
+ "key_cache": [head_tensor.clone()],
+ "value_cache": [(head_tensor + 100).clone()],
+ },
+ "block_ids": [0],
+ "metadata": {"seq_len": 2},
+ }
+
+
+def _make_transfer_data(head_values: list[float], request_id: str = "req-1") -> KVCacheTransferData:
+ payload = _make_payload(head_values, request_id=request_id)
+ return KVCacheTransferData(
+ request_id=request_id,
+ layer_blocks=payload["layer_blocks"],
+ block_ids=payload["block_ids"],
+ metadata=payload["metadata"],
+ )
+
+
+# ── Key format helper ────────────────────────────────────────────────
+
+
+class TestConnectorKeyFormat:
+ def test_key_format_matches_pr2677(self):
+ key = get_kv_connector_key("req-1", "stage0", 0, 1, 2)
+ assert key == "req-1_stage0_0_1_2"
+
+ def test_key_fields_are_positional(self):
+ key = get_kv_connector_key("r", "s", 5, 3, 7)
+ parts = key.split("_")
+ assert parts == ["r", "s", "5", "3", "7"]
+
+
+# ── Source / target rank mapping ─────────────────────────────────────
+
+
+class TestRankMapping:
+ """Verify get_kv_target_ranks and get_kv_source_ranks for various TP configs."""
+
+ def test_homogeneous_tp2_rank0(self):
+ topo = KVTPTopology(source_tp_size=2, target_tp_size=2, local_rank=0)
+ assert get_kv_target_ranks(topo) == [0]
+ assert get_kv_source_ranks(topo) == [0]
+
+ def test_homogeneous_tp2_rank1(self):
+ topo = KVTPTopology(source_tp_size=2, target_tp_size=2, local_rank=1)
+ assert get_kv_target_ranks(topo) == [1]
+ assert get_kv_source_ranks(topo) == [1]
+
+ def test_homogeneous_tp4_rank3(self):
+ topo = KVTPTopology(source_tp_size=4, target_tp_size=4, local_rank=3)
+ assert get_kv_target_ranks(topo) == [3]
+ assert get_kv_source_ranks(topo) == [3]
+
+ def test_sender_gt_receiver_tp4_to_tp2_rank0(self):
+ """Receiver rank 0 should receive from sender rank 0 and 1."""
+ topo = KVTPTopology(source_tp_size=4, target_tp_size=2, local_rank=0)
+ assert get_kv_source_ranks(topo) == [0, 1]
+
+ def test_sender_gt_receiver_tp4_to_tp2_rank1(self):
+ """Receiver rank 1 should receive from sender rank 2 and 3."""
+ topo = KVTPTopology(source_tp_size=4, target_tp_size=2, local_rank=1)
+ assert get_kv_source_ranks(topo) == [2, 3]
+
+ def test_sender_lt_receiver_tp2_to_tp4_rank0(self):
+ """Sender rank 0 should send to receiver ranks 0 and 1."""
+ topo = KVTPTopology(source_tp_size=2, target_tp_size=4, local_rank=0)
+ assert get_kv_target_ranks(topo) == [0, 1]
+
+ def test_sender_lt_receiver_tp2_to_tp4_rank1(self):
+ topo = KVTPTopology(source_tp_size=2, target_tp_size=4, local_rank=1)
+ assert get_kv_target_ranks(topo) == [2, 3]
+
+ def test_receiver_lt_sender_source_ranks(self):
+ """Receiver rank 0 with tp2_to_tp4 should source from rank 0 only."""
+ topo = KVTPTopology(source_tp_size=2, target_tp_size=4, local_rank=0)
+ assert get_kv_source_ranks(topo) == [0]
+
+ def test_invalid_topology_raises(self):
+ topo = KVTPTopology(source_tp_size=3, target_tp_size=2, local_rank=0)
+ with pytest.raises(ValueError, match="divisible"):
+ get_kv_source_ranks(topo)
+
+
+# ── _build_rank_aware_recv_keys ──────────────────────────────────────
+
+
+class TestBuildRankAwareRecvKeys:
+ """Verify build_rank_aware_recv_keys returns (key, from_rank) tuples."""
+
+ def test_tp1_returns_legacy_key_with_none_rank(self):
+ topo = KVTPTopology(source_tp_size=1, target_tp_size=1, local_rank=0)
+ pairs = build_rank_aware_recv_keys("req-1", "stage0", "stage1", topo)
+ assert len(pairs) == 1
+ key, rank = pairs[0]
+ assert key == "omni_stage0_to_stage1_kv_cache_req-1"
+ assert rank is None
+
+ def test_homogeneous_tp2_rank0(self):
+ topo = KVTPTopology(source_tp_size=2, target_tp_size=2, local_rank=0)
+ pairs = build_rank_aware_recv_keys("req-1", "stage0", "stage1", topo)
+ assert len(pairs) == 1
+ key, rank = pairs[0]
+ assert key == "req-1_stage0_0_0_0"
+ assert rank == 0
+
+ def test_homogeneous_tp2_rank1(self):
+ topo = KVTPTopology(source_tp_size=2, target_tp_size=2, local_rank=1)
+ pairs = build_rank_aware_recv_keys("req-1", "stage0", "stage1", topo)
+ assert len(pairs) == 1
+ key, rank = pairs[0]
+ assert key == "req-1_stage0_0_1_1"
+ assert rank == 1
+
+ def test_heterogeneous_tp4_to_tp2_rank0_gets_two_keys(self):
+ """Receiver rank 0 with source_tp=4, target_tp=2 should get 2 keys."""
+ topo = KVTPTopology(source_tp_size=4, target_tp_size=2, local_rank=0)
+ pairs = build_rank_aware_recv_keys("req-1", "stage0", "stage1", topo)
+ assert len(pairs) == 2
+
+ keys = [k for k, _ in pairs]
+ ranks = [r for _, r in pairs]
+ assert keys == ["req-1_stage0_0_0_0", "req-1_stage0_0_1_0"]
+ assert ranks == [0, 1]
+
+ def test_heterogeneous_tp4_to_tp2_rank1_gets_two_keys(self):
+ topo = KVTPTopology(source_tp_size=4, target_tp_size=2, local_rank=1)
+ pairs = build_rank_aware_recv_keys("req-1", "stage0", "stage1", topo)
+ assert len(pairs) == 2
+
+ ranks = [r for _, r in pairs]
+ assert ranks == [2, 3]
+
+ def test_heterogeneous_tp2_to_tp4_rank2_gets_one_key(self):
+ """Receiver rank 2 with source_tp=2, target_tp=4 should get 1 key from sender rank 1."""
+ topo = KVTPTopology(source_tp_size=2, target_tp_size=4, local_rank=2)
+ pairs = build_rank_aware_recv_keys("req-1", "stage0", "stage1", topo)
+ assert len(pairs) == 1
+ key, rank = pairs[0]
+ assert rank == 1
+ assert key == "req-1_stage0_0_1_2"
+
+
+# ── _build_rank_aware_send_keys ──────────────────────────────────────
+
+
+class TestBuildRankAwareSendKeys:
+ def test_tp1_returns_legacy_key(self):
+ topo = KVTPTopology(source_tp_size=1, target_tp_size=1, local_rank=0)
+ keys = build_rank_aware_send_keys("req-1", "stage0", "stage1", topo)
+ assert keys == ["omni_stage0_to_stage1_kv_cache_req-1"]
+
+ def test_homogeneous_tp2_rank0(self):
+ topo = KVTPTopology(source_tp_size=2, target_tp_size=2, local_rank=0)
+ keys = build_rank_aware_send_keys("req-1", "stage0", "stage1", topo)
+ assert keys == ["req-1_stage0_0_0_0"]
+
+ def test_sender_lt_receiver_tp2_to_tp4_rank0_sends_two_keys(self):
+ topo = KVTPTopology(source_tp_size=2, target_tp_size=4, local_rank=0)
+ keys = build_rank_aware_send_keys("req-1", "stage0", "stage1", topo)
+ assert len(keys) == 2
+ assert keys == ["req-1_stage0_0_0_0", "req-1_stage0_0_0_1"]
+
+
+# ── update_sender_info stores base host/port ─────────────────────────
+
+
+class TestUpdateSenderInfoBase:
+ def test_stores_base_host_and_port(self):
+ mgr = _make_manager(from_tp=2, to_tp=2, local_rank=0)
+ mgr.update_sender_info({"host": "10.0.0.1", "zmq_port": 50151})
+
+ assert mgr._sender_base_host == "10.0.0.1"
+ assert mgr._sender_base_zmq_port == 50151
+
+ def test_rank1_adjusts_default_port_but_preserves_base(self):
+ mgr = _make_manager(from_tp=2, to_tp=2, local_rank=1)
+ mgr.update_sender_info({"host": "10.0.0.1", "zmq_port": 50151})
+
+ assert mgr._sender_base_host == "10.0.0.1"
+ assert mgr._sender_base_zmq_port == 50151
+ expected_adjusted = 50151 + 1 * KV_RANK_PORT_STRIDE
+ assert mgr.config.connector_config["sender_zmq_port"] == expected_adjusted
+
+ def test_nested_sender_info_resolves_correctly(self):
+ """Nested sender_info keyed by integer stage id should resolve
+ using recv_stages (engine_input_source → recv_from)."""
+ config = OmniKVCacheConfig(
+ connector_config={"type": "mock"},
+ stage_id=2,
+ engine_input_source=[1],
+ need_recv_cache=True,
+ from_tp=2,
+ to_tp=2,
+ )
+ with (
+ patch("vllm_omni.distributed.omni_connectors.kv_transfer_manager.get_local_tp_rank", return_value=0),
+ patch("vllm_omni.distributed.omni_connectors.kv_transfer_manager.get_tp_world_size", return_value=2),
+ ):
+ mgr = OmniKVTransferManager(config)
+ mgr.update_sender_info(
+ {
+ 0: {"host": "10.0.0.1", "zmq_port": 50151},
+ 1: {"host": "10.0.0.2", "zmq_port": 50152},
+ }
+ )
+ assert mgr._sender_base_host == "10.0.0.2"
+ assert mgr._sender_base_zmq_port == 50152
+
+
+# ── receive path constructs per-rank metadata ────────────────────────
+
+
+class TestReceiveConstructsMetadata:
+ """Verify that receive_kv_cache_for_request passes metadata with
+ correct (host, port) to connector.get() for heterogeneous TP."""
+
+ def test_tp1_no_metadata_passed(self):
+ """TP=1: connector.get() should be called WITHOUT metadata."""
+ mgr = _make_manager(from_tp=1, to_tp=1, local_rank=0, recv_timeout=0.05)
+ mgr.update_sender_info({"host": "10.0.0.1", "zmq_port": 50151})
+
+ calls = []
+
+ class _Connector:
+ def get(self, from_stage, to_stage, get_key, metadata=None):
+ calls.append({"key": get_key, "metadata": metadata})
+ return None
+
+ mgr._connector = _Connector()
+ mgr.receive_kv_cache_for_request("req-1")
+
+ assert len(calls) > 0
+ assert calls[0]["metadata"] is None
+
+ def test_homogeneous_tp2_rank0_passes_metadata(self):
+ """TP=2 rank 0: metadata should point to sender rank 0's port."""
+ mgr = _make_manager(from_tp=2, to_tp=2, local_rank=0, recv_timeout=0.05)
+ mgr.update_sender_info({"host": "10.0.0.1", "zmq_port": 50151})
+
+ calls = []
+
+ class _Connector:
+ def get(self, from_stage, to_stage, get_key, metadata=None):
+ calls.append({"key": get_key, "metadata": metadata})
+ return None
+
+ mgr._connector = _Connector()
+ mgr.receive_kv_cache_for_request("req-1")
+
+ assert len(calls) > 0
+ meta = calls[0]["metadata"]
+ assert meta is not None
+ assert meta["source_host"] == "10.0.0.1"
+ assert meta["source_port"] == 50151 + 0 * KV_RANK_PORT_STRIDE
+
+ def test_homogeneous_tp2_rank1_passes_metadata_with_offset(self):
+ mgr = _make_manager(from_tp=2, to_tp=2, local_rank=1, recv_timeout=0.05)
+ mgr.update_sender_info({"host": "10.0.0.1", "zmq_port": 50151})
+
+ calls = []
+
+ class _Connector:
+ def get(self, from_stage, to_stage, get_key, metadata=None):
+ calls.append({"key": get_key, "metadata": metadata})
+ return None
+
+ mgr._connector = _Connector()
+ mgr.receive_kv_cache_for_request("req-1")
+
+ meta = calls[0]["metadata"]
+ assert meta["source_port"] == 50151 + 1 * KV_RANK_PORT_STRIDE
+
+ def test_heterogeneous_tp4_to_tp2_rank0_multiple_metadata(self):
+ """Receiver rank 0 with source_tp=4, target_tp=2 should call get() with
+ two different metadata entries for sender ranks 0 and 1."""
+ mgr = _make_manager(from_tp=4, to_tp=2, local_rank=0, recv_timeout=0.05)
+ mgr.update_sender_info({"host": "10.0.0.1", "zmq_port": 50151})
+
+ calls = []
+
+ class _Connector:
+ def get(self, from_stage, to_stage, get_key, metadata=None):
+ calls.append({"key": get_key, "metadata": metadata})
+ return None
+
+ mgr._connector = _Connector()
+ mgr.receive_kv_cache_for_request("req-1")
+
+ seen_ports = set()
+ for c in calls:
+ if c["metadata"]:
+ seen_ports.add(c["metadata"]["source_port"])
+ expected_ports = {
+ 50151 + 0 * KV_RANK_PORT_STRIDE,
+ 50151 + 1 * KV_RANK_PORT_STRIDE,
+ }
+ assert expected_ports.issubset(seen_ports)
+
+
+# ── Mooncake connector _query_metadata_at ────────────────────────────
+
+
+class TestMooncakeQueryMetadataAt:
+ """Test the connector's _query_metadata_at method and partial-metadata
+ path in get() without requiring real RDMA/Mooncake."""
+
+ def test_query_metadata_at_returns_full_metadata(self):
+ """Mock the ZMQ interaction to verify _query_metadata_at returns
+ complete metadata including data_size."""
+
+ try:
+ from vllm_omni.distributed.omni_connectors.connectors.mooncake_transfer_engine_connector import (
+ MooncakeTransferEngineConnector,
+ QueryResponse,
+ )
+ except ImportError:
+ pytest.skip("Mooncake not available")
+
+ import msgspec
+
+ connector = MagicMock(spec=MooncakeTransferEngineConnector)
+ connector._get_req_socket = MagicMock()
+
+ mock_socket = MagicMock()
+ resp = QueryResponse(request_id="test_key@s0_s1", data_size=4096, is_fast_path=True)
+ mock_socket.recv.return_value = msgspec.msgpack.encode(resp)
+ connector._get_req_socket.return_value = mock_socket
+
+ result = MooncakeTransferEngineConnector._query_metadata_at(
+ connector,
+ "test_key@s0_s1",
+ "10.0.0.1",
+ 50151,
+ )
+
+ assert result is not None
+ assert result["source_host"] == "10.0.0.1"
+ assert result["source_port"] == 50151
+ assert result["data_size"] == 4096
+ assert result["is_fast_path"] is True
+
+ def test_query_metadata_at_returns_none_on_not_found(self):
+ try:
+ from vllm_omni.distributed.omni_connectors.connectors.mooncake_transfer_engine_connector import (
+ INFO_NOT_FOUND,
+ MooncakeTransferEngineConnector,
+ )
+ except ImportError:
+ pytest.skip("Mooncake not available")
+
+ connector = MagicMock(spec=MooncakeTransferEngineConnector)
+ mock_socket = MagicMock()
+ mock_socket.recv.return_value = INFO_NOT_FOUND
+ connector._get_req_socket.return_value = mock_socket
+
+ result = MooncakeTransferEngineConnector._query_metadata_at(
+ connector,
+ "test_key@s0_s1",
+ "10.0.0.1",
+ 50151,
+ )
+ assert result is None
+
+
+# ── Merge / slice hooks ──────────────────────────────────────────────
+
+
+class TestMergeSliceHooks:
+ def test_single_shard_passes_through(self):
+ payload = {"layer_blocks": {"key_cache": [1]}}
+ assert merge_received_rank_shards([payload]) == payload
+
+ def test_default_merger_concats_head_dim(self):
+ p0 = _make_payload([0.0])
+ p1 = _make_payload([1.0])
+ result = merge_received_rank_shards([p0, p1])
+ key_cache = result["layer_blocks"]["key_cache"][0]
+ value_cache = result["layer_blocks"]["value_cache"][0]
+ assert key_cache.shape == (2, 2, 1)
+ assert value_cache.shape == (2, 2, 1)
+ assert torch.equal(key_cache[:, :, 0], torch.tensor([[0.0, 1.0], [0.0, 1.0]]))
+ assert torch.equal(value_cache[:, :, 0], torch.tensor([[100.0, 101.0], [100.0, 101.0]]))
+
+ def test_custom_merger_hook_called(self):
+ merged = {"merged": True}
+ assert merge_received_rank_shards([{}, {}], merger=lambda payloads: merged) == merged
+
+ def test_slicer_hook_called(self):
+ topo = KVTPTopology(source_tp_size=2, target_tp_size=4, local_rank=0)
+ sliced = {"sliced": True}
+ assert slice_received_rank_shard({"full": True}, topo, slicer=lambda payload: sliced) == sliced
+
+ def test_default_slicer_extracts_rank_local_heads(self):
+ topo = KVTPTopology(source_tp_size=2, target_tp_size=4, local_rank=1)
+ payload = _make_payload([0.0, 1.0])
+ result = slice_received_rank_shard(payload, topo)
+ key_cache = result["layer_blocks"]["key_cache"][0]
+ value_cache = result["layer_blocks"]["value_cache"][0]
+ assert key_cache.shape == (2, 1, 1)
+ assert value_cache.shape == (2, 1, 1)
+ assert torch.equal(key_cache[:, :, 0], torch.tensor([[1.0], [1.0]]))
+ assert torch.equal(value_cache[:, :, 0], torch.tensor([[101.0], [101.0]]))
+
+ def test_presliced_payload_is_not_sliced_twice(self):
+ topo = KVTPTopology(source_tp_size=2, target_tp_size=4, local_rank=1)
+ payload = _make_payload([1.0])
+ payload["metadata"]["tp_head_slice"] = {"applied": True, "target_rank": 1}
+ result = slice_received_rank_shard(payload, topo)
+ assert result is payload
+
+ def test_round_trip_merge_from_tp4_to_tp2(self):
+ topo = KVTPTopology(source_tp_size=4, target_tp_size=2, local_rank=1)
+ source_ranks = get_kv_source_ranks(topo)
+ payloads = [_make_payload([float(rank)]) for rank in source_ranks]
+ result = merge_received_rank_shards(payloads)
+ key_cache = result["layer_blocks"]["key_cache"][0]
+ assert torch.equal(key_cache[:, :, 0], torch.tensor([[2.0, 3.0], [2.0, 3.0]]))
+
+ def test_round_trip_slice_from_tp2_to_tp4(self):
+ topo = KVTPTopology(source_tp_size=2, target_tp_size=4, local_rank=3)
+ payload = _make_payload([2.0, 3.0])
+ result = slice_received_rank_shard(payload, topo)
+ key_cache = result["layer_blocks"]["key_cache"][0]
+ assert torch.equal(key_cache[:, :, 0], torch.tensor([[3.0], [3.0]]))
+
+
+class TestSenderSideSlicing:
+ def test_transfer_slices_before_sending_to_multiple_targets(self):
+ mgr = _make_manager(
+ from_tp=2,
+ to_tp=4,
+ local_rank=0,
+ need_send=True,
+ need_recv=False,
+ )
+ sent_payloads = []
+
+ class _Connector:
+ supports_raw_data = False
+
+ def put(self, from_stage, to_stage, put_key, data):
+ sent_payloads.append((put_key, KVCacheTransferData.from_bytes(data)))
+ return True, len(data), {}
+
+ mgr._connector = _Connector()
+ mgr._transfer_kv_cache(_make_transfer_data([0.0, 1.0]), "req-1")
+
+ assert [key for key, _ in sent_payloads] == ["req-1_stage0_0_0_0", "req-1_stage0_0_0_1"]
+ assert sent_payloads[0][1]["layer_blocks"]["key_cache"][0].shape == (2, 1, 1)
+ assert sent_payloads[1][1]["layer_blocks"]["key_cache"][0].shape == (2, 1, 1)
+ assert torch.equal(
+ sent_payloads[0][1]["layer_blocks"]["key_cache"][0][:, :, 0],
+ torch.tensor([[0.0], [0.0]]),
+ )
+ assert torch.equal(
+ sent_payloads[1][1]["layer_blocks"]["key_cache"][0][:, :, 0],
+ torch.tensor([[1.0], [1.0]]),
+ )
+ assert sent_payloads[0][1]["metadata"]["tp_head_slice"]["target_rank"] == 0
+ assert sent_payloads[1][1]["metadata"]["tp_head_slice"]["target_rank"] == 1
+
+
+class _MockBroadcastGroup:
+ def __init__(self, world_size: int, rank_in_group: int, broadcast_value=None, recv_value=None):
+ self.world_size = world_size
+ self.rank_in_group = rank_in_group
+ self.broadcast_value = broadcast_value
+ self.recv_value = recv_value
+ self.broadcast_calls = []
+ self.send_calls = []
+ self.recv_calls = []
+ self.shm_broadcaster = None
+
+ def broadcast_object(self, obj=None, src: int = 0):
+ self.broadcast_calls.append((obj, src))
+ return self.broadcast_value if self.broadcast_value is not None else obj
+
+ def send_object(self, obj, dst: int):
+ self.send_calls.append((dst, obj))
+
+ def recv_object(self, src: int):
+ self.recv_calls.append(src)
+ return self.recv_value
+
+
+class TestDistributedReceive:
+ def test_tp_cfg_leader_receives_then_sends_branch_local_payloads(self):
+ mgr = _make_manager(from_tp=2, to_tp=4, local_rank=0)
+ req = SimpleNamespace(request_id="req-1", sampling_params=SimpleNamespace())
+ world_group = _MockBroadcastGroup(world_size=4, rank_in_group=2)
+ cfg_group = _MockBroadcastGroup(world_size=3, rank_in_group=0)
+
+ def _receive(req_obj, cfg_func, target_device):
+ req_obj.past_key_values = SimpleNamespace(key_cache=[torch.tensor([1.0])])
+ req_obj.kv_metadata = {"source": "leader"}
+ req_obj.sampling_params.past_key_values = req_obj.past_key_values
+ req_obj.sampling_params.kv_metadata = req_obj.kv_metadata
+ req_obj.sampling_params.cfg_text_past_key_values = SimpleNamespace(key_cache=[torch.tensor([2.0])])
+ req_obj.sampling_params.cfg_text_kv_metadata = {"source": "cfg_text"}
+ req_obj.sampling_params.cfg_img_past_key_values = SimpleNamespace(key_cache=[torch.tensor([3.0])])
+ req_obj.sampling_params.cfg_img_kv_metadata = {"source": "cfg_img"}
+ return True
+
+ mgr.receive_multi_kv_cache = MagicMock(side_effect=_receive)
+ with (
+ patch("vllm_omni.diffusion.distributed.parallel_state.get_world_group", return_value=world_group),
+ patch(
+ "vllm_omni.diffusion.distributed.parallel_state.get_classifier_free_guidance_world_size",
+ return_value=3,
+ ),
+ patch(
+ "vllm_omni.diffusion.distributed.parallel_state.get_classifier_free_guidance_rank",
+ return_value=0,
+ ),
+ patch("vllm_omni.diffusion.distributed.parallel_state.get_cfg_group", return_value=cfg_group),
+ ):
+ assert mgr.receive_multi_kv_cache_distributed(req) is True
+
+ mgr.receive_multi_kv_cache.assert_called_once()
+ assert mgr.receive_multi_kv_cache.call_args.args[2] == torch.device("cpu")
+ assert req.kv_metadata == {"source": "leader"}
+ assert cfg_group.broadcast_calls == []
+ assert [dst for dst, _ in cfg_group.send_calls] == [1, 2]
+ rank1_payload = cfg_group.send_calls[0][1]
+ rank2_payload = cfg_group.send_calls[1][1]
+ assert torch.equal(rank1_payload["past_key_values"].key_cache[0], torch.tensor([1.0]))
+ assert torch.equal(rank2_payload["past_key_values"].key_cache[0], torch.tensor([1.0]))
+ assert rank1_payload["sp.cfg_active_branch"] == "cfg_text"
+ assert rank2_payload["sp.cfg_active_branch"] == "cfg_img"
+ assert rank1_payload["sp.cfg_branch_roles"] == ["cfg_text", "cfg_img"]
+ assert rank2_payload["sp.cfg_branch_roles"] == ["cfg_text", "cfg_img"]
+ assert "sp.cfg_branch_past_key_values" in rank1_payload
+ assert "sp.cfg_branch_past_key_values" in rank2_payload
+ assert list(rank1_payload["sp.cfg_branch_past_key_values"].keys()) == ["cfg_text"]
+ assert list(rank2_payload["sp.cfg_branch_past_key_values"].keys()) == ["cfg_img"]
+ assert "sp.cfg_text_past_key_values" in rank1_payload
+ assert "sp.cfg_img_past_key_values" not in rank1_payload
+ assert "sp.cfg_img_past_key_values" in rank2_payload
+ assert "sp.cfg_text_past_key_values" not in rank2_payload
+
+ def test_tp_cfg_follower_receives_local_payload_without_receiving(self):
+ mgr = _make_manager(from_tp=2, to_tp=4, local_rank=1)
+ req = SimpleNamespace(request_id="req-1", sampling_params=SimpleNamespace())
+ world_group = _MockBroadcastGroup(world_size=4, rank_in_group=3)
+ cfg_payload = {
+ "past_key_values": SimpleNamespace(key_cache=[torch.tensor([1.0])]),
+ "kv_metadata": {"source": "main"},
+ "sp.past_key_values": SimpleNamespace(key_cache=[torch.tensor([1.0])]),
+ "sp.kv_metadata": {"source": "main"},
+ "sp.cfg_active_branch": "cfg_text",
+ "sp.cfg_branch_roles": ["cfg_text", "cfg_img"],
+ "sp.cfg_branch_past_key_values": {
+ "cfg_text": SimpleNamespace(key_cache=[torch.tensor([2.0])]),
+ },
+ "sp.cfg_branch_kv_metadata": {"cfg_text": {"source": "cfg-text"}},
+ "sp.cfg_text_past_key_values": SimpleNamespace(key_cache=[torch.tensor([2.0])]),
+ }
+ cfg_group = _MockBroadcastGroup(world_size=2, rank_in_group=1, recv_value=cfg_payload)
+
+ mgr.receive_multi_kv_cache = MagicMock(return_value=True)
+ with (
+ patch("vllm_omni.diffusion.distributed.parallel_state.get_world_group", return_value=world_group),
+ patch(
+ "vllm_omni.diffusion.distributed.parallel_state.get_classifier_free_guidance_world_size",
+ return_value=2,
+ ),
+ patch(
+ "vllm_omni.diffusion.distributed.parallel_state.get_classifier_free_guidance_rank",
+ return_value=1,
+ ),
+ patch("vllm_omni.diffusion.distributed.parallel_state.get_cfg_group", return_value=cfg_group),
+ ):
+ assert mgr.receive_multi_kv_cache_distributed(req) is True
+
+ mgr.receive_multi_kv_cache.assert_not_called()
+ assert req.kv_metadata == {"source": "main"}
+ assert torch.equal(req.past_key_values.key_cache[0], torch.tensor([1.0]))
+ assert torch.equal(req.sampling_params.past_key_values.key_cache[0], torch.tensor([1.0]))
+ assert req.sampling_params.cfg_active_branch == "cfg_text"
+ assert req.sampling_params.cfg_branch_roles == ["cfg_text", "cfg_img"]
+ assert torch.equal(
+ req.sampling_params.cfg_branch_past_key_values["cfg_text"].key_cache[0],
+ torch.tensor([2.0]),
+ )
+ assert req.sampling_params.cfg_branch_kv_metadata == {"cfg_text": {"source": "cfg-text"}}
+ assert torch.equal(req.sampling_params.cfg_text_past_key_values.key_cache[0], torch.tensor([2.0]))
+ assert cfg_group.broadcast_calls == []
+ assert cfg_group.recv_calls == [0]
+
+ def test_tp_without_cfg_keeps_independent_receive_path(self):
+ mgr = _make_manager(from_tp=2, to_tp=2, local_rank=1)
+ req = SimpleNamespace(request_id="req-1", sampling_params=SimpleNamespace())
+ world_group = _MockBroadcastGroup(world_size=2, rank_in_group=1)
+ mgr.receive_multi_kv_cache = MagicMock(return_value=True)
+
+ with patch("vllm_omni.diffusion.distributed.parallel_state.get_world_group", return_value=world_group):
+ assert mgr.receive_multi_kv_cache_distributed(req, target_device=torch.device("cpu")) is True
+
+ mgr.receive_multi_kv_cache.assert_called_once_with(req, None, torch.device("cpu"))
+
+
+# ── TP auto-detect ───────────────────────────────────────────────────
+
+
+class TestAutoDetectTP:
+ def test_auto_detect_when_config_defaults(self):
+ """When config from_tp/to_tp == 1 (default), manager should auto-detect."""
+ config = OmniKVCacheConfig(
+ connector_config={"type": "mock"},
+ from_stage="s0",
+ stage_id="s1",
+ need_recv_cache=True,
+ )
+ with (
+ patch("vllm_omni.distributed.omni_connectors.kv_transfer_manager.get_local_tp_rank", return_value=0),
+ patch("vllm_omni.distributed.omni_connectors.kv_transfer_manager.get_tp_world_size", return_value=4),
+ ):
+ mgr = OmniKVTransferManager(config)
+ assert mgr._tp_topo.source_tp_size == 4
+ assert mgr._tp_topo.target_tp_size == 4
+
+ def test_explicit_tp_overrides_auto_detect(self):
+ config = OmniKVCacheConfig(
+ connector_config={"type": "mock"},
+ from_stage="s0",
+ stage_id="s1",
+ need_recv_cache=True,
+ from_tp=2,
+ to_tp=4,
+ )
+ with (
+ patch("vllm_omni.distributed.omni_connectors.kv_transfer_manager.get_local_tp_rank", return_value=0),
+ patch("vllm_omni.distributed.omni_connectors.kv_transfer_manager.get_tp_world_size", return_value=8),
+ ):
+ mgr = OmniKVTransferManager(config)
+ assert mgr._tp_topo.source_tp_size == 2
+ assert mgr._tp_topo.target_tp_size == 4
diff --git a/tests/e2e/accuracy/conftest.py b/tests/e2e/accuracy/conftest.py
index 0a81b02075..3d614b8cdc 100644
--- a/tests/e2e/accuracy/conftest.py
+++ b/tests/e2e/accuracy/conftest.py
@@ -5,10 +5,13 @@
import subprocess
from contextlib import contextmanager
from dataclasses import dataclass
+from io import BytesIO
from pathlib import Path
import pytest
+import requests
import torch
+from PIL import Image
from tests.conftest import OmniServer, OmniServerParams
@@ -114,8 +117,8 @@ def generate_server(self):
params = self.generate_params
model = self.model_prefix + params.model
server_args = params.server_args or []
- if params.use_omni:
- server_args = ["--stage-init-timeout", "120", *server_args]
+ if params.use_omni and params.stage_init_timeout is not None:
+ server_args = ["--stage-init-timeout", str(params.stage_init_timeout), *server_args]
with OmniServer(
model,
server_args,
@@ -183,6 +186,28 @@ def accuracy_artifact_root() -> Path:
return root
+@pytest.fixture(scope="session")
+def qwen_bear_image(accuracy_artifact_root: Path) -> Image.Image:
+ """Download the Qwen bear image from the URL and save it to the accuracy artifact root."""
+ QWEN_BEAR_IMAGE_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/omni-assets/qwen-bear.png"
+ response = requests.get(QWEN_BEAR_IMAGE_URL, timeout=60)
+ response.raise_for_status()
+ image = Image.open(BytesIO(response.content)).convert("RGB")
+ image.save(accuracy_artifact_root / "qwen_bear.png")
+ return image
+
+
+@pytest.fixture(scope="session")
+def rabbit_image(accuracy_artifact_root: Path) -> Image.Image:
+ """Download the rabbit image from the URL and save it to the accuracy artifact root."""
+ RABBIT_IMAGE_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/omni-assets/rabbit.png"
+ response = requests.get(RABBIT_IMAGE_URL, timeout=60)
+ response.raise_for_status()
+ image = Image.open(BytesIO(response.content)).convert("RGB")
+ image.save(accuracy_artifact_root / "rabbit.png")
+ return image
+
+
def reset_artifact_dir(path: Path) -> Path:
if path.exists():
shutil.rmtree(path)
@@ -226,6 +251,7 @@ def _build_accuracy_server_config(
server_args=generate_server_args,
env_dict={"CUDA_VISIBLE_DEVICES": shared_gpu},
use_omni=True,
+ stage_init_timeout=300,
),
judge_params=OmniServerParams(
model=judge_model,
diff --git a/tests/e2e/accuracy/test_gedit_bench_h100_smoke.py b/tests/e2e/accuracy/test_gedit_bench_h100_smoke.py
index ac5f2cb3cf..960ea57960 100644
--- a/tests/e2e/accuracy/test_gedit_bench_h100_smoke.py
+++ b/tests/e2e/accuracy/test_gedit_bench_h100_smoke.py
@@ -106,9 +106,9 @@ def test_gedit_bench_h100_smoke(
group_summary = language_summary["by_group"][group]
assert set(group_summary) == {"count", "Q_SC", "Q_PQ", "Q_O"}
- assert summary["languages"]["en"]["overall"]["Q_SC"] >= 7.0
+ assert summary["languages"]["en"]["overall"]["Q_SC"] >= 6.95
assert summary["languages"]["en"]["overall"]["Q_PQ"] >= 5.8
- assert summary["languages"]["en"]["overall"]["Q_O"] >= 6.2
+ assert summary["languages"]["en"]["overall"]["Q_O"] >= 6.15
assert summary["languages"]["cn"]["overall"]["Q_SC"] >= 6.9
assert summary["languages"]["cn"]["overall"]["Q_PQ"] >= 5.7
assert summary["languages"]["cn"]["overall"]["Q_O"] >= 6.1
diff --git a/tests/e2e/accuracy/test_qwen_image.py b/tests/e2e/accuracy/test_qwen_image.py
new file mode 100644
index 0000000000..e73195017a
--- /dev/null
+++ b/tests/e2e/accuracy/test_qwen_image.py
@@ -0,0 +1,124 @@
+from __future__ import annotations
+
+import base64
+import gc
+import io
+import os
+from pathlib import Path
+
+import pytest
+import requests
+import torch
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from PIL import Image
+
+from tests.conftest import (
+ OmniServer,
+ _run_post_test_cleanup,
+ _run_pre_test_cleanup,
+)
+from tests.e2e.accuracy.utils import assert_similarity, model_output_dir
+from tests.utils import hardware_test
+
+MODEL_ID = "Qwen/Qwen-Image"
+MODEL_ENV_VAR = "QWEN_IMAGE_MODEL"
+PROMPT = "A photo of a cat sitting on a laptop keyboard, digital art style."
+NEGATIVE_PROMPT = "blurry, low quality"
+WIDTH = 512
+HEIGHT = 512
+NUM_INFERENCE_STEPS = 20
+TRUE_CFG_SCALE = 4.0
+SEED = 42
+SSIM_THRESHOLD = 0.97
+PSNR_THRESHOLD = 30.0
+
+
+def _model_name() -> str:
+ return os.environ.get(MODEL_ENV_VAR, MODEL_ID)
+
+
+def _local_files_only(model: str) -> bool:
+ return Path(model).exists()
+
+
+def _run_vllm_omni_qwen_image(*, model: str, output_path: Path) -> Image.Image:
+ server_args = ["--num-gpus", "1", "--stage-init-timeout", "300", "--init-timeout", "900"]
+ with OmniServer(model, server_args, use_omni=True) as omni_server:
+ response = requests.post(
+ f"http://{omni_server.host}:{omni_server.port}/v1/images/generations",
+ json={
+ "model": omni_server.model,
+ "prompt": PROMPT,
+ "size": f"{WIDTH}x{HEIGHT}",
+ "n": 1,
+ "response_format": "b64_json",
+ "negative_prompt": NEGATIVE_PROMPT,
+ "num_inference_steps": NUM_INFERENCE_STEPS,
+ "true_cfg_scale": TRUE_CFG_SCALE,
+ "seed": SEED,
+ },
+ timeout=600,
+ )
+ response.raise_for_status()
+ payload = response.json()
+ assert len(payload["data"]) == 1
+ image_bytes = base64.b64decode(payload["data"][0]["b64_json"])
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
+ image.load()
+ image.save(output_path)
+ return image
+
+
+def _run_diffusers_qwen_image(*, model: str, output_path: Path) -> Image.Image:
+ _run_pre_test_cleanup(enable_force=True)
+ pipe: DiffusionPipeline | None = None
+ try:
+ pipe = DiffusionPipeline.from_pretrained(
+ model,
+ torch_dtype=torch.bfloat16,
+ trust_remote_code=True,
+ local_files_only=_local_files_only(model),
+ ).to("cuda")
+ generator = torch.Generator(device="cuda").manual_seed(SEED)
+ result = pipe( # pyright: ignore[reportCallIssue]
+ prompt=PROMPT,
+ negative_prompt=NEGATIVE_PROMPT,
+ width=WIDTH,
+ height=HEIGHT,
+ num_inference_steps=NUM_INFERENCE_STEPS,
+ true_cfg_scale=TRUE_CFG_SCALE,
+ generator=generator,
+ )
+ output_image = result.images[0].convert("RGB")
+ output_image.save(output_path)
+ return output_image
+ finally:
+ if pipe is not None and hasattr(pipe, "maybe_free_model_hooks"):
+ pipe.maybe_free_model_hooks()
+ del pipe
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ _run_post_test_cleanup(enable_force=True)
+
+
+@pytest.mark.advanced_model
+@pytest.mark.benchmark
+@pytest.mark.diffusion
+@hardware_test(res={"cuda": "H100"}, num_cards=1)
+def test_qwen_image_matches_diffusers(accuracy_artifact_root: Path) -> None:
+ model = _model_name()
+ output_dir = model_output_dir(accuracy_artifact_root, MODEL_ID)
+
+ vllm_output = _run_vllm_omni_qwen_image(model=model, output_path=output_dir / "vllm_omni.png")
+ diffusers_output = _run_diffusers_qwen_image(model=model, output_path=output_dir / "diffusers.png")
+
+ assert_similarity(
+ model_name=MODEL_ID,
+ vllm_image=vllm_output,
+ diffusers_image=diffusers_output,
+ width=WIDTH,
+ height=HEIGHT,
+ ssim_threshold=SSIM_THRESHOLD,
+ psnr_threshold=PSNR_THRESHOLD,
+ )
diff --git a/tests/e2e/accuracy/test_qwen_image_edit.py b/tests/e2e/accuracy/test_qwen_image_edit.py
new file mode 100644
index 0000000000..9a97010343
--- /dev/null
+++ b/tests/e2e/accuracy/test_qwen_image_edit.py
@@ -0,0 +1,232 @@
+from __future__ import annotations
+
+import gc
+from pathlib import Path
+
+import pytest
+import requests
+import torch
+from diffusers import QwenImageEditPipeline, QwenImageEditPlusPipeline
+from PIL import Image
+
+from benchmarks.accuracy.common import decode_base64_image, pil_to_png_bytes
+from tests.conftest import (
+ OmniServer,
+ _run_post_test_cleanup,
+ _run_pre_test_cleanup,
+)
+from tests.e2e.accuracy.utils import assert_similarity, model_output_dir
+from tests.utils import hardware_test
+
+SINGLE_MODEL = "Qwen/Qwen-Image-Edit"
+MULTIPLE_MODEL = "Qwen/Qwen-Image-Edit-2509"
+WIDTH = 512
+HEIGHT = 512
+NUM_INFERENCE_STEPS = 20
+TRUE_CFG_SCALE = 4.0
+SEED = 42
+SSIM_THRESHOLD = 0.94
+PSNR_THRESHOLD = 28.0
+
+PROMPT_SINGLE_IMAGE = "The input is a 2D cartoon bear mascot. Restyle it into a painterly oil artwork with warm colors while preserving the main structure."
+PROMPT_MULTIPLE_IMAGE = "Put the cartoon bear mascot and the furry rabbit into one coherent scene with a painterly oil artwork style and consistent lighting."
+NEGATIVE_PROMPT = "low quality, blurry, artifacts, distortion"
+SERVER_ARGS = ["--num-gpus", "1", "--stage-init-timeout", "300", "--init-timeout", "900"]
+
+
+def _run_vllm_omni_image_edit(
+ *,
+ omni_server: OmniServer,
+ prompt: str,
+ input_images: list[Image.Image],
+ output_path: Path,
+) -> Image.Image:
+ response = requests.post(
+ f"http://{omni_server.host}:{omni_server.port}/v1/images/edits",
+ data={
+ "model": omni_server.model,
+ "prompt": prompt,
+ "size": f"{WIDTH}x{HEIGHT}",
+ "n": 1,
+ "response_format": "b64_json",
+ "negative_prompt": NEGATIVE_PROMPT,
+ "num_inference_steps": NUM_INFERENCE_STEPS,
+ "true_cfg_scale": TRUE_CFG_SCALE,
+ "seed": SEED,
+ },
+ files=[
+ ("image", (f"image_{index}.png", pil_to_png_bytes(image), "image/png"))
+ for index, image in enumerate(input_images)
+ ],
+ timeout=600,
+ )
+ response.raise_for_status()
+ payload = response.json()
+ assert len(payload["data"]) == 1
+ image = decode_base64_image(payload["data"][0]["b64_json"])
+ image.load()
+ image.save(output_path)
+ return image
+
+
+def _run_diffusers_image_edit(
+ *,
+ model: str,
+ pipeline_class: type[QwenImageEditPipeline] | type[QwenImageEditPlusPipeline],
+ prompt: str,
+ input_images: list[Image.Image],
+ output_path: Path,
+) -> Image.Image:
+ _run_pre_test_cleanup(enable_force=True)
+ pipe: QwenImageEditPipeline | QwenImageEditPlusPipeline | None = None
+ device = torch.device("cuda:0")
+ torch.cuda.set_device(device)
+ try:
+ images = input_images[0] if len(input_images) == 1 else input_images
+ pipe = pipeline_class.from_pretrained(
+ model,
+ torch_dtype=torch.bfloat16,
+ trust_remote_code=True,
+ ).to(device)
+ pipe.set_progress_bar_config(disable=False)
+ generator = torch.Generator(device=device).manual_seed(SEED)
+ result = pipe( # pyright: ignore[reportCallIssue]
+ prompt=prompt,
+ image=images,
+ negative_prompt=NEGATIVE_PROMPT,
+ num_inference_steps=NUM_INFERENCE_STEPS,
+ true_cfg_scale=TRUE_CFG_SCALE,
+ width=WIDTH,
+ height=HEIGHT,
+ generator=generator,
+ )
+ output_image = result.images[0].convert("RGB") # pyright: ignore[reportAttributeAccessIssue]
+ output_image.save(output_path)
+ return output_image
+ finally:
+ if pipe is not None and hasattr(pipe, "maybe_free_model_hooks"):
+ pipe.maybe_free_model_hooks()
+ del pipe
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ _run_post_test_cleanup(enable_force=True)
+
+
+def _vllm_omni_output_single_image(
+ accuracy_artifact_root: Path,
+ qwen_bear_image: Image.Image,
+) -> Image.Image:
+ output_dir = model_output_dir(accuracy_artifact_root, SINGLE_MODEL)
+ output_path = output_dir / "vllm_omni_single.png"
+ with OmniServer(model=SINGLE_MODEL, serve_args=SERVER_ARGS) as server:
+ output = _run_vllm_omni_image_edit(
+ omni_server=server,
+ prompt=PROMPT_SINGLE_IMAGE,
+ input_images=[qwen_bear_image],
+ output_path=output_path,
+ )
+ return output
+
+
+def _diffusers_output_single_image(accuracy_artifact_root: Path, qwen_bear_image: Image.Image) -> Image.Image:
+ output_dir = model_output_dir(accuracy_artifact_root, SINGLE_MODEL)
+ output_path = output_dir / "diffusers_single.png"
+ return _run_diffusers_image_edit(
+ model=SINGLE_MODEL,
+ pipeline_class=QwenImageEditPipeline,
+ prompt=PROMPT_SINGLE_IMAGE,
+ input_images=[qwen_bear_image],
+ output_path=output_path,
+ )
+
+
+def _vllm_omni_output_multiple_image(
+ accuracy_artifact_root: Path,
+ qwen_bear_image: Image.Image,
+ rabbit_image: Image.Image,
+) -> Image.Image:
+ output_dir = model_output_dir(accuracy_artifact_root, MULTIPLE_MODEL)
+ output_path = output_dir / "vllm_omni_multiple.png"
+ with OmniServer(model=MULTIPLE_MODEL, serve_args=SERVER_ARGS) as server:
+ output = _run_vllm_omni_image_edit(
+ omni_server=server,
+ prompt=PROMPT_MULTIPLE_IMAGE,
+ input_images=[qwen_bear_image, rabbit_image],
+ output_path=output_path,
+ )
+ return output
+
+
+def _diffusers_output_multiple_image(
+ accuracy_artifact_root: Path, qwen_bear_image: Image.Image, rabbit_image: Image.Image
+) -> Image.Image:
+ output_dir = model_output_dir(accuracy_artifact_root, MULTIPLE_MODEL)
+ output_path = output_dir / "diffusers_multiple.png"
+ return _run_diffusers_image_edit(
+ model=MULTIPLE_MODEL,
+ pipeline_class=QwenImageEditPlusPipeline,
+ prompt=PROMPT_MULTIPLE_IMAGE,
+ input_images=[qwen_bear_image, rabbit_image],
+ output_path=output_path,
+ )
+
+
+@pytest.mark.advanced_model
+@pytest.mark.benchmark
+@pytest.mark.diffusion
+@hardware_test(res={"cuda": "H100"}, num_cards=1)
+def test_qwen_image_edit_single_matches_diffusers(
+ accuracy_artifact_root: Path,
+ qwen_bear_image: Image.Image,
+) -> None:
+ vllm_image = _vllm_omni_output_single_image(
+ accuracy_artifact_root=accuracy_artifact_root,
+ qwen_bear_image=qwen_bear_image,
+ )
+ diffusers_image = _diffusers_output_single_image(
+ accuracy_artifact_root=accuracy_artifact_root,
+ qwen_bear_image=qwen_bear_image,
+ )
+ assert_similarity(
+ model_name=SINGLE_MODEL,
+ vllm_image=vllm_image,
+ diffusers_image=diffusers_image,
+ width=WIDTH,
+ height=HEIGHT,
+ ssim_threshold=SSIM_THRESHOLD,
+ psnr_threshold=PSNR_THRESHOLD,
+ )
+
+
+@pytest.mark.advanced_model
+@pytest.mark.benchmark
+@pytest.mark.diffusion
+@hardware_test(res={"cuda": "H100"}, num_cards=1)
+@pytest.mark.skip(
+ reason="Skipping as the second image seems to be ignored by the API. Will come back to this later after #2772 is merged."
+)
+def test_qwen_image_edit_multiple_matches_diffusers(
+ accuracy_artifact_root: Path,
+ qwen_bear_image: Image.Image,
+ rabbit_image: Image.Image,
+) -> None:
+ vllm_image = _vllm_omni_output_multiple_image(
+ accuracy_artifact_root=accuracy_artifact_root,
+ qwen_bear_image=qwen_bear_image,
+ rabbit_image=rabbit_image,
+ )
+ diffusers_image = _diffusers_output_multiple_image(
+ accuracy_artifact_root=accuracy_artifact_root,
+ qwen_bear_image=qwen_bear_image,
+ rabbit_image=rabbit_image,
+ )
+ assert_similarity(
+ model_name=MULTIPLE_MODEL,
+ vllm_image=vllm_image,
+ diffusers_image=diffusers_image,
+ width=WIDTH,
+ height=HEIGHT,
+ ssim_threshold=SSIM_THRESHOLD,
+ psnr_threshold=PSNR_THRESHOLD,
+ )
diff --git a/tests/e2e/accuracy/test_qwen_image_layered.py b/tests/e2e/accuracy/test_qwen_image_layered.py
new file mode 100644
index 0000000000..04b13df3bb
--- /dev/null
+++ b/tests/e2e/accuracy/test_qwen_image_layered.py
@@ -0,0 +1,151 @@
+from __future__ import annotations
+
+import base64
+import gc
+import io
+import os
+from pathlib import Path
+
+import pytest
+import requests
+import torch
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from PIL import Image
+
+from tests.conftest import (
+ OmniServer,
+ _run_post_test_cleanup,
+ _run_pre_test_cleanup,
+)
+from tests.e2e.accuracy.utils import assert_image_sequence_similarity, model_output_dir
+from tests.utils import hardware_test
+
+MODEL_ID = "Qwen/Qwen-Image-Layered"
+MODEL_ENV_VAR = "QWEN_IMAGE_LAYERED_MODEL"
+PROMPT = "decompose into layers"
+NEGATIVE_PROMPT = " "
+NUM_INFERENCE_STEPS = 20
+TRUE_CFG_SCALE = 4.0
+SEED = 777
+LAYERS = 3
+RESOLUTION = 640
+SSIM_THRESHOLD = 0.97
+PSNR_THRESHOLD = 30.0
+
+
+def _model_name() -> str:
+ return os.environ.get(MODEL_ENV_VAR, MODEL_ID)
+
+
+def _local_files_only(model: str) -> bool:
+ return Path(model).exists()
+
+
+def _normalize_layered_images(images: object) -> list[Image.Image]:
+ if not isinstance(images, list) or not images:
+ raise AssertionError(f"Unexpected layered output container: {type(images).__name__}")
+
+ first_item = images[0]
+ if isinstance(first_item, Image.Image):
+ return [image.convert("RGBA") for image in images if isinstance(image, Image.Image)]
+ if isinstance(first_item, (list, tuple)):
+ return [image.convert("RGBA") for image in first_item if isinstance(image, Image.Image)]
+ raise AssertionError(f"Unexpected layered image element type: {type(first_item).__name__}")
+
+
+def _run_vllm_omni_qwen_image_layered(*, model: str, input_image: Image.Image, output_dir: Path) -> list[Image.Image]:
+ input_image.save(output_dir / "input.png")
+ server_args = ["--num-gpus", "1", "--stage-init-timeout", "300", "--init-timeout", "900"]
+ with OmniServer(model, server_args, use_omni=True) as omni_server:
+ buffer = io.BytesIO()
+ input_image.save(buffer, format="PNG")
+ buffer.seek(0)
+ response = requests.post(
+ f"http://{omni_server.host}:{omni_server.port}/v1/images/edits",
+ data={
+ "model": omni_server.model,
+ "prompt": PROMPT,
+ "size": "auto",
+ "n": 1,
+ "response_format": "b64_json",
+ "negative_prompt": NEGATIVE_PROMPT,
+ "num_inference_steps": NUM_INFERENCE_STEPS,
+ "true_cfg_scale": TRUE_CFG_SCALE,
+ "seed": SEED,
+ "layers": LAYERS,
+ "resolution": RESOLUTION,
+ },
+ files=[("image", ("input.png", buffer, "image/png"))],
+ timeout=600,
+ )
+ response.raise_for_status()
+ payload = response.json()
+ assert len(payload["data"]) == LAYERS
+ output_images = []
+ for item in payload["data"]:
+ image_bytes = base64.b64decode(item["b64_json"])
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGBA")
+ image.load()
+ output_images.append(image)
+ for index, image in enumerate(output_images, start=1):
+ image.save(output_dir / f"vllm_omni_layer_{index}.png")
+ return output_images
+
+
+def _run_diffusers_qwen_image_layered(*, model: str, input_image: Image.Image, output_dir: Path) -> list[Image.Image]:
+ _run_pre_test_cleanup(enable_force=True)
+ pipe: DiffusionPipeline | None = None
+ try:
+ pipe = DiffusionPipeline.from_pretrained(
+ model,
+ torch_dtype=torch.bfloat16,
+ trust_remote_code=True,
+ local_files_only=_local_files_only(model),
+ ).to("cuda")
+ generator = torch.Generator(device="cuda").manual_seed(SEED)
+ result = pipe( # pyright: ignore[reportCallIssue]
+ image=input_image,
+ prompt=PROMPT,
+ negative_prompt=NEGATIVE_PROMPT,
+ num_inference_steps=NUM_INFERENCE_STEPS,
+ true_cfg_scale=TRUE_CFG_SCALE,
+ generator=generator,
+ num_images_per_prompt=1,
+ layers=LAYERS,
+ resolution=RESOLUTION,
+ )
+ output_images = _normalize_layered_images(result.images)
+ assert len(output_images) == LAYERS, f"Expected {LAYERS} diffusers layers, got {len(output_images)}"
+ for index, image in enumerate(output_images, start=1):
+ image.save(output_dir / f"diffusers_layer_{index}.png")
+ return output_images
+ finally:
+ if pipe is not None and hasattr(pipe, "maybe_free_model_hooks"):
+ pipe.maybe_free_model_hooks()
+ del pipe
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ _run_post_test_cleanup(enable_force=True)
+
+
+@pytest.mark.advanced_model
+@pytest.mark.benchmark
+@pytest.mark.diffusion
+@hardware_test(res={"cuda": "H100"}, num_cards=1)
+def test_qwen_image_layered_matches_diffusers(accuracy_artifact_root: Path, qwen_bear_image: Image.Image) -> None:
+ model = _model_name()
+ output_dir = model_output_dir(accuracy_artifact_root, MODEL_ID)
+ input_image = qwen_bear_image.convert("RGBA")
+
+ vllm_outputs = _run_vllm_omni_qwen_image_layered(model=model, input_image=input_image, output_dir=output_dir)
+ diffusers_outputs = _run_diffusers_qwen_image_layered(model=model, input_image=input_image, output_dir=output_dir)
+
+ assert_image_sequence_similarity(
+ model_name=MODEL_ID,
+ vllm_images=vllm_outputs,
+ diffusers_images=diffusers_outputs,
+ ssim_threshold=SSIM_THRESHOLD,
+ psnr_threshold=PSNR_THRESHOLD,
+ compare_mode="RGBA",
+ )
diff --git a/tests/e2e/accuracy/utils.py b/tests/e2e/accuracy/utils.py
new file mode 100644
index 0000000000..d722b69b01
--- /dev/null
+++ b/tests/e2e/accuracy/utils.py
@@ -0,0 +1,103 @@
+from __future__ import annotations
+
+from pathlib import Path
+
+import numpy as np
+import pytest
+import torch
+from PIL import Image
+from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
+
+
+def model_output_dir(parent_dir: Path, model: str) -> Path:
+ safe_model_name = model.split("/")[-1].replace(".", "_")
+ path = parent_dir / safe_model_name
+ path.mkdir(parents=True, exist_ok=True)
+ return path
+
+
+def assert_similarity(
+ *,
+ model_name: str,
+ vllm_image: Image.Image,
+ diffusers_image: Image.Image,
+ ssim_threshold: float,
+ psnr_threshold: float,
+ width: int | None = None,
+ height: int | None = None,
+ compare_mode: str = "RGB",
+) -> None:
+ requested_size = (width, height) if width is not None and height is not None else None
+ if requested_size is not None and diffusers_image.size != requested_size:
+ pytest.skip(
+ "Skipping as diffusers baseline output is corrupt and not comparable: "
+ f"dimensions do not match requested size; requested={requested_size}, got={diffusers_image.size}."
+ )
+
+ assert vllm_image.size == diffusers_image.size, (
+ f"Online and diffusers output sizes mismatch: online={vllm_image.size}, diffusers={diffusers_image.size}"
+ )
+
+ ssim_score, psnr_score = compute_image_ssim_psnr(
+ prediction=vllm_image,
+ reference=diffusers_image,
+ compare_mode=compare_mode,
+ )
+ print(f"{model_name} similarity metrics:")
+ print(f" SSIM: value={ssim_score:.6f}, threshold>={ssim_threshold:.6f}, range=[-1, 1], higher_is_better=True")
+ print(
+ f" PSNR: value={psnr_score:.6f} dB, threshold>={psnr_threshold:.6f} dB, range=[0, +inf), higher_is_better=True"
+ )
+
+ assert ssim_score >= ssim_threshold, (
+ f"SSIM below threshold for {model_name}: got {ssim_score:.6f}, expected >= {ssim_threshold:.6f}."
+ )
+ assert psnr_score >= psnr_threshold, (
+ f"PSNR below threshold for {model_name}: got {psnr_score:.6f}, expected >= {psnr_threshold:.6f}."
+ )
+
+
+def assert_image_sequence_similarity(
+ *,
+ model_name: str,
+ vllm_images: list[Image.Image],
+ diffusers_images: list[Image.Image],
+ ssim_threshold: float,
+ psnr_threshold: float,
+ compare_mode: str = "RGB",
+) -> None:
+ assert len(vllm_images) == len(diffusers_images), (
+ f"Output image count mismatch for {model_name}: online={len(vllm_images)}, diffusers={len(diffusers_images)}"
+ )
+ for index, (vllm_image, diffusers_image) in enumerate(zip(vllm_images, diffusers_images, strict=True), start=1):
+ assert_similarity(
+ model_name=f"{model_name}[layer={index}]",
+ vllm_image=vllm_image,
+ diffusers_image=diffusers_image,
+ ssim_threshold=ssim_threshold,
+ psnr_threshold=psnr_threshold,
+ compare_mode=compare_mode,
+ )
+
+
+def compute_image_ssim_psnr(
+ *,
+ prediction: Image.Image,
+ reference: Image.Image,
+ compare_mode: str = "RGB",
+) -> tuple[float, float]:
+ pred_tensor = _pil_to_batched_tensor(prediction, compare_mode=compare_mode)
+ ref_tensor = _pil_to_batched_tensor(reference, compare_mode=compare_mode)
+
+ ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0)
+ psnr_metric = PeakSignalNoiseRatio(data_range=1.0)
+
+ ssim_value = float(ssim_metric(pred_tensor, ref_tensor).item())
+ psnr_value = float(psnr_metric(pred_tensor, ref_tensor).item())
+ return ssim_value, psnr_value
+
+
+def _pil_to_batched_tensor(image: Image.Image, *, compare_mode: str) -> torch.Tensor:
+ array = np.asarray(image.convert(compare_mode), dtype=np.float32) / 255.0
+ tensor = torch.from_numpy(array).permute(2, 0, 1).unsqueeze(0)
+ return tensor
diff --git a/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py b/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py
index 3cdda1f9ff..bec82e0257 100644
--- a/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py
+++ b/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py
@@ -567,6 +567,7 @@ def test_wan22_i2v_diffusers_offline_generates_video(
@pytest.mark.benchmark
@pytest.mark.diffusion
@hardware_test(res={"cuda": "H100"}, num_cards=2)
+@pytest.mark.skip(reason="issue: #2874")
@pytest.mark.parametrize("omni_server", SERVER_CASES, indirect=True)
def test_wan22_i2v_online_serving_generates_video(
omni_server,
diff --git a/tests/e2e/offline_inference/stage_configs/bagel_mooncake_ci.yaml b/tests/e2e/offline_inference/stage_configs/bagel_mooncake_ci.yaml
index 590244acd2..b7768c071f 100644
--- a/tests/e2e/offline_inference/stage_configs/bagel_mooncake_ci.yaml
+++ b/tests/e2e/offline_inference/stage_configs/bagel_mooncake_ci.yaml
@@ -47,15 +47,9 @@ stage_args:
engine_args:
model_stage: dit
max_num_seqs: 1
- gpu_memory_utilization: 0.45
enforce_eager: true
trust_remote_code: true
- engine_output_type: image
distributed_executor_backend: mp
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- tensor_parallel_size: 1
- load_format: dummy
omni_kv_config:
need_recv_cache: true
engine_input_source: [0]
@@ -70,9 +64,6 @@ stage_args:
# Top-level runtime config with Mooncake connector
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
connectors:
mooncake_connector:
name: MooncakeConnector
@@ -86,4 +77,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/tests/e2e/offline_inference/stage_configs/bagel_sharedmemory_ci.yaml b/tests/e2e/offline_inference/stage_configs/bagel_sharedmemory_ci.yaml
index b7999652e2..504f3c98e9 100644
--- a/tests/e2e/offline_inference/stage_configs/bagel_sharedmemory_ci.yaml
+++ b/tests/e2e/offline_inference/stage_configs/bagel_sharedmemory_ci.yaml
@@ -46,15 +46,9 @@ stage_args:
engine_args:
model_stage: dit
max_num_seqs: 1
- gpu_memory_utilization: 0.45
enforce_eager: true
trust_remote_code: true
- engine_output_type: image
distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- tensor_parallel_size: 1
- load_format: dummy
omni_kv_config:
need_recv_cache: true
engine_input_source: [0]
@@ -68,10 +62,6 @@ stage_args:
# Runtime edges
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
-
# Distributed connectors configuration (optional)
# More connectors will be supported in the future.
connectors:
@@ -84,4 +74,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/tests/e2e/offline_inference/stage_configs/npu/qwen2_5_omni_ci.yaml b/tests/e2e/offline_inference/stage_configs/npu/qwen2_5_omni_ci.yaml
deleted file mode 100644
index f93a6c7147..0000000000
--- a/tests/e2e/offline_inference/stage_configs/npu/qwen2_5_omni_ci.yaml
+++ /dev/null
@@ -1,103 +0,0 @@
-# stage config for running qwen2.5-omni for multi-stage omni runtime.
-
-# This config is optimized for CI e2e tests.
-stage_args:
- - stage_id: 0
- runtime:
- process: true # Run this stage in a separate process
- devices: "0"
- engine_args:
- model_stage: thinker
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- max_model_len: 896
- max_num_batched_tokens: 896
- max_num_seqs: 1
- gpu_memory_utilization: 0.8
- skip_mm_profiling: true
- enforce_eager: true # Now we only support eager mode
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- mm_processor_cache_gb: 0
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 128
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- - stage_id: 1
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- max_model_len: 896
- max_num_batched_tokens: 896
- max_num_seqs: 1
- gpu_memory_utilization: 0.8
- skip_mm_profiling: true
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: latent
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 128
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
- - stage_id: 2
- runtime:
- process: true
- devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.15
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 128
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/tests/e2e/offline_inference/test_bagel_img2img.py b/tests/e2e/offline_inference/test_bagel_img2img.py
index a0c3f6cc9f..be79aa7348 100644
--- a/tests/e2e/offline_inference/test_bagel_img2img.py
+++ b/tests/e2e/offline_inference/test_bagel_img2img.py
@@ -22,9 +22,9 @@
from PIL import Image
from vllm.assets.image import ImageAsset
-from tests.conftest import modify_stage_config
+from tests.conftest import OmniRunner, modify_stage_config
from tests.utils import hardware_test
-from vllm_omni.entrypoints.omni import Omni
+from vllm_omni import Omni
from vllm_omni.platforms import current_omni_platform
# Reference pixel data extracted from the known-good output image
@@ -32,30 +32,30 @@
# prompt='Change the grass color to red',
# input image: 2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg
REFERENCE_PIXELS = [
- {"position": (100, 100), "rgb": (157, 172, 217)},
- {"position": (400, 50), "rgb": (105, 144, 218)},
- {"position": (700, 100), "rgb": (118, 159, 233)},
- {"position": (150, 400), "rgb": (195, 34, 60)},
- {"position": (512, 336), "rgb": (222, 214, 193)},
- {"position": (700, 400), "rgb": (197, 15, 43)},
- {"position": (100, 600), "rgb": (105, 13, 18)},
- {"position": (400, 600), "rgb": (169, 33, 44)},
- {"position": (700, 600), "rgb": (101, 86, 93)},
- {"position": (256, 256), "rgb": (181, 202, 222)},
+ {"position": (100, 100), "rgb": (156, 172, 217)},
+ {"position": (400, 50), "rgb": (105, 144, 217)},
+ {"position": (700, 100), "rgb": (118, 159, 232)},
+ {"position": (150, 400), "rgb": (180, 22, 52)},
+ {"position": (512, 336), "rgb": (221, 211, 194)},
+ {"position": (700, 400), "rgb": (192, 10, 46)},
+ {"position": (100, 600), "rgb": (102, 12, 22)},
+ {"position": (400, 600), "rgb": (161, 28, 47)},
+ {"position": (700, 600), "rgb": (100, 87, 94)},
+ {"position": (256, 256), "rgb": (181, 201, 221)},
]
if current_omni_platform.is_rocm():
REFERENCE_PIXELS = [
- {"position": (100, 100), "rgb": (156, 172, 215)},
- {"position": (400, 50), "rgb": (106, 144, 216)},
- {"position": (700, 100), "rgb": (118, 158, 231)},
- {"position": (150, 400), "rgb": (183, 23, 48)},
- {"position": (512, 336), "rgb": (218, 215, 191)},
- {"position": (700, 400), "rgb": (194, 14, 42)},
- {"position": (100, 600), "rgb": (105, 10, 16)},
- {"position": (400, 600), "rgb": (167, 33, 46)},
- {"position": (700, 600), "rgb": (102, 86, 92)},
- {"position": (256, 256), "rgb": (181, 201, 220)},
+ {"position": (100, 100), "rgb": (156, 172, 217)},
+ {"position": (400, 50), "rgb": (105, 144, 217)},
+ {"position": (700, 100), "rgb": (118, 159, 232)},
+ {"position": (150, 400), "rgb": (180, 22, 52)},
+ {"position": (512, 336), "rgb": (221, 211, 194)},
+ {"position": (700, 400), "rgb": (192, 10, 46)},
+ {"position": (100, 600), "rgb": (102, 12, 22)},
+ {"position": (400, 600), "rgb": (161, 28, 47)},
+ {"position": (700, 600), "rgb": (100, 87, 94)},
+ {"position": (256, 256), "rgb": (181, 201, 221)},
]
PIXEL_TOLERANCE = 10
@@ -210,11 +210,10 @@ def test_bagel_img2img_shared_memory_connector(run_level):
input_image = _load_input_image()
config_path = str(Path(__file__).parent / "stage_configs" / "bagel_sharedmemory_ci.yaml")
config_path = _resolve_stage_config(config_path, run_level)
- omni = Omni(model="ByteDance-Seed/BAGEL-7B-MoT", stage_configs_path=config_path, stage_init_timeout=300)
-
- try:
- generated_image = _generate_bagel_img2img(omni, input_image)
+ with OmniRunner(
+ "ByteDance-Seed/BAGEL-7B-MoT",
+ stage_configs_path=config_path,
+ ) as runner:
+ generated_image = _generate_bagel_img2img(runner.omni, input_image)
if run_level == "advanced_model":
_validate_pixels(generated_image)
- finally:
- omni.close()
diff --git a/tests/e2e/offline_inference/test_bagel_lora.py b/tests/e2e/offline_inference/test_bagel_lora.py
index 593a640478..501d23eaa8 100644
--- a/tests/e2e/offline_inference/test_bagel_lora.py
+++ b/tests/e2e/offline_inference/test_bagel_lora.py
@@ -22,7 +22,6 @@
from vllm_omni.outputs import OmniRequestOutput
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
from pathlib import Path
@@ -32,9 +31,9 @@
from PIL import Image
from safetensors.torch import save_file
-from tests.conftest import modify_stage_config
+from tests.conftest import OmniRunner, modify_stage_config
from tests.utils import hardware_test
-from vllm_omni.entrypoints.omni import Omni
+from vllm_omni import Omni
from vllm_omni.lora.request import LoRARequest
from vllm_omni.lora.utils import stable_lora_int_id
@@ -154,8 +153,8 @@ def _make_file_lora_request(adapter_dir: Path) -> LoRARequest:
def test_bagel_lora_scale_and_deactivation(run_level, tmp_path):
"""Validate LoRA effect, bounded perturbation, and clean deactivation."""
config_path = _resolve_stage_config(BAGEL_STAGE_CONFIG, run_level)
- omni = Omni(model=MODEL, stage_configs_path=config_path, stage_init_timeout=300)
- try:
+ with OmniRunner(MODEL, stage_configs_path=config_path) as runner:
+ omni = runner.omni
lora_request = _make_file_lora_request(tmp_path / "bagel_lora")
# 1) Baseline (no LoRA)
@@ -194,5 +193,3 @@ def test_bagel_lora_scale_and_deactivation(run_level, tmp_path):
# (d) Deactivation fully restores base model
assert diff_restored == 0.0, f"Base model not restored after LoRA deactivation: diff={diff_restored}"
- finally:
- omni.close()
diff --git a/tests/e2e/offline_inference/test_bagel_text2img.py b/tests/e2e/offline_inference/test_bagel_text2img.py
index 7cce8da3a7..534b873068 100644
--- a/tests/e2e/offline_inference/test_bagel_text2img.py
+++ b/tests/e2e/offline_inference/test_bagel_text2img.py
@@ -16,7 +16,6 @@
import os
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
import signal
import socket
import subprocess
@@ -28,9 +27,9 @@
import pytest
from PIL import Image
-from tests.conftest import modify_stage_config
+from tests.conftest import OmniRunner, modify_stage_config
from tests.utils import hardware_test
-from vllm_omni.entrypoints.omni import Omni
+from vllm_omni import Omni
from vllm_omni.platforms import current_omni_platform
# Reference pixel data extracted from the known-good output image
@@ -38,30 +37,30 @@
# "Generated with seed=52, num_inference_steps=15,
# prompt='A futuristic city skyline at twilight, cyberpunk style'"
REFERENCE_PIXELS = [
- {"position": (100, 100), "rgb": (121, 118, 100)},
- {"position": (400, 50), "rgb": (163, 162, 143)},
- {"position": (700, 100), "rgb": (170, 156, 127)},
- {"position": (150, 400), "rgb": (129, 127, 112)},
- {"position": (512, 512), "rgb": (135, 61, 59)},
- {"position": (700, 400), "rgb": (205, 107, 43)},
- {"position": (100, 700), "rgb": (197, 177, 157)},
- {"position": (400, 700), "rgb": (139, 107, 86)},
- {"position": (700, 700), "rgb": (247, 205, 146)},
- {"position": (256, 256), "rgb": (171, 160, 153)},
+ {"position": (100, 100), "rgb": (115, 113, 94)},
+ {"position": (400, 50), "rgb": (159, 160, 144)},
+ {"position": (700, 100), "rgb": (164, 151, 123)},
+ {"position": (150, 400), "rgb": (120, 121, 107)},
+ {"position": (512, 512), "rgb": (165, 133, 127)},
+ {"position": (700, 400), "rgb": (217, 130, 66)},
+ {"position": (100, 700), "rgb": (191, 168, 152)},
+ {"position": (400, 700), "rgb": (130, 96, 77)},
+ {"position": (700, 700), "rgb": (247, 203, 140)},
+ {"position": (256, 256), "rgb": (167, 156, 150)},
]
if current_omni_platform.is_rocm():
REFERENCE_PIXELS = [
- {"position": (100, 100), "rgb": (123, 119, 100)},
- {"position": (400, 50), "rgb": (162, 161, 142)},
- {"position": (700, 100), "rgb": (171, 156, 127)},
- {"position": (150, 400), "rgb": (131, 128, 112)},
- {"position": (512, 512), "rgb": (134, 61, 59)},
- {"position": (700, 400), "rgb": (204, 107, 43)},
- {"position": (100, 700), "rgb": (201, 180, 165)},
- {"position": (400, 700), "rgb": (140, 108, 87)},
- {"position": (700, 700), "rgb": (247, 205, 145)},
- {"position": (256, 256), "rgb": (171, 160, 153)},
+ {"position": (100, 100), "rgb": (115, 113, 94)},
+ {"position": (400, 50), "rgb": (159, 160, 144)},
+ {"position": (700, 100), "rgb": (164, 151, 123)},
+ {"position": (150, 400), "rgb": (120, 121, 107)},
+ {"position": (512, 512), "rgb": (165, 133, 127)},
+ {"position": (700, 400), "rgb": (217, 130, 66)},
+ {"position": (100, 700), "rgb": (191, 168, 152)},
+ {"position": (400, 700), "rgb": (130, 96, 77)},
+ {"position": (700, 700), "rgb": (247, 203, 140)},
+ {"position": (256, 256), "rgb": (167, 156, 150)},
]
# Maximum allowed difference per color channel
@@ -199,14 +198,13 @@ def test_bagel_text2img_shared_memory_connector(run_level):
"""Test Bagel text2img with shared memory connector."""
config_path = str(Path(__file__).parent / "stage_configs" / "bagel_sharedmemory_ci.yaml")
config_path = _resolve_stage_config(config_path, run_level)
- omni = Omni(model="ByteDance-Seed/BAGEL-7B-MoT", stage_configs_path=config_path, stage_init_timeout=300)
-
- try:
- generated_image = _generate_bagel_image(omni)
+ with OmniRunner(
+ "ByteDance-Seed/BAGEL-7B-MoT",
+ stage_configs_path=config_path,
+ ) as runner:
+ generated_image = _generate_bagel_image(runner.omni)
if run_level == "advanced_model":
_validate_pixels(generated_image)
- finally:
- omni.close()
def _wait_for_port(host: str, port: int, timeout: int = 30) -> bool:
@@ -319,7 +317,6 @@ def test_bagel_text2img_mooncake_connector(run_level):
mooncake_master_proc = None
temp_config_file = None
- omni = None
try:
_cleanup_mooncake_processes()
@@ -349,15 +346,16 @@ def test_bagel_text2img_mooncake_connector(run_level):
)
temp_config_file = _resolve_stage_config(temp_config_file, run_level)
- omni = Omni(model="ByteDance-Seed/BAGEL-7B-MoT", stage_configs_path=temp_config_file, stage_init_timeout=300)
-
- generated_image = _generate_bagel_image(omni)
- if run_level == "advanced_model":
- _validate_pixels(generated_image)
+ with OmniRunner(
+ "ByteDance-Seed/BAGEL-7B-MoT",
+ stage_configs_path=temp_config_file,
+ stage_init_timeout=300,
+ ) as runner:
+ generated_image = _generate_bagel_image(runner.omni)
+ if run_level == "advanced_model":
+ _validate_pixels(generated_image)
finally:
- if omni:
- omni.close()
if temp_config_file:
try:
os.unlink(temp_config_file)
diff --git a/tests/e2e/offline_inference/test_bagel_understanding.py b/tests/e2e/offline_inference/test_bagel_understanding.py
index 6f95e7ee00..bbee329807 100644
--- a/tests/e2e/offline_inference/test_bagel_understanding.py
+++ b/tests/e2e/offline_inference/test_bagel_understanding.py
@@ -21,15 +21,13 @@
import os
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
from pathlib import Path
import pytest
from vllm.assets.image import ImageAsset
-from tests.conftest import modify_stage_config
+from tests.conftest import OmniRunner, modify_stage_config
from tests.utils import hardware_test
-from vllm_omni.entrypoints.omni import Omni
MODEL_NAME = "ByteDance-Seed/BAGEL-7B-MoT"
STAGE_CONFIG = str(Path(__file__).parent / "stage_configs" / "bagel_sharedmemory_ci.yaml")
@@ -76,13 +74,11 @@ def _extract_text(omni_outputs: list) -> str:
def test_bagel_text2text(run_level):
"""Test Bagel text2text produces correct text output."""
config_path = _resolve_stage_config(STAGE_CONFIG, run_level)
- omni = Omni(
- model=MODEL_NAME,
+ with OmniRunner(
+ MODEL_NAME,
stage_configs_path=config_path,
- stage_init_timeout=300,
- )
-
- try:
+ ) as runner:
+ omni = runner.omni
prompt = "<|im_start|>user\nWhere is the capital of France?<|im_end|>\n<|im_start|>assistant\n"
params_list = omni.default_sampling_params_list
omni_outputs = list(
@@ -100,8 +96,6 @@ def test_bagel_text2text(run_level):
assert text == REFERENCE_TEXT_TEXT2TEXT, (
f"Text mismatch: expected {REFERENCE_TEXT_TEXT2TEXT!r}, got {text!r}"
)
- finally:
- omni.close()
@pytest.mark.core_model
@@ -112,13 +106,12 @@ def test_bagel_img2text(run_level):
"""Test Bagel img2text produces correct text output."""
input_image = ImageAsset("2560px-Gfp-wisconsin-madison-the-nature-boardwalk").pil_image.convert("RGB")
config_path = _resolve_stage_config(STAGE_CONFIG, run_level)
- omni = Omni(
- model=MODEL_NAME,
+ with OmniRunner(
+ MODEL_NAME,
stage_configs_path=config_path,
stage_init_timeout=300,
- )
-
- try:
+ ) as runner:
+ omni = runner.omni
prompt = "<|im_start|>user\n<|image_pad|>\nPlease describe this image<|im_end|>\n<|im_start|>assistant\n"
params_list = omni.default_sampling_params_list
omni_outputs = list(
@@ -140,5 +133,3 @@ def test_bagel_img2text(run_level):
if run_level == "advanced_model":
assert text == REFERENCE_TEXT_IMG2TEXT, f"Text mismatch: expected {REFERENCE_TEXT_IMG2TEXT!r}, got {text!r}"
- finally:
- omni.close()
diff --git a/tests/e2e/offline_inference/test_cache_dit.py b/tests/e2e/offline_inference/test_cache_dit.py
index 0e31413dc0..fc08da7bed 100644
--- a/tests/e2e/offline_inference/test_cache_dit.py
+++ b/tests/e2e/offline_inference/test_cache_dit.py
@@ -8,27 +8,15 @@
It uses minimal settings to keep test time short for CI.
"""
-import os
-import sys
-from pathlib import Path
-
import pytest
import torch
+from tests.conftest import OmniRunner
from tests.utils import hardware_test
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-# ruff: noqa: E402
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
-
-from vllm_omni import Omni
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
-
# Use random weights model for testing
models = ["riverclouds/qwen_image_random"]
@@ -48,20 +36,17 @@ def test_cache_dit(model_name: str):
"residual_diff_threshold": 0.24,
"max_continuous_cached_steps": 3,
}
- m = None
- try:
- m = Omni(
- model=model_name,
- cache_backend="cache_dit",
- cache_config=cache_config,
- )
-
+ with OmniRunner(
+ model_name,
+ cache_backend="cache_dit",
+ cache_config=cache_config,
+ ) as runner:
# Use minimal settings for fast testing
height = 256
width = 256
num_inference_steps = 4 # Minimal steps for fast test
- outputs = m.generate(
+ outputs = runner.omni.generate(
"a photo of a cat sitting on a laptop keyboard",
OmniDiffusionSamplingParams(
height=height,
@@ -90,9 +75,3 @@ def test_cache_dit(model_name: str):
# Check image size
assert images[0].width == width
assert images[0].height == height
- except Exception as e:
- print(f"Test failed with error: {e}")
- raise
- finally:
- if m is not None and hasattr(m, "close"):
- m.close()
diff --git a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py
index f3830f02e9..257755ef8b 100644
--- a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py
+++ b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py
@@ -1,22 +1,14 @@
import gc
-import sys
-from pathlib import Path
import pytest
import torch
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
+from tests.conftest import OmniRunner
from tests.utils import DeviceMemoryMonitor, hardware_test
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
-# ruff: noqa: E402
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
-
-from vllm_omni import Omni
-
models = ["riverclouds/qwen_image_random"]
@@ -27,30 +19,29 @@ def inference(model_name: str, offload: bool = True):
current_omni_platform.reset_peak_memory_stats()
monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02)
monitor.start()
- m = Omni(
- model=model_name,
+ with OmniRunner(
+ model_name,
# TODO: we might want to add overlapped feature e2e tests
# cache_backend="cache_dit",
enable_cpu_offload=offload,
- )
- current_omni_platform.reset_peak_memory_stats()
- height = 256
- width = 256
+ ) as runner:
+ current_omni_platform.reset_peak_memory_stats()
+ height = 256
+ width = 256
- m.generate(
- "a photo of a cat sitting on a laptop keyboard",
- OmniDiffusionSamplingParams(
- height=height,
- width=width,
- num_inference_steps=9,
- guidance_scale=0.0,
- generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
- ),
- )
+ runner.omni.generate(
+ "a photo of a cat sitting on a laptop keyboard",
+ OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_inference_steps=9,
+ guidance_scale=0.0,
+ generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
+ ),
+ )
peak = monitor.peak_used_mb
monitor.stop()
- del m
gc.collect()
current_omni_platform.empty_cache()
diff --git a/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py b/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py
index 6132f1bd0e..bdfd594c77 100644
--- a/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py
+++ b/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py
@@ -1,21 +1,12 @@
-import sys
-from pathlib import Path
-
import pytest
import torch
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
+from tests.conftest import OmniRunner
from tests.utils import DeviceMemoryMonitor
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
-# ruff: noqa: E402
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
-
-from vllm_omni import Omni
-
# Models to test and expected saved memory in MB, correspondingly
MODELS_SAVED_MEMORY_MB = {
"riverclouds/qwen_image_random": 4500,
@@ -33,34 +24,33 @@ def run_inference(
monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02)
monitor.start()
- m = Omni(
- model=model_name,
+ with OmniRunner(
+ model_name,
enable_layerwise_offload=layerwise_offload,
# TODO: we might want to add overlapped feature e2e tests
# cache_backend="cache_dit",
boundary_ratio=0.875,
flow_shift=5.0,
- )
-
- current_omni_platform.reset_peak_memory_stats()
-
- # Refer to tests/e2e/offline_inference/test_t2v_model.py
- # Use minimal settings for testing
- height = 480
- width = 640
- num_frames = 5
-
- m.generate(
- "A cat sitting on a table",
- OmniDiffusionSamplingParams(
- height=height,
- width=width,
- generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
- guidance_scale=1.0,
- num_inference_steps=num_inference_steps,
- num_frames=num_frames,
- ),
- )
+ ) as runner:
+ current_omni_platform.reset_peak_memory_stats()
+
+ # Refer to tests/e2e/offline_inference/test_t2v_model.py
+ # Use minimal settings for testing
+ height = 480
+ width = 640
+ num_frames = 5
+
+ runner.omni.generate(
+ "A cat sitting on a table",
+ OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
+ guidance_scale=1.0,
+ num_inference_steps=num_inference_steps,
+ num_frames=num_frames,
+ ),
+ )
peak = monitor.peak_used_mb
monitor.stop()
diff --git a/tests/e2e/offline_inference/test_diffusion_lora.py b/tests/e2e/offline_inference/test_diffusion_lora.py
index b414fe30ee..7edd03f20d 100644
--- a/tests/e2e/offline_inference/test_diffusion_lora.py
+++ b/tests/e2e/offline_inference/test_diffusion_lora.py
@@ -7,6 +7,7 @@
import torch
from safetensors.torch import save_file
+from tests.conftest import OmniRunner
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
@@ -16,15 +17,12 @@
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
-from vllm_omni import Omni
-
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
# This test is specific to Z-Image LoRA behavior. Keep it focused on a single
# model to reduce runtime and avoid extra downloads.
models = ["Tongyi-MAI/Z-Image-Turbo"]
-DIFFUSION_INIT_TIMEOUT_S = 600
@pytest.mark.parametrize("model_name", models)
@@ -77,12 +75,8 @@ def _write_zimage_lora(adapter_dir: Path) -> str:
)
return str(adapter_dir)
- m = Omni(
- model=model_name,
- stage_init_timeout=DIFFUSION_INIT_TIMEOUT_S,
- init_timeout=DIFFUSION_INIT_TIMEOUT_S,
- )
- try:
+ with OmniRunner(model_name) as runner:
+ m = runner.omni
# high resolution may cause OOM on L4
height = 256
width = 256
@@ -140,5 +134,3 @@ def _write_zimage_lora(adapter_dir: Path) -> str:
diff = np.abs(np.array(images[0], dtype=np.int16) - np.array(images_lora[0], dtype=np.int16)).mean()
assert diff > 0.0
- finally:
- m.close()
diff --git a/tests/e2e/offline_inference/test_dynin_omni.py b/tests/e2e/offline_inference/test_dynin_omni.py
index d17e7b8175..5388ac6746 100644
--- a/tests/e2e/offline_inference/test_dynin_omni.py
+++ b/tests/e2e/offline_inference/test_dynin_omni.py
@@ -18,7 +18,6 @@
import torch
from transformers import AutoTokenizer
-from tests.conftest import OmniRunner
from tests.utils import hardware_test
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@@ -37,6 +36,7 @@
pytestmark = [
pytest.mark.core_model,
pytest.mark.omni,
+ pytest.mark.parametrize("omni_runner", test_params, indirect=True),
]
@@ -291,20 +291,11 @@ def _numel(value: Any) -> int:
@hardware_test(res={"cuda": "L4", "rocm": "MI325"})
-@pytest.mark.parametrize("test_config", test_params)
-def test_dynin_t2i_decode_to_image(test_config: tuple[str, str]) -> None:
- model, stage_config_path = test_config
+def test_dynin_t2i_decode_to_image(omni_runner) -> None:
_configure_dynin_config_env()
prompt = _build_t2i_decode_prompt(dynin_config_path=DYNIN_CONFIG_PATH)
- with OmniRunner(
- model,
- seed=42,
- stage_configs_path=stage_config_path,
- stage_init_timeout=600,
- init_timeout=600,
- ) as runner:
- outputs = runner.generate([prompt])
+ outputs = omni_runner.generate([prompt])
image_output = _find_stage_output(outputs, "image")
assert image_output is not None
@@ -314,25 +305,16 @@ def test_dynin_t2i_decode_to_image(test_config: tuple[str, str]) -> None:
@hardware_test(res={"cuda": "L4", "rocm": "MI325"})
-@pytest.mark.parametrize("test_config", test_params)
-def test_dynin_mmu_to_text(test_config: tuple[str, str]) -> None:
- model, stage_config_path = test_config
+def test_dynin_mmu_to_text(omni_runner) -> None:
_configure_dynin_config_env()
- tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
+ tokenizer = AutoTokenizer.from_pretrained(omni_runner.model_name, trust_remote_code=True)
prompt = _build_mmu_prompt(
tokenizer=tokenizer,
question="What is 2 + 2? Answer in one short sentence.",
dynin_config_path=DYNIN_CONFIG_PATH,
)
- with OmniRunner(
- model,
- seed=42,
- stage_configs_path=stage_config_path,
- stage_init_timeout=600,
- init_timeout=600,
- ) as runner:
- outputs = runner.generate([prompt])
+ outputs = omni_runner.generate([prompt])
text_output = _find_stage_output(outputs, "text")
assert text_output is not None
@@ -341,11 +323,9 @@ def test_dynin_mmu_to_text(test_config: tuple[str, str]) -> None:
@hardware_test(res={"cuda": "L4", "rocm": "MI325"})
-@pytest.mark.parametrize("test_config", test_params)
-def test_dynin_image_to_text(test_config: tuple[str, str]) -> None:
- model, stage_config_path = test_config
+def test_dynin_image_to_text(omni_runner) -> None:
_configure_dynin_config_env()
- tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
+ tokenizer = AutoTokenizer.from_pretrained(omni_runner.model_name, trust_remote_code=True)
prompt = _build_mmu_multimodal_prompt(
tokenizer=tokenizer,
question="Describe the image briefly in one sentence.",
@@ -353,14 +333,7 @@ def test_dynin_image_to_text(test_config: tuple[str, str]) -> None:
image=_generate_synthetic_image(),
)
- with OmniRunner(
- model,
- seed=42,
- stage_configs_path=stage_config_path,
- stage_init_timeout=600,
- init_timeout=600,
- ) as runner:
- outputs = runner.generate([prompt])
+ outputs = omni_runner.generate([prompt])
text_output = _find_stage_output(outputs, "text")
assert text_output is not None
@@ -369,11 +342,9 @@ def test_dynin_image_to_text(test_config: tuple[str, str]) -> None:
@hardware_test(res={"cuda": "L4", "rocm": "MI325"})
-@pytest.mark.parametrize("test_config", test_params)
-def test_dynin_speech_to_text(test_config: tuple[str, str]) -> None:
- model, stage_config_path = test_config
+def test_dynin_speech_to_text(omni_runner) -> None:
_configure_dynin_config_env()
- tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
+ tokenizer = AutoTokenizer.from_pretrained(omni_runner.model_name, trust_remote_code=True)
prompt = _build_mmu_multimodal_prompt(
tokenizer=tokenizer,
question="Transcribe the audio briefly in one sentence.",
@@ -381,14 +352,7 @@ def test_dynin_speech_to_text(test_config: tuple[str, str]) -> None:
audio=_generate_synthetic_audio(),
)
- with OmniRunner(
- model,
- seed=42,
- stage_configs_path=stage_config_path,
- stage_init_timeout=600,
- init_timeout=600,
- ) as runner:
- outputs = runner.generate([prompt])
+ outputs = omni_runner.generate([prompt])
text_output = _find_stage_output(outputs, "text")
assert text_output is not None
@@ -397,20 +361,11 @@ def test_dynin_speech_to_text(test_config: tuple[str, str]) -> None:
@hardware_test(res={"cuda": "L4", "rocm": "MI325"})
-@pytest.mark.parametrize("test_config", test_params)
-def test_dynin_t2s_decode_to_audio(test_config: tuple[str, str]) -> None:
- model, stage_config_path = test_config
+def test_dynin_t2s_decode_to_audio(omni_runner) -> None:
_configure_dynin_config_env()
prompt = _build_t2s_decode_prompt(dynin_config_path=DYNIN_CONFIG_PATH)
- with OmniRunner(
- model,
- seed=42,
- stage_configs_path=stage_config_path,
- stage_init_timeout=600,
- init_timeout=600,
- ) as runner:
- outputs = runner.generate([prompt])
+ outputs = omni_runner.generate([prompt])
audio_output = _find_stage_output(outputs, "audio")
assert audio_output is not None
diff --git a/tests/e2e/offline_inference/test_expert_parallel.py b/tests/e2e/offline_inference/test_expert_parallel.py
index ba126986ec..29d84d7a3e 100644
--- a/tests/e2e/offline_inference/test_expert_parallel.py
+++ b/tests/e2e/offline_inference/test_expert_parallel.py
@@ -18,8 +18,8 @@
import torch.distributed as dist
from PIL import Image
+from tests.conftest import OmniRunner
from tests.utils import hardware_test
-from vllm_omni import Omni
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
@@ -96,12 +96,26 @@ def _run_inference(
tensor_parallel_size=tensor_parallel_size,
enable_expert_parallel=enable_expert_parallel,
)
- omni = Omni(model=model_name, parallel_config=parallel_config)
-
try:
- # Warmup run (not timed)
- if warmup:
- _ = omni.generate(
+ with OmniRunner(model_name, parallel_config=parallel_config) as runner:
+ omni = runner.omni
+ # Warmup run (not timed)
+ if warmup:
+ _ = omni.generate(
+ PROMPT,
+ OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_inference_steps=DEFAULT_STEPS,
+ guidance_scale=guidance_scale,
+ generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed),
+ num_outputs_per_prompt=1,
+ ),
+ )
+
+ # Timed run
+ start = time.time()
+ outputs = omni.generate(
PROMPT,
OmniDiffusionSamplingParams(
height=height,
@@ -112,28 +126,13 @@ def _run_inference(
num_outputs_per_prompt=1,
),
)
+ elapsed_ms = (time.time() - start) * 1000
- # Timed run
- start = time.time()
- outputs = omni.generate(
- PROMPT,
- OmniDiffusionSamplingParams(
- height=height,
- width=width,
- num_inference_steps=DEFAULT_STEPS,
- guidance_scale=guidance_scale,
- generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed),
- num_outputs_per_prompt=1,
- ),
- )
- elapsed_ms = (time.time() - start) * 1000
-
- return InferenceResult(
- images=outputs[0].images,
- elapsed_ms=elapsed_ms,
- )
+ return InferenceResult(
+ images=outputs[0].images,
+ elapsed_ms=elapsed_ms,
+ )
finally:
- omni.close()
_cleanup_distributed()
diff --git a/tests/e2e/offline_inference/test_flux.py b/tests/e2e/offline_inference/test_flux.py
new file mode 100644
index 0000000000..02c6787be2
--- /dev/null
+++ b/tests/e2e/offline_inference/test_flux.py
@@ -0,0 +1,37 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""Tests for Flux1 Schnell."""
+
+import pytest
+from PIL import Image
+
+from vllm_omni.entrypoints.omni import Omni
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+
+MODEL = "black-forest-labs/FLUX.1-schnell"
+
+
+@pytest.mark.core_model
+@pytest.mark.diffusion
+def test_flux_schnell_text_to_image():
+ """Test FLUX.1-schnell text-to-image generation."""
+ omni = Omni(model=MODEL)
+
+ omni_outputs = list(
+ omni.generate(
+ prompts=["A photo of a cat sitting on a laptop"],
+ sampling_params_list=OmniDiffusionSamplingParams(
+ height=512,
+ width=512,
+ num_inference_steps=2,
+ seed=42,
+ ),
+ )
+ )
+
+ assert len(omni_outputs) > 0
+ images = omni_outputs[0].images
+ assert len(images) == 1
+ assert isinstance(images[0], Image.Image)
+ assert images[0].size == (512, 512)
diff --git a/tests/e2e/offline_inference/test_flux2_klein_inpaint.py b/tests/e2e/offline_inference/test_flux2_klein_inpaint.py
new file mode 100644
index 0000000000..ac7169361f
--- /dev/null
+++ b/tests/e2e/offline_inference/test_flux2_klein_inpaint.py
@@ -0,0 +1,216 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""
+End-to-end test for Flux2 Klein inpainting.
+
+"""
+
+# ruff: noqa: E402
+
+import os
+import sys
+from pathlib import Path
+
+import pytest
+import torch
+from PIL import Image, ImageDraw
+
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+from vllm_omni.outputs import OmniRequestOutput
+from vllm_omni.platforms import current_omni_platform
+
+REPO_ROOT = Path(__file__).resolve().parents[2]
+if str(REPO_ROOT) not in sys.path:
+ sys.path.insert(0, str(REPO_ROOT))
+
+from vllm_omni import Omni
+
+os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
+
+MODEL = "black-forest-labs/FLUX.2-klein-4B"
+
+_HEIGHT = 512
+_WIDTH = 512
+_NUM_INFERENCE_STEPS = 4
+
+
+def _create_test_image(width: int = _WIDTH, height: int = _HEIGHT, color: tuple = (128, 128, 128)) -> Image.Image:
+ return Image.new("RGB", (width, height), color)
+
+
+def _create_test_mask(width: int = _WIDTH, height: int = _HEIGHT) -> Image.Image:
+ mask = Image.new("L", (width, height), 0)
+ draw = ImageDraw.Draw(mask)
+ draw.rectangle([width // 4, height // 4, width * 3 // 4, height * 3 // 4], fill=255)
+ return mask
+
+
+def _create_test_inputs(color: tuple = (100, 150, 200)):
+ return _create_test_image(_WIDTH, _HEIGHT, color), _create_test_mask(_WIDTH, _HEIGHT)
+
+
+def _extract_images_from_output(outputs: list) -> list[Image.Image]:
+ images = []
+ for req_output in outputs:
+ if hasattr(req_output, "images") and req_output.images:
+ images.extend(req_output.images)
+ elif hasattr(req_output, "request_output") and req_output.request_output:
+ stage_out = req_output.request_output
+ if isinstance(stage_out, OmniRequestOutput) and hasattr(stage_out, "images"):
+ images.extend(stage_out.images)
+ elif isinstance(stage_out, list):
+ for s in stage_out:
+ if hasattr(s, "images") and s.images:
+ images.extend(s.images)
+ return images
+
+
+@pytest.mark.core_model
+@pytest.mark.diffusion
+def test_flux2_klein_inpaint_basic():
+ m = None
+ try:
+ m = Omni(model=MODEL)
+ input_image, mask_image = _create_test_inputs()
+
+ outputs = m.generate(
+ prompts=[
+ {
+ "prompt": "Fill in the masked area with a beautiful garden",
+ "multi_modal_data": {"image": input_image, "mask_image": mask_image},
+ }
+ ],
+ sampling_params_list=OmniDiffusionSamplingParams(
+ height=_HEIGHT,
+ width=_WIDTH,
+ num_inference_steps=_NUM_INFERENCE_STEPS,
+ guidance_scale=0.0,
+ generator=torch.Generator(current_omni_platform.device_type).manual_seed(42),
+ num_outputs_per_prompt=1,
+ ),
+ )
+
+ images = _extract_images_from_output(list(outputs))
+ assert len(images) == 1
+ assert images[0].size == (_WIDTH, _HEIGHT)
+ finally:
+ if m is not None and hasattr(m, "close"):
+ m.close()
+
+
+@pytest.mark.diffusion
+def test_flux2_klein_inpaint_deterministic():
+ m = None
+ try:
+ m = Omni(model=MODEL)
+ input_image, mask_image = _create_test_inputs()
+ seed = 12345
+
+ gen1 = torch.Generator(current_omni_platform.device_type).manual_seed(seed)
+ gen2 = torch.Generator(current_omni_platform.device_type).manual_seed(seed)
+
+ outputs1 = m.generate(
+ prompts=[
+ {
+ "prompt": "A red flower in a field",
+ "multi_modal_data": {"image": input_image, "mask_image": mask_image},
+ }
+ ],
+ sampling_params_list=OmniDiffusionSamplingParams(
+ height=_HEIGHT,
+ width=_WIDTH,
+ num_inference_steps=_NUM_INFERENCE_STEPS,
+ guidance_scale=0.0,
+ generator=gen1,
+ num_outputs_per_prompt=1,
+ ),
+ )
+
+ outputs2 = m.generate(
+ prompts=[
+ {
+ "prompt": "A red flower in a field",
+ "multi_modal_data": {"image": input_image, "mask_image": mask_image},
+ }
+ ],
+ sampling_params_list=OmniDiffusionSamplingParams(
+ height=_HEIGHT,
+ width=_WIDTH,
+ num_inference_steps=_NUM_INFERENCE_STEPS,
+ guidance_scale=0.0,
+ generator=gen2,
+ num_outputs_per_prompt=1,
+ ),
+ )
+
+ images1 = _extract_images_from_output(list(outputs1))
+ images2 = _extract_images_from_output(list(outputs2))
+
+ assert len(images1) == 1
+ assert len(images2) == 1
+
+ assert list(images1[0].getdata()) == list(images2[0].getdata()), (
+ "Same input with same seed should produce identical output. "
+ "This is critical for offline/online consistency."
+ )
+ finally:
+ if m is not None and hasattr(m, "close"):
+ m.close()
+
+
+@pytest.mark.diffusion
+def test_flux2_klein_inpaint_different_seeds_different_output():
+ m = None
+ try:
+ m = Omni(model=MODEL)
+ input_image, mask_image = _create_test_inputs()
+
+ gen1 = torch.Generator(current_omni_platform.device_type).manual_seed(42)
+ gen2 = torch.Generator(current_omni_platform.device_type).manual_seed(99999)
+
+ outputs1 = m.generate(
+ prompts=[
+ {
+ "prompt": "A beautiful landscape",
+ "multi_modal_data": {"image": input_image, "mask_image": mask_image},
+ }
+ ],
+ sampling_params_list=OmniDiffusionSamplingParams(
+ height=_HEIGHT,
+ width=_WIDTH,
+ num_inference_steps=_NUM_INFERENCE_STEPS,
+ guidance_scale=0.0,
+ generator=gen1,
+ num_outputs_per_prompt=1,
+ ),
+ )
+
+ outputs2 = m.generate(
+ prompts=[
+ {
+ "prompt": "A beautiful landscape",
+ "multi_modal_data": {"image": input_image, "mask_image": mask_image},
+ }
+ ],
+ sampling_params_list=OmniDiffusionSamplingParams(
+ height=_HEIGHT,
+ width=_WIDTH,
+ num_inference_steps=_NUM_INFERENCE_STEPS,
+ guidance_scale=0.0,
+ generator=gen2,
+ num_outputs_per_prompt=1,
+ ),
+ )
+
+ images1 = _extract_images_from_output(list(outputs1))
+ images2 = _extract_images_from_output(list(outputs2))
+
+ assert len(images1) == 1
+ assert len(images2) == 1
+
+ different_pixel_count = sum(1 for p1, p2 in zip(images1[0].getdata(), images2[0].getdata()) if p1 != p2)
+ assert different_pixel_count > 0, "Different seeds should produce different outputs"
+ finally:
+ if m is not None and hasattr(m, "close"):
+ m.close()
diff --git a/tests/e2e/offline_inference/test_flux_autoround_w4a16.py b/tests/e2e/offline_inference/test_flux_autoround_w4a16.py
index 42aab7f26a..cbcd1009dd 100644
--- a/tests/e2e/offline_inference/test_flux_autoround_w4a16.py
+++ b/tests/e2e/offline_inference/test_flux_autoround_w4a16.py
@@ -8,31 +8,21 @@
"""
import gc
-import sys
-from pathlib import Path
+import os as _os
import pytest
import torch
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
+from tests.conftest import OmniRunner
from tests.utils import DeviceMemoryMonitor, hardware_test
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
-# ruff: noqa: E402
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
-
-from vllm_omni import Omni
-
QUANTIZED_MODEL = "vllm-project-org/FLUX.1-dev-AutoRound-w4a16"
BASELINE_MODEL = "black-forest-labs/FLUX.1-dev"
-# Allow overriding via environment for local testing
-import os as _os
-
QUANTIZED_MODEL = _os.environ.get("FLUX_AUTOROUND_MODEL", QUANTIZED_MODEL)
BASELINE_MODEL = _os.environ.get("FLUX_BASELINE_MODEL", BASELINE_MODEL)
@@ -51,19 +41,18 @@ def _generate_image(model_name: str, **extra_kwargs) -> tuple[list, float]:
monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02)
monitor.start()
- m = Omni(model=model_name, enforce_eager=True, **extra_kwargs)
-
- current_omni_platform.reset_peak_memory_stats()
- outputs = m.generate(
- "a photo of a cat sitting on a laptop keyboard",
- OmniDiffusionSamplingParams(
- height=HEIGHT,
- width=WIDTH,
- num_inference_steps=NUM_STEPS,
- guidance_scale=0.0,
- generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
- ),
- )
+ with OmniRunner(model_name, enforce_eager=True, **extra_kwargs) as runner:
+ current_omni_platform.reset_peak_memory_stats()
+ outputs = runner.omni.generate(
+ "a photo of a cat sitting on a laptop keyboard",
+ OmniDiffusionSamplingParams(
+ height=HEIGHT,
+ width=WIDTH,
+ num_inference_steps=NUM_STEPS,
+ guidance_scale=0.0,
+ generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
+ ),
+ )
peak = monitor.peak_used_mb
monitor.stop()
@@ -74,7 +63,6 @@ def _generate_image(model_name: str, **extra_kwargs) -> tuple[list, float]:
assert isinstance(req_out, OmniRequestOutput) and hasattr(req_out, "images")
images = req_out.images
- del m
gc.collect()
current_omni_platform.empty_cache()
diff --git a/tests/e2e/offline_inference/test_flux_kontext.py b/tests/e2e/offline_inference/test_flux_kontext.py
index 93dca21c9a..cd711d6b81 100644
--- a/tests/e2e/offline_inference/test_flux_kontext.py
+++ b/tests/e2e/offline_inference/test_flux_kontext.py
@@ -9,23 +9,14 @@
- Image editing with text guidance
"""
-import os
-import sys
-from pathlib import Path
-
import pytest
from PIL import Image
+from vllm.assets.image import ImageAsset
+from tests.conftest import OmniRunner
from vllm_omni.diffusion.data import DiffusionParallelConfig
-from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
-
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
-
MODEL = "black-forest-labs/FLUX.1-Kontext-dev"
@@ -33,17 +24,15 @@
@pytest.mark.diffusion
def test_flux_kontext_text_to_image():
"""Test FluxKontext text-to-image generation with real model."""
- omni = Omni(
- model=MODEL,
+ with OmniRunner(
+ MODEL,
parallel_config=DiffusionParallelConfig(
tensor_parallel_size=2,
),
enable_cpu_offload=False,
- )
-
- try:
+ ) as runner:
omni_outputs = list(
- omni.generate(
+ runner.omni.generate(
prompts=["A photo of a cat sitting on a laptop"],
sampling_params_list=OmniDiffusionSamplingParams(
height=512,
@@ -54,43 +43,37 @@ def test_flux_kontext_text_to_image():
)
)
- assert len(omni_outputs) > 0
- output = omni_outputs[0]
- images = None
- if output.images:
- images = output.images
- elif hasattr(output, "request_output") and output.request_output:
- for stage_out in output.request_output:
- if hasattr(stage_out, "images") and stage_out.images:
- images = stage_out.images
- break
+ assert len(omni_outputs) > 0
+ output = omni_outputs[0]
+ images = None
+ if output.images:
+ images = output.images
+ elif hasattr(output, "request_output") and output.request_output:
+ for stage_out in output.request_output:
+ if hasattr(stage_out, "images") and stage_out.images:
+ images = stage_out.images
+ break
- assert images is not None
- assert len(images) > 0
- assert isinstance(images[0], Image.Image)
- assert images[0].size == (512, 512)
- finally:
- omni.close()
+ assert images is not None
+ assert len(images) > 0
+ assert isinstance(images[0], Image.Image)
+ assert images[0].size == (512, 512)
@pytest.mark.core_model
@pytest.mark.diffusion
def test_flux_kontext_image_edit():
"""Test FluxKontext image-to-image editing with real model."""
- from vllm.assets.image import ImageAsset
-
input_image = ImageAsset("2560px-Gfp-wisconsin-madison-the-nature-boardwalk").pil_image.convert("RGB")
- omni = Omni(
- model=MODEL,
+ with OmniRunner(
+ MODEL,
parallel_config=DiffusionParallelConfig(
tensor_parallel_size=2,
),
enable_cpu_offload=False,
- )
-
- try:
+ ) as runner:
omni_outputs = list(
- omni.generate(
+ runner.omni.generate(
prompts=[
{
"prompt": "Transform this image into a Vincent van Gogh style painting",
@@ -107,20 +90,18 @@ def test_flux_kontext_image_edit():
)
)
- assert len(omni_outputs) > 0
- output = omni_outputs[0]
- images = None
- if output.images:
- images = output.images
- elif hasattr(output, "request_output") and output.request_output:
- for stage_out in output.request_output:
- if hasattr(stage_out, "images") and stage_out.images:
- images = stage_out.images
- break
-
- assert images is not None
- assert len(images) > 0
- assert isinstance(images[0], Image.Image)
- assert images[0].size == (512, 512)
- finally:
- omni.close()
+ assert len(omni_outputs) > 0
+ output = omni_outputs[0]
+ images = None
+ if output.images:
+ images = output.images
+ elif hasattr(output, "request_output") and output.request_output:
+ for stage_out in output.request_output:
+ if hasattr(stage_out, "images") and stage_out.images:
+ images = stage_out.images
+ break
+
+ assert images is not None
+ assert len(images) > 0
+ assert isinstance(images[0], Image.Image)
+ assert images[0].size == (512, 512)
diff --git a/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py b/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py
index 5522f33eaa..ec4f4693d7 100644
--- a/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py
+++ b/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py
@@ -8,6 +8,7 @@
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
+from tests.conftest import OmniRunner
from vllm_omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
@@ -16,7 +17,7 @@
MODEL_NAME = "tencent/HunyuanImage-3.0"
LOCAL_CLIP_PATH = "openai/clip-vit-base-patch32"
REPO_ROOT = Path(__file__).resolve().parents[3]
-STAGE_CONFIG_PATH = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "hunyuan_image_3_moe.yaml"
+STAGE_CONFIG_PATH = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "hunyuan_image3_t2i.yaml"
pytestmark = [pytest.mark.advanced_model, pytest.mark.diffusion]
@@ -271,16 +272,11 @@ def clip_bundle() -> tuple[CLIPModel, CLIPProcessor]:
@pytest.fixture(scope="module")
def omni() -> Generator[Omni, None, None]:
- engine = Omni(
- model=MODEL_NAME,
+ with OmniRunner(
+ MODEL_NAME,
stage_configs_path=str(STAGE_CONFIG_PATH),
- stage_init_timeout=600,
- init_timeout=900,
- )
- try:
- yield engine
- finally:
- engine.close()
+ ) as runner:
+ yield runner.omni
def _extract_generated_image(outputs: list[object]) -> Image.Image:
diff --git a/tests/e2e/offline_inference/test_magi_human.py b/tests/e2e/offline_inference/test_magi_human.py
index 8648216a92..abb7f9c163 100644
--- a/tests/e2e/offline_inference/test_magi_human.py
+++ b/tests/e2e/offline_inference/test_magi_human.py
@@ -8,9 +8,9 @@
import numpy as np
import pytest
+from tests.conftest import OmniRunner
from tests.utils import hardware_test
from vllm_omni.diffusion.utils.media_utils import mux_video_audio_bytes
-from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
@@ -49,12 +49,6 @@ def test_magi_human_e2e(run_level):
model_path = "SII-GAIR/daVinci-MagiHuman-Base-1080p"
- omni = Omni(
- model=model_path,
- init_timeout=1200,
- tensor_parallel_size=2,
- )
-
prompt = (
"A young woman with long, wavy golden blonde hair and bright blue eyes, "
"wearing a fitted ivory silk blouse with a delicate lace collar, sits "
@@ -94,7 +88,12 @@ def test_magi_human_e2e(run_level):
},
)
- try:
+ with OmniRunner(
+ model_path,
+ init_timeout=1200,
+ tensor_parallel_size=2,
+ ) as runner:
+ omni = runner.omni
outputs = list(
omni.generate(
prompts=[prompt],
@@ -140,5 +139,3 @@ def test_magi_human_e2e(run_level):
assert len(video_bytes) > 1000, f"MP4 too small ({len(video_bytes)} bytes)"
_validate_mp4(video_bytes)
- finally:
- omni.close()
diff --git a/tests/e2e/offline_inference/test_mammoth_moda2.py b/tests/e2e/offline_inference/test_mammoth_moda2.py
index 5293b5ed1b..ff744c86e1 100644
--- a/tests/e2e/offline_inference/test_mammoth_moda2.py
+++ b/tests/e2e/offline_inference/test_mammoth_moda2.py
@@ -23,10 +23,9 @@
import torch
from vllm.sampling_params import SamplingParams
+from tests.conftest import OmniRunner
from tests.utils import hardware_test
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
-
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
@@ -116,8 +115,6 @@ def test_mammothmoda2_t2i_e2e():
- A fixed set of pixel values matches a golden reference
(regenerate with ``UPDATE_GOLDEN=1``).
"""
- from vllm_omni import Omni
-
if not Path(MODEL_PATH).exists():
pytest.skip(f"Model weights not found at {MODEL_PATH}")
if not Path(T2I_STAGE_CONFIG).exists():
@@ -135,8 +132,8 @@ def test_mammothmoda2_t2i_e2e():
prompt_text = "A cat sitting on a laptop keyboard"
formatted_prompt = _format_t2i_prompt(prompt_text, ar_width, ar_height)
- omni = Omni(model=MODEL_PATH, stage_configs_path=T2I_STAGE_CONFIG, trust_remote_code=True)
- try:
+ with OmniRunner(MODEL_PATH, stage_configs_path=T2I_STAGE_CONFIG, trust_remote_code=True) as runner:
+ omni = runner.omni
# Greedy / deterministic sampling so pixel values are reproducible.
ar_sampling = SamplingParams(
temperature=0.0,
@@ -211,5 +208,3 @@ def test_mammothmoda2_t2i_e2e():
found_image = True
assert found_image, "No image tensor found in pipeline output"
- finally:
- omni.close()
diff --git a/tests/e2e/offline_inference/test_ming_flash_omni.py b/tests/e2e/offline_inference/test_ming_flash_omni.py
new file mode 100644
index 0000000000..be0ed3b056
--- /dev/null
+++ b/tests/e2e/offline_inference/test_ming_flash_omni.py
@@ -0,0 +1,142 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import os
+
+os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
+
+from pathlib import Path
+
+import pytest
+
+from tests.conftest import (
+ generate_synthetic_audio,
+ generate_synthetic_image,
+ generate_synthetic_video,
+ modify_stage_config,
+)
+from tests.utils import hardware_test
+
+models = ["Jonathan1909/Ming-flash-omni-2.0"]
+
+# Ming-specific
+SYSTEM_PROMPT = "你是一个友好的AI助手。\n\ndetailed thinking off"
+EOS_TOKEN = "<|role_end|>"
+IMAGE_TOKEN = ""
+VIDEO_TOKEN = ""
+AUDIO_TOKEN = ""
+
+
+def build_prompt(user_text: str) -> str:
+ """Build a Ming chat prompt."""
+ return (
+ f"SYSTEM {SYSTEM_PROMPT}{EOS_TOKEN}HUMAN {user_text}{EOS_TOKEN}ASSISTANT "
+ )
+
+
+def get_eager_config():
+ path = modify_stage_config(
+ str(Path(__file__).parent.parent / "stage_configs" / "bailingmm_moe_v2_lite_ci.yaml"),
+ updates={
+ "stage_args": {
+ 0: {
+ "engine_args.enforce_eager": "true",
+ },
+ },
+ },
+ )
+ return path
+
+
+stage_configs = [get_eager_config()]
+test_params = [(model, stage_config) for model in models for stage_config in stage_configs]
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
+def test_text_to_text(omni_runner, omni_runner_handler) -> None:
+ """
+ Test text-only input processing and text output generation.
+ Input Modal: text
+ Output Modal: text
+ """
+ prompt = build_prompt("请详细介绍鹦鹉的生活习性。")
+ request_config = {"prompts": prompt, "modalities": ["text"]}
+
+ omni_runner_handler.send_request(request_config)
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
+def test_image_to_text(omni_runner, omni_runner_handler) -> None:
+ """
+ Test image understanding with text output.
+ Input Modal: image + text
+ Output Modal: text
+ """
+ image = generate_synthetic_image(224, 224)["np_array"]
+ prompt = build_prompt(f"{IMAGE_TOKEN}Describe this image briefly.")
+ request_config = {"prompts": prompt, "images": image, "modalities": ["text"]}
+
+ omni_runner_handler.send_request(request_config)
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
+def test_audio_to_text(omni_runner, omni_runner_handler) -> None:
+ """
+ Test audio understanding with text output.
+ Input Modal: audio + text
+ Output Modal: text
+ """
+ audio = generate_synthetic_audio(2, 1, 16000)["np_array"]
+ if len(audio.shape) == 2:
+ audio = audio.squeeze()
+ prompt = build_prompt(f"{AUDIO_TOKEN}Please recognize the language of this speech and transcribe it. Format: oral.")
+ request_config = {"prompts": prompt, "audios": audio, "modalities": ["text"]}
+
+ omni_runner_handler.send_request(request_config)
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
+def test_video_to_text(omni_runner, omni_runner_handler) -> None:
+ """
+ Test video understanding with text output.
+ Input Modal: video + text
+ Output Modal: text
+ """
+ video = generate_synthetic_video(224, 224, 30)["np_array"]
+ prompt = build_prompt(f"{VIDEO_TOKEN}Describe what is happening in this video.")
+ request_config = {"prompts": prompt, "videos": video, "modalities": ["text"]}
+
+ omni_runner_handler.send_request(request_config)
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
+def test_mixed_to_text(omni_runner, omni_runner_handler) -> None:
+ """
+ Test mixed modality input (image + audio) with text output.
+ Input Modal: image + audio + text
+ Output Modal: text
+ """
+ image = generate_synthetic_image(224, 224)["np_array"]
+ audio = generate_synthetic_audio(2, 1, 16000)["np_array"]
+ if len(audio.shape) == 2:
+ audio = audio.squeeze()
+ prompt = build_prompt(f"{IMAGE_TOKEN}{AUDIO_TOKEN}Describe the image and transcribe the audio.")
+ request_config = {"prompts": prompt, "images": image, "audios": audio, "modalities": ["text"]}
+
+ omni_runner_handler.send_request(request_config)
diff --git a/tests/e2e/offline_inference/test_omnivoice.py b/tests/e2e/offline_inference/test_omnivoice.py
index 4b093e357d..bb4c8a5dd7 100644
--- a/tests/e2e/offline_inference/test_omnivoice.py
+++ b/tests/e2e/offline_inference/test_omnivoice.py
@@ -16,6 +16,7 @@
import numpy as np
import pytest
+from tests.conftest import OmniRunner
from tests.utils import hardware_test
MODEL = "k2-fsa/OmniVoice"
@@ -37,48 +38,42 @@ def test_omnivoice_text_to_audio() -> None:
Input Modal: text
Output Modal: audio
"""
- from vllm_omni.entrypoints.omni import Omni
+ from vllm_omni.inputs.data import OmniDiffusionSamplingParams
- omni = Omni(
- model=MODEL,
+ with OmniRunner(
+ MODEL,
stage_configs_path=get_stage_config(),
trust_remote_code=True,
log_stats=True,
- )
-
- try:
+ ) as runner:
prompts = {"prompt": "Hello, this is a test for text to audio."}
- from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
sampling_params_list = [OmniDiffusionSamplingParams()]
- outputs = list(omni.generate(prompts, sampling_params_list=sampling_params_list))
+ outputs = list(runner.omni.generate(prompts, sampling_params_list=sampling_params_list))
- assert len(outputs) > 0, "No outputs generated"
+ assert len(outputs) > 0, "No outputs generated"
- # Check final output has audio
- final_output = outputs[-1]
- ro = final_output.request_output
- assert ro is not None, "No request_output"
+ # Check final output has audio
+ final_output = outputs[-1]
+ ro = final_output.request_output
+ assert ro is not None, "No request_output"
- mm = getattr(ro, "multimodal_output", None)
- if not mm and ro.outputs:
- mm = getattr(ro.outputs[0], "multimodal_output", None)
+ mm = getattr(ro, "multimodal_output", None)
+ if not mm and ro.outputs:
+ mm = getattr(ro.outputs[0], "multimodal_output", None)
- assert mm is not None, "No multimodal_output"
- assert "audio" in mm, f"No 'audio' key in multimodal_output: {mm.keys()}"
+ assert mm is not None, "No multimodal_output"
+ assert "audio" in mm, f"No 'audio' key in multimodal_output: {mm.keys()}"
- audio = mm["audio"]
- if isinstance(audio, np.ndarray):
- audio_np = audio
- else:
- audio_np = audio.cpu().numpy().squeeze()
+ audio = mm["audio"]
+ if isinstance(audio, np.ndarray):
+ audio_np = audio
+ else:
+ audio_np = audio.cpu().numpy().squeeze()
- assert audio_np.size > 0, "Audio output is empty"
- rms = np.sqrt(np.mean(audio_np**2))
- assert rms > 0.01, f"Audio RMS too low ({rms:.4f}), likely silence"
+ assert audio_np.size > 0, "Audio output is empty"
+ rms = np.sqrt(np.mean(audio_np**2))
+ assert rms > 0.01, f"Audio RMS too low ({rms:.4f}), likely silence"
- print(f"Generated audio: {len(audio_np) / 24000:.2f}s, rms={rms:.4f}")
- finally:
- omni.close()
+ print(f"Generated audio: {len(audio_np) / 24000:.2f}s, rms={rms:.4f}")
diff --git a/tests/e2e/offline_inference/test_quantization_fp8.py b/tests/e2e/offline_inference/test_quantization_fp8.py
index f71c53de74..291779fd93 100644
--- a/tests/e2e/offline_inference/test_quantization_fp8.py
+++ b/tests/e2e/offline_inference/test_quantization_fp8.py
@@ -29,7 +29,6 @@
import os
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
from pathlib import Path
from typing import Any
@@ -37,8 +36,8 @@
import pytest
import torch
+from tests.conftest import OmniRunner
from tests.utils import hardware_test
-from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
@@ -61,16 +60,15 @@ def _generate_single_stage_image(
Returns (images, peak_memory_gib).
"""
- omni_kwargs: dict[str, Any] = {"model": model, **extra_omni_kwargs}
+ omni_kwargs: dict[str, Any] = dict(extra_omni_kwargs)
if quantization:
omni_kwargs["quantization"] = quantization
- omni = Omni(**omni_kwargs)
- try:
+ with OmniRunner(model, **omni_kwargs) as runner:
torch.cuda.reset_peak_memory_stats()
generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(seed)
- outputs = omni.generate(
+ outputs = runner.omni.generate(
"a photo of a cat sitting on a laptop keyboard",
OmniDiffusionSamplingParams(
height=height,
@@ -94,8 +92,6 @@ def _generate_single_stage_image(
assert images[0].height == height
return images, peak_mem
- finally:
- omni.close()
def _generate_bagel_image(
@@ -115,8 +111,9 @@ def _generate_bagel_image(
if quantization_config:
omni_kwargs["quantization_config"] = quantization_config
- omni = Omni(**omni_kwargs)
- try:
+ model_name = omni_kwargs.pop("model")
+ with OmniRunner(model_name, **omni_kwargs) as runner:
+ omni = runner.omni
torch.cuda.reset_peak_memory_stats()
params_list = omni.default_sampling_params_list
@@ -168,8 +165,6 @@ def _generate_bagel_image(
)
return generated_image, peak_mem
- finally:
- omni.close()
# ─── Single-stage diffusion model tests ──────────────────────────────────────
diff --git a/tests/e2e/offline_inference/test_qwen2_5_omni.py b/tests/e2e/offline_inference/test_qwen2_5_omni.py
index 4c4315aab9..4500ebfbe2 100644
--- a/tests/e2e/offline_inference/test_qwen2_5_omni.py
+++ b/tests/e2e/offline_inference/test_qwen2_5_omni.py
@@ -2,8 +2,6 @@
E2E tests for Qwen2.5-Omni model with mixed modality inputs, audio and text output.
"""
-from pathlib import Path
-
import pytest
from tests.conftest import (
@@ -12,36 +10,31 @@
generate_synthetic_video,
modify_stage_config,
)
-from tests.utils import hardware_test
+from tests.utils import get_deploy_config_path, hardware_test
from vllm_omni.platforms import current_omni_platform
models = ["Qwen/Qwen2.5-Omni-7B"]
+# Single CI deploy YAML; rocm/xpu deltas are picked automatically via the
+# platforms: section. NPU still uses the legacy per-platform YAML until it
+# also migrates to the new schema.
+_CI_DEPLOY = get_deploy_config_path("ci/qwen2_5_omni.yaml")
+
def get_cuda_graph_config():
- path = modify_stage_config(
- str(Path(__file__).parent.parent / "stage_configs" / "qwen2_5_omni_ci.yaml"),
+ return modify_stage_config(
+ _CI_DEPLOY,
updates={
- "stage_args": {
- 0: {
- "engine_args.enforce_eager": "true",
- },
- 1: {"engine_args.enforce_eager": "true"},
+ "stages": {
+ 0: {"enforce_eager": True},
+ 1: {"enforce_eager": True},
},
},
)
- return path
-
-
-# CI stage config optimized for 24GB GPU (L4/RTX3090) or NPU
-if current_omni_platform.is_npu():
- stage_config = str(Path(__file__).parent / "stage_configs" / "npu" / "qwen2_5_omni_ci.yaml")
-elif current_omni_platform.is_rocm():
- # ROCm stage config optimized for MI325 GPU
- stage_config = str(Path(__file__).parent.parent / "stage_configs" / "rocm" / "qwen2_5_omni_ci.yaml")
-elif current_omni_platform.is_xpu():
- # Intel XPU stage config optimized for B60 GPU
- stage_config = str(Path(__file__).parent.parent / "stage_configs" / "xpu" / "qwen2_5_omni_ci.yaml")
+
+
+if current_omni_platform.is_rocm() or current_omni_platform.is_xpu() or current_omni_platform.is_npu():
+ stage_config = _CI_DEPLOY
else:
stage_config = get_cuda_graph_config()
diff --git a/tests/e2e/offline_inference/test_qwen3_omni.py b/tests/e2e/offline_inference/test_qwen3_omni.py
index cc0af437ec..0df89c3e88 100644
--- a/tests/e2e/offline_inference/test_qwen3_omni.py
+++ b/tests/e2e/offline_inference/test_qwen3_omni.py
@@ -7,41 +7,37 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-from pathlib import Path
-
import pytest
from tests.conftest import (
generate_synthetic_video,
modify_stage_config,
)
-from tests.utils import hardware_test
+from tests.utils import get_deploy_config_path, hardware_test
from vllm_omni.platforms import current_omni_platform
models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]
+# Single CI deploy YAML; rocm/xpu deltas are picked automatically via the
+# platforms: section. Only CUDA needs an extra enforce_eager tweak.
+_CI_DEPLOY = get_deploy_config_path("ci/qwen3_omni_moe.yaml")
+
+
def get_cuda_graph_config():
- path = modify_stage_config(
- str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml"),
+ return modify_stage_config(
+ _CI_DEPLOY,
updates={
- "stage_args": {
- 0: {
- "engine_args.enforce_eager": "true",
- },
- 1: {"engine_args.enforce_eager": "true"},
+ "stages": {
+ 0: {"enforce_eager": True},
+ 1: {"enforce_eager": True},
},
},
)
- return path
-# CI stage config for 2xH100-80G GPUs or AMD GPU MI325
-if current_omni_platform.is_rocm():
- # ROCm stage config optimized for MI325 GPU
- stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "rocm" / "qwen3_omni_ci.yaml")]
-elif current_omni_platform.is_xpu():
- stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "xpu" / "qwen3_omni_ci.yaml")]
+if current_omni_platform.is_rocm() or current_omni_platform.is_xpu():
+ stage_configs = [_CI_DEPLOY]
else:
stage_configs = [get_cuda_graph_config()]
diff --git a/tests/e2e/offline_inference/test_qwen3_tts_base.py b/tests/e2e/offline_inference/test_qwen3_tts_base.py
index be7bd50a36..a706798043 100644
--- a/tests/e2e/offline_inference/test_qwen3_tts_base.py
+++ b/tests/e2e/offline_inference/test_qwen3_tts_base.py
@@ -13,12 +13,10 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-from pathlib import Path
-
import pytest
from tests.conftest import modify_stage_config
-from tests.utils import hardware_test
+from tests.utils import get_deploy_config_path, hardware_test
MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
REF_AUDIO_URL = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-TTS-Repo/clone_2.wav"
@@ -26,23 +24,31 @@
def get_cuda_graph_config():
- path = modify_stage_config(
- get_stage_config(),
+ """Build a temp deploy yaml mirroring the deleted qwen3_tts_no_async_chunk.yaml.
+
+ Composes the synchronous (no-async-chunk) variant on top of the bundled
+ qwen3_tts.yaml prod default, with cudagraphs disabled. Replaces the deleted
+ standalone variant yaml; same effective config, no checked-in file needed.
+ """
+ return modify_stage_config(
+ get_deploy_config_path("qwen3_tts.yaml"),
updates={
- "stage_args": {
+ "async_chunk": False,
+ "stages": {
0: {
- "engine_args.enforce_eager": "true",
+ "max_num_seqs": 1,
+ "gpu_memory_utilization": 0.2,
+ "enforce_eager": True,
+ "async_scheduling": False,
+ },
+ 1: {
+ "gpu_memory_utilization": 0.2,
+ "enforce_eager": True,
+ "async_scheduling": False,
},
- 1: {"engine_args.enforce_eager": "true"},
},
},
)
- return path
-
-
-def get_stage_config(name: str = "qwen3_tts_no_async_chunk.yaml"):
- """Get the no_async_chunk stage config path (async_chunk disable, cuda_graph disabled)."""
- return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
# Same structure as test_qwen3_omni: models, stage_configs, test_params
diff --git a/tests/e2e/offline_inference/test_qwen3_tts_customvoice.py b/tests/e2e/offline_inference/test_qwen3_tts_customvoice.py
index 67d72df908..cf411349c3 100644
--- a/tests/e2e/offline_inference/test_qwen3_tts_customvoice.py
+++ b/tests/e2e/offline_inference/test_qwen3_tts_customvoice.py
@@ -13,34 +13,40 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-from pathlib import Path
-
import pytest
from tests.conftest import modify_stage_config
-from tests.utils import hardware_test
+from tests.utils import get_deploy_config_path, hardware_test
MODEL = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
def get_cuda_graph_config():
- path = modify_stage_config(
- get_stage_config(),
+ """Build a temp deploy yaml mirroring the deleted qwen3_tts_no_async_chunk.yaml.
+
+ Composes the synchronous (no-async-chunk) variant on top of the bundled
+ qwen3_tts.yaml prod default, with cudagraphs disabled. Replaces the deleted
+ standalone variant yaml; same effective config, no checked-in file needed.
+ """
+ return modify_stage_config(
+ get_deploy_config_path("qwen3_tts.yaml"),
updates={
- "stage_args": {
+ "async_chunk": False,
+ "stages": {
0: {
- "engine_args.enforce_eager": "true",
+ "max_num_seqs": 1,
+ "gpu_memory_utilization": 0.2,
+ "enforce_eager": True,
+ "async_scheduling": False,
+ },
+ 1: {
+ "gpu_memory_utilization": 0.2,
+ "enforce_eager": True,
+ "async_scheduling": False,
},
- 1: {"engine_args.enforce_eager": "true"},
},
},
)
- return path
-
-
-def get_stage_config(name: str = "qwen3_tts_no_async_chunk.yaml"):
- """Get the no_async_chunk stage config path (async_chunk disable, cuda_graph disabled)."""
- return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
# Same structure as test_qwen3_omni: models, stage_configs, test_params
diff --git a/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py b/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py
index d5f82f893e..f0b0b55c9f 100644
--- a/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py
+++ b/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py
@@ -28,7 +28,6 @@
import argparse
import asyncio
-import os
import sys
import time
import uuid
@@ -37,6 +36,7 @@
import pytest
import torch
+from tests.conftest import OmniRunner
from tests.utils import hardware_test
from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
@@ -48,9 +48,6 @@
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
-from vllm_omni import Omni
-
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
# ------------------------------------------------------------------
models = ["tiny-random/Qwen-Image"]
@@ -391,31 +388,28 @@ async def main(model: str, num_prompts: int, mode: str, batch_size: int = 1) ->
def test_diffusion_batching_sync_sequential(model_name: str):
"""Test that synchronous Omni can generate images for multiple prompts
submitted sequentially (one at a time) and each returns a valid image."""
- m = None
try:
- m = Omni(model=model_name)
- sp = _default_sync_sampling_params()
- prompts = TEST_PROMPTS[:4]
+ with OmniRunner(model_name) as runner:
+ m = runner.omni
+ sp = _default_sync_sampling_params()
+ prompts = TEST_PROMPTS[:4]
- for i, prompt in enumerate(prompts):
- outputs = m.generate(prompt, sp)
- first_output = outputs[0]
- assert first_output.final_output_type == "image", (
- f"Expected 'image', got '{first_output.final_output_type}'"
- )
+ for i, prompt in enumerate(prompts):
+ outputs = m.generate(prompt, sp)
+ first_output = outputs[0]
+ assert first_output.final_output_type == "image", (
+ f"Expected 'image', got '{first_output.final_output_type}'"
+ )
- # Images are surfaced both at top-level and inside request_output
- images = _extract_images(first_output)
- assert len(images) >= 1, f"Expected at least 1 image for prompt {i}, got {len(images)}"
- assert images[0].width == 256
- assert images[0].height == 256
- print(f" prompt {i}: OK ({len(images)} images)")
+ # Images are surfaced both at top-level and inside request_output
+ images = _extract_images(first_output)
+ assert len(images) >= 1, f"Expected at least 1 image for prompt {i}, got {len(images)}"
+ assert images[0].width == 256
+ assert images[0].height == 256
+ print(f" prompt {i}: OK ({len(images)} images)")
except Exception as e:
print(f"Test failed with error: {e}")
raise
- finally:
- if m is not None and hasattr(m, "close"):
- m.close()
@pytest.mark.core_model
@@ -431,34 +425,31 @@ def test_diffusion_batching_sync_multi_prompt(model_name: str):
handling at the diffusion stage, not the explicit list-batch path
(which is only available via AsyncOmni).
"""
- m = None
try:
- m = Omni(model=model_name)
- sp = _default_sync_sampling_params()
- prompts = TEST_PROMPTS[:4]
+ with OmniRunner(model_name) as runner:
+ m = runner.omni
+ sp = _default_sync_sampling_params()
+ prompts = TEST_PROMPTS[:4]
- outputs = m.generate(prompts, sp)
- assert len(outputs) == len(prompts), f"Expected {len(prompts)} outputs, got {len(outputs)}"
+ outputs = m.generate(prompts, sp)
+ assert len(outputs) == len(prompts), f"Expected {len(prompts)} outputs, got {len(outputs)}"
- for i, output in enumerate(outputs):
- assert output.final_output_type == "image", (
- f"Output {i} final_output_type expected 'image', got '{output.final_output_type}'"
- )
- images = _extract_images(output)
- assert images and len(images) >= 1, f"Expected at least 1 image for prompt {i}"
- assert images[0].width == 256
- assert images[0].height == 256
- print(f" prompt {i}: OK ({len(images)} images, request_id={output.request_id})")
-
- # Verify all request_ids are distinct
- request_ids = [o.request_id for o in outputs]
- assert len(set(request_ids)) == len(request_ids), f"Duplicate request_ids found: {request_ids}"
+ for i, output in enumerate(outputs):
+ assert output.final_output_type == "image", (
+ f"Output {i} final_output_type expected 'image', got '{output.final_output_type}'"
+ )
+ images = _extract_images(output)
+ assert images and len(images) >= 1, f"Expected at least 1 image for prompt {i}"
+ assert images[0].width == 256
+ assert images[0].height == 256
+ print(f" prompt {i}: OK ({len(images)} images, request_id={output.request_id})")
+
+ # Verify all request_ids are distinct
+ request_ids = [o.request_id for o in outputs]
+ assert len(set(request_ids)) == len(request_ids), f"Duplicate request_ids found: {request_ids}"
except Exception as e:
print(f"Test failed with error: {e}")
raise
- finally:
- if m is not None and hasattr(m, "close"):
- m.close()
@pytest.mark.core_model
@@ -552,32 +543,29 @@ async def _inner():
def test_diffusion_batching_num_outputs(model_name: str):
"""Test that the diffusion model respects num_outputs_per_prompt and
generates the correct number of images per request."""
- m = None
try:
- m = Omni(model=model_name)
- num_outputs = 2
- sp = _default_sync_sampling_params(num_outputs_per_prompt=num_outputs)
-
- outputs = m.generate(
- "a photo of a cat sitting on a laptop keyboard",
- sp,
- )
+ with OmniRunner(model_name) as runner:
+ m = runner.omni
+ num_outputs = 2
+ sp = _default_sync_sampling_params(num_outputs_per_prompt=num_outputs)
+
+ outputs = m.generate(
+ "a photo of a cat sitting on a laptop keyboard",
+ sp,
+ )
- first_output = outputs[0]
- assert first_output.final_output_type == "image"
- images = _extract_images(first_output)
- assert images is not None and len(images) == num_outputs, (
- f"Expected {num_outputs} images, got {len(images) if images else 0}"
- )
- for img in images:
- assert img.width == 256
- assert img.height == 256
+ first_output = outputs[0]
+ assert first_output.final_output_type == "image"
+ images = _extract_images(first_output)
+ assert images is not None and len(images) == num_outputs, (
+ f"Expected {num_outputs} images, got {len(images) if images else 0}"
+ )
+ for img in images:
+ assert img.width == 256
+ assert img.height == 256
except Exception as e:
print(f"Test failed with error: {e}")
raise
- finally:
- if m is not None and hasattr(m, "close"):
- m.close()
@pytest.mark.core_model
@@ -587,34 +575,31 @@ def test_diffusion_batching_num_outputs(model_name: str):
def test_diffusion_batching_distinct_results(model_name: str):
"""Test that different prompts produce distinct images when batched,
ensuring the batching logic does not mix up results across requests."""
- m = None
try:
- m = Omni(model=model_name)
- sp = _default_sync_sampling_params()
- prompts = [
- {"prompt": "a bright red apple on a white table", "negative_prompt": "blurry"},
- {"prompt": "a blue ocean with white waves crashing", "negative_prompt": "blurry"},
- ]
-
- outputs = m.generate(prompts, sp)
- assert len(outputs) == len(prompts), f"Expected {len(prompts)} outputs, got {len(outputs)}"
-
- # Verify each output has a unique request_id
- request_ids = [o.request_id for o in outputs]
- assert len(set(request_ids)) == len(request_ids), f"Duplicate request_ids: {request_ids}"
-
- # Verify each output has images
- for i, output in enumerate(outputs):
- images = _extract_images(output)
- assert images and len(images) >= 1, f"No images for prompt {i}"
- assert images[0].width == 256
- assert images[0].height == 256
+ with OmniRunner(model_name) as runner:
+ m = runner.omni
+ sp = _default_sync_sampling_params()
+ prompts = [
+ {"prompt": "a bright red apple on a white table", "negative_prompt": "blurry"},
+ {"prompt": "a blue ocean with white waves crashing", "negative_prompt": "blurry"},
+ ]
+
+ outputs = m.generate(prompts, sp)
+ assert len(outputs) == len(prompts), f"Expected {len(prompts)} outputs, got {len(outputs)}"
+
+ # Verify each output has a unique request_id
+ request_ids = [o.request_id for o in outputs]
+ assert len(set(request_ids)) == len(request_ids), f"Duplicate request_ids: {request_ids}"
+
+ # Verify each output has images
+ for i, output in enumerate(outputs):
+ images = _extract_images(output)
+ assert images and len(images) >= 1, f"No images for prompt {i}"
+ assert images[0].width == 256
+ assert images[0].height == 256
except Exception as e:
print(f"Test failed with error: {e}")
raise
- finally:
- if m is not None and hasattr(m, "close"):
- m.close()
# ------------------------------------------------------------------
diff --git a/tests/e2e/offline_inference/test_sequence_parallel.py b/tests/e2e/offline_inference/test_sequence_parallel.py
index 16239a1c52..d3abccd78c 100644
--- a/tests/e2e/offline_inference/test_sequence_parallel.py
+++ b/tests/e2e/offline_inference/test_sequence_parallel.py
@@ -20,8 +20,8 @@
import torch.distributed as dist
from PIL import Image
+from tests.conftest import OmniRunner
from tests.utils import hardware_test
-from vllm_omni import Omni
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
@@ -92,49 +92,48 @@ def _run_inference(
warmup: If True, run one warmup iteration before the timed run.
"""
parallel_config = DiffusionParallelConfig(ulysses_degree=ulysses_degree, ring_degree=ring_degree)
- omni = Omni(
- model=model_name,
- parallel_config=parallel_config,
- dtype=dtype,
- attention_backend=attn_backend,
- )
-
try:
- # Warmup run (not timed)
- if warmup:
- _ = omni.generate(
+ with OmniRunner(
+ model_name,
+ parallel_config=parallel_config,
+ dtype=dtype,
+ attention_backend=attn_backend,
+ ) as runner:
+ omni = runner.omni
+ # Warmup run (not timed)
+ if warmup:
+ _ = omni.generate(
+ PROMPT,
+ OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_inference_steps=DEFAULT_STEPS,
+ guidance_scale=0.0,
+ generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed + 1000),
+ num_outputs_per_prompt=1,
+ ),
+ )
+
+ # Timed run
+ start = time.time()
+ outputs = omni.generate(
PROMPT,
OmniDiffusionSamplingParams(
height=height,
width=width,
num_inference_steps=DEFAULT_STEPS,
guidance_scale=0.0,
- generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed + 1000),
+ generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed),
num_outputs_per_prompt=1,
),
)
+ elapsed_ms = (time.time() - start) * 1000
- # Timed run
- start = time.time()
- outputs = omni.generate(
- PROMPT,
- OmniDiffusionSamplingParams(
- height=height,
- width=width,
- num_inference_steps=DEFAULT_STEPS,
- guidance_scale=0.0,
- generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed),
- num_outputs_per_prompt=1,
- ),
- )
- elapsed_ms = (time.time() - start) * 1000
-
- return InferenceResult(
- images=outputs[0].request_output.images,
- elapsed_ms=elapsed_ms,
- )
+ return InferenceResult(
+ images=outputs[0].request_output.images,
+ elapsed_ms=elapsed_ms,
+ )
finally:
- omni.close()
_cleanup_distributed()
diff --git a/tests/e2e/offline_inference/test_stable_audio_expansion.py b/tests/e2e/offline_inference/test_stable_audio_expansion.py
new file mode 100644
index 0000000000..54c1799e14
--- /dev/null
+++ b/tests/e2e/offline_inference/test_stable_audio_expansion.py
@@ -0,0 +1,99 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""Stable Audio offline e2e: real weights, FP8 + TeaCache (single job to save GPU).
+
+NOTE: This test instantiates Omni directly instead of using the omni_runner
+fixture (introduced in PR #2711) because the fixture's parametrize interface
+only accepts (model, stage_config_path) and does not support extra kwargs like
+quantization, cache_backend, or cache_config.
+"""
+
+from __future__ import annotations
+
+import numpy as np
+import pytest
+import torch
+
+from tests.conftest import assert_audio_valid
+from tests.utils import hardware_test
+from vllm_omni import Omni
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+from vllm_omni.outputs import OmniRequestOutput
+from vllm_omni.platforms import current_omni_platform
+
+_SAMPLE_RATE = 44100
+_CLIP_DURATION_S = 2.0
+
+
+def generate_stable_audio_short_clip(
+ omni: Omni,
+ *,
+ audio_start_in_s: float = 0.0,
+ audio_end_in_s: float = 2.0,
+ num_inference_steps: int = 4,
+ seed: int = 42,
+) -> np.ndarray:
+ """Run a minimal Stable Audio generation and return audio as (batch, channels, samples)."""
+ outputs = omni.generate(
+ prompts={
+ "prompt": "The sound of a dog barking",
+ "negative_prompt": "Low quality.",
+ },
+ sampling_params_list=OmniDiffusionSamplingParams(
+ num_inference_steps=num_inference_steps,
+ guidance_scale=7.0,
+ generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed),
+ num_outputs_per_prompt=1,
+ extra_args={
+ "audio_start_in_s": audio_start_in_s,
+ "audio_end_in_s": audio_end_in_s,
+ },
+ ),
+ )
+
+ assert outputs is not None
+ first_output = outputs[0]
+ # Outer OmniRequestOutput.final_output_type comes from get_stage_metadata.
+ # The nested request_output is the worker OmniRequestOutput
+ # (e.g. final_output_type="audio") and holds the multimodal payload.
+ # Follow-up: add StableAudioPipeline stage YAML, and pass model into
+ # _create_default_diffusion_stage_cfg so default diffusion metadata can set
+ # final_output_type to "audio" for future audio pipelines without YAML.
+ assert first_output.final_output_type == "image"
+ assert hasattr(first_output, "request_output") and first_output.request_output
+
+ req_out = first_output.request_output
+ assert isinstance(req_out, OmniRequestOutput)
+ assert req_out.final_output_type == "audio"
+ assert hasattr(req_out, "multimodal_output") and req_out.multimodal_output
+ audio = req_out.multimodal_output.get("audio")
+ assert isinstance(audio, np.ndarray)
+ return audio
+
+
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
+@pytest.mark.cache
+@hardware_test(res={"cuda": "L4", "xpu": "B60"})
+def test_stable_audio_quantization_and_teacache() -> None:
+ """Stable Audio Open on real Hub weights with FP8 + TeaCache (covers former L2 smoke + L4 features).
+
+ CI should provide ``HF_TOKEN`` if the checkpoint is gated.
+ """
+ m = Omni(
+ model="stabilityai/stable-audio-open-1.0",
+ quantization="fp8",
+ cache_backend="tea_cache",
+ cache_config={"rel_l1_thresh": 0.2},
+ )
+ try:
+ audio = generate_stable_audio_short_clip(m)
+ assert_audio_valid(
+ audio,
+ sample_rate=_SAMPLE_RATE,
+ channels=2,
+ duration_s=_CLIP_DURATION_S,
+ )
+ finally:
+ m.close()
diff --git a/tests/e2e/offline_inference/test_stable_audio_model.py b/tests/e2e/offline_inference/test_stable_audio_model.py
deleted file mode 100644
index ff4d9b4017..0000000000
--- a/tests/e2e/offline_inference/test_stable_audio_model.py
+++ /dev/null
@@ -1,72 +0,0 @@
-import sys
-from pathlib import Path
-
-import numpy as np
-import pytest
-import torch
-
-from tests.utils import hardware_test
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-from vllm_omni.outputs import OmniRequestOutput
-from vllm_omni.platforms import current_omni_platform
-
-# ruff: noqa: E402
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
-
-from vllm_omni import Omni
-
-# Use random weights model for CI testing (small, no authentication required)
-models = ["linyueqian/stable_audio_random"]
-
-
-@pytest.mark.core_model
-@pytest.mark.diffusion
-@hardware_test(res={"cuda": "L4", "xpu": "B60"})
-@pytest.mark.parametrize("model_name", models)
-def test_stable_audio_model(model_name: str):
- m = Omni(model=model_name)
-
- # Use minimal settings for testing
- # Generate a short 2-second audio clip with minimal inference steps
- audio_start_in_s = 0.0
- audio_end_in_s = 2.0 # Short duration for fast testing
- sample_rate = 44100 # Stable Audio uses 44100 Hz
-
- outputs = m.generate(
- prompts={
- "prompt": "The sound of a dog barking",
- "negative_prompt": "Low quality.",
- },
- sampling_params_list=OmniDiffusionSamplingParams(
- num_inference_steps=4, # Minimal steps for speed
- guidance_scale=7.0,
- generator=torch.Generator(current_omni_platform.device_type).manual_seed(42),
- num_outputs_per_prompt=1,
- extra_args={
- "audio_start_in_s": audio_start_in_s,
- "audio_end_in_s": audio_end_in_s,
- },
- ),
- )
-
- # Extract audio from OmniRequestOutput
- assert outputs is not None
- first_output = outputs[0]
- assert first_output.final_output_type == "image"
- assert hasattr(first_output, "request_output") and first_output.request_output
-
- req_out = first_output.request_output
- assert isinstance(req_out, OmniRequestOutput)
- assert req_out.final_output_type == "audio"
- assert hasattr(req_out, "multimodal_output") and req_out.multimodal_output
- audio = req_out.multimodal_output.get("audio")
- assert isinstance(audio, np.ndarray)
- # audio shape: (batch, channels, samples)
- # For stable-audio-open-1.0: sample_rate=44100, so 2 seconds = 88200 samples
- assert audio.ndim == 3
- assert audio.shape[0] == 1 # batch size
- assert audio.shape[1] == 2 # stereo channels
- expected_samples = int((audio_end_in_s - audio_start_in_s) * sample_rate)
- assert audio.shape[2] == expected_samples # 88200 samples for 2 seconds
diff --git a/tests/e2e/offline_inference/test_t2i_model.py b/tests/e2e/offline_inference/test_t2i_model.py
index 77b2b3aaf2..fc54f9a7ff 100644
--- a/tests/e2e/offline_inference/test_t2i_model.py
+++ b/tests/e2e/offline_inference/test_t2i_model.py
@@ -1,7 +1,3 @@
-import os
-import sys
-from pathlib import Path
-
import pytest
import torch
@@ -10,14 +6,12 @@
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
-# ruff: noqa: E402
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
+# Match unprefixed HF id even when MODEL_PREFIX is set (omni_runner resolves full path).
+_QWEN_IMAGE_RANDOM_ID = "riverclouds/qwen_image_random"
-from vllm_omni import Omni
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
+def _is_qwen_image_random(model_path: str) -> bool:
+ return model_path.rstrip("/").endswith(_QWEN_IMAGE_RANDOM_ID)
models = ["Tongyi-MAI/Z-Image-Turbo", "riverclouds/qwen_image_random"]
@@ -26,62 +20,56 @@
# TODO: When NPU support is ready, remove this branch.
if current_omni_platform.is_npu():
models = ["Tongyi-MAI/Z-Image-Turbo", "Qwen/Qwen-Image"]
-elif current_omni_platform.is_rocm():
- # TODO: When ROCm support is ready, remove this branch.
- # Current upstream vLLM has issues running riverclouds/qwen_image_random
- # on ROCm
- models = ["Tongyi-MAI/Z-Image-Turbo"]
+
+# omni_runner expects (model, stage_configs_path); single-stage diffusion has no YAML.
+test_params = [(m, None) for m in models]
@pytest.mark.core_model
@pytest.mark.advanced_model
@pytest.mark.diffusion
-@hardware_test(res={"cuda": "L4", "rocm": "MI325", "xpu": "B60"}, num_cards={"cuda": 1, "rocm": 2, "xpu": 2})
-@pytest.mark.parametrize("model_name", models)
-def test_diffusion_model(model_name: str, run_level):
- if run_level == "core_model" and model_name != "riverclouds/qwen_image_random":
+@hardware_test(res={"cuda": "L4", "rocm": "MI325", "xpu": "B60"}, num_cards={"cuda": 1, "rocm": 1, "xpu": 2})
+@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
+def test_diffusion_model(omni_runner, run_level):
+ resolved = omni_runner.model_name
+ if run_level == "core_model" and not _is_qwen_image_random(resolved):
pytest.skip()
- if run_level == "advanced_model" and model_name == "riverclouds/qwen_image_random":
+ if run_level == "advanced_model" and _is_qwen_image_random(resolved):
pytest.skip()
- m = None
- try:
- m = Omni(model=model_name)
- # high resolution may cause OOM on L4
- height = 256
- width = 256
- outputs = m.generate(
- "a photo of a cat sitting on a laptop keyboard",
- OmniDiffusionSamplingParams(
- height=height,
- width=width,
- num_inference_steps=2,
- guidance_scale=0.0,
- generator=torch.Generator(current_omni_platform.device_type).manual_seed(42),
- num_outputs_per_prompt=2,
- ),
- )
- # Extract images from request_output['images']
- first_output = outputs[0]
- assert first_output.final_output_type == "image"
- if not hasattr(first_output, "request_output") or not first_output.request_output:
- raise ValueError("No request_output found in OmniRequestOutput")
-
- req_out = first_output.request_output
- if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"):
- raise ValueError("Invalid request_output structure or missing 'images' key")
-
- images = req_out.images
-
- assert len(images) == 2
- # check image size
- assert images[0].width == width
- assert images[0].height == height
- images[0].save("image_output.png")
- except Exception as e:
- print(f"Test failed with error: {e}")
- raise
- finally:
- if m is not None and hasattr(m, "close"):
- m.close()
+ # high resolution may cause OOM on L4
+ height = 256
+ width = 256
+ sampling = OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_inference_steps=2,
+ guidance_scale=0.0,
+ generator=torch.Generator(current_omni_platform.device_type).manual_seed(42),
+ num_outputs_per_prompt=2,
+ )
+
+ # OmniRunner.generate() is typed for list[TextPrompt]; diffusion uses Omni.generate(str, ...).
+ outputs = omni_runner.omni.generate(
+ "a photo of a cat sitting on a laptop keyboard",
+ sampling,
+ )
+
+ # Extract images from request_output['images']
+ first_output = outputs[0]
+ assert first_output.final_output_type == "image"
+ if not hasattr(first_output, "request_output") or not first_output.request_output:
+ raise ValueError("No request_output found in OmniRequestOutput")
+
+ req_out = first_output.request_output
+ if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"):
+ raise ValueError("Invalid request_output structure or missing 'images' key")
+
+ images = req_out.images
+
+ assert len(images) == 2
+ # check image size
+ assert images[0].width == width
+ assert images[0].height == height
+ images[0].save("image_output.png")
diff --git a/tests/e2e/offline_inference/test_t2v_model.py b/tests/e2e/offline_inference/test_t2v_model.py
index 94c9dedf74..6fe623cfc8 100644
--- a/tests/e2e/offline_inference/test_t2v_model.py
+++ b/tests/e2e/offline_inference/test_t2v_model.py
@@ -1,22 +1,13 @@
import os
-import sys
-from pathlib import Path
import pytest
import torch
+from tests.conftest import OmniRunner
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-# ruff: noqa: E402
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
-
-from vllm_omni import Omni
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
-# os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
models = ["Wan-AI/Wan2.2-T2V-A14B-Diffusers"]
@@ -24,28 +15,28 @@
@pytest.mark.parametrize("model_name", models)
def test_video_diffusion_model(model_name: str):
- m = Omni(
- model=model_name,
+ with OmniRunner(
+ model_name,
boundary_ratio=0.875,
flow_shift=5.0,
- )
- # Use minimal settings for testing
- # num_frames must satisfy: num_frames % vae_scale_factor_temporal == 1
- # For Wan2.2, vae_scale_factor_temporal=4, so valid values are 5, 9, 13, 17, ...
- height = 480
- width = 640
- num_frames = 5
- outputs = m.generate(
- prompts="A cat sitting on a table",
- sampling_params_list=OmniDiffusionSamplingParams(
- height=height,
- width=width,
- num_frames=num_frames,
- num_inference_steps=2,
- guidance_scale=1.0,
- generator=torch.Generator(current_omni_platform.device_type).manual_seed(42),
- ),
- )
+ ) as runner:
+ # Use minimal settings for testing
+ # num_frames must satisfy: num_frames % vae_scale_factor_temporal == 1
+ # For Wan2.2, vae_scale_factor_temporal=4, so valid values are 5, 9, 13, 17, ...
+ height = 480
+ width = 640
+ num_frames = 5
+ outputs = runner.omni.generate(
+ prompts="A cat sitting on a table",
+ sampling_params_list=OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ num_inference_steps=2,
+ guidance_scale=1.0,
+ generator=torch.Generator(current_omni_platform.device_type).manual_seed(42),
+ ),
+ )
first_output = outputs[0]
assert first_output.final_output_type == "image"
if not hasattr(first_output, "request_output") or not first_output.request_output:
diff --git a/tests/e2e/offline_inference/test_teacache.py b/tests/e2e/offline_inference/test_teacache.py
index efc0e43e86..7cd1c5a479 100644
--- a/tests/e2e/offline_inference/test_teacache.py
+++ b/tests/e2e/offline_inference/test_teacache.py
@@ -8,26 +8,14 @@
It uses minimal settings to keep test time short for CI.
"""
-import os
-import sys
-from pathlib import Path
-
import pytest
import torch
+from tests.conftest import OmniRunner
from tests.utils import hardware_test
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-from vllm_omni.platforms import current_omni_platform
-
-# ruff: noqa: E402
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
-
-from vllm_omni import Omni
from vllm_omni.outputs import OmniRequestOutput
-
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
+from vllm_omni.platforms import current_omni_platform
# Use random weights model for testing
models = ["riverclouds/qwen_image_random"]
@@ -44,20 +32,17 @@ def test_teacache(model_name: str):
cache_config = {
"rel_l1_thresh": 0.2, # Default threshold
}
- m = None
- try:
- m = Omni(
- model=model_name,
- cache_backend="tea_cache",
- cache_config=cache_config,
- )
-
+ with OmniRunner(
+ model_name,
+ cache_backend="tea_cache",
+ cache_config=cache_config,
+ ) as runner:
# Use minimal settings for fast testing
height = 256
width = 256
num_inference_steps = 4 # Minimal steps for fast test
- outputs = m.generate(
+ outputs = runner.omni.generate(
"a photo of a cat sitting on a laptop keyboard",
OmniDiffusionSamplingParams(
height=height,
@@ -86,9 +71,3 @@ def test_teacache(model_name: str):
# Check image size
assert images[0].width == width
assert images[0].height == height
- except Exception as e:
- print(f"Test failed with error: {e}")
- raise
- finally:
- if m is not None and hasattr(m, "close"):
- m.close()
diff --git a/tests/e2e/offline_inference/test_vae_decode_parallelism.py b/tests/e2e/offline_inference/test_vae_decode_parallelism.py
index cee76fac2e..0fce28d669 100644
--- a/tests/e2e/offline_inference/test_vae_decode_parallelism.py
+++ b/tests/e2e/offline_inference/test_vae_decode_parallelism.py
@@ -18,7 +18,7 @@
import time
-from vllm_omni import Omni
+from tests.conftest import OmniRunner
from vllm_omni.platforms import current_omni_platform
# os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
@@ -72,23 +72,22 @@ def is_nextstep_model(model_name: str) -> bool:
def model_run(model_configs, tp, out_height, out_width, out_frames, using_tile, vae_patch_parallel_size=1):
- m = None
- try:
- parallel_config = DiffusionParallelConfig(
- tensor_parallel_size=tp,
- vae_patch_parallel_size=vae_patch_parallel_size,
- )
+ parallel_config = DiffusionParallelConfig(
+ tensor_parallel_size=tp,
+ vae_patch_parallel_size=vae_patch_parallel_size,
+ )
- omni_kwargs = {
- "model": model_configs["model_name"],
- "vae_use_tiling": using_tile,
- "parallel_config": parallel_config,
- }
- use_nextstep = is_nextstep_model(model_configs["model_name"])
- if use_nextstep:
- # NextStep-1.1 requires explicit pipeline class
- omni_kwargs["model_class_name"] = "NextStep11Pipeline"
- m = Omni(**omni_kwargs)
+ omni_kwargs = {
+ "vae_use_tiling": using_tile,
+ "parallel_config": parallel_config,
+ }
+ use_nextstep = is_nextstep_model(model_configs["model_name"])
+ if use_nextstep:
+ # NextStep-1.1 requires explicit pipeline class
+ omni_kwargs["model_class_name"] = "NextStep11Pipeline"
+
+ with OmniRunner(model_configs["model_name"], **omni_kwargs) as runner:
+ m = runner.omni
image = Image.new("RGB", (out_width, out_height), (0, 0, 0))
start = time.perf_counter()
outputs = m.generate(
@@ -115,9 +114,6 @@ def model_run(model_configs, tp, out_height, out_width, out_frames, using_tile,
# frames shape: (batch, num_frames, height, width, channels)
cost = (end - start) * 1000
return frames, cost
- finally:
- if m is not None:
- m.close()
cleanup_dist_env_and_memory()
diff --git a/tests/e2e/offline_inference/test_voxcpm.py b/tests/e2e/offline_inference/test_voxcpm.py
new file mode 100644
index 0000000000..d7f65525e9
--- /dev/null
+++ b/tests/e2e/offline_inference/test_voxcpm.py
@@ -0,0 +1,156 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""E2E test for VoxCPM offline inference."""
+
+import json
+import os
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import pytest
+import torch
+
+import tests.conftest as omni_test_conftest
+from tests.conftest import OmniRunner
+from tests.utils import hardware_test
+from vllm_omni.model_executor.models.voxcpm.voxcpm_runtime_utils import (
+ prepare_voxcpm_hf_config_dir,
+ resolve_voxcpm_model_dir,
+)
+
+VOXCPM_MODEL = os.environ.get("VOXCPM_MODEL", "OpenBMB/VoxCPM1.5")
+STAGE_CONFIG = str(
+ Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm.yaml"
+)
+SAMPLE_RATE = 24000
+
+
+@pytest.fixture(autouse=True)
+def _patch_npu_cleanup_for_voxcpm(monkeypatch: pytest.MonkeyPatch):
+ """Limit the NPU cleanup workaround to this VoxCPM test module only."""
+ original_cleanup = omni_test_conftest.cleanup_dist_env_and_memory
+
+ def _safe_cleanup() -> None:
+ try:
+ original_cleanup()
+ except RuntimeError as exc:
+ if "Allocator for npu is not a DeviceAllocator" in str(exc):
+ return
+ raise
+
+ monkeypatch.setattr(omni_test_conftest, "cleanup_dist_env_and_memory", _safe_cleanup)
+
+
+def _build_prompt(text: str) -> dict[str, Any]:
+ return {
+ "prompt_token_ids": [1],
+ "additional_information": {
+ "text": [text],
+ "cfg_value": [2.0],
+ "inference_timesteps": [10],
+ "min_len": [2],
+ "max_new_tokens": [1024],
+ },
+ }
+
+
+def _extract_audio_tensor(multimodal_output: dict[str, Any]) -> torch.Tensor:
+ audio = multimodal_output.get("audio", multimodal_output.get("model_outputs"))
+ assert audio is not None, f"No audio output found, keys={list(multimodal_output.keys())}"
+
+ if isinstance(audio, list):
+ parts: list[torch.Tensor] = []
+ for item in audio:
+ if item is None:
+ continue
+ tensor = torch.as_tensor(item)
+ if tensor.numel() == 0:
+ continue
+ parts.append(tensor.float().cpu().reshape(-1))
+ return torch.cat(parts, dim=-1) if parts else torch.zeros((0,), dtype=torch.float32)
+
+ return torch.as_tensor(audio).float().cpu().reshape(-1)
+
+
+def _extract_final_multimodal_output(outputs) -> dict[str, Any]:
+ for item in reversed(outputs):
+ request_output = getattr(item, "request_output", None)
+ if request_output is not None:
+ multimodal_output = getattr(request_output, "multimodal_output", None)
+ if isinstance(multimodal_output, dict):
+ return multimodal_output
+ completions = getattr(request_output, "outputs", None) or []
+ for completion in completions:
+ multimodal_output = getattr(completion, "multimodal_output", None)
+ if isinstance(multimodal_output, dict):
+ return multimodal_output
+
+ multimodal_output = getattr(item, "multimodal_output", None)
+ if isinstance(multimodal_output, dict):
+ return multimodal_output
+
+ raise AssertionError("No multimodal audio output found in VoxCPM generate results")
+
+
+@pytest.fixture
+def voxcpm_model_path(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> str:
+ model_dir = resolve_voxcpm_model_dir(VOXCPM_MODEL)
+
+ hf_config_env = os.environ.get("VLLM_OMNI_VOXCPM_HF_CONFIG_PATH")
+ if hf_config_env:
+ hf_config_dir = Path(hf_config_env).expanduser()
+ else:
+ hf_config_dir = tmp_path / "voxcpm_hf_config"
+
+ if not (hf_config_dir / "config.json").exists():
+ prepare_voxcpm_hf_config_dir(model_dir, hf_config_dir)
+
+ monkeypatch.setenv("VLLM_OMNI_VOXCPM_HF_CONFIG_PATH", str(hf_config_dir))
+ return str(model_dir)
+
+
+def test_prepare_voxcpm_hf_config_dir(tmp_path: Path):
+ model_dir = tmp_path / "model"
+ model_dir.mkdir()
+ (model_dir / "config.json").write_text(json.dumps({"hidden_size": 1024}), encoding="utf-8")
+ (model_dir / "generation_config.json").write_text(json.dumps({"do_sample": False}), encoding="utf-8")
+
+ hf_config_dir = prepare_voxcpm_hf_config_dir(model_dir, tmp_path / "voxcpm_hf_config")
+
+ prepared_config = json.loads((hf_config_dir / "config.json").read_text(encoding="utf-8"))
+ assert prepared_config["model_type"] == "voxcpm"
+ assert prepared_config["architectures"] == ["VoxCPMForConditionalGeneration"]
+ assert (hf_config_dir / "generation_config.json").exists()
+
+
+def test_resolve_voxcpm_model_dir_local_path(tmp_path: Path):
+ model_dir = tmp_path / "OpenBMB" / "VoxCPM1.5"
+ model_dir.mkdir(parents=True)
+
+ assert resolve_voxcpm_model_dir(str(model_dir)) == model_dir
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "L4"}, num_cards=1)
+def test_voxcpm_zero_shot_001(voxcpm_model_path: str):
+ with OmniRunner(voxcpm_model_path, stage_configs_path=STAGE_CONFIG) as runner:
+ outputs = list(runner.omni.generate(_build_prompt("Hello, this is a VoxCPM offline inference test.")))
+
+ assert outputs, "No outputs returned"
+
+ multimodal_output = _extract_final_multimodal_output(outputs)
+ audio = _extract_audio_tensor(multimodal_output)
+ assert audio.numel() > SAMPLE_RATE // 2, f"Audio too short: {audio.numel()} samples"
+
+ duration_s = audio.shape[0] / SAMPLE_RATE
+ assert 0.5 < duration_s < 30.0, f"Audio duration out of range: {duration_s:.2f}s"
+
+ peak = float(torch.max(torch.abs(audio)).item()) if audio.numel() > 0 else 0.0
+ assert peak > 0.01, "Generated audio appears to be silence"
+
+ audio_np = audio.numpy()
+ rms = float(np.sqrt(np.mean(np.square(audio_np)))) if audio_np.size else 0.0
+ assert rms > 1e-4, "Generated audio RMS too low"
diff --git a/tests/e2e/offline_inference/test_voxcpm2.py b/tests/e2e/offline_inference/test_voxcpm2.py
new file mode 100644
index 0000000000..e37d3f74df
--- /dev/null
+++ b/tests/e2e/offline_inference/test_voxcpm2.py
@@ -0,0 +1,130 @@
+"""E2E test for VoxCPM2 native AR offline inference."""
+
+import os
+
+import pytest
+import torch
+
+from tests.conftest import OmniRunner
+from tests.utils import hardware_test
+
+VOXCPM2_MODEL = "openbmb/VoxCPM2"
+STAGE_CONFIG = os.path.join(
+ os.path.dirname(__file__),
+ "..",
+ "..",
+ "..",
+ "vllm_omni",
+ "model_executor",
+ "stage_configs",
+ "voxcpm2.yaml",
+)
+SAMPLE_RATE = 48000
+
+
+@pytest.fixture(scope="module")
+def voxcpm2_engine():
+ """Create VoxCPM2 engine for testing."""
+ with OmniRunner(VOXCPM2_MODEL, stage_configs_path=STAGE_CONFIG) as runner:
+ yield runner.omni
+
+
+def _extract_audio(multimodal_output: dict) -> torch.Tensor:
+ """Extract the final complete audio tensor from multimodal output."""
+ assert isinstance(multimodal_output, dict), f"Expected dict, got {type(multimodal_output)}"
+
+ # Output processor accumulates per-step audio chunks under "audio".
+ audio = multimodal_output.get("audio")
+ if audio is None:
+ audio = multimodal_output.get("model_outputs")
+ assert audio is not None, f"No audio key, got {list(multimodal_output.keys())}"
+
+ if isinstance(audio, list):
+ valid = [torch.as_tensor(x).float().cpu().reshape(-1) for x in audio if x is not None]
+ assert valid, "No valid audio tensors in output list"
+ audio = torch.cat(valid, dim=0) if len(valid) > 1 else valid[0]
+
+ assert isinstance(audio, torch.Tensor), f"Expected Tensor, got {type(audio)}"
+ return audio
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "L4"}, num_cards=1)
+def test_voxcpm2_zero_shot_001(voxcpm2_engine):
+ """Test zero-shot TTS produces valid audio output."""
+ outputs = voxcpm2_engine.generate([{"prompt": "Hello, this is a test."}])
+ assert len(outputs) == 1
+
+ audio = _extract_audio(outputs[0].outputs[0].multimodal_output)
+ duration_s = audio.shape[0] / SAMPLE_RATE
+ assert 0.5 < duration_s < 30.0, f"Audio duration out of range: {duration_s:.2f}s"
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "L4"}, num_cards=1)
+def test_voxcpm2_voice_clone_002(voxcpm2_engine):
+ """Test voice cloning with a reference audio file.
+
+ Uses the example ``reference_speaker.wav`` bundled with the voxcpm
+ package. Skipped if the file is not present.
+ """
+ # Try to locate a reference wav from the voxcpm package / env override
+ candidates = []
+ env_path = os.environ.get("VLLM_OMNI_VOXCPM_CODE_PATH")
+ if env_path:
+ candidates.append(os.path.join(env_path, "..", "examples", "reference_speaker.wav"))
+ try:
+ import voxcpm # noqa: F401 (only used to locate path)
+
+ vox_dir = os.path.dirname(os.path.dirname(os.path.abspath(voxcpm.__file__)))
+ candidates.append(os.path.join(vox_dir, "examples", "reference_speaker.wav"))
+ except ImportError:
+ pass
+
+ ref_path = next((p for p in candidates if p and os.path.exists(p)), None)
+ if ref_path is None:
+ pytest.skip("No reference audio available for voice clone test")
+
+ outputs = voxcpm2_engine.generate(
+ [
+ {
+ "prompt": "Hello, this is a voice clone demo.",
+ "additional_information": {"reference_audio": ref_path},
+ }
+ ]
+ )
+ assert len(outputs) == 1
+
+ audio = _extract_audio(outputs[0].outputs[0].multimodal_output)
+ duration_s = audio.shape[0] / SAMPLE_RATE
+ assert 0.5 < duration_s < 30.0, f"Audio duration out of range: {duration_s:.2f}s"
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "L4"}, num_cards=1)
+def test_voxcpm2_prefill_decode_mixed_batch_003(voxcpm2_engine):
+ """Regression: prefill+decode mixed batch must not crash (PR #2903)."""
+ long_prompt = (
+ "This is a deliberately long prompt that will stay in the decode "
+ "phase for many steps so that subsequent shorter prompts keep "
+ "entering prefill alongside it, reproducing the prefill plus "
+ "decode mixed batch scheduling pattern."
+ )
+ short_prompts = [
+ "Hello one.",
+ "Hello two.",
+ "Hello three.",
+ "Hello four.",
+ ]
+ requests = [{"prompt": long_prompt}] + [{"prompt": p} for p in short_prompts]
+
+ outputs = voxcpm2_engine.generate(requests)
+ assert len(outputs) == len(requests)
+
+ for i, out in enumerate(outputs):
+ audio = _extract_audio(out.outputs[0].multimodal_output)
+ duration_s = audio.shape[0] / SAMPLE_RATE
+ assert 0.1 < duration_s < 30.0, f"Request {i} audio duration out of range: {duration_s:.2f}s"
diff --git a/tests/e2e/offline_inference/test_voxtral_tts.py b/tests/e2e/offline_inference/test_voxtral_tts.py
index b559cc252d..4f440f243b 100644
--- a/tests/e2e/offline_inference/test_voxtral_tts.py
+++ b/tests/e2e/offline_inference/test_voxtral_tts.py
@@ -19,7 +19,6 @@
import uuid
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
from pathlib import Path
@@ -30,10 +29,9 @@
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from vllm import SamplingParams
-from tests.conftest import modify_stage_config
+from tests.conftest import OmniRunner, modify_stage_config
from tests.utils import hardware_test
from vllm_omni.entrypoints.async_omni import AsyncOmni
-from vllm_omni.entrypoints.omni import Omni
MODEL = "mistralai/Voxtral-4B-TTS-2603"
STAGE_CONFIG = str(
@@ -83,14 +81,12 @@ def test_voxtral_tts_offline_basic(run_level):
"""Test basic Voxtral TTS offline inference with a voice preset."""
stage_config = _resolve_stage_config(run_level)
- omni = Omni(
- model=MODEL,
+ with OmniRunner(
+ MODEL,
stage_configs_path=stage_config,
- stage_init_timeout=300,
enforce_eager=True,
- )
-
- try:
+ ) as runner:
+ omni = runner.omni
inputs = _compose_request(MODEL, TEST_TEXT, VOICE)
sampling_params = SamplingParams(max_tokens=2500)
@@ -127,9 +123,6 @@ def test_voxtral_tts_offline_basic(run_level):
# Verify audio isn't all zeros / silence
assert np.max(np.abs(audio_array)) > 0.01, "Audio appears to be silence"
- finally:
- omni.close()
-
@pytest.mark.advanced_model
@pytest.mark.omni
diff --git a/tests/e2e/offline_inference/test_zimage_parallelism.py b/tests/e2e/offline_inference/test_zimage_parallelism.py
index 9d9db16a40..27edc48f20 100644
--- a/tests/e2e/offline_inference/test_zimage_parallelism.py
+++ b/tests/e2e/offline_inference/test_zimage_parallelism.py
@@ -12,7 +12,6 @@
"""
import os
-import sys
import time
from pathlib import Path
@@ -20,21 +19,14 @@
import pytest
import torch
from PIL import Image
-from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
+from tests.conftest import OmniRunner
from tests.utils import DeviceMemoryMonitor, hardware_test
-from vllm_omni import Omni
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
-# ruff: noqa: E402
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
-
-
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
PROMPT = "a photo of a cat sitting on a laptop keyboard"
@@ -97,61 +89,61 @@ def _run_zimage_generate(
device_index = current_omni_platform.current_device()
monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02)
monitor.start()
- m = Omni(
- model=_get_zimage_model(),
- parallel_config=DiffusionParallelConfig(
- tensor_parallel_size=tp_size,
- vae_patch_parallel_size=vae_patch_parallel_size,
- ),
- enforce_eager=enforce_eager,
- vae_use_tiling=vae_use_tiling,
- )
try:
- # NOTE: Omni closes itself when a generate() call is exhausted.
- # To avoid measuring teardown time (process shutdown, memory cleanup),
- # we measure the latency to produce *subsequent* outputs within a single
- # generator run.
- #
- # This also serves as a warmup: the first output may include extra
- # compilation/caching overhead, while later outputs are closer to
- # steady-state inference.
- gen = m.generate(
- [PROMPT] * num_requests,
- OmniDiffusionSamplingParams(
- height=height,
- width=width,
- num_inference_steps=num_inference_steps,
- guidance_scale=0.0,
- seed=seed,
- num_outputs_per_prompt=1,
+ # Each run needs a distinct DiffusionParallelConfig; use OmniRunner per call (not the
+ # parametrized omni_runner fixture, which is fixed per module).
+ with OmniRunner(
+ _get_zimage_model(),
+ parallel_config=DiffusionParallelConfig(
+ tensor_parallel_size=tp_size,
+ vae_patch_parallel_size=vae_patch_parallel_size,
),
- py_generator=True,
- )
-
- warmup_output = next(gen)
-
- t_prev = time.perf_counter()
- per_request_times_s: list[float] = []
- last_output = warmup_output
- for _ in range(num_requests - 1):
- last_output = next(gen)
- t_now = time.perf_counter()
- per_request_times_s.append(t_now - t_prev)
- t_prev = t_now
-
- # Ensure the generator is fully consumed so it can clean up.
- for _ in gen:
- pass
-
- median_time_s = float(np.median(per_request_times_s))
-
- peak_memory_mb = monitor.peak_used_mb
-
- return _extract_single_image([last_output]), median_time_s, peak_memory_mb
+ enforce_eager=enforce_eager,
+ vae_use_tiling=vae_use_tiling,
+ ) as runner:
+ # NOTE: Omni closes itself when a generate() call is exhausted.
+ # To avoid measuring teardown time (process shutdown, memory cleanup),
+ # we measure the latency to produce *subsequent* outputs within a single
+ # generator run.
+ #
+ # This also serves as a warmup: the first output may include extra
+ # compilation/caching overhead, while later outputs are closer to
+ # steady-state inference.
+ gen = runner.omni.generate(
+ [PROMPT] * num_requests,
+ OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=0.0,
+ seed=seed,
+ num_outputs_per_prompt=1,
+ ),
+ py_generator=True,
+ )
+
+ warmup_output = next(gen)
+
+ t_prev = time.perf_counter()
+ per_request_times_s: list[float] = []
+ last_output = warmup_output
+ for _ in range(num_requests - 1):
+ last_output = next(gen)
+ t_now = time.perf_counter()
+ per_request_times_s.append(t_now - t_prev)
+ t_prev = t_now
+
+ # Ensure the generator is fully consumed so it can clean up.
+ for _ in gen:
+ pass
+
+ median_time_s = float(np.median(per_request_times_s))
+
+ peak_memory_mb = monitor.peak_used_mb
+
+ return _extract_single_image([last_output]), median_time_s, peak_memory_mb
finally:
monitor.stop()
- m.close()
- cleanup_dist_env_and_memory()
@pytest.mark.advanced_model
@@ -159,8 +151,8 @@ def _run_zimage_generate(
@pytest.mark.parallel
@hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards={"cuda": 4, "rocm": 2})
def test_zimage_tensor_parallel_tp2(tmp_path: Path):
- if current_omni_platform.is_npu() or current_omni_platform.is_rocm():
- pytest.skip("Z-Image TP e2e test is only supported on CUDA for now.")
+ if current_omni_platform.is_npu():
+ pytest.skip("Z-Image TP e2e test is only supported on CUDA and ROCm for now.")
if not current_omni_platform.is_available() or current_omni_platform.device_count() < 2:
pytest.skip("Z-Image TP=2 requires >= 2 devices.")
@@ -211,7 +203,9 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path):
)
print(f"Z-Image TP perf (lower is better): tp1_time_s={tp1_time_s:.6f}, tp2_time_s={tp2_time_s:.6f}")
- assert tp2_time_s < tp1_time_s, f"Expected TP=2 to be faster than TP=1 (tp1={tp1_time_s}, tp2={tp2_time_s})"
+ # ROCm is not optimized TP2 can be slower than TP1
+ if not current_omni_platform.is_rocm():
+ assert tp2_time_s < tp1_time_s, f"Expected TP=2 to be faster than TP=1 (tp1={tp1_time_s}, tp2={tp2_time_s})"
print(f"Z-Image TP peak memory (MB): tp1_peak_mem={tp1_peak_mem:.2f}, tp2_peak_mem={tp2_peak_mem:.2f}")
assert tp2_peak_mem < tp1_peak_mem, (
@@ -221,8 +215,8 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path):
@pytest.mark.integration
def test_zimage_vae_patch_parallel_tp2(tmp_path: Path):
- if current_omni_platform.is_npu() or current_omni_platform.is_rocm():
- pytest.skip("Z-Image VAE patch parallel e2e test is only supported on CUDA for now.")
+ if current_omni_platform.is_npu():
+ pytest.skip("Z-Image VAE patch parallel e2e test is only supported on CUDA and ROCm for now.")
if not current_omni_platform.is_available() or current_omni_platform.device_count() < 2:
pytest.skip("Z-Image VAE patch parallel TP=2 requires >= 2 devices.")
diff --git a/tests/e2e/online_serving/test_bagel_expansion.py b/tests/e2e/online_serving/test_bagel_expansion.py
index e2d75e0d19..d801020c9d 100644
--- a/tests/e2e/online_serving/test_bagel_expansion.py
+++ b/tests/e2e/online_serving/test_bagel_expansion.py
@@ -88,7 +88,7 @@ def _get_diffusion_feature_cases(model: str):
],
),
id="parallel_tp_2",
- marks=PARALLEL_FEATURE_MARKS,
+ marks=[*PARALLEL_FEATURE_MARKS, pytest.mark.skip(reason="issue: #2862")],
),
# Ulysses-SP degree=2 (2 GPUs)
pytest.param(
diff --git a/tests/e2e/online_serving/test_bagel_online.py b/tests/e2e/online_serving/test_bagel_online.py
index ca24f5f81f..a3f999f13d 100644
--- a/tests/e2e/online_serving/test_bagel_online.py
+++ b/tests/e2e/online_serving/test_bagel_online.py
@@ -47,7 +47,7 @@
OmniServerParams(
model=MODEL,
stage_config_path=STAGE_CONFIGS_PATH,
- server_args=["--stage-init-timeout", "300"],
+ stage_init_timeout=300,
),
]
diff --git a/tests/e2e/online_serving/test_dynin_omni_expansion.py b/tests/e2e/online_serving/test_dynin_omni_expansion.py
index 4648c424fe..710c480f08 100644
--- a/tests/e2e/online_serving/test_dynin_omni_expansion.py
+++ b/tests/e2e/online_serving/test_dynin_omni_expansion.py
@@ -30,7 +30,7 @@
T2S_PROMPT = "Please read this sentence naturally: Hello from Dynin-Omni online serving."
I2I_PROMPT = "Transform this outdoor nature boardwalk scene into a painting style with vivid colors."
-TEST_PARAMS = [OmniServerParams(model=MODEL, stage_config_path=STAGE_CONFIG)]
+TEST_PARAMS = [OmniServerParams(model=MODEL, stage_config_path=STAGE_CONFIG, stage_init_timeout=600)]
_STAGE_COUNT = 3
_I2I_STAGE_SAMPLING = {"max_tokens": 1, "temperature": 0.0, "top_p": 1.0, "detokenize": False}
@@ -120,7 +120,7 @@ def _build_i2i_messages(prompt: str) -> list[dict]:
@pytest.mark.advanced_model
@pytest.mark.omni
-@hardware_test(res={"cuda": "L4", "rocm": "MI325"})
+@hardware_test(res={"cuda": "H100", "rocm": "MI325"})
@pytest.mark.parametrize("omni_server", TEST_PARAMS, indirect=True)
def test_send_i2i_request_001(omni_server, openai_client) -> None:
request_config = {
@@ -136,7 +136,7 @@ def test_send_i2i_request_001(omni_server, openai_client) -> None:
@pytest.mark.advanced_model
@pytest.mark.omni
-@hardware_test(res={"cuda": "L4", "rocm": "MI325"})
+@hardware_test(res={"cuda": "H100", "rocm": "MI325"})
@pytest.mark.parametrize("omni_server", TEST_PARAMS, indirect=True)
def test_send_t2i_request_001(omni_server, openai_client) -> None:
request_config = {
@@ -149,7 +149,7 @@ def test_send_t2i_request_001(omni_server, openai_client) -> None:
@pytest.mark.core_model
@pytest.mark.omni
-@hardware_test(res={"cuda": "L4", "rocm": "MI325"})
+@hardware_test(res={"cuda": "H100", "rocm": "MI325"})
@pytest.mark.parametrize("omni_server", TEST_PARAMS, indirect=True)
def test_send_t2s_request_001(omni_server, dynin_t2s_openai_client) -> None:
request_config = {
diff --git a/tests/e2e/online_serving/test_flux2_klein_inpaint_expansion.py b/tests/e2e/online_serving/test_flux2_klein_inpaint_expansion.py
new file mode 100644
index 0000000000..f59a0e783d
--- /dev/null
+++ b/tests/e2e/online_serving/test_flux2_klein_inpaint_expansion.py
@@ -0,0 +1,160 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""
+End-to-end tests for Flux2 Klein inpainting in online serving mode.
+
+Uses /v1/images/edits endpoint which is the correct API for image inpainting.
+"""
+
+import base64
+from io import BytesIO
+
+import httpx
+import pytest
+from PIL import Image, ImageDraw
+
+from tests.conftest import OmniServer, OmniServerParams
+
+MODEL = "black-forest-labs/FLUX.2-klein-4B"
+
+_HEIGHT = 512
+_WIDTH = 512
+_NUM_INFERENCE_STEPS = 4
+
+
+def _get_diffusion_feature_cases(model: str):
+ return [
+ pytest.param(
+ OmniServerParams(
+ model=model,
+ server_args=["--tensor-parallel-size", "2"],
+ ),
+ id="tp2_basic",
+ ),
+ ]
+
+
+def _image_to_base64_jpeg(image: Image.Image) -> str:
+ buffer = BytesIO()
+ image.save(buffer, format="JPEG")
+ buffer.seek(0)
+ return base64.b64encode(buffer.read()).decode("utf-8")
+
+
+def _create_test_mask_base64(width: int = _WIDTH, height: int = _HEIGHT) -> str:
+ mask = Image.new("L", (width, height), 0)
+ draw = ImageDraw.Draw(mask)
+ draw.rectangle([width // 4, height // 4, width * 3 // 4, height * 3 // 4], fill=255)
+ return _image_to_base64_jpeg(mask)
+
+
+def _compare_images(img1: Image.Image, img2: Image.Image) -> bool:
+ return list(img1.getdata()) == list(img2.getdata())
+
+
+def _send_edit_request(host: str, port: int, model: str, image_b64: str, mask_b64: str, prompt: str, **kwargs):
+ url = f"http://{host}:{port}/v1/images/edits"
+ files = {
+ "image": ("image.jpg", base64.b64decode(image_b64), "image/jpeg"),
+ "mask_image": ("mask.jpg", base64.b64decode(mask_b64), "image/jpeg"),
+ }
+ data = {"prompt": prompt, "model": model, **kwargs}
+ with httpx.Client(timeout=60.0) as client:
+ response = client.post(url, files=files, data=data)
+ response.raise_for_status()
+ return response.json()
+
+
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
+@pytest.mark.parametrize("omni_server", _get_diffusion_feature_cases(MODEL), indirect=True)
+def test_flux2_klein_inpaint_basic(omni_server: OmniServer):
+ input_image_b64 = _image_to_base64_jpeg(Image.new("RGB", (_WIDTH, _HEIGHT), (128, 128, 128)))
+ mask_b64 = _create_test_mask_base64()
+
+ result = _send_edit_request(
+ host=omni_server.host,
+ port=omni_server.port,
+ model=MODEL,
+ image_b64=input_image_b64,
+ mask_b64=mask_b64,
+ prompt="Fill in the masked area with a beautiful garden",
+ guidance_scale=1.0,
+ num_inference_steps=_NUM_INFERENCE_STEPS,
+ n=1,
+ seed=42,
+ )
+
+ assert "data" in result and len(result["data"]) == 1
+ img_data = result["data"][0].get("b64_json") or result["data"][0].get("url", "").split(",")[-1]
+ img = Image.open(BytesIO(base64.b64decode(img_data)))
+ assert img.size == (_WIDTH, _HEIGHT)
+
+
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
+@pytest.mark.parametrize("omni_server", _get_diffusion_feature_cases(MODEL), indirect=True)
+def test_flux2_klein_inpaint_deterministic(omni_server: OmniServer):
+ input_image_b64 = _image_to_base64_jpeg(Image.new("RGB", (_WIDTH, _HEIGHT), (128, 128, 128)))
+ mask_b64 = _create_test_mask_base64()
+ prompt = "A red flower in a field"
+
+ result1 = _send_edit_request(
+ host=omni_server.host,
+ port=omni_server.port,
+ model=MODEL,
+ image_b64=input_image_b64,
+ mask_b64=mask_b64,
+ prompt=prompt,
+ guidance_scale=1.0,
+ num_inference_steps=_NUM_INFERENCE_STEPS,
+ n=1,
+ seed=12345,
+ )
+
+ result2 = _send_edit_request(
+ host=omni_server.host,
+ port=omni_server.port,
+ model=MODEL,
+ image_b64=input_image_b64,
+ mask_b64=mask_b64,
+ prompt=prompt,
+ guidance_scale=1.0,
+ num_inference_steps=_NUM_INFERENCE_STEPS,
+ n=1,
+ seed=12345,
+ )
+
+ img1_data = result1["data"][0].get("b64_json") or result1["data"][0].get("url", "").split(",")[-1]
+ img2_data = result2["data"][0].get("b64_json") or result2["data"][0].get("url", "").split(",")[-1]
+
+ img1 = Image.open(BytesIO(base64.b64decode(img1_data)))
+ img2 = Image.open(BytesIO(base64.b64decode(img2_data)))
+
+ assert _compare_images(img1, img2), (
+ "Same input with same seed should produce identical output. This is critical for offline/online consistency."
+ )
+
+
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
+@pytest.mark.parametrize("omni_server", _get_diffusion_feature_cases(MODEL), indirect=True)
+def test_flux2_klein_inpaint_multiple_outputs(omni_server: OmniServer):
+ input_image_b64 = _image_to_base64_jpeg(Image.new("RGB", (_WIDTH, _HEIGHT), (128, 128, 128)))
+ mask_b64 = _create_test_mask_base64()
+
+ result = _send_edit_request(
+ host=omni_server.host,
+ port=omni_server.port,
+ model=MODEL,
+ image_b64=input_image_b64,
+ mask_b64=mask_b64,
+ prompt="A beautiful landscape",
+ guidance_scale=1.0,
+ num_inference_steps=_NUM_INFERENCE_STEPS,
+ n=2,
+ seed=42,
+ )
+
+ assert "data" in result and len(result["data"]) == 2
diff --git a/tests/e2e/online_serving/test_flux_2_dev_expansion.py b/tests/e2e/online_serving/test_flux_2_dev_expansion.py
index 9d96a48c0c..f7477ed803 100644
--- a/tests/e2e/online_serving/test_flux_2_dev_expansion.py
+++ b/tests/e2e/online_serving/test_flux_2_dev_expansion.py
@@ -27,7 +27,7 @@
NEGATIVE_PROMPT = "low quality, blurry, distorted, deformed, watermark"
SINGLE_CARD_FEATURE_MARKS = hardware_marks(res={"cuda": "H100"})
-PARALLEL_FEATURE_MARKS = hardware_marks(res={"cuda": "L4"}, num_cards=2)
+PARALLEL_FEATURE_MARKS = hardware_marks(res={"cuda": "H100"}, num_cards=2)
def _get_flux_2_dev_feature_cases(model: str):
@@ -48,8 +48,6 @@ def _get_flux_2_dev_feature_cases(model: str):
OmniServerParams(
model=model,
server_args=[
- "--cache-backend",
- "cache_dit",
"--enable-cpu-offload",
"--cfg-parallel-size",
"2",
diff --git a/tests/e2e/online_serving/test_images_generations_lora.py b/tests/e2e/online_serving/test_images_generations_lora.py
index 8c826591a5..fb1e3ea1e0 100644
--- a/tests/e2e/online_serving/test_images_generations_lora.py
+++ b/tests/e2e/online_serving/test_images_generations_lora.py
@@ -28,7 +28,7 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
MODEL = "Tongyi-MAI/Z-Image-Turbo"
-DIFFUSION_INIT_TIMEOUT_S = 700
+DIFFUSION_INIT_TIMEOUT_S = 900
PROMPT = "a photo of a cat sitting on a laptop keyboard"
diff --git a/tests/e2e/online_serving/test_ming_flash_omni.py b/tests/e2e/online_serving/test_ming_flash_omni.py
new file mode 100644
index 0000000000..35b7b64c06
--- /dev/null
+++ b/tests/e2e/online_serving/test_ming_flash_omni.py
@@ -0,0 +1,247 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+E2E online serving tests for Ming-flash-omni-2.0 model (Thinker stage).
+Tests multimodal understanding via OpenAI-compatible API.
+"""
+
+import os
+from pathlib import Path
+
+import pytest
+
+from tests.conftest import (
+ OmniServerParams,
+ dummy_messages_from_mix_data,
+ generate_synthetic_audio,
+ generate_synthetic_image,
+ generate_synthetic_video,
+ modify_stage_config,
+)
+from tests.utils import hardware_test
+
+os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
+
+models = ["Jonathan1909/Ming-flash-omni-2.0"]
+
+
+def get_eager_config():
+ path = modify_stage_config(
+ str(Path(__file__).parent.parent / "stage_configs" / "bailingmm_moe_v2_lite_ci.yaml"),
+ updates={
+ "stage_args": {
+ 0: {
+ "engine_args.enforce_eager": "true",
+ },
+ },
+ },
+ )
+ return path
+
+
+stage_configs = [get_eager_config()]
+
+# Create parameter combinations for model and stage config
+test_params = [
+ OmniServerParams(model=model, stage_config_path=stage_config) for model in models for stage_config in stage_configs
+]
+
+
+def get_system_prompt():
+ return {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": "你是一个友好的AI助手。\n\ndetailed thinking off",
+ }
+ ],
+ }
+
+
+def get_prompt(prompt_type="text_only"):
+ prompts = {
+ "text_only": "What is the capital of China? Answer in 20 words.",
+ "text_image": "What is in this image?",
+ "text_audio": "What is in this audio?",
+ "text_video": "What is in this video?",
+ "mix": "What is recited in the audio? What is in this image? What is in this video?",
+ }
+ return prompts.get(prompt_type, prompts["text_only"])
+
+
+def get_max_batch_size(size_type="few"):
+ batch_sizes = {"few": 5, "medium": 100, "large": 256}
+ return batch_sizes.get(size_type, 5)
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_text_to_text_001(omni_server, openai_client) -> None:
+ """
+ Input Modal: text
+ Output Modal: text
+ Input Setting: stream=False
+ Datasets: single request
+ """
+ messages = dummy_messages_from_mix_data(
+ system_prompt=get_system_prompt(),
+ content_text=get_prompt("text_only"),
+ )
+
+ request_config = {
+ "model": omni_server.model,
+ "messages": messages,
+ "stream": False,
+ "modalities": ["text"],
+ "key_words": {"text": ["beijing"]},
+ }
+
+ openai_client.send_omni_request(request_config)
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_text_to_text_stream_001(omni_server, openai_client) -> None:
+ """
+ Input Modal: text
+ Output Modal: text
+ Input Setting: stream=True
+ Datasets: few requests
+ """
+ messages = dummy_messages_from_mix_data(
+ system_prompt=get_system_prompt(),
+ content_text=get_prompt("text_only"),
+ )
+
+ request_config = {
+ "model": omni_server.model,
+ "messages": messages,
+ "stream": True,
+ "modalities": ["text"],
+ "key_words": {"text": ["beijing"]},
+ }
+
+ openai_client.send_omni_request(request_config, request_num=get_max_batch_size())
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_image_to_text_001(omni_server, openai_client) -> None:
+ """
+ Input Modal: image + text
+ Output Modal: text
+ Input Setting: stream=True
+ Datasets: single request
+ """
+ image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}"
+ messages = dummy_messages_from_mix_data(
+ system_prompt=get_system_prompt(),
+ image_data_url=image_data_url,
+ content_text=get_prompt("text_image"),
+ )
+
+ request_config = {
+ "model": omni_server.model,
+ "messages": messages,
+ "stream": True,
+ "modalities": ["text"],
+ }
+
+ openai_client.send_omni_request(request_config)
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_audio_to_text_001(omni_server, openai_client) -> None:
+ """
+ Input Modal: audio + text
+ Output Modal: text
+ Input Setting: stream=True
+ Datasets: single request
+ """
+ audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(2, 1)['base64']}"
+ messages = dummy_messages_from_mix_data(
+ system_prompt=get_system_prompt(),
+ audio_data_url=audio_data_url,
+ content_text=get_prompt("text_audio"),
+ )
+
+ request_config = {
+ "model": omni_server.model,
+ "messages": messages,
+ "stream": True,
+ "modalities": ["text"],
+ }
+
+ openai_client.send_omni_request(request_config)
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_video_to_text_001(omni_server, openai_client) -> None:
+ """
+ Input Modal: video + text
+ Output Modal: text
+ Input Setting: stream=False
+ Datasets: single request
+ """
+ video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}"
+ messages = dummy_messages_from_mix_data(
+ system_prompt=get_system_prompt(),
+ video_data_url=video_data_url,
+ content_text=get_prompt("text_video"),
+ )
+
+ request_config = {
+ "model": omni_server.model,
+ "messages": messages,
+ "stream": False,
+ "modalities": ["text"],
+ }
+
+ openai_client.send_omni_request(request_config)
+
+
+@pytest.mark.advanced_model
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_mix_to_text_001(omni_server, openai_client) -> None:
+ """
+ Input Modal: text + audio + image + video
+ Output Modal: text
+ Input Setting: stream=True
+ Datasets: single request
+ """
+ video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}"
+ image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}"
+ audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(2, 1)['base64']}"
+ messages = dummy_messages_from_mix_data(
+ system_prompt=get_system_prompt(),
+ video_data_url=video_data_url,
+ image_data_url=image_data_url,
+ audio_data_url=audio_data_url,
+ content_text=get_prompt("mix"),
+ )
+
+ request_config = {
+ "model": omni_server.model,
+ "messages": messages,
+ "stream": True,
+ "modalities": ["text"],
+ }
+
+ openai_client.send_omni_request(request_config)
diff --git a/tests/e2e/online_serving/test_nextstep_expansion.py b/tests/e2e/online_serving/test_nextstep_expansion.py
new file mode 100644
index 0000000000..cd3d7f9bca
--- /dev/null
+++ b/tests/e2e/online_serving/test_nextstep_expansion.py
@@ -0,0 +1,71 @@
+"""
+Online serving E2E for NextStep-1.1 text-to-image (tensor parallel).
+"""
+
+import os
+
+import pytest
+
+from tests.conftest import (
+ OmniServer,
+ OmniServerParams,
+ OpenAIClientHandler,
+ dummy_messages_from_mix_data,
+)
+from tests.utils import hardware_marks
+
+# L4: 4 GPUs + TP=4; XPU B60: 2 cards (use num_cards={"cuda": 4, "xpu": 4} if needed)
+FOUR_CARD_MARKS = hardware_marks(
+ res={"cuda": "L4", "xpu": "B60"},
+ num_cards={"cuda": 2, "xpu": 2},
+)
+
+POSITIVE_PROMPT = "A small red barn in a snowy field, simple illustration."
+NEGATIVE_PROMPT = "blurry, low quality"
+
+_DEFAULT_MODEL = "stepfun-ai/NextStep-1.1"
+
+
+def _get_diffusion_feature_cases(model: str):
+ """Single online config: TP=4, explicit pipeline class."""
+ return [
+ pytest.param(
+ OmniServerParams(
+ model=model,
+ server_args=[
+ "--tensor-parallel-size",
+ "2",
+ "--model-class-name",
+ "NextStep11Pipeline",
+ ],
+ ),
+ id="nextstep_tp4_pipeline",
+ marks=FOUR_CARD_MARKS,
+ ),
+ ]
+
+
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
+@pytest.mark.parametrize(
+ "omni_server",
+ _get_diffusion_feature_cases(model=os.environ.get("VLLM_TEST_NEXTSTEP_MODEL", _DEFAULT_MODEL)),
+ indirect=True,
+)
+def test_nextstep_11(omni_server: OmniServer, openai_client: OpenAIClientHandler):
+ messages = dummy_messages_from_mix_data(content_text=POSITIVE_PROMPT)
+ request_config = {
+ "model": omni_server.model,
+ "messages": messages,
+ "extra_body": {
+ "height": 512,
+ "width": 512,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "guidance_scale_2": 1.0,
+ "negative_prompt": NEGATIVE_PROMPT,
+ "seed": 42,
+ },
+ }
+
+ openai_client.send_diffusion_request(request_config)
diff --git a/tests/e2e/online_serving/test_omnivoice.py b/tests/e2e/online_serving/test_omnivoice.py
index ec1981aab2..4a0069f402 100644
--- a/tests/e2e/online_serving/test_omnivoice.py
+++ b/tests/e2e/online_serving/test_omnivoice.py
@@ -17,9 +17,16 @@
import httpx
import pytest
-from tests.conftest import OmniServerParams
+from tests.conftest import OmniServerParams, generate_synthetic_audio
from tests.utils import hardware_test
+try:
+ from transformers import HiggsAudioV2TokenizerModel # noqa: F401
+
+ _HAS_VOICE_CLONE = True
+except ImportError:
+ _HAS_VOICE_CLONE = False
+
MODEL = "k2-fsa/OmniVoice"
STAGE_CONFIG = str(
@@ -40,6 +47,16 @@
MIN_AUDIO_BYTES = 5000
+def _get_ref_audio_b64() -> str:
+ """Generate synthetic speech for reference audio.
+
+ Returns:
+ Base64 data URL string (data:audio/wav;base64,...)
+ """
+ audio_data = generate_synthetic_audio(duration=2, num_channels=1, sample_rate=24000)
+ return f"data:audio/wav;base64,{audio_data['base64']}"
+
+
def make_speech_request(
host: str,
port: int,
@@ -82,3 +99,102 @@ def test_speech_auto_voice(self, omni_server) -> None:
assert len(response.content) > MIN_AUDIO_BYTES, (
f"Audio too small ({len(response.content)} bytes), expected > {MIN_AUDIO_BYTES}"
)
+
+
+def make_voice_clone_request(
+ host: str,
+ port: int,
+ text: str,
+ ref_audio_b64: str,
+ ref_text: str | None = None,
+ timeout: float = 180.0,
+) -> httpx.Response:
+ """Make a voice cloning request to the /v1/audio/speech endpoint.
+
+ Args:
+ host: Server host
+ port: Server port
+ text: Text to synthesize
+ ref_audio_b64: Base64-encoded reference audio data URL
+ ref_text: Optional transcript of reference audio
+ timeout: Request timeout in seconds
+
+ Returns:
+ httpx.Response object
+ """
+ url = f"http://{host}:{port}/v1/audio/speech"
+ payload = {
+ "input": text,
+ "ref_audio": ref_audio_b64,
+ }
+ if ref_text:
+ payload["ref_text"] = ref_text
+
+ with httpx.Client(timeout=timeout) as client:
+ return client.post(url, json=payload)
+
+
+@pytest.mark.skipif(not _HAS_VOICE_CLONE, reason="Voice cloning requires transformers>=5.3.0")
+@pytest.mark.parametrize("omni_server", TEST_PARAMS, indirect=True)
+class TestOmniVoiceVoiceCloning:
+ """E2E tests for OmniVoice voice cloning functionality."""
+
+ @pytest.mark.core_model
+ @pytest.mark.omni
+ @hardware_test(res={"cuda": "L4"}, num_cards=1)
+ def test_voice_clone_ref_audio_only(self, omni_server) -> None:
+ """Test voice cloning with ref_audio only (x_vector mode)."""
+ ref_audio_b64 = _get_ref_audio_b64()
+
+ response = make_voice_clone_request(
+ host=omni_server.host,
+ port=omni_server.port,
+ text="Hello, this is a voice cloning test.",
+ ref_audio_b64=ref_audio_b64,
+ )
+
+ assert response.status_code == 200, f"Request failed: {response.text}"
+ assert response.headers.get("content-type") == "audio/wav"
+ assert verify_wav_audio(response.content), "Response is not valid WAV audio"
+ assert len(response.content) > MIN_AUDIO_BYTES, (
+ f"Audio too small ({len(response.content)} bytes), expected > {MIN_AUDIO_BYTES}"
+ )
+
+ @pytest.mark.core_model
+ @pytest.mark.omni
+ @hardware_test(res={"cuda": "L4"}, num_cards=1)
+ def test_voice_clone_ref_audio_and_text(self, omni_server) -> None:
+ """Test voice cloning with ref_audio and ref_text (in-context mode)."""
+ ref_audio_b64 = _get_ref_audio_b64()
+ ref_text = "This is the reference transcript."
+
+ response = make_voice_clone_request(
+ host=omni_server.host,
+ port=omni_server.port,
+ text="Hello, this is a voice cloning test with in-context learning.",
+ ref_audio_b64=ref_audio_b64,
+ ref_text=ref_text,
+ )
+
+ assert response.status_code == 200, f"Request failed: {response.text}"
+ assert response.headers.get("content-type") == "audio/wav"
+ assert verify_wav_audio(response.content), "Response is not valid WAV audio"
+ assert len(response.content) > MIN_AUDIO_BYTES, (
+ f"Audio too small ({len(response.content)} bytes), expected > {MIN_AUDIO_BYTES}"
+ )
+
+ @pytest.mark.core_model
+ @pytest.mark.omni
+ @hardware_test(res={"cuda": "L4"}, num_cards=1)
+ def test_voice_clone_invalid_ref_audio_format(self, omni_server) -> None:
+ """Test that invalid ref_audio format returns a clear error."""
+ response = make_voice_clone_request(
+ host=omni_server.host,
+ port=omni_server.port,
+ text="This should fail with invalid ref_audio.",
+ ref_audio_b64="not_a_valid_uri",
+ )
+
+ assert response.status_code in (400, 422), (
+ f"Expected 400/422 for invalid ref_audio format, got {response.status_code}"
+ )
diff --git a/tests/e2e/online_serving/test_qwen2_5_omni.py b/tests/e2e/online_serving/test_qwen2_5_omni.py
index e2913ce021..ba333e498c 100644
--- a/tests/e2e/online_serving/test_qwen2_5_omni.py
+++ b/tests/e2e/online_serving/test_qwen2_5_omni.py
@@ -3,7 +3,6 @@
"""
import os
-from pathlib import Path
import pytest
@@ -15,8 +14,7 @@
generate_synthetic_video,
modify_stage_config,
)
-from tests.utils import hardware_test
-from vllm_omni.platforms import current_omni_platform
+from tests.utils import get_deploy_config_path, hardware_test
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
@@ -24,20 +22,9 @@
models = ["Qwen/Qwen2.5-Omni-7B"]
-
-def get_config():
- path = modify_stage_config(
- str(Path(__file__).parent.parent / "stage_configs" / "qwen2_5_omni_ci.yaml"),
- )
- return path
-
-
-# CI stage config for 2xH100-80G GPUs or AMD GPU MI325
-if current_omni_platform.is_rocm():
- # ROCm stage config optimized for MI325 GPU
- stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "rocm" / "qwen2_5_omni_ci.yaml")]
-else:
- stage_configs = [get_config()]
+# Single CI deploy YAML; rocm/xpu deltas are picked automatically via the
+# platforms: section in vllm_omni/deploy/ci/qwen2_5_omni.yaml.
+stage_configs = [modify_stage_config(get_deploy_config_path("ci/qwen2_5_omni.yaml"))]
# Create parameter combinations for model and stage config
test_params = [
diff --git a/tests/e2e/online_serving/test_qwen3_omni.py b/tests/e2e/online_serving/test_qwen3_omni.py
index fcda20ba38..62eca6349f 100644
--- a/tests/e2e/online_serving/test_qwen3_omni.py
+++ b/tests/e2e/online_serving/test_qwen3_omni.py
@@ -3,7 +3,6 @@
"""
import os
-from pathlib import Path
import pytest
@@ -15,7 +14,7 @@
generate_synthetic_video,
modify_stage_config,
)
-from tests.utils import hardware_test
+from tests.utils import get_deploy_config_path, hardware_test
from vllm_omni.platforms import current_omni_platform
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@@ -24,35 +23,64 @@
models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]
+# Set VLLM_TEST_PD_MODE=1 to test PD disaggregation (follow-up — deploy overlay not yet migrated).
+_USE_PD = os.environ.get("VLLM_TEST_PD_MODE", "0") == "1"
-def get_chunk_config():
+_CI_DEPLOY = get_deploy_config_path("ci/qwen3_omni_moe.yaml")
+
+
+def get_chunk_config(config_path: str | None = None):
+ """Load the qwen3_omni CI deploy yaml with async_chunk modifications for streaming mode."""
+ if config_path is None:
+ config_path = _CI_DEPLOY
+ # TODO: remove this workaround once legacy `stage_args` path is deleted.
+ # The pipeline (qwen3_omni/pipeline.py) already wires
+ # thinker2talker_async_chunk / talker2code2wav_async_chunk on stage 0/1,
+ # so only async_chunk needs flipping. Writing nested `engine_args:` into
+ # the new-schema overlay trips _parse_stage_deploy's legacy branch and
+ # drops flat fields (load_format, max_num_seqs, ...).
+ return modify_stage_config(config_path, updates={"async_chunk": True})
+
+
+def get_prefix_caching_config(config_path: str):
+ """Create a stage config with prefix caching enabled on the thinker (stage 0)."""
path = modify_stage_config(
- str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml"),
+ config_path,
updates={
- "async_chunk": True,
"stage_args": {
- 0: {
- "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk"
- },
- 1: {
- "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk"
- },
+ 0: {"engine_args.enable_prefix_caching": True},
},
},
- deletes={"stage_args": {2: ["custom_process_input_func"]}},
)
return path
+# Platform-specific overrides live inside the new deploy yaml's ``platforms:``
+# section, so a single ``_CI_DEPLOY`` path serves CUDA, ROCm, and XPU.
+# TODO: re-add VLLM_TEST_PD_MODE branch once the PD-disaggregation deploy
+# overlay has been migrated to the new schema (previously used the deleted
+# ``qwen3_omni_moe_pd_ci.yaml`` stage-configs file).
if current_omni_platform.is_xpu():
- stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "xpu" / "qwen3_omni_ci.yaml")]
-else: # MI325 GPU should share the same config as H100
+ stage_configs = [_CI_DEPLOY]
+else: # CUDA + ROCm MI325 share the same deploy config
stage_configs = [get_chunk_config()]
+prefix_caching_stage_configs = [get_prefix_caching_config(_CI_DEPLOY)]
# Create parameter combinations for model and stage config
test_params = [
OmniServerParams(model=model, stage_config_path=stage_config) for model in models for stage_config in stage_configs
]
+# For prefix caching, we need to enable prompt token details so that we
+# can determine if any tokens were cached.
+prefix_test_params = [
+ OmniServerParams(
+ model=model,
+ stage_config_path=stage_config,
+ server_args=["--enable-prompt-tokens-details"], # Enable prompt tokens details to get cached_tokens
+ )
+ for model in models
+ for stage_config in prefix_caching_stage_configs
+]
def get_system_prompt():
@@ -75,6 +103,7 @@ def get_prompt(prompt_type="text_only"):
prompts = {
"text_only": "What is the capital of China? Answer in 20 words.",
"mix": "What is recited in the audio? What is in this image? Describe the video briefly.",
+ "text_image": "What color are the squares in this image?",
}
return prompts.get(prompt_type, prompts["text_only"])
@@ -87,7 +116,8 @@ def get_max_batch_size(size_type="few"):
@pytest.mark.advanced_model
@pytest.mark.core_model
@pytest.mark.omni
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
+@pytest.mark.skipif(_USE_PD, reason="Temporarily skip PD mode in this test module.")
+@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=3 if _USE_PD else 2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_mix_to_text_audio_001(omni_server, openai_client) -> None:
"""
@@ -120,13 +150,14 @@ def test_mix_to_text_audio_001(omni_server, openai_client) -> None:
}
# Test single completion
- openai_client.send_omni_request(request_config)
+ openai_client.send_omni_request(request_config, request_num=get_max_batch_size())
@pytest.mark.advanced_model
@pytest.mark.core_model
@pytest.mark.omni
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
+@pytest.mark.skipif(_USE_PD, reason="Temporarily skip PD mode in this test module.")
+@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=3 if _USE_PD else 2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_text_to_text_001(omni_server, openai_client) -> None:
"""
@@ -147,3 +178,42 @@ def test_text_to_text_001(omni_server, openai_client) -> None:
}
openai_client.send_omni_request(request_config, request_num=get_max_batch_size())
+
+
+@pytest.mark.advanced_model
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
+@pytest.mark.parametrize("omni_server", prefix_test_params, indirect=True)
+@pytest.mark.skip(reason="issue: #2833")
+def test_thinker_prefix_caching(omni_server, openai_client) -> None:
+ """
+ Test thinker prefix caching by sending identical requests with an image (i.e.,
+ a large shared prefix) and verifying that the second request uses cached tokens
+ & produces the same output.
+ """
+ image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}"
+ messages = dummy_messages_from_mix_data(
+ system_prompt=get_system_prompt(),
+ image_data_url=image_data_url,
+ content_text=get_prompt("text_image"),
+ )
+
+ request_config = {
+ "model": omni_server.model,
+ "messages": messages,
+ "stream": False,
+ "modalities": ["text"],
+ }
+
+ response_1 = openai_client.send_omni_request(request_config, request_num=1)[0]
+ response_2 = openai_client.send_omni_request(request_config, request_num=1)[0]
+
+ assert response_1.success
+ assert response_2.success
+ assert response_2.cached_tokens is not None
+ # We should cache the vast majority of the prompt (image + up to last full block),
+ # and set seed in the CI config, so the second request should give an identical
+ # response for the generated input image, even if we use dummy weights
+ assert response_2.cached_tokens > 0
+ assert response_1.text_content == response_2.text_content
diff --git a/tests/e2e/online_serving/test_qwen3_omni_expansion.py b/tests/e2e/online_serving/test_qwen3_omni_expansion.py
index 1637627695..3152a8f982 100644
--- a/tests/e2e/online_serving/test_qwen3_omni_expansion.py
+++ b/tests/e2e/online_serving/test_qwen3_omni_expansion.py
@@ -6,10 +6,7 @@
import os
-from vllm_omni.platforms import current_omni_platform
-
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-from pathlib import Path
import pytest
@@ -21,7 +18,7 @@
generate_synthetic_video,
modify_stage_config,
)
-from tests.utils import hardware_test
+from tests.utils import get_deploy_config_path, hardware_test
model = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
@@ -29,47 +26,68 @@
IMAGE_KEY = ["square", "quadrate", "rectangle"]
VIDEO_KEY = ["sphere", "globe", "circle", "round", "ball"]
+# Heavier synthetic inputs than the default expansion cases (longer timeline / more pixels).
+# Long video: 120s @ 30fps => 3600 frames (generate_synthetic_video in tests/conftest.py).
+# Use 224² spatial size to bound RAM (~W*H*num_frames*3) vs. 288² at this frame count.
+LONG_VIDEO_WIDTH = 224
+LONG_VIDEO_HEIGHT = 224
+LONG_VIDEO_FRAMES = 3600
+LARGE_IMAGE_WIDTH = 1920
+LARGE_IMAGE_HEIGHT = 1080
+LONG_AUDIO_DURATION_SEC = 120
+
+
+def get_batch_token_config(default_path):
+ """Override stage 1's max_num_batched_tokens to exercise small-batch paths.
-def get_chunk_config(default_path):
- path = modify_stage_config(
+ Uses the new flat-stage schema (``stages..``); the legacy
+ ``stage_args..engine_args.`` path no longer applies because
+ the deploy YAML doesn't nest engine fields under ``engine_args:``.
+ """
+ return modify_stage_config(
default_path,
updates={
- "async_chunk": True,
- "stage_args": {
- 0: {
- "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk"
- },
- 1: {
- "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk"
- },
- },
+ "stages": {1: {"max_num_batched_tokens": 64}},
},
- deletes={"stage_args": {2: ["custom_process_input_func"]}},
)
- return path
-def get_batch_token_config(default_path):
- path = modify_stage_config(
+def get_async_chunk_config(default_path):
+ """Flip async_chunk on and bump stage 0 thinker output to 2048 tokens.
+
+ Pipeline registry (qwen3_omni/pipeline.py) already wires
+ thinker2talker_async_chunk / talker2code2wav_async_chunk on stages 0/1,
+ so no per-stage processor override is needed. Using only flat-schema
+ writes so _parse_stage_deploy stays in its flat branch (nested
+ ``engine_args:`` would drop other overlay fields).
+ """
+ return modify_stage_config(
default_path,
updates={
- "stage_args": {1: {"engine_args.max_num_batched_tokens": 64}},
+ "stages": {0: {"default_sampling_params.max_tokens": 2048}},
},
)
- return path
-# CI stage config for 2*H100-80G GPUs
-default_path = str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml")
+# CI deploy YAML (single file; xpu deltas applied via ``platforms:`` section).
+# The overlay explicitly sets ``async_chunk: False``, so ``default`` tests the
+# sync path and ``async_chunk`` tests the streaming path with a longer thinker
+# output — two distinct scenarios, kept as separate parametrizations.
+default_path = get_deploy_config_path("ci/qwen3_omni_moe.yaml")
-if current_omni_platform.is_xpu():
- default_path = str(Path(__file__).parent.parent / "stage_configs" / "xpu" / "qwen3_omni_ci.yaml")
-
-# Create parameter combinations for model and stage config
test_params = [
- pytest.param(OmniServerParams(model=model, stage_config_path=default_path, use_stage_cli=True), id="default"),
pytest.param(
- OmniServerParams(model=model, stage_config_path=get_chunk_config(default_path), use_stage_cli=True),
+ OmniServerParams(
+ model=model, stage_config_path=default_path, use_stage_cli=True, server_args=["--no-async-chunk"]
+ ),
+ id="default",
+ ),
+ pytest.param(
+ OmniServerParams(
+ model=model,
+ stage_config_path=get_async_chunk_config(default_path),
+ use_stage_cli=True,
+ ),
id="async_chunk",
),
]
@@ -167,88 +185,17 @@ def test_text_to_text_audio_001(omni_server, openai_client) -> None:
@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-def test_image_to_text_001(omni_server, openai_client) -> None:
- """
- Input Modal: image
- Output Modal: text
- Input Setting: stream=True
- Datasets: single request
- """
- image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}"
- messages = dummy_messages_from_mix_data(image_data_url=image_data_url)
-
- request_config = {
- "model": omni_server.model,
- "messages": messages,
- "modalities": ["text"],
- "stream": True,
- "key_words": {"image": IMAGE_KEY},
- }
-
- openai_client.send_omni_request(request_config)
-
-
-@pytest.mark.advanced_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
-@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-def test_image_to_audio_001(omni_server, openai_client) -> None:
- """
- Input Modal: image
- Output Modal: audio
- Input Setting: stream=False
- Datasets: single request
- """
- image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}"
- messages = dummy_messages_from_mix_data(image_data_url=image_data_url)
-
- request_config = {
- "model": omni_server.model,
- "messages": messages,
- "modalities": ["audio"],
- "key_words": {"image": IMAGE_KEY},
- }
-
- openai_client.send_omni_request(request_config)
-
-
-@pytest.mark.advanced_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
-@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-def test_image_to_text_audio_001(omni_server, openai_client) -> None:
- """
- Input Modal: image
- Output Modal: text, audio
- Input Setting: stream=False
- Datasets: few requests
- """
- image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(1280, 720)['base64']}"
-
- messages = dummy_messages_from_mix_data(image_data_url=image_data_url)
-
- request_config = {
- "model": omni_server.model,
- "messages": messages,
- "key_words": {"image": IMAGE_KEY},
- }
-
- openai_client.send_omni_request(request_config, request_num=get_max_batch_size())
-
-
-@pytest.mark.advanced_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
-@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-def test_video_to_text_001(omni_server, openai_client) -> None:
+def test_text_video_to_text_001(omni_server, openai_client) -> None:
"""
- Input Modal: video
+ Input Modal: long synthetic video (120s @ 30fps, LONG_VIDEO_FRAMES frames)
Output Modal: text
Input Setting: stream=False
Datasets: single request
"""
- video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}"
- messages = dummy_messages_from_mix_data(video_data_url=video_data_url)
+ video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(LONG_VIDEO_WIDTH, LONG_VIDEO_HEIGHT, LONG_VIDEO_FRAMES)['base64']}"
+ messages = dummy_messages_from_mix_data(
+ video_data_url=video_data_url, system_prompt=get_system_prompt(), content_text=get_prompt("text_video")
+ )
request_config = {
"model": omni_server.model,
@@ -257,28 +204,29 @@ def test_video_to_text_001(omni_server, openai_client) -> None:
"key_words": {"video": VIDEO_KEY},
}
- openai_client.send_omni_request(request_config)
+ openai_client.send_omni_request(request_config, request_num=get_max_batch_size())
@pytest.mark.advanced_model
@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
-@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-def test_video_to_audio_001(omni_server, openai_client) -> None:
+@pytest.mark.parametrize("omni_server", test_params + test_token_params, indirect=True)
+def test_text_audio_to_text_audio_001(omni_server, openai_client) -> None:
"""
- Input Modal: video
- Output Modal: audio
+ Input Modal: text, audio
+ Output Modal: text, audio
Input Setting: stream=False
Datasets: single request
"""
- video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}"
- messages = dummy_messages_from_mix_data(video_data_url=video_data_url)
+ audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(5, 1)['base64']}"
+ messages = dummy_messages_from_mix_data(
+ audio_data_url=audio_data_url, system_prompt=get_system_prompt(), content_text=get_prompt("text_audio")
+ )
request_config = {
"model": omni_server.model,
"messages": messages,
- "modalities": ["audio"],
- "key_words": {"video": VIDEO_KEY},
+ "key_words": {"audio": AUDIO_KEY},
}
openai_client.send_omni_request(request_config)
@@ -287,22 +235,25 @@ def test_video_to_audio_001(omni_server, openai_client) -> None:
@pytest.mark.advanced_model
@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
-@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-def test_video_to_text_audio_001(omni_server, openai_client) -> None:
+@pytest.mark.parametrize("omni_server", test_params + test_token_params, indirect=True)
+def test_text_audio_to_text_audio_002(omni_server, openai_client) -> None:
"""
- Input Modal: video
+ Input Modal: text, long-duration audio (~LONG_AUDIO_DURATION_SEC s WAV)
Output Modal: text, audio
Input Setting: stream=False
- Datasets: few requests
+ Datasets: single request
"""
- video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}"
-
- messages = dummy_messages_from_mix_data(video_data_url=video_data_url)
+ audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(LONG_AUDIO_DURATION_SEC, 1)['base64']}"
+ messages = dummy_messages_from_mix_data(
+ audio_data_url=audio_data_url,
+ system_prompt=get_system_prompt(),
+ content_text=get_prompt("text_audio"),
+ )
request_config = {
"model": omni_server.model,
"messages": messages,
- "key_words": {"video": VIDEO_KEY},
+ "key_words": {"audio": AUDIO_KEY},
}
openai_client.send_omni_request(request_config, request_num=get_max_batch_size())
@@ -312,22 +263,23 @@ def test_video_to_text_audio_001(omni_server, openai_client) -> None:
@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params + test_token_params, indirect=True)
-def test_text_audio_to_text_audio_001(omni_server, openai_client) -> None:
+def test_text_image_to_text_audio_001(omni_server, openai_client) -> None:
"""
- Input Modal: text, audio
+ Input Modal: text, image
Output Modal: text, audio
Input Setting: stream=False
Datasets: single request
"""
- audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(5, 1)['base64']}"
+ image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}"
+
messages = dummy_messages_from_mix_data(
- audio_data_url=audio_data_url, system_prompt=get_system_prompt(), content_text=get_prompt("text_audio")
+ image_data_url=image_data_url, system_prompt=get_system_prompt(), content_text=get_prompt("text_image")
)
request_config = {
"model": omni_server.model,
"messages": messages,
- "key_words": {"audio": AUDIO_KEY},
+ "key_words": {"image": IMAGE_KEY},
}
openai_client.send_omni_request(request_config)
@@ -337,17 +289,21 @@ def test_text_audio_to_text_audio_001(omni_server, openai_client) -> None:
@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params + test_token_params, indirect=True)
-def test_text_image_to_text_audio_001(omni_server, openai_client) -> None:
+def test_large_image_to_text_audio_001(omni_server, openai_client) -> None:
"""
- Input Modal: text, image
+ Input Modal: text, high-resolution image (1080p-class JPEG)
Output Modal: text, audio
Input Setting: stream=False
Datasets: single request
"""
- image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}"
+ image_data_url = (
+ f"data:image/jpeg;base64,{generate_synthetic_image(LARGE_IMAGE_WIDTH, LARGE_IMAGE_HEIGHT)['base64']}"
+ )
messages = dummy_messages_from_mix_data(
- image_data_url=image_data_url, system_prompt=get_system_prompt(), content_text=get_prompt("text_image")
+ image_data_url=image_data_url,
+ system_prompt=get_system_prompt(),
+ content_text=get_prompt("text_image"),
)
request_config = {
@@ -356,7 +312,7 @@ def test_text_image_to_text_audio_001(omni_server, openai_client) -> None:
"key_words": {"image": IMAGE_KEY},
}
- openai_client.send_omni_request(request_config)
+ openai_client.send_omni_request(request_config, request_num=get_max_batch_size())
@pytest.mark.advanced_model
@@ -422,6 +378,7 @@ def test_mix_to_text_audio_001(omni_server, openai_client) -> None:
@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+@pytest.mark.skip(reason="issue: #2827")
def test_audio_in_video_001(omni_server, openai_client) -> None:
"""
Input Modal: text + video (synthetic MP4 with embedded audio; ``use_audio_in_video`` uses audio from the video).
@@ -542,6 +499,7 @@ def test_speaker_001(omni_server, openai_client) -> None:
@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+@pytest.mark.skip(reason="Known issue: occasional inaccuracy in voice recognition.")
def test_speaker_002(omni_server, openai_client) -> None:
"""
Input Modal: text only (one-word answer constraint).
diff --git a/tests/e2e/online_serving/test_qwen3_omni_realtime_websocket.py b/tests/e2e/online_serving/test_qwen3_omni_realtime_websocket.py
new file mode 100644
index 0000000000..6a7cf1c67e
--- /dev/null
+++ b/tests/e2e/online_serving/test_qwen3_omni_realtime_websocket.py
@@ -0,0 +1,206 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+E2E online tests for Qwen3-Omni /v1/realtime WebSocket (streaming PCM in, audio out).
+"""
+
+from __future__ import annotations
+
+import asyncio
+import base64
+import io
+import json
+import os
+import wave
+
+import pytest
+import websockets
+
+from tests.conftest import (
+ OmniServerParams,
+ convert_audio_bytes_to_text,
+ cosine_similarity_text,
+ generate_synthetic_audio,
+ modify_stage_config,
+)
+from tests.utils import get_deploy_config_path, hardware_test
+
+os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+
+MODEL = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
+
+# The new-schema CI overlay bakes in async_chunk: False and covers CUDA/ROCm/XPU
+# via its ``platforms:`` section, so one path serves all three.
+default_stage_config = get_deploy_config_path("ci/qwen3_omni_moe.yaml")
+
+
+def _realtime_stage_config_path() -> str:
+ """CI omni layout without async_chunk; stage 0 thinker max_tokens=10."""
+ return modify_stage_config(
+ default_stage_config,
+ updates={"stages": {0: {"default_sampling_params.max_tokens": 10}}},
+ )
+
+
+realtime_server_params = [
+ pytest.param(
+ OmniServerParams(
+ model=MODEL,
+ stage_config_path=_realtime_stage_config_path(),
+ use_stage_cli=True,
+ ),
+ id="thinker_max_tokens_10",
+ ),
+]
+
+
+def _pcm16_mono_16k_from_wav_bytes(wav_bytes: bytes) -> bytes:
+ with wave.open(io.BytesIO(wav_bytes), "rb") as wf:
+ if wf.getnchannels() != 1:
+ raise ValueError(f"Expected mono WAV, got {wf.getnchannels()} channels")
+ if wf.getsampwidth() != 2:
+ raise ValueError(f"Expected 16-bit PCM, sampwidth={wf.getsampwidth()}")
+ if wf.getframerate() != 16000:
+ raise ValueError(f"Expected 16 kHz input for /v1/realtime, got {wf.getframerate()} Hz")
+ if wf.getcomptype() != "NONE":
+ raise ValueError(f"Expected uncompressed PCM, comptype={wf.getcomptype()!r}")
+ return wf.readframes(wf.getnframes())
+
+
+def _wav_bytes_from_pcm16(pcm: bytes, sample_rate_hz: int) -> bytes:
+ buf = io.BytesIO()
+ with wave.open(buf, "wb") as wf:
+ wf.setnchannels(1)
+ wf.setsampwidth(2)
+ wf.setframerate(sample_rate_hz)
+ wf.writeframes(pcm)
+ return buf.getvalue()
+
+
+async def _run_realtime_audio_roundtrip(
+ host: str,
+ port: int,
+ model: str,
+ pcm16: bytes,
+ *,
+ chunk_ms: int = 100,
+) -> dict:
+ uri = f"ws://{host}:{port}/v1/realtime"
+ incremental: list[bytes] = []
+ output_sr = 24000
+ text_chunks: list[str] = []
+ final_text = ""
+ delta_events = 0
+
+ bytes_per_ms = 16000 * 2 // 1000
+ chunk_bytes = max(bytes_per_ms * chunk_ms, 2)
+
+ async with websockets.connect(uri, max_size=64 * 1024 * 1024) as ws:
+ await ws.send(json.dumps({"type": "session.update", "model": model}))
+ await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": False}))
+
+ for i in range(0, len(pcm16), chunk_bytes):
+ chunk = pcm16[i : i + chunk_bytes]
+ await ws.send(
+ json.dumps(
+ {
+ "type": "input_audio_buffer.append",
+ "audio": base64.b64encode(chunk).decode("utf-8"),
+ }
+ )
+ )
+
+ await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": True}))
+
+ while True:
+ message = await asyncio.wait_for(ws.recv(), timeout=600)
+ if isinstance(message, bytes):
+ continue
+
+ event = json.loads(message)
+ event_type = event.get("type")
+
+ if event_type == "session.created":
+ continue
+
+ if event_type == "response.audio.delta":
+ delta_events += 1
+ sr = event.get("sample_rate_hz")
+ if isinstance(sr, int) and sr > 0:
+ output_sr = sr
+ audio_b64 = event.get("audio", "")
+ if audio_b64:
+ incremental.append(base64.b64decode(audio_b64))
+ continue
+
+ if event_type == "transcription.delta":
+ d = event.get("delta", "")
+ if d:
+ text_chunks.append(d)
+ continue
+
+ if event_type == "transcription.done":
+ final_text = event.get("text", "") or "".join(text_chunks)
+ continue
+
+ if event_type == "response.audio.done":
+ break
+
+ if event_type == "error":
+ raise AssertionError(f"WebSocket error: {event}")
+
+ raise AssertionError(f"Unexpected WebSocket event: {event}")
+
+ out_pcm = b"".join(incremental)
+ return {
+ "output_pcm": out_pcm,
+ "output_sample_rate": output_sr,
+ "transcription_text": final_text if final_text else "".join(text_chunks),
+ "delta_events": delta_events,
+ }
+
+
+class TestQwen3OmniRealtimeWebSocket:
+ @pytest.mark.advanced_model
+ @pytest.mark.omni
+ @hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
+ @pytest.mark.parametrize("omni_server", realtime_server_params, indirect=True)
+ def test_streaming_audio_input_pcm_output(self, omni_server) -> None:
+ """
+ Short streamed 16 kHz mono PCM16 input; expect streamed PCM16 audio deltas and
+ transcription. Verify Whisper(output audio) aligns with model text (same idea
+ as multimodal omni e2e).
+ """
+ syn = generate_synthetic_audio(10, 1, sample_rate=16000)
+ wav_bytes = base64.b64decode(syn["base64"])
+ pcm16 = _pcm16_mono_16k_from_wav_bytes(wav_bytes)
+
+ result = asyncio.run(
+ _run_realtime_audio_roundtrip(
+ omni_server.host,
+ omni_server.port,
+ omni_server.model,
+ pcm16,
+ chunk_ms=100,
+ )
+ )
+
+ out_pcm = result["output_pcm"]
+ assert result["delta_events"] >= 1
+ assert out_pcm, "No output PCM from response.audio.delta"
+ assert len(out_pcm) % 2 == 0
+ assert len(out_pcm) >= 4096, "Output audio unexpectedly small"
+ assert result["output_sample_rate"] > 0
+
+ final_text = (result["transcription_text"] or "").strip()
+ assert final_text, "Expected non-empty transcription (model text stream)"
+
+ wav_out = _wav_bytes_from_pcm16(out_pcm, result["output_sample_rate"])
+ whisper_text = convert_audio_bytes_to_text(wav_out).strip()
+ assert whisper_text, "Whisper returned empty string for synthesized output audio"
+
+ sim = cosine_similarity_text(whisper_text.lower(), final_text.lower())
+ assert sim > 0.9, (
+ f"Output audio transcript should match model text (sim={sim:.3f}): "
+ f"whisper={whisper_text!r}, model_text={final_text!r}"
+ )
diff --git a/tests/e2e/online_serving/test_qwen3_tts_base.py b/tests/e2e/online_serving/test_qwen3_tts_base.py
index 002f9d9972..c97fdef5bc 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_base.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_base.py
@@ -12,12 +12,10 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-from pathlib import Path
-
import pytest
from tests.conftest import OmniServerParams
-from tests.utils import hardware_test
+from tests.utils import get_deploy_config_path, hardware_test
MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
@@ -25,11 +23,6 @@
REF_TEXT = "Okay. Yeah. I resent you. I love you. I respect you. But you know what? You blew it! And thanks to you."
-def get_stage_config(name: str = "qwen3_tts.yaml"):
- """Get the stage config path from vllm_omni model_executor stage_configs."""
- return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
-
-
def get_prompt(prompt_type="text"):
"""Text prompt for text-to-audio tests (same as test_qwen3_omni - beijing test case)."""
prompts = {
@@ -48,7 +41,7 @@ def get_max_batch_size(size_type="few"):
pytest.param(
OmniServerParams(
model=MODEL,
- stage_config_path=get_stage_config("qwen3_tts.yaml"),
+ stage_config_path=get_deploy_config_path("qwen3_tts.yaml"),
server_args=["--trust-remote-code", "--disable-log-stats"],
),
id="async_chunk",
diff --git a/tests/e2e/online_serving/test_qwen3_tts_base_expansion.py b/tests/e2e/online_serving/test_qwen3_tts_base_expansion.py
index 3c33485e4f..364865d286 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_base_expansion.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_base_expansion.py
@@ -12,12 +12,10 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-from pathlib import Path
-
import pytest
from tests.conftest import OmniServerParams
-from tests.utils import hardware_test
+from tests.utils import get_deploy_config_path, hardware_test
MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
@@ -25,11 +23,6 @@
REF_TEXT = "Okay. Yeah. I resent you. I love you. I respect you. But you know what? You blew it! And thanks to you."
-def get_stage_config(name: str = "qwen3_tts.yaml"):
- """Get the stage config path from vllm_omni model_executor stage_configs."""
- return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
-
-
def get_prompt(prompt_type="text"):
"""Text prompt for text-to-audio tests (same as test_qwen3_omni - beijing test case)."""
prompts = {
@@ -48,16 +41,19 @@ def get_max_batch_size(size_type="few"):
pytest.param(
OmniServerParams(
model=MODEL,
- stage_config_path=get_stage_config("qwen3_tts.yaml"),
+ stage_config_path=get_deploy_config_path("qwen3_tts.yaml"),
server_args=["--trust-remote-code", "--disable-log-stats"],
),
id="async_chunk",
),
+ # Synchronous (no async-chunk) variant — ``--no-async-chunk`` alone
+ # flips the deploy yaml's bool and the pipeline dispatches to the
+ # end-to-end codec processor. No variant yaml / pipeline needed.
pytest.param(
OmniServerParams(
model=MODEL,
- stage_config_path=get_stage_config("qwen3_tts_no_async_chunk.yaml"),
- server_args=["--trust-remote-code", "--disable-log-stats"],
+ stage_config_path=get_deploy_config_path("qwen3_tts.yaml"),
+ server_args=["--trust-remote-code", "--disable-log-stats", "--no-async-chunk"],
),
id="no_async_chunk",
),
diff --git a/tests/e2e/online_serving/test_qwen3_tts_batch.py b/tests/e2e/online_serving/test_qwen3_tts_batch.py
index d0d6336618..bf13884997 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_batch.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_batch.py
@@ -27,13 +27,15 @@
convert_audio_file_to_text,
cosine_similarity_text,
)
-from tests.utils import hardware_test
+from tests.utils import get_deploy_config_path, hardware_test
MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice"
+STAGE_INIT_TIMEOUT_S = 120
-def get_stage_config(name: str = "qwen3_tts.yaml"):
- return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
+def get_stage_config(name: str = "qwen3_tts.yaml") -> str:
+ """Resolve a deploy config path under vllm_omni/deploy/."""
+ return get_deploy_config_path(name)
@pytest.fixture(scope="module")
@@ -47,7 +49,7 @@ def omni_server():
"--stage-configs-path",
stage_config_path,
"--stage-init-timeout",
- "120",
+ str(STAGE_INIT_TIMEOUT_S),
"--trust-remote-code",
"--enforce-eager",
"--disable-log-stats",
@@ -337,7 +339,7 @@ def omni_server_batch2():
"--stage-configs-path",
config_path,
"--stage-init-timeout",
- "120",
+ str(STAGE_INIT_TIMEOUT_S),
"--trust-remote-code",
"--enforce-eager",
"--disable-log-stats",
diff --git a/tests/e2e/online_serving/test_qwen3_tts_customvoice.py b/tests/e2e/online_serving/test_qwen3_tts_customvoice.py
index fb60df725b..d19c652689 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_customvoice.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_customvoice.py
@@ -12,21 +12,14 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-from pathlib import Path
-
import pytest
from tests.conftest import OmniServerParams
-from tests.utils import hardware_test
+from tests.utils import get_deploy_config_path, hardware_test
MODEL = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
-def get_stage_config(name: str = "qwen3_tts.yaml"):
- """Get the stage config path from vllm_omni model_executor stage_configs."""
- return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
-
-
def get_prompt(prompt_type="text"):
"""Text prompt for text-to-audio tests (same as test_qwen3_omni - beijing test case)."""
prompts = {
@@ -45,7 +38,7 @@ def get_max_batch_size(size_type="few"):
pytest.param(
OmniServerParams(
model=MODEL,
- stage_config_path=get_stage_config("qwen3_tts.yaml"),
+ stage_config_path=get_deploy_config_path("qwen3_tts.yaml"),
server_args=["--trust-remote-code", "--disable-log-stats"],
),
id="async_chunk",
diff --git a/tests/e2e/online_serving/test_qwen3_tts_customvoice_expansion.py b/tests/e2e/online_serving/test_qwen3_tts_customvoice_expansion.py
index 03a985896e..4087532d63 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_customvoice_expansion.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_customvoice_expansion.py
@@ -12,21 +12,14 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-from pathlib import Path
-
import pytest
from tests.conftest import OmniServerParams
-from tests.utils import hardware_test
+from tests.utils import get_deploy_config_path, hardware_test
MODEL = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
-def get_stage_config(name: str = "qwen3_tts.yaml"):
- """Get the stage config path from vllm_omni model_executor stage_configs."""
- return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
-
-
def get_prompt(prompt_type="english"):
"""Text prompt for text-to-audio tests (same as test_qwen3_omni - beijing test case)."""
prompts = {
@@ -46,16 +39,19 @@ def get_max_batch_size(size_type="few"):
pytest.param(
OmniServerParams(
model=MODEL,
- stage_config_path=get_stage_config("qwen3_tts.yaml"),
+ stage_config_path=get_deploy_config_path("qwen3_tts.yaml"),
server_args=["--trust-remote-code", "--disable-log-stats"],
),
id="async_chunk",
),
+ # Synchronous (no async-chunk) variant — ``--no-async-chunk`` alone
+ # flips the deploy yaml's bool and the pipeline dispatches to the
+ # end-to-end codec processor. No variant yaml / pipeline needed.
pytest.param(
OmniServerParams(
model=MODEL,
- stage_config_path=get_stage_config("qwen3_tts_no_async_chunk.yaml"),
- server_args=["--trust-remote-code", "--disable-log-stats"],
+ stage_config_path=get_deploy_config_path("qwen3_tts.yaml"),
+ server_args=["--trust-remote-code", "--disable-log-stats", "--no-async-chunk"],
),
id="no_async_chunk",
),
diff --git a/tests/e2e/online_serving/test_qwen3_tts_speaker_embedding.py b/tests/e2e/online_serving/test_qwen3_tts_speaker_embedding.py
index 64e13e1557..d4212bb5b1 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_speaker_embedding.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_speaker_embedding.py
@@ -13,16 +13,16 @@
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
import struct
-from pathlib import Path
import httpx
import pytest
from tests.conftest import OmniServer
-from tests.utils import hardware_test
+from tests.utils import get_deploy_config_path, hardware_test
MODEL_BASE = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
MODEL_BASE_1_7B = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
+STAGE_INIT_TIMEOUT_S = 120
# A synthetic 1024-dim speaker embedding (all 0.1 — not a real voice, but
# exercises the full code path through the talker's _build_prompt_embeds).
@@ -36,10 +36,8 @@
MAX_NEW_TOKENS = 256
-def get_stage_config():
- return str(
- Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / "qwen3_tts.yaml"
- )
+def get_stage_config() -> str:
+ return get_deploy_config_path("qwen3_tts.yaml")
def _server_args():
@@ -47,7 +45,7 @@ def _server_args():
"--stage-configs-path",
get_stage_config(),
"--stage-init-timeout",
- "120",
+ str(STAGE_INIT_TIMEOUT_S),
"--trust-remote-code",
"--enforce-eager",
"--disable-log-stats",
diff --git a/tests/e2e/online_serving/test_qwen3_tts_websocket.py b/tests/e2e/online_serving/test_qwen3_tts_websocket.py
index df05146011..dddba6e58a 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_websocket.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_websocket.py
@@ -7,24 +7,22 @@
import asyncio
import json
import os
-from pathlib import Path
import pytest
import websockets
from tests.conftest import OmniServer
-from tests.utils import hardware_test
+from tests.utils import get_deploy_config_path, hardware_test
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice"
+STAGE_INIT_TIMEOUT_S = 120
def get_stage_config() -> str:
- return str(
- Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / "qwen3_tts.yaml"
- )
+ return get_deploy_config_path("qwen3_tts.yaml")
@pytest.fixture(scope="module")
@@ -37,7 +35,7 @@ def omni_server():
"--stage-configs-path",
stage_config_path,
"--stage-init-timeout",
- "120",
+ str(STAGE_INIT_TIMEOUT_S),
"--trust-remote-code",
"--enforce-eager",
"--disable-log-stats",
diff --git a/tests/e2e/stage_configs/bailingmm_moe_v2_lite_ci.yaml b/tests/e2e/stage_configs/bailingmm_moe_v2_lite_ci.yaml
new file mode 100644
index 0000000000..fb0c72cc51
--- /dev/null
+++ b/tests/e2e/stage_configs/bailingmm_moe_v2_lite_ci.yaml
@@ -0,0 +1,35 @@
+# Thinker stage only
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ runtime:
+ devices: "0,1,2,3"
+ max_batch_size: 1
+ engine_args:
+ model_stage: thinker
+ model_arch: MingFlashOmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.9
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: latent
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ max_model_len: 32768
+ tensor_parallel_size: 4
+ hf_config_name: llm_config
+ load_format: dummy
+ mm_processor_cache_gb: 0
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ max_tokens: 100
+ repetition_penalty: 1.05
+ seed: 42
+ detokenize: true
+ ignore_eos: false
diff --git a/tests/e2e/stage_configs/dynin_omni_ci.yaml b/tests/e2e/stage_configs/dynin_omni_ci.yaml
index 0240007510..525b7d888c 100644
--- a/tests/e2e/stage_configs/dynin_omni_ci.yaml
+++ b/tests/e2e/stage_configs/dynin_omni_ci.yaml
@@ -72,13 +72,8 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
edges:
- from: 0
to: 1
- window_size: -1
- from: 1
to: 2
- window_size: -1
diff --git a/tests/e2e/stage_configs/qwen2_5_omni_ci.yaml b/tests/e2e/stage_configs/qwen2_5_omni_ci.yaml
deleted file mode 100644
index a7c637d486..0000000000
--- a/tests/e2e/stage_configs/qwen2_5_omni_ci.yaml
+++ /dev/null
@@ -1,109 +0,0 @@
-# stage config for running qwen2.5-omni for multi-stage omni runtime.
-
-# The following config has been verified on 2x 24GB GPU (L4/RTX3090/RTX4090).
-# This config is optimized for CI e2e tests.
-stage_args:
- - stage_id: 0
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
- engine_args:
- model_stage: thinker
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- max_model_len: 16384
- max_num_batched_tokens: 16384
- max_num_seqs: 1
- gpu_memory_utilization: 0.9
- skip_mm_profiling: true
- enforce_eager: true # Now we only support eager mode
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- mm_processor_cache_gb: 0
- load_format: dummy
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 128
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- - stage_id: 1
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- max_model_len: 16384
- max_num_batched_tokens: 16384
- max_num_seqs: 1
- gpu_memory_utilization: 0.4
- skip_mm_profiling: true
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: latent
- load_format: dummy
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 4096
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
- - stage_id: 2
- runtime:
- process: true
- devices: "2" # Example: use a different GPU than the previous stage; use "0" if single GPU
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.5 #increase the gpu memory utilization to enable the test on H800
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio
- max_num_batched_tokens: 8192
- max_model_len: 8192
- load_format: dummy
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 8192
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/tests/e2e/stage_configs/qwen2_5_omni_thinker_ci.yaml b/tests/e2e/stage_configs/qwen2_5_omni_thinker_ci.yaml
deleted file mode 100644
index 9401382847..0000000000
--- a/tests/e2e/stage_configs/qwen2_5_omni_thinker_ci.yaml
+++ /dev/null
@@ -1,31 +0,0 @@
-stage_args:
- - stage_id: 0
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
- engine_args:
- model_stage: thinker
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- max_model_len: 16384
- max_num_batched_tokens: 16384
- max_num_seqs: 1
- gpu_memory_utilization: 0.9
- skip_mm_profiling: true
- enforce_eager: true # Now we only support eager mode
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- mm_processor_cache_gb: 0
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 128
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/tests/e2e/stage_configs/qwen3_omni_ci.yaml b/tests/e2e/stage_configs/qwen3_omni_ci.yaml
deleted file mode 100644
index 08dd49de95..0000000000
--- a/tests/e2e/stage_configs/qwen3_omni_ci.yaml
+++ /dev/null
@@ -1,102 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
-
-# The following config has been verified on 2x H100-80G GPUs.
-stage_args:
-- stage_id: 0
- runtime:
- devices: "0"
- engine_args:
- model_stage: thinker
- max_num_seqs: 5
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 32768
- max_model_len: 32768
- enable_prefix_caching: false
- mm_processor_cache_gb: 0
- hf_config_name: thinker_config
- tensor_parallel_size: 1
- load_format: dummy
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 150
- seed: 42
- ignore_eos: False
- detokenize: True
- repetition_penalty: 1.05
-
-- stage_id: 1
- runtime:
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 5
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.5
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- max_model_len: 32768
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- load_format: dummy
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 1000
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
-- stage_id: 2
- runtime:
- devices: "1"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 5
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.1
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 100000
- hf_config_name: thinker_config
- async_scheduling: false
- load_format: dummy
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2000
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/tests/e2e/stage_configs/rocm/qwen2_5_omni_ci.yaml b/tests/e2e/stage_configs/rocm/qwen2_5_omni_ci.yaml
deleted file mode 100644
index 0c756ce56b..0000000000
--- a/tests/e2e/stage_configs/rocm/qwen2_5_omni_ci.yaml
+++ /dev/null
@@ -1,106 +0,0 @@
-# stage config for running qwen2.5-omni for multi-stage omni runtime.
-
-# The following config has been verified on 2x 24GB GPU (L4/RTX3090/RTX4090).
-# This config is optimized for CI e2e tests.
-stage_args:
- - stage_id: 0
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
- engine_args:
- model_stage: thinker
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- max_model_len: 16384
- max_num_batched_tokens: 16384
- max_num_seqs: 1
- gpu_memory_utilization: 0.8
- skip_mm_profiling: true
- enforce_eager: true # Now we only support eager mode
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- mm_processor_cache_gb: 0
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 128
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- - stage_id: 1
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- max_model_len: 16384
- max_num_batched_tokens: 16384
- max_num_seqs: 1
- gpu_memory_utilization: 0.8
- skip_mm_profiling: true
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: latent
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 4096
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
- - stage_id: 2
- runtime:
- process: true
- devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.15
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio
- max_num_batched_tokens: 4096
- max_model_len: 4096
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 4096
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/tests/e2e/stage_configs/rocm/qwen3_omni_ci.yaml b/tests/e2e/stage_configs/rocm/qwen3_omni_ci.yaml
deleted file mode 100644
index ac2b1fbd71..0000000000
--- a/tests/e2e/stage_configs/rocm/qwen3_omni_ci.yaml
+++ /dev/null
@@ -1,100 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
-
-# The following config has been verified on 2x H100-80G GPUs.
-stage_args:
- - stage_id: 0
- runtime:
- devices: "0"
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- mm_processor_cache_gb: 0
- hf_config_name: thinker_config
- tensor_parallel_size: 1
- load_format: dummy
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 100
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
-
- - stage_id: 1
- runtime:
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- # tensor_parallel_size: 2
- enable_prefix_caching: false
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- load_format: dummy
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 100
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 2
- runtime:
- devices: "1"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.1
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 1000000
- hf_config_name: thinker_config
- load_format: dummy
- async_scheduling: false
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 200
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/tests/e2e/stage_configs/xpu/qwen2_5_omni_ci.yaml b/tests/e2e/stage_configs/xpu/qwen2_5_omni_ci.yaml
deleted file mode 100644
index 14ef3c3438..0000000000
--- a/tests/e2e/stage_configs/xpu/qwen2_5_omni_ci.yaml
+++ /dev/null
@@ -1,108 +0,0 @@
-# stage config for running qwen2.5-omni for multi-stage omni runtime.
-
-# The following config is verified with 2 * Intel Arc Pro B60 XPU.
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage
- engine_args:
- model_stage: thinker
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- max_model_len: 16384
- max_num_batched_tokens: 16384
- max_num_seqs: 1
- gpu_memory_utilization: 0.9 # thinker weight is around 16.74GB for Qwen2.5-Omni-7B
- skip_mm_profiling: true
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- mm_processor_cache_gb: 0
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 128
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- max_model_len: 16384
- max_num_batched_tokens: 16384
- max_num_seqs: 1
- gpu_memory_utilization: 0.5 # talker weight is 6.03GB for Qwen2.5-Omni-7B
- skip_mm_profiling: true
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: latent
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 4096
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
-
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "2"
- engine_args:
- max_num_seqs: 1
- model_stage: code2wav
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.3 # code2wav weight is around 1.46GB for Qwen2.5-Omni-7B
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
-
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/tests/e2e/stage_configs/xpu/qwen3_omni_ci.yaml b/tests/e2e/stage_configs/xpu/qwen3_omni_ci.yaml
deleted file mode 100644
index c4586e0664..0000000000
--- a/tests/e2e/stage_configs/xpu/qwen3_omni_ci.yaml
+++ /dev/null
@@ -1,109 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
-
-# The following config is verified with 8 * Intel Arc Pro B60 XPU.
-stage_args:
-- stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "0,1,2,3"
- engine_args:
- max_num_seqs: 1
- model_stage: thinker
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.85 # thinker weight is around 61.08GB for Qwen3-Omni-30B-A3B-Instruct
- skip_mm_profiling: true
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 4096
- max_model_len: 4096
- enable_prefix_caching: false
- hf_config_name: thinker_config
- tensor_parallel_size: 4
- max_cudagraph_capture_size: 0
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 100
- seed: 42
- ignore_eos: False
- detokenize: True
- repetition_penalty: 1.05
-
-- stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "4"
- engine_args:
- max_num_seqs: 1
- model_stage: talker
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6 # talker weight is around 8.5GB for Qwen3-Omni-30B-A3B-Instruct
- skip_mm_profiling: true
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- enable_prefix_caching: false
- max_num_batched_tokens: 4096
- max_model_len: 4096
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- max_cudagraph_capture_size: 0
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
-- stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "5"
- engine_args:
- max_num_seqs: 1
- model_stage: code2wav
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.3 # code2wav weight is around 0.4GB for Qwen3-Omni-30B-A3B-Instruct
- skip_mm_profiling: true
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 100000
- hf_config_name: thinker_config
- async_scheduling: false
- max_cudagraph_capture_size: 0
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2000
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py
index 5584b15d9f..4d69f24c56 100644
--- a/tests/engine/test_arg_utils.py
+++ b/tests/engine/test_arg_utils.py
@@ -4,7 +4,9 @@
explicitly patch values that differ from vLLM.
"""
+import argparse
import inspect
+from types import SimpleNamespace
from unittest.mock import Mock
import pytest
@@ -14,6 +16,7 @@
from vllm_omni.config.model import OmniModelConfig
from vllm_omni.engine.arg_utils import OmniEngineArgs
+from vllm_omni.engine.async_omni_engine import AsyncOmniEngine
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
@@ -36,21 +39,28 @@ def test_default_stage_id_is_concrete_int():
assert cfg.stage_id == 0
-def test_multimodal_kwarg_overrides():
+def test_multimodal_kwarg_overrides(mocker):
"""Ensure that overrides in the multimodal config are preserved."""
- # Get a different value than the default for a multimodal field
sig = inspect.signature(OmniEngineArgs)
default_mm_cache = sig.parameters["mm_processor_cache_gb"].default
override_val = default_mm_cache + 1
- # NOTE: This needs to be a model that resolves to supports_multimodal=True
- # in vLLM, otherwise we won't have an MM config
+ fake_model_config = SimpleNamespace(
+ multimodal_config=SimpleNamespace(mm_processor_cache_gb=override_val),
+ )
+
+ def _fake_parent_create_model_config(self):
+ assert self.mm_processor_cache_gb == override_val
+ return fake_model_config
+
+ mocker.patch.object(EngineArgs, "create_model_config", _fake_parent_create_model_config)
+ mocker.patch.object(OmniModelConfig, "from_vllm_model_config", side_effect=lambda model_config, **_: model_config)
+
cfg = OmniEngineArgs(
model="Qwen/Qwen2-VL-2B-Instruct",
mm_processor_cache_gb=override_val,
).create_model_config()
- # Ensure that the override was applied correctly
assert cfg.multimodal_config is not None
assert cfg.multimodal_config.mm_processor_cache_gb == override_val
@@ -100,7 +110,7 @@ def test_qwen3_tts_codec_frame_rate_patching():
vllm_config = EngineArgs().create_model_config()
# Create a mock talking config with a dummy value for position_id_per_seconds
- mock_talker_config = Mock()
+ mock_talker_config = SimpleNamespace()
mock_talker_config.position_id_per_seconds = 12.3
vllm_config.hf_config.talker_config = mock_talker_config
@@ -116,6 +126,19 @@ def test_qwen3_tts_codec_frame_rate_patching():
assert omni_config.codec_frame_rate_hz == 12.3
+def test_from_cli_args_picks_up_stage_configs_path():
+ """from_cli_args should pick up stage_configs_path from namespace."""
+ ns = argparse.Namespace(
+ model="facebook/opt-125m",
+ stage_configs_path="/some/path.yaml",
+ custom_pipeline_args=None,
+ )
+
+ args = OmniEngineArgs.from_cli_args(ns)
+ assert args.stage_configs_path == "/some/path.yaml"
+ assert args.custom_pipeline_args is None
+
+
def test_stage_specific_text_config_override():
"""Ensure dependent attributes are updated when using stage-specific config."""
vllm_config = EngineArgs().create_model_config()
@@ -124,13 +147,12 @@ def test_stage_specific_text_config_override():
# Switch the created hf text config with a mock whose
# values we want to pull through the text config helper
stage_text_config = vllm_config.hf_text_config
- vllm_config.hf_text_config = Mock()
+ vllm_config.hf_text_config = SimpleNamespace()
stage_text_config.sliding_window = 4096
stage_text_config.attention_chunk_size = 2048
# Move the stage config's text config getter & thinker config
- mock_stage_config = Mock()
- mock_stage_config.get_text_config.return_value = stage_text_config
+ mock_stage_config = SimpleNamespace(get_text_config=lambda: stage_text_config)
vllm_config.hf_config.thinker_config = mock_stage_config
# Ensure that create from a vLLM config correctly pulls the
@@ -144,3 +166,92 @@ def test_stage_specific_text_config_override():
assert omni_config.attention_chunk_size == 2048
assert omni_config.max_model_len == 4096
assert omni_config.hf_text_config.sliding_window is None
+
+
+def test_stage_configs_path_field():
+ """OmniEngineArgs with stage_configs_path should construct without error."""
+ args = OmniEngineArgs(stage_configs_path="/some/path.yaml")
+ assert args.stage_configs_path == "/some/path.yaml"
+
+
+def test_voxcpm_model_arch_injects_model_type_override(mocker):
+ """Ensure VoxCPM model_arch injects hf_overrides for config resolution."""
+ mocker.patch.object(OmniEngineArgs, "_ensure_omni_models_registered", return_value=True)
+ mocker.patch.object(OmniEngineArgs, "_patch_empty_hf_config")
+ mocker.patch.object(EngineArgs, "create_model_config", return_value=Mock())
+ mocker.patch.object(OmniModelConfig, "from_vllm_model_config", return_value=Mock())
+
+ args = OmniEngineArgs(
+ model="OpenBMB/VoxCPM1.5",
+ model_arch="VoxCPMForConditionalGeneration",
+ )
+ args.create_model_config()
+
+ assert args.hf_overrides["architectures"] == ["VoxCPMForConditionalGeneration"]
+ assert args.hf_overrides["model_type"] == "voxcpm"
+ args._patch_empty_hf_config.assert_called_once_with("voxcpm")
+
+
+def test_strip_single_engine_args():
+ """_strip_single_engine_args should remove EngineArgs fields but keep omni fields."""
+ kwargs = {
+ # Parent EngineArgs fields — should be stripped
+ "compilation_config": '{"cudagraph_mode": "FULL_AND_PIECEWISE"}',
+ "tensor_parallel_size": 4,
+ "gpu_memory_utilization": 0.9,
+ "model": "some/model",
+ # Parent field that should be kept (allowlisted)
+ "worker_extension_cls": "some.Extension",
+ # OmniEngineArgs-only / non-engine fields — should pass through
+ "stage_configs_path": "/path/to/yaml",
+ "custom_pipeline_args": {"pipeline_class": "my.Pipeline"},
+ "mode": "text-to-image",
+ "lora_path": "/some/lora",
+ }
+
+ filtered = AsyncOmniEngine._strip_single_engine_args(kwargs)
+
+ # Stripped — parent EngineArgs fields
+ assert "compilation_config" not in filtered
+ assert "tensor_parallel_size" not in filtered
+ assert "gpu_memory_utilization" not in filtered
+ assert "model" not in filtered
+
+ # Stripped — orchestrator-level OmniEngineArgs field
+ assert "stage_configs_path" not in filtered
+
+ # Kept
+ assert filtered["worker_extension_cls"] == "some.Extension"
+ assert filtered["custom_pipeline_args"] == {"pipeline_class": "my.Pipeline"}
+ assert filtered["mode"] == "text-to-image"
+ assert filtered["lora_path"] == "/some/lora"
+
+
+def test_strip_single_engine_args_model_does_not_trigger_warning(mocker):
+ """model is always in kwargs (callers set it via from_cli_args/asdict),
+ so it should not cause the override warning by itself or appear in it."""
+ mock_warn = mocker.patch("vllm_omni.engine.async_omni_engine.logger.warning")
+
+ # Typical caller kwargs: model is always present, no other parent
+ # EngineArgs fields are explicitly overridden.
+ AsyncOmniEngine._strip_single_engine_args(
+ {
+ "model": "some/model",
+ "custom_pipeline_args": {"pipeline_class": "my.Pipeline"},
+ }
+ )
+ mock_warn.assert_not_called()
+
+ # When there *are* genuinely surprising overrides alongside model,
+ # the warning should mention them but not model.
+ AsyncOmniEngine._strip_single_engine_args(
+ {
+ "model": "some/model",
+ "tensor_parallel_size": 4,
+ "custom_pipeline_args": {"pipeline_class": "my.Pipeline"},
+ }
+ )
+ mock_warn.assert_called_once()
+ warned_args = mock_warn.call_args[0][-1] # the formatted arg list
+ assert "tensor_parallel_size" in warned_args
+ assert "model" not in warned_args
diff --git a/tests/engine/test_async_omni_engine_abort.py b/tests/engine/test_async_omni_engine_abort.py
index 34fdf45ea2..e7f2bb679f 100644
--- a/tests/engine/test_async_omni_engine_abort.py
+++ b/tests/engine/test_async_omni_engine_abort.py
@@ -2,20 +2,24 @@
import os
import sys
from contextlib import ExitStack
-from pathlib import Path
import pytest
from vllm import SamplingParams
from vllm.inputs import PromptType
-from tests.utils import hardware_test
+# Side-effect import: registers QWEN2_5_OMNI_THINKER_ONLY_PIPELINE in the
+# pipeline registry so the materialized deploy overlay below can select it
+# via its top-level ``pipeline:`` field.
+import vllm_omni.model_executor.models.qwen2_5_omni.pipeline # noqa: F401, E402
+from tests.utils import get_deploy_config_path, hardware_test
from vllm_omni.entrypoints.async_omni import AsyncOmni
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
SEED = 42
-stage_config = str(Path(__file__).parent.parent / "e2e" / "stage_configs" / "qwen2_5_omni_thinker_ci.yaml")
+# Single-stage thinker-only deploy, materialized from tests.utils._CI_OVERLAYS.
+stage_config = get_deploy_config_path("ci/qwen2_5_omni_thinker_only.yaml")
model = "Qwen/Qwen2.5-Omni-7B"
diff --git a/tests/engine/test_async_omni_engine_input.py b/tests/engine/test_async_omni_engine_input.py
index ed6a7277b4..3700e426d4 100644
--- a/tests/engine/test_async_omni_engine_input.py
+++ b/tests/engine/test_async_omni_engine_input.py
@@ -1,6 +1,5 @@
-from unittest.mock import Mock
-
import pytest
+from pytest_mock import MockerFixture
from vllm.sampling_params import SamplingParams
from vllm.v1.engine import EngineCoreRequest
@@ -24,18 +23,18 @@ def _make_engine_core_request() -> EngineCoreRequest:
)
-def test_build_add_request_message_preserves_additional_information():
+def test_build_add_request_message_preserves_additional_information(mocker: MockerFixture):
engine = object.__new__(AsyncOmniEngine)
params = SamplingParams(max_tokens=8)
engine.default_sampling_params_list = [params]
engine.stage_metadata = [{"stage_type": "llm"}]
engine.supported_tasks = ("speech",)
- input_processor = Mock()
+ input_processor = mocker.Mock()
input_processor.process_inputs.return_value = _make_engine_core_request()
engine.input_processor = input_processor
- output_processor = Mock()
+ output_processor = mocker.Mock()
engine.output_processors = [output_processor]
prompt = {
@@ -63,18 +62,18 @@ def test_build_add_request_message_preserves_additional_information():
output_processor.add_request.assert_called_once()
-def test_build_add_request_message_with_resumable_streaming():
+def test_build_add_request_message_with_resumable_streaming(mocker: MockerFixture):
engine = object.__new__(AsyncOmniEngine)
params = SamplingParams(max_tokens=8)
engine.default_sampling_params_list = [params]
engine.stage_metadata = [{"stage_type": "llm"}]
engine.supported_tasks = ("generate",)
- input_processor = Mock()
+ input_processor = mocker.Mock()
input_processor.process_inputs.return_value = _make_engine_core_request()
engine.input_processor = input_processor
- output_processor = Mock()
+ output_processor = mocker.Mock()
engine.output_processors = [output_processor]
msg = engine._build_add_request_message(
diff --git a/tests/engine/test_async_omni_engine_outputs.py b/tests/engine/test_async_omni_engine_outputs.py
index ccf9e8cb6b..ef3cfab3bf 100644
--- a/tests/engine/test_async_omni_engine_outputs.py
+++ b/tests/engine/test_async_omni_engine_outputs.py
@@ -5,36 +5,36 @@
"""
import queue
-from unittest.mock import MagicMock
import pytest
+from pytest_mock import MockerFixture
from vllm_omni.engine.async_omni_engine import AsyncOmniEngine
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-def _make_engine(output_queue, *, thread_alive: bool = True) -> AsyncOmniEngine:
+def _make_engine(output_queue, mocker: MockerFixture, *, thread_alive: bool = True) -> AsyncOmniEngine:
"""Create an AsyncOmniEngine bypassing __init__."""
engine = object.__new__(AsyncOmniEngine)
engine.output_queue = output_queue
- engine.orchestrator_thread = MagicMock(
- is_alive=MagicMock(return_value=thread_alive),
+ engine.orchestrator_thread = mocker.MagicMock(
+ is_alive=mocker.MagicMock(return_value=thread_alive),
)
return engine
-def test_try_get_output_raises_after_orchestrator_dies():
+def test_try_get_output_raises_after_orchestrator_dies(mocker: MockerFixture):
"""Draining remaining results then hitting an empty queue with a dead
orchestrator must raise RuntimeError so callers know the pipeline is gone."""
- mock_queue = MagicMock()
+ mock_queue = mocker.MagicMock()
# First call succeeds; second call finds the queue empty.
mock_queue.sync_q.get.side_effect = [
{"type": "output", "request_id": "r1"},
queue.Empty,
]
- engine = _make_engine(mock_queue, thread_alive=True)
+ engine = _make_engine(mock_queue, mocker, thread_alive=True)
# Collect the one buffered result.
assert engine.try_get_output()["request_id"] == "r1"
@@ -47,15 +47,15 @@ def test_try_get_output_raises_after_orchestrator_dies():
@pytest.mark.asyncio
-async def test_try_get_output_async_raises_after_orchestrator_dies():
+async def test_try_get_output_async_raises_after_orchestrator_dies(mocker: MockerFixture):
"""Same scenario as above but for the async variant."""
- mock_queue = MagicMock()
+ mock_queue = mocker.MagicMock()
mock_queue.sync_q.get_nowait.side_effect = [
{"type": "output", "request_id": "r1"},
queue.Empty,
]
- engine = _make_engine(mock_queue, thread_alive=True)
+ engine = _make_engine(mock_queue, mocker, thread_alive=True)
assert (await engine.try_get_output_async())["request_id"] == "r1"
diff --git a/tests/engine/test_async_omni_engine_stage_init.py b/tests/engine/test_async_omni_engine_stage_init.py
index 31d3ed7751..5c2a9edb77 100644
--- a/tests/engine/test_async_omni_engine_stage_init.py
+++ b/tests/engine/test_async_omni_engine_stage_init.py
@@ -1,5 +1,6 @@
import importlib
import os
+import threading
import types
import pytest
@@ -86,6 +87,238 @@ def _fake_setup_stage_devices(_stage_id, _runtime_cfg):
os.environ[env_var] = old_env
+def test_initialize_stages_passes_stage_init_timeout_to_diffusion_handshake(monkeypatch):
+ """Regression test for stage_init_timeout passing to complete_diffusion_handshake
+ in the diffusion stage path.
+ """
+ import vllm_omni.diffusion.data as diffusion_data_mod
+ import vllm_omni.diffusion.stage_diffusion_client as client_mod
+ import vllm_omni.engine.async_omni_engine as engine_mod
+ from vllm_omni.platforms import current_omni_platform
+
+ engine = object.__new__(AsyncOmniEngine)
+ engine.log_stats = False
+ engine.model = "dummy-model"
+ engine.config_path = "dummy-config"
+ engine.num_stages = 2
+ engine.async_chunk = False
+ engine.diffusion_batch_size = 1
+ engine.single_stage_mode = False
+ engine._omni_master_server = None
+ engine.stage_configs = [types.SimpleNamespace(stage_id=0, stage_type="diffusion", engine_args={})]
+
+ metadata = types.SimpleNamespace(
+ stage_id=0,
+ stage_type="diffusion",
+ runtime_cfg={"devices": "0"},
+ prompt_expand_func=None,
+ final_output=True,
+ final_output_type="image",
+ default_sampling_params=None,
+ custom_process_input_func=None,
+ engine_input_source=None,
+ cfg_kv_collect_func=None,
+ )
+
+ captured_timeout = None
+ device_env_var = current_omni_platform.device_control_env_var
+ prev_device_env = os.environ.get(device_env_var)
+ os.environ[device_env_var] = "0"
+
+ monkeypatch.setattr(engine_mod, "prepare_engine_environment", lambda: None)
+ monkeypatch.setattr(engine_mod, "load_omni_transfer_config_for_model", lambda *_: None)
+ monkeypatch.setattr(engine_mod, "extract_stage_metadata", lambda _cfg: metadata)
+ monkeypatch.setattr(engine_mod, "setup_stage_devices", lambda *_: None)
+ monkeypatch.setattr(
+ engine_mod,
+ "finalize_initialized_stages",
+ lambda stage_clients, _input_processor: (
+ stage_clients,
+ [types.SimpleNamespace()],
+ [{"final_output_type": "image"}],
+ ),
+ )
+ monkeypatch.setattr(
+ diffusion_data_mod.OmniDiffusionConfig,
+ "from_kwargs",
+ classmethod(lambda cls, **kwargs: types.SimpleNamespace(parallel_config=types.SimpleNamespace(world_size=1))),
+ )
+ monkeypatch.setattr(
+ client_mod,
+ "spawn_diffusion_proc",
+ lambda model, od_cfg: (object(), "ipc://handshake", "ipc://request", "ipc://response"),
+ )
+
+ def _capture_handshake_timeout(proc, handshake_address, handshake_timeout):
+ nonlocal captured_timeout
+ captured_timeout = handshake_timeout
+
+ monkeypatch.setattr(client_mod, "complete_diffusion_handshake", _capture_handshake_timeout)
+ monkeypatch.setattr(
+ client_mod.zmq,
+ "Context",
+ lambda: types.SimpleNamespace(socket=lambda _: types.SimpleNamespace(connect=lambda _: None)),
+ )
+
+ try:
+ engine._initialize_stages(stage_init_timeout=302)
+ finally:
+ if prev_device_env is None:
+ os.environ.pop(device_env_var, None)
+ else:
+ os.environ[device_env_var] = prev_device_env
+
+ assert captured_timeout == 302
+
+
+def test_launch_llm_stage_passes_stage_init_timeout_to_complete_stage_handshake(monkeypatch):
+ """Regression test for stage_init_timeout reaching complete_stage_handshake
+ in the LLM stage path.
+ """
+ import vllm_omni.engine.async_omni_engine as engine_mod
+ from vllm_omni.platforms import current_omni_platform
+
+ engine = object.__new__(AsyncOmniEngine)
+ engine.log_stats = False
+ engine.model = "dummy-model"
+ engine.single_stage_mode = False
+ engine._omni_master_server = None
+ engine.stage_configs = []
+
+ metadata = types.SimpleNamespace(stage_id=0, runtime_cfg={"devices": "0"})
+ fake_vllm_config = types.SimpleNamespace()
+ fake_addresses = types.SimpleNamespace()
+ fake_proc = types.SimpleNamespace()
+
+ captured_timeout = None
+
+ device_env_var = current_omni_platform.device_control_env_var
+ prev_device_env = os.environ.get(device_env_var)
+ os.environ[device_env_var] = "0"
+
+ monkeypatch.setattr(engine_mod, "setup_stage_devices", lambda *_: None)
+ monkeypatch.setattr(engine_mod, "build_engine_args_dict", lambda *_, **__: {})
+ monkeypatch.setattr(engine_mod, "build_vllm_config", lambda *_, **__: (fake_vllm_config, object))
+ monkeypatch.setattr(engine_mod, "acquire_device_locks", lambda *_: [])
+ monkeypatch.setattr(
+ engine_mod,
+ "spawn_stage_core",
+ lambda **_: (fake_addresses, fake_proc, "ipc://handshake"),
+ )
+
+ def _capture_stage_timeout(_proc, _handshake_addr, _addresses, _vllm_cfg, handshake_timeout):
+ nonlocal captured_timeout
+ captured_timeout = handshake_timeout
+
+ monkeypatch.setattr(engine_mod, "complete_stage_handshake", _capture_stage_timeout)
+
+ try:
+ engine._launch_llm_stage(
+ stage_cfg=types.SimpleNamespace(engine_args={}),
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=302,
+ llm_stage_launch_lock=threading.Lock(),
+ )
+ finally:
+ if prev_device_env is None:
+ os.environ.pop(device_env_var, None)
+ else:
+ os.environ[device_env_var] = prev_device_env
+
+ assert captured_timeout == 302
+
+
+def test_launch_llm_stage_releases_launch_lock_before_complete_stage_handshake(monkeypatch):
+ """Regression test for parallel LLM stage startup during handshake wait."""
+ import vllm_omni.engine.async_omni_engine as engine_mod
+ from vllm_omni.platforms import current_omni_platform
+
+ engine = object.__new__(AsyncOmniEngine)
+ engine.log_stats = False
+ engine.model = "dummy-model"
+ engine.single_stage_mode = False
+ engine._omni_master_server = None
+ engine.stage_configs = []
+
+ fake_vllm_config = types.SimpleNamespace()
+ fake_addresses = types.SimpleNamespace()
+ shared_launch_lock = threading.Lock()
+ counter_lock = threading.Lock()
+ first_handshake_started = threading.Event()
+ second_stage_spawned = threading.Event()
+ allow_first_handshake_to_finish = threading.Event()
+ launch_errors: list[BaseException] = []
+ spawn_count = 0
+
+ device_env_var = current_omni_platform.device_control_env_var
+ prev_device_env = os.environ.get(device_env_var)
+ os.environ[device_env_var] = "0"
+
+ monkeypatch.setattr(engine_mod, "setup_stage_devices", lambda *_: None)
+ monkeypatch.setattr(engine_mod, "build_engine_args_dict", lambda *_, **__: {})
+ monkeypatch.setattr(engine_mod, "build_vllm_config", lambda *_, **__: (fake_vllm_config, object))
+ monkeypatch.setattr(engine_mod, "acquire_device_locks", lambda *_: [])
+
+ def _spawn_stage_core(**_):
+ nonlocal spawn_count
+ with counter_lock:
+ spawn_count += 1
+ call_idx = spawn_count
+ if call_idx == 2:
+ second_stage_spawned.set()
+ return fake_addresses, types.SimpleNamespace(), f"ipc://handshake-{call_idx}"
+
+ def _complete_stage_handshake(_proc, handshake_address, _addresses, _vllm_cfg, _timeout):
+ if handshake_address == "ipc://handshake-1":
+ first_handshake_started.set()
+ assert second_stage_spawned.wait(timeout=1), (
+ "second stage did not reach spawn_stage_core while first stage waited in handshake"
+ )
+ assert allow_first_handshake_to_finish.wait(timeout=1), (
+ "second stage did not enter handshake while first stage was still waiting"
+ )
+ else:
+ allow_first_handshake_to_finish.set()
+
+ monkeypatch.setattr(engine_mod, "spawn_stage_core", _spawn_stage_core)
+ monkeypatch.setattr(engine_mod, "complete_stage_handshake", _complete_stage_handshake)
+
+ def _launch_stage(stage_id: int) -> None:
+ metadata = types.SimpleNamespace(stage_id=stage_id, runtime_cfg={"devices": str(stage_id)})
+ try:
+ engine._launch_llm_stage(
+ stage_cfg=types.SimpleNamespace(engine_args={}),
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=302,
+ llm_stage_launch_lock=shared_launch_lock,
+ )
+ except BaseException as exc: # pragma: no cover - surfaced through assertion below
+ launch_errors.append(exc)
+
+ try:
+ first_thread = threading.Thread(target=_launch_stage, args=(0,))
+ first_thread.start()
+ assert first_handshake_started.wait(timeout=1), "first stage never entered handshake"
+
+ second_thread = threading.Thread(target=_launch_stage, args=(1,))
+ second_thread.start()
+
+ first_thread.join(timeout=3)
+ second_thread.join(timeout=3)
+ finally:
+ if prev_device_env is None:
+ os.environ.pop(device_env_var, None)
+ else:
+ os.environ[device_env_var] = prev_device_env
+
+ assert not first_thread.is_alive()
+ assert not second_thread.is_alive()
+ assert second_stage_spawned.is_set()
+ assert not launch_errors
+
+
def test_attach_llm_stage_uses_omni_input_preprocessor(monkeypatch):
"""Regression test for GLM-Image t2i preprocessing path.
@@ -147,3 +380,70 @@ def __init__(self, vllm_config, renderer=None):
assert input_processor is not None
assert isinstance(input_processor.input_preprocessor, DummyOmniInputPreprocessor)
assert input_processor.input_preprocessor.renderer is input_processor.renderer
+
+
+def test_inject_kv_stage_info_infers_sender_tp_topology():
+ from vllm_omni.engine.stage_init_utils import inject_kv_stage_info
+
+ stage0 = types.SimpleNamespace(
+ stage_id=0,
+ engine_args={
+ "tensor_parallel_size": 4,
+ "omni_kv_config": {
+ "need_send_cache": True,
+ "omni_from_stage": "0",
+ "omni_to_stage": "1",
+ },
+ },
+ engine_input_source=[],
+ )
+ stage1 = types.SimpleNamespace(
+ stage_id=1,
+ engine_args={
+ "parallel_config": {
+ "tensor_parallel_size": 2,
+ "cfg_parallel_size": 1,
+ },
+ "omni_kv_config": {"need_recv_cache": True},
+ },
+ engine_input_source=[0],
+ )
+
+ inject_kv_stage_info(stage0, 0, [stage0, stage1])
+
+ assert stage0.engine_args["omni_kv_config"]["stage_id"] == 0
+ assert stage0.engine_args["omni_kv_config"]["rank_mapping"] == {"from_tp": 4, "to_tp": 2}
+
+
+def test_inject_kv_stage_info_infers_receiver_tp_topology():
+ from vllm_omni.engine.stage_init_utils import inject_kv_stage_info
+
+ stage0 = types.SimpleNamespace(
+ stage_id=0,
+ engine_args={
+ "tensor_parallel_size": 4,
+ "omni_kv_config": {"need_send_cache": True},
+ },
+ engine_input_source=[],
+ )
+ stage1 = types.SimpleNamespace(
+ stage_id=1,
+ engine_args={
+ "parallel_config": {
+ "tensor_parallel_size": 2,
+ "cfg_parallel_size": 1,
+ },
+ "omni_kv_config": {
+ "need_recv_cache": True,
+ "omni_from_stage": "0",
+ "omni_to_stage": "1",
+ },
+ },
+ engine_input_source=[0],
+ )
+
+ inject_kv_stage_info(stage1, 1, [stage0, stage1])
+
+ assert stage1.engine_args["omni_kv_config"]["stage_id"] == 1
+ assert stage1.engine_args["omni_kv_config"]["engine_input_source"] == [0]
+ assert stage1.engine_args["omni_kv_config"]["rank_mapping"] == {"from_tp": 4, "to_tp": 2}
diff --git a/tests/engine/test_cfg_companion_tracker.py b/tests/engine/test_cfg_companion_tracker.py
new file mode 100644
index 0000000000..f856a38c3e
--- /dev/null
+++ b/tests/engine/test_cfg_companion_tracker.py
@@ -0,0 +1,82 @@
+import pytest
+
+from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+def test_register_companion_and_cleanup():
+ tracker = CfgCompanionTracker()
+
+ tracker.register_companion("req1", "cfg_text", "req1__cfg_text")
+ tracker.register_companion("req1", "cfg_img", "req1__cfg_img")
+
+ assert tracker.is_companion("req1__cfg_text")
+ assert tracker.get_companion_request_ids("req1") == {
+ "cfg_text": "req1__cfg_text",
+ "cfg_img": "req1__cfg_img",
+ }
+
+ removed = tracker.cleanup_parent("req1")
+
+ assert sorted(removed) == ["req1__cfg_img", "req1__cfg_text"]
+ assert not tracker.is_companion("req1__cfg_text")
+ assert tracker.get_companion_request_ids("req1") == {}
+
+
+def test_attach_cfg_request_ids_clones_diffusion_params():
+ tracker = CfgCompanionTracker()
+ tracker.register_companion("req1", "cfg_text", "req1__cfg_text")
+
+ params = OmniDiffusionSamplingParams()
+ updated = tracker.attach_cfg_request_ids("req1", params)
+
+ assert updated is not params
+ assert params.cfg_kv_request_ids is None
+ assert updated.cfg_kv_request_ids == {"cfg_text": "req1__cfg_text"}
+
+
+def test_abort_parent_expands_to_companions_and_cleans_up_deferred_parent():
+ tracker = CfgCompanionTracker()
+ tracker.register_companion("req1", "cfg_text", "req1__cfg_text")
+ tracker.defer_parent("req1", {"out": 1}, stage_id=0)
+
+ aborted = tracker.abort_parents(["req1"])
+
+ assert sorted(aborted) == ["req1", "req1__cfg_text"]
+ assert not tracker.is_companion("req1__cfg_text")
+ assert tracker.pop_pending_parent("req1") is None
+
+
+def test_abort_companion_does_not_expand_to_parent():
+ tracker = CfgCompanionTracker()
+ tracker.register_companion("req1", "cfg_text", "req1__cfg_text")
+
+ aborted = tracker.abort_parents(["req1__cfg_text"])
+
+ assert aborted == ["req1__cfg_text"]
+
+
+def test_companion_completion_flushes_deferred_parent():
+ tracker = CfgCompanionTracker()
+ tracker.register_companion("req1", "cfg_text", "req1__cfg_text")
+ tracker.defer_parent("req1", {"out": 123}, stage_id=0)
+
+ assert not tracker.all_companions_done("req1")
+ assert tracker.on_companion_completed("req1__cfg_text") == "req1"
+ assert tracker.all_companions_done("req1")
+
+ popped = tracker.pop_pending_parent("req1")
+ assert popped is not None
+ assert popped["engine_outputs"] == {"out": 123}
+ assert popped["stage_id"] == 0
+
+
+def test_companion_completion_without_registered_parent_asserts():
+ tracker = CfgCompanionTracker()
+ tracker._companion_ids.add("req1__cfg_text")
+ tracker._companion_to_parent["req1__cfg_text"] = "req1"
+
+ with pytest.raises(AssertionError, match="completed before parent req1 was registered"):
+ tracker.on_companion_completed("req1__cfg_text")
diff --git a/tests/engine/test_orchestrator.py b/tests/engine/test_orchestrator.py
index 7bf2eccf7f..0b549f58e9 100644
--- a/tests/engine/test_orchestrator.py
+++ b/tests/engine/test_orchestrator.py
@@ -70,7 +70,7 @@ def get_diffusion_output_nowait(self):
def set_engine_outputs(self, outputs) -> None:
return None
- def process_engine_inputs(self, stage_list, prompt=None):
+ def process_engine_inputs(self, stage_list, prompt=None, streaming_context=None):
return list(self.next_inputs)
async def abort_requests_async(self, request_ids: list[str]) -> None:
diff --git a/tests/engine/test_orchestrator_kv_sender_info.py b/tests/engine/test_orchestrator_kv_sender_info.py
index 94da4ce717..7e3fe0906e 100644
--- a/tests/engine/test_orchestrator_kv_sender_info.py
+++ b/tests/engine/test_orchestrator_kv_sender_info.py
@@ -4,6 +4,7 @@
import pytest
from vllm import SamplingParams
+from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker
from vllm_omni.engine.orchestrator import Orchestrator, OrchestratorRequestState
from vllm_omni.engine.stage_engine_core_client import StageEngineCoreClient
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
@@ -130,7 +131,7 @@ def test_forward_to_diffusion_attaches_kv_sender_info():
orchestrator.num_stages = 2
orchestrator.stage_clients = [sender_stage, diffusion_stage]
- orchestrator._companion_map = {}
+ orchestrator._cfg_tracker = CfgCompanionTracker()
orchestrator.stage_vllm_configs = [None, None]
orchestrator.output_processors = [None, None]
@@ -161,7 +162,7 @@ def test_forward_to_diffusion_uses_engine_input_source_for_kv_sender_info():
orchestrator.num_stages = 3
orchestrator.stage_clients = [source_stage, previous_stage, diffusion_stage]
- orchestrator._companion_map = {}
+ orchestrator._cfg_tracker = CfgCompanionTracker()
orchestrator.stage_vllm_configs = [None, None, None]
orchestrator.output_processors = [None, None, None]
diff --git a/tests/engine/test_single_stage_mode.py b/tests/engine/test_single_stage_mode.py
index 627a98395f..28ccccaa2b 100644
--- a/tests/engine/test_single_stage_mode.py
+++ b/tests/engine/test_single_stage_mode.py
@@ -17,10 +17,11 @@
import threading
from contextlib import contextmanager
+from types import SimpleNamespace
from typing import Any
-from unittest.mock import MagicMock, Mock, patch
import pytest
+from pytest_mock import MockerFixture
from vllm.v1.engine.utils import EngineZmqAddresses
from vllm_omni.engine.async_omni_engine import AsyncOmniEngine
@@ -41,31 +42,33 @@
# ---------------------------------------------------------------------------
-def _make_stage_cfg(stage_id: int, stage_type: str = "llm") -> Mock:
+def _make_stage_cfg(stage_id: int, stage_type: str = "llm"):
"""Return a lightweight stage config mock."""
- cfg = Mock()
- cfg.stage_id = stage_id
- cfg.stage_type = stage_type
- cfg.engine_args = MagicMock()
- cfg.engine_args.async_chunk = False
- cfg.engine_args.model_stage = None
- cfg.engine_args.engine_output_type = None
- return cfg
+ return SimpleNamespace(
+ stage_id=stage_id,
+ stage_type=stage_type,
+ engine_args=SimpleNamespace(
+ async_chunk=False,
+ model_stage=None,
+ engine_output_type=None,
+ ),
+ )
def _make_started_llm_stage(stage_id: int) -> StartedLlmStage:
"""Return a minimal StartedLlmStage for mocking."""
- addresses = Mock()
- addresses.inputs = ["tcp://127.0.0.1:5000"]
- addresses.outputs = ["tcp://127.0.0.1:5001"]
- addresses.frontend_stats_publish_address = None
+ addresses = SimpleNamespace(
+ inputs=["tcp://127.0.0.1:5000"],
+ outputs=["tcp://127.0.0.1:5001"],
+ frontend_stats_publish_address=None,
+ )
return StartedLlmStage(
stage_id=stage_id,
- metadata=Mock(stage_id=stage_id),
- vllm_config=Mock(),
- executor_class=Mock(),
- engine_manager=Mock(),
- coordinator=Mock(),
+ metadata=SimpleNamespace(stage_id=stage_id),
+ vllm_config=SimpleNamespace(),
+ executor_class=SimpleNamespace(),
+ engine_manager=SimpleNamespace(),
+ coordinator=SimpleNamespace(),
addresses=addresses,
)
@@ -348,74 +351,80 @@ class TestSingleStageModeDetection:
the orchestrator thread, so no actual engines are started.
"""
- def _make_engine_no_thread(self, **kwargs: Any) -> AsyncOmniEngine:
+ def _make_engine_no_thread(self, mocker: MockerFixture, **kwargs: Any) -> AsyncOmniEngine:
"""Create an AsyncOmniEngine without starting the orchestrator thread."""
stage_cfg = _make_stage_cfg(0)
mock_stage_configs = [stage_cfg]
- with (
- patch.object(
- AsyncOmniEngine,
- "_resolve_stage_configs",
- return_value=("/fake/path", mock_stage_configs),
- ),
- patch.object(
- AsyncOmniEngine,
- "_bootstrap_orchestrator",
- ),
- patch("threading.Thread") as mock_thread_cls,
- patch("concurrent.futures.Future") as mock_future_cls,
- ):
- mock_future = Mock()
- mock_future.result.return_value = Mock() # simulates a loop
- mock_future_cls.return_value = mock_future
+ mocker.patch.object(
+ AsyncOmniEngine,
+ "_resolve_stage_configs",
+ return_value=("/fake/path", mock_stage_configs),
+ )
+ mocker.patch.object(
+ AsyncOmniEngine,
+ "_bootstrap_orchestrator",
+ )
+ mock_thread_cls = mocker.patch("threading.Thread")
+ mock_future_cls = mocker.patch("concurrent.futures.Future")
+
+ mock_future = mocker.Mock()
+ mock_future.result.return_value = mocker.Mock() # simulates a loop
+ mock_future_cls.return_value = mock_future
- mock_thread = Mock()
- mock_thread.is_alive.return_value = False
- mock_thread_cls.return_value = mock_thread
+ mock_thread = mocker.Mock()
+ mock_thread.is_alive.return_value = False
+ mock_thread_cls.return_value = mock_thread
- engine = AsyncOmniEngine(model="fake-model", **kwargs)
+ engine = AsyncOmniEngine(model="fake-model", **kwargs)
return engine
- def test_explicit_single_stage_mode_true(self):
+ def test_explicit_single_stage_mode_true(self, mocker: MockerFixture):
engine = self._make_engine_no_thread(
+ mocker,
single_stage_mode=True,
omni_master_address="127.0.0.1",
omni_master_port=20000,
)
assert engine.single_stage_mode is True
- def test_stage_id_kwarg_promotes_to_single_stage_mode(self):
+ def test_stage_id_kwarg_promotes_to_single_stage_mode(self, mocker: MockerFixture):
engine = self._make_engine_no_thread(
+ mocker,
stage_id=0,
omni_master_address="127.0.0.1",
omni_master_port=20001,
)
assert engine.single_stage_mode is True
- def test_stage_id_kwarg_sets_filter(self):
+ def test_stage_id_kwarg_sets_filter(self, mocker: MockerFixture):
engine = self._make_engine_no_thread(
+ mocker,
stage_id=1,
omni_master_address="127.0.0.1",
omni_master_port=20002,
)
assert engine._single_stage_id_filter == 1
- def test_no_stage_id_no_single_stage_mode(self):
- engine = self._make_engine_no_thread()
+ def test_no_stage_id_no_single_stage_mode(self, mocker: MockerFixture):
+ engine = self._make_engine_no_thread(
+ mocker,
+ )
assert engine.single_stage_mode is False
assert engine._single_stage_id_filter is None
- def test_single_stage_mode_without_stage_id_has_no_filter(self):
+ def test_single_stage_mode_without_stage_id_has_no_filter(self, mocker: MockerFixture):
engine = self._make_engine_no_thread(
+ mocker,
single_stage_mode=True,
omni_master_address="127.0.0.1",
omni_master_port=20003,
)
assert engine._single_stage_id_filter is None
- def test_master_address_and_port_stored(self):
+ def test_master_address_and_port_stored(self, mocker: MockerFixture):
engine = self._make_engine_no_thread(
+ mocker,
stage_id=0,
omni_master_address="10.0.0.1",
omni_master_port=12345,
@@ -423,8 +432,10 @@ def test_master_address_and_port_stored(self):
assert engine._omni_master_address == "10.0.0.1"
assert engine._omni_master_port == 12345
- def test_omni_master_server_starts_as_none(self):
- engine = self._make_engine_no_thread()
+ def test_omni_master_server_starts_as_none(self, mocker: MockerFixture):
+ engine = self._make_engine_no_thread(
+ mocker,
+ )
assert engine._omni_master_server is None
@@ -448,7 +459,7 @@ class TestInitializeStagesRouting:
def _build_engine_skeleton(
self,
- stage_cfgs: list[Mock],
+ stage_cfgs: list[Any],
single_stage_mode: bool,
stage_id_filter: int | None,
omni_master_address: str = "127.0.0.1",
@@ -478,8 +489,8 @@ def _build_engine_skeleton(
engine.prompt_expand_func = None
return engine
- def _fake_metadata(self, stage_id: int, stage_type: str = "llm") -> Mock:
- meta = Mock()
+ def _fake_metadata(self, mocker: MockerFixture, stage_id: int, stage_type: str = "llm") -> Any:
+ meta = mocker.Mock()
meta.stage_id = stage_id
meta.stage_type = stage_type
meta.runtime_cfg = {}
@@ -492,13 +503,14 @@ def _fake_metadata(self, stage_id: int, stage_type: str = "llm") -> Mock:
def _run_initialize_stages_mocked(
self,
+ mocker: MockerFixture,
engine: AsyncOmniEngine,
- stage_cfgs: list[Mock],
+ stage_cfgs: list[Any],
*,
launch_side_effect: Any = None,
remote_side_effect: Any = None,
attach_result: Any = None,
- ) -> tuple[Mock, Mock]:
+ ) -> tuple[Any, Any]:
"""Execute _initialize_stages with all heavy helpers mocked.
Returns (mock_launch_llm_stage, mock_create_remote_llm_stage).
@@ -509,167 +521,217 @@ def _run_initialize_stages_mocked(
if getattr(cfg, "stage_type", "llm") != "diffusion"
}
- default_attach = (Mock(), Mock(), Mock(), Mock())
+ default_attach = (mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock())
- mock_launch = Mock(
+ mock_launch = mocker.Mock(
side_effect=launch_side_effect
or (lambda cfg, meta, spec, timeout, llm_stage_launch_lock, kv: started_by_stage[meta.stage_id])
)
- mock_remote = Mock(
+ mock_remote = mocker.Mock(
side_effect=remote_side_effect or (lambda cfg, meta, spec, timeout, srv: started_by_stage[meta.stage_id])
)
- mock_attach = Mock(return_value=attach_result or default_attach)
+ mock_attach = mocker.Mock(return_value=attach_result or default_attach)
- mock_oms = Mock(spec=OmniMasterServer)
- mock_oms.get_zmq_addresses.side_effect = lambda sid: Mock()
+ mock_oms = mocker.Mock(spec=OmniMasterServer)
+ mock_oms.get_zmq_addresses.side_effect = lambda sid: mocker.Mock()
finalized = (
- [Mock() for _ in stage_cfgs],
- [Mock() for _ in stage_cfgs],
+ [mocker.Mock() for _ in stage_cfgs],
+ [mocker.Mock() for _ in stage_cfgs],
[{"final_output": True, "final_output_type": None, "stage_type": "llm"} for _ in stage_cfgs],
)
- with (
- patch.object(engine, "_launch_llm_stage", mock_launch),
- patch.object(engine, "_create_remote_llm_stage", mock_remote),
- patch.object(engine, "_attach_llm_stage", mock_attach),
- patch("vllm_omni.engine.async_omni_engine.OmniMasterServer", return_value=mock_oms),
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- return_value={},
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(cfg.stage_id, getattr(cfg, "stage_type", "llm")),
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- return_value=finalized,
+ mocker.patch.object(engine, "_launch_llm_stage", mock_launch)
+ mocker.patch.object(engine, "_create_remote_llm_stage", mock_remote)
+ mocker.patch.object(engine, "_attach_llm_stage", mock_attach)
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.OmniMasterServer",
+ return_value=mock_oms,
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.prepare_engine_environment",
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
+ return_value={},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
+ side_effect=lambda cfg: self._fake_metadata(
+ mocker,
+ cfg.stage_id,
+ getattr(cfg, "stage_type", "llm"),
),
- ):
- engine._initialize_stages(stage_init_timeout=60)
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
+ return_value=finalized,
+ )
+
+ engine._initialize_stages(stage_init_timeout=60)
return mock_launch, mock_remote
# -- single-stage mode: stage matches filter → local launch ---------------
- def test_matching_stage_uses_launch_llm_stage(self):
+ def test_matching_stage_uses_launch_llm_stage(self, mocker: MockerFixture):
"""stage_id == _single_stage_id_filter → _launch_llm_stage is called."""
stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(engine, stage_cfgs)
+ mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
launched_ids = [c.args[1].stage_id for c in mock_launch.call_args_list]
assert 0 in launched_ids, "_launch_llm_stage should be called for stage 0"
- def test_non_matching_stage_uses_create_remote_llm_stage(self):
+ def test_non_matching_stage_uses_create_remote_llm_stage(self, mocker: MockerFixture):
"""stage_id != _single_stage_id_filter → _create_remote_llm_stage is called."""
stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(engine, stage_cfgs)
+ mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
remote_ids = [c.args[1].stage_id for c in mock_remote.call_args_list]
assert 1 in remote_ids, "_create_remote_llm_stage should be called for stage 1"
- def test_filter_1_routes_correctly(self):
+ def test_filter_1_routes_correctly(self, mocker: MockerFixture):
"""With filter=1, stage 0 is remote and stage 1 is local."""
stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=1)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(engine, stage_cfgs)
+ mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
launched_ids = [c.args[1].stage_id for c in mock_launch.call_args_list]
remote_ids = [c.args[1].stage_id for c in mock_remote.call_args_list]
assert 1 in launched_ids, "stage 1 should be launched locally with filter=1"
assert 0 in remote_ids, "stage 0 should use remote path with filter=1"
- def test_no_filter_all_stages_use_launch_path(self):
+ def test_no_filter_all_stages_use_launch_path(self, mocker: MockerFixture):
"""single_stage_mode=True but no filter → all stages use _launch_llm_stage."""
stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=None)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(engine, stage_cfgs)
+ mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
assert mock_remote.call_count == 0, "No remote launches without a filter"
launched_ids = [c.args[1].stage_id for c in mock_launch.call_args_list]
assert set(launched_ids) == {0, 1}
- def test_non_single_stage_mode_never_calls_create_remote(self):
+ def test_non_single_stage_mode_never_calls_create_remote(self, mocker: MockerFixture):
"""Outside single_stage_mode, _create_remote_llm_stage must not be called."""
stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=False, stage_id_filter=None)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(engine, stage_cfgs)
+ mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
assert mock_remote.call_count == 0
- def test_omni_master_server_started_in_single_stage_mode(self):
+ def test_omni_master_server_started_in_single_stage_mode(self, mocker: MockerFixture):
"""OmniMasterServer.start() must be called when single_stage_mode=True."""
stage_cfgs = [_make_stage_cfg(0)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_oms = Mock(spec=OmniMasterServer)
- mock_oms.get_zmq_addresses.return_value = Mock()
- finalized = ([Mock()], [Mock()], [{"final_output": True, "final_output_type": None, "stage_type": "llm"}])
-
- with (
- patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0)),
- patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(0)),
- patch.object(engine, "_attach_llm_stage", return_value=(Mock(), Mock(), Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.OmniMasterServer", return_value=mock_oms),
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch("vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", return_value=None),
- patch("vllm_omni.engine.async_omni_engine.get_stage_connector_spec", return_value={}),
- patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", return_value=(None, None, None)
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(cfg.stage_id),
- ),
- patch("vllm_omni.engine.async_omni_engine.finalize_initialized_stages", return_value=finalized),
- ):
- engine._initialize_stages(stage_init_timeout=60)
+ mock_oms = mocker.Mock(spec=OmniMasterServer)
+ mock_oms.get_zmq_addresses.return_value = mocker.Mock()
+ finalized = (
+ [mocker.Mock()],
+ [mocker.Mock()],
+ [{"final_output": True, "final_output_type": None, "stage_type": "llm"}],
+ )
+
+ mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0))
+ mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(0))
+ mocker.patch.object(
+ engine,
+ "_attach_llm_stage",
+ return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.OmniMasterServer",
+ return_value=mock_oms,
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
+ return_value={},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
+ side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
+ return_value=finalized,
+ )
+
+ engine._initialize_stages(stage_init_timeout=60)
mock_oms.start.assert_called_once()
- def test_omni_master_server_uses_configured_stage_ids(self):
+ def test_omni_master_server_uses_configured_stage_ids(self, mocker: MockerFixture):
"""Configured stage IDs, not list indexes, should drive pre-allocation."""
stage_cfgs = [_make_stage_cfg(7), _make_stage_cfg(11)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=7)
- mock_oms = Mock(spec=OmniMasterServer)
- mock_oms.get_zmq_addresses.return_value = Mock()
+ mock_oms = mocker.Mock(spec=OmniMasterServer)
+ mock_oms.get_zmq_addresses.return_value = mocker.Mock()
finalized = (
- [Mock(), Mock()],
- [Mock(), Mock()],
+ [mocker.Mock(), mocker.Mock()],
+ [mocker.Mock(), mocker.Mock()],
[{"final_output": False, "final_output_type": None, "stage_type": "llm"} for _ in stage_cfgs],
)
- with (
- patch.object(
- engine, "_launch_llm_stage", side_effect=[_make_started_llm_stage(7), _make_started_llm_stage(11)]
- ),
- patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(11)),
- patch.object(engine, "_attach_llm_stage", return_value=(Mock(), Mock(), Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.OmniMasterServer", return_value=mock_oms) as mock_oms_cls,
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch("vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", return_value=None),
- patch("vllm_omni.engine.async_omni_engine.get_stage_connector_spec", return_value={}),
- patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", return_value=(None, None, None)
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(cfg.stage_id),
- ),
- patch("vllm_omni.engine.async_omni_engine.finalize_initialized_stages", return_value=finalized),
- ):
- engine._initialize_stages(stage_init_timeout=60)
+ mocker.patch.object(
+ engine,
+ "_launch_llm_stage",
+ side_effect=[_make_started_llm_stage(7), _make_started_llm_stage(11)],
+ )
+ mocker.patch.object(
+ engine,
+ "_create_remote_llm_stage",
+ return_value=_make_started_llm_stage(11),
+ )
+ mocker.patch.object(
+ engine,
+ "_attach_llm_stage",
+ return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
+ )
+ mock_oms_cls = mocker.patch(
+ "vllm_omni.engine.async_omni_engine.OmniMasterServer",
+ return_value=mock_oms,
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
+ return_value={},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
+ side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
+ return_value=finalized,
+ )
+
+ engine._initialize_stages(stage_init_timeout=60)
mock_oms_cls.assert_called_once_with(
master_address=engine._omni_master_address,
@@ -677,73 +739,121 @@ def test_omni_master_server_uses_configured_stage_ids(self):
stage_ids=[7, 11],
)
- def test_single_stage_filter_uses_configured_stage_ids(self):
+ def test_single_stage_filter_uses_configured_stage_ids(self, mocker: MockerFixture):
"""Local/remote dispatch should compare against configured stage IDs."""
stage_cfgs = [_make_stage_cfg(7), _make_stage_cfg(11)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=7)
- mock_oms = Mock(spec=OmniMasterServer)
+ mock_oms = mocker.Mock(spec=OmniMasterServer)
finalized = (
- [Mock(), Mock()],
- [Mock(), Mock()],
+ [mocker.Mock(), mocker.Mock()],
+ [mocker.Mock(), mocker.Mock()],
[{"final_output": False, "final_output_type": None, "stage_type": "llm"} for _ in stage_cfgs],
)
- with (
- patch.object(engine, "_launch_llm_stage", side_effect=[_make_started_llm_stage(7)]) as mock_launch,
- patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(11)) as mock_remote,
- patch.object(engine, "_attach_llm_stage", return_value=(Mock(), Mock(), Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.OmniMasterServer", return_value=mock_oms),
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch("vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", return_value=None),
- patch("vllm_omni.engine.async_omni_engine.get_stage_connector_spec", return_value={}),
- patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", return_value=(None, None, None)
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(cfg.stage_id),
- ),
- patch("vllm_omni.engine.async_omni_engine.finalize_initialized_stages", return_value=finalized),
- ):
- engine._initialize_stages(stage_init_timeout=60)
+ mock_launch = mocker.patch.object(
+ engine,
+ "_launch_llm_stage",
+ side_effect=[_make_started_llm_stage(7)],
+ )
+ mock_remote = mocker.patch.object(
+ engine,
+ "_create_remote_llm_stage",
+ return_value=_make_started_llm_stage(11),
+ )
+ mocker.patch.object(
+ engine,
+ "_attach_llm_stage",
+ return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.OmniMasterServer",
+ return_value=mock_oms,
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
+ return_value={},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
+ side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
+ return_value=finalized,
+ )
+
+ engine._initialize_stages(stage_init_timeout=60)
assert [call.args[1].stage_id for call in mock_launch.call_args_list] == [7]
assert [call.args[1].stage_id for call in mock_remote.call_args_list] == [11]
- def test_omni_master_server_preallocates_diffusion_stage_ids(self):
+ def test_omni_master_server_preallocates_diffusion_stage_ids(self, mocker: MockerFixture):
"""Diffusion stages should also receive OmniMasterServer allocations."""
stage_cfgs = [_make_stage_cfg(7), _make_stage_cfg(11, stage_type="diffusion")]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=7)
- mock_oms = Mock(spec=OmniMasterServer)
+ mock_oms = mocker.Mock(spec=OmniMasterServer)
finalized = (
- [Mock(), Mock()],
- [Mock(), Mock()],
+ [mocker.Mock(), mocker.Mock()],
+ [mocker.Mock(), mocker.Mock()],
[
{"final_output": False, "final_output_type": None, "stage_type": "llm"},
{"final_output": False, "final_output_type": None, "stage_type": "diffusion"},
],
)
- with (
- patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(7)),
- patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(7)),
- patch.object(engine, "_launch_diffusion_stage", return_value=Mock()),
- patch.object(engine, "_create_remote_diffusion_stage", return_value=Mock()),
- patch.object(engine, "_attach_llm_stage", return_value=(Mock(), Mock(), Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.OmniMasterServer", return_value=mock_oms) as mock_oms_cls,
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch("vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", return_value=None),
- patch("vllm_omni.engine.async_omni_engine.get_stage_connector_spec", return_value={}),
- patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", return_value=(None, None, None)
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(cfg.stage_id, getattr(cfg, "stage_type", "llm")),
+ mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(7))
+ mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(7))
+ mocker.patch.object(engine, "_launch_diffusion_stage", return_value=mocker.Mock())
+ mocker.patch.object(
+ engine,
+ "_create_remote_diffusion_stage",
+ return_value=mocker.Mock(),
+ )
+ mocker.patch.object(
+ engine,
+ "_attach_llm_stage",
+ return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
+ )
+ mock_oms_cls = mocker.patch(
+ "vllm_omni.engine.async_omni_engine.OmniMasterServer",
+ return_value=mock_oms,
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
+ return_value={},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
+ side_effect=lambda cfg: self._fake_metadata(
+ mocker,
+ cfg.stage_id,
+ getattr(cfg, "stage_type", "llm"),
),
- patch("vllm_omni.engine.async_omni_engine.finalize_initialized_stages", return_value=finalized),
- ):
- engine._initialize_stages(stage_init_timeout=60)
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
+ return_value=finalized,
+ )
+
+ engine._initialize_stages(stage_init_timeout=60)
mock_oms_cls.assert_called_once_with(
master_address=engine._omni_master_address,
@@ -751,135 +861,200 @@ def test_omni_master_server_preallocates_diffusion_stage_ids(self):
stage_ids=[7, 11],
)
- def test_duplicate_llm_stage_ids_raise(self):
+ def test_duplicate_llm_stage_ids_raise(self, mocker: MockerFixture):
"""Duplicate configured LLM stage IDs should fail fast."""
stage_cfgs = [_make_stage_cfg(3), _make_stage_cfg(3)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=3)
- with (
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch("vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", return_value=None),
- pytest.raises(ValueError, match="Duplicate stage_id"),
- ):
+ mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ with pytest.raises(ValueError, match="Duplicate stage_id"):
engine._initialize_stages(stage_init_timeout=60)
- def test_omni_master_server_not_started_in_normal_mode(self):
+ def test_omni_master_server_not_started_in_normal_mode(self, mocker: MockerFixture):
"""OmniMasterServer must NOT be instantiated outside single_stage_mode."""
stage_cfgs = [_make_stage_cfg(0)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=False, stage_id_filter=None)
- finalized = ([Mock()], [Mock()], [{"final_output": True, "final_output_type": None, "stage_type": "llm"}])
-
- with (
- patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0)),
- patch.object(engine, "_attach_llm_stage", return_value=(Mock(), Mock(), Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.OmniMasterServer") as mock_oms_cls,
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch("vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", return_value=None),
- patch("vllm_omni.engine.async_omni_engine.get_stage_connector_spec", return_value={}),
- patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", return_value=(None, None, None)
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(cfg.stage_id),
- ),
- patch("vllm_omni.engine.async_omni_engine.finalize_initialized_stages", return_value=finalized),
- ):
- engine._initialize_stages(stage_init_timeout=60)
+ finalized = (
+ [mocker.Mock()],
+ [mocker.Mock()],
+ [{"final_output": True, "final_output_type": None, "stage_type": "llm"}],
+ )
+
+ mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0))
+ mocker.patch.object(
+ engine,
+ "_attach_llm_stage",
+ return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
+ )
+ mock_oms_cls = mocker.patch("vllm_omni.engine.async_omni_engine.OmniMasterServer")
+ mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
+ return_value={},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
+ side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
+ return_value=finalized,
+ )
+
+ engine._initialize_stages(stage_init_timeout=60)
mock_oms_cls.assert_not_called()
- def test_single_stage_mode_missing_master_address_raises(self):
+ def test_single_stage_mode_missing_master_address_raises(self, mocker: MockerFixture):
"""single_stage_mode without master address/port raises ValueError."""
stage_cfgs = [_make_stage_cfg(0)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
engine._omni_master_address = None # missing
engine._omni_master_port = None
- with (
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch("vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", return_value=None),
- pytest.raises(ValueError, match="omni_master_address"),
- ):
+ mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ with pytest.raises(ValueError, match="omni_master_address"):
engine._initialize_stages(stage_init_timeout=60)
- def test_matching_diffusion_stage_uses_local_registered_launch(self):
+ def test_matching_diffusion_stage_uses_local_registered_launch(self, mocker: MockerFixture):
"""A local diffusion stage should use the registered single-stage launch path."""
stage_cfgs = [_make_stage_cfg(0, stage_type="diffusion"), _make_stage_cfg(1)]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_oms = Mock(spec=OmniMasterServer)
- diffusion_client = Mock(stage_type="diffusion")
+ mock_oms = mocker.Mock(spec=OmniMasterServer)
+ diffusion_client = mocker.Mock(stage_type="diffusion")
finalized = (
- [diffusion_client, Mock()],
- [Mock(), Mock()],
+ [diffusion_client, mocker.Mock()],
+ [mocker.Mock(), mocker.Mock()],
[
{"final_output": False, "final_output_type": None, "stage_type": "diffusion"},
{"final_output": False, "final_output_type": None, "stage_type": "llm"},
],
)
- with (
- patch.object(engine, "_launch_diffusion_stage", return_value=diffusion_client) as mock_local_diff,
- patch.object(engine, "_create_remote_diffusion_stage") as mock_remote_diff,
- patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(1)),
- patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(1)),
- patch.object(engine, "_attach_llm_stage", return_value=(Mock(), Mock(), Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.OmniMasterServer", return_value=mock_oms),
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch("vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", return_value=None),
- patch("vllm_omni.engine.async_omni_engine.get_stage_connector_spec", return_value={}),
- patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", return_value=(None, None, None)
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(cfg.stage_id, getattr(cfg, "stage_type", "llm")),
+ mock_local_diff = mocker.patch.object(
+ engine,
+ "_launch_diffusion_stage",
+ return_value=diffusion_client,
+ )
+ mock_remote_diff = mocker.patch.object(engine, "_create_remote_diffusion_stage")
+ mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(1))
+ mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(1))
+ mocker.patch.object(
+ engine,
+ "_attach_llm_stage",
+ return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.OmniMasterServer",
+ return_value=mock_oms,
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
+ return_value={},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
+ side_effect=lambda cfg: self._fake_metadata(
+ mocker,
+ cfg.stage_id,
+ getattr(cfg, "stage_type", "llm"),
),
- patch("vllm_omni.engine.async_omni_engine.finalize_initialized_stages", return_value=finalized),
- ):
- engine._initialize_stages(stage_init_timeout=60)
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
+ return_value=finalized,
+ )
+
+ engine._initialize_stages(stage_init_timeout=60)
assert mock_local_diff.call_count == 1
assert mock_local_diff.call_args.args[1].stage_id == 0
mock_remote_diff.assert_not_called()
- def test_non_matching_diffusion_stage_uses_remote_diffusion_client(self):
+ def test_non_matching_diffusion_stage_uses_remote_diffusion_client(self, mocker: MockerFixture):
"""A non-local diffusion stage should attach via the remote diffusion path."""
stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1, stage_type="diffusion")]
engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_oms = Mock(spec=OmniMasterServer)
- remote_diffusion_client = Mock(stage_type="diffusion")
+ mock_oms = mocker.Mock(spec=OmniMasterServer)
+ remote_diffusion_client = mocker.Mock(stage_type="diffusion")
finalized = (
- [Mock(), remote_diffusion_client],
- [Mock(), Mock()],
+ [mocker.Mock(), remote_diffusion_client],
+ [mocker.Mock(), mocker.Mock()],
[
{"final_output": False, "final_output_type": None, "stage_type": "llm"},
{"final_output": False, "final_output_type": None, "stage_type": "diffusion"},
],
)
- with (
- patch.object(engine, "_launch_diffusion_stage") as mock_local_diff,
- patch.object(
- engine, "_create_remote_diffusion_stage", return_value=remote_diffusion_client
- ) as mock_remote_diff,
- patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0)),
- patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(0)),
- patch.object(engine, "_attach_llm_stage", return_value=(Mock(), Mock(), Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.OmniMasterServer", return_value=mock_oms),
- patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment"),
- patch("vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", return_value=None),
- patch("vllm_omni.engine.async_omni_engine.get_stage_connector_spec", return_value={}),
- patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", return_value=(None, None, None)
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(cfg.stage_id, getattr(cfg, "stage_type", "llm")),
+ mock_local_diff = mocker.patch.object(engine, "_launch_diffusion_stage")
+ mock_remote_diff = mocker.patch.object(
+ engine,
+ "_create_remote_diffusion_stage",
+ return_value=remote_diffusion_client,
+ )
+ mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0))
+ mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(0))
+ mocker.patch.object(
+ engine,
+ "_attach_llm_stage",
+ return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.OmniMasterServer",
+ return_value=mock_oms,
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
+ return_value=None,
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
+ return_value={},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
+ side_effect=lambda cfg: self._fake_metadata(
+ mocker,
+ cfg.stage_id,
+ getattr(cfg, "stage_type", "llm"),
),
- patch("vllm_omni.engine.async_omni_engine.finalize_initialized_stages", return_value=finalized),
- ):
- engine._initialize_stages(stage_init_timeout=60)
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
+ return_value=finalized,
+ )
+
+ engine._initialize_stages(stage_init_timeout=60)
mock_local_diff.assert_not_called()
assert mock_remote_diff.call_count == 1
@@ -894,45 +1069,47 @@ def test_non_matching_diffusion_stage_uses_remote_diffusion_client(self):
class TestLaunchDiffusionStage:
"""Test local diffusion stage launch wiring."""
- def test_registers_stage_with_public_master_properties(self):
+ def test_registers_stage_with_public_master_properties(self, mocker: MockerFixture):
engine = object.__new__(AsyncOmniEngine)
engine.model = "fake-model"
engine.diffusion_batch_size = 4
stage_cfg = _make_stage_cfg(5, stage_type="diffusion")
- metadata = Mock(stage_id=5)
- omni_master_server = Mock(spec=OmniMasterServer)
+ metadata = mocker.Mock(stage_id=5)
+ omni_master_server = mocker.Mock(spec=OmniMasterServer)
omni_master_server.address = "127.0.0.1"
omni_master_server.port = 25000
- proc = Mock()
- diffusion_client = Mock()
-
- with (
- patch("vllm_omni.engine.async_omni_engine.build_diffusion_config", return_value="diffusion-config"),
- patch(
- "vllm_omni.engine.async_omni_engine.register_stage_with_omni_master",
- return_value=(
- "tcp://127.0.0.1:25001",
- "tcp://127.0.0.1:25002",
- "tcp://127.0.0.1:25003",
- ),
- ) as mock_register,
- patch(
- "vllm_omni.engine.async_omni_engine.spawn_diffusion_proc",
- return_value=(proc, None, None, None),
- ) as mock_spawn,
- patch("vllm_omni.engine.async_omni_engine.complete_diffusion_handshake") as mock_handshake,
- patch(
- "vllm_omni.engine.async_omni_engine.StageDiffusionClient.from_addresses",
- return_value=diffusion_client,
- ) as mock_from_addresses,
- ):
- result = engine._launch_diffusion_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- omni_master_server=omni_master_server,
- )
+ proc = mocker.Mock()
+ diffusion_client = mocker.Mock()
+
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_diffusion_config",
+ return_value="diffusion-config",
+ )
+ mock_register = mocker.patch(
+ "vllm_omni.engine.async_omni_engine.register_stage_with_omni_master",
+ return_value=(
+ "tcp://127.0.0.1:25001",
+ "tcp://127.0.0.1:25002",
+ "tcp://127.0.0.1:25003",
+ ),
+ )
+ mock_spawn = mocker.patch(
+ "vllm_omni.engine.async_omni_engine.spawn_diffusion_proc",
+ return_value=(proc, None, None, None),
+ )
+ mock_handshake = mocker.patch("vllm_omni.engine.async_omni_engine.complete_diffusion_handshake")
+ mock_from_addresses = mocker.patch(
+ "vllm_omni.engine.async_omni_engine.StageDiffusionClient.from_addresses",
+ return_value=diffusion_client,
+ )
+
+ result = engine._launch_diffusion_stage(
+ stage_cfg=stage_cfg,
+ metadata=metadata,
+ omni_master_server=omni_master_server,
+ )
mock_register.assert_called_once_with(
omni_master_address="127.0.0.1",
@@ -967,14 +1144,14 @@ def test_registers_stage_with_public_master_properties(self):
class TestCreateRemoteLlmStage:
"""Test _create_remote_llm_stage delegates correctly."""
- def _engine(self) -> AsyncOmniEngine:
+ def _engine(self, mocker: MockerFixture) -> AsyncOmniEngine:
engine = object.__new__(AsyncOmniEngine)
engine.model = "fake-model"
engine.single_stage_mode = True
engine._single_stage_id_filter = 0
- engine._omni_master_server = Mock(spec=OmniMasterServer)
- engine._omni_master_server.get_zmq_addresses.return_value = Mock()
- engine._omni_master_server.get_allocation.return_value = Mock()
+ engine._omni_master_server = mocker.Mock(spec=OmniMasterServer)
+ engine._omni_master_server.get_zmq_addresses.return_value = mocker.Mock()
+ engine._omni_master_server.get_allocation.return_value = mocker.Mock()
engine._omni_master_server.get_stage_config.return_value = {
"stage_id": 0,
"stage_type": "llm",
@@ -982,42 +1159,40 @@ def _engine(self) -> AsyncOmniEngine:
}
return engine
- @contextmanager
- def _patch_build_and_connect(self, stage_id: int):
- fake_vllm_config = Mock()
- fake_executor_cls = Mock()
- fake_addresses = Mock()
+ def _mock_build_and_connect(self, mocker: MockerFixture, stage_id: int):
+ fake_vllm_config = mocker.Mock()
+ fake_executor_cls = mocker.Mock()
+ fake_addresses = mocker.Mock()
fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
fake_addresses.frontend_stats_publish_address = None
- eng_mgr = Mock()
- coordinator = Mock()
+ eng_mgr = mocker.Mock()
+ coordinator = mocker.Mock()
@contextmanager
def fake_connect_cm(*args, **kwargs):
yield eng_mgr, coordinator, fake_addresses
- with (
- patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": stage_id},
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(fake_vllm_config, fake_executor_cls),
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.connect_remote_engine_cores",
- return_value=fake_connect_cm(),
- ) as mock_connect,
- ):
- yield mock_connect, fake_vllm_config, fake_executor_cls, fake_addresses
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
+ return_value={"model": "fake", "stage_id": stage_id},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_vllm_config",
+ return_value=(fake_vllm_config, fake_executor_cls),
+ )
+ mock_connect = mocker.patch(
+ "vllm_omni.engine.async_omni_engine.connect_remote_engine_cores",
+ return_value=fake_connect_cm(),
+ )
+
+ return mock_connect, fake_vllm_config, fake_executor_cls, fake_addresses
- def test_returns_started_llm_stage_with_correct_stage_id(self):
- engine = self._engine()
+ def test_returns_started_llm_stage_with_correct_stage_id(self, mocker: MockerFixture):
+ engine = self._engine(mocker)
stage_cfg = _make_stage_cfg(1)
- metadata = Mock(stage_id=1)
+ metadata = mocker.Mock(stage_id=1)
omni_ms = engine._omni_master_server
omni_ms.get_stage_config.return_value = {
"stage_id": 1,
@@ -1025,93 +1200,93 @@ def test_returns_started_llm_stage_with_correct_stage_id(self):
"engine_args": {},
}
- with self._patch_build_and_connect(1):
- result = engine._create_remote_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- omni_master_server=omni_ms,
- )
+ self._mock_build_and_connect(mocker, 1)
+ result = engine._create_remote_llm_stage(
+ stage_cfg=stage_cfg,
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=60,
+ omni_master_server=omni_ms,
+ )
assert isinstance(result, StartedLlmStage)
assert result.stage_id == 1
- def test_connect_remote_engine_cores_called_with_stage_id(self):
- engine = self._engine()
+ def test_connect_remote_engine_cores_called_with_stage_id(self, mocker: MockerFixture):
+ engine = self._engine(mocker)
stage_cfg = _make_stage_cfg(2)
- metadata = Mock(stage_id=2)
+ metadata = mocker.Mock(stage_id=2)
omni_ms = engine._omni_master_server
- omni_ms.get_zmq_addresses.return_value = Mock(inputs=["x"], outputs=["y"])
+ omni_ms.get_zmq_addresses.return_value = mocker.Mock(inputs=["x"], outputs=["y"])
omni_ms.get_stage_config.return_value = {
"stage_id": 2,
"stage_type": "llm",
"engine_args": {},
}
- fake_vllm_config = Mock()
- fake_executor_cls = Mock()
- fake_addresses = Mock()
+ fake_vllm_config = mocker.Mock()
+ fake_executor_cls = mocker.Mock()
+ fake_addresses = mocker.Mock()
fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
fake_addresses.frontend_stats_publish_address = None
@contextmanager
def fake_connect_cm(*args, **kwargs):
- yield Mock(), Mock(), fake_addresses
+ yield mocker.Mock(), mocker.Mock(), fake_addresses
- with (
- patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 2},
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(fake_vllm_config, fake_executor_cls),
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.connect_remote_engine_cores", return_value=fake_connect_cm()
- ) as mock_connect,
- ):
- engine._create_remote_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- omni_master_server=omni_ms,
- )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
+ return_value={"model": "fake", "stage_id": 2},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_vllm_config",
+ return_value=(fake_vllm_config, fake_executor_cls),
+ )
+ mock_connect = mocker.patch(
+ "vllm_omni.engine.async_omni_engine.connect_remote_engine_cores",
+ return_value=fake_connect_cm(),
+ )
+
+ engine._create_remote_llm_stage(
+ stage_cfg=stage_cfg,
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=60,
+ omni_master_server=omni_ms,
+ )
mock_connect.assert_called_once()
_, kwargs = mock_connect.call_args
assert kwargs.get("stage_id") == 2 or mock_connect.call_args.args[-1] == 2
omni_ms.get_stage_config.assert_called_once_with(2, timeout_s=60)
- def test_missing_registered_stage_config_raises_value_error(self):
- engine = self._engine()
+ def test_missing_registered_stage_config_raises_value_error(self, mocker: MockerFixture):
+ engine = self._engine(mocker)
stage_cfg = _make_stage_cfg(3)
- metadata = Mock(stage_id=3)
+ metadata = mocker.Mock(stage_id=3)
omni_ms = engine._omni_master_server
omni_ms.get_stage_config.return_value = None
- with patch("vllm_omni.engine.async_omni_engine.build_engine_args_dict") as mock_build_args:
- with pytest.raises(
- ValueError,
- match="Remote stage 3 registered without stage config",
- ):
- engine._create_remote_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- omni_master_server=omni_ms,
- )
+ mock_build_args = mocker.patch("vllm_omni.engine.async_omni_engine.build_engine_args_dict")
+ with pytest.raises(
+ ValueError,
+ match="Remote stage 3 registered without stage config",
+ ):
+ engine._create_remote_llm_stage(
+ stage_cfg=stage_cfg,
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=60,
+ omni_master_server=omni_ms,
+ )
mock_build_args.assert_not_called()
- def test_exception_during_connect_closes_started_stage(self):
+ def test_exception_during_connect_closes_started_stage(self, mocker: MockerFixture):
"""If an error occurs after StartedLlmStage creation, close_started_llm_stage is called."""
- engine = self._engine()
+ engine = self._engine(mocker)
stage_cfg = _make_stage_cfg(1)
- metadata = Mock(stage_id=1)
+ metadata = mocker.Mock(stage_id=1)
omni_ms = engine._omni_master_server
omni_ms.get_stage_config.return_value = {
"stage_id": 1,
@@ -1121,26 +1296,30 @@ def test_exception_during_connect_closes_started_stage(self):
@contextmanager
def boom(*args, **kwargs):
- yield Mock(), Mock(), Mock()
+ yield mocker.Mock(), mocker.Mock(), mocker.Mock()
raise RuntimeError("handshake failed")
- with (
- patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 1},
- ),
- patch("vllm_omni.engine.async_omni_engine.build_vllm_config", return_value=(Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.connect_remote_engine_cores", return_value=boom()),
- patch("vllm_omni.engine.async_omni_engine.close_started_llm_stage") as mock_close,
- ):
- with pytest.raises(RuntimeError, match="handshake failed"):
- engine._create_remote_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- omni_master_server=omni_ms,
- )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
+ return_value={"model": "fake", "stage_id": 1},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_vllm_config",
+ return_value=(mocker.Mock(), mocker.Mock()),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.connect_remote_engine_cores",
+ return_value=boom(),
+ )
+ mock_close = mocker.patch("vllm_omni.engine.async_omni_engine.close_started_llm_stage")
+ with pytest.raises(RuntimeError, match="handshake failed"):
+ engine._create_remote_llm_stage(
+ stage_cfg=stage_cfg,
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=60,
+ omni_master_server=omni_ms,
+ )
mock_close.assert_called_once()
@@ -1148,27 +1327,29 @@ class TestConnectRemoteEngineCoresCoordinator:
"""Test coordinator launch parity with launch_core_engines."""
@staticmethod
- def _build_vllm_config(*, dp_rank: int = 0, offline_mode: bool = False, needs_dp_coordinator: bool = True) -> Mock:
- parallel_config = Mock()
+ def _build_vllm_config(
+ mocker: MockerFixture, *, dp_rank: int = 0, offline_mode: bool = False, needs_dp_coordinator: bool = True
+ ) -> Any:
+ parallel_config = mocker.Mock()
parallel_config.data_parallel_size_local = 1
parallel_config.data_parallel_size = 2
parallel_config.data_parallel_rank = dp_rank
parallel_config.data_parallel_rank_local = 0 if offline_mode else None
- vllm_config = Mock()
+ vllm_config = mocker.Mock()
vllm_config.parallel_config = parallel_config
vllm_config.needs_dp_coordinator = needs_dp_coordinator
- vllm_config.model_config = Mock(is_moe=False)
+ vllm_config.model_config = mocker.Mock(is_moe=False)
return vllm_config
- def test_uses_registered_coordinator_addresses(self):
- vllm_config = self._build_vllm_config(dp_rank=0, offline_mode=False, needs_dp_coordinator=True)
+ def test_uses_registered_coordinator_addresses(self, mocker: MockerFixture):
+ vllm_config = self._build_vllm_config(mocker, dp_rank=0, offline_mode=False, needs_dp_coordinator=True)
- omni_master_server = Mock(spec=OmniMasterServer)
+ omni_master_server = mocker.Mock(spec=OmniMasterServer)
omni_master_server.get_zmq_addresses.return_value = EngineZmqAddresses(
inputs=["tcp://client-in"], outputs=["tcp://client-out"]
)
- omni_master_server.get_allocation.return_value = Mock(handshake_bind_address="tcp://127.0.0.1:26001")
+ omni_master_server.get_allocation.return_value = mocker.Mock(handshake_bind_address="tcp://127.0.0.1:26001")
omni_master_server.get_stage_coordinator_addresses.return_value = StageCoordinatorAddresses(
coordinator_input="tcp://coord-in",
coordinator_output="tcp://coord-out",
@@ -1177,103 +1358,107 @@ def test_uses_registered_coordinator_addresses(self):
@contextmanager
def fake_socket_ctx(*args, **kwargs):
- yield Mock()
+ yield mocker.Mock()
- with (
- patch("vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", return_value=fake_socket_ctx()),
- patch("vllm_omni.engine.stage_engine_startup._wait_for_omni_engine_startup") as mock_wait,
- ):
- with connect_remote_engine_cores(
- vllm_config=vllm_config,
- omni_master_server=omni_master_server,
- stage_id=7,
- ) as (_, yielded_coordinator, yielded_addresses):
- assert yielded_coordinator is None
- assert yielded_addresses.coordinator_input == "tcp://coord-in"
- assert yielded_addresses.coordinator_output == "tcp://coord-out"
- assert yielded_addresses.frontend_stats_publish_address == "tcp://stats"
+ mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx",
+ return_value=fake_socket_ctx(),
+ )
+ mock_wait = mocker.patch("vllm_omni.engine.stage_engine_startup._wait_for_omni_engine_startup")
+ with connect_remote_engine_cores(
+ vllm_config=vllm_config,
+ omni_master_server=omni_master_server,
+ stage_id=7,
+ ) as (_, yielded_coordinator, yielded_addresses):
+ assert yielded_coordinator is None
+ assert yielded_addresses.coordinator_input == "tcp://coord-in"
+ assert yielded_addresses.coordinator_output == "tcp://coord-out"
+ assert yielded_addresses.frontend_stats_publish_address == "tcp://stats"
omni_master_server.get_stage_coordinator_addresses.assert_called_once_with(7)
mock_wait.assert_called_once()
- def test_defaults_to_no_coordinator_addresses_when_none_registered(self):
+ def test_defaults_to_no_coordinator_addresses_when_none_registered(self, mocker: MockerFixture):
vllm_config = self._build_vllm_config(
+ mocker,
dp_rank=0,
offline_mode=False,
needs_dp_coordinator=True,
)
- omni_master_server = Mock(spec=OmniMasterServer)
+ omni_master_server = mocker.Mock(spec=OmniMasterServer)
omni_master_server.get_zmq_addresses.return_value = EngineZmqAddresses(
inputs=["tcp://client-in"], outputs=["tcp://client-out"]
)
- omni_master_server.get_allocation.return_value = Mock(handshake_bind_address="tcp://127.0.0.1:26001")
+ omni_master_server.get_allocation.return_value = mocker.Mock(handshake_bind_address="tcp://127.0.0.1:26001")
omni_master_server.get_stage_coordinator_addresses.return_value = StageCoordinatorAddresses()
@contextmanager
def fake_socket_ctx(*args, **kwargs):
- yield Mock()
+ yield mocker.Mock()
- with (
- patch("vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", return_value=fake_socket_ctx()),
- patch("vllm_omni.engine.stage_engine_startup._wait_for_omni_engine_startup"),
- ):
- with connect_remote_engine_cores(
- vllm_config=vllm_config,
- omni_master_server=omni_master_server,
- stage_id=7,
- ) as (_, yielded_coordinator, yielded_addresses):
- assert yielded_coordinator is None
- assert yielded_addresses.coordinator_input is None
- assert yielded_addresses.coordinator_output is None
- assert yielded_addresses.frontend_stats_publish_address is None
+ mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx",
+ return_value=fake_socket_ctx(),
+ )
+ mocker.patch("vllm_omni.engine.stage_engine_startup._wait_for_omni_engine_startup")
+ with connect_remote_engine_cores(
+ vllm_config=vllm_config,
+ omni_master_server=omni_master_server,
+ stage_id=7,
+ ) as (_, yielded_coordinator, yielded_addresses):
+ assert yielded_coordinator is None
+ assert yielded_addresses.coordinator_input is None
+ assert yielded_addresses.coordinator_output is None
+ assert yielded_addresses.frontend_stats_publish_address is None
class TestLaunchOmniCoreEngines:
"""Tests for local omni engine launch wiring."""
- def test_registers_stage_once_and_reuses_handshake_for_all_local_engines(self):
- parallel_config = Mock(
+ def test_registers_stage_once_and_reuses_handshake_for_all_local_engines(self, mocker: MockerFixture):
+ parallel_config = mocker.Mock(
data_parallel_size_local=2,
data_parallel_size=4,
data_parallel_rank=3,
)
- vllm_config = Mock(parallel_config=parallel_config)
+ vllm_config = mocker.Mock(parallel_config=parallel_config)
- omni_master_server = Mock(spec=OmniMasterServer)
+ omni_master_server = mocker.Mock(spec=OmniMasterServer)
omni_master_server.address = "127.0.0.1"
omni_master_server.port = 26000
- omni_master_server.get_allocation.return_value = Mock(handshake_bind_address="tcp://127.0.0.1:26001")
+ omni_master_server.get_allocation.return_value = mocker.Mock(handshake_bind_address="tcp://127.0.0.1:26001")
stage_config = {"stage_id": 7, "stage_type": "llm"}
- local_engine_manager = Mock()
+ local_engine_manager = mocker.Mock()
@contextmanager
def fake_socket_ctx(*args, **kwargs):
- yield Mock()
-
- with (
- patch(
- "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
- return_value="tcp://127.0.0.1:26001",
- ) as mock_register,
- patch("vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", return_value=fake_socket_ctx()),
- patch(
- "vllm_omni.engine.stage_engine_startup.CoreEngineProcManager",
- return_value=local_engine_manager,
- ) as mock_manager_cls,
- patch("vllm_omni.engine.stage_engine_startup.wait_for_engine_startup"),
- ):
- with launch_omni_core_engines(
- vllm_config=vllm_config,
- executor_class=Mock(),
- log_stats=False,
- omni_master_server=omni_master_server,
- stage_id=7,
- stage_config=stage_config,
- ) as (yielded_manager, yielded_coordinator, yielded_addresses):
- assert yielded_manager is local_engine_manager
- assert yielded_coordinator is None
+ yield mocker.Mock()
+
+ mock_register = mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
+ return_value="tcp://127.0.0.1:26001",
+ )
+ mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx",
+ return_value=fake_socket_ctx(),
+ )
+ mock_manager_cls = mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.CoreEngineProcManager",
+ return_value=local_engine_manager,
+ )
+ mocker.patch("vllm_omni.engine.stage_engine_startup.wait_for_engine_startup")
+ with launch_omni_core_engines(
+ vllm_config=vllm_config,
+ executor_class=mocker.Mock(),
+ log_stats=False,
+ omni_master_server=omni_master_server,
+ stage_id=7,
+ stage_config=stage_config,
+ ) as (yielded_manager, yielded_coordinator, yielded_addresses):
+ assert yielded_manager is local_engine_manager
+ assert yielded_coordinator is None
mock_register.assert_called_once_with(
omni_master_address="127.0.0.1",
@@ -1292,55 +1477,56 @@ def fake_socket_ctx(*args, **kwargs):
assert manager_kwargs["handshake_address"] == "tcp://127.0.0.1:26001"
assert manager_kwargs["executor_class"] is not None
- def test_registers_stage_with_coordinator_when_started(self):
- parallel_config = Mock(
+ def test_registers_stage_with_coordinator_when_started(self, mocker: MockerFixture):
+ parallel_config = mocker.Mock(
data_parallel_size_local=1,
data_parallel_size=2,
data_parallel_rank=0,
)
- vllm_config = Mock(parallel_config=parallel_config)
+ vllm_config = mocker.Mock(parallel_config=parallel_config)
vllm_config.needs_dp_coordinator = True
- vllm_config.model_config = Mock(is_moe=False)
+ vllm_config.model_config = mocker.Mock(is_moe=False)
- omni_master_server = Mock(spec=OmniMasterServer)
+ omni_master_server = mocker.Mock(spec=OmniMasterServer)
omni_master_server.address = "127.0.0.1"
omni_master_server.port = 26000
omni_master_server.get_zmq_addresses.return_value = EngineZmqAddresses(
inputs=["tcp://client-in"], outputs=["tcp://client-out"]
)
- omni_master_server.get_allocation.return_value = Mock(handshake_bind_address="tcp://127.0.0.1:26001")
+ omni_master_server.get_allocation.return_value = mocker.Mock(handshake_bind_address="tcp://127.0.0.1:26001")
- coordinator = Mock()
+ coordinator = mocker.Mock()
coordinator.proc.pid = 1234
coordinator.get_engine_socket_addresses.return_value = ("tcp://coord-in", "tcp://coord-out")
coordinator.get_stats_publish_address.return_value = "tcp://stats"
@contextmanager
def fake_socket_ctx(*args, **kwargs):
- yield Mock()
-
- with (
- patch("vllm_omni.engine.stage_engine_startup.DPCoordinator", return_value=coordinator),
- patch(
- "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
- return_value="tcp://127.0.0.1:26001",
- ) as mock_register,
- patch("vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", return_value=fake_socket_ctx()),
- patch(
- "vllm_omni.engine.stage_engine_startup.CoreEngineProcManager",
- return_value=Mock(),
- ) as mock_manager_cls,
- patch("vllm_omni.engine.stage_engine_startup.wait_for_engine_startup") as mock_wait,
+ yield mocker.Mock()
+
+ mocker.patch("vllm_omni.engine.stage_engine_startup.DPCoordinator", return_value=coordinator)
+ mock_register = mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
+ return_value="tcp://127.0.0.1:26001",
+ )
+ mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx",
+ return_value=fake_socket_ctx(),
+ )
+ mock_manager_cls = mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.CoreEngineProcManager",
+ return_value=mocker.Mock(),
+ )
+ mock_wait = mocker.patch("vllm_omni.engine.stage_engine_startup.wait_for_engine_startup")
+ with launch_omni_core_engines(
+ vllm_config=vllm_config,
+ executor_class=mocker.Mock(),
+ log_stats=False,
+ omni_master_server=omni_master_server,
+ stage_id=7,
+ stage_config={"stage_id": 7},
):
- with launch_omni_core_engines(
- vllm_config=vllm_config,
- executor_class=Mock(),
- log_stats=False,
- omni_master_server=omni_master_server,
- stage_id=7,
- stage_config={"stage_id": 7},
- ):
- pass
+ pass
mock_register.assert_called_once_with(
omni_master_address="127.0.0.1",
@@ -1363,19 +1549,20 @@ class TestLaunchLlmStageSingleStageMode:
"""Test that _launch_llm_stage selects launch_omni_core_engines when
single_stage_mode=True and _omni_master_server is set."""
- def _build_engine_with_oms(self) -> AsyncOmniEngine:
+ def _build_engine_with_oms(self, mocker: MockerFixture) -> AsyncOmniEngine:
engine = object.__new__(AsyncOmniEngine)
engine.model = "fake-model"
engine.single_stage_mode = True
engine._single_stage_id_filter = 0
engine._llm_stage_launch_lock = threading.Lock()
- mock_oms = Mock(spec=OmniMasterServer)
+ engine.stage_configs = []
+ mock_oms = mocker.Mock(spec=OmniMasterServer)
mock_oms.address = "127.0.0.1"
mock_oms.port = 25000
- alloc = Mock()
+ alloc = mocker.Mock()
alloc.handshake_bind_address = "tcp://127.0.0.1:25001"
mock_oms.get_allocation.return_value = alloc
- fake_addresses = Mock()
+ fake_addresses = mocker.Mock()
fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
fake_addresses.frontend_stats_publish_address = None
@@ -1383,110 +1570,107 @@ def _build_engine_with_oms(self) -> AsyncOmniEngine:
engine._omni_master_server = mock_oms
return engine
- @contextmanager
- def _patch_launch_omni_cm(self, stage_id: int):
- fake_vllm_config = Mock()
- fake_executor_cls = Mock()
- fake_addresses = Mock()
+ def _mock_launch_omni(self, mocker: MockerFixture, stage_id: int):
+ fake_vllm_config = mocker.Mock()
+ fake_executor_cls = mocker.Mock()
+ fake_addresses = mocker.Mock()
fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
fake_addresses.frontend_stats_publish_address = None
- eng_mgr = Mock()
+ eng_mgr = mocker.Mock()
@contextmanager
def fake_launch_omni(*args, **kwargs):
yield eng_mgr, None, fake_addresses
- with (
- patch("vllm_omni.engine.async_omni_engine.setup_stage_devices"),
- patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": stage_id},
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(fake_vllm_config, fake_executor_cls),
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.acquire_device_locks",
- return_value=[],
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.release_device_locks",
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
- return_value=fake_launch_omni(),
- ) as mock_launch_omni,
- ):
- yield mock_launch_omni
+ mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
+ return_value={"model": "fake", "stage_id": stage_id},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_vllm_config",
+ return_value=(fake_vllm_config, fake_executor_cls),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.acquire_device_locks",
+ return_value=[],
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
+ return mocker.patch(
+ "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
+ return_value=fake_launch_omni(),
+ )
- def test_launch_omni_core_engines_used_in_single_stage_mode(self):
+ def test_launch_omni_core_engines_used_in_single_stage_mode(self, mocker: MockerFixture):
"""single_stage_mode + _omni_master_server → launch_omni_core_engines."""
- engine = self._build_engine_with_oms()
- metadata = Mock(stage_id=0, runtime_cfg={})
+ engine = self._build_engine_with_oms(mocker)
+ metadata = mocker.Mock(stage_id=0, runtime_cfg={})
stage_cfg = _make_stage_cfg(0)
- with self._patch_launch_omni_cm(0) as mock_launch_omni:
- result = engine._launch_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- llm_stage_launch_lock=threading.Lock(),
- )
+ mock_launch_omni = self._mock_launch_omni(mocker, 0)
+ result = engine._launch_llm_stage(
+ stage_cfg=stage_cfg,
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=60,
+ llm_stage_launch_lock=threading.Lock(),
+ )
mock_launch_omni.assert_called_once()
assert mock_launch_omni.call_args.kwargs["stage_config"] is stage_cfg
assert isinstance(result, StartedLlmStage)
assert result.stage_id == 0
- def test_spawn_stage_core_used_in_normal_mode(self):
+ def test_spawn_stage_core_used_in_normal_mode(self, mocker: MockerFixture):
"""~single_stage_mode → spawn_stage_core + complete_stage_handshake."""
engine = object.__new__(AsyncOmniEngine)
engine.model = "fake-model"
engine.single_stage_mode = False
engine._omni_master_server = None
engine._llm_stage_launch_lock = threading.Lock()
+ engine.stage_configs = []
- fake_vllm_config = Mock()
- fake_executor_cls = Mock()
- fake_addresses = Mock()
+ fake_vllm_config = mocker.Mock()
+ fake_executor_cls = mocker.Mock()
+ fake_addresses = mocker.Mock()
fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
fake_addresses.frontend_stats_publish_address = None
- fake_proc = Mock()
+ fake_proc = mocker.Mock()
fake_handshake_address = "ipc:///tmp/fake-handshake"
+ stage_init_timeout = 60
- with (
- patch("vllm_omni.engine.async_omni_engine.setup_stage_devices"),
- patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 0},
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(fake_vllm_config, fake_executor_cls),
- ),
- patch("vllm_omni.engine.async_omni_engine.acquire_device_locks", return_value=[]),
- patch("vllm_omni.engine.async_omni_engine.release_device_locks"),
- patch(
- "vllm_omni.engine.async_omni_engine.spawn_stage_core",
- return_value=(fake_addresses, fake_proc, fake_handshake_address),
- ) as mock_spawn,
- patch("vllm_omni.engine.async_omni_engine.complete_stage_handshake") as mock_handshake,
- patch("vllm_omni.engine.async_omni_engine.launch_omni_core_engines") as mock_omni,
- ):
- metadata = Mock(stage_id=0, runtime_cfg={})
- result = engine._launch_llm_stage(
- stage_cfg=_make_stage_cfg(0),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- llm_stage_launch_lock=threading.Lock(),
- )
+ mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
+ return_value={"model": "fake", "stage_id": 0},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_vllm_config",
+ return_value=(fake_vllm_config, fake_executor_cls),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.acquire_device_locks",
+ return_value=[],
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
+ mock_spawn = mocker.patch(
+ "vllm_omni.engine.async_omni_engine.spawn_stage_core",
+ return_value=(fake_addresses, fake_proc, fake_handshake_address),
+ )
+ mock_handshake = mocker.patch("vllm_omni.engine.async_omni_engine.complete_stage_handshake")
+ mock_omni = mocker.patch("vllm_omni.engine.async_omni_engine.launch_omni_core_engines")
+ metadata = mocker.Mock(stage_id=0, runtime_cfg={})
+ result = engine._launch_llm_stage(
+ stage_cfg=_make_stage_cfg(0),
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=stage_init_timeout,
+ llm_stage_launch_lock=threading.Lock(),
+ )
mock_spawn.assert_called_once_with(
vllm_config=fake_vllm_config,
@@ -1498,55 +1682,64 @@ def test_spawn_stage_core_used_in_normal_mode(self):
fake_handshake_address,
fake_addresses,
fake_vllm_config,
+ stage_init_timeout,
)
mock_omni.assert_not_called()
assert isinstance(result, StartedLlmStage)
assert result.proc is fake_proc
- def test_launch_omni_passes_stage_id_and_master_server(self):
+ def test_launch_omni_passes_stage_id_and_master_server(self, mocker: MockerFixture):
"""launch_omni_core_engines receives the correct stage_id and omni_master_server."""
- engine = self._build_engine_with_oms()
- metadata = Mock(stage_id=0, runtime_cfg={})
+ engine = self._build_engine_with_oms(mocker)
+ metadata = mocker.Mock(stage_id=0, runtime_cfg={})
captured_kwargs: dict[str, Any] = {}
@contextmanager
def capturing_launch(*args, **kwargs):
captured_kwargs.update(kwargs)
- fake_addresses = Mock()
+ fake_addresses = mocker.Mock()
fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
fake_addresses.frontend_stats_publish_address = None
- yield Mock(), None, fake_addresses
+ yield mocker.Mock(), None, fake_addresses
- with (
- patch("vllm_omni.engine.async_omni_engine.setup_stage_devices"),
- patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 0},
- ),
- patch("vllm_omni.engine.async_omni_engine.build_vllm_config", return_value=(Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.acquire_device_locks", return_value=[]),
- patch("vllm_omni.engine.async_omni_engine.release_device_locks"),
- patch("vllm_omni.engine.async_omni_engine.launch_omni_core_engines", side_effect=capturing_launch),
- ):
- engine._launch_llm_stage(
- stage_cfg=_make_stage_cfg(0),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- llm_stage_launch_lock=threading.Lock(),
- )
+ mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
+ return_value={"model": "fake", "stage_id": 0},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_vllm_config",
+ return_value=(mocker.Mock(), mocker.Mock()),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.acquire_device_locks",
+ return_value=[],
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
+ side_effect=capturing_launch,
+ )
+
+ engine._launch_llm_stage(
+ stage_cfg=_make_stage_cfg(0),
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=60,
+ llm_stage_launch_lock=threading.Lock(),
+ )
assert captured_kwargs.get("stage_id") == 0
assert captured_kwargs.get("omni_master_server") is engine._omni_master_server
- def test_launch_omni_context_exits_before_stage_cleanup_on_error(self):
+ def test_launch_omni_context_exits_before_stage_cleanup_on_error(self, mocker: MockerFixture):
"""Errors after entering the omni launch context still unwind it first."""
- engine = self._build_engine_with_oms()
- metadata = Mock(stage_id=0, runtime_cfg={})
+ engine = self._build_engine_with_oms(mocker)
+ metadata = mocker.Mock(stage_id=0, runtime_cfg={})
- fake_addresses = Mock()
+ fake_addresses = mocker.Mock()
fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
fake_addresses.frontend_stats_publish_address = None
@@ -1556,47 +1749,51 @@ def test_launch_omni_context_exits_before_stage_cleanup_on_error(self):
@contextmanager
def fake_launch_omni(*args, **kwargs):
try:
- yield Mock(), None, fake_addresses
+ yield mocker.Mock(), None, fake_addresses
finally:
events.append("launch_exit")
- with (
- patch("vllm_omni.engine.async_omni_engine.setup_stage_devices"),
- patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 0},
- ),
- patch("vllm_omni.engine.async_omni_engine.build_vllm_config", return_value=(Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.acquire_device_locks", return_value=[]),
- patch("vllm_omni.engine.async_omni_engine.release_device_locks"),
- patch(
- "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
- return_value=fake_launch_omni(),
- ),
- patch("vllm_omni.engine.async_omni_engine.logger.info", side_effect=RuntimeError("boom")),
- patch(
- "vllm_omni.engine.async_omni_engine.close_started_llm_stage",
- side_effect=lambda _started: events.append("stage_close"),
- ) as mock_close_stage,
- ):
- with pytest.raises(RuntimeError, match="boom"):
- engine._launch_llm_stage(
- stage_cfg=_make_stage_cfg(0),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- llm_stage_launch_lock=threading.Lock(),
- )
+ mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
+ return_value={"model": "fake", "stage_id": 0},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_vllm_config",
+ return_value=(mocker.Mock(), mocker.Mock()),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.acquire_device_locks",
+ return_value=[],
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
+ return_value=fake_launch_omni(),
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.logger.info", side_effect=RuntimeError("boom"))
+ mock_close_stage = mocker.patch(
+ "vllm_omni.engine.async_omni_engine.close_started_llm_stage",
+ side_effect=lambda _started: events.append("stage_close"),
+ )
+ with pytest.raises(RuntimeError, match="boom"):
+ engine._launch_llm_stage(
+ stage_cfg=_make_stage_cfg(0),
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=60,
+ llm_stage_launch_lock=threading.Lock(),
+ )
mock_close_stage.assert_called_once()
assert events == ["launch_exit", "stage_close"]
- def test_base_exception_propagates_without_started_stage_cleanup(self):
+ def test_base_exception_propagates_without_started_stage_cleanup(self, mocker: MockerFixture):
"""BaseException subclasses should bypass the Exception cleanup path."""
- engine = self._build_engine_with_oms()
- metadata = Mock(stage_id=0, runtime_cfg={})
+ engine = self._build_engine_with_oms(mocker)
+ metadata = mocker.Mock(stage_id=0, runtime_cfg={})
- fake_addresses = Mock()
+ fake_addresses = mocker.Mock()
fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
fake_addresses.frontend_stats_publish_address = None
@@ -1609,37 +1806,41 @@ class FatalLaunchInterrupt(BaseException):
@contextmanager
def fake_launch_omni(*args, **kwargs):
try:
- yield Mock(), None, fake_addresses
+ yield mocker.Mock(), None, fake_addresses
finally:
events.append("launch_exit")
- with (
- patch("vllm_omni.engine.async_omni_engine.setup_stage_devices"),
- patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 0},
- ),
- patch("vllm_omni.engine.async_omni_engine.build_vllm_config", return_value=(Mock(), Mock())),
- patch("vllm_omni.engine.async_omni_engine.acquire_device_locks", return_value=[]),
- patch("vllm_omni.engine.async_omni_engine.release_device_locks"),
- patch(
- "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
- return_value=fake_launch_omni(),
- ),
- patch(
- "vllm_omni.engine.async_omni_engine.logger.info",
- side_effect=FatalLaunchInterrupt("stop"),
- ),
- patch("vllm_omni.engine.async_omni_engine.close_started_llm_stage") as mock_close_stage,
- ):
- with pytest.raises(FatalLaunchInterrupt, match="stop"):
- engine._launch_llm_stage(
- stage_cfg=_make_stage_cfg(0),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- llm_stage_launch_lock=threading.Lock(),
- )
+ mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
+ return_value={"model": "fake", "stage_id": 0},
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.build_vllm_config",
+ return_value=(mocker.Mock(), mocker.Mock()),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.acquire_device_locks",
+ return_value=[],
+ )
+ mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
+ return_value=fake_launch_omni(),
+ )
+ mocker.patch(
+ "vllm_omni.engine.async_omni_engine.logger.info",
+ side_effect=FatalLaunchInterrupt("stop"),
+ )
+ mock_close_stage = mocker.patch("vllm_omni.engine.async_omni_engine.close_started_llm_stage")
+ with pytest.raises(FatalLaunchInterrupt, match="stop"):
+ engine._launch_llm_stage(
+ stage_cfg=_make_stage_cfg(0),
+ metadata=metadata,
+ stage_connector_spec={},
+ stage_init_timeout=60,
+ llm_stage_launch_lock=threading.Lock(),
+ )
mock_close_stage.assert_not_called()
assert events == ["launch_exit"]
diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py
index c91c5a5c75..607b3eaa81 100644
--- a/tests/entrypoints/openai_api/test_image_server.py
+++ b/tests/entrypoints/openai_api/test_image_server.py
@@ -106,10 +106,13 @@ def test_encode_image_base64():
class MockGenerationResult:
- """Mock result object from AsyncOmni.generate()"""
+ """Mock result object compatible with current diffusion output shape."""
def __init__(self, images):
self.images = images
+ self.request_output = SimpleNamespace(images=images)
+ self.stage_durations = {}
+ self.peak_memory_mb = 0.0
class FakeAsyncOmni:
@@ -117,20 +120,26 @@ class FakeAsyncOmni:
def __init__(self, images=None):
self.stage_configs = [
- SimpleNamespace(stage_type="llm"),
- SimpleNamespace(stage_type="diffusion"),
+ SimpleNamespace(stage_type="llm", is_comprehension=True),
+ SimpleNamespace(stage_type="diffusion", is_comprehension=False),
]
self.default_sampling_params_list = [SamplingParams(temperature=0.1), OmniDiffusionSamplingParams()]
self.captured_sampling_params_list = None
self.captured_prompt = None
self._images = images or [Image.new("RGB", (64, 64), color="green")]
- async def generate(self, prompt, request_id, sampling_params_list):
- self.captured_sampling_params_list = sampling_params_list
+ async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None):
+ if sampling_params_list is not None:
+ self.captured_sampling_params_list = sampling_params_list
+ else:
+ self.captured_sampling_params_list = [sampling_params]
self.captured_prompt = prompt
images = [img.copy() for img in self._images]
yield MockGenerationResult(images)
+ def __class_getitem__(cls, item):
+ return cls
+
@pytest.fixture
def mock_async_diffusion(mocker: MockerFixture):
@@ -177,7 +186,7 @@ def test_client(mock_async_diffusion):
[BaseModelPath(name="Qwen/Qwen-Image", model_path="Qwen/Qwen-Image")]
)
app.state.args = Namespace(
- default_sampling_params='{"0": {"num_inference_steps":4, "guidance_scale":7.5}}',
+ default_sampling_params='{"0": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}',
max_generated_image_size=1024 * 1792,
)
@@ -189,18 +198,60 @@ def async_omni_test_client():
"""Create test client with mocked AsyncOmni engine."""
from fastapi import FastAPI
+ from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.openai.api_server import router
+ from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
+
+ class FakeAsyncOmniClass(AsyncOmni):
+ def __init__(self):
+ stage_configs = [
+ SimpleNamespace(stage_type="llm", is_comprehension=True),
+ SimpleNamespace(stage_type="diffusion", is_comprehension=False),
+ ]
+ default_sampling_params_list = [
+ SamplingParams(temperature=0.1),
+ OmniDiffusionSamplingParams(),
+ ]
+ self.engine = SimpleNamespace(
+ stage_configs=stage_configs,
+ default_sampling_params_list=default_sampling_params_list,
+ )
+ self.default_sampling_params_list = default_sampling_params_list
+ self.captured_sampling_params_list = None
+ self.captured_prompt = None
+ self._images = [Image.new("RGB", (64, 64), color="green")]
+ self.od_config = SimpleNamespace(supports_multimodal_inputs=True)
+
+ async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None):
+ if sampling_params_list is not None:
+ self.captured_sampling_params_list = sampling_params_list
+ else:
+ self.captured_sampling_params_list = [sampling_params]
+ self.captured_prompt = prompt
+ images = [img.copy() for img in self._images]
+ yield MockGenerationResult(images)
+
+ def __class_getitem__(cls, item):
+ return cls
+
+ def get_diffusion_od_config(self):
+ return self.od_config
app = FastAPI()
app.include_router(router)
- app.state.engine_client = FakeAsyncOmni()
+ engine = FakeAsyncOmniClass()
+ chat_handler = object.__new__(OmniOpenAIServingChat)
+ chat_handler.engine_client = engine
+ chat_handler._diffusion_engine = None
+ app.state.openai_serving_chat = chat_handler
+ app.state.engine_client = engine
app.state.stage_configs = [
SimpleNamespace(stage_type="llm"),
SimpleNamespace(stage_type="diffusion"),
]
app.state.args = Namespace(
- default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}',
+ default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}',
max_generated_image_size=1048576, # 1024*1024 to support resolution tests
)
return TestClient(app)
@@ -211,18 +262,60 @@ def async_omni_rgba_test_client():
"""Create test client with mocked AsyncOmni engine returning RGBA output."""
from fastapi import FastAPI
+ from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.openai.api_server import router
+ from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
+
+ class FakeAsyncOmniClass(AsyncOmni):
+ def __init__(self):
+ stage_configs = [
+ SimpleNamespace(stage_type="llm", is_comprehension=True),
+ SimpleNamespace(stage_type="diffusion", is_comprehension=False),
+ ]
+ default_sampling_params_list = [
+ SamplingParams(temperature=0.1),
+ OmniDiffusionSamplingParams(),
+ ]
+ self.engine = SimpleNamespace(
+ stage_configs=stage_configs,
+ default_sampling_params_list=default_sampling_params_list,
+ )
+ self.default_sampling_params_list = default_sampling_params_list
+ self.captured_sampling_params_list = None
+ self.captured_prompt = None
+ self._images = [Image.new("RGBA", (64, 64), color=(0, 255, 0, 128))]
+ self.od_config = SimpleNamespace(supports_multimodal_inputs=True)
+
+ async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None):
+ if sampling_params_list is not None:
+ self.captured_sampling_params_list = sampling_params_list
+ else:
+ self.captured_sampling_params_list = [sampling_params]
+ self.captured_prompt = prompt
+ images = [img.copy() for img in self._images]
+ yield MockGenerationResult(images)
+
+ def __class_getitem__(cls, item):
+ return cls
+
+ def get_diffusion_od_config(self):
+ return self.od_config
app = FastAPI()
app.include_router(router)
- app.state.engine_client = FakeAsyncOmni(images=[Image.new("RGBA", (64, 64), color=(0, 255, 0, 128))])
+ engine = FakeAsyncOmniClass()
+ chat_handler = object.__new__(OmniOpenAIServingChat)
+ chat_handler.engine_client = engine
+ chat_handler._diffusion_engine = None
+ app.state.openai_serving_chat = chat_handler
+ app.state.engine_client = engine
app.state.stage_configs = [
SimpleNamespace(stage_type="llm"),
SimpleNamespace(stage_type="diffusion"),
]
app.state.args = Namespace(
- default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}',
+ default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}',
max_generated_image_size=1048576,
)
return TestClient(app)
@@ -233,18 +326,57 @@ def async_omni_stage_configs_only_client():
"""Create test client with refactored AsyncOmni compatibility surface only."""
from fastapi import FastAPI
+ from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.openai.api_server import router
+ from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
+
+ class FakeAsyncOmniClass(AsyncOmni):
+ def __init__(self):
+ stage_configs = [
+ SimpleNamespace(stage_type="llm", is_comprehension=True),
+ SimpleNamespace(stage_type="diffusion", is_comprehension=False),
+ ]
+ default_sampling_params_list = [
+ SamplingParams(temperature=0.1),
+ OmniDiffusionSamplingParams(),
+ ]
+ self.engine = SimpleNamespace(
+ stage_configs=stage_configs,
+ default_sampling_params_list=default_sampling_params_list,
+ )
+ self.default_sampling_params_list = default_sampling_params_list
+ self.captured_sampling_params_list = None
+ self.captured_prompt = None
+ self._images = [Image.new("RGB", (64, 64), color="green")]
+ self.od_config = SimpleNamespace(supports_multimodal_inputs=True)
+
+ async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None):
+ if sampling_params_list is not None:
+ self.captured_sampling_params_list = sampling_params_list
+ else:
+ self.captured_sampling_params_list = [sampling_params]
+ self.captured_prompt = prompt
+ images = [img.copy() for img in self._images]
+ yield MockGenerationResult(images)
+
+ def __class_getitem__(cls, item):
+ return cls
+
+ def get_diffusion_od_config(self):
+ return self.od_config
app = FastAPI()
app.include_router(router)
- engine = FakeAsyncOmni()
+ engine = FakeAsyncOmniClass()
assert not hasattr(engine, "stage_list")
app.state.engine_client = engine
- # Intentionally do not populate app.state.stage_configs. Refactored
- # AsyncOmni exposes stage_configs on the engine instance.
+ chat_handler = object.__new__(OmniOpenAIServingChat)
+ chat_handler.engine_client = engine
+ chat_handler._diffusion_engine = None
+ app.state.openai_serving_chat = chat_handler
app.state.args = Namespace(
- default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}',
+ default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}',
max_generated_image_size=1024 * 1792,
)
return TestClient(app)
@@ -306,6 +438,9 @@ def test_models_endpoint_no_engine():
def test_generate_single_image(test_client):
"""Test generating a single image"""
+ # Single-stage path should not require openai_serving_chat.
+ assert not hasattr(test_client.app.state, "openai_serving_chat")
+
response = test_client.post(
"/v1/images/generations",
json={
@@ -374,6 +509,43 @@ def test_generate_images_async_omni_stage_configs_only(async_omni_stage_configs_
assert captured[1].seed == 11
+def test_multistage_images_async_omni_construction(async_omni_test_client):
+ """Regression: multistage image generation builds the expected chat-style payload."""
+ response = async_omni_test_client.post(
+ "/v1/images/generations",
+ json={
+ "prompt": "a cat",
+ "n": 2,
+ "size": "128x256",
+ "seed": 7,
+ "num_inference_steps": 12,
+ "guidance_scale": 6.5,
+ },
+ )
+ assert response.status_code == 200
+
+ engine = async_omni_test_client.app.state.engine_client
+ captured_prompt = engine.captured_prompt
+ assert captured_prompt["prompt"] == "a cat"
+ assert captured_prompt["modalities"] == ["image"]
+ assert captured_prompt["mm_processor_kwargs"] == {
+ "target_h": 256,
+ "target_w": 128,
+ }
+
+ captured = engine.captured_sampling_params_list
+ assert captured is not None
+ assert len(captured) == 2
+ assert captured[0].temperature == 0.1
+ assert captured[0].seed == 7
+ assert captured[1].num_outputs_per_prompt == 2
+ assert captured[1].width == 128
+ assert captured[1].height == 256
+ assert captured[1].seed == 7
+ assert captured[1].num_inference_steps == 12
+ assert captured[1].guidance_scale == 6.5
+
+
def test_image_edits_async_omni_stage_configs_only(async_omni_stage_configs_only_client):
"""Regression: image edits accepts refactored AsyncOmni without stage_list."""
img_bytes = make_test_image_bytes((16, 16))
@@ -679,6 +851,19 @@ def test_model_field_omitted_works(test_client):
assert response.status_code == 200
+def test_generate_images_rejects_model_mismatch(test_client):
+ response = test_client.post(
+ "/v1/images/generations",
+ json={
+ "prompt": "test",
+ "model": "Qwen/Qwen-Image-2512",
+ "size": "1024x1024",
+ },
+ )
+ assert response.status_code == 400
+ assert "model mismatch" in response.json()["detail"].lower()
+
+
def make_test_image_bytes(size=(64, 64)) -> bytes:
img = Image.new(
"RGB",
@@ -782,6 +967,77 @@ def test_image_edit_rejects_multiple_images_when_model_does_not_support_them(asy
assert engine.captured_prompt is None
+def test_image_edit_rejects_model_mismatch(test_client):
+ img_bytes = make_test_image_bytes((16, 16))
+ response = test_client.post(
+ "/v1/images/edits",
+ files=[("image", img_bytes)],
+ data={
+ "prompt": "edit me",
+ "model": "Qwen/Qwen-Image-Edit",
+ },
+ )
+ assert response.status_code == 400
+ assert "model mismatch" in response.json()["detail"].lower()
+
+
+def test_image_edit_rejects_too_many_images_for_qwen_image_edit_2511(async_omni_test_client):
+ engine = async_omni_test_client.app.state.engine_client
+ engine.get_diffusion_od_config = lambda: SimpleNamespace(
+ supports_multimodal_inputs=True,
+ max_multimodal_image_inputs=4,
+ )
+
+ response = async_omni_test_client.post(
+ "/v1/images/edits",
+ files=[
+ ("image", make_test_image_bytes((16, 16))),
+ ("image", make_test_image_bytes((16, 16))),
+ ("image", make_test_image_bytes((16, 16))),
+ ("image", make_test_image_bytes((16, 16))),
+ ("image", make_test_image_bytes((16, 16))),
+ ],
+ data={"prompt": "hello world."},
+ )
+
+ assert response.status_code == 400
+ assert response.json()["detail"] == "Received 5 input images. At most 4 images are supported by this model."
+ assert engine.captured_prompt is None
+
+
+def test_image_edit_rejects_too_many_images_for_qwen_image_edit_2511_before_loading(
+ async_omni_test_client, monkeypatch: pytest.MonkeyPatch
+):
+ import vllm_omni.entrypoints.openai.api_server as api_server_module
+
+ engine = async_omni_test_client.app.state.engine_client
+ engine.get_diffusion_od_config = lambda: SimpleNamespace(
+ supports_multimodal_inputs=True,
+ max_multimodal_image_inputs=4,
+ )
+
+ def _fail_load(*args, **kwargs):
+ raise AssertionError("_load_input_images should not run for over-limit requests")
+
+ monkeypatch.setattr(api_server_module, "_load_input_images", _fail_load)
+
+ response = async_omni_test_client.post(
+ "/v1/images/edits",
+ files=[
+ ("image", make_test_image_bytes((16, 16))),
+ ("image", make_test_image_bytes((16, 16))),
+ ("image", make_test_image_bytes((16, 16))),
+ ("image", make_test_image_bytes((16, 16))),
+ ("image", make_test_image_bytes((16, 16))),
+ ],
+ data={"prompt": "hello world."},
+ )
+
+ assert response.status_code == 400
+ assert response.json()["detail"] == "Received 5 input images. At most 4 images are supported by this model."
+ assert engine.captured_prompt is None
+
+
def test_image_edit_parameter_pass(async_omni_test_client):
img_bytes_1 = make_test_image_bytes((16, 16))
@@ -960,6 +1216,7 @@ def test_image_edit_parameter_default(async_omni_test_client):
assert captured_sampling_params.num_outputs_per_prompt == 1
assert captured_sampling_params.num_inference_steps == 4
assert captured_sampling_params.guidance_scale == 7.5
+ assert captured_sampling_params.generator_device == "cpu"
# Test that a size exceeding max_generated_image_size returns 400
response = async_omni_test_client.post(
@@ -993,6 +1250,7 @@ def test_image_edit_parameter_default_single_stage(test_client):
assert captured_sampling_params.num_outputs_per_prompt == 1
assert captured_sampling_params.num_inference_steps == 4
assert captured_sampling_params.guidance_scale == 7.5
+ assert captured_sampling_params.generator_device == "cpu"
# Size exceeding max_generated_image_size (1024*1792) returns 400
response = test_client.post(
@@ -1165,3 +1423,91 @@ def test_image_edit_with_seed_zero_single_stage(test_client):
f"Expected seed=0, but got seed={captured_sampling_params.seed}. "
"This indicates the bug where seed=0 is treated as falsy."
)
+
+
+def test_normalize_image():
+ """Test _normalize_image with various input types"""
+ import numpy as np
+
+ from vllm_omni.entrypoints.openai.api_server import _normalize_image
+
+ # Test PIL Image input
+ img = Image.new("RGB", (64, 64), color="red")
+ result = _normalize_image(img)
+ assert isinstance(result, Image.Image)
+ assert result.size == (64, 64)
+
+ # Test uint8 numpy array
+ arr = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
+ result = _normalize_image(arr)
+ assert isinstance(result, Image.Image)
+ assert result.size == (64, 64)
+
+ # Test float [0, 1] numpy array
+ arr = np.random.rand(64, 64, 3).astype(np.float32)
+ result = _normalize_image(arr)
+ assert isinstance(result, Image.Image)
+ assert result.size == (64, 64)
+
+ # Test float [-1, 1] numpy array
+ arr = np.random.rand(64, 64, 3).astype(np.float32) * 2 - 1
+ result = _normalize_image(arr)
+ assert isinstance(result, Image.Image)
+ assert result.size == (64, 64)
+
+ # Test batch dimensions (1, 1, H, W, C)
+ arr = np.random.randint(0, 255, (1, 1, 64, 64, 3), dtype=np.uint8)
+ result = _normalize_image(arr)
+ assert isinstance(result, Image.Image)
+ assert result.size == (64, 64)
+
+
+def test_extract_images_from_result():
+ """Test _extract_images_from_result with various result formats"""
+ import numpy as np
+
+ from vllm_omni.entrypoints.openai.api_server import _extract_images_from_result
+
+ # Test empty result
+ class EmptyResult:
+ pass
+
+ result = EmptyResult()
+ images = _extract_images_from_result(result)
+ assert images == []
+
+ # Test nested batch: [np.array(shape=(3, 64, 64, 3))]
+ batch = np.random.randint(0, 255, (3, 1, 64, 64, 3), dtype=np.uint8)
+
+ class BatchResult:
+ def __init__(self):
+ self.images = [batch]
+
+ result = BatchResult()
+ images = _extract_images_from_result(result)
+ assert len(images) == 3
+ assert all(isinstance(img, Image.Image) for img in images)
+ assert all(img.size == (64, 64) for img in images)
+
+ # Test dict path: result.request_output["images"]
+ class DictRequestOutput:
+ def __init__(self):
+ self.request_output = {"images": [np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)]}
+
+ result = DictRequestOutput()
+ images = _extract_images_from_result(result)
+ assert len(images) == 1
+ assert isinstance(images[0], Image.Image)
+
+ # Test attribute path: result.request_output.images
+ class AttrRequestOutput:
+ def __init__(self):
+ self.request_output = type(
+ "obj", (), {"images": [np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)]}
+ )()
+
+ result = AttrRequestOutput()
+ images = _extract_images_from_result(result)
+ assert len(images) == 1
+ assert isinstance(images[0], Image.Image)
+ assert images[0].size == (32, 32)
diff --git a/tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py b/tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py
new file mode 100644
index 0000000000..a9b9f53ba8
--- /dev/null
+++ b/tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py
@@ -0,0 +1,82 @@
+# SPDX-License-Identifier: Apache-2.0
+"""Regression tests for multistage diffusion generation input construction."""
+
+from __future__ import annotations
+
+from types import SimpleNamespace
+
+import pytest
+from PIL import Image
+from vllm.sampling_params import SamplingParams
+
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+@pytest.fixture
+def serving_chat():
+ from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
+
+ return object.__new__(OmniOpenAIServingChat)
+
+
+def test_build_multistage_generation_inputs_applies_stage_specific_overrides(serving_chat):
+ from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
+
+ engine = SimpleNamespace(
+ stage_configs=[
+ SimpleNamespace(stage_type="llm", is_comprehension=True),
+ SimpleNamespace(stage_type="diffusion", is_comprehension=False),
+ SimpleNamespace(stage_type="diffusion", is_comprehension=False),
+ ],
+ default_sampling_params_list=[
+ SamplingParams(temperature=0.2, seed=11),
+ OmniDiffusionSamplingParams(),
+ OmniDiffusionSamplingParams(),
+ ],
+ )
+ reference_image = Image.new("RGB", (24, 24), color="green")
+ extra_body = {
+ "negative_prompt": "blurry",
+ "num_inference_steps": 28,
+ "guidance_scale": 7.5,
+ "true_cfg_scale": 5.0,
+ "guidance_scale_2": 1.25,
+ "layers": 6,
+ "resolution": 1024,
+ "lora": {"name": "adapter-a", "path": "/tmp/adapter-a", "scale": 0.6},
+ }
+ gen_params = OmniDiffusionSamplingParams(height=768, width=1024, seed=0, num_outputs_per_prompt=2)
+
+ engine_prompt, sampling_params_list = OmniOpenAIServingChat._build_multistage_generation_inputs(
+ serving_chat,
+ engine=engine,
+ prompt="draw a robot",
+ extra_body=extra_body,
+ reference_images=[reference_image],
+ gen_params=gen_params,
+ )
+
+ assert engine_prompt["prompt"] == "draw a robot"
+ assert engine_prompt["modalities"] == ["img2img"]
+ assert engine_prompt["negative_prompt"] == "blurry"
+ assert engine_prompt["mm_processor_kwargs"] == {"target_h": 768, "target_w": 1024}
+ assert engine_prompt["multi_modal_data"]["img2img"].size == (24, 24)
+
+ assert len(sampling_params_list) == 3
+ assert sampling_params_list[0].temperature == 0.2
+ assert sampling_params_list[0].seed == 0
+ assert sampling_params_list[1].height == 768
+ assert sampling_params_list[1].width == 1024
+ assert sampling_params_list[1].seed == 0
+ assert sampling_params_list[1].num_inference_steps == 28
+ assert sampling_params_list[1].guidance_scale == 7.5
+ assert sampling_params_list[1].num_outputs_per_prompt == 2
+ assert sampling_params_list[1].true_cfg_scale == 5.0
+ assert sampling_params_list[1].lora_request.name == "adapter-a"
+ assert sampling_params_list[2].height == 768
+ assert sampling_params_list[2].width == 1024
+ assert sampling_params_list[2].num_inference_steps == 28
+ assert engine.default_sampling_params_list[1].height is None
+ assert engine.default_sampling_params_list[2].resolution == 640
diff --git a/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py b/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py
index fa4c1e195d..4190b1fbb1 100644
--- a/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py
+++ b/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py
@@ -284,6 +284,185 @@ def test_apply_request_overrides_applies_values(serving_chat, mock_request, defa
assert result.top_k == 1 # YAML custom param preserved
+# =============================================================================
+# Tests for empty-list handling in _apply_request_overrides
+# =============================================================================
+
+
+def test_apply_overrides_empty_stop_list_preserves_default(serving_chat, mocker):
+ """Test that request.stop=[] does NOT override YAML default stop words."""
+ default_params = SamplingParams(temperature=0.5, stop=["<|im_end|>"])
+ request = mocker.MagicMock()
+ request.temperature = None
+ request.top_p = None
+ request.top_k = None
+ request.max_tokens = None
+ request.min_tokens = None
+ request.seed = None
+ request.ignore_eos = None
+ request.stop = [] # empty list — should be treated as "not set"
+ request.stop_token_ids = None
+ request.frequency_penalty = None
+ request.presence_penalty = None
+
+ result = serving_chat._apply_request_overrides(default_params, request)
+
+ assert result.stop == ["<|im_end|>"] # YAML default preserved
+
+
+def test_apply_overrides_nonempty_stop_list_overrides_default(serving_chat, mocker):
+ """Test that request.stop=["\\n"] overrides YAML default stop words."""
+ default_params = SamplingParams(temperature=0.5, stop=["<|im_end|>"])
+ request = mocker.MagicMock()
+ request.temperature = None
+ request.top_p = None
+ request.top_k = None
+ request.max_tokens = None
+ request.min_tokens = None
+ request.seed = None
+ request.ignore_eos = None
+ request.stop = ["\n"] # non-empty list — should override
+ request.stop_token_ids = None
+ request.frequency_penalty = None
+ request.presence_penalty = None
+
+ result = serving_chat._apply_request_overrides(default_params, request)
+
+ assert result.stop == ["\n"] # Overridden by request
+
+
+def test_apply_overrides_empty_stop_token_ids_preserves_default(serving_chat, mocker):
+ """Test that request.stop_token_ids=[] does NOT override YAML default."""
+ default_params = SamplingParams(temperature=0.5, stop_token_ids=[2, 3])
+ request = mocker.MagicMock()
+ request.temperature = None
+ request.top_p = None
+ request.top_k = None
+ request.max_tokens = None
+ request.min_tokens = None
+ request.seed = None
+ request.ignore_eos = None
+ request.stop = None
+ request.stop_token_ids = [] # empty list — should be treated as "not set"
+ request.frequency_penalty = None
+ request.presence_penalty = None
+
+ result = serving_chat._apply_request_overrides(default_params, request)
+
+ assert result.stop_token_ids == [2, 3] # YAML default preserved
+
+
+def test_apply_overrides_nonempty_stop_token_ids_overrides_default(serving_chat, mocker):
+ """Test that request.stop_token_ids=[100] overrides YAML default."""
+ default_params = SamplingParams(temperature=0.5, stop_token_ids=[2, 3])
+ request = mocker.MagicMock()
+ request.temperature = None
+ request.top_p = None
+ request.top_k = None
+ request.max_tokens = None
+ request.min_tokens = None
+ request.seed = None
+ request.ignore_eos = None
+ request.stop = None
+ request.stop_token_ids = [100] # non-empty list — should override
+ request.frequency_penalty = None
+ request.presence_penalty = None
+
+ result = serving_chat._apply_request_overrides(default_params, request)
+
+ assert result.stop_token_ids == [100] # Overridden by request
+
+
+def test_apply_overrides_mixed_empty_and_nonempty_lists(serving_chat, mocker):
+ """Test mixing empty and non-empty list fields with scalar fields."""
+ default_params = SamplingParams(
+ temperature=0.4,
+ stop=["<|end|>"],
+ stop_token_ids=[2],
+ )
+ request = mocker.MagicMock()
+ request.temperature = 0.9
+ request.top_p = None
+ request.top_k = None
+ request.max_tokens = None
+ request.min_tokens = None
+ request.seed = None
+ request.ignore_eos = None
+ request.stop = [] # empty — should NOT override
+ request.stop_token_ids = [100, 200] # non-empty — SHOULD override
+ request.frequency_penalty = None
+ request.presence_penalty = None
+
+ result = serving_chat._apply_request_overrides(default_params, request)
+
+ assert result.temperature == 0.9 # Scalar override works
+ assert result.stop == ["<|end|>"] # Empty list did NOT override
+ assert result.stop_token_ids == [100, 200] # Non-empty list DID override
+
+
+def test_apply_overrides_none_scalar_still_preserves_default(serving_chat, mocker):
+ """Regression: ensure None scalar values still don't override defaults."""
+ default_params = SamplingParams(temperature=0.5, max_tokens=100, seed=42)
+ request = mocker.MagicMock()
+ request.temperature = None
+ request.top_p = None
+ request.top_k = None
+ request.max_tokens = None
+ request.min_tokens = None
+ request.seed = None
+ request.ignore_eos = None
+ request.stop = None
+ request.stop_token_ids = None
+ request.frequency_penalty = None
+ request.presence_penalty = None
+
+ result = serving_chat._apply_request_overrides(default_params, request)
+
+ assert result.temperature == 0.5
+ assert result.max_tokens == 100
+ assert result.seed == 42
+
+
+def test_apply_overrides_both_lists_empty_preserves_defaults(serving_chat, mocker):
+ """Test that both stop=[] and stop_token_ids=[] preserve YAML defaults."""
+ default_params = SamplingParams(
+ temperature=0.5,
+ stop=["<|end|>", "\\n"],
+ stop_token_ids=[2, 32000],
+ )
+ request = mocker.MagicMock()
+ request.temperature = None
+ request.top_p = None
+ request.top_k = None
+ request.max_tokens = None
+ request.min_tokens = None
+ request.seed = None
+ request.ignore_eos = None
+ request.stop = []
+ request.stop_token_ids = []
+ request.frequency_penalty = None
+ request.presence_penalty = None
+
+ result = serving_chat._apply_request_overrides(default_params, request)
+
+ assert result.stop == ["<|end|>", "\\n"]
+ assert result.stop_token_ids == [2, 32000]
+
+
+def test_build_sampling_params_list_empty_stop_preserves_yaml(serving_chat, mock_request):
+ """Test that empty stop list in request preserves YAML defaults via
+ _build_sampling_params_list_from_request."""
+ mock_request.stop = []
+ mock_request.stop_token_ids = []
+
+ result = serving_chat._build_sampling_params_list_from_request(mock_request)
+
+ comprehension_params = result[0]
+ # Empty lists should NOT override — YAML defaults are preserved
+ assert comprehension_params.stop == []
+ assert comprehension_params.stop_token_ids == []
+
+
# =============================================================================
# Tests for _get_comprehension_stage_index
# =============================================================================
diff --git a/tests/entrypoints/openai_api/test_serving_chat_speaker.py b/tests/entrypoints/openai_api/test_serving_chat_speaker.py
new file mode 100644
index 0000000000..97c05e45b4
--- /dev/null
+++ b/tests/entrypoints/openai_api/test_serving_chat_speaker.py
@@ -0,0 +1,111 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Tests for chat endpoint speaker validation."""
+
+import asyncio
+from types import SimpleNamespace
+
+import pytest
+from pytest_mock import MockerFixture
+
+from vllm_omni.entrypoints.openai.utils import (
+ get_supported_speakers_from_hf_config,
+ validate_requested_speaker,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+@pytest.fixture
+def serving_chat():
+ from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
+
+ instance = object.__new__(OmniOpenAIServingChat)
+ instance._supported_speakers = None
+ return instance
+
+
+def _make_hf_config(mocker: MockerFixture, *, speaker_id: dict | None = None, spk_id: dict | None = None):
+ hf_config = mocker.MagicMock()
+ talker_config = mocker.MagicMock()
+ talker_config.speaker_id = speaker_id
+ talker_config.spk_id = spk_id
+ hf_config.talker_config = talker_config
+ return hf_config
+
+
+def test_validate_requested_speaker_accepts_case_insensitive_value():
+ supported = {"vivian", "ethan"}
+ assert validate_requested_speaker("Vivian", supported) == "vivian"
+ assert validate_requested_speaker(" vivian ", supported) == "vivian"
+
+
+def test_validate_requested_speaker_rejects_invalid_value_with_supported_list():
+ supported = {"vivian", "ethan"}
+ with pytest.raises(ValueError, match="Invalid speaker 'uncle_fu'. Supported: ethan, vivian"):
+ validate_requested_speaker("uncle_fu", supported)
+
+
+def test_validate_requested_speaker_skips_validation_when_supported_empty():
+ assert validate_requested_speaker("anything", set()) == "anything"
+ assert validate_requested_speaker(" ", {"vivian"}) is None
+
+
+def test_get_supported_speakers_from_hf_config_uses_spk_id_fallback(mocker: MockerFixture):
+ hf_config = _make_hf_config(mocker, speaker_id=None, spk_id={"Serena": 0})
+ assert get_supported_speakers_from_hf_config(hf_config) == {"serena"}
+
+
+def test_get_supported_speakers_caches_normalized_keys(mocker: MockerFixture, serving_chat):
+ serving_chat.model_config = mocker.MagicMock()
+ serving_chat.model_config.hf_config = _make_hf_config(mocker, speaker_id={"Vivian": 0, "Ethan": 1})
+
+ assert serving_chat._get_supported_speakers() == {"vivian", "ethan"}
+
+ # Cached value should be reused even if the config changes afterwards.
+ serving_chat.model_config.hf_config.talker_config.speaker_id = {"Serena": 2}
+ assert serving_chat._get_supported_speakers() == {"vivian", "ethan"}
+
+
+def test_create_chat_completion_converts_value_error_to_error_response(mocker: MockerFixture, serving_chat):
+ serving_chat._diffusion_mode = False
+ serving_chat._check_model = mocker.AsyncMock(return_value=None)
+ serving_chat.engine_client = mocker.MagicMock(errored=False)
+ serving_chat._maybe_get_adapters = mocker.MagicMock(return_value=None)
+ serving_chat.models = mocker.MagicMock()
+ serving_chat.models.model_name.return_value = "test-model"
+ serving_chat.renderer = mocker.MagicMock()
+ serving_chat.renderer.get_tokenizer.return_value = mocker.MagicMock()
+ serving_chat.reasoning_parser_cls = None
+ serving_chat.tool_parser = None
+ serving_chat.use_harmony = False
+ serving_chat.enable_auto_tools = False
+ serving_chat.exclude_tools_when_tool_choice_none = False
+ serving_chat.trust_request_chat_template = False
+ serving_chat.chat_template = None
+ serving_chat.chat_template_content_format = "string"
+ serving_chat.default_chat_template_kwargs = {}
+ serving_chat._validate_chat_template = mocker.MagicMock(return_value=None)
+ serving_chat._prepare_extra_chat_template_kwargs = mocker.MagicMock(return_value={})
+ serving_chat._preprocess_chat = mocker.AsyncMock(
+ side_effect=ValueError("Invalid speaker 'uncle_fu'. Supported: ethan, vivian")
+ )
+ serving_chat.create_error_response = mocker.MagicMock(return_value="error-response")
+
+ request = SimpleNamespace(
+ tool_choice=None,
+ tools=None,
+ chat_template=None,
+ chat_template_kwargs=None,
+ reasoning_effort=None,
+ messages=[],
+ add_generation_prompt=False,
+ continue_final_message=False,
+ add_special_tokens=False,
+ request_id="speaker-test",
+ )
+
+ result = asyncio.run(serving_chat.create_chat_completion(request))
+
+ assert result == "error-response"
+ serving_chat.create_error_response.assert_called_once_with("Invalid speaker 'uncle_fu'. Supported: ethan, vivian")
diff --git a/tests/entrypoints/openai_api/test_serving_speech.py b/tests/entrypoints/openai_api/test_serving_speech.py
index 57aeef8f9d..b388b18606 100644
--- a/tests/entrypoints/openai_api/test_serving_speech.py
+++ b/tests/entrypoints/openai_api/test_serving_speech.py
@@ -6,7 +6,6 @@
from inspect import Signature, signature
from pathlib import Path
from types import SimpleNamespace
-from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
@@ -63,14 +62,11 @@ def test_stereo_to_mono_conversion(self, audio_mixin, mocker: MockerFixture):
adjusted_tensor = mock_speed.call_args[0][0]
assert len(adjusted_tensor) == 24000
- def test_speed_adjustment(self, audio_mixin, mocker: MockerFixture):
- mock_time_stretch = mocker.patch("librosa.effects.time_stretch")
- mock_time_stretch.return_value = np.zeros(12000)
+ def test_speed_adjustment(self, audio_mixin):
audio_tensor = np.random.rand(24000).astype(np.float32)
adjusted_audio, _ = audio_mixin._apply_speed_adjustment(audio_tensor, speed=2.0, sample_rate=24000)
- mock_time_stretch.assert_called_with(y=audio_tensor, rate=2.0)
assert adjusted_audio.shape == (12000,)
def test_unsupported_format_fallback(self, audio_mixin, caplog, mocker: MockerFixture):
@@ -117,30 +113,22 @@ def test_stereo_audio_preservation(self, audio_mixin, mocker: MockerFixture):
assert np.array_equal(output_tensor, stereo_tensor)
def test_speed_adjustment_bypass(self, audio_mixin, mocker: MockerFixture):
- """Test that speed=1.0 bypasses the expensive librosa time stretching."""
+ """Test that speed=1.0 bypasses the expensive torchaudio time stretching."""
audio_tensor = np.random.rand(24000).astype(np.float32)
- mock_time_stretch = mocker.patch("librosa.effects.time_stretch")
- # speed=1.0 should return immediately without calling librosa
+ mock_time_stretch = mocker.patch("torchaudio.transforms.TimeStretch")
+ # speed=1.0 should return immediately without calling torchaudio
result, _ = audio_mixin._apply_speed_adjustment(audio_tensor, speed=1.0, sample_rate=24000)
mock_time_stretch.assert_not_called()
assert np.array_equal(result, audio_tensor)
- def test_speed_adjustment_stereo_handling(self, audio_mixin, mocker: MockerFixture):
- """Test that speed adjustment is attempted on stereo inputs."""
- mock_time_stretch = mocker.patch("librosa.effects.time_stretch")
+ def test_speed_adjustment_stereo_handling(self, audio_mixin):
+ """Test that speed adjustment handles stereo (channels-last) input."""
stereo_tensor = np.random.rand(24000, 2).astype(np.float32)
- # Mock return value representing a sped-up version (half length)
- mock_time_stretch.return_value = np.zeros((12000, 2), dtype=np.float32)
result, _ = audio_mixin._apply_speed_adjustment(stereo_tensor, speed=2.0, sample_rate=24000)
- mock_time_stretch.assert_called_once()
- # Ensure the stereo tensor was passed to librosa
- call_args = mock_time_stretch.call_args
- assert np.array_equal(call_args.kwargs["y"], stereo_tensor)
- assert call_args.kwargs["rate"] == 2.0
assert result.shape == (12000, 2)
@@ -696,6 +684,32 @@ def test_is_tts_detection_with_tts_stage(self, mocker: MockerFixture):
assert server._is_tts is True
assert server._tts_stage is mock_stage
+ def test_prepare_speech_rejects_non_tts_omni_model(self, mocker: MockerFixture):
+ """Multi-stage omni models (e.g. Qwen3-Omni) must not use /v1/audio/speech."""
+ mock_engine_client = mocker.MagicMock()
+ mock_engine_client.errored = False
+ mock_engine_client.tts_max_instructions_length = None
+
+ # Simulate Qwen3-Omni: multiple stages, none in _TTS_MODEL_STAGES
+ thinker = SimpleNamespace(engine_args=SimpleNamespace(model_stage="thinker"), tts_args={})
+ talker = SimpleNamespace(engine_args=SimpleNamespace(model_stage="talker"), tts_args={})
+ code2wav = SimpleNamespace(engine_args=SimpleNamespace(model_stage="code2wav"), tts_args={})
+ mock_engine_client.stage_configs = [thinker, talker, code2wav]
+
+ mock_models = mocker.MagicMock()
+ mock_models.is_base_model.return_value = True
+ server = OmniOpenAIServingSpeech(
+ engine_client=mock_engine_client,
+ models=mock_models,
+ request_logger=mocker.MagicMock(),
+ )
+ assert server._is_tts is False
+
+ request = OpenAICreateSpeechRequest(input="Hello world")
+ with pytest.raises(ValueError, match="only supported for dedicated TTS models"):
+ asyncio.run(server._prepare_speech_generation(request))
+ server.shutdown()
+
def test_estimate_prompt_len_fallback(self, speech_server):
"""Test prompt length estimation falls back to 2048 when model is unavailable."""
tts_params = {"text": ["Hello"], "task_type": ["CustomVoice"]}
@@ -763,6 +777,26 @@ def test_validate_tts_request_base_empty_ref_text(self, speech_server):
)
assert speech_server._validate_tts_request(req) is None
+ @pytest.mark.parametrize(
+ "ref_text",
+ [None, "", " "],
+ ids=["none", "empty", "whitespace"],
+ )
+ def test_validate_base_task_missing_ref_text_returns_400(self, speech_server, ref_text):
+ """Regression: Base task without ref_text must return 400, not crash EngineCore.
+
+ See https://github.com/vllm-project/vllm-omni/pull/2203
+ """
+ req = OpenAICreateSpeechRequest(
+ input="Hello",
+ task_type="Base",
+ ref_audio="data:audio/wav;base64,abc",
+ ref_text=ref_text,
+ )
+ result = speech_server._validate_tts_request(req)
+ assert result is not None, f"ref_text={ref_text!r} should be rejected"
+ assert "ref_text" in result
+
def test_validate_tts_request_customvoice_no_speakers(self, speech_server):
"""CustomVoice on a model with no speakers returns 400 instead of crashing engine."""
req = OpenAICreateSpeechRequest(input="Hello", task_type="CustomVoice")
@@ -892,7 +926,7 @@ def test_load_supported_speakers(self, mocker: MockerFixture):
# Verify speakers are normalized to lowercase
assert server.supported_speakers == {"ryan", "vivian", "aiden"}
- def test_build_tts_params_with_uploaded_voice(self, speech_server):
+ def test_build_tts_params_with_uploaded_voice(self, speech_server, mocker: MockerFixture):
"""Test _build_tts_params auto-sets ref_audio for uploaded voices (x_vector only)."""
speech_server.uploaded_speakers = {
"custom_voice": {
@@ -905,18 +939,18 @@ def test_build_tts_params_with_uploaded_voice(self, speech_server):
}
speech_server.supported_speakers = {"ryan", "vivian", "custom_voice"}
- with patch.object(speech_server, "_get_uploaded_audio_data") as mock_get_audio:
- mock_get_audio.return_value = "data:audio/wav;base64,ZmFrZWF1ZGlv"
- req = OpenAICreateSpeechRequest(input="Hello", voice="custom_voice")
- params = speech_server._build_tts_params(req)
+ mock_get_audio = mocker.patch.object(speech_server, "_get_uploaded_audio_data")
+ mock_get_audio.return_value = "data:audio/wav;base64,ZmFrZWF1ZGlv"
+ req = OpenAICreateSpeechRequest(input="Hello", voice="custom_voice")
+ params = speech_server._build_tts_params(req)
- assert params["ref_audio"] == ["data:audio/wav;base64,ZmFrZWF1ZGlv"]
- assert params["x_vector_only_mode"] == [True]
- assert params["task_type"] == ["Base"]
- assert params["voice_created_at"] == [1711234567.89]
- assert "ref_text" not in params
+ assert params["ref_audio"] == ["data:audio/wav;base64,ZmFrZWF1ZGlv"]
+ assert params["x_vector_only_mode"] == [True]
+ assert params["task_type"] == ["Base"]
+ assert params["voice_created_at"] == [1711234567.89]
+ assert "ref_text" not in params
- def test_build_tts_params_with_uploaded_voice_ref_text(self, speech_server):
+ def test_build_tts_params_with_uploaded_voice_ref_text(self, speech_server, mocker: MockerFixture):
"""Test _build_tts_params enables in-context cloning when ref_text is stored."""
speech_server.uploaded_speakers = {
"custom_voice": {
@@ -929,16 +963,16 @@ def test_build_tts_params_with_uploaded_voice_ref_text(self, speech_server):
}
speech_server.supported_speakers = {"ryan", "vivian", "custom_voice"}
- with patch.object(speech_server, "_get_uploaded_audio_data") as mock_get_audio:
- mock_get_audio.return_value = "data:audio/wav;base64,ZmFrZWF1ZGlv"
- req = OpenAICreateSpeechRequest(input="Hello", voice="custom_voice")
- params = speech_server._build_tts_params(req)
+ mock_get_audio = mocker.patch.object(speech_server, "_get_uploaded_audio_data")
+ mock_get_audio.return_value = "data:audio/wav;base64,ZmFrZWF1ZGlv"
+ req = OpenAICreateSpeechRequest(input="Hello", voice="custom_voice")
+ params = speech_server._build_tts_params(req)
- assert params["ref_audio"] == ["data:audio/wav;base64,ZmFrZWF1ZGlv"]
- assert params["x_vector_only_mode"] == [False]
- assert params["task_type"] == ["Base"]
- assert params["ref_text"] == ["Hello world transcript"]
- assert params["voice_created_at"] == [1711234567.89]
+ assert params["ref_audio"] == ["data:audio/wav;base64,ZmFrZWF1ZGlv"]
+ assert params["x_vector_only_mode"] == [False]
+ assert params["task_type"] == ["Base"]
+ assert params["ref_text"] == ["Hello world transcript"]
+ assert params["voice_created_at"] == [1711234567.89]
def test_build_tts_params_without_uploaded_voice(self, speech_server):
"""Test _build_tts_params does not auto-set ref_audio for non-uploaded voices."""
@@ -980,45 +1014,43 @@ def test_build_tts_params_with_explicit_ref_audio(self, speech_server):
# x_vector_only_mode should not be set when explicit ref_audio is provided
assert "x_vector_only_mode" not in params
- def test_get_uploaded_audio_data(self, speech_server):
+ def test_get_uploaded_audio_data(self, speech_server, mocker: MockerFixture):
"""Test _get_uploaded_audio_data function."""
# Mock file operations
- with (
- patch("builtins.open", create=True) as mock_open,
- patch("base64.b64encode") as mock_b64encode,
- patch("pathlib.Path.exists") as mock_exists,
- ):
- mock_exists.return_value = True
- mock_b64encode.return_value = b"ZmFrZWF1ZGlv"
-
- # Setup mock file
- mock_file = MagicMock()
- mock_file.read.return_value = b"fakeaudio"
- mock_open.return_value.__enter__.return_value = mock_file
-
- # Setup uploaded speaker
- speech_server.uploaded_speakers = {
- "test_voice": {"name": "test_voice", "file_path": "/tmp/test.wav", "mime_type": "audio/wav"}
- }
- result = speech_server._get_uploaded_audio_data("test_voice")
+ mock_open = mocker.patch("builtins.open", create=True)
+ mock_b64encode = mocker.patch("base64.b64encode")
+ mock_exists = mocker.patch("pathlib.Path.exists")
+ mock_exists.return_value = True
+ mock_b64encode.return_value = b"ZmFrZWF1ZGlv"
+
+ # Setup mock file
+ mock_file = mocker.MagicMock()
+ mock_file.read.return_value = b"fakeaudio"
+ mock_open.return_value.__enter__.return_value = mock_file
+
+ # Setup uploaded speaker
+ speech_server.uploaded_speakers = {
+ "test_voice": {"name": "test_voice", "file_path": "/tmp/test.wav", "mime_type": "audio/wav"}
+ }
+ result = speech_server._get_uploaded_audio_data("test_voice")
- assert result == "data:audio/wav;base64,ZmFrZWF1ZGlv"
- mock_open.assert_called_once_with(Path("/tmp/test.wav"), "rb")
- mock_b64encode.assert_called_once_with(b"fakeaudio")
+ assert result == "data:audio/wav;base64,ZmFrZWF1ZGlv"
+ mock_open.assert_called_once_with(Path("/tmp/test.wav"), "rb")
+ mock_b64encode.assert_called_once_with(b"fakeaudio")
- def test_get_uploaded_audio_data_missing_file(self, speech_server):
+ def test_get_uploaded_audio_data_missing_file(self, speech_server, mocker: MockerFixture):
"""Test _get_uploaded_audio_data when file is missing."""
- with patch("pathlib.Path.exists") as mock_exists:
- mock_exists.return_value = False
+ mock_exists = mocker.patch("pathlib.Path.exists")
+ mock_exists.return_value = False
- # Setup uploaded speaker
- speech_server.uploaded_speakers = {
- "test_voice": {"name": "test_voice", "file_path": "/tmp/test.wav", "mime_type": "audio/wav"}
- }
+ # Setup uploaded speaker
+ speech_server.uploaded_speakers = {
+ "test_voice": {"name": "test_voice", "file_path": "/tmp/test.wav", "mime_type": "audio/wav"}
+ }
- result = speech_server._get_uploaded_audio_data("test_voice")
+ result = speech_server._get_uploaded_audio_data("test_voice")
- assert result is None
+ assert result is None
def test_get_uploaded_audio_data_voice_not_found(self, speech_server):
"""Test _get_uploaded_audio_data when voice is not in uploaded_speakers."""
@@ -1040,7 +1072,7 @@ def test_voice_field_still_accepted(self):
req = OpenAICreateSpeechRequest.model_validate({"input": "Hello", "voice": "custom_voice"})
assert req.voice == "custom_voice"
- def test_speaker_alias_in_base_task_with_uploaded_voice(self, speech_server):
+ def test_speaker_alias_in_base_task_with_uploaded_voice(self, speech_server, mocker: MockerFixture):
"""Using 'speaker' key with an uploaded voice should work for Base task."""
speech_server.uploaded_speakers = {
"utesf": {
@@ -1052,13 +1084,13 @@ def test_speaker_alias_in_base_task_with_uploaded_voice(self, speech_server):
}
req = OpenAICreateSpeechRequest.model_validate({"input": "Hello", "speaker": "UTESF", "task_type": "Base"})
assert req.voice == "UTESF"
- with patch("pathlib.Path.exists", return_value=True):
- result = speech_server._validate_qwen_tts_request(req)
+ mocker.patch("pathlib.Path.exists", return_value=True)
+ result = speech_server._validate_qwen_tts_request(req)
assert result is None
# ── uploaded voice with embedding ──
- def test_build_tts_params_with_uploaded_voice_embedding(self, speech_server):
+ def test_build_tts_params_with_uploaded_voice_embedding(self, speech_server, mocker: MockerFixture):
"""Test _build_tts_params loads embedding for embedding-uploaded voices."""
speech_server.uploaded_speakers = {
"emb_voice": {
@@ -1074,20 +1106,20 @@ def test_build_tts_params_with_uploaded_voice_embedding(self, speech_server):
speech_server.supported_speakers = {"ryan", "vivian", "emb_voice"}
fake_embedding = [0.1] * 1024
- with patch.object(speech_server, "_get_uploaded_speaker_embedding") as mock_get_emb:
- mock_get_emb.return_value = fake_embedding
- req = OpenAICreateSpeechRequest(input="Hello", voice="emb_voice")
- params = speech_server._build_tts_params(req)
+ mock_get_emb = mocker.patch.object(speech_server, "_get_uploaded_speaker_embedding")
+ mock_get_emb.return_value = fake_embedding
+ req = OpenAICreateSpeechRequest(input="Hello", voice="emb_voice")
+ params = speech_server._build_tts_params(req)
- assert "voice_clone_prompt" in params
- assert params["voice_clone_prompt"][0]["ref_spk_embedding"] == fake_embedding
- assert params["task_type"] == ["Base"]
- assert params["x_vector_only_mode"] == [True]
- assert "ref_audio" not in params
+ assert "voice_clone_prompt" in params
+ assert params["voice_clone_prompt"][0]["ref_spk_embedding"] == fake_embedding
+ assert params["task_type"] == ["Base"]
+ assert params["x_vector_only_mode"] == [True]
+ assert "ref_audio" not in params
# ── regression: full flow from issue #1603 ──
- def test_regression_1603_speaker_key_with_uploaded_audio_voice(self, speech_server):
+ def test_regression_1603_speaker_key_with_uploaded_audio_voice(self, speech_server, mocker: MockerFixture):
"""Regression test for #1603: upload audio voice, then invoke TTS with 'speaker' key.
Verifies the full validate → build_params pipeline works end-to-end.
@@ -1107,14 +1139,14 @@ def test_regression_1603_speaker_key_with_uploaded_audio_voice(self, speech_serv
assert req.voice == "UTESF"
# Validation should pass (file exists)
- with patch("pathlib.Path.exists", return_value=True):
- err = speech_server._validate_qwen_tts_request(req)
+ mocker.patch("pathlib.Path.exists", return_value=True)
+ err = speech_server._validate_qwen_tts_request(req)
assert err is None, f"Validation failed: {err}"
# Build params should auto-set ref_audio from stored file
- with patch.object(speech_server, "_get_uploaded_audio_data") as mock_audio:
- mock_audio.return_value = "data:audio/wav;base64,ZmFrZQ=="
- params = speech_server._build_tts_params(req)
+ mock_audio = mocker.patch.object(speech_server, "_get_uploaded_audio_data")
+ mock_audio.return_value = "data:audio/wav;base64,ZmFrZQ=="
+ params = speech_server._build_tts_params(req)
assert params["task_type"] == ["Base"]
assert params["ref_audio"] == ["data:audio/wav;base64,ZmFrZQ=="]
@@ -1122,7 +1154,7 @@ def test_regression_1603_speaker_key_with_uploaded_audio_voice(self, speech_serv
assert params["x_vector_only_mode"] == [False]
assert params["speaker"] == ["utesf"]
- def test_regression_1603_speaker_key_with_uploaded_embedding_voice(self, speech_server):
+ def test_regression_1603_speaker_key_with_uploaded_embedding_voice(self, speech_server, mocker: MockerFixture):
"""Regression test for #1603: upload embedding voice, then invoke TTS with 'speaker' key.
Verifies embedding-uploaded voices are loaded as voice_clone_prompt, not as audio.
@@ -1145,15 +1177,15 @@ def test_regression_1603_speaker_key_with_uploaded_embedding_voice(self, speech_
assert req.voice == "myvoice"
# Validation should pass
- with patch("pathlib.Path.exists", return_value=True):
- err = speech_server._validate_qwen_tts_request(req)
+ mocker.patch("pathlib.Path.exists", return_value=True)
+ err = speech_server._validate_qwen_tts_request(req)
assert err is None, f"Validation failed: {err}"
# Build params should use embedding, NOT audio
fake_emb = [0.1] * 1024
- with patch.object(speech_server, "_get_uploaded_speaker_embedding") as mock_emb:
- mock_emb.return_value = fake_emb
- params = speech_server._build_tts_params(req)
+ mock_emb = mocker.patch.object(speech_server, "_get_uploaded_speaker_embedding")
+ mock_emb.return_value = fake_emb
+ params = speech_server._build_tts_params(req)
assert params["task_type"] == ["Base"]
assert params["x_vector_only_mode"] == [True]
@@ -1162,7 +1194,7 @@ def test_regression_1603_speaker_key_with_uploaded_embedding_voice(self, speech_
# Must NOT have ref_audio — that would fail for safetensors files
assert "ref_audio" not in params
- def test_validate_rejects_embedding_voice_with_pending_cache(self, speech_server):
+ def test_validate_rejects_embedding_voice_with_pending_cache(self, speech_server, mocker: MockerFixture):
"""Validation should reject embedding voices whose cache is not yet ready."""
speech_server.uploaded_speakers = {
"myvoice": {
@@ -1175,12 +1207,12 @@ def test_validate_rejects_embedding_voice_with_pending_cache(self, speech_server
}
}
req = OpenAICreateSpeechRequest.model_validate({"input": "Hello", "speaker": "myvoice", "task_type": "Base"})
- with patch("pathlib.Path.exists", return_value=True):
- err = speech_server._validate_qwen_tts_request(req)
+ mocker.patch("pathlib.Path.exists", return_value=True)
+ err = speech_server._validate_qwen_tts_request(req)
assert err is not None
assert "not yet ready" in err
- def test_x_vector_only_mode_not_overwritten_for_uploaded_embedding(self, speech_server):
+ def test_x_vector_only_mode_not_overwritten_for_uploaded_embedding(self, speech_server, mocker: MockerFixture):
"""x_vector_only_mode set by uploaded embedding must not be overwritten by request field."""
speech_server.uploaded_speakers = {
"emb_voice": {
@@ -1194,11 +1226,11 @@ def test_x_vector_only_mode_not_overwritten_for_uploaded_embedding(self, speech_
}
}
fake_emb = [0.1] * 1024
- with patch.object(speech_server, "_get_uploaded_speaker_embedding") as mock_emb:
- mock_emb.return_value = fake_emb
- # Client explicitly sends x_vector_only_mode=False, but embedding requires True
- req = OpenAICreateSpeechRequest(input="Hello", voice="emb_voice", x_vector_only_mode=False)
- params = speech_server._build_tts_params(req)
+ mock_emb = mocker.patch.object(speech_server, "_get_uploaded_speaker_embedding")
+ mock_emb.return_value = fake_emb
+ # Client explicitly sends x_vector_only_mode=False, but embedding requires True
+ req = OpenAICreateSpeechRequest(input="Hello", voice="emb_voice", x_vector_only_mode=False)
+ params = speech_server._build_tts_params(req)
assert params["x_vector_only_mode"] == [True]
assert "voice_clone_prompt" in params
@@ -1645,9 +1677,9 @@ async def test_omni_model_includes_generate(self):
assert "generate" in tasks
-def test_api_server_create_speech_wraps_error_response_status():
- handler = MagicMock()
- handler.create_speech = AsyncMock(
+def test_api_server_create_speech_wraps_error_response_status(mocker: MockerFixture):
+ handler = mocker.MagicMock()
+ handler.create_speech = mocker.AsyncMock(
return_value=ErrorResponse(
error=ErrorInfo(message="bad request", type="BadRequestError", param=None, code=400),
)
@@ -1842,9 +1874,9 @@ def test_build_fish_prompt_normalizes_legacy_speaker_tags(self, fish_speech_serv
assert "<|speaker:0|>你好,[laughing]欢迎回来。<|speaker:1|>我也来了。" in encoded_texts
assert all(allowed_special is None for _, _, allowed_special in tokenizer.calls)
- def test_build_fish_clone_prompt_normalizes_text_fields(self, fish_speech_server):
+ def test_build_fish_clone_prompt_normalizes_text_fields(self, fish_speech_server, mocker: MockerFixture):
fish_speech_server._fish_speech_tokenizer = _FakeFishTokenizer()
- fish_speech_server._estimate_fish_prompt_len = MagicMock(return_value=123)
+ fish_speech_server._estimate_fish_prompt_len = mocker.MagicMock(return_value=123)
request = OpenAICreateSpeechRequest(
input="你好,欢迎回来。",
@@ -1895,8 +1927,10 @@ def test_build_fish_prompt_rejects_unsafe_control_tokens(self, fish_speech_serve
with pytest.raises(ValueError, match="unsupported control token"):
fish_speech_server._build_fish_speech_prompt(request)
- def test_prepare_speech_generation_overrides_fish_default_max_tokens(self, fish_speech_server):
- fish_speech_server._build_fish_speech_prompt_async = AsyncMock(
+ def test_prepare_speech_generation_overrides_fish_default_max_tokens(
+ self, fish_speech_server, mocker: MockerFixture
+ ):
+ fish_speech_server._build_fish_speech_prompt_async = mocker.AsyncMock(
return_value={
"prompt_token_ids": [1, 2, 3],
"additional_information": {},
@@ -1915,8 +1949,8 @@ def test_prepare_speech_generation_overrides_fish_default_max_tokens(self, fish_
assert sampling_params_list[0].max_tokens == 4096
assert fish_speech_server.engine_client.default_sampling_params_list[0].max_tokens == 2048
- def test_prepare_speech_generation_uses_stage_default_max_tokens(self, fish_speech_server):
- fish_speech_server._build_fish_speech_prompt_async = AsyncMock(
+ def test_prepare_speech_generation_uses_stage_default_max_tokens(self, fish_speech_server, mocker: MockerFixture):
+ fish_speech_server._build_fish_speech_prompt_async = mocker.AsyncMock(
return_value={
"prompt_token_ids": [1, 2, 3],
"additional_information": {},
@@ -1947,9 +1981,9 @@ def test_prepare_speech_generation_rejects_invalid_fish_max_new_tokens(self, fis
fish_speech_server.engine_client.generate.assert_not_called()
- def test_create_speech_batch_allows_fish_text_only_items(self, fish_speech_server):
- fish_speech_server._check_model = AsyncMock(return_value=None)
- fish_speech_server._generate_audio_bytes = AsyncMock(return_value=("YWJj", "audio/wav"))
+ def test_create_speech_batch_allows_fish_text_only_items(self, fish_speech_server, mocker: MockerFixture):
+ fish_speech_server._check_model = mocker.AsyncMock(return_value=None)
+ fish_speech_server._generate_audio_bytes = mocker.AsyncMock(return_value=("YWJj", "audio/wav"))
batch = BatchSpeechRequest(items=[SpeechBatchItem(input="hello fish")])
response = asyncio.run(fish_speech_server.create_speech_batch(batch))
@@ -2145,8 +2179,8 @@ def test_validate_cosyvoice3_max_new_tokens_range(self, cosyvoice3_server):
assert error is not None
assert "max_new_tokens" in error
- def test_prepare_speech_generation_cosyvoice3(self, cosyvoice3_server):
- cosyvoice3_server._build_cosyvoice3_prompt = AsyncMock(
+ def test_prepare_speech_generation_cosyvoice3(self, cosyvoice3_server, mocker: MockerFixture):
+ cosyvoice3_server._build_cosyvoice3_prompt = mocker.AsyncMock(
return_value={
"prompt": "Hello",
"multi_modal_data": {"audio": (np.zeros(24000), 24000)},
@@ -2227,9 +2261,9 @@ def qwen3_tts_server(self, mocker: MockerFixture):
yield server
server.shutdown()
- def test_prepare_speech_generation_awaits_voxtral_async(self, voxtral_server):
+ def test_prepare_speech_generation_awaits_voxtral_async(self, voxtral_server, mocker: MockerFixture):
"""Voxtral path in _prepare_speech_generation should call the async wrapper."""
- voxtral_server._build_voxtral_prompt_async = AsyncMock(
+ voxtral_server._build_voxtral_prompt_async = mocker.AsyncMock(
return_value={
"prompt_token_ids": [1, 2, 3],
"additional_information": {"voice": ["test"]},
@@ -2239,13 +2273,13 @@ def test_prepare_speech_generation_awaits_voxtral_async(self, voxtral_server):
asyncio.run(voxtral_server._prepare_speech_generation(request))
voxtral_server._build_voxtral_prompt_async.assert_awaited_once()
- def test_prepare_speech_generation_awaits_qwen3_tts_async(self, qwen3_tts_server):
+ def test_prepare_speech_generation_awaits_qwen3_tts_async(self, qwen3_tts_server, mocker: MockerFixture):
"""Qwen3 TTS path should call _estimate_prompt_len_async."""
- qwen3_tts_server._validate_tts_request = MagicMock(return_value=None)
- qwen3_tts_server._build_tts_params = MagicMock(
+ qwen3_tts_server._validate_tts_request = mocker.MagicMock(return_value=None)
+ qwen3_tts_server._build_tts_params = mocker.MagicMock(
return_value={"text": ["hello"], "task_type": ["CustomVoice"], "speaker": ["Vivian"]}
)
- qwen3_tts_server._estimate_prompt_len_async = AsyncMock(return_value=512)
+ qwen3_tts_server._estimate_prompt_len_async = mocker.AsyncMock(return_value=512)
request = OpenAICreateSpeechRequest(input="hello")
asyncio.run(qwen3_tts_server._prepare_speech_generation(request))
qwen3_tts_server._build_tts_params.assert_called_once()
@@ -2272,8 +2306,8 @@ def test_shutdown_is_idempotent(self, mocker: MockerFixture):
server.shutdown() # Should not raise
assert server._tts_executor is None
- def test_diffusion_instance_shutdown_safe(self):
+ def test_diffusion_instance_shutdown_safe(self, mocker: MockerFixture):
"""Diffusion instances (created via for_diffusion) should have safe shutdown."""
- server = OmniOpenAIServingSpeech.for_diffusion(diffusion_engine=MagicMock(), model_name="test-model")
+ server = OmniOpenAIServingSpeech.for_diffusion(diffusion_engine=mocker.MagicMock(), model_name="test-model")
assert server._tts_executor is None
server.shutdown() # Should not raise
diff --git a/tests/entrypoints/openai_api/test_serving_speech_stream.py b/tests/entrypoints/openai_api/test_serving_speech_stream.py
index 1d26b5855f..1b93ef58e2 100644
--- a/tests/entrypoints/openai_api/test_serving_speech_stream.py
+++ b/tests/entrypoints/openai_api/test_serving_speech_stream.py
@@ -1,8 +1,8 @@
import asyncio
-from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import FastAPI, WebSocket
+from pytest_mock import MockerFixture
from starlette.testclient import TestClient
from starlette.websockets import WebSocketDisconnect
@@ -13,19 +13,26 @@
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-def _build_test_app(speech_service=None, *, idle_timeout=30.0, config_timeout=10.0):
+def _build_test_app(
+ speech_service=None,
+ *,
+ idle_timeout=30.0,
+ config_timeout=10.0,
+ mocker: MockerFixture | None = None,
+):
if speech_service is None:
- speech_service = MagicMock(spec=OmniOpenAIServingSpeech)
- speech_service._generate_audio_bytes = AsyncMock(return_value=(b"RIFF" + b"\x00" * 32, "audio/wav"))
- speech_service._prepare_speech_generation = AsyncMock(return_value=("req-1", object(), {}))
+ assert mocker is not None
+ speech_service = mocker.MagicMock(spec=OmniOpenAIServingSpeech)
+ speech_service._generate_audio_bytes = mocker.AsyncMock(return_value=(b"RIFF" + b"\x00" * 32, "audio/wav"))
+ speech_service._prepare_speech_generation = mocker.AsyncMock(return_value=("req-1", object(), {}))
async def mock_generate_pcm_chunks(_generator, _request_id):
for chunk in (b"\x01\x02", b"\x03\x04\x05"):
yield chunk
speech_service._generate_pcm_chunks = mock_generate_pcm_chunks
- speech_service.engine_client = MagicMock()
- speech_service.engine_client.abort = AsyncMock()
+ speech_service.engine_client = mocker.MagicMock()
+ speech_service.engine_client.abort = mocker.AsyncMock()
handler = OmniStreamingSpeechHandler(
speech_service=speech_service,
@@ -42,8 +49,8 @@ async def ws_endpoint(websocket: WebSocket):
class TestStreamingSpeechWebSocket:
- def test_non_streaming_single_frame(self):
- app, speech_service = _build_test_app()
+ def test_non_streaming_single_frame(self, mocker: MockerFixture):
+ app, speech_service = _build_test_app(mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -68,13 +75,13 @@ def test_non_streaming_single_frame(self):
assert speech_service._generate_audio_bytes.await_count == 1
- def test_streaming_multiple_binary_frames(self):
+ def test_streaming_multiple_binary_frames(self, mocker: MockerFixture):
captured_requests = []
- speech_service = MagicMock(spec=OmniOpenAIServingSpeech)
- speech_service._generate_audio_bytes = AsyncMock(return_value=(b"", "audio/wav"))
- speech_service.engine_client = MagicMock()
- speech_service.engine_client.abort = AsyncMock()
+ speech_service = mocker.MagicMock(spec=OmniOpenAIServingSpeech)
+ speech_service._generate_audio_bytes = mocker.AsyncMock(return_value=(b"", "audio/wav"))
+ speech_service.engine_client = mocker.MagicMock()
+ speech_service.engine_client.abort = mocker.AsyncMock()
async def mock_prepare_speech_generation(request):
captured_requests.append(request)
@@ -123,8 +130,8 @@ async def mock_generate_pcm_chunks(_generator, _request_id):
assert captured_requests[0].initial_codec_chunk_frames == 12
assert speech_service._generate_audio_bytes.await_count == 0
- def test_flush_on_input_done(self):
- app, _ = _build_test_app()
+ def test_flush_on_input_done(self, mocker: MockerFixture):
+ app, _ = _build_test_app(mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -142,8 +149,8 @@ def test_flush_on_input_done(self):
}
assert ws.receive_json() == {"type": "session.done", "total_sentences": 1}
- def test_invalid_streaming_config(self):
- app, _ = _build_test_app()
+ def test_invalid_streaming_config(self, mocker: MockerFixture):
+ app, _ = _build_test_app(mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -159,8 +166,8 @@ def test_invalid_streaming_config(self):
assert error["type"] == "error"
assert "response_format='pcm'" in error["message"]
- def test_empty_input_text_emits_no_audio(self):
- app, speech_service = _build_test_app()
+ def test_empty_input_text_emits_no_audio(self, mocker: MockerFixture):
+ app, speech_service = _build_test_app(mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -172,8 +179,8 @@ def test_empty_input_text_emits_no_audio(self):
assert speech_service._generate_audio_bytes.await_count == 0
- def test_multiple_sentences_increment_indices(self):
- app, _ = _build_test_app()
+ def test_multiple_sentences_increment_indices(self, mocker: MockerFixture):
+ app, _ = _build_test_app(mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -203,8 +210,8 @@ def test_multiple_sentences_increment_indices(self):
ws.send_json({"type": "input.done"})
assert ws.receive_json() == {"type": "session.done", "total_sentences": 2}
- def test_unknown_message_type_keeps_session_open(self):
- app, _ = _build_test_app()
+ def test_unknown_message_type_keeps_session_open(self, mocker: MockerFixture):
+ app, _ = _build_test_app(mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -227,21 +234,21 @@ def test_unknown_message_type_keeps_session_open(self):
ws.send_json({"type": "input.done"})
assert ws.receive_json() == {"type": "session.done", "total_sentences": 1}
- def test_config_timeout_closes_session(self):
- app, _ = _build_test_app(config_timeout=0.01)
+ def test_config_timeout_closes_session(self, mocker: MockerFixture):
+ app, _ = _build_test_app(config_timeout=0.01, mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
error = ws.receive_json()
assert error == {"type": "error", "message": "Timeout waiting for session.config"}
- def test_generation_error_marks_audio_done(self):
- speech_service = MagicMock(spec=OmniOpenAIServingSpeech)
- speech_service._generate_audio_bytes = AsyncMock(side_effect=RuntimeError("boom"))
- speech_service._prepare_speech_generation = AsyncMock(return_value=("req-err", object(), {}))
- speech_service._generate_pcm_chunks = AsyncMock()
- speech_service.engine_client = MagicMock()
- speech_service.engine_client.abort = AsyncMock()
+ def test_generation_error_marks_audio_done(self, mocker: MockerFixture):
+ speech_service = mocker.MagicMock(spec=OmniOpenAIServingSpeech)
+ speech_service._generate_audio_bytes = mocker.AsyncMock(side_effect=RuntimeError("boom"))
+ speech_service._prepare_speech_generation = mocker.AsyncMock(return_value=("req-err", object(), {}))
+ speech_service._generate_pcm_chunks = mocker.AsyncMock()
+ speech_service.engine_client = mocker.MagicMock()
+ speech_service.engine_client.abort = mocker.AsyncMock()
app, _ = _build_test_app(speech_service)
with TestClient(app) as client:
@@ -256,12 +263,12 @@ def test_generation_error_marks_audio_done(self):
ws.send_json({"type": "input.done"})
assert ws.receive_json() == {"type": "session.done", "total_sentences": 1}
- def test_streaming_generation_error_marks_audio_done(self):
- speech_service = MagicMock(spec=OmniOpenAIServingSpeech)
- speech_service._generate_audio_bytes = AsyncMock(return_value=(b"", "audio/wav"))
- speech_service._prepare_speech_generation = AsyncMock(return_value=("req-stream-err", object(), {}))
- speech_service.engine_client = MagicMock()
- speech_service.engine_client.abort = AsyncMock()
+ def test_streaming_generation_error_marks_audio_done(self, mocker: MockerFixture):
+ speech_service = mocker.MagicMock(spec=OmniOpenAIServingSpeech)
+ speech_service._generate_audio_bytes = mocker.AsyncMock(return_value=(b"", "audio/wav"))
+ speech_service._prepare_speech_generation = mocker.AsyncMock(return_value=("req-stream-err", object(), {}))
+ speech_service.engine_client = mocker.MagicMock()
+ speech_service.engine_client.abort = mocker.AsyncMock()
async def mock_generate_pcm_chunks(_generator, _request_id):
yield b"\x01\x02"
@@ -298,8 +305,8 @@ async def mock_generate_pcm_chunks(_generator, _request_id):
ws.send_json({"type": "input.done"})
assert ws.receive_json() == {"type": "session.done", "total_sentences": 1}
- def test_invalid_input_text_type_returns_validation_error(self):
- app, speech_service = _build_test_app()
+ def test_invalid_input_text_type_returns_validation_error(self, mocker: MockerFixture):
+ app, speech_service = _build_test_app(mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -316,9 +323,9 @@ def test_invalid_input_text_type_returns_validation_error(self):
assert speech_service._generate_audio_bytes.await_count == 0
- def test_input_text_message_too_large(self, monkeypatch):
+ def test_input_text_message_too_large(self, monkeypatch, mocker: MockerFixture):
monkeypatch.setattr(streaming_speech_module, "_MAX_INPUT_TEXT_MESSAGE_SIZE", 32)
- app, speech_service = _build_test_app()
+ app, speech_service = _build_test_app(mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -335,9 +342,9 @@ def test_input_text_message_too_large(self, monkeypatch):
assert speech_service._generate_audio_bytes.await_count == 0
- def test_session_config_message_too_large(self, monkeypatch):
+ def test_session_config_message_too_large(self, monkeypatch, mocker: MockerFixture):
monkeypatch.setattr(streaming_speech_module, "_MAX_CONFIG_MESSAGE_SIZE", 64)
- app, _ = _build_test_app()
+ app, _ = _build_test_app(mocker=mocker)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -348,12 +355,12 @@ def test_session_config_message_too_large(self, monkeypatch):
"message": "session.config message too large",
}
- def test_disconnect_aborts_streaming_request(self):
- speech_service = MagicMock(spec=OmniOpenAIServingSpeech)
- speech_service._generate_audio_bytes = AsyncMock(return_value=(b"", "audio/wav"))
- speech_service._prepare_speech_generation = AsyncMock(return_value=("req-abort", object(), {}))
- speech_service.engine_client = MagicMock()
- speech_service.engine_client.abort = AsyncMock()
+ def test_disconnect_aborts_streaming_request(self, mocker: MockerFixture):
+ speech_service = mocker.MagicMock(spec=OmniOpenAIServingSpeech)
+ speech_service._generate_audio_bytes = mocker.AsyncMock(return_value=(b"", "audio/wav"))
+ speech_service._prepare_speech_generation = mocker.AsyncMock(return_value=("req-abort", object(), {}))
+ speech_service.engine_client = mocker.MagicMock()
+ speech_service.engine_client.abort = mocker.AsyncMock()
async def mock_generate_pcm_chunks(_generator, _request_id):
yield b"\x01\x02"
@@ -361,11 +368,11 @@ async def mock_generate_pcm_chunks(_generator, _request_id):
speech_service._generate_pcm_chunks = mock_generate_pcm_chunks
handler = OmniStreamingSpeechHandler(speech_service=speech_service)
- websocket = MagicMock()
- websocket.send_json = AsyncMock(side_effect=[None, WebSocketDisconnect()])
- websocket.send_bytes = AsyncMock(side_effect=WebSocketDisconnect())
+ websocket = mocker.MagicMock()
+ websocket.send_json = mocker.AsyncMock(side_effect=[None, WebSocketDisconnect()])
+ websocket.send_bytes = mocker.AsyncMock(side_effect=WebSocketDisconnect())
- config = MagicMock()
+ config = mocker.MagicMock()
config.model = None
config.voice = "Vivian"
config.task_type = None
diff --git a/tests/entrypoints/openai_api/test_serving_speech_voxcpm.py b/tests/entrypoints/openai_api/test_serving_speech_voxcpm.py
new file mode 100644
index 0000000000..48660b6d1c
--- /dev/null
+++ b/tests/entrypoints/openai_api/test_serving_speech_voxcpm.py
@@ -0,0 +1,143 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""UTs for VoxCPM OpenAI speech serving behavior."""
+
+import asyncio
+from types import SimpleNamespace
+from unittest.mock import AsyncMock
+
+import pytest
+from pytest_mock import MockerFixture
+
+from vllm_omni.entrypoints.openai.protocol.audio import OpenAICreateSpeechRequest
+from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+@pytest.fixture
+def voxcpm_server(mocker: MockerFixture):
+ mocker.patch.object(OmniOpenAIServingSpeech, "_load_supported_speakers", return_value=set())
+ mocker.patch.object(OmniOpenAIServingSpeech, "_load_codec_frame_rate", return_value=None)
+
+ mock_engine_client = mocker.MagicMock()
+ mock_engine_client.errored = False
+ mock_engine_client.model_config = mocker.MagicMock(model="OpenBMB/VoxCPM1.5")
+ mock_engine_client.default_sampling_params_list = [SimpleNamespace(max_tokens=2048)]
+ mock_engine_client.tts_batch_max_items = 32
+ mock_engine_client.generate = mocker.MagicMock(return_value="generator")
+ mock_engine_client.stage_configs = [
+ SimpleNamespace(
+ engine_args=SimpleNamespace(
+ model_stage="latent_generator",
+ model_arch="VoxCPMForConditionalGeneration",
+ ),
+ tts_args={},
+ ),
+ SimpleNamespace(
+ engine_args=SimpleNamespace(model_stage="vae"),
+ tts_args={},
+ ),
+ ]
+
+ mock_models = mocker.MagicMock()
+ mock_models.is_base_model.return_value = True
+
+ return OmniOpenAIServingSpeech(
+ engine_client=mock_engine_client,
+ models=mock_models,
+ request_logger=mocker.MagicMock(),
+ )
+
+
+class TestVoxCPMServing:
+ def test_voxcpm_model_type_detection(self, voxcpm_server):
+ assert voxcpm_server._tts_model_type == "voxcpm"
+ assert voxcpm_server._is_tts is True
+ assert voxcpm_server.supported_speakers == set()
+
+ @pytest.mark.parametrize(
+ ("request_kwargs", "expected_substring"),
+ [
+ ({"voice": "alice"}, "voice"),
+ ({"instructions": "whisper"}, "instructions"),
+ ({"language": "en"}, "language"),
+ ({"task_type": "CustomVoice"}, "plain tts"),
+ ({"x_vector_only_mode": True}, "x_vector_only_mode"),
+ ({"speaker_embedding": [0.1, 0.2]}, "speaker_embedding"),
+ ({"initial_codec_chunk_frames": 4}, "initial_codec_chunk_frames"),
+ ({"ref_text": "reference"}, "ref_audio"),
+ ],
+ )
+ def test_validate_voxcpm_rejects_unsupported_fields(self, voxcpm_server, request_kwargs, expected_substring):
+ request = OpenAICreateSpeechRequest(input="hello voxcpm", **request_kwargs)
+ error = voxcpm_server._validate_voxcpm_request(request)
+ assert error is not None
+ assert expected_substring in error.lower()
+
+ def test_validate_voxcpm_accepts_plain_tts_request(self, voxcpm_server):
+ request = OpenAICreateSpeechRequest(input="hello voxcpm", max_new_tokens=256)
+ assert voxcpm_server._validate_voxcpm_request(request) is None
+
+ def test_validate_voxcpm_accepts_voice_clone_request(self, voxcpm_server):
+ request = OpenAICreateSpeechRequest(
+ input="clone this voice",
+ ref_audio="data:audio/wav;base64,QUJD",
+ ref_text="reference transcript",
+ max_new_tokens=256,
+ )
+ assert voxcpm_server._validate_voxcpm_request(request) is None
+
+ def test_prepare_speech_generation_voxcpm_text_only(self, voxcpm_server):
+ request = OpenAICreateSpeechRequest(input="hello voxcpm", max_new_tokens=321)
+
+ request_id, generator, tts_params = asyncio.run(voxcpm_server._prepare_speech_generation(request))
+
+ assert request_id.startswith("speech-")
+ assert generator == "generator"
+ assert tts_params == {
+ "text": ["hello voxcpm"],
+ "cfg_value": [2.0],
+ "inference_timesteps": [10],
+ "min_len": [2],
+ "max_new_tokens": [321],
+ }
+
+ voxcpm_server.engine_client.generate.assert_called_once()
+ call = voxcpm_server.engine_client.generate.call_args
+ assert call.kwargs["prompt"] == {
+ "prompt_token_ids": [1],
+ "additional_information": tts_params,
+ }
+ assert call.kwargs["output_modalities"] == ["audio"]
+
+ def test_prepare_speech_generation_voxcpm_voice_clone_resolves_ref_audio(self, voxcpm_server):
+ voxcpm_server._resolve_ref_audio = AsyncMock(return_value=([0.1, -0.1, 0.2], 16000))
+ request = OpenAICreateSpeechRequest(
+ input="clone this voice",
+ ref_audio="data:audio/wav;base64,QUJD",
+ ref_text="reference transcript",
+ max_new_tokens=512,
+ )
+
+ request_id, generator, tts_params = asyncio.run(voxcpm_server._prepare_speech_generation(request))
+
+ assert request_id.startswith("speech-")
+ assert generator == "generator"
+ assert tts_params == {
+ "text": ["clone this voice"],
+ "cfg_value": [2.0],
+ "inference_timesteps": [10],
+ "min_len": [2],
+ "max_new_tokens": [512],
+ "ref_text": ["reference transcript"],
+ "ref_audio": [[[0.1, -0.1, 0.2], 16000]],
+ }
+
+ voxcpm_server._resolve_ref_audio.assert_awaited_once_with("data:audio/wav;base64,QUJD")
+ call = voxcpm_server.engine_client.generate.call_args
+ assert call.kwargs["prompt"] == {
+ "prompt_token_ids": [1],
+ "additional_information": tts_params,
+ }
diff --git a/tests/entrypoints/openai_api/test_video_api_utils.py b/tests/entrypoints/openai_api/test_video_api_utils.py
new file mode 100644
index 0000000000..9e732403fb
--- /dev/null
+++ b/tests/entrypoints/openai_api/test_video_api_utils.py
@@ -0,0 +1,93 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for OpenAI-compatible video API encoding helpers."""
+
+import numpy as np
+import pytest
+import torch
+
+from vllm_omni.diffusion.postprocess import rife_interpolator
+from vllm_omni.entrypoints.openai import video_api_utils
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+def _install_fake_video_mux(monkeypatch, mux_calls):
+ def _fake_mux_video_audio_bytes(frames, audio, fps, audio_sample_rate, video_codec_options=None):
+ mux_calls.append(
+ {
+ "frames": frames,
+ "audio": audio,
+ "fps": fps,
+ "audio_sample_rate": audio_sample_rate,
+ "video_codec_options": video_codec_options,
+ }
+ )
+ return b"fake-video"
+
+ monkeypatch.setattr(
+ "vllm_omni.diffusion.utils.media_utils.mux_video_audio_bytes",
+ _fake_mux_video_audio_bytes,
+ )
+
+
+def test_encode_video_bytes_exports_frames_without_interpolation(monkeypatch):
+ mux_calls = []
+ _install_fake_video_mux(monkeypatch, mux_calls)
+
+ frames = [np.full((2, 2, 3), fill_value=i / 5, dtype=np.float32) for i in range(5)]
+ video_bytes = video_api_utils._encode_video_bytes(
+ frames,
+ fps=8,
+ )
+
+ assert video_bytes == b"fake-video"
+ assert mux_calls[0]["frames"].shape == (5, 2, 2, 3)
+ assert mux_calls[0]["frames"].dtype == np.uint8
+ assert mux_calls[0]["fps"] == 8.0
+ assert mux_calls[0]["audio"] is None
+
+
+def test_rife_model_inference_runs_on_dummy_tensors():
+ model = rife_interpolator.Model().eval()
+ img0 = torch.rand(1, 3, 32, 32)
+ img1 = torch.rand(1, 3, 32, 32)
+
+ output = model.inference(img0, img1, scale=1.0)
+
+ assert output.shape == (1, 3, 32, 32)
+ assert torch.isfinite(output).all()
+
+
+def test_frame_interpolator_runs_actual_torch_tensor_path(monkeypatch):
+ model = rife_interpolator.Model().eval()
+ interpolator = rife_interpolator.FrameInterpolator()
+ monkeypatch.setattr(interpolator, "_ensure_model_loaded", lambda preferred_device=None: model)
+
+ video = torch.zeros(1, 3, 2, 32, 32)
+ output_video, multiplier = interpolator.interpolate_tensor(video, exp=1, scale=1.0)
+
+ assert multiplier == 2
+ assert output_video.shape == (1, 3, 3, 32, 32)
+ assert torch.isfinite(output_video).all()
+
+
+def test_frame_interpolator_uses_platform_device_when_tensor_is_cpu(monkeypatch):
+ chosen_devices = []
+ model = rife_interpolator.Model().eval()
+
+ def _fake_ensure_model_loaded(*, preferred_device=None):
+ chosen_devices.append(preferred_device)
+ return model
+
+ interpolator = rife_interpolator.FrameInterpolator()
+ monkeypatch.setattr(interpolator, "_ensure_model_loaded", _fake_ensure_model_loaded)
+ monkeypatch.setattr(model.flownet, "to", lambda device: model.flownet)
+ monkeypatch.setattr(rife_interpolator, "_select_torch_device", lambda: torch.device("cuda"))
+
+ video = torch.zeros(1, 3, 2, 32, 32)
+ output_video, multiplier = interpolator.interpolate_tensor(video, exp=1, scale=1.0)
+
+ assert chosen_devices == [torch.device("cuda")]
+ assert multiplier == 2
+ assert output_video.shape == (1, 3, 3, 32, 32)
diff --git a/tests/entrypoints/openai_api/test_video_server.py b/tests/entrypoints/openai_api/test_video_server.py
index 7200b38abb..6157d82313 100644
--- a/tests/entrypoints/openai_api/test_video_server.py
+++ b/tests/entrypoints/openai_api/test_video_server.py
@@ -34,12 +34,27 @@
class MockVideoResult:
- def __init__(self, videos, audios=None, sample_rate=None):
+ def __init__(
+ self,
+ videos,
+ audios=None,
+ sample_rate=None,
+ custom_output=None,
+ stage_durations=None,
+ peak_memory_mb=0.0,
+ ):
self.multimodal_output = {"video": videos}
if audios is not None:
self.multimodal_output["audio"] = audios
if sample_rate is not None:
self.multimodal_output["audio_sample_rate"] = sample_rate
+ self._custom_output = custom_output or {}
+ self.stage_durations = stage_durations or {}
+ self.peak_memory_mb = peak_memory_mb
+
+ @property
+ def custom_output(self):
+ return self._custom_output
class FakeAsyncOmni:
@@ -67,7 +82,7 @@ def set_stage_configs_if_missing(self, stage_configs):
if self.stage_configs is None:
self.stage_configs = stage_configs
- async def generate_videos(self, request, reference_id, *, reference_image=None):
+ async def generate_video_bytes(self, request, reference_id, *, reference_image=None):
self.started.set()
try:
await asyncio.Future()
@@ -135,15 +150,81 @@ def _wait_until(predicate, timeout_s: float = 2.0, interval_s: float = 0.02):
raise AssertionError("Timed out waiting for condition")
+def test_async_video_generation_bypasses_base64(test_client, mocker: MockerFixture):
+ """Regression test: Ensure async video generation saves raw bytes directly
+ without bouncing through base64 encoding."""
+ # We mock _encode_video_bytes (the correct path)
+ mocker.patch(
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ return_value=b"raw-mp4-bytes",
+ )
+
+ # We assert that encode_video_base64 is never called
+ mock_base64 = mocker.patch(
+ "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
+ side_effect=RuntimeError("Regression: async video path should not base64 encode"),
+ )
+
+ response = test_client.post(
+ "/v1/videos",
+ data={"prompt": "A base64 test."},
+ )
+ assert response.status_code == 200
+ video_id = response.json()["id"]
+
+ # Wait for completion. If it used base64, the RuntimeError would fail the task
+ _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value)
+ mock_base64.assert_not_called()
+
+
+def test_async_video_generation_with_audio_bypasses_base64(test_client, mocker: MockerFixture):
+ """Regression test: Ensure async video generation passes audio through
+ generate_video_bytes without bouncing through base64 encoding."""
+ mock_encode = mocker.patch(
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ return_value=b"raw-mp4-bytes",
+ )
+
+ mock_base64 = mocker.patch(
+ "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
+ side_effect=RuntimeError("Regression: async video path should not base64 encode"),
+ )
+
+ engine = test_client.app.state.openai_serving_video._engine_client
+
+ async def _generate(prompt, request_id, sampling_params_list):
+ engine.captured_prompt = prompt
+ engine.captured_sampling_params_list = sampling_params_list
+ yield MockVideoResult([object()], audios=[object()], sample_rate=48000)
+
+ engine.generate = _generate
+
+ response = test_client.post(
+ "/v1/videos",
+ data={"prompt": "A base64 test with audio."},
+ )
+ assert response.status_code == 200
+ video_id = response.json()["id"]
+
+ _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value)
+ mock_base64.assert_not_called()
+
+ mock_encode.assert_called_once()
+ kwargs = mock_encode.call_args.kwargs
+ assert "audio" in kwargs
+ assert kwargs["audio"] is not None
+ assert kwargs["audio_sample_rate"] == 48000
+
+
def test_t2v_video_generation_form(test_client, mocker: MockerFixture):
fps_values = []
- def _fake_encode(video, fps):
+ def _fake_encode(video, fps, audio=None, audio_sample_rate=None, **kwargs):
fps_values.append(fps)
- return "Zg=="
+ return b"fake-video"
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
side_effect=_fake_encode,
)
response = test_client.post(
@@ -175,8 +256,8 @@ def test_i2v_video_generation_form(test_client, mocker: MockerFixture):
image_bytes = _make_test_image_bytes((48, 32))
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
- return_value="Zg==",
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ return_value=b"fake-video",
)
response = test_client.post(
"/v1/videos",
@@ -201,8 +282,8 @@ def test_i2v_video_generation_resizes_input_to_requested_dimensions(test_client,
image_bytes = _make_test_image_bytes((48, 32))
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
- return_value="Zg==",
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ return_value=b"fake-video",
)
response = test_client.post(
"/v1/videos",
@@ -227,8 +308,8 @@ def test_i2v_video_generation_resizes_input_to_requested_dimensions(test_client,
def test_i2v_video_generation_with_image_reference_form(test_client, mocker: MockerFixture):
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
- return_value="Zg==",
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ return_value=b"fake-video",
)
response = test_client.post(
"/v1/videos",
@@ -252,12 +333,12 @@ def test_i2v_video_generation_with_image_reference_form(test_client, mocker: Moc
def test_seconds_defaults_fps_and_frames(test_client, mocker: MockerFixture):
fps_values = []
- def _fake_encode(video, fps):
+ def _fake_encode(video, fps, audio=None, audio_sample_rate=None, **kwargs):
fps_values.append(fps)
- return "Zg=="
+ return b"fake-video"
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
side_effect=_fake_encode,
)
response = test_client.post(
@@ -281,8 +362,8 @@ def _fake_encode(video, fps):
def test_size_param_sets_width_height(test_client, mocker: MockerFixture):
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
- return_value="Zg==",
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ return_value=b"fake-video",
)
response = test_client.post(
"/v1/videos",
@@ -303,8 +384,8 @@ def test_size_param_sets_width_height(test_client, mocker: MockerFixture):
def test_sampling_params_pass_through(test_client, mocker: MockerFixture):
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
- return_value="Zg==",
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ return_value=b"fake-video",
)
response = test_client.post(
"/v1/videos",
@@ -332,13 +413,74 @@ def test_sampling_params_pass_through(test_client, mocker: MockerFixture):
assert captured.extra_args["flow_shift"] == 0.25
+def test_frame_interpolation_params_pass_to_diffusion_sampling_params(test_client, mocker: MockerFixture):
+ """Frame interpolation parameters should be forwarded to diffusion worker sampling params."""
+ mocker.patch(
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ return_value=b"fake-video",
+ )
+ response = test_client.post(
+ "/v1/videos",
+ data={
+ "prompt": "smooth motion",
+ "fps": "8",
+ "enable_frame_interpolation": "true",
+ "frame_interpolation_exp": "2",
+ "frame_interpolation_scale": "0.5",
+ "frame_interpolation_model_path": "local-rife",
+ },
+ )
+
+ assert response.status_code == 200
+ video_id = response.json()["id"]
+ _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value)
+
+ engine = test_client.app.state.openai_serving_video._engine_client
+ captured = engine.captured_sampling_params_list[0]
+ assert captured.enable_frame_interpolation is True
+ assert captured.frame_interpolation_exp == 2
+ assert captured.frame_interpolation_scale == 0.5
+ assert captured.frame_interpolation_model_path == "local-rife"
+
+
+def test_worker_fps_multiplier_is_applied_to_async_encoding(test_client, mocker: MockerFixture):
+ fps_values = []
+ engine = test_client.app.state.openai_serving_video._engine_client
+
+ async def _generate(prompt, request_id, sampling_params_list):
+ engine.captured_prompt = prompt
+ engine.captured_sampling_params_list = sampling_params_list
+ import numpy as np
+
+ yield MockVideoResult([np.zeros((1, 64, 64, 3), dtype=np.uint8)], custom_output={"video_fps_multiplier": 2})
+
+ engine.generate = _generate
+
+ def _fake_encode(video, fps, **kwargs):
+ del video, kwargs
+ fps_values.append(fps)
+ return b"fake-video"
+
+ mocker.patch(
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ side_effect=_fake_encode,
+ )
+
+ response = test_client.post("/v1/videos", data={"prompt": "fps multiplier", "fps": "8"})
+
+ assert response.status_code == 200
+ video_id = response.json()["id"]
+ _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value)
+ assert fps_values == [16]
+
+
def test_audio_sample_rate_comes_from_model_config(test_client, mocker: MockerFixture):
audio_sample_rates = []
- def _fake_encode(video, fps, audio=None, audio_sample_rate=None):
- del video, fps, audio
+ def _fake_encode(video, fps, audio=None, audio_sample_rate=None, video_codec_options=None):
+ del video, fps, audio, video_codec_options
audio_sample_rates.append(audio_sample_rate)
- return "Zg=="
+ return b"fake-video"
engine = test_client.app.state.openai_serving_video._engine_client
engine.model_config = SimpleNamespace(
@@ -352,12 +494,14 @@ def _fake_encode(video, fps, audio=None, audio_sample_rate=None):
async def _generate(prompt, request_id, sampling_params_list):
engine.captured_prompt = prompt
engine.captured_sampling_params_list = sampling_params_list
- yield MockVideoResult([object()], audios=[object()])
+ import numpy as np
+
+ yield MockVideoResult([np.zeros((1, 64, 64, 3), dtype=np.uint8)], audios=[object()])
engine.generate = _generate
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
side_effect=_fake_encode,
)
response = test_client.post(
@@ -371,6 +515,33 @@ async def _generate(prompt, request_id, sampling_params_list):
assert audio_sample_rates == [16000]
+def test_video_job_persists_profiler_metadata(test_client, mocker: MockerFixture):
+ engine = test_client.app.state.openai_serving_video._engine_client
+
+ async def _generate(prompt, request_id, sampling_params_list):
+ engine.captured_prompt = prompt
+ engine.captured_sampling_params_list = sampling_params_list
+ yield MockVideoResult(
+ [object()],
+ stage_durations={"diffuse": 2.5, "vae.decode": 0.3},
+ peak_memory_mb=4096.5,
+ )
+
+ engine.generate = _generate
+ mocker.patch(
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ return_value=b"fake-video",
+ )
+
+ response = test_client.post("/v1/videos", data={"prompt": "profile me"})
+ assert response.status_code == 200
+ video_id = response.json()["id"]
+ completed = _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value)
+
+ assert completed["stage_durations"] == {"diffuse": 2.5, "vae.decode": 0.3}
+ assert completed["peak_memory_mb"] == 4096.5
+
+
def test_missing_handler_returns_503():
app = FastAPI()
app.include_router(router)
@@ -393,6 +564,18 @@ def test_missing_prompt_returns_422(test_client):
assert response.status_code == 422
+def test_video_generation_rejects_model_mismatch(test_client):
+ response = test_client.post(
+ "/v1/videos",
+ data={
+ "prompt": "bad model",
+ "model": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
+ },
+ )
+ assert response.status_code == 400
+ assert "model mismatch" in response.json()["detail"].lower()
+
+
def test_invalid_size_parse_returns_422(test_client):
response = test_client.post(
"/v1/videos",
@@ -428,8 +611,8 @@ def test_invalid_seconds_returns_422(test_client):
def test_negative_prompt_and_seed_pass_through(test_client, mocker: MockerFixture):
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
- return_value="Zg==",
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ return_value=b"fake-video",
)
response = test_client.post(
"/v1/videos",
@@ -498,12 +681,16 @@ def test_video_request_validation():
with pytest.raises(ValueError):
VideoGenerationRequest(prompt="test", image_reference={"file_id": "file-1", "image_url": "https://example.com"})
+ with pytest.raises(ValueError):
+ VideoGenerationRequest(prompt="test", frame_interpolation_exp=0)
+ with pytest.raises(ValueError):
+ VideoGenerationRequest(prompt="test", frame_interpolation_scale=0)
def test_list_videos_supports_order_after_and_limit(test_client, mocker: MockerFixture):
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
- return_value="Zg==",
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ return_value=b"fake-video",
)
ids = []
for i in range(3):
@@ -571,8 +758,8 @@ def test_list_videos_supports_order_after_and_limit(test_client, mocker: MockerF
def test_delete_completed_job_removes_file_and_metadata(test_client, mocker: MockerFixture):
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
- return_value="Zg==",
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ return_value=b"fake-video",
)
create_resp = test_client.post("/v1/videos", data={"prompt": "Delete this video"})
assert create_resp.status_code == 200
@@ -643,8 +830,8 @@ def test_video_response_file_extension_is_robust():
def test_extra_params_merged_into_extra_args(test_client, mocker: MockerFixture):
"""extra_params JSON object is merged into sampling_params.extra_args."""
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
- return_value="Zg==",
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ return_value=b"fake-video",
)
extra_params = {
"is_enable_stage2": True,
@@ -674,8 +861,8 @@ def test_extra_params_merged_into_extra_args(test_client, mocker: MockerFixture)
def test_extra_params_none_by_default(test_client, mocker: MockerFixture):
"""When extra_params is omitted, extra_args stays empty."""
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
- return_value="Zg==",
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ return_value=b"fake-video",
)
response = test_client.post(
"/v1/videos",
@@ -715,8 +902,8 @@ def test_extra_params_invalid_json(test_client):
def test_extra_params_merged_with_existing_extra_args(test_client, mocker: MockerFixture):
"""extra_params is merged on top of existing extra_args (e.g. flow_shift)."""
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
- return_value="Zg==",
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ return_value=b"fake-video",
)
response = test_client.post(
"/v1/videos",
@@ -737,6 +924,28 @@ def test_extra_params_merged_with_existing_extra_args(test_client, mocker: Mocke
assert captured.extra_args["zero_steps"] == 2
+def test_sample_solver_forwarded_via_extra_params(test_client, mocker: MockerFixture):
+ """sample_solver can be passed through existing extra_params for Wan2.2 online serving."""
+ mocker.patch(
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ return_value=b"fake-video",
+ )
+ response = test_client.post(
+ "/v1/videos",
+ data={
+ "prompt": "A fox running through snow.",
+ "extra_params": json.dumps({"sample_solver": "euler"}),
+ },
+ )
+
+ assert response.status_code == 200
+ video_id = response.json()["id"]
+ _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value)
+ engine = test_client.app.state.openai_serving_video._engine_client
+ captured = engine.captured_sampling_params_list[0]
+ assert captured.extra_args["sample_solver"] == "euler"
+
+
# ---------------------------------------------------------------------------
# Sync endpoint tests (POST /v1/videos/sync)
# ---------------------------------------------------------------------------
@@ -770,6 +979,31 @@ def test_sync_t2v_returns_video_bytes(test_client, mocker: MockerFixture):
assert response.headers["x-request-id"].startswith("video_sync-")
assert response.headers["x-model"] == "Wan-AI/Wan2.2-T2V-A14B-Diffusers"
assert float(response.headers["x-inference-time-s"]) >= 0
+ assert json.loads(response.headers["x-stage-durations"]) == {}
+ assert float(response.headers["x-peak-memory-mb"]) == 0.0
+
+
+def test_sync_t2v_returns_profiler_headers(test_client, mocker: MockerFixture):
+ engine = test_client.app.state.openai_serving_video._engine_client
+
+ async def _generate(prompt, request_id, sampling_params_list):
+ engine.captured_prompt = prompt
+ engine.captured_sampling_params_list = sampling_params_list
+ yield MockVideoResult(
+ [object()],
+ stage_durations={"diffuse": 1.75},
+ peak_memory_mb=1234.25,
+ )
+
+ engine.generate = _generate
+ _mock_encode_video_bytes(mocker, b"profiled-video")
+
+ response = test_client.post("/v1/videos/sync", data={"prompt": "sync profile"})
+
+ assert response.status_code == 200
+ assert response.content == b"profiled-video"
+ assert json.loads(response.headers["x-stage-durations"]) == {"diffuse": 1.75}
+ assert float(response.headers["x-peak-memory-mb"]) == pytest.approx(1234.25, rel=0, abs=1e-3)
def test_sync_i2v_returns_video_bytes(test_client, mocker: MockerFixture):
@@ -888,3 +1122,57 @@ def test_sync_sampling_params_pass_through(test_client, mocker: MockerFixture):
assert captured.num_inference_steps == 30
assert captured.guidance_scale == 6.5
assert captured.seed == 42
+
+
+def test_sync_frame_interpolation_params_pass_to_sampling_params(test_client, mocker: MockerFixture):
+ """Frame interpolation parameters should be forwarded on the sync path."""
+ encode_mock = _mock_encode_video_bytes(mocker)
+ response = test_client.post(
+ "/v1/videos/sync",
+ data={
+ "prompt": "smooth sync",
+ "fps": "8",
+ "enable_frame_interpolation": "true",
+ "frame_interpolation_exp": "2",
+ "frame_interpolation_scale": "0.5",
+ "frame_interpolation_model_path": "local-rife",
+ },
+ )
+
+ assert response.status_code == 200
+ engine = test_client.app.state.openai_serving_video._engine_client
+ captured = engine.captured_sampling_params_list[0]
+ assert captured.enable_frame_interpolation is True
+ assert captured.frame_interpolation_exp == 2
+ assert captured.frame_interpolation_scale == 0.5
+ assert captured.frame_interpolation_model_path == "local-rife"
+ _, kwargs = encode_mock.call_args
+ assert kwargs["fps"] == 8
+
+
+def test_worker_fps_multiplier_is_applied_to_sync_encoding(test_client, mocker: MockerFixture):
+ engine = test_client.app.state.openai_serving_video._engine_client
+ fps_values = []
+
+ async def _generate(prompt, request_id, sampling_params_list):
+ engine.captured_prompt = prompt
+ engine.captured_sampling_params_list = sampling_params_list
+ yield MockVideoResult([object()], custom_output={"video_fps_multiplier": 2})
+
+ engine.generate = _generate
+
+ def _fake_encode(video, fps, **kwargs):
+ del video, kwargs
+ fps_values.append(fps)
+ return b"fps-multiplied"
+
+ mocker.patch(
+ "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ side_effect=_fake_encode,
+ )
+
+ response = test_client.post("/v1/videos/sync", data={"prompt": "fps multiplier", "fps": "8"})
+
+ assert response.status_code == 200
+ assert response.content == b"fps-multiplied"
+ assert fps_values == [16]
diff --git a/tests/entrypoints/test_async_omni_diffusion_config.py b/tests/entrypoints/test_async_omni_diffusion_config.py
index ca5624f2d4..7ed8128260 100644
--- a/tests/entrypoints/test_async_omni_diffusion_config.py
+++ b/tests/entrypoints/test_async_omni_diffusion_config.py
@@ -69,6 +69,20 @@ def test_default_stage_config_propagates_ulysses_mode():
assert parallel_config.ulysses_mode == "advanced_uaa"
+def test_default_stage_config_includes_default_sampling_params():
+ """Ensure default sampling params survive the default diffusion-stage builder."""
+ stage_cfg = AsyncOmniEngine._create_default_diffusion_stage_cfg(
+ {
+ "default_sampling_params": '{"0": {"generator_device":"cpu", "guidance_scale":7.5}}',
+ }
+ )[0]
+
+ assert stage_cfg["default_sampling_params"] == {
+ "generator_device": "cpu",
+ "guidance_scale": 7.5,
+ }
+
+
def test_serve_cli_accepts_ulysses_mode():
"""Ensure diffusion serve CLI exposes ulysses_mode and wires it to parallel_config."""
parser = FlexibleArgumentParser()
@@ -93,3 +107,24 @@ def test_serve_cli_accepts_ulysses_mode():
assert args.ulysses_mode == "advanced_uaa"
assert parallel_config.ulysses_degree == 4
assert parallel_config.ulysses_mode == "advanced_uaa"
+
+
+def test_serve_cli_accepts_diffusion_pipeline_profiler_flag():
+ """Ensure diffusion serve CLI exposes the profiler switch."""
+ parser = FlexibleArgumentParser()
+ subparsers = parser.add_subparsers(dest="command")
+ OmniServeCommand().subparser_init(subparsers)
+
+ args = parser.parse_args(
+ [
+ "serve",
+ "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
+ "--omni",
+ "--enable-diffusion-pipeline-profiler",
+ ]
+ )
+
+ stage_cfg = _create_default_diffusion_stage_cfg(args)[0]
+
+ assert args.enable_diffusion_pipeline_profiler is True
+ assert stage_cfg["engine_args"]["enable_diffusion_pipeline_profiler"] is True
diff --git a/tests/entrypoints/test_cfg_companion_tracker.py b/tests/entrypoints/test_cfg_companion_tracker.py
deleted file mode 100644
index 941ead41ff..0000000000
--- a/tests/entrypoints/test_cfg_companion_tracker.py
+++ /dev/null
@@ -1,114 +0,0 @@
-import time
-from types import SimpleNamespace
-
-import pytest
-
-from vllm_omni.entrypoints.cfg_companion_tracker import CfgCompanionTracker
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-def dummy_expand_func(prompt, sp0):
- if prompt == "expand_me":
- return [SimpleNamespace(prompt={"prompt": "neg"}, role="cfg_text", request_id_suffix="__cfg_text")]
- return []
-
-
-@pytest.fixture
-def tracker():
- sp0 = SimpleNamespace()
- return CfgCompanionTracker(prompt_expand_func=dummy_expand_func, stage0_sampling_params=sp0, timeout_s=0.1)
-
-
-def test_companion_tracker_initialization(tracker):
- assert not tracker.is_active
- assert tracker.num_companions == 0
-
-
-def test_expand_prompts_registers_companions(tracker):
- request_id_to_prompt = {"req1": "expand_me", "req2": "do_not_expand"}
-
- pairs = tracker.expand_prompts(request_id_to_prompt)
-
- assert len(pairs) == 1
- companion_id, prompt = pairs[0]
- assert companion_id == "req1__cfg_text"
- assert prompt == {"prompt": "neg"}
-
- assert tracker.is_active
- assert tracker.num_companions == 1
- assert tracker.is_companion("req1__cfg_text")
- assert not tracker.is_companion("req2__cfg_text")
- assert tracker.has_companions("req1")
- assert not tracker.has_companions("req2")
-
- comp_map = tracker.get_companion_request_ids("req1")
- assert comp_map == {"cfg_text": "req1__cfg_text"}
-
-
-def test_companion_lifecycle_success(tracker):
- request_id_to_prompt = {"req1": "expand_me"}
- tracker.expand_prompts(request_id_to_prompt)
-
- # Defer parent
- engine_outputs = {"out": 123}
- tracker.defer_parent("req1", engine_outputs, stage_id=0)
-
- # Initially not done
- assert not tracker.all_companions_done("req1")
-
- # Companion completes
- parent_id = tracker.on_companion_completed("req1__cfg_text")
-
- # Parent should be returned since all companions are done and it is pending
- assert parent_id == "req1"
- assert tracker.all_companions_done("req1")
-
- # Pop pending parent
- popped = tracker.pop_pending_parent("req1")
- assert popped is not None
- assert popped["engine_outputs"] == engine_outputs
- assert popped["stage_id"] == 0
-
-
-def test_companion_lifecycle_failure(tracker):
- request_id_to_prompt = {"req1": "expand_me"}
- tracker.expand_prompts(request_id_to_prompt)
-
- tracker.defer_parent("req1", {"out": 123}, stage_id=0)
-
- # Companion fails
- parent_id, aborted = tracker.on_companion_error("req1__cfg_text")
-
- assert parent_id == "req1"
- assert aborted is True
- assert tracker.is_parent_failed("req1")
-
- # Parent should be removed from pending list
- assert tracker.pop_pending_parent("req1") is None
-
- # Consume failure
- tracker.consume_parent_failure("req1")
- assert not tracker.is_parent_failed("req1")
-
-
-def test_companion_lifecycle_timeout(tracker):
- request_id_to_prompt = {"req1": "expand_me"}
- tracker.expand_prompts(request_id_to_prompt)
-
- tracker.defer_parent("req1", {"out": 123}, stage_id=0)
-
- # Initially no timeouts
- timeouts = tracker.check_timeouts()
- assert len(timeouts) == 0
-
- # Wait for timeout
- time.sleep(0.15)
-
- # Check timeouts again
- timeouts = tracker.check_timeouts()
- assert len(timeouts) == 1
- assert timeouts[0] == "req1"
-
- # Should be removed from pending
- assert tracker.pop_pending_parent("req1") is None
diff --git a/tests/entrypoints/test_omni_base_profiler.py b/tests/entrypoints/test_omni_base_profiler.py
index 0c1ddc6a5d..ca10eed91f 100644
--- a/tests/entrypoints/test_omni_base_profiler.py
+++ b/tests/entrypoints/test_omni_base_profiler.py
@@ -1,8 +1,7 @@
"""Unit tests for OmniBase and AsyncOmni profiler methods."""
-from unittest.mock import MagicMock, patch
-
import pytest
+from pytest_mock import MockerFixture
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
@@ -11,12 +10,12 @@ class TestOmniBaseProfiler:
"""Test suite for OmniBase profiler methods (start_profile, stop_profile)."""
@pytest.fixture
- def mock_engine(self):
+ def mock_engine(self, mocker: MockerFixture):
"""Create a mock AsyncOmniEngine for testing."""
- engine = MagicMock()
+ engine = mocker.MagicMock()
engine.num_stages = 3
engine.is_alive.return_value = True
- engine.default_sampling_params_list = [MagicMock() for _ in range(3)]
+ engine.default_sampling_params_list = [mocker.MagicMock() for _ in range(3)]
engine.get_stage_metadata.side_effect = lambda i: {
"final_output_type": "text" if i == 0 else "audio",
"final_output": True,
@@ -25,17 +24,15 @@ def mock_engine(self):
return engine
@pytest.fixture
- def omni_base_instance(self, mock_engine):
+ def omni_base_instance(self, mock_engine, mocker: MockerFixture):
"""Create an OmniBase instance with mocked dependencies."""
- with (
- patch("vllm_omni.entrypoints.omni_base.AsyncOmniEngine", return_value=mock_engine),
- patch("vllm_omni.entrypoints.omni_base.omni_snapshot_download", side_effect=lambda x: x),
- patch("vllm_omni.entrypoints.omni_base.weakref.finalize"),
- ):
- from vllm_omni.entrypoints.omni_base import OmniBase
-
- instance = OmniBase(model="test-model")
- return instance
+ mocker.patch("vllm_omni.entrypoints.omni_base.AsyncOmniEngine", return_value=mock_engine)
+ mocker.patch("vllm_omni.entrypoints.omni_base.omni_snapshot_download", side_effect=lambda x: x)
+ mocker.patch("vllm_omni.entrypoints.omni_base.weakref.finalize")
+ from vllm_omni.entrypoints.omni_base import OmniBase
+
+ instance = OmniBase(model="test-model")
+ return instance
def test_start_profile_calls_collective_rpc(self, omni_base_instance, mock_engine):
"""Test that start_profile calls collective_rpc with correct arguments."""
diff --git a/tests/entrypoints/test_pd_disaggregation.py b/tests/entrypoints/test_pd_disaggregation.py
new file mode 100644
index 0000000000..5ffabfbf2a
--- /dev/null
+++ b/tests/entrypoints/test_pd_disaggregation.py
@@ -0,0 +1,1222 @@
+"""Unit tests for PD (Prefill-Decode) disaggregation in the Omni orchestrator.
+
+Tests the PD detection, validation, config parsing, sampling param
+preparation, and routing logic added by the PD disaggregation feature
+(issue #1188). All tests run without GPU.
+
+NOTE (v1908 adaptation): Tests that relied on the old OmniStage / stage_list
+architecture (removed in PR #1908) are marked xfail with
+``reason="Requires migration to v1908 Orchestrator architecture"``.
+The remaining tests exercise PDDisaggregationMixin directly and work
+without spinning up a real engine.
+"""
+
+import uuid
+import warnings
+from queue import Empty, Queue
+from types import SimpleNamespace
+from typing import Any
+
+import pytest
+from vllm import SamplingParams
+
+from vllm_omni.entrypoints.pd_utils import PDDisaggregationMixin
+
+pytestmark = pytest.mark.skip(reason="Temporarily skip PD entrypoint tests while PD config is being removed.")
+
+# Suppress noisy DeprecationWarnings from optional Swig bindings imported by vLLM dependencies.
+warnings.filterwarnings(
+ "ignore",
+ message=r"builtin type SwigPy.*has no __module__ attribute",
+ category=DeprecationWarning,
+)
+
+
+def _ns(**kwargs):
+ """Create a lightweight attribute object for tests."""
+ return SimpleNamespace(**kwargs)
+
+
+# ---------------------------------------------------------------------------
+# Fake helpers (same pattern as test_omni_llm.py)
+# ---------------------------------------------------------------------------
+
+
+class _FakeEngineArgs(dict):
+ """Fake engine args that supports both attribute and dict access."""
+
+ def __init__(self, args_dict: dict[str, Any]):
+ super().__init__(args_dict)
+ if "model_stage" not in self:
+ self["model_stage"] = None
+ if "engine_output_type" not in self:
+ self["engine_output_type"] = None
+ for key, value in self.items():
+ setattr(self, key, value)
+
+
+class _FakeStageConfig:
+ def __init__(self, config_dict: dict[str, Any]):
+ engine_args_dict = config_dict.get("engine_args", {})
+ self.engine_args = _FakeEngineArgs(engine_args_dict)
+ self.final_output = config_dict.get("final_output", False)
+ self.final_output_type = config_dict.get("final_output_type", None)
+ self.stage_id = config_dict.get("stage_id", 0)
+ self.is_prefill_only = config_dict.get("is_prefill_only", False)
+ self.is_decode_only = config_dict.get("is_decode_only", False)
+ self.engine_input_source = config_dict.get("engine_input_source", [])
+ self.is_comprehension = config_dict.get("is_comprehension", False)
+ self._config_dict = config_dict
+
+
+class _FakeQueue:
+ def __init__(self, maxsize=0):
+ self._queue = Queue(maxsize=maxsize)
+
+ def put(self, item):
+ self._queue.put(item)
+
+ def put_nowait(self, item):
+ self._queue.put_nowait(item)
+
+ def get(self):
+ return self._queue.get()
+
+ def get_nowait(self):
+ return self._queue.get_nowait()
+
+ def empty(self):
+ return self._queue.empty()
+
+
+class _FakeStage:
+ """Lightweight stage stub with PD disaggregation flag support."""
+
+ def __init__(self, config, stage_init_timeout: int = 300):
+ if isinstance(config, dict):
+ config = _FakeStageConfig(config)
+ self.config = config
+ self.stage_config = config
+ self.engine = None
+ self.engine_outputs = None
+ self.stage_id = getattr(config, "stage_id", 0)
+ self.engine_args = config.engine_args
+ self.model_stage = getattr(config.engine_args, "model_stage", None)
+ self.stage_type = "llm"
+ self.default_sampling_params = SamplingParams(temperature=1.0)
+ self.final_output = config.final_output if hasattr(config, "final_output") else False
+ self.final_output_type = getattr(config, "final_output_type", None)
+ self.is_prefill_only = getattr(config, "is_prefill_only", False)
+ self.is_decode_only = getattr(config, "is_decode_only", False)
+ self.engine_input_source = getattr(config, "engine_input_source", [])
+ self.is_comprehension = getattr(config, "is_comprehension", False)
+ processed_input = getattr(config, "_config_dict", {}).get("processed_input", ["processed"])
+ self._processed_input = processed_input
+ self._in_q = None
+ self._out_q = None
+ self._proc = None
+ self._stage_init_timeout = max(0, int(stage_init_timeout))
+
+ def attach_queues(self, in_q, out_q):
+ self._in_q = in_q
+ self._out_q = out_q
+
+ def init_stage_worker(
+ self, model: str, *, is_async=False, shm_threshold_bytes=65536, ctx=None, batch_timeout=10, **kwargs
+ ):
+ self._proc = _ns(
+ start=lambda: None,
+ join=lambda timeout=None: None,
+ is_alive=lambda: False,
+ terminate=lambda: None,
+ )
+ if self._out_q is not None:
+ try:
+ self._out_q.put_nowait({"type": "stage_ready", "stage_id": self.stage_id})
+ except Exception:
+ pass
+
+ def stop_stage_worker(self):
+ if self._in_q is not None:
+ try:
+ self._in_q.put_nowait({"type": "shutdown"})
+ except Exception:
+ pass
+
+ def submit(self, payload: dict[str, Any]):
+ if self._in_q is not None:
+ self._in_q.put(payload)
+
+ def try_collect(self) -> Any:
+ if self._out_q is None:
+ return None
+ try:
+ return self._out_q.get_nowait()
+ except Empty:
+ return None
+
+ def set_engine_outputs(self, outputs):
+ self.engine_outputs = outputs
+
+ def process_engine_inputs(self, stage_list, prompts):
+ return self._processed_input
+
+
+# ---------------------------------------------------------------------------
+# Shared mock setup helpers
+# ---------------------------------------------------------------------------
+
+
+def _setup_engine_mocks(monkeypatch):
+ fake_engine = _ns()
+ fake_engine.tokenizer = _ns()
+ fake_engine.log_stats = False
+ fake_engine.vllm_config = _ns()
+ fake_engine.vllm_config.model_config = _ns()
+ fake_engine.vllm_config.model_config.io_processor_plugin = None
+ fake_engine.get_supported_tasks = lambda: []
+ fake_engine.model_config = _ns()
+ fake_engine.model_config.io_processor_plugin = None
+ fake_registry = _ns()
+ fake_registry.resolve_model_cls = lambda *args, **kwargs: (_ns(), "test_arch")
+ fake_engine.model_config.registry = fake_registry
+ fake_engine.vllm_config.model_config.registry = fake_registry
+
+ monkeypatch.setattr(
+ "vllm.v1.engine.llm_engine.LLMEngine.from_engine_args",
+ lambda **kw: fake_engine,
+ raising=False,
+ )
+
+ class FakeModelClass:
+ pass
+
+ monkeypatch.setattr(
+ "vllm.model_executor.model_loader.utils.get_model_architecture",
+ lambda model_config: (FakeModelClass, "test_arch"),
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "vllm.model_executor.model_loader.utils._get_model_architecture",
+ lambda model_config: (FakeModelClass, "test_arch"),
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "vllm.model_executor.models.adapters.try_create_mm_pooling_model_cls",
+ lambda model_cls: model_cls,
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "vllm.multimodal.cache._enable_processor_cache",
+ lambda model_config, mm_registry: False,
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "vllm.plugins.io_processors.get_io_processor",
+ lambda vllm_config, io_processor_plugin: None,
+ raising=False,
+ )
+
+
+def _setup_multiprocessing_mocks(monkeypatch):
+ import multiprocessing as mp
+
+ fake_process_instance = _ns(
+ start=lambda: None,
+ join=lambda timeout=None: None,
+ is_alive=lambda: False,
+ terminate=lambda: None,
+ )
+
+ def fake_process_class(*args, **kwargs):
+ return fake_process_instance
+
+ fake_ctx = _ns()
+ fake_ctx.Queue = lambda maxsize=0: _FakeQueue(maxsize=maxsize)
+ fake_ctx.Process = fake_process_class
+
+ monkeypatch.setattr(mp, "get_context", lambda method: fake_ctx, raising=False)
+ monkeypatch.setattr(mp, "Process", fake_process_class, raising=False)
+
+
+def _setup_ipc_mocks(monkeypatch):
+ # These IPC helpers existed in the old architecture; no-op in new arch.
+ pass
+
+
+def _setup_log_mocks(monkeypatch):
+ class _FakeOrchestratorAggregator:
+ def __init__(self, num_stages, enable_stats, wall_start_ts, final_stage_id_for_e2e=None):
+ self.num_stages = num_stages
+ self.enable_stats = enable_stats
+ self.stage_first_ts = [None] * num_stages
+ self.stage_last_ts = [None] * num_stages
+ self.stage_total_tokens = [0] * num_stages
+ self.accumulated_gen_time_ms = {}
+ self.e2e_done = set()
+ self.e2e_count = 0
+ self.e2e_total_ms = 0.0
+
+ def on_stage_metrics(self, stage_id, req_id, metrics, final_output_type=None):
+ pass
+
+ def on_finalize_request(self, stage_id, req_id, start_ts):
+ self.e2e_done.add(req_id)
+
+ def on_forward(self, from_stage, to_stage, req_id, size_bytes, tx_ms, use_shm):
+ pass
+
+ def accumulate_diffusion_metrics(self, stage_type, req_id, engine_outputs):
+ pass
+
+ def record_audio_generated_frames(self, output, stage_id, req_id):
+ pass
+
+ def stage_postprocess_timer(self, stage_id, req_id):
+ from contextlib import contextmanager
+
+ @contextmanager
+ def _noop():
+ yield
+
+ return _noop()
+
+ def build_and_log_summary(self):
+ return "Fake summary"
+
+ monkeypatch.setattr(
+ "vllm_omni.entrypoints.omni.OrchestratorAggregator",
+ _FakeOrchestratorAggregator,
+ raising=False,
+ )
+
+
+def _clear_modules():
+ import sys
+
+ for module_name in [
+ "vllm_omni.entrypoints.utils",
+ "vllm_omni.entrypoints.omni",
+ ]:
+ if module_name in sys.modules:
+ del sys.modules[module_name]
+
+
+@pytest.fixture(autouse=True)
+def mock_get_config(monkeypatch):
+ """Auto-mock get_config and related model loading functions."""
+ import sys
+
+ fake_tokenizer = _ns()
+ fake_tokenizer.encode = lambda *args, **kwargs: [1, 2, 3]
+ fake_tokenizer.decode = lambda *args, **kwargs: "test"
+
+ def _mock_init_tokenizer_from_configs(model_config=None, **kwargs):
+ return fake_tokenizer
+
+ monkeypatch.setattr(
+ "vllm.transformers_utils.tokenizer.init_tokenizer_from_configs",
+ _mock_init_tokenizer_from_configs,
+ raising=False,
+ )
+ tokenizer_module_path = "vllm.transformers_utils.tokenizer"
+ if tokenizer_module_path in sys.modules:
+ setattr(sys.modules[tokenizer_module_path], "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs)
+
+ def _mock_length_from_prompt_token_ids_or_embeds(prompt_token_ids=None, prompt_embeds=None):
+ if prompt_token_ids is not None:
+ if isinstance(prompt_token_ids, list):
+ return len(prompt_token_ids)
+ return 10
+
+ monkeypatch.setattr(
+ "vllm.utils.length_from_prompt_token_ids_or_embeds", _mock_length_from_prompt_token_ids_or_embeds, raising=False
+ )
+ monkeypatch.setattr(
+ "vllm_omni.engine.input_processor.length_from_prompt_token_ids_or_embeds",
+ _mock_length_from_prompt_token_ids_or_embeds,
+ raising=False,
+ )
+
+ processor_module_path = "vllm_omni.engine.input_processor"
+ if processor_module_path in sys.modules:
+ setattr(
+ sys.modules[processor_module_path],
+ "length_from_prompt_token_ids_or_embeds",
+ _mock_length_from_prompt_token_ids_or_embeds,
+ )
+
+ monkeypatch.setattr(
+ "vllm_omni.entrypoints.async_omni.init_tokenizer_from_configs", _mock_init_tokenizer_from_configs, raising=False
+ )
+ async_omni_path = "vllm_omni.entrypoints.async_omni"
+ if async_omni_path in sys.modules:
+ setattr(sys.modules[async_omni_path], "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs)
+
+ fake_hf_config = _ns()
+ fake_hf_config.model_type = "qwen2_5_omni"
+
+ monkeypatch.setattr(
+ "vllm.transformers_utils.config.get_config", lambda model, **kwargs: fake_hf_config, raising=False
+ )
+ monkeypatch.setattr("vllm_omni.entrypoints.utils.get_config", lambda model, **kwargs: fake_hf_config, raising=False)
+
+ def _mock_cached_file(path_or_repo_id, *args, **kwargs):
+ import os
+ import tempfile
+
+ fake_config_file = os.path.join(tempfile.gettempdir(), "fake_config.json")
+ if not os.path.exists(fake_config_file):
+ with open(fake_config_file, "w") as f:
+ f.write('{"model_type": "qwen2_5_omni"}')
+ return fake_config_file
+
+ monkeypatch.setattr("transformers.utils.hub.cached_file", _mock_cached_file, raising=False)
+ monkeypatch.setattr(
+ "transformers.utils.hub.cached_files",
+ lambda path_or_repo_id, filenames, **kwargs: (
+ [_mock_cached_file(path_or_repo_id, filenames[0])] if filenames else None
+ ),
+ raising=False,
+ )
+
+
+# ---------------------------------------------------------------------------
+# Helper to build an Omni instance with PD stage configs
+# ---------------------------------------------------------------------------
+
+
+def _make_pd_omni(monkeypatch, stage_configs, *, extra_setup=None):
+ """Create a lightweight PDDisaggregationMixin instance for unit tests.
+
+ Bypasses the full OmniBase / AsyncOmniEngine init chain so tests run
+ without GPU. Returns an object that has all PDDisaggregationMixin
+ methods and state (``_pd_separation_pair``, ``_pd_kv_params_by_req``,
+ etc.) initialised from *stage_configs*.
+
+ Tests that need the full ``Omni.generate()`` loop (old stage_list / queue
+ infrastructure) are marked ``xfail`` and not covered here.
+ """
+ configs = [_FakeStageConfig(c) for c in stage_configs]
+
+ class _LightweightOmni(PDDisaggregationMixin):
+ """Minimal shim: exposes stage_configs so PDDisaggregationMixin works."""
+
+ def __init__(self):
+ self._name = "Omni"
+ self._stage_configs = configs
+ self._init_pd_state()
+
+ @property
+ def stage_configs(self):
+ return self._stage_configs
+
+ if extra_setup:
+ import vllm_omni.entrypoints.omni as omni_module
+
+ extra_setup(monkeypatch, omni_module)
+
+ return _LightweightOmni()
+
+
+# ---------------------------------------------------------------------------
+# Stage config templates
+# ---------------------------------------------------------------------------
+
+
+def _prefill_stage_cfg(stage_id=0, **overrides):
+ cfg = {
+ "stage_id": stage_id,
+ "engine_args": {
+ "model_stage": "thinker",
+ "kv_transfer_config": {
+ "kv_connector": "MooncakeConnector",
+ "kv_role": "kv_producer",
+ "kv_rank": 0,
+ "kv_parallel_size": 2,
+ "kv_connector_extra_config": {"mooncake_bootstrap_port": 25201},
+ },
+ },
+ "is_prefill_only": True,
+ "final_output": False,
+ "is_comprehension": True,
+ }
+ cfg.update(overrides)
+ return cfg
+
+
+def _decode_stage_cfg(stage_id=1, engine_input_source=None, **overrides):
+ cfg = {
+ "stage_id": stage_id,
+ "engine_args": {
+ "model_stage": "thinker",
+ "kv_transfer_config": {
+ "kv_connector": "MooncakeConnector",
+ "kv_role": "kv_consumer",
+ "kv_rank": 1,
+ "kv_parallel_size": 2,
+ "kv_connector_extra_config": {"mooncake_bootstrap_port": 25202},
+ },
+ },
+ "is_decode_only": True,
+ "engine_input_source": engine_input_source if engine_input_source is not None else [0],
+ "final_output": True,
+ "final_output_type": "text",
+ "is_comprehension": True,
+ }
+ cfg.update(overrides)
+ return cfg
+
+
+def _talker_stage_cfg(stage_id=2, engine_input_source=None, **overrides):
+ cfg = {
+ "stage_id": stage_id,
+ "engine_args": {"model_stage": "talker"},
+ "engine_input_source": engine_input_source if engine_input_source is not None else [1],
+ "final_output": False,
+ }
+ cfg.update(overrides)
+ return cfg
+
+
+def _code2wav_stage_cfg(stage_id=3, engine_input_source=None, **overrides):
+ cfg = {
+ "stage_id": stage_id,
+ "engine_args": {"model_stage": "code2wav"},
+ "engine_input_source": engine_input_source if engine_input_source is not None else [2],
+ "final_output": True,
+ "final_output_type": "audio",
+ }
+ cfg.update(overrides)
+ return cfg
+
+
+# ===================================================================
+# Tests: PD pair detection
+# ===================================================================
+
+
+class TestDetectPDSeparation:
+ """Tests for Omni._detect_pd_separation()."""
+
+ def test_detects_pd_pair(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(stage_id=0),
+ _decode_stage_cfg(stage_id=1, engine_input_source=[0]),
+ ],
+ )
+ assert omni._pd_separation_pair == (0, 1)
+
+ def test_no_pd_pair_without_flags(self, monkeypatch):
+ """Normal (non-PD) pipeline has no PD pair."""
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ {
+ "stage_id": 0,
+ "engine_args": {"model_stage": "thinker"},
+ "final_output": True,
+ "final_output_type": "text",
+ },
+ {
+ "stage_id": 1,
+ "engine_args": {"model_stage": "talker"},
+ "engine_input_source": [0],
+ "final_output": True,
+ "final_output_type": "audio",
+ },
+ ],
+ )
+ assert omni._pd_separation_pair is None
+
+ def test_detects_pd_pair_in_4_stage_pipeline(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(stage_id=0),
+ _decode_stage_cfg(stage_id=1, engine_input_source=[0]),
+ _talker_stage_cfg(stage_id=2, engine_input_source=[1]),
+ _code2wav_stage_cfg(stage_id=3, engine_input_source=[2]),
+ ],
+ )
+ assert omni._pd_separation_pair == (0, 1)
+
+ def test_pd_pair_uses_stage_id_for_input_source(self, monkeypatch):
+ """engine_input_source references stage_id, not list index."""
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(stage_id=10),
+ _decode_stage_cfg(stage_id=20, engine_input_source=[10]),
+ ],
+ )
+ assert omni._pd_separation_pair == (0, 1)
+
+
+# ===================================================================
+# Tests: PD config validation
+# ===================================================================
+
+
+class TestValidatePDConfig:
+ """Tests for Omni._validate_pd_separation_config()."""
+
+ def test_valid_config_passes(self, monkeypatch):
+ """Valid PD config should not raise."""
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ # If we got here without error, validation passed
+ assert omni._pd_separation_pair == (0, 1)
+
+ def test_mismatched_connector_raises(self, monkeypatch):
+ """Different kv_connector types should raise ValueError."""
+ decode_cfg = _decode_stage_cfg(engine_input_source=[0])
+ decode_cfg["engine_args"]["kv_transfer_config"]["kv_connector"] = "NixlConnector"
+
+ with pytest.raises(ValueError, match="connector mismatch"):
+ _make_pd_omni(monkeypatch, [_prefill_stage_cfg(), decode_cfg])
+
+ def test_wrong_prefill_role_raises(self, monkeypatch):
+ """Prefill with kv_consumer role should raise."""
+ prefill_cfg = _prefill_stage_cfg()
+ prefill_cfg["engine_args"]["kv_transfer_config"]["kv_role"] = "kv_consumer"
+
+ with pytest.raises(ValueError, match="kv_role must be"):
+ _make_pd_omni(monkeypatch, [prefill_cfg, _decode_stage_cfg(engine_input_source=[0])])
+
+ def test_wrong_decode_role_raises(self, monkeypatch):
+ """Decode with kv_producer role should raise."""
+ decode_cfg = _decode_stage_cfg(engine_input_source=[0])
+ decode_cfg["engine_args"]["kv_transfer_config"]["kv_role"] = "kv_producer"
+
+ with pytest.raises(ValueError, match="kv_role must be"):
+ _make_pd_omni(monkeypatch, [_prefill_stage_cfg(), decode_cfg])
+
+ def test_missing_kv_transfer_config_raises(self, monkeypatch):
+ """Missing kv_transfer_config should raise."""
+ prefill_cfg = _prefill_stage_cfg()
+ del prefill_cfg["engine_args"]["kv_transfer_config"]
+
+ with pytest.raises(ValueError, match="kv_transfer_config"):
+ _make_pd_omni(monkeypatch, [prefill_cfg, _decode_stage_cfg(engine_input_source=[0])])
+
+ def test_mismatched_buffer_device_raises(self, monkeypatch):
+ """Mismatched kv_buffer_device should raise."""
+ prefill_cfg = _prefill_stage_cfg()
+ prefill_cfg["engine_args"]["kv_transfer_config"]["kv_buffer_device"] = "cuda"
+ decode_cfg = _decode_stage_cfg(engine_input_source=[0])
+ decode_cfg["engine_args"]["kv_transfer_config"]["kv_buffer_device"] = "cpu"
+
+ with pytest.raises(ValueError, match="kv_buffer_device mismatch"):
+ _make_pd_omni(monkeypatch, [prefill_cfg, decode_cfg])
+
+
+# ===================================================================
+# Tests: Connector info extraction
+# ===================================================================
+
+
+class TestGetPDConnectorInfo:
+ """Tests for Omni._get_pd_connector_info()."""
+
+ def test_extracts_bootstrap_addr_for_mooncake(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ info = omni._pd_connector_info
+ assert "prefill_bootstrap_addr" in info
+ assert info["prefill_bootstrap_addr"] == "127.0.0.1:25201"
+
+ def test_none_for_non_pd_pipeline(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ {"stage_id": 0, "engine_args": {}, "final_output": True, "final_output_type": "text"},
+ ],
+ )
+ assert omni._pd_connector_info is None
+
+
+# ===================================================================
+# Tests: Prefill sampling params preparation
+# ===================================================================
+
+
+class TestPreparePrefillSamplingParams:
+ """Tests for Omni._prepare_prefill_sampling_params()."""
+
+ def test_sets_max_tokens_to_1(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ sp = SamplingParams(max_tokens=2048)
+ result = omni._prepare_prefill_sampling_params("req-1", sp)
+
+ assert result.max_tokens == 1
+ assert result is not sp # should be cloned
+
+ def test_injects_kv_transfer_params(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ sp = SamplingParams(max_tokens=2048)
+ result = omni._prepare_prefill_sampling_params("req-1", sp)
+
+ kv_params = result.extra_args["kv_transfer_params"]
+ assert kv_params["do_remote_decode"] is True
+ assert kv_params["do_remote_prefill"] is False
+ assert kv_params["transfer_id"] == "xfer-req-1"
+
+ def test_preserves_existing_extra_args(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ sp = SamplingParams(max_tokens=2048, extra_args={"custom_key": "value"})
+ result = omni._prepare_prefill_sampling_params("req-1", sp)
+
+ assert result.extra_args["custom_key"] == "value"
+ assert "kv_transfer_params" in result.extra_args
+
+ def test_does_not_mutate_original(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ sp = SamplingParams(max_tokens=2048)
+ _ = omni._prepare_prefill_sampling_params("req-1", sp)
+
+ assert sp.max_tokens == 2048
+ assert sp.extra_args is None
+
+
+# ===================================================================
+# Tests: Sampling params auto-duplication for PD split
+# ===================================================================
+
+
+@pytest.mark.xfail(reason="Requires migration to v1908 Orchestrator architecture (no stage_list / OmniStage)")
+class TestSamplingParamsAutoDuplication:
+ """When user provides N-1 sampling params (for logical stages), the
+ orchestrator should auto-duplicate the thinker params for the decode stage.
+ """
+
+ def test_auto_duplicates_for_4_stage_pipeline(self, monkeypatch):
+ """User provides 3 params for 4 physical stages -> auto-insert decode params."""
+ test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000001")
+
+ def _extra_setup(mp, omni_module):
+ mp.setattr(uuid, "uuid4", lambda: test_uuid)
+ mp.setattr(omni_module, "uuid", uuid)
+
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(stage_id=0),
+ _decode_stage_cfg(stage_id=1, engine_input_source=[0]),
+ _talker_stage_cfg(stage_id=2, engine_input_source=[1]),
+ _code2wav_stage_cfg(stage_id=3, engine_input_source=[2]),
+ ],
+ extra_setup=_extra_setup,
+ )
+
+ assert omni._pd_separation_pair == (0, 1)
+ assert len(omni.stage_list) == 4
+
+ # Simulate outputs for all stages
+ expected_rid = f"0_{test_uuid}"
+ for i in range(4):
+ omni.stage_list[i]._out_q.put_nowait(
+ {
+ "request_id": expected_rid,
+ "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1, 2])])],
+ "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
+ }
+ )
+
+ # Provide 3 params (one less than 4 stages) - should auto-duplicate
+ sp_thinker = SamplingParams(temperature=0.4, max_tokens=2048)
+ sp_talker = SamplingParams(temperature=0.9, max_tokens=4096)
+ sp_code2wav = SamplingParams(temperature=0.0, max_tokens=65536)
+
+ # This should NOT raise ValueError about param count mismatch
+ outputs = omni.generate(
+ prompts=["hello"],
+ sampling_params_list=[sp_thinker, sp_talker, sp_code2wav],
+ )
+ assert isinstance(outputs, list)
+
+
+# ===================================================================
+# Tests: KV transfer params normalization
+# ===================================================================
+
+
+class TestNormalizeKVTransferParams:
+ def test_dict_passthrough(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ d = {"transfer_id": "test", "do_remote_decode": True}
+ assert omni._normalize_kv_transfer_params(d) is d
+
+ def test_none_returns_none(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ assert omni._normalize_kv_transfer_params(None) is None
+
+ def test_dataclass_to_dict(self, monkeypatch):
+ from dataclasses import dataclass
+
+ @dataclass
+ class FakeKVParams:
+ transfer_id: str = "test"
+ do_remote_decode: bool = True
+
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ result = omni._normalize_kv_transfer_params(FakeKVParams())
+ assert isinstance(result, dict)
+ assert result["transfer_id"] == "test"
+
+
+# ===================================================================
+# Tests: _kv_cfg_to_dict
+# ===================================================================
+
+
+class TestKvCfgToDict:
+ def test_dict_passthrough(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ d = {"kv_connector": "MooncakeConnector"}
+ assert omni._kv_cfg_to_dict(d) is d
+
+ def test_none_returns_empty(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ assert omni._kv_cfg_to_dict(None) == {}
+
+ def test_dataclass_converted(self, monkeypatch):
+ from dataclasses import dataclass
+
+ @dataclass
+ class FakeCfg:
+ kv_connector: str = "TestConnector"
+ kv_role: str = "kv_producer"
+
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ result = omni._kv_cfg_to_dict(FakeCfg())
+ assert result["kv_connector"] == "TestConnector"
+ assert result["kv_role"] == "kv_producer"
+
+
+# ===================================================================
+# Tests: PD routing in scheduling loop
+# ===================================================================
+
+
+@pytest.mark.xfail(reason="Requires migration to v1908 Orchestrator architecture (no stage_list / OmniStage)")
+class TestPDRouting:
+ """Test that the scheduling loop correctly routes requests from
+ prefill to decode stage with proper kv_transfer_params.
+ """
+
+ def test_prefill_stage_receives_max_tokens_1(self, monkeypatch):
+ """Stage 0 (prefill) should receive max_tokens=1."""
+ test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000002")
+
+ def _extra_setup(mp, omni_module):
+ mp.setattr(uuid, "uuid4", lambda: test_uuid)
+ mp.setattr(omni_module, "uuid", uuid)
+
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(stage_id=0),
+ _decode_stage_cfg(stage_id=1, engine_input_source=[0]),
+ ],
+ extra_setup=_extra_setup,
+ )
+
+ expected_rid = f"0_{test_uuid}"
+
+ # Put stage outputs in both queues
+ omni.stage_list[0]._out_q.put_nowait(
+ {
+ "request_id": expected_rid,
+ "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1])])],
+ "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
+ }
+ )
+ omni.stage_list[1]._out_q.put_nowait(
+ {
+ "request_id": expected_rid,
+ "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1, 2, 3])])],
+ "metrics": {"num_tokens_out": 3, "stage_gen_time_ms": 50.0},
+ }
+ )
+
+ sp_list = [SamplingParams(max_tokens=2048), SamplingParams(max_tokens=2048)]
+ omni.generate(prompts=["hello"], sampling_params_list=sp_list)
+
+ # Check what was submitted to stage 0's input queue
+ # (skip the stage_ready message first)
+ task = omni.stage_list[0]._in_q.get_nowait()
+ assert task["sampling_params"].max_tokens == 1
+ kv_params = task["sampling_params"].extra_args["kv_transfer_params"]
+ assert kv_params["do_remote_decode"] is True
+ assert kv_params["do_remote_prefill"] is False
+ assert kv_params["transfer_id"] == f"xfer-{expected_rid}"
+
+ def test_decode_stage_receives_original_prompt(self, monkeypatch):
+ """Decode stage should get the original prompt (not processed outputs)."""
+ test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000003")
+
+ def _extra_setup(mp, omni_module):
+ mp.setattr(uuid, "uuid4", lambda: test_uuid)
+ mp.setattr(omni_module, "uuid", uuid)
+
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(stage_id=0),
+ _decode_stage_cfg(stage_id=1, engine_input_source=[0]),
+ ],
+ extra_setup=_extra_setup,
+ )
+
+ expected_rid = f"0_{test_uuid}"
+ original_prompt = "test prompt for PD"
+
+ omni.stage_list[0]._out_q.put_nowait(
+ {
+ "request_id": expected_rid,
+ "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1])])],
+ "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
+ }
+ )
+ omni.stage_list[1]._out_q.put_nowait(
+ {
+ "request_id": expected_rid,
+ "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1, 2, 3])])],
+ "metrics": {"num_tokens_out": 3, "stage_gen_time_ms": 50.0},
+ }
+ )
+
+ sp_list = [SamplingParams(max_tokens=2048), SamplingParams(max_tokens=2048)]
+ omni.generate(prompts=[original_prompt], sampling_params_list=sp_list)
+
+ # Check what was forwarded to stage 1 (decode)
+ # The connector sends tasks to stage 1's input queue
+ task = omni.stage_list[1]._in_q.get_nowait()
+ # The engine_inputs should contain the original prompt
+ engine_inputs = task.get("engine_inputs")
+ # For PD routing, the original prompt is wrapped in a list
+ if isinstance(engine_inputs, list):
+ assert original_prompt in engine_inputs
+ else:
+ assert engine_inputs == original_prompt
+
+ def test_decode_kv_params_have_correct_flags(self, monkeypatch):
+ """Decode stage kv_transfer_params should have correct role flags."""
+ test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000004")
+
+ def _extra_setup(mp, omni_module):
+ mp.setattr(uuid, "uuid4", lambda: test_uuid)
+ mp.setattr(omni_module, "uuid", uuid)
+
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(stage_id=0),
+ _decode_stage_cfg(stage_id=1, engine_input_source=[0]),
+ ],
+ extra_setup=_extra_setup,
+ )
+
+ expected_rid = f"0_{test_uuid}"
+
+ omni.stage_list[0]._out_q.put_nowait(
+ {
+ "request_id": expected_rid,
+ "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1])])],
+ "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
+ }
+ )
+ omni.stage_list[1]._out_q.put_nowait(
+ {
+ "request_id": expected_rid,
+ "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1, 2, 3])])],
+ "metrics": {"num_tokens_out": 3, "stage_gen_time_ms": 50.0},
+ }
+ )
+
+ sp_list = [SamplingParams(max_tokens=2048), SamplingParams(max_tokens=2048)]
+ omni.generate(prompts=["hello"], sampling_params_list=sp_list)
+
+ # Check decode task's kv_transfer_params
+ task = omni.stage_list[1]._in_q.get_nowait()
+ kv_params = task["sampling_params"].extra_args["kv_transfer_params"]
+ assert kv_params["do_remote_prefill"] is True
+ assert kv_params["do_remote_decode"] is False
+ assert kv_params["transfer_id"] == f"xfer-{expected_rid}"
+ assert kv_params["remote_bootstrap_addr"] == "127.0.0.1:25201"
+
+
+# ===================================================================
+# Tests: KV params cleanup
+# ===================================================================
+
+
+class TestKVParamsCleanup:
+ def test_drop_cleans_up(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ omni._pd_kv_params_by_req["req-1"] = {"transfer_id": "xfer-1"}
+ omni._drop_pd_kv_params("req-1")
+ assert "req-1" not in omni._pd_kv_params_by_req
+
+ def test_drop_nonexistent_is_noop(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ omni._drop_pd_kv_params("nonexistent") # should not raise
+
+ def test_pop_returns_stored_params(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ stored = {"transfer_id": "xfer-1", "extra_field": "value"}
+ omni._pd_kv_params_by_req["req-1"] = stored
+
+ result = omni._pop_pd_kv_params("req-1")
+ assert result == stored
+ assert "req-1" not in omni._pd_kv_params_by_req
+
+ def test_pop_uses_fallback_when_no_stored(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ fallback = {"transfer_id": "xfer-fallback"}
+ result = omni._pop_pd_kv_params("req-1", fallback=fallback)
+ assert result == fallback
+
+
+# ===================================================================
+# Tests: Config YAML loads without error
+# ===================================================================
+
+
+class TestPDYAMLConfig:
+ def test_pd_yaml_loads(self):
+ """The PD separation YAML config should load without errors."""
+ import os
+
+ yaml_path = os.path.join(
+ os.path.dirname(__file__),
+ "../../vllm_omni/model_executor/stage_configs/qwen3_omni_moe_pd_separation.yaml",
+ )
+ yaml_path = os.path.abspath(yaml_path)
+ if not os.path.exists(yaml_path):
+ pytest.skip("PD separation YAML not found")
+
+ from omegaconf import OmegaConf
+
+ cfg = OmegaConf.load(yaml_path)
+ stages = cfg.stage_args
+ assert len(stages) == 4
+
+ # Prefill stage
+ assert stages[0].is_prefill_only is True
+ assert stages[0].final_output is False
+ assert stages[0].is_comprehension is True
+
+ # Decode stage
+ assert stages[1].is_decode_only is True
+ assert stages[1].final_output is True
+ assert stages[1].final_output_type == "text"
+ assert stages[1].is_comprehension is True
+ assert 0 in stages[1].engine_input_source
+
+ # KV transfer configs
+ assert stages[0].engine_args.kv_transfer_config.kv_role == "kv_producer"
+ assert stages[1].engine_args.kv_transfer_config.kv_role == "kv_consumer"
+ assert stages[0].engine_args.kv_transfer_config.kv_connector == "MooncakeConnector"
+ assert stages[1].engine_args.kv_transfer_config.kv_connector == "MooncakeConnector"
+
+
+class TestPrefillStopNeutralization:
+ """Tests that _prepare_prefill_sampling_params neutralizes stop
+ conditions to ensure finish_reason='length'.
+ """
+
+ def test_clears_stop_strings(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ sp = SamplingParams(max_tokens=2048, stop=[" ", "STOP"])
+ result = omni._prepare_prefill_sampling_params("req-1", sp)
+ assert result.stop == []
+
+ def test_clears_stop_token_ids(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ sp = SamplingParams(max_tokens=2048, stop_token_ids=[151643, 151644])
+ result = omni._prepare_prefill_sampling_params("req-1", sp)
+ assert result.stop_token_ids == []
+
+ def test_clears_include_stop_str_in_output(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ sp = SamplingParams(max_tokens=2048, include_stop_str_in_output=True)
+ result = omni._prepare_prefill_sampling_params("req-1", sp)
+ assert result.include_stop_str_in_output is False
+
+ def test_original_sp_unchanged(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ sp = SamplingParams(max_tokens=2048, stop=[" "], stop_token_ids=[151643])
+ _ = omni._prepare_prefill_sampling_params("req-1", sp)
+ assert sp.stop == [""]
+ assert sp.stop_token_ids == [151643]
+
+
+# ===================================================================
+# Tests: Failure mode & memory leak prevention
+# ===================================================================
+# NOTE: Full generate()-level failure mode tests are removed for now.
+# The _run_generation error handler (line 1344-1350 in omni.py) calls
+# _drop_pd_kv_params but does not increment completed_requests, causing
+# the while-loop to hang. These tests need to be revisited once the
+# production error-handling path is fixed to properly terminate on
+# stage errors.
+
+
+# ===================================================================
+# Tests: TP size validation
+# ===================================================================
+
+
+class TestTPSizeValidation:
+ """Tests that _validate_pd_separation_config checks tensor_parallel_size."""
+
+ def test_matching_tp_passes(self, monkeypatch):
+ """Same TP size should not raise."""
+ prefill_cfg = _prefill_stage_cfg()
+ prefill_cfg["engine_args"]["tensor_parallel_size"] = 2
+ decode_cfg = _decode_stage_cfg(engine_input_source=[0])
+ decode_cfg["engine_args"]["tensor_parallel_size"] = 2
+ omni = _make_pd_omni(monkeypatch, [prefill_cfg, decode_cfg])
+ assert omni._pd_separation_pair == (0, 1)
+
+ def test_mismatched_tp_raises(self, monkeypatch):
+ """Different TP sizes should raise ValueError."""
+ prefill_cfg = _prefill_stage_cfg()
+ prefill_cfg["engine_args"]["tensor_parallel_size"] = 2
+ decode_cfg = _decode_stage_cfg(engine_input_source=[0])
+ decode_cfg["engine_args"]["tensor_parallel_size"] = 4
+ with pytest.raises(ValueError, match="tensor_parallel_size"):
+ _make_pd_omni(monkeypatch, [prefill_cfg, decode_cfg])
+
+ def test_default_tp_no_error(self, monkeypatch):
+ """Stages without explicit TP (defaults to 1) should pass."""
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ assert omni._pd_separation_pair == (0, 1)
diff --git a/tests/entrypoints/test_realtime_connection_helpers.py b/tests/entrypoints/test_realtime_connection_helpers.py
new file mode 100644
index 0000000000..e795aa92d0
--- /dev/null
+++ b/tests/entrypoints/test_realtime_connection_helpers.py
@@ -0,0 +1,86 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for realtime streaming helpers (PR #2581 /v1/realtime path)."""
+
+from __future__ import annotations
+
+import base64
+
+import numpy as np
+import pytest
+import torch
+from vllm.sampling_params import RequestOutputKind, SamplingParams
+
+from vllm_omni.entrypoints.async_omni import AsyncOmni
+from vllm_omni.entrypoints.openai.realtime_connection import RealtimeConnection
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+@pytest.fixture
+def realtime_conn() -> RealtimeConnection:
+ return RealtimeConnection.__new__(RealtimeConnection)
+
+
+class TestRealtimeConnectionTensorAndPcm:
+ def test_tensor_to_numpy_none(self) -> None:
+ assert RealtimeConnection._tensor_to_numpy(None) is None
+
+ def test_tensor_to_numpy_1d_numpy(self) -> None:
+ arr = np.array([1.0, 2.0], dtype=np.float64)
+ out = RealtimeConnection._tensor_to_numpy(arr)
+ assert out is not None
+ assert out.dtype == np.float32
+ assert out.shape == (2,)
+
+ def test_tensor_to_numpy_2d_numpy_flattened(self) -> None:
+ arr = np.array([[0.5], [-0.5]], dtype=np.float32)
+ out = RealtimeConnection._tensor_to_numpy(arr)
+ assert out is not None
+ assert out.shape == (2,)
+
+ def test_tensor_to_numpy_torch(self) -> None:
+ t = torch.tensor([[0.25, -0.25]], dtype=torch.float32)
+ out = RealtimeConnection._tensor_to_numpy(t)
+ assert out is not None
+ assert out.shape == (2,)
+ np.testing.assert_allclose(out, [0.25, -0.25], rtol=1e-5)
+
+ def test_pcm16_b64_roundtrip(self) -> None:
+ audio = np.array([0.0, 1.0, -1.0], dtype=np.float32)
+ b64 = RealtimeConnection._pcm16_b64(audio)
+ raw = base64.b64decode(b64)
+ assert len(raw) == 6
+ pcm = np.frombuffer(raw, dtype=np.int16)
+ assert pcm[0] == 0
+ assert pcm[1] == 32767
+ assert pcm[2] == -32767
+
+
+class TestAsyncOmniStreamingParamsValidation:
+ def test_accepts_streaming_friendly_params(self) -> None:
+ p = SamplingParams(
+ n=1,
+ stop=[],
+ output_kind=RequestOutputKind.DELTA,
+ )
+ AsyncOmni._validate_streaming_input_sampling_params(p)
+
+ def test_rejects_non_sampling_params(self) -> None:
+ with pytest.raises(ValueError, match="Input streaming"):
+ AsyncOmni._validate_streaming_input_sampling_params(object()) # type: ignore[arg-type]
+
+ def test_rejects_n_greater_than_one(self) -> None:
+ p = SamplingParams(n=2, stop=[], output_kind=RequestOutputKind.DELTA)
+ with pytest.raises(ValueError, match="Input streaming"):
+ AsyncOmni._validate_streaming_input_sampling_params(p)
+
+ def test_rejects_final_only(self) -> None:
+ p = SamplingParams(n=1, stop=[], output_kind=RequestOutputKind.FINAL_ONLY)
+ with pytest.raises(ValueError, match="Input streaming"):
+ AsyncOmni._validate_streaming_input_sampling_params(p)
+
+ def test_rejects_stop_strings(self) -> None:
+ p = SamplingParams(n=1, stop=["\n"], output_kind=RequestOutputKind.DELTA)
+ with pytest.raises(ValueError, match="Input streaming"):
+ AsyncOmni._validate_streaming_input_sampling_params(p)
diff --git a/tests/entrypoints/test_serve.py b/tests/entrypoints/test_serve.py
index 916db3cc22..e60afc9cd7 100644
--- a/tests/entrypoints/test_serve.py
+++ b/tests/entrypoints/test_serve.py
@@ -3,15 +3,37 @@
from __future__ import annotations
import argparse
-from unittest.mock import Mock, patch
import pytest
+from pytest_mock import MockerFixture
-from vllm_omni.entrypoints.cli.serve import run_headless
+from vllm_omni.entrypoints.cli.serve import OmniServeCommand, run_headless
+from vllm_omni.entrypoints.utils import detect_explicit_cli_keys
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+def test_serve_parser_accepts_no_async_chunk_and_marks_it_explicit() -> None:
+ """``--no-async-chunk`` should parse to ``async_chunk=False`` and mark the
+ shared deploy-level dest as explicitly provided by the user."""
+ try:
+ from vllm.utils.argparse_utils import FlexibleArgumentParser
+ except Exception as exc:
+ pytest.skip(f"Cannot build parser in this environment: {exc}")
+
+ root = FlexibleArgumentParser()
+ subparsers = root.add_subparsers(dest="subcommand")
+ cmd = OmniServeCommand()
+ serve_parser = cmd.subparser_init(subparsers)
+
+ argv = ["serve", "fake-model", "--omni", "--no-async-chunk"]
+ args = root.parse_args(argv)
+
+ assert args.async_chunk is False
+ explicit = detect_explicit_cli_keys(argv, serve_parser)
+ assert "async_chunk" in explicit
+
+
def _make_headless_args() -> argparse.Namespace:
return argparse.Namespace(
model="fake-model",
@@ -26,45 +48,43 @@ def _make_headless_args() -> argparse.Namespace:
)
-def test_run_headless_registers_stage_once_and_launches_all_local_engines() -> None:
+def test_run_headless_registers_stage_once_and_launches_all_local_engines(mocker: MockerFixture) -> None:
args = _make_headless_args()
- stage_cfg = Mock(stage_id=3)
+ stage_cfg = mocker.Mock(stage_id=3)
stage_cfgs = [stage_cfg]
- parallel_config = Mock(
+ parallel_config = mocker.Mock(
data_parallel_size_local=2,
data_parallel_rank=4,
data_parallel_rank_local=1,
node_rank_within_dp=0,
)
- vllm_config = Mock(parallel_config=parallel_config)
- executor_class = Mock()
- engine_manager = Mock()
-
- with (
- patch(
- "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs",
- return_value=("/fake/stages.yaml", stage_cfgs),
- ),
- patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment"),
- patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=Mock()),
- patch("vllm_omni.engine.stage_init_utils.get_stage_connector_spec", return_value={}),
- patch("vllm_omni.engine.stage_init_utils.build_engine_args_dict", return_value={}),
- patch(
- "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- ),
- patch(
- "vllm_omni.engine.stage_init_utils.build_vllm_config",
- return_value=(vllm_config, executor_class),
- ) as mock_build_vllm_config,
- patch(
- "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
- return_value="tcp://127.0.0.1:26001",
- ) as mock_register,
- patch("vllm.v1.engine.utils.CoreEngineProcManager", return_value=engine_manager) as mock_manager_cls,
- patch("signal.signal"),
- ):
- run_headless(args)
+ vllm_config = mocker.Mock(parallel_config=parallel_config)
+ executor_class = mocker.Mock()
+ engine_manager = mocker.Mock()
+
+ mocker.patch(
+ "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs",
+ return_value=("/fake/stages.yaml", stage_cfgs),
+ )
+ mocker.patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment")
+ mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=mocker.Mock())
+ mocker.patch("vllm_omni.engine.stage_init_utils.get_stage_connector_spec", return_value={})
+ mocker.patch("vllm_omni.engine.stage_init_utils.build_engine_args_dict", return_value={})
+ mocker.patch(
+ "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mock_build_vllm_config = mocker.patch(
+ "vllm_omni.engine.stage_init_utils.build_vllm_config",
+ return_value=(vllm_config, executor_class),
+ )
+ mock_register = mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
+ return_value="tcp://127.0.0.1:26001",
+ )
+ mock_manager_cls = mocker.patch("vllm.v1.engine.utils.CoreEngineProcManager", return_value=engine_manager)
+ mocker.patch("signal.signal")
+ run_headless(args)
mock_build_vllm_config.assert_called_once_with(
stage_cfg,
@@ -92,89 +112,85 @@ def test_run_headless_registers_stage_once_and_launches_all_local_engines() -> N
engine_manager.shutdown.assert_called_once_with()
-def test_run_headless_honors_explicit_log_stats_flag() -> None:
+def test_run_headless_honors_explicit_log_stats_flag(mocker: MockerFixture) -> None:
args = _make_headless_args()
args.log_stats = True
- stage_cfg = Mock(stage_id=3)
+ stage_cfg = mocker.Mock(stage_id=3)
stage_cfgs = [stage_cfg]
- parallel_config = Mock(
+ parallel_config = mocker.Mock(
data_parallel_size_local=2,
data_parallel_rank=4,
data_parallel_rank_local=1,
node_rank_within_dp=0,
)
- vllm_config = Mock(parallel_config=parallel_config)
- executor_class = Mock()
- engine_manager = Mock()
-
- with (
- patch(
- "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs",
- return_value=("/fake/stages.yaml", stage_cfgs),
- ),
- patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment"),
- patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=Mock()),
- patch("vllm_omni.engine.stage_init_utils.get_stage_connector_spec", return_value={}),
- patch("vllm_omni.engine.stage_init_utils.build_engine_args_dict", return_value={}),
- patch(
- "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- ),
- patch(
- "vllm_omni.engine.stage_init_utils.build_vllm_config",
- return_value=(vllm_config, executor_class),
- ),
- patch(
- "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
- return_value="tcp://127.0.0.1:26001",
- ),
- patch("vllm.v1.engine.utils.CoreEngineProcManager", return_value=engine_manager) as mock_manager_cls,
- patch("signal.signal"),
- ):
- run_headless(args)
+ vllm_config = mocker.Mock(parallel_config=parallel_config)
+ executor_class = mocker.Mock()
+ engine_manager = mocker.Mock()
+
+ mocker.patch(
+ "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs",
+ return_value=("/fake/stages.yaml", stage_cfgs),
+ )
+ mocker.patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment")
+ mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=mocker.Mock())
+ mocker.patch("vllm_omni.engine.stage_init_utils.get_stage_connector_spec", return_value={})
+ mocker.patch("vllm_omni.engine.stage_init_utils.build_engine_args_dict", return_value={})
+ mocker.patch(
+ "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch(
+ "vllm_omni.engine.stage_init_utils.build_vllm_config",
+ return_value=(vllm_config, executor_class),
+ )
+ mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
+ return_value="tcp://127.0.0.1:26001",
+ )
+ mock_manager_cls = mocker.patch("vllm.v1.engine.utils.CoreEngineProcManager", return_value=engine_manager)
+ mocker.patch("signal.signal")
+ run_headless(args)
manager_kwargs = mock_manager_cls.call_args.kwargs
assert manager_kwargs["log_stats"] is True
-def test_run_headless_launches_diffusion_stage_via_omni_master() -> None:
+def test_run_headless_launches_diffusion_stage_via_omni_master(mocker: MockerFixture) -> None:
args = _make_headless_args()
- stage_cfg = Mock(stage_id=3, stage_type="diffusion")
- stage_cfg.engine_args = Mock()
+ stage_cfg = mocker.Mock(stage_id=3, stage_type="diffusion")
+ stage_cfg.engine_args = mocker.Mock()
stage_cfg.engine_input_source = []
stage_cfgs = [stage_cfg]
- metadata = Mock(stage_id=3)
- od_config = Mock()
- proc = Mock()
+ metadata = mocker.Mock(stage_id=3)
+ od_config = mocker.Mock()
+ proc = mocker.Mock()
proc.exitcode = 0
proc.is_alive.return_value = False
- with (
- patch(
- "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs",
- return_value=("/fake/stages.yaml", stage_cfgs),
- ),
- patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment"),
- patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=Mock()),
- patch(
- "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- ),
- patch("vllm_omni.engine.stage_init_utils.extract_stage_metadata", return_value=metadata),
- patch("vllm_omni.engine.stage_init_utils.inject_kv_stage_info") as mock_inject_stage_info,
- patch("vllm_omni.engine.stage_init_utils.build_diffusion_config", return_value=od_config),
- patch(
- "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
- return_value=("tcp://127.0.0.1:26001", "tcp://127.0.0.1:26002", "tcp://127.0.0.1:26003"),
- ) as mock_register,
- patch(
- "vllm_omni.diffusion.stage_diffusion_proc.spawn_diffusion_proc",
- return_value=(proc, "tcp://127.0.0.1:26001", "tcp://127.0.0.1:26002", "tcp://127.0.0.1:26003"),
- ) as mock_spawn,
- patch("vllm_omni.diffusion.stage_diffusion_proc.complete_diffusion_handshake") as mock_handshake,
- patch("signal.signal"),
- ):
- run_headless(args)
+ mocker.patch(
+ "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs",
+ return_value=("/fake/stages.yaml", stage_cfgs),
+ )
+ mocker.patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment")
+ mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=mocker.Mock())
+ mocker.patch(
+ "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage",
+ return_value=(None, None, None),
+ )
+ mocker.patch("vllm_omni.engine.stage_init_utils.extract_stage_metadata", return_value=metadata)
+ mock_inject_stage_info = mocker.patch("vllm_omni.engine.stage_init_utils.inject_kv_stage_info")
+ mocker.patch("vllm_omni.engine.stage_init_utils.build_diffusion_config", return_value=od_config)
+ mock_register = mocker.patch(
+ "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
+ return_value=("tcp://127.0.0.1:26001", "tcp://127.0.0.1:26002", "tcp://127.0.0.1:26003"),
+ )
+ mock_spawn = mocker.patch(
+ "vllm_omni.diffusion.stage_diffusion_proc.spawn_diffusion_proc",
+ return_value=(proc, "tcp://127.0.0.1:26001", "tcp://127.0.0.1:26002", "tcp://127.0.0.1:26003"),
+ )
+ mock_handshake = mocker.patch("vllm_omni.diffusion.stage_diffusion_proc.complete_diffusion_handshake")
+ mocker.patch("signal.signal")
+ run_headless(args)
mock_inject_stage_info.assert_called_once_with(stage_cfg, 3)
mock_register.assert_called_once_with(
diff --git a/tests/entrypoints/test_utils.py b/tests/entrypoints/test_utils.py
index 94e254c250..248629d51d 100644
--- a/tests/entrypoints/test_utils.py
+++ b/tests/entrypoints/test_utils.py
@@ -310,6 +310,39 @@ def mock_exists(path):
assert result is not None
assert "glm_image.yaml" in result
+ def test_voxcpm_transformers_format_resolution(self, mocker: MockerFixture):
+ """Test VoxCPM transformers config resolves to the voxcpm stage config."""
+ mocker.patch(
+ "vllm_omni.entrypoints.utils.get_config",
+ side_effect=ValueError("missing transformers config"),
+ )
+ mocker.patch(
+ "vllm_omni.entrypoints.utils.file_or_path_exists",
+ side_effect=lambda _model, filename, revision=None: filename == "config.json",
+ )
+ mocker.patch(
+ "vllm_omni.entrypoints.utils.get_hf_file_to_dict",
+ return_value={"model_type": "voxcpm"},
+ )
+ mocker.patch(
+ "vllm_omni.entrypoints.utils.current_omni_platform.get_default_stage_config_path",
+ return_value="vllm_omni/model_executor/stage_configs",
+ )
+
+ original_exists = os.path.exists
+
+ def mock_exists(path):
+ if "voxcpm.yaml" in str(path):
+ return True
+ return original_exists(path)
+
+ mocker.patch("os.path.exists", side_effect=mock_exists)
+
+ result = resolve_model_config_path("OpenBMB/VoxCPM1.5")
+
+ assert result is not None
+ assert "voxcpm.yaml" in result
+
class TestLoadAndResolveStageConfigs:
def test_load_and_resolve_with_kwargs(self):
diff --git a/tests/examples/online_serving/test_qwen2_5_omni.py b/tests/examples/online_serving/test_qwen2_5_omni.py
index a78ccf5924..2813b2fda8 100644
--- a/tests/examples/online_serving/test_qwen2_5_omni.py
+++ b/tests/examples/online_serving/test_qwen2_5_omni.py
@@ -5,8 +5,6 @@
import os
-from vllm_omni.platforms import current_omni_platform
-
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
from pathlib import Path
@@ -19,19 +17,15 @@
run_cmd,
strip_trailing_audio_saved_line,
)
-from tests.utils import hardware_test
+from tests.utils import get_deploy_config_path, hardware_test
pytestmark = [pytest.mark.advanced_model, pytest.mark.example]
models = ["Qwen/Qwen2.5-Omni-7B"]
-
-stage_configs = [str(Path(__file__).parent.parent.parent / "e2e" / "stage_configs" / "qwen2_5_omni_ci.yaml")]
-
-if current_omni_platform.is_xpu():
- stage_configs = [
- str(Path(__file__).parent.parent.parent / "e2e" / "stage_configs" / "xpu" / "qwen2_5_omni_ci.yaml")
- ]
+# Single CI deploy YAML; rocm/xpu deltas are picked automatically via the
+# platforms: section in vllm_omni/deploy/ci/qwen2_5_omni.yaml.
+stage_configs = [get_deploy_config_path("ci/qwen2_5_omni.yaml")]
example_dir = str(Path(__file__).parent.parent.parent.parent / "examples" / "online_serving")
# Create parameter combinations for model and stage config
diff --git a/tests/examples/online_serving/test_qwen3_omni.py b/tests/examples/online_serving/test_qwen3_omni.py
index 65f99d7bf2..e9ee2763bb 100644
--- a/tests/examples/online_serving/test_qwen3_omni.py
+++ b/tests/examples/online_serving/test_qwen3_omni.py
@@ -5,8 +5,6 @@
import os
-from vllm_omni.platforms import current_omni_platform
-
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
from pathlib import Path
@@ -19,17 +17,14 @@
run_cmd,
strip_trailing_audio_saved_line,
)
-from tests.utils import hardware_test
+from tests.utils import get_deploy_config_path, hardware_test
pytestmark = [pytest.mark.advanced_model, pytest.mark.example]
models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]
-stage_configs = [str(Path(__file__).parent.parent.parent / "e2e" / "stage_configs" / "qwen3_omni_ci.yaml")]
-
-if current_omni_platform.is_xpu():
- stage_configs = [str(Path(__file__).parent.parent.parent / "e2e" / "stage_configs" / "xpu" / "qwen3_omni_ci.yaml")]
+stage_configs = [get_deploy_config_path("ci/qwen3_omni_moe.yaml")]
example_dir = str(Path(__file__).parent.parent.parent.parent / "examples" / "online_serving")
diff --git a/tests/model_executor/models/mimo_audio/test_mimo_audio_code2wav_batch_decode.py b/tests/model_executor/models/mimo_audio/test_mimo_audio_code2wav_batch_decode.py
index 85c0e8b56e..8858d1f8f1 100644
--- a/tests/model_executor/models/mimo_audio/test_mimo_audio_code2wav_batch_decode.py
+++ b/tests/model_executor/models/mimo_audio/test_mimo_audio_code2wav_batch_decode.py
@@ -2,10 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from types import SimpleNamespace
-from unittest.mock import Mock
import pytest
import torch
+from pytest_mock import MockerFixture
from vllm_omni.model_executor.models.mimo_audio.config_mimo_audio import TALKER_CODEC_PAD_TOKEN_ID
from vllm_omni.model_executor.models.mimo_audio.mimo_audio_code2wav import (
@@ -51,7 +51,7 @@ def _make_invalid_flat_immediate_eostm(eostm_id: int = 666) -> torch.Tensor:
return g.reshape(-1)
-def _minimal_model():
+def _minimal_model(mocker: MockerFixture):
"""Avoid __init__ (HF tokenizer paths); only fields used by _batch_decode_waveforms."""
model = object.__new__(MiMoAudioToken2WavForConditionalGenerationVLLM)
model.device = torch.device("cpu")
@@ -59,7 +59,7 @@ def _minimal_model():
model.streamer_config = AudioStreamerConfig(group_size=_GROUP, audio_channels=_AC)
model.codes = _codes_ns()
- decode_vq = Mock(
+ decode_vq = mocker.Mock(
side_effect=lambda audio_codes: torch.ones(
audio_codes.shape[1],
7,
@@ -67,7 +67,7 @@ def _minimal_model():
device=audio_codes.device,
)
)
- decoder = Mock()
+ decoder = mocker.Mock()
audio_tok = SimpleNamespace(
encoder=SimpleNamespace(decode_vq=decode_vq),
@@ -78,9 +78,9 @@ def _minimal_model():
return model, audio_tok
-def test_batch_decode_waveforms_empty_input_list():
+def test_batch_decode_waveforms_empty_input_list(mocker: MockerFixture):
"""Empty input list returns a single zero-length float32 tensor on model device."""
- model, _ = _minimal_model()
+ model, _ = _minimal_model(mocker)
out = MiMoAudioToken2WavForConditionalGenerationVLLM._batch_decode_waveforms(model, [])
assert len(out) == 1
assert out[0].dtype == torch.float32
@@ -88,9 +88,9 @@ def test_batch_decode_waveforms_empty_input_list():
assert out[0].device == model.device
-def test_batch_decode_waveforms_single_vs_multiple_decoder_shapes():
+def test_batch_decode_waveforms_single_vs_multiple_decoder_shapes(mocker: MockerFixture):
"""Single and multi-request batches produce correctly shaped packed hidden states and trimmed waveforms."""
- model, audio_tok = _minimal_model()
+ model, audio_tok = _minimal_model(mocker)
decoder = audio_tok.decoder
# Single valid request: decoder output rank-3 for double squeeze path
@@ -118,9 +118,9 @@ def test_batch_decode_waveforms_single_vs_multiple_decoder_shapes():
assert out2[1].shape == (8 * _FTP,)
-def test_batch_decode_waveforms_mixed_valid_invalid_requests():
+def test_batch_decode_waveforms_mixed_valid_invalid_requests(mocker: MockerFixture):
"""Mixed valid and invalid requests: invalid slots get empty tensors, valid slots get decoded waveforms."""
- model, audio_tok = _minimal_model()
+ model, audio_tok = _minimal_model(mocker)
valid_a = _make_valid_flat_codes(1)
valid_b = _make_valid_flat_codes(1)
dummy = _make_dummy_code_tensor()
@@ -151,9 +151,9 @@ def test_batch_decode_waveforms_mixed_valid_invalid_requests():
assert input_lengths.tolist() == [4, 4]
-def test_batch_decode_waveforms_all_invalid_returns_per_request_empty():
+def test_batch_decode_waveforms_all_invalid_returns_per_request_empty(mocker: MockerFixture):
"""All-invalid batch skips decoder entirely and returns empty tensors for every slot."""
- model, audio_tok = _minimal_model()
+ model, audio_tok = _minimal_model(mocker)
out = MiMoAudioToken2WavForConditionalGenerationVLLM._batch_decode_waveforms(
model,
[None, _make_dummy_code_tensor(), torch.tensor([], dtype=torch.long)],
@@ -163,9 +163,9 @@ def test_batch_decode_waveforms_all_invalid_returns_per_request_empty():
audio_tok.decoder.assert_not_called()
-def test_batch_decode_waveforms_output_shape_trim_when_decoder_returns_extra_samples():
+def test_batch_decode_waveforms_output_shape_trim_when_decoder_returns_extra_samples(mocker: MockerFixture):
"""Decoder output longer than valid_len is trimmed to the exact expected waveform length."""
- model, audio_tok = _minimal_model()
+ model, audio_tok = _minimal_model(mocker)
flat = _make_valid_flat_codes(1)
# Longer than valid_len so branch wav = wav[:valid_len] runs
audio_tok.decoder.return_value = torch.ones(1, 1, 10_000, dtype=torch.float32)
@@ -175,9 +175,9 @@ def test_batch_decode_waveforms_output_shape_trim_when_decoder_returns_extra_sam
assert out[0].dtype == torch.float32
-def test_batch_decode_waveforms_multi_request_trims_each_row_when_decoder_returns_extra():
+def test_batch_decode_waveforms_multi_request_trims_each_row_when_decoder_returns_extra(mocker: MockerFixture):
"""Else-branch split: per-request wav[:valid_len] when decoder pads each batch row."""
- model, audio_tok = _minimal_model()
+ model, audio_tok = _minimal_model(mocker)
a = _make_valid_flat_codes(1)
b = _make_valid_flat_codes(2)
audio_tok.decoder.return_value = torch.ones(2, 1, 10_000, dtype=torch.float32)
@@ -189,9 +189,9 @@ def test_batch_decode_waveforms_multi_request_trims_each_row_when_decoder_return
assert out[1].dtype == torch.float32
-def test_batch_decode_waveforms_valid_only_at_edges_maps_to_correct_indices():
+def test_batch_decode_waveforms_valid_only_at_edges_maps_to_correct_indices(mocker: MockerFixture):
"""Tensor packing order must match valid_indices when invalid requests are in the middle."""
- model, audio_tok = _minimal_model()
+ model, audio_tok = _minimal_model(mocker)
first = _make_valid_flat_codes(1)
last = _make_valid_flat_codes(2)
inputs = [
@@ -212,9 +212,9 @@ def test_batch_decode_waveforms_valid_only_at_edges_maps_to_correct_indices():
assert input_lengths.tolist() == [4, 8]
-def test_batch_decode_waveforms_output_shapes_1d_float32_for_all_slots():
+def test_batch_decode_waveforms_output_shapes_1d_float32_for_all_slots(mocker: MockerFixture):
"""Every slot is a 1-D float32 vector (empty or waveform), matching downstream expectations."""
- model, audio_tok = _minimal_model()
+ model, audio_tok = _minimal_model(mocker)
inputs = [_make_valid_flat_codes(1), None, _make_valid_flat_codes(1)]
audio_tok.decoder.return_value = torch.ones(2, 1, 5000, dtype=torch.float32)
out = MiMoAudioToken2WavForConditionalGenerationVLLM._batch_decode_waveforms(model, inputs)
diff --git a/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py b/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py
index 8e04b04966..587e7f7f8b 100644
--- a/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py
+++ b/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py
@@ -10,10 +10,9 @@
- Interleaved (use_audio_in_video) should also work correctly.
"""
-from unittest.mock import Mock
-
import pytest
import torch
+from pytest_mock import MockerFixture
from vllm.model_executor.models.qwen2_5_omni_thinker import (
check_interleaved_audio_video,
merge_interleaved_embeddings,
@@ -107,7 +106,7 @@ def test_interleaved(self):
# ---------------------------------------------------------------------------
-def make_mock_model(hidden: int = 8):
+def make_mock_model(mocker: MockerFixture, hidden: int = 8):
"""
Return a minimal mock of Qwen2_5OmniThinkerForConditionalGeneration
that has enough structure to run embed_input_ids.
@@ -116,10 +115,10 @@ def make_mock_model(hidden: int = 8):
Qwen2_5OmniThinkerForConditionalGeneration,
)
- model = Mock(spec=Qwen2_5OmniThinkerForConditionalGeneration)
+ model = mocker.Mock(spec=Qwen2_5OmniThinkerForConditionalGeneration)
# Config with token IDs
- cfg = Mock()
+ cfg = mocker.Mock()
cfg.video_token_index = VIDEO_TOKEN_ID
cfg.audio_token_index = AUDIO_TOKEN_ID
model.config = cfg
@@ -130,9 +129,9 @@ def fake_lm_embed(ids: torch.Tensor) -> torch.Tensor:
# view with shared memory, which masked_scatter_ cannot handle).
return ids.float().unsqueeze(-1).expand(-1, hidden).clone()
- lang_model = Mock()
+ lang_model = mocker.Mock()
lang_model.embed_input_ids = fake_lm_embed
- model.get_language_model = Mock(return_value=lang_model)
+ model.get_language_model = mocker.Mock(return_value=lang_model)
from vllm.model_executor.models.interfaces import SupportsMultiModal
@@ -169,7 +168,7 @@ def build_mm_embeds(audio_n, image_n, video_n, hidden, audio_val=10.0, image_val
class TestEmbedInputIds:
- def _run(self, audio_n, image_n, video_n, hidden=8):
+ def _run(self, mocker: MockerFixture, audio_n, image_n, video_n, hidden=8):
"""
Run embed_input_ids for a non-interleaved mixed-modality sequence.
Returns (result_embeds, input_ids, is_multimodal).
@@ -177,33 +176,33 @@ def _run(self, audio_n, image_n, video_n, hidden=8):
input_ids, is_multimodal = make_token_seq(audio_n, image_n, video_n)
mm_embeds = build_mm_embeds(audio_n, image_n, video_n, hidden)
- model, _ = make_mock_model(hidden)
+ model, _ = make_mock_model(mocker, hidden)
result = model.embed_input_ids(input_ids, mm_embeds, is_multimodal=is_multimodal)
return result, input_ids, is_multimodal
- def test_audio_only(self):
+ def test_audio_only(self, mocker: MockerFixture):
"""Audio-only: audio positions get audio embeddings."""
audio_n, hidden = 5, 8
audio_val = 10.0
- result, input_ids, is_multimodal = self._run(audio_n, 0, 0, hidden)
+ result, input_ids, is_multimodal = self._run(mocker, audio_n, 0, 0, hidden)
audio_pos = (input_ids == AUDIO_TOKEN_ID).nonzero(as_tuple=True)[0]
assert result[audio_pos].allclose(torch.full((audio_n, hidden), audio_val)), (
"Audio positions should get audio embeddings"
)
- def test_video_only(self):
+ def test_video_only(self, mocker: MockerFixture):
"""Video-only: video positions get video embeddings."""
video_n, hidden = 6, 8
video_val = 30.0
- result, input_ids, is_multimodal = self._run(0, 0, video_n, hidden)
+ result, input_ids, is_multimodal = self._run(mocker, 0, 0, video_n, hidden)
video_pos = (input_ids == VIDEO_TOKEN_ID).nonzero(as_tuple=True)[0]
assert result[video_pos].allclose(torch.full((video_n, hidden), video_val)), (
"Video positions should get video embeddings"
)
- def test_mixed_modalities_audio_goes_to_audio_pos(self):
+ def test_mixed_modalities_audio_goes_to_audio_pos(self, mocker: MockerFixture):
"""
Regression test for GitHub issue #34506:
With audio + image + video (non-interleaved), audio positions must
@@ -212,7 +211,7 @@ def test_mixed_modalities_audio_goes_to_audio_pos(self):
audio_n, image_n, video_n, hidden = 5, 4, 6, 8
audio_val, image_val, video_val = 10.0, 20.0, 30.0
- result, input_ids, is_multimodal = self._run(audio_n, image_n, video_n, hidden)
+ result, input_ids, is_multimodal = self._run(mocker, audio_n, image_n, video_n, hidden)
audio_pos = (input_ids == AUDIO_TOKEN_ID).nonzero(as_tuple=True)[0]
image_pos = (input_ids == IMAGE_TOKEN_ID).nonzero(as_tuple=True)[0]
@@ -233,10 +232,10 @@ def test_mixed_modalities_audio_goes_to_audio_pos(self):
f"Video emb wrong: expected {video_val}, got mean={mean_v:.1f}"
)
- def test_text_positions_unchanged(self):
+ def test_text_positions_unchanged(self, mocker: MockerFixture):
"""Text positions should keep their text embeddings."""
audio_n, image_n, video_n, hidden = 3, 2, 4, 8
- result, input_ids, is_multimodal = self._run(audio_n, image_n, video_n, hidden)
+ result, input_ids, is_multimodal = self._run(mocker, audio_n, image_n, video_n, hidden)
text_pos = (~is_multimodal).nonzero(as_tuple=True)[0]
# Text tokens have value TEXT_TOKEN_ID=0, so embed -> 0.0
@@ -244,7 +243,7 @@ def test_text_positions_unchanged(self):
"Text positions should keep text embeddings"
)
- def test_interleaved_use_audio_in_video(self):
+ def test_interleaved_use_audio_in_video(self, mocker: MockerFixture):
"""
Interleaved (use_audio_in_video): video chunks interleaved with audio.
Video embeddings must go to video positions, audio to audio positions.
@@ -263,7 +262,7 @@ def test_interleaved_use_audio_in_video(self):
torch.full((audio_n, hidden), audio_val),
]
- model, _ = make_mock_model(hidden)
+ model, _ = make_mock_model(mocker, hidden)
result = model.embed_input_ids(input_ids, mm_embeds, is_multimodal=is_multimodal)
video_pos = (input_ids == VIDEO_TOKEN_ID).nonzero(as_tuple=True)[0]
diff --git a/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py b/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py
index e2970dcb2d..8798cb3ca9 100644
--- a/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py
+++ b/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py
@@ -15,12 +15,13 @@
import os
import sys
import types
-from unittest.mock import MagicMock, patch
+import pytest
import torch
+from pytest_mock import MockerFixture
# Direct file import to avoid vllm_omni.__init__ patch dependencies.
-_BASE = os.path.join(
+_MODELS = os.path.join(
os.path.dirname(__file__),
os.pardir,
os.pardir,
@@ -29,80 +30,116 @@
"vllm_omni",
"model_executor",
"models",
- "qwen3_tts",
)
+_BASE = os.path.join(_MODELS, "qwen3_tts")
+_COMMON = os.path.join(_MODELS, "common")
def _load_module(name: str, filename: str):
path = os.path.abspath(os.path.join(_BASE, filename))
spec = importlib.util.spec_from_file_location(name, path)
mod = importlib.util.module_from_spec(spec)
+ sys.modules[name] = mod # register before exec (needed for dataclasses etc.)
spec.loader.exec_module(mod)
return mod
-def _build_mock_modules() -> dict[str, object]:
+def _build_mock_modules(mocker: MockerFixture) -> dict[str, object]:
"""Build the dict of modules to inject into sys.modules."""
- platforms_mock = MagicMock()
+ platforms_mock = mocker.MagicMock()
platforms_mock.current_omni_platform.supports_torch_inductor.return_value = False
- logger_mock = MagicMock()
- logger_mock.init_logger = lambda name: MagicMock()
+ logger_mock = mocker.MagicMock()
+ logger_mock.init_logger = lambda name: mocker.MagicMock()
- vllm_config_mod = MagicMock()
- vllm_config_mod.set_current_vllm_config = lambda cfg: MagicMock(__enter__=MagicMock(), __exit__=MagicMock())
+ vllm_config_mod = mocker.MagicMock()
+ vllm_config_mod.set_current_vllm_config = lambda cfg: mocker.MagicMock(
+ __enter__=mocker.MagicMock(),
+ __exit__=mocker.MagicMock(),
+ )
- weight_utils_mock = MagicMock()
+ weight_utils_mock = mocker.MagicMock()
weight_utils_mock.default_weight_loader = lambda p, w: None
- pkg = types.ModuleType("vllm_omni.model_executor.models.qwen3_tts")
- pkg.__path__ = [os.path.abspath(_BASE)]
+ tts_pkg = types.ModuleType("vllm_omni.model_executor.models.qwen3_tts")
+ tts_pkg.__path__ = [os.path.abspath(_BASE)]
+
+ common_pkg = types.ModuleType("vllm_omni.model_executor.models.common")
+ common_pkg.__path__ = [os.path.abspath(_COMMON)]
+
+ models_pkg = types.ModuleType("vllm_omni.model_executor.models")
+ models_pkg.__path__ = [os.path.abspath(_MODELS)]
+
+ vllm_parallel_mock = mocker.MagicMock()
+ vllm_parallel_mock.VocabParallelEmbedding = torch.nn.Embedding
return {
- "vllm_omni": MagicMock(),
+ "vllm_omni": mocker.MagicMock(),
"vllm_omni.platforms": platforms_mock,
"vllm.logger": logger_mock,
- "vllm.config": MagicMock(),
+ "vllm.config": mocker.MagicMock(),
"vllm.config.vllm": vllm_config_mod,
"vllm.model_executor.model_loader.weight_utils": weight_utils_mock,
+ "vllm.model_executor.layers.vocab_parallel_embedding": vllm_parallel_mock,
"vllm_omni.model_executor": types.ModuleType("vllm_omni.model_executor"),
- "vllm_omni.model_executor.models": types.ModuleType("vllm_omni.model_executor.models"),
- "vllm_omni.model_executor.models.qwen3_tts": pkg,
+ "vllm_omni.model_executor.models": models_pkg,
+ "vllm_omni.model_executor.models.common": common_pkg,
+ "vllm_omni.model_executor.models.qwen3_tts": tts_pkg,
}
-def _load_target_classes():
+def _load_target_classes(mocker: MockerFixture):
"""Load config and code predictor modules with mocked dependencies.
- Uses patch.dict to ensure sys.modules is always restored, even on failure.
+ Uses mocker.patch.dict to ensure sys.modules is always restored, even on failure.
"""
- mocks = _build_mock_modules()
- with patch.dict(sys.modules, mocks):
- config_mod = _load_module(
- "vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts",
- "configuration_qwen3_tts.py",
- )
- sys.modules["vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts"] = config_mod
+ mocks = _build_mock_modules(mocker)
+ mocker.patch.dict(sys.modules, mocks)
+ config_mod = _load_module(
+ "vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts",
+ "configuration_qwen3_tts.py",
+ )
+ sys.modules["vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts"] = config_mod
- cp_mod = _load_module(
- "vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code_predictor_vllm",
- "qwen3_tts_code_predictor_vllm.py",
- )
+ # Load the shared common module (thin wrappers import from it)
+ common_cp_path = os.path.abspath(os.path.join(_COMMON, "qwen3_code_predictor.py"))
+ common_spec = importlib.util.spec_from_file_location(
+ "vllm_omni.model_executor.models.common.qwen3_code_predictor", common_cp_path
+ )
+ common_cp_mod = importlib.util.module_from_spec(common_spec)
+ sys.modules["vllm_omni.model_executor.models.common.qwen3_code_predictor"] = common_cp_mod
+ common_spec.loader.exec_module(common_cp_mod)
- return config_mod, cp_mod
+ cp_mod = _load_module(
+ "vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code_predictor_vllm",
+ "qwen3_tts_code_predictor_vllm.py",
+ )
+ return config_mod, cp_mod
-_config_mod, _cp_mod = _load_target_classes()
-Qwen3TTSTalkerCodePredictorConfig = _config_mod.Qwen3TTSTalkerCodePredictorConfig
-Qwen3TTSTalkerConfig = _config_mod.Qwen3TTSTalkerConfig
-CodePredictorWrapper = _cp_mod.Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM
-CodePredictorModel = _cp_mod.Qwen3TTSTalkerCodePredictorModelVLLM
+@pytest.fixture
+def loaded_target_classes(mocker: MockerFixture):
+ config_mod, cp_mod = _load_target_classes(mocker)
+ return (
+ config_mod.Qwen3TTSTalkerCodePredictorConfig,
+ config_mod.Qwen3TTSTalkerConfig,
+ cp_mod.Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM,
+ cp_mod.Qwen3TTSTalkerCodePredictorModelVLLM,
+ cp_mod.CodePredictorWrapperConfig,
+ )
-def _make_tiny_config() -> tuple:
+def _make_tiny_config(loaded_target_classes) -> tuple:
"""Create minimal configs for a tiny code predictor model."""
- cp_config = Qwen3TTSTalkerCodePredictorConfig(
+ (
+ qwen3_tts_talker_code_predictor_config,
+ qwen3_tts_talker_config,
+ _,
+ _,
+ _,
+ ) = loaded_target_classes
+ cp_config = qwen3_tts_talker_code_predictor_config(
vocab_size=64,
hidden_size=32,
intermediate_size=64,
@@ -113,16 +150,16 @@ def _make_tiny_config() -> tuple:
num_code_groups=4,
rms_norm_eps=1e-6,
)
- talker_config = Qwen3TTSTalkerConfig(
+ talker_config = qwen3_tts_talker_config(
hidden_size=32,
num_code_groups=4,
)
return cp_config, talker_config
-def _make_vllm_config(max_num_seqs: int = 4) -> MagicMock:
+def _make_vllm_config(mocker: MockerFixture, max_num_seqs: int = 4):
"""Create a mock VllmConfig with scheduler_config."""
- vllm_config = MagicMock()
+ vllm_config = mocker.MagicMock()
vllm_config.scheduler_config.max_num_seqs = max_num_seqs
return vllm_config
@@ -130,32 +167,34 @@ def _make_vllm_config(max_num_seqs: int = 4) -> MagicMock:
class TestCodePredictorDtypeAlignment:
"""Test that code predictor buffers match model parameter dtype."""
- def test_ensure_buffers_uses_given_dtype(self) -> None:
+ def test_ensure_buffers_uses_given_dtype(self, mocker: MockerFixture, loaded_target_classes) -> None:
"""_ensure_buffers should create proj_buf with the given dtype."""
- cp_config, talker_config = _make_tiny_config()
- vllm_config = _make_vllm_config()
+ _, _, code_predictor_wrapper, _, _ = loaded_target_classes
+ cp_config, talker_config = _make_tiny_config(loaded_target_classes)
+ vllm_config = _make_vllm_config(mocker)
- predictor = CodePredictorWrapper(
+ predictor = code_predictor_wrapper(
vllm_config=vllm_config,
config=cp_config,
talker_config=talker_config,
)
# Create buffer in float16
- predictor._ensure_buffers(torch.device("cpu"), torch.float16)
+ predictor._ensure_buffers(torch.device("cpu"), torch.float16, 4)
assert predictor._proj_buf is not None
assert predictor._proj_buf.dtype == torch.float16
# Re-create buffer in float32 (different dtype triggers re-allocation)
- predictor._ensure_buffers(torch.device("cpu"), torch.float32)
+ predictor._ensure_buffers(torch.device("cpu"), torch.float32, 4)
assert predictor._proj_buf.dtype == torch.float32
- def test_warmup_aligns_buffer_to_model_params(self) -> None:
+ def test_warmup_aligns_buffer_to_model_params(self, mocker: MockerFixture, loaded_target_classes) -> None:
"""_warmup_buckets should align proj_buf dtype to model parameters."""
- cp_config, talker_config = _make_tiny_config()
- vllm_config = _make_vllm_config(max_num_seqs=2)
+ _, _, code_predictor_wrapper, _, _ = loaded_target_classes
+ cp_config, talker_config = _make_tiny_config(loaded_target_classes)
+ vllm_config = _make_vllm_config(mocker, max_num_seqs=2)
- predictor = CodePredictorWrapper(
+ predictor = code_predictor_wrapper(
vllm_config=vllm_config,
config=cp_config,
talker_config=talker_config,
@@ -165,7 +204,7 @@ def test_warmup_aligns_buffer_to_model_params(self) -> None:
predictor = predictor.to(torch.float16)
# Pre-create proj_buf with WRONG dtype (float32) — simulating the bug
- predictor._ensure_buffers(torch.device("cpu"), torch.float32)
+ predictor._ensure_buffers(torch.device("cpu"), torch.float32, 2)
assert predictor._proj_buf.dtype == torch.float32
# Simulate _setup_compile having cached model dtype and compiled forward
@@ -177,12 +216,13 @@ def test_warmup_aligns_buffer_to_model_params(self) -> None:
assert predictor._proj_buf.dtype == torch.float16
- def test_setup_compile_caches_model_dtype(self) -> None:
+ def test_setup_compile_caches_model_dtype(self, mocker: MockerFixture, loaded_target_classes) -> None:
"""_setup_compile should cache model parameter dtype."""
- cp_config, talker_config = _make_tiny_config()
- vllm_config = _make_vllm_config(max_num_seqs=2)
+ _, _, code_predictor_wrapper, _, _ = loaded_target_classes
+ cp_config, talker_config = _make_tiny_config(loaded_target_classes)
+ vllm_config = _make_vllm_config(mocker, max_num_seqs=2)
- predictor = CodePredictorWrapper(
+ predictor = code_predictor_wrapper(
vllm_config=vllm_config,
config=cp_config,
talker_config=talker_config,
@@ -193,12 +233,13 @@ def test_setup_compile_caches_model_dtype(self) -> None:
predictor._setup_compile()
assert predictor._model_dtype == torch.float16
- def test_forward_with_mismatched_input_dtype(self) -> None:
+ def test_forward_with_mismatched_input_dtype(self, mocker: MockerFixture, loaded_target_classes) -> None:
"""forward() should not crash when inputs are float32 but model is float16."""
- cp_config, talker_config = _make_tiny_config()
- vllm_config = _make_vllm_config(max_num_seqs=2)
+ _, _, code_predictor_wrapper, _, _ = loaded_target_classes
+ cp_config, talker_config = _make_tiny_config(loaded_target_classes)
+ vllm_config = _make_vllm_config(mocker, max_num_seqs=2)
- predictor = CodePredictorWrapper(
+ predictor = code_predictor_wrapper(
vllm_config=vllm_config,
config=cp_config,
talker_config=talker_config,
@@ -231,10 +272,11 @@ def test_forward_with_mismatched_input_dtype(self) -> None:
class TestCodePredictorModelDtype:
"""Test the inner model forward with different dtypes."""
- def test_model_forward_float16(self) -> None:
+ def test_model_forward_float16(self, loaded_target_classes) -> None:
"""Inner model forward should work in float16."""
- cp_config, _ = _make_tiny_config()
- model = CodePredictorModel(cp_config, talker_hidden_size=32).to(torch.float16)
+ _, _, _, code_predictor_model, _ = loaded_target_classes
+ cp_config, _ = _make_tiny_config(loaded_target_classes)
+ model = code_predictor_model(cp_config, embedding_dim=32).to(torch.float16)
bsz, seq_len = 1, 4
inputs = torch.randn(bsz, seq_len, 32, dtype=torch.float16)
@@ -244,10 +286,11 @@ def test_model_forward_float16(self) -> None:
assert output.dtype == torch.float16
assert output.shape == (bsz, seq_len, 32)
- def test_model_forward_float32(self) -> None:
+ def test_model_forward_float32(self, loaded_target_classes) -> None:
"""Inner model forward should work in float32."""
- cp_config, _ = _make_tiny_config()
- model = CodePredictorModel(cp_config, talker_hidden_size=32).to(torch.float32)
+ _, _, _, code_predictor_model, _ = loaded_target_classes
+ cp_config, _ = _make_tiny_config(loaded_target_classes)
+ model = code_predictor_model(cp_config, embedding_dim=32).to(torch.float32)
bsz, seq_len = 1, 4
inputs = torch.randn(bsz, seq_len, 32, dtype=torch.float32)
@@ -256,3 +299,37 @@ def test_model_forward_float32(self) -> None:
output = model(inputs, pos_ids)
assert output.dtype == torch.float32
assert output.shape == (bsz, seq_len, 32)
+
+
+class TestCodePredictorWrapperConfig:
+ """Test wrapper configuration for different models."""
+
+ def test_omni_config(self, loaded_target_classes) -> None:
+ """Qwen3-Omni uses correct wrapper config."""
+ _, _, _, _, code_predictor_wrapper_config = loaded_target_classes
+ config = code_predictor_wrapper_config(
+ use_cuda_graphs=False,
+ use_parallel_embedding=True,
+ use_projection=False,
+ return_proj_buf=True,
+ sampling_mode="stored",
+ )
+ assert config.use_cuda_graphs is False
+ assert config.use_parallel_embedding is True
+ assert config.return_proj_buf is True
+ assert config.sampling_mode == "stored"
+
+ def test_tts_config(self, loaded_target_classes) -> None:
+ """Qwen3-TTS uses correct wrapper config."""
+ _, _, _, _, code_predictor_wrapper_config = loaded_target_classes
+ config = code_predictor_wrapper_config(
+ use_cuda_graphs=True,
+ use_parallel_embedding=False,
+ use_projection=True,
+ return_proj_buf=False,
+ sampling_mode="per_call",
+ )
+ assert config.use_cuda_graphs is True
+ assert config.use_parallel_embedding is False
+ assert config.return_proj_buf is False
+ assert config.sampling_mode == "per_call"
diff --git a/tests/model_executor/models/test_encoder_quant_config.py b/tests/model_executor/models/test_encoder_quant_config.py
new file mode 100644
index 0000000000..8020184986
--- /dev/null
+++ b/tests/model_executor/models/test_encoder_quant_config.py
@@ -0,0 +1,77 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Regression test for #2686: pre-quantized methods must not apply
+quant config to vision / audio encoders.
+
+For modelopt FP8/FP4/MXFP8 checkpoints the Thinker LM is the only
+quantized component. Vision and audio encoder weights are BF16 with no
+FP8 scale tensors — passing quant_config to them causes FP8 kernels to
+run on BF16 weights, producing garbage embeddings.
+"""
+
+from __future__ import annotations
+
+from unittest.mock import MagicMock
+
+import pytest
+
+from vllm_omni.quantization.component_config import (
+ PRE_QUANTIZED_METHODS,
+ ComponentQuantizationConfig,
+ resolve_encoder_quant_config,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+# ---------------------------------------------------------------------------
+# resolve_encoder_quant_config — the core routing logic for encoder quant
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.parametrize("method", sorted(PRE_QUANTIZED_METHODS))
+def test_pre_quantized_returns_none(method: str) -> None:
+ """visual_quant_config and audio_quant_config must be None for
+ pre-quantized methods (modelopt, modelopt_fp4, modelopt_mxfp8)."""
+ mock_config = MagicMock()
+ mock_config.get_name.return_value = method
+
+ assert resolve_encoder_quant_config(mock_config) is None
+
+
+@pytest.mark.parametrize("method", ["fp8", "awq", "gptq", "bitsandbytes"])
+def test_non_pre_quantized_preserves_config(method: str) -> None:
+ """Non-pre-quantized methods should pass through the original config."""
+ mock_config = MagicMock()
+ mock_config.get_name.return_value = method
+
+ assert resolve_encoder_quant_config(mock_config) is mock_config
+
+
+def test_none_input_returns_none() -> None:
+ """No quantization → None for encoders."""
+ assert resolve_encoder_quant_config(None) is None
+
+
+def test_component_config_passed_through() -> None:
+ """ComponentQuantizationConfig should be returned as-is so the caller
+ can call .resolve() with the appropriate prefix."""
+ inner = MagicMock()
+ inner.get_name.return_value = "modelopt" # would be None if not Component
+ component = ComponentQuantizationConfig(
+ component_configs={"language_model": inner},
+ default_config=None,
+ )
+
+ result = resolve_encoder_quant_config(component)
+ assert result is component
+
+
+# ---------------------------------------------------------------------------
+# PRE_QUANTIZED_METHODS constant — exhaustiveness check
+# ---------------------------------------------------------------------------
+
+
+def test_pre_quantized_methods_contains_expected() -> None:
+ """Guard against accidental removal of a known pre-quantized method."""
+ expected = {"modelopt", "modelopt_fp4", "modelopt_mxfp8"}
+ assert PRE_QUANTIZED_METHODS == expected
diff --git a/tests/model_executor/models/test_fish_speech_voice_cache.py b/tests/model_executor/models/test_fish_speech_voice_cache.py
index 8fe7a4a4d1..fef4b551ab 100644
--- a/tests/model_executor/models/test_fish_speech_voice_cache.py
+++ b/tests/model_executor/models/test_fish_speech_voice_cache.py
@@ -10,11 +10,11 @@
import os
import tempfile
-from unittest.mock import MagicMock, patch
import numpy as np
import pytest
import torch
+from pytest_mock import MockerFixture
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
@@ -61,18 +61,18 @@ class TestFishSpeechVoiceCacheIntegration:
"""Test the cache-hit / cache-miss / no-cache paths in the model."""
@pytest.fixture
- def mock_model(self):
+ def mock_model(self, mocker: MockerFixture):
"""Create a mock FishSpeechSlowARForConditionalGeneration with cache."""
from vllm_omni.utils.voice_cache import VoiceEmbeddingCache
- model = MagicMock()
+ model = mocker.MagicMock()
model._voice_cache = VoiceEmbeddingCache(max_entries=4)
model._semantic_begin_id = 151678
model._num_codebooks = 10
model._codebook_size = 4096
model.model_path = "/fake/model"
- model.codebook_embeddings = MagicMock()
- model.codebook_embeddings.weight = MagicMock()
+ model.codebook_embeddings = mocker.MagicMock()
+ model.codebook_embeddings.weight = mocker.MagicMock()
model.codebook_embeddings.weight.device = torch.device("cpu")
return model
@@ -166,9 +166,9 @@ def test_created_at_zero_disables_cache(self, mock_model):
class TestFishSpeechValidatorUploadedVoice:
"""Test _validate_fish_tts_request uploaded voice resolution."""
- def test_uploaded_voice_resolves_ref_audio(self):
+ def test_uploaded_voice_resolves_ref_audio(self, mocker: MockerFixture):
"""When voice matches an uploaded speaker, ref_audio should be auto-set."""
- request = MagicMock()
+ request = mocker.MagicMock()
request.input = "Hello"
request.voice = "alice"
request.ref_audio = None
@@ -185,17 +185,17 @@ def test_uploaded_voice_resolves_ref_audio(self):
}
# Simulate: voice in uploaded_speakers, file exists, get_audio returns data URL.
- with patch("pathlib.Path.exists", return_value=True):
- voice_lower = request.voice.lower()
- assert voice_lower in uploaded_speakers
+ mocker.patch("pathlib.Path.exists", return_value=True)
+ voice_lower = request.voice.lower()
+ assert voice_lower in uploaded_speakers
- speaker_info = uploaded_speakers[voice_lower]
- ref_text_from_upload = speaker_info.get("ref_text")
- assert ref_text_from_upload == "Hi this is Alice"
+ speaker_info = uploaded_speakers[voice_lower]
+ ref_text_from_upload = speaker_info.get("ref_text")
+ assert ref_text_from_upload == "Hi this is Alice"
- def test_uploaded_voice_without_ref_text_uses_request_ref_text(self):
+ def test_uploaded_voice_without_ref_text_uses_request_ref_text(self, mocker: MockerFixture):
"""If upload has no ref_text but request provides it, use request's."""
- request = MagicMock()
+ request = mocker.MagicMock()
request.input = "Hello"
request.voice = "bob"
request.ref_audio = None
diff --git a/tests/model_executor/models/voxcpm2/__init__.py b/tests/model_executor/models/voxcpm2/__init__.py
new file mode 100644
index 0000000000..208f01a7cb
--- /dev/null
+++ b/tests/model_executor/models/voxcpm2/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
diff --git a/tests/model_executor/models/voxcpm2/test_talker_state_eviction.py b/tests/model_executor/models/voxcpm2/test_talker_state_eviction.py
new file mode 100644
index 0000000000..5d8a35636b
--- /dev/null
+++ b/tests/model_executor/models/voxcpm2/test_talker_state_eviction.py
@@ -0,0 +1,121 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Regression tests for VoxCPM2 talker per-request state lifecycle."""
+
+from __future__ import annotations
+
+import pytest
+
+torch = pytest.importorskip("torch")
+pytest.importorskip("librosa")
+
+from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import ( # noqa: E402
+ VoxCPM2TalkerForConditionalGeneration,
+ _RequestState,
+)
+
+
+def _make_bare_talker() -> VoxCPM2TalkerForConditionalGeneration:
+ talker = VoxCPM2TalkerForConditionalGeneration.__new__(VoxCPM2TalkerForConditionalGeneration)
+ talker._active_states = {}
+ talker._current_request_id = None
+ talker._pending_requests = []
+ talker._results_queue = []
+ talker._audio_queue = []
+ talker._deferred_cleanup_ids = set()
+ talker._max_batch_size = 4
+ talker._active_state_warn_threshold = 512
+ talker._active_state_warned = False
+ return talker
+
+
+def _seed_cached_decode(talker, req_id: str) -> _RequestState:
+ state = _RequestState(request_id=req_id)
+ state.prefill_completed = True
+ state.decode_step_count = 5
+ talker._active_states[req_id] = state
+ return state
+
+
+class TestStateEvictionContract:
+ def test_pending_requests_is_not_used_for_eviction(self) -> None:
+ talker = _make_bare_talker()
+
+ cached_ids = [f"req-{i}" for i in range(4)]
+ for rid in cached_ids:
+ _seed_cached_decode(talker, rid)
+
+ walked_so_far = ["req-new", cached_ids[0], cached_ids[1]]
+ talker._pending_requests = [(rid, False, None, 0) for rid in walked_so_far]
+
+ for rid in cached_ids:
+ assert rid in talker._active_states
+ assert talker._active_states[rid].prefill_completed is True
+
+ def test_on_requests_finished_defers_cleanup(self) -> None:
+ talker = _make_bare_talker()
+ _seed_cached_decode(talker, "req-A")
+ _seed_cached_decode(talker, "req-B")
+
+ talker.on_requests_finished({"req-A"})
+
+ assert "req-A" in talker._active_states
+ assert "req-A" in talker._deferred_cleanup_ids
+
+ def test_flush_deferred_cleanup_removes_only_finished(self) -> None:
+ talker = _make_bare_talker()
+ _seed_cached_decode(talker, "req-A")
+ _seed_cached_decode(talker, "req-B")
+ talker.on_requests_finished(["req-A"])
+
+ talker._flush_deferred_cleanup()
+
+ assert "req-A" not in talker._active_states
+ assert "req-B" in talker._active_states
+ assert talker._deferred_cleanup_ids == set()
+
+ def test_current_request_id_cleared_when_matching(self) -> None:
+ talker = _make_bare_talker()
+ _seed_cached_decode(talker, "req-A")
+ talker._current_request_id = "req-A"
+
+ talker.on_requests_finished({"req-A"})
+ talker._flush_deferred_cleanup()
+
+ assert talker._current_request_id is None
+
+ def test_current_request_id_preserved_when_not_finished(self) -> None:
+ talker = _make_bare_talker()
+ _seed_cached_decode(talker, "req-A")
+ _seed_cached_decode(talker, "req-B")
+ talker._current_request_id = "req-B"
+
+ talker.on_requests_finished({"req-A"})
+ talker._flush_deferred_cleanup()
+
+ assert talker._current_request_id == "req-B"
+
+
+class TestLeakWarnGuard:
+ def test_warn_fires_once_over_threshold(self, monkeypatch) -> None:
+ from vllm_omni.model_executor.models.voxcpm2 import voxcpm2_talker as tk
+
+ calls: list[str] = []
+
+ def _capture(msg, *args, **kwargs):
+ calls.append(msg % args if args else msg)
+
+ monkeypatch.setattr(tk.logger, "warning", _capture)
+
+ talker = _make_bare_talker()
+ talker._active_state_warn_threshold = 3
+
+ for i in range(4):
+ talker._active_states[f"seed-{i}"] = _RequestState(request_id=f"seed-{i}")
+
+ talker._get_or_create_state("new-1")
+ talker._get_or_create_state("new-2")
+
+ leak_warnings = [m for m in calls if "cleanup path leak" in m]
+ assert len(leak_warnings) == 1
+ assert talker._active_state_warned is True
diff --git a/tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py b/tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py
index 6f072944d9..847adae06f 100644
--- a/tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py
+++ b/tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py
@@ -78,6 +78,13 @@
AudioSpecialTokens = _mod2.AudioSpecialTokens
+class SyntheticAcousticTransformerArgs:
+ """Mimics AcousticTransformerArgs interface."""
+
+ def __init__(self):
+ self.n_decoding_steps = 7
+
+
class SyntheticModelArgs:
"""Mimics MultimodalAudioModelArgs interface."""
@@ -96,6 +103,7 @@ class SyntheticAcousticTransformer(nn.Module):
def __init__(self):
super().__init__()
self.model_args = SyntheticModelArgs()
+ self.acoustic_transformer_args = SyntheticAcousticTransformerArgs()
self.acoustic_embeddings_levels = ACOUSTIC_EMBEDDINGS_LEVELS
# semantic_codebook_output: hidden_dim -> padded_codebook_size
diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py
new file mode 100644
index 0000000000..18972c91d5
--- /dev/null
+++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py
@@ -0,0 +1,81 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for Qwen3-Omni streaming thinker→talker / talker→codec helpers (PR #2581)."""
+
+from __future__ import annotations
+
+from types import SimpleNamespace
+
+import pytest
+
+import vllm_omni.model_executor.stage_input_processors.qwen3_omni as q3
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+@pytest.fixture(autouse=True)
+def _streaming_context() -> SimpleNamespace:
+ return SimpleNamespace(bridge_states={})
+
+
+def test_get_streaming_talker_tokens_first_segment(_streaming_context: SimpleNamespace) -> None:
+ inc_p, inc_o, merged, thinker_in = q3._get_streaming_talker_tokens(
+ "r1",
+ [1, 2],
+ [10, 11],
+ streaming_context=_streaming_context,
+ )
+ assert inc_p == [1, 2]
+ assert inc_o == [10, 11]
+ assert merged == [1, 2, 10, 11]
+ assert thinker_in == [1, 2]
+
+
+def test_get_streaming_talker_tokens_second_segment_accumulates(_streaming_context: SimpleNamespace) -> None:
+ q3._get_streaming_talker_tokens("r2", [1, 2], [10, 11], streaming_context=_streaming_context)
+ inc_p, inc_o, merged, thinker_in = q3._get_streaming_talker_tokens(
+ "r2",
+ [1, 2, 3, 4],
+ [10, 11, 12, 13],
+ streaming_context=_streaming_context,
+ )
+ assert inc_p == [3, 4]
+ assert inc_o == [12, 13]
+ assert merged == [1, 2, 10, 3, 4, 12, 13]
+ assert thinker_in == [1, 2, 10, 3, 4]
+
+
+def test_get_streaming_talker_tokens_new_prompt_len_snapshot_truncates(
+ _streaming_context: SimpleNamespace,
+) -> None:
+ inc_p, inc_o, merged, thinker_in = q3._get_streaming_talker_tokens(
+ "r3",
+ [1, 2, 3, 4, 5, 6],
+ [10],
+ new_prompt_len_snapshot=2,
+ streaming_context=_streaming_context,
+ )
+ assert inc_p == [1, 2, 3, 4]
+ assert inc_o == [10]
+ assert merged == [1, 2, 3, 4, 10]
+ assert thinker_in == [1, 2, 3, 4]
+
+
+def test_get_streaming_talker_tokens_clear_state(_streaming_context: SimpleNamespace) -> None:
+ q3._get_streaming_talker_tokens("r4", [1], [2], streaming_context=_streaming_context, clear_state=True)
+ state = q3._get_qwen3_streaming_state("r4", _streaming_context).thinker2talker
+ assert state.last_prompt_len == 0
+ assert state.last_output_len == 0
+ assert state.merged_sequences == []
+
+
+def test_get_streaming_codec_delta_len_increments_and_finishes(_streaming_context: SimpleNamespace) -> None:
+ d1 = q3._get_streaming_codec_delta_len(5, "c1", SimpleNamespace(finished=False), _streaming_context)
+ assert d1 == 5
+ d2 = q3._get_streaming_codec_delta_len(8, "c1", SimpleNamespace(finished=False), _streaming_context)
+ assert d2 == 2
+ # After d2, stored cursor is cur_seq_len + 1 == 9; next delta uses new cur_seq_len - 9.
+ d3 = q3._get_streaming_codec_delta_len(10, "c1", SimpleNamespace(finished=True), _streaming_context)
+ assert d3 == 1
+ state = q3._get_qwen3_streaming_state("c1", _streaming_context)
+ assert state.talker2code2wav_last_seq_len == 0
diff --git a/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py b/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py
new file mode 100644
index 0000000000..7d6fc6e74c
--- /dev/null
+++ b/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py
@@ -0,0 +1,87 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""UTs for VoxCPM async-chunk stage input processing."""
+
+from types import SimpleNamespace
+
+import pytest
+import torch
+
+from vllm_omni.model_executor.stage_input_processors.voxcpm import (
+ _VOXCPM_LATENT_MAGIC,
+ _coerce_finished_flag,
+ latent2vae_async_chunk,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+def _request(*, finished):
+ return SimpleNamespace(is_finished=lambda: finished)
+
+
+def _decode_serialized_latent(codes: list[int]) -> torch.Tensor:
+ assert codes[0] == _VOXCPM_LATENT_MAGIC
+ latent_dim = codes[1]
+ time_dim = codes[2]
+ payload = torch.tensor(codes[3:], dtype=torch.int32).to(torch.uint16)
+ return payload.view(torch.bfloat16).to(torch.float32).reshape(1, latent_dim, time_dim)
+
+
+@pytest.mark.parametrize(
+ ("value", "expected"),
+ [
+ (None, False),
+ (False, False),
+ (True, True),
+ (torch.tensor(False), False),
+ (torch.tensor(True), True),
+ ([torch.tensor(True)], True),
+ (([True],), True),
+ ([], False),
+ ],
+)
+def test_coerce_finished_flag(value, expected):
+ assert _coerce_finished_flag(value) is expected
+
+
+def test_latent2vae_async_chunk_serializes_latent_payload():
+ latent = torch.arange(6, dtype=torch.float32).reshape(2, 3)
+
+ payload = latent2vae_async_chunk(
+ transfer_manager=None,
+ pooling_output={"latent_audio_feat": latent},
+ request=_request(finished=False),
+ is_finished=torch.tensor(False),
+ )
+
+ assert payload is not None
+ assert torch.equal(payload["finished"], torch.tensor(False, dtype=torch.bool))
+ recovered = _decode_serialized_latent(payload["code_predictor_codes"])
+ torch.testing.assert_close(recovered, latent.to(torch.bfloat16).to(torch.float32).unsqueeze(0))
+
+
+def test_latent2vae_async_chunk_returns_terminal_marker_without_latent():
+ payload = latent2vae_async_chunk(
+ transfer_manager=None,
+ pooling_output=None,
+ request=_request(finished=[torch.tensor(True)]),
+ is_finished=False,
+ )
+
+ assert payload == {
+ "code_predictor_codes": [],
+ "finished": torch.tensor(True, dtype=torch.bool),
+ }
+
+
+def test_latent2vae_async_chunk_returns_none_for_nonterminal_empty_chunk():
+ payload = latent2vae_async_chunk(
+ transfer_manager=None,
+ pooling_output={"latent_audio_feat": torch.zeros((0,), dtype=torch.float32)},
+ request=_request(finished=False),
+ is_finished=False,
+ )
+
+ assert payload is None
diff --git a/tests/test_arg_utils.py b/tests/test_arg_utils.py
new file mode 100644
index 0000000000..dab5ed6878
--- /dev/null
+++ b/tests/test_arg_utils.py
@@ -0,0 +1,353 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Tests for vllm_omni.engine.arg_utils — invariants that must
+hold for the orchestrator/engine/server CLI flag partition."""
+
+from __future__ import annotations
+
+import logging
+from dataclasses import dataclass, fields
+
+import pytest
+
+from vllm_omni.engine.arg_utils import (
+ SHARED_FIELDS,
+ derive_server_dests_from_vllm_parser,
+ internal_blacklist_keys,
+ orchestrator_args_from_argparse,
+ orchestrator_field_names,
+ split_kwargs,
+)
+
+# ---------------------------------------------------------------------------
+# Fake engine class for unit testing — avoids pulling in the full vllm
+# EngineArgs and its heavy __post_init__ at test time.
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class _FakeEngineArgs:
+ """Stand-in for OmniEngineArgs with a representative subset of fields."""
+
+ model: str = ""
+ stage_id: int = 0
+ max_num_seqs: int = 64
+ gpu_memory_utilization: float = 0.9
+ async_chunk: bool = False # also in OrchestratorArgs → shared
+ log_stats: bool = False # also in OrchestratorArgs → shared
+ stage_configs_path: str | None = None
+
+
+# ============================================================================
+# Invariant 1 — OrchestratorArgs and engine must not ambiguously overlap.
+# ============================================================================
+
+
+def test_no_ambiguous_overlap_with_fake_engine():
+ """OrchestratorArgs ∩ engine fields must be ⊆ SHARED_FIELDS."""
+ orch = orchestrator_field_names()
+ engine = {f.name for f in fields(_FakeEngineArgs)}
+ overlap = orch & engine
+ unexpected = overlap - SHARED_FIELDS
+ assert not unexpected, (
+ f"Fields declared in both OrchestratorArgs and the engine class "
+ f"but not in SHARED_FIELDS: {sorted(unexpected)}. These cause "
+ f"double-routing — either remove the duplicate declaration or add "
+ f"to SHARED_FIELDS if sharing is intentional."
+ )
+
+
+def test_no_ambiguous_overlap_with_real_engine():
+ """Same check, but against the real OmniEngineArgs."""
+ try:
+ from vllm_omni.engine.arg_utils import OmniEngineArgs
+ except Exception as exc:
+ pytest.skip(f"OmniEngineArgs not importable: {exc}")
+
+ orch = orchestrator_field_names()
+ engine = {f.name for f in fields(OmniEngineArgs)}
+ overlap = orch & engine
+ unexpected = overlap - SHARED_FIELDS
+ assert not unexpected, (
+ f"Real OmniEngineArgs has ambiguous overlap with OrchestratorArgs: "
+ f"{sorted(unexpected)}. Update SHARED_FIELDS or remove duplication."
+ )
+
+
+# ============================================================================
+# Invariant 2 — split_kwargs partitions correctly.
+# ============================================================================
+
+
+def test_split_orchestrator_only():
+ """Pure orchestrator fields go to OrchestratorArgs, not engine_kwargs."""
+ raw = {"stage_init_timeout": 500, "worker_backend": "ray"}
+ orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs)
+ assert orch.stage_init_timeout == 500
+ assert orch.worker_backend == "ray"
+ assert "stage_init_timeout" not in engine
+ assert "worker_backend" not in engine
+
+
+def test_split_engine_only():
+ """Pure engine fields go to engine_kwargs, not OrchestratorArgs."""
+ raw = {"max_num_seqs": 128, "gpu_memory_utilization": 0.85}
+ orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs)
+ assert engine["max_num_seqs"] == 128
+ assert engine["gpu_memory_utilization"] == 0.85
+ # These fields don't exist on OrchestratorArgs at all.
+
+
+def test_split_shared_fields_go_to_both():
+ """Fields in SHARED_FIELDS are copied to both buckets."""
+ raw = {"model": "Qwen/Qwen2.5-Omni-7B", "log_stats": True}
+ orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs)
+ assert orch.log_stats is True
+ assert engine["model"] == "Qwen/Qwen2.5-Omni-7B"
+ assert engine["log_stats"] is True
+
+
+def test_split_drops_unclassified():
+ """Unclassified fields (uvicorn/server) are dropped silently."""
+ raw = {
+ "max_num_seqs": 64, # engine
+ "host": "0.0.0.0", # unclassified (server)
+ "port": 8091, # unclassified (server)
+ "ssl_keyfile": "key.pem", # unclassified (server)
+ }
+ orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs)
+ assert engine == {"max_num_seqs": 64}
+ assert "host" not in engine
+ assert "port" not in engine
+ assert "ssl_keyfile" not in engine
+
+
+def test_split_mixed_real_world():
+ """End-to-end: raw CLI kwargs with all three classes present."""
+ raw = {
+ # orchestrator
+ "stage_init_timeout": 400,
+ "deploy_config": "/tmp/deploy.yaml",
+ "worker_backend": "multi_process",
+ "async_chunk": True,
+ # engine
+ "max_num_seqs": 32,
+ "gpu_memory_utilization": 0.8,
+ # shared
+ "model": "Qwen/Qwen3-Omni",
+ "log_stats": False,
+ # server / unclassified
+ "host": "0.0.0.0",
+ "port": 8091,
+ "api_key": "secret",
+ # None values
+ "ray_address": None,
+ }
+ orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs)
+
+ # Orchestrator side
+ assert orch.stage_init_timeout == 400
+ assert orch.deploy_config == "/tmp/deploy.yaml"
+ assert orch.worker_backend == "multi_process"
+ assert orch.async_chunk is True
+ assert orch.log_stats is False # shared, read from raw
+ assert orch.ray_address is None # default preserved
+
+ # Engine side
+ assert engine["max_num_seqs"] == 32
+ assert engine["gpu_memory_utilization"] == 0.8
+ assert engine["model"] == "Qwen/Qwen3-Omni"
+ assert engine["log_stats"] is False
+ assert "host" not in engine
+ assert "port" not in engine
+ assert "api_key" not in engine
+ # orchestrator-only keys never reach engine
+ assert "stage_init_timeout" not in engine
+ assert "deploy_config" not in engine
+ assert "async_chunk" not in engine
+
+
+# ============================================================================
+# Invariant 3 — user-typed unclassifiable flags warn (don't fail silently).
+# ============================================================================
+
+
+def test_user_typed_unclassified_warns(caplog):
+ """If the user types a flag we can't route, warn — don't silently drop."""
+ raw = {"bogus_flag": "value", "max_num_seqs": 64}
+ with caplog.at_level(logging.WARNING, logger="vllm_omni.engine.arg_utils"):
+ split_kwargs(raw, engine_cls=_FakeEngineArgs, user_typed={"bogus_flag"})
+ assert any("bogus_flag" in rec.message for rec in caplog.records), (
+ f"Expected warning mentioning 'bogus_flag', got: {[rec.message for rec in caplog.records]}"
+ )
+
+
+def test_unclassified_without_user_typed_silent(caplog):
+ """Without user_typed, unclassified keys drop silently (argparse defaults
+ for server flags shouldn't spam logs on every launch)."""
+ raw = {"host": "0.0.0.0", "port": 8091, "max_num_seqs": 64}
+ with caplog.at_level(logging.WARNING, logger="vllm_omni.engine.arg_utils"):
+ split_kwargs(raw, engine_cls=_FakeEngineArgs, user_typed=None)
+ # No warnings because we don't know these were user-typed.
+ assert not any("host" in rec.message or "port" in rec.message for rec in caplog.records)
+
+
+# ============================================================================
+# Invariant 4 — CLI flag classification completeness.
+# Catches new flags added without updating OrchestratorArgs or OmniEngineArgs.
+# ============================================================================
+
+
+def test_all_omni_cli_flags_classified():
+ """Every vllm-omni-added CLI flag must be classifiable.
+
+ Runs ``OmniServeCommand.subparser_init`` and checks that every new
+ argument (compared to vllm's base parser) is either:
+ - a field on OrchestratorArgs, OR
+ - a field on OmniEngineArgs, OR
+ - in SHARED_FIELDS
+ """
+ try:
+ from vllm.utils.argparse_utils import FlexibleArgumentParser
+
+ from vllm_omni.engine.arg_utils import OmniEngineArgs
+ from vllm_omni.entrypoints.cli.serve import OmniServeCommand
+ except Exception as exc:
+ pytest.skip(f"Cannot build parser in this environment: {exc}")
+
+ # Build the serve parser
+ root = FlexibleArgumentParser()
+ subparsers = root.add_subparsers()
+ cmd = OmniServeCommand()
+ try:
+ parser = cmd.subparser_init(subparsers)
+ except Exception as exc:
+ pytest.skip(f"subparser_init failed (dev env issue): {exc}")
+
+ all_dests = {a.dest for a in parser._actions if a.dest and a.dest not in {"help", "model_tag"}}
+
+ orch = orchestrator_field_names()
+ engine = {f.name for f in fields(OmniEngineArgs)}
+ server_derived = derive_server_dests_from_vllm_parser()
+
+ unclassified = all_dests - orch - engine - SHARED_FIELDS - server_derived
+ # Some argparse-internal dests (suppressed, private) may not match —
+ # filter those out.
+ unclassified = {d for d in unclassified if not d.startswith("_")}
+
+ assert not unclassified, (
+ f"These CLI flags are not classified as "
+ f"orchestrator/engine/shared/server: {sorted(unclassified)}. "
+ f"Add them to OrchestratorArgs (if consumed by orchestrator), "
+ f"OmniEngineArgs (if consumed by per-stage engine), or the known-server "
+ f"allowlist (if they're vllm frontend flags). "
+ f"If intentional (e.g. a new CLI-only flag that doesn't map to either "
+ f"dataclass), add it to a KNOWN_UNROUTED allowlist."
+ )
+
+
+# ============================================================================
+# argparse interop (Phase 3).
+# ============================================================================
+
+
+def test_orchestrator_args_from_argparse():
+ """Can build OrchestratorArgs from an argparse.Namespace."""
+ import argparse
+
+ ns = argparse.Namespace(
+ stage_init_timeout=500,
+ deploy_config="/tmp/x.yaml",
+ max_num_seqs=64, # engine field — ignored
+ host="0.0.0.0", # server field — ignored
+ )
+ orch = orchestrator_args_from_argparse(ns)
+ assert orch.stage_init_timeout == 500
+ assert orch.deploy_config == "/tmp/x.yaml"
+ assert orch.worker_backend == "multi_process" # default
+
+
+def test_derive_server_dests_returns_frozenset():
+ """Server-dest derivation returns a frozenset (possibly empty)."""
+ result = derive_server_dests_from_vllm_parser()
+ assert isinstance(result, frozenset)
+
+
+# ============================================================================
+# internal_blacklist_keys — single source of truth for per-stage forwarding.
+# ============================================================================
+
+
+def test_internal_blacklist_keys_derived_from_orchestrator():
+ """Blacklist is exactly OrchestratorArgs fields minus SHARED_FIELDS.
+
+ This function replaces the old hardcoded INTERNAL_STAGE_OVERRIDE_KEYS
+ frozenset. Asserts the contract so future changes to OrchestratorArgs
+ automatically propagate to the blacklist.
+ """
+ blacklist = internal_blacklist_keys()
+ assert blacklist == orchestrator_field_names() - SHARED_FIELDS
+ # Spot-check expected entries
+ assert "stage_init_timeout" in blacklist
+ assert "deploy_config" in blacklist
+ assert "async_chunk" in blacklist
+ # Shared fields must NOT appear — they flow to both orchestrator and engine
+ assert "model" not in blacklist
+ assert "log_stats" not in blacklist
+
+
+# ============================================================================
+# Boundary value analysis — edge cases around split_kwargs.
+# ============================================================================
+
+
+def test_split_empty_kwargs():
+ """Empty kwargs yields default OrchestratorArgs and empty engine dict."""
+ orch, engine = split_kwargs({}, engine_cls=_FakeEngineArgs)
+ assert orch.stage_init_timeout == 300 # dataclass default
+ assert orch.worker_backend == "multi_process" # dataclass default
+ assert engine == {}
+
+
+def test_split_all_none_values_preserved_on_orchestrator():
+ """None values for orchestrator fields are kept (represents 'not set')."""
+ raw = {"ray_address": None, "deploy_config": None, "max_num_seqs": None}
+ orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs)
+ assert orch.ray_address is None
+ assert orch.deploy_config is None
+ # Engine-side None still passes through; caller decides semantics downstream.
+ assert engine.get("max_num_seqs") is None
+
+
+def test_split_user_typed_with_empty_kwargs_no_warn(caplog):
+ """user_typed non-empty but kwargs empty — no warnings emitted."""
+ with caplog.at_level(logging.WARNING, logger="vllm_omni.engine.arg_utils"):
+ split_kwargs({}, engine_cls=_FakeEngineArgs, user_typed={"nothing"})
+ assert not caplog.records
+
+
+def test_ambiguous_field_strict_raises():
+ """strict=True raises ValueError on overlap outside SHARED_FIELDS."""
+
+ # deploy_config is on OrchestratorArgs; declaring it on the engine class
+ # too (without adding to SHARED_FIELDS) creates an ambiguous route.
+ @dataclass
+ class _AmbiguousEngine:
+ deploy_config: str | None = None
+
+ with pytest.raises(ValueError, match="both OrchestratorArgs and"):
+ split_kwargs({"deploy_config": "x"}, engine_cls=_AmbiguousEngine, strict=True)
+
+
+def test_ambiguous_field_non_strict_routes_to_orchestrator(caplog):
+ """strict=False logs ERROR but routes the ambiguous field to orchestrator."""
+
+ @dataclass
+ class _AmbiguousEngine:
+ deploy_config: str | None = None
+
+ with caplog.at_level(logging.ERROR, logger="vllm_omni.engine.arg_utils"):
+ orch, engine = split_kwargs({"deploy_config": "x"}, engine_cls=_AmbiguousEngine, strict=False)
+ assert orch.deploy_config == "x"
+ assert "deploy_config" not in engine
+ assert any("both OrchestratorArgs" in r.message for r in caplog.records)
diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py
index e284de48d0..1d65d3acd2 100644
--- a/tests/test_config_factory.py
+++ b/tests/test_config_factory.py
@@ -4,12 +4,26 @@
Unit tests for StageConfigFactory and related classes.
"""
+from dataclasses import dataclass
+from pathlib import Path
+
+import pytest
+
from vllm_omni.config.stage_config import (
+ _EXECUTION_TYPE_TO_SCHEDULER,
+ _PIPELINE_REGISTRY,
ModelPipeline,
+ PipelineConfig,
StageConfig,
StageConfigFactory,
+ StageExecutionType,
+ StagePipelineConfig,
StageType,
+ build_stage_runtime_overrides,
+ register_pipeline,
+ strip_parent_engine_args,
)
+from vllm_omni.engine.arg_utils import SHARED_FIELDS, internal_blacklist_keys
class TestStageType:
@@ -241,8 +255,9 @@ def test_default_diffusion_no_yaml(self):
def test_default_diffusion_with_parallel_config(self):
"""Test diffusion config calculates devices from parallel_config."""
+ @dataclass
class MockParallelConfig:
- world_size = 4
+ world_size: int = 4
kwargs = {
"parallel_config": MockParallelConfig(),
@@ -270,7 +285,7 @@ def test_cli_override_forwards_engine_registered_args(self):
stage = StageConfig(stage_id=0, model_stage="thinker", input_sources=[])
cli_overrides = {
"gpu_memory_utilization": 0.9, # Well-known param
- "custom_engine_flag": True, # Not in _INTERNAL_KEYS, so forwarded
+ "custom_engine_flag": True, # Not orchestrator-owned, so forwarded
}
overrides = StageConfigFactory._merge_cli_overrides(stage, cli_overrides)
@@ -311,6 +326,56 @@ def test_per_stage_override_excludes_internal_keys(self):
assert "batch_timeout" not in overrides
+class TestStageResolutionHelpers:
+ """Tests for shared stage override / filtering helpers."""
+
+ def test_build_stage_runtime_overrides_ignores_other_stage_and_internal_keys(self):
+ # Pass the same filter set the function uses by default
+ # (orchestrator-only fields plus SHARED_FIELDS so ``model`` is
+ # treated as not-per-stage-overridable).
+ overrides = build_stage_runtime_overrides(
+ 0,
+ {
+ "gpu_memory_utilization": 0.5,
+ "stage_0_gpu_memory_utilization": 0.9,
+ "stage_1_gpu_memory_utilization": 0.1,
+ "stage_0_model": "should_be_ignored",
+ "parallel_config": {"world_size": 2},
+ },
+ internal_keys=internal_blacklist_keys() | SHARED_FIELDS,
+ )
+
+ assert overrides["gpu_memory_utilization"] == 0.9
+ assert "model" not in overrides
+ assert "parallel_config" not in overrides
+
+ def test_strip_parent_engine_args_reports_only_surprising_parent_overrides(self):
+ from dataclasses import fields as dc_fields
+
+ from vllm.engine.arg_utils import EngineArgs
+
+ parent_fields = {f.name: f for f in dc_fields(EngineArgs)}
+ filtered, overridden = strip_parent_engine_args(
+ {
+ "model": "some/model",
+ "stage_configs_path": "/tmp/stages.yaml",
+ "tensor_parallel_size": 4,
+ "worker_extension_cls": "some.Extension",
+ "custom_pipeline_args": {"pipeline_class": "demo.Pipeline"},
+ },
+ parent_fields=parent_fields,
+ keep_keys={"worker_extension_cls"},
+ strip_keys={"stage_configs_path"},
+ no_warn_keys={"model"},
+ )
+
+ assert filtered == {
+ "worker_extension_cls": "some.Extension",
+ "custom_pipeline_args": {"pipeline_class": "demo.Pipeline"},
+ }
+ assert overridden == ["tensor_parallel_size"]
+
+
class TestPipelineYamlParsing:
"""Tests for pipeline YAML file parsing (@ZJY0516)."""
@@ -609,16 +674,617 @@ def test_parse_missing_async_chunk_defaults_false(self, tmp_path):
assert pipeline.async_chunk is False
-class TestArchitectureFallback:
- """Tests for architecture-based model detection fallback."""
+class TestPipelineDiscovery:
+ """Tests for the central pipeline registry (``pipeline_registry._VLLM_OMNI_PIPELINES``)."""
+
+ def test_registry_has_known_models(self):
+ """Built-in pipelines are lazy-loaded from the central declaration
+ on first access; no eager import or discovery walk needed."""
+ # ``in`` triggers the lazy-map lookup without forcing a load.
+ assert "qwen2_5_omni" in _PIPELINE_REGISTRY
+ assert "qwen3_omni_moe" in _PIPELINE_REGISTRY
+ assert "qwen3_tts" in _PIPELINE_REGISTRY
+
+ def test_registry_loads_pipeline_on_getitem(self):
+ """Looking up a registered model_type returns the matching PipelineConfig."""
+ pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"]
+ assert pipeline.model_type == "qwen3_omni_moe"
+ assert len(pipeline.stages) == 3 # thinker + talker + code2wav
+
+ def test_registry_returns_none_for_unknown(self):
+ """Unknown model_types aren't found; ``get()`` returns None."""
+ assert "definitely_not_a_real_model" not in _PIPELINE_REGISTRY
+ assert _PIPELINE_REGISTRY.get("definitely_not_a_real_model") is None
+
+ def test_pipeline_config_supports_hf_architectures(self):
+ """PipelineConfig accepts hf_architectures for HF-arch fallback
+ (replaces the old _ARCHITECTURE_MODELS dict)."""
+ p = PipelineConfig(
+ model_type="custom_collide",
+ hf_architectures=("SomeCollidingArch",),
+ )
+ assert p.hf_architectures == ("SomeCollidingArch",)
+
+
+class TestStagePipelineConfig:
+ def test_frozen(self):
+ s = StagePipelineConfig(stage_id=0, model_stage="a")
+ with pytest.raises(AttributeError):
+ s.model_stage = "changed"
+
+ def test_defaults(self):
+ s = StagePipelineConfig(stage_id=0, model_stage="a")
+ assert s.execution_type == StageExecutionType.LLM_AR
+ assert s.input_sources == ()
+ assert s.final_output is False
+ assert s.sampling_constraints == {}
+ assert s.engine_output_type is None
+
+
+class TestPipelineConfigNew:
+ def test_frozen(self):
+ p = PipelineConfig(model_type="t", model_arch="A")
+ with pytest.raises(AttributeError):
+ p.model_type = "changed"
+
+ def test_validate_valid(self):
+ p = PipelineConfig(
+ model_type="t",
+ model_arch="A",
+ stages=(
+ StagePipelineConfig(stage_id=0, model_stage="a"),
+ StagePipelineConfig(stage_id=1, model_stage="b", input_sources=(0,)),
+ ),
+ )
+ assert p.validate() == []
+
+ def test_validate_no_stages(self):
+ p = PipelineConfig(model_type="t", model_arch="A")
+ assert any("no stages" in e.lower() for e in p.validate())
+
+ def test_get_scheduler_cls(self):
+ p = PipelineConfig(
+ model_type="t",
+ model_arch="A",
+ stages=(
+ StagePipelineConfig(stage_id=0, model_stage="a", execution_type=StageExecutionType.LLM_AR),
+ StagePipelineConfig(
+ stage_id=1, model_stage="b", execution_type=StageExecutionType.LLM_GENERATION, input_sources=(0,)
+ ),
+ ),
+ )
+ assert "OmniARScheduler" in p.get_scheduler_cls(0)
+ assert "OmniGenerationScheduler" in p.get_scheduler_cls(1)
+
+
+class TestExecutionTypeToScheduler:
+ def test_all_types_mapped(self):
+ for et in StageExecutionType:
+ assert et in _EXECUTION_TYPE_TO_SCHEDULER
+
+
+class TestPipelineRegistry:
+ def test_register_and_lookup(self):
+ p = PipelineConfig(
+ model_type="__test_only__",
+ model_arch="A",
+ stages=(StagePipelineConfig(stage_id=0, model_stage="a"),),
+ )
+ register_pipeline(p)
+ assert _PIPELINE_REGISTRY["__test_only__"] is p
+ del _PIPELINE_REGISTRY["__test_only__"]
+
+
+class TestDeployConfigLoading:
+ def test_load_deploy_config(self):
+ from pathlib import Path
+
+ from vllm_omni.config.stage_config import load_deploy_config
+
+ deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not deploy_path.exists():
+ pytest.skip("Deploy config not found")
+
+ deploy = load_deploy_config(deploy_path)
+ assert len(deploy.stages) == 3
+ assert deploy.async_chunk is True
+ assert deploy.connectors is not None
+ assert deploy.platforms is not None
+
+ def test_merge_pipeline_deploy(self):
+ from pathlib import Path
+
+ import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
+ from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy
+
+ pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"]
+ deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not deploy_path.exists():
+ pytest.skip("Deploy config not found")
+
+ deploy = load_deploy_config(deploy_path)
+ stages = merge_pipeline_deploy(pipeline, deploy)
+
+ assert len(stages) == 3
+ s0 = stages[0]
+ assert s0.model_stage == "thinker"
+ assert s0.yaml_engine_args["model_arch"] == "Qwen3OmniMoeForConditionalGeneration"
+ assert s0.yaml_engine_args["engine_output_type"] == "latent"
+ assert s0.yaml_extras["default_sampling_params"]["detokenize"] is True
+
+
+class TestQwen3OmniPipeline:
+ def test_registered(self):
+ import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
+
+ p = _PIPELINE_REGISTRY.get("qwen3_omni_moe")
+ assert p is not None
+ assert p.model_arch == "Qwen3OmniMoeForConditionalGeneration"
+ assert len(p.stages) == 3
+ assert p.validate() == []
+
+ def test_thinker(self):
+ import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
+
+ s = _PIPELINE_REGISTRY["qwen3_omni_moe"].get_stage(0)
+ assert s.model_stage == "thinker"
+ assert s.execution_type == StageExecutionType.LLM_AR
+ assert s.owns_tokenizer is True
+ assert s.engine_output_type == "latent"
+ assert s.sampling_constraints["detokenize"] is True
+
+ def test_talker(self):
+ import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
+
+ s = _PIPELINE_REGISTRY["qwen3_omni_moe"].get_stage(1)
+ assert s.input_sources == (0,)
+ assert s.sampling_constraints["stop_token_ids"] == [2150]
+ assert s.custom_process_input_func is not None
+ assert s.custom_process_next_stage_input_func is not None
+
+ def test_code2wav(self):
+ import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
+
+ s = _PIPELINE_REGISTRY["qwen3_omni_moe"].get_stage(2)
+ assert s.execution_type == StageExecutionType.LLM_GENERATION
+ assert s.final_output_type == "audio"
+ assert s.custom_process_input_func is not None
+
+
+class TestQwen2_5OmniPipeline:
+ def test_registered(self):
+ import vllm_omni.model_executor.models.qwen2_5_omni.pipeline # noqa: F401
+
+ p = _PIPELINE_REGISTRY.get("qwen2_5_omni")
+ assert p is not None
+ assert p.model_arch == "Qwen2_5OmniForConditionalGeneration"
+ assert len(p.stages) == 3
+ assert p.validate() == []
+
+ def test_thinker(self):
+ import vllm_omni.model_executor.models.qwen2_5_omni.pipeline # noqa: F401
+
+ s = _PIPELINE_REGISTRY["qwen2_5_omni"].get_stage(0)
+ assert s.model_stage == "thinker"
+ assert s.execution_type == StageExecutionType.LLM_AR
+ assert s.owns_tokenizer is True
+ assert s.engine_output_type == "latent"
+ assert s.requires_multimodal_data is True
+
+ def test_talker(self):
+ import vllm_omni.model_executor.models.qwen2_5_omni.pipeline # noqa: F401
+
+ s = _PIPELINE_REGISTRY["qwen2_5_omni"].get_stage(1)
+ assert s.input_sources == (0,)
+ assert s.sampling_constraints["stop_token_ids"] == [8294]
+ assert s.custom_process_input_func is not None
+
+ def test_code2wav(self):
+ import vllm_omni.model_executor.models.qwen2_5_omni.pipeline # noqa: F401
+
+ s = _PIPELINE_REGISTRY["qwen2_5_omni"].get_stage(2)
+ assert s.execution_type == StageExecutionType.LLM_GENERATION
+ assert s.final_output_type == "audio"
+ assert s.engine_output_type == "audio"
+
+
+class TestQwen3TTSPipeline:
+ def test_registered(self):
+ import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401
+
+ p = _PIPELINE_REGISTRY.get("qwen3_tts")
+ assert p is not None
+ assert p.model_arch == "Qwen3TTSTalkerForConditionalGeneration"
+ assert len(p.stages) == 2
+ assert p.validate() == []
+
+ def test_talker_stage(self):
+ import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401
+
+ s = _PIPELINE_REGISTRY["qwen3_tts"].get_stage(0)
+ assert s.model_stage == "qwen3_tts"
+ assert s.execution_type == StageExecutionType.LLM_AR
+ assert s.owns_tokenizer is True
+ assert s.engine_output_type == "latent"
+ assert s.sampling_constraints["stop_token_ids"] == [2150]
+ # Stage 0 inherits the pipeline-level model_arch
+ assert s.model_arch is None
+
+ def test_code2wav_stage_has_per_stage_model_arch(self):
+ import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401
+
+ s = _PIPELINE_REGISTRY["qwen3_tts"].get_stage(1)
+ assert s.execution_type == StageExecutionType.LLM_GENERATION
+ assert s.final_output_type == "audio"
+ assert s.engine_output_type == "audio"
+ # Per-stage model_arch override (different from pipeline-level talker)
+ assert s.model_arch == "Qwen3TTSCode2Wav"
+ # tts_args is passed through via extras
+ assert s.extras["tts_args"]["max_instructions_length"] == 500
+
+ def test_per_stage_model_arch_flows_through_merge(self, tmp_path):
+ """Verify the new ps.model_arch override survives merge_pipeline_deploy."""
+ import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401
+ from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy
+
+ deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_tts.yaml"
+ if not deploy_path.exists():
+ pytest.skip("qwen3_tts deploy yaml not found")
+
+ deploy = load_deploy_config(deploy_path)
+ pipeline = _PIPELINE_REGISTRY["qwen3_tts"]
+ stages = merge_pipeline_deploy(pipeline, deploy)
+
+ # Stage 0 inherits pipeline-level model_arch
+ assert stages[0].yaml_engine_args["model_arch"] == "Qwen3TTSTalkerForConditionalGeneration"
+ # Stage 1 uses its per-stage override
+ assert stages[1].yaml_engine_args["model_arch"] == "Qwen3TTSCode2Wav"
+
+
+class TestBaseConfigInheritance:
+ """Test deploy YAML base_config inheritance."""
+
+ def test_ci_inherits_from_main(self):
+ from tests.utils import get_deploy_config_path
+ from vllm_omni.config.stage_config import load_deploy_config
+
+ ci_path = Path(get_deploy_config_path("ci/qwen3_omni_moe.yaml"))
+ if not ci_path.exists():
+ pytest.skip("CI deploy config not found")
+
+ deploy = load_deploy_config(ci_path)
+ assert len(deploy.stages) == 3
+ # CI overrides
+ assert deploy.stages[0].engine_extras.get("load_format") == "dummy"
+ assert deploy.stages[0].max_num_seqs == 5
+ # Inherited from base
+ assert deploy.stages[0].gpu_memory_utilization == 0.9
+ assert deploy.connectors is not None
+ assert "connector_of_shared_memory" in deploy.connectors
+ # CI overlay explicitly sets async_chunk: False (see
+ # tests/utils.py::_CI_OVERLAYS and PR #2383 discussion). Overlay
+ # bool overrides base even when the base yaml has async_chunk: true.
+ assert deploy.async_chunk is False
+
+ def test_ci_sampling_merge(self):
+ from tests.utils import get_deploy_config_path
+ from vllm_omni.config.stage_config import load_deploy_config
+
+ ci_path = Path(get_deploy_config_path("ci/qwen3_omni_moe.yaml"))
+ if not ci_path.exists():
+ pytest.skip("CI deploy config not found")
+
+ deploy = load_deploy_config(ci_path)
+ s0 = deploy.stages[0].default_sampling_params
+ # CI overrides max_tokens
+ assert s0["max_tokens"] == 150
+ # Inherited from base
+ assert s0["temperature"] == 0.4
+ assert s0["seed"] == 42
+
+ def test_pure_inheritance_overlay(self, tmp_path):
+ """An overlay with only ``base_config`` inherits everything."""
+ from vllm_omni.config.stage_config import load_deploy_config
+
+ base = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not base.exists():
+ pytest.skip("Base deploy config not found")
+
+ overlay = tmp_path / "overlay.yaml"
+ overlay.write_text(f"base_config: {base}\n")
+
+ deploy = load_deploy_config(overlay)
+ assert len(deploy.stages) == 3
+ assert deploy.stages[0].gpu_memory_utilization == 0.9
+
+ def test_single_field_overlay(self, tmp_path):
+ """An overlay overriding one stage field merges with the base."""
+ from vllm_omni.config.stage_config import load_deploy_config
+
+ base = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not base.exists():
+ pytest.skip("Base deploy config not found")
+
+ overlay = tmp_path / "overlay.yaml"
+ overlay.write_text(f"base_config: {base}\nstages:\n - stage_id: 2\n max_num_batched_tokens: 1000000\n")
+
+ deploy = load_deploy_config(overlay)
+ assert deploy.stages[2].max_num_batched_tokens == 1000000
+ # Rest inherited
+ assert deploy.stages[0].gpu_memory_utilization == 0.9
+
+
+class TestPlatformOverrides:
+ """Test platform-specific deploy config overrides."""
+
+ def test_npu_overrides(self):
+ from pathlib import Path
+
+ from vllm_omni.config.stage_config import _apply_platform_overrides, load_deploy_config
+
+ deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not deploy_path.exists():
+ pytest.skip("Deploy config not found")
+
+ deploy = load_deploy_config(deploy_path)
+ deploy = _apply_platform_overrides(deploy, platform="npu")
+
+ assert deploy.stages[0].gpu_memory_utilization == 0.6
+ assert deploy.stages[0].tensor_parallel_size == 2
+ assert deploy.stages[0].devices == "0,1"
+ # Stage 2 unaffected fields stay at base
+ assert deploy.stages[2].enforce_eager is True
+
+ def test_xpu_overrides(self):
+ from pathlib import Path
+
+ from vllm_omni.config.stage_config import _apply_platform_overrides, load_deploy_config
+
+ deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not deploy_path.exists():
+ pytest.skip("Deploy config not found")
+
+ deploy = load_deploy_config(deploy_path)
+ deploy = _apply_platform_overrides(deploy, platform="xpu")
+
+ assert deploy.stages[0].tensor_parallel_size == 4
+ assert deploy.stages[0].devices == "0,1,2,3"
+ assert deploy.stages[0].engine_extras.get("max_cudagraph_capture_size") == 0
+
+ def test_unknown_platform_noop(self):
+ from pathlib import Path
+
+ from vllm_omni.config.stage_config import _apply_platform_overrides, load_deploy_config
+
+ deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not deploy_path.exists():
+ pytest.skip("Deploy config not found")
+
+ deploy = load_deploy_config(deploy_path)
+ original_mem = deploy.stages[0].gpu_memory_utilization
+ deploy = _apply_platform_overrides(deploy, platform="unknown_hw")
+ assert deploy.stages[0].gpu_memory_utilization == original_mem
+
+ def test_platforms_deep_merge_inheritance(self, tmp_path):
+ """Overlay's platforms: block layers onto base's, per-stage."""
+ from vllm_omni.config.stage_config import _apply_platform_overrides, load_deploy_config
+
+ base = tmp_path / "base.yaml"
+ base.write_text(
+ "stages:\n"
+ " - stage_id: 0\n"
+ " gpu_memory_utilization: 0.9\n"
+ "platforms:\n"
+ " rocm:\n"
+ " stages:\n"
+ " - stage_id: 0\n"
+ " enforce_eager: true\n"
+ )
+ overlay = tmp_path / "overlay.yaml"
+ overlay.write_text(
+ f"base_config: {base.name}\n"
+ "platforms:\n"
+ " rocm:\n"
+ " stages:\n"
+ " - stage_id: 0\n"
+ " max_num_seqs: 1\n"
+ )
+
+ deploy = load_deploy_config(overlay)
+ deploy = _apply_platform_overrides(deploy, platform="rocm")
+ # Both base's enforce_eager and overlay's max_num_seqs should apply.
+ assert deploy.stages[0].enforce_eager is True
+ assert deploy.stages[0].max_num_seqs == 1
+ # Inherited stage default not touched by overlay platforms section.
+ assert deploy.stages[0].gpu_memory_utilization == 0.9
+
+
+class TestCLIOverrideFlow:
+ """Test --stage-overrides JSON merge into StageConfig."""
+
+ def test_stage_overrides_merge(self):
+ from pathlib import Path
+
+ import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
+ from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy
+
+ pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"]
+ deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not deploy_path.exists():
+ pytest.skip("Deploy config not found")
+
+ deploy = load_deploy_config(deploy_path)
+ stages = merge_pipeline_deploy(pipeline, deploy)
+
+ # Simulate --stage-overrides '{"0": {"gpu_memory_utilization": 0.5}}'
+ overrides = {"stage_0_gpu_memory_utilization": 0.5}
+ stages[0].runtime_overrides = StageConfigFactory._merge_cli_overrides(stages[0], overrides)
+ assert stages[0].runtime_overrides["gpu_memory_utilization"] == 0.5
+
+ def test_global_override_applies_to_all(self):
+ from pathlib import Path
+
+ import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
+ from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy
+
+ pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"]
+ deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not deploy_path.exists():
+ pytest.skip("Deploy config not found")
+
+ deploy = load_deploy_config(deploy_path)
+ stages = merge_pipeline_deploy(pipeline, deploy)
+
+ overrides = {"enforce_eager": True}
+ for s in stages:
+ s.runtime_overrides = StageConfigFactory._merge_cli_overrides(s, overrides)
+ assert s.runtime_overrides["enforce_eager"] is True
+
+
+class TestCLIExplicitPrecedence:
+ """Verify YAML > argparse defaults; explicit CLI args > YAML."""
+
+ def _stages(self, cli_overrides, cli_explicit_keys):
+ import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
+
+ return StageConfigFactory._create_from_registry(
+ "qwen3_omni_moe",
+ cli_overrides=cli_overrides,
+ cli_explicit_keys=cli_explicit_keys,
+ )
+
+ def test_explicit_cli_overrides_yaml(self):
+ """User-typed --max-num-seqs wins over the deploy YAML value."""
+ stages = self._stages(
+ cli_overrides={"max_num_seqs": 999},
+ cli_explicit_keys={"max_num_seqs"},
+ )
+ # Stage 2 yaml has max_num_seqs=1; explicit CLI must beat it.
+ assert stages[2].runtime_overrides.get("max_num_seqs") == 999
+
+ def test_default_cli_does_not_override_yaml(self):
+ """Argparse defaults must NOT clobber values that are present in YAML."""
+ stages = self._stages(
+ cli_overrides={"max_num_seqs": 256},
+ cli_explicit_keys=set(), # user typed nothing
+ )
+ # Stage 2's YAML value (1) should win because the user didn't type --max-num-seqs.
+ assert stages[2].runtime_overrides.get("max_num_seqs") != 256
+
+ def test_default_cli_fills_missing_yaml_field(self):
+ """Argparse defaults still fill fields the YAML doesn't set."""
+ stages = self._stages(
+ cli_overrides={"some_unrelated_knob": "fallback"},
+ cli_explicit_keys=set(),
+ )
+ # Field absent from YAML → CLI default flows through as a fallback.
+ assert stages[0].runtime_overrides.get("some_unrelated_knob") == "fallback"
+
+ def test_per_stage_overrides_always_explicit(self):
+ """``stage__*`` keys are always treated as explicit."""
+ stages = self._stages(
+ cli_overrides={"stage_0_gpu_memory_utilization": 0.42},
+ cli_explicit_keys=set(), # not in the explicit set, but per-stage
+ )
+ assert stages[0].runtime_overrides.get("gpu_memory_utilization") == 0.42
+
+ def test_none_explicit_set_treats_all_as_explicit(self):
+ """Programmatic Omni() callers (cli_explicit_keys=None) keep current behavior."""
+ stages = self._stages(
+ cli_overrides={"max_num_seqs": 999},
+ cli_explicit_keys=None,
+ )
+ assert stages[2].runtime_overrides.get("max_num_seqs") == 999
+
+ def test_explicit_async_chunk_false_overrides_yaml(self):
+ """``--no-async-chunk`` flips the deploy-level async_chunk to False even
+ when the YAML sets it to True. Verifies that the per-stage
+ ``async_chunk: True`` injection in ``merge_pipeline_deploy`` is skipped
+ and that ``async_chunk`` does not leak through ``_merge_cli_overrides``.
+ """
+ stages = self._stages(
+ cli_overrides={"async_chunk": False},
+ cli_explicit_keys={"async_chunk"},
+ )
+ # qwen3_omni_moe.yaml has `async_chunk: true`, so by default every
+ # stage's engine_args would carry it. With the explicit override, it
+ # must NOT show up.
+ for stage in stages:
+ assert stage.yaml_engine_args.get("async_chunk") is not True
+ assert stage.runtime_overrides.get("async_chunk") is None
+
+ def test_default_async_chunk_leaves_yaml_alone(self):
+ """An unset ``--async-chunk`` (default None) must leave the YAML's True
+ in force on every stage."""
+ stages = self._stages(
+ cli_overrides={"async_chunk": None},
+ cli_explicit_keys=set(),
+ )
+ # qwen3_omni_moe.yaml: `async_chunk: true` → injected on every stage.
+ for stage in stages:
+ assert stage.yaml_engine_args.get("async_chunk") is True
+
+ def test_explicit_enable_prefix_caching_overrides_yaml(self):
+ """``--enable-prefix-caching`` (global) flips every stage's
+ ``enable_prefix_caching`` to True regardless of the YAML default."""
+ stages = self._stages(
+ cli_overrides={"enable_prefix_caching": True},
+ cli_explicit_keys={"enable_prefix_caching"},
+ )
+ for stage in stages:
+ assert stage.runtime_overrides.get("enable_prefix_caching") is True
+
+ def test_async_chunk_dispatches_processors(self):
+ """A single ``qwen3_tts`` pipeline picks per-chunk vs end-to-end
+ processors based on ``deploy.async_chunk``, without needing a
+ separate variant pipeline registration."""
+ import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401
+ from vllm_omni.config.stage_config import (
+ _PIPELINE_REGISTRY,
+ DeployConfig,
+ merge_pipeline_deploy,
+ )
+
+ pipeline = _PIPELINE_REGISTRY["qwen3_tts"]
+
+ # async_chunk=True → stage 0's per-chunk processor wires up, stage 1
+ # has no sync input processor.
+ async_stages = merge_pipeline_deploy(pipeline, DeployConfig(async_chunk=True))
+ assert (
+ async_stages[0]
+ .yaml_engine_args.get("custom_process_next_stage_input_func", "")
+ .endswith("talker2code2wav_async_chunk")
+ )
+ assert async_stages[1].custom_process_input_func is None
+
+ # async_chunk=False → stage 0 has no streaming processor, stage 1's
+ # batch-end processor wires up.
+ sync_stages = merge_pipeline_deploy(pipeline, DeployConfig(async_chunk=False))
+ assert "custom_process_next_stage_input_func" not in sync_stages[0].yaml_engine_args
+ assert sync_stages[1].custom_process_input_func is not None
+ assert sync_stages[1].custom_process_input_func.endswith("talker2code2wav")
+
+
+class TestSamplingConstraintsPrecedence:
+ """Test that pipeline sampling_constraints override deploy defaults."""
+
+ def test_constraints_win(self):
+ from pathlib import Path
+
+ import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
+ from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy
+
+ pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"]
+ deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not deploy_path.exists():
+ pytest.skip("Deploy config not found")
- def test_architecture_models_mapping_exists(self):
- """Test that _ARCHITECTURE_MODELS contains expected entries."""
- assert "MiMoAudioForConditionalGeneration" in StageConfigFactory._ARCHITECTURE_MODELS
- assert StageConfigFactory._ARCHITECTURE_MODELS["MiMoAudioForConditionalGeneration"] == "mimo_audio"
- assert "HunyuanImage3ForCausalMM" in StageConfigFactory._ARCHITECTURE_MODELS
- assert StageConfigFactory._ARCHITECTURE_MODELS["HunyuanImage3ForCausalMM"] == "hunyuan_image3"
+ deploy = load_deploy_config(deploy_path)
+ stages = merge_pipeline_deploy(pipeline, deploy)
- def test_mimo_audio_in_pipeline_models(self):
- """Test that mimo_audio is registered in PIPELINE_MODELS."""
- assert "mimo_audio" in StageConfigFactory.PIPELINE_MODELS
+ # Pipeline says detokenize=True for thinker, deploy can't override
+ assert stages[0].yaml_extras["default_sampling_params"]["detokenize"] is True
+ # Pipeline says stop_token_ids=[2150] for talker
+ assert stages[1].yaml_extras["default_sampling_params"]["stop_token_ids"] == [2150]
+ # Deploy temperature still flows through
+ assert stages[0].yaml_extras["default_sampling_params"]["temperature"] == 0.4
diff --git a/tests/test_diffusion_config_fields.py b/tests/test_diffusion_config_fields.py
new file mode 100644
index 0000000000..b87ceec1df
--- /dev/null
+++ b/tests/test_diffusion_config_fields.py
@@ -0,0 +1,68 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Ensure diffusion stage YAML configs only use valid OmniDiffusionConfig fields.
+
+Regression test for https://github.com/vllm-project/vllm-omni/issues/2563
+"""
+
+from dataclasses import fields
+from pathlib import Path
+
+import pytest
+import yaml
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+try:
+ from vllm_omni.diffusion.data import OmniDiffusionConfig
+except Exception:
+ OmniDiffusionConfig = None
+
+
+@pytest.mark.skipif(
+ OmniDiffusionConfig is None,
+ reason="OmniDiffusionConfig could not be imported (missing torch?)",
+)
+def test_diffusion_stage_configs_only_contain_valid_fields():
+ """Diffusion stage engine_args must only contain OmniDiffusionConfig fields.
+
+ Regression test for https://github.com/vllm-project/vllm-omni/issues/2563
+ """
+ # Scan both main configs and test configs
+ repo_root = Path(__file__).parent.parent
+ config_dirs = [
+ repo_root / "vllm_omni" / "model_executor" / "stage_configs",
+ ]
+ # Also scan test directories recursively
+ test_dir = repo_root / "tests"
+
+ yaml_paths: list[Path] = []
+ for config_dir in config_dirs:
+ yaml_paths.extend(sorted(config_dir.glob("*.yaml")))
+ yaml_paths.extend(sorted(test_dir.rglob("*.yaml")))
+
+ valid_fields = {f.name for f in fields(OmniDiffusionConfig)}
+ # model_stage is consumed by the stage init layer, not OmniDiffusionConfig
+ valid_fields.add("model_stage")
+ # model_arch is consumed by the stage init layer for diffusion model class resolution
+ valid_fields.add("model_arch")
+ # "quantization" is mapped to "quantization_config" by from_kwargs() backwards-compat
+ valid_fields.add("quantization")
+
+ invalid_entries: list[tuple[str, set[str]]] = []
+ for yaml_path in yaml_paths:
+ with open(yaml_path) as fh:
+ config = yaml.safe_load(fh)
+
+ stages = config.get("stage_args", config.get("stages", []))
+ for stage in stages:
+ if stage.get("stage_type") != "diffusion":
+ continue
+ engine_args = stage.get("engine_args", {})
+ invalid = set(engine_args.keys()) - valid_fields
+ if invalid:
+ invalid_entries.append((yaml_path.relative_to(repo_root), invalid))
+
+ assert not invalid_entries, "Diffusion stage configs contain fields not in OmniDiffusionConfig:\n" + "\n".join(
+ f" {name}: {sorted(bad)}" for name, bad in invalid_entries
+ )
diff --git a/tests/test_diffusion_config_propagation.py b/tests/test_diffusion_config_propagation.py
index 7d6d9c43f0..eeb3505efe 100644
--- a/tests/test_diffusion_config_propagation.py
+++ b/tests/test_diffusion_config_propagation.py
@@ -15,6 +15,7 @@
DiffusionParallelConfig,
OmniDiffusionConfig,
)
+from vllm_omni.diffusion.model_metadata import QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
@@ -109,3 +110,12 @@ def test_extra_kwargs_forwarded(self):
ea = stages[0]["engine_args"]
assert ea["enforce_eager"] is True
assert ea["lora_path"] == "/tmp/lora"
+
+
+def test_qwen_image_edit_plus_sets_generic_multimodal_limit():
+ od_config = OmniDiffusionConfig(model="Qwen/Qwen-Image-Edit-2511", model_class_name="QwenImageEditPlusPipeline")
+
+ od_config.update_multimodal_support()
+
+ assert od_config.supports_multimodal_inputs is True
+ assert od_config.max_multimodal_image_inputs == QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES
diff --git a/tests/test_fish_speech_voice_cache.py b/tests/test_fish_speech_voice_cache.py
index 8fe7a4a4d1..1c299d8014 100644
--- a/tests/test_fish_speech_voice_cache.py
+++ b/tests/test_fish_speech_voice_cache.py
@@ -10,11 +10,12 @@
import os
import tempfile
-from unittest.mock import MagicMock, patch
+from pathlib import Path
import numpy as np
import pytest
import torch
+from pytest_mock import MockerFixture
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
@@ -61,18 +62,18 @@ class TestFishSpeechVoiceCacheIntegration:
"""Test the cache-hit / cache-miss / no-cache paths in the model."""
@pytest.fixture
- def mock_model(self):
+ def mock_model(self, mocker: MockerFixture):
"""Create a mock FishSpeechSlowARForConditionalGeneration with cache."""
from vllm_omni.utils.voice_cache import VoiceEmbeddingCache
- model = MagicMock()
+ model = mocker.MagicMock()
model._voice_cache = VoiceEmbeddingCache(max_entries=4)
model._semantic_begin_id = 151678
model._num_codebooks = 10
model._codebook_size = 4096
model.model_path = "/fake/model"
- model.codebook_embeddings = MagicMock()
- model.codebook_embeddings.weight = MagicMock()
+ model.codebook_embeddings = mocker.MagicMock()
+ model.codebook_embeddings.weight = mocker.MagicMock()
model.codebook_embeddings.weight.device = torch.device("cpu")
return model
@@ -166,9 +167,13 @@ def test_created_at_zero_disables_cache(self, mock_model):
class TestFishSpeechValidatorUploadedVoice:
"""Test _validate_fish_tts_request uploaded voice resolution."""
- def test_uploaded_voice_resolves_ref_audio(self):
+ def test_uploaded_voice_resolves_ref_audio(
+ self,
+ monkeypatch: pytest.MonkeyPatch,
+ mocker: MockerFixture,
+ ):
"""When voice matches an uploaded speaker, ref_audio should be auto-set."""
- request = MagicMock()
+ request = mocker.MagicMock()
request.input = "Hello"
request.voice = "alice"
request.ref_audio = None
@@ -185,17 +190,21 @@ def test_uploaded_voice_resolves_ref_audio(self):
}
# Simulate: voice in uploaded_speakers, file exists, get_audio returns data URL.
- with patch("pathlib.Path.exists", return_value=True):
- voice_lower = request.voice.lower()
- assert voice_lower in uploaded_speakers
+ monkeypatch.setattr(Path, "exists", lambda self: True)
- speaker_info = uploaded_speakers[voice_lower]
- ref_text_from_upload = speaker_info.get("ref_text")
- assert ref_text_from_upload == "Hi this is Alice"
+ voice_lower = request.voice.lower()
+ assert voice_lower in uploaded_speakers
+
+ speaker_info = uploaded_speakers[voice_lower]
+ ref_text_from_upload = speaker_info.get("ref_text")
+ assert ref_text_from_upload == "Hi this is Alice"
- def test_uploaded_voice_without_ref_text_uses_request_ref_text(self):
+ def test_uploaded_voice_without_ref_text_uses_request_ref_text(
+ self,
+ mocker: MockerFixture,
+ ):
"""If upload has no ref_text but request provides it, use request's."""
- request = MagicMock()
+ request = mocker.MagicMock()
request.input = "Hello"
request.voice = "bob"
request.ref_audio = None
diff --git a/tests/test_generate_nightly_perf_excel.py b/tests/test_generate_nightly_perf_excel.py
new file mode 100644
index 0000000000..9b05d6de0f
--- /dev/null
+++ b/tests/test_generate_nightly_perf_excel.py
@@ -0,0 +1,71 @@
+import importlib.util
+import json
+import sys
+from pathlib import Path
+
+import pytest
+from openpyxl import load_workbook
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+def _load_excel_module():
+ repo_root = Path(__file__).resolve().parents[1]
+ module_path = repo_root / "tools" / "nightly" / "generate_nightly_perf_excel.py"
+ spec = importlib.util.spec_from_file_location("generate_nightly_perf_excel", module_path)
+ module = importlib.util.module_from_spec(spec)
+ sys.modules[spec.name] = module
+ spec.loader.exec_module(module)
+ return module
+
+
+def _cell_value_by_header(ws, header_name: str, row_idx: int = 2):
+ headers = [c.value for c in ws[1]]
+ col_idx = headers.index(header_name) + 1
+ return ws.cell(row=row_idx, column=col_idx).value
+
+
+def test_generate_excel_report_with_perf_templates(tmp_path: Path):
+ module = _load_excel_module()
+ repo_root = Path(__file__).resolve().parents[1]
+ perf_scripts_dir = repo_root / "tests" / "dfx" / "perf" / "scripts"
+
+ omni_template_path = perf_scripts_dir / "result_omni_template.json"
+ diffusion_template_path = perf_scripts_dir / "diffusion_result_template.json"
+
+ omni_record = json.loads(omni_template_path.read_text(encoding="utf-8"))
+ diffusion_records = json.loads(diffusion_template_path.read_text(encoding="utf-8"))
+
+ input_dir = tmp_path / "input"
+ diffusion_input_dir = tmp_path / "diffusion_input"
+ input_dir.mkdir()
+ diffusion_input_dir.mkdir()
+
+ # Keep file names compatible with parser conventions in generate_nightly_perf_excel.py
+ omni_result_file = input_dir / "result_test_perf_random_1_4_in2500_out900_20260415-185642.json"
+ diffusion_result_file = diffusion_input_dir / "diffusion_result_qwen_image_edit_20260415-193200.json"
+ omni_result_file.write_text(json.dumps(omni_record, ensure_ascii=False, indent=2), encoding="utf-8")
+ diffusion_result_file.write_text(json.dumps(diffusion_records, ensure_ascii=False, indent=2), encoding="utf-8")
+
+ output_file = tmp_path / "nightly_perf.xlsx"
+ module.generate_excel_report(
+ input_dir=str(input_dir),
+ diffusion_input_dir=str(diffusion_input_dir),
+ output_file=str(output_file),
+ commit_sha="test_commit_sha",
+ build_id="test_build_id",
+ build_url="https://example.com/build/123",
+ )
+
+ assert output_file.exists()
+
+ wb = load_workbook(output_file)
+ assert set(wb.sheetnames) >= {"omni_summary", "diffusion_summary", "omni_raw", "diffusion_raw"}
+
+ ws_omni_raw = wb["omni_raw"]
+ baseline_cell = _cell_value_by_header(ws_omni_raw, "baseline")
+ assert baseline_cell == json.dumps(omni_record["baseline"], ensure_ascii=False, sort_keys=True)
+
+ ws_omni_summary = wb["omni_summary"]
+ assert _cell_value_by_header(ws_omni_summary, "commit_sha") == "test_commit_sha"
+ assert _cell_value_by_header(ws_omni_summary, "build_id") == "test_build_id"
diff --git a/tests/test_generate_nightly_perf_html.py b/tests/test_generate_nightly_perf_html.py
new file mode 100644
index 0000000000..4e77eb3adf
--- /dev/null
+++ b/tests/test_generate_nightly_perf_html.py
@@ -0,0 +1,54 @@
+import importlib.util
+import json
+import sys
+from pathlib import Path
+
+import pytest
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+def _load_html_module():
+ repo_root = Path(__file__).resolve().parents[1]
+ module_path = repo_root / "tools" / "nightly" / "generate_nightly_perf_html.py"
+ spec = importlib.util.spec_from_file_location("generate_nightly_perf_html", module_path)
+ module = importlib.util.module_from_spec(spec)
+ sys.modules[spec.name] = module
+ spec.loader.exec_module(module)
+ return module
+
+
+def test_generate_html_report_with_perf_templates(tmp_path: Path):
+ module = _load_html_module()
+ repo_root = Path(__file__).resolve().parents[1]
+ perf_scripts_dir = repo_root / "tests" / "dfx" / "perf" / "scripts"
+
+ omni_template_path = perf_scripts_dir / "result_omni_template.json"
+ diffusion_template_path = perf_scripts_dir / "diffusion_result_template.json"
+
+ omni_record = json.loads(omni_template_path.read_text(encoding="utf-8"))
+ diffusion_records = json.loads(diffusion_template_path.read_text(encoding="utf-8"))
+
+ input_dir = tmp_path / "input"
+ diffusion_input_dir = tmp_path / "diffusion_input"
+ input_dir.mkdir()
+ diffusion_input_dir.mkdir()
+
+ omni_result_file = input_dir / "result_test_perf_random_1_4_in2500_out900_20260415-185642.json"
+ diffusion_result_file = diffusion_input_dir / "diffusion_result_qwen_image_edit_20260415-193200.json"
+ omni_result_file.write_text(json.dumps(omni_record, ensure_ascii=False, indent=2), encoding="utf-8")
+ diffusion_result_file.write_text(json.dumps(diffusion_records, ensure_ascii=False, indent=2), encoding="utf-8")
+
+ output_file = tmp_path / "nightly_perf_v2.html"
+ module.generate_html_report(
+ input_dir=str(input_dir),
+ diffusion_input_dir=str(diffusion_input_dir),
+ output_file=str(output_file),
+ )
+
+ assert output_file.exists()
+ html = output_file.read_text(encoding="utf-8")
+ assert "Nightly Performance Report" in html
+ assert "Omni records 1 " in html
+ assert f"Diffusion records {len(diffusion_records)} " in html
+ assert "const DIFF_DATA =" in html
diff --git a/tests/test_version.py b/tests/test_version.py
new file mode 100644
index 0000000000..07e622a7d1
--- /dev/null
+++ b/tests/test_version.py
@@ -0,0 +1,58 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for version compatibility warnings."""
+
+import warnings
+from unittest import mock
+
+import pytest
+
+from vllm_omni.version import warn_if_misaligned_vllm_version
+
+
+@mock.patch("vllm_omni.version.__version_tuple__", (0, 19, 0))
+@mock.patch("vllm_omni.version.__version__", "0.19.0")
+@mock.patch("vllm.__version_tuple__", (0, 18, 0))
+@mock.patch("vllm.__version__", "0.18.0")
+def test_version_mismatch_warning():
+ """Ensure that we warn when vLLM and vLLM-Omni major/minor versions differ."""
+ with pytest.warns(RuntimeWarning, match="mismatched major/minor versions"):
+ warn_if_misaligned_vllm_version()
+
+
+@pytest.mark.parametrize(
+ "vllm_ver,vllm_tuple,omni_ver,omni_tuple",
+ [
+ ("0.19.0", (0, 19, 0), "0.19.5", (0, 19, 5)), # Patch differs
+ ("0.18.0", (0, 18, 0), "dev", (0, 0, "dev")), # Omni dev
+ ("dev", (0, 0, "dev"), "0.19.0", (0, 19, 0)), # vLLM dev
+ # Ensure local identifies don't matter for the warning
+ ("0.19.0+foo", (0, 19, 0, "foo"), "0.19.5", (0, 19, 0)),
+ ("0.19.0", (0, 19, 0), "0.19.5+bar", (0, 19, 0, "bar")),
+ ("0.19.0+foo", (0, 19, 0, "foo"), "0.19.5+bar", (0, 19, 0, "bar")),
+ ],
+)
+def test_no_warning_cases(vllm_ver, vllm_tuple, omni_ver, omni_tuple):
+ """Ensure we don't warn when minor versions match or either is a dev build."""
+ with (
+ mock.patch.multiple("vllm", __version__=vllm_ver, __version_tuple__=vllm_tuple),
+ mock.patch.multiple("vllm_omni.version", __version__=omni_ver, __version_tuple__=omni_tuple),
+ ):
+ with warnings.catch_warnings():
+ warnings.simplefilter("error")
+ warn_if_misaligned_vllm_version()
+
+
+@mock.patch("vllm_omni.version.__version_tuple__", (0, 19, 0))
+@mock.patch("vllm_omni.version.__version__", "0.19.0rc2.dev21")
+@mock.patch("vllm.__version_tuple__", (0, 18, 0))
+@mock.patch("vllm.__version__", "0.18.0")
+def test_warning_contains_version_strings():
+ """Ensure that the warning contains the full version strings."""
+ with pytest.warns(RuntimeWarning) as record:
+ warn_if_misaligned_vllm_version()
+
+ assert len(record) == 1
+ msg = str(record[0].message)
+ assert "0.19.0rc2.dev21" in msg
+ assert "0.18.0" in msg
diff --git a/tests/utils.py b/tests/utils.py
index 84edbbf3d1..d8137cf963 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -11,6 +11,7 @@
import time
from collections.abc import Callable
from contextlib import ExitStack, contextmanager, suppress
+from pathlib import Path
from typing import Any, Literal
import cloudpickle
@@ -24,6 +25,221 @@
_P = ParamSpec("_P")
+_REPO_ROOT = Path(__file__).resolve().parent.parent
+_DEPLOY_DIR = _REPO_ROOT / "vllm_omni" / "deploy"
+_CI_GENERATED_DIR = _REPO_ROOT / "tests" / ".ci_generated"
+
+
+# CI overlays as Python dicts (LSP-friendly). Materialized on demand to
+# tests/.ci_generated/.yaml via get_deploy_config_path("ci/.yaml").
+_CI_OVERLAYS: dict[str, dict[str, Any]] = {
+ "qwen2_5_omni": {
+ "base_config": "qwen2_5_omni.yaml",
+ "async_chunk": False,
+ "stages": [
+ {
+ "stage_id": 0,
+ "max_model_len": 16384,
+ "max_num_batched_tokens": 16384,
+ "max_num_seqs": 1,
+ "gpu_memory_utilization": 0.9,
+ "skip_mm_profiling": True,
+ "load_format": "dummy",
+ "default_sampling_params": {"max_tokens": 128},
+ },
+ {
+ "stage_id": 1,
+ "max_model_len": 16384,
+ "max_num_batched_tokens": 16384,
+ "max_num_seqs": 1,
+ "gpu_memory_utilization": 0.4,
+ "skip_mm_profiling": True,
+ "load_format": "dummy",
+ "default_sampling_params": {"max_tokens": 4096},
+ },
+ {
+ "stage_id": 2,
+ "max_num_seqs": 1,
+ "gpu_memory_utilization": 0.5,
+ "max_num_batched_tokens": 8192,
+ "max_model_len": 8192,
+ "load_format": "dummy",
+ "devices": "2",
+ "default_sampling_params": {"max_tokens": 8192},
+ },
+ ],
+ "platforms": {
+ "rocm": {
+ "stages": [
+ {"stage_id": 0, "gpu_memory_utilization": 0.9},
+ {"stage_id": 1, "gpu_memory_utilization": 0.4},
+ {"stage_id": 2, "gpu_memory_utilization": 0.5, "devices": "2"},
+ ],
+ },
+ "xpu": {
+ "stages": [
+ {
+ "stage_id": 0,
+ "gpu_memory_utilization": 0.9,
+ "max_num_batched_tokens": 16384,
+ "max_model_len": 16384,
+ },
+ {"stage_id": 1, "gpu_memory_utilization": 0.5},
+ {
+ "stage_id": 2,
+ "gpu_memory_utilization": 0.3,
+ "max_num_batched_tokens": 4096,
+ "max_model_len": 4096,
+ "devices": "2",
+ },
+ ],
+ },
+ },
+ },
+ "qwen3_omni_moe": {
+ "base_config": "qwen3_omni_moe.yaml",
+ "async_chunk": False,
+ "stages": [
+ {
+ "stage_id": 0,
+ "max_num_seqs": 5,
+ "max_model_len": 32768,
+ "mm_processor_cache_gb": 0,
+ "load_format": "dummy",
+ "default_sampling_params": {"max_tokens": 150, "ignore_eos": False},
+ },
+ {
+ "stage_id": 1,
+ "gpu_memory_utilization": 0.5,
+ "max_num_seqs": 5,
+ "max_model_len": 32768,
+ "load_format": "dummy",
+ "default_sampling_params": {"max_tokens": 1000},
+ },
+ {
+ "stage_id": 2,
+ "max_num_seqs": 5,
+ "max_num_batched_tokens": 100000,
+ "load_format": "dummy",
+ "default_sampling_params": {"max_tokens": 2000},
+ },
+ ],
+ "platforms": {
+ "rocm": {
+ "stages": [
+ {"stage_id": 0, "max_num_seqs": 1, "default_sampling_params": {"max_tokens": 100}},
+ {
+ "stage_id": 1,
+ "max_num_seqs": 1,
+ "enforce_eager": True,
+ "default_sampling_params": {"max_tokens": 100},
+ },
+ {
+ "stage_id": 2,
+ "max_num_seqs": 1,
+ "max_num_batched_tokens": 1000000,
+ "default_sampling_params": {"max_tokens": 200},
+ },
+ ],
+ },
+ "xpu": {
+ "stages": [
+ {
+ "stage_id": 0,
+ "gpu_memory_utilization": 0.85,
+ "max_num_seqs": 1,
+ "tensor_parallel_size": 4,
+ "enforce_eager": True,
+ "max_num_batched_tokens": 4096,
+ "max_model_len": 4096,
+ "max_cudagraph_capture_size": 0,
+ "skip_mm_profiling": True,
+ "devices": "0,1,2,3",
+ "default_sampling_params": {"max_tokens": 100, "ignore_eos": False},
+ },
+ {
+ "stage_id": 1,
+ "gpu_memory_utilization": 0.6,
+ "max_num_seqs": 1,
+ "enforce_eager": True,
+ "max_num_batched_tokens": 4096,
+ "max_model_len": 4096,
+ "max_cudagraph_capture_size": 0,
+ "skip_mm_profiling": True,
+ "devices": "4",
+ },
+ {
+ "stage_id": 2,
+ "gpu_memory_utilization": 0.3,
+ "max_num_seqs": 1,
+ "max_num_batched_tokens": 100000,
+ "max_cudagraph_capture_size": 0,
+ "skip_mm_profiling": True,
+ "devices": "5",
+ "default_sampling_params": {"max_tokens": 2000},
+ },
+ ],
+ },
+ },
+ },
+ # Single-stage thinker-only topology for the abort test.
+ "qwen2_5_omni_thinker_only": {
+ "async_chunk": False,
+ "pipeline": "qwen2_5_omni_thinker_only",
+ "stages": [
+ {
+ "stage_id": 0,
+ "max_num_seqs": 1,
+ "gpu_memory_utilization": 0.9,
+ "enforce_eager": True,
+ "max_num_batched_tokens": 16384,
+ "max_model_len": 16384,
+ "skip_mm_profiling": True,
+ "mm_processor_cache_gb": 0,
+ "load_format": "dummy",
+ "devices": "0",
+ "default_sampling_params": {
+ "temperature": 0.0,
+ "top_p": 1.0,
+ "top_k": -1,
+ "max_tokens": 128,
+ "seed": 42,
+ "repetition_penalty": 1.1,
+ },
+ },
+ ],
+ },
+}
+
+
+def _materialize_ci_overlay(model_type: str) -> Path:
+ import yaml
+
+ if model_type not in _CI_OVERLAYS:
+ raise KeyError(f"No CI overlay registered for {model_type!r}. Available: {sorted(_CI_OVERLAYS)}")
+
+ _CI_GENERATED_DIR.mkdir(parents=True, exist_ok=True)
+ out = _CI_GENERATED_DIR / f"{model_type}.yaml"
+
+ overlay = {**_CI_OVERLAYS[model_type]}
+ base = overlay.get("base_config")
+ if base:
+ overlay["base_config"] = str(_DEPLOY_DIR / base)
+
+ with open(out, "w", encoding="utf-8") as f:
+ yaml.safe_dump(overlay, f, sort_keys=False)
+ return out
+
+
+def get_deploy_config_path(rel_path: str) -> str:
+ """Resolve a deploy yaml; ``ci/.yaml`` materializes from ``_CI_OVERLAYS``."""
+ if rel_path.startswith("ci/") and rel_path.endswith(".yaml"):
+ model_type = rel_path[len("ci/") : -len(".yaml")]
+ if model_type in _CI_OVERLAYS:
+ return str(_materialize_ci_overlay(model_type))
+ return str(_DEPLOY_DIR / rel_path)
+
+
if current_platform.is_rocm():
from amdsmi import (
amdsmi_get_gpu_vram_usage,
diff --git a/tests/utils/test_audio.py b/tests/utils/test_audio.py
new file mode 100644
index 0000000000..0e483e6468
--- /dev/null
+++ b/tests/utils/test_audio.py
@@ -0,0 +1,79 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""Unit tests for vllm_omni.utils.audio."""
+
+import numpy as np
+import pytest
+import torch
+
+from vllm_omni.utils.audio import mel_filter_bank, peak_normalize
+
+# Parameter combinations used across the codebase.
+_PARAM_SETS = [
+ # Qwen3-TTS talker / speaker encoder (sr=24000)
+ dict(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=12000),
+ # CosyVoice3 whisper encoder, Qwen3-TTS 25Hz tokenizer (sr=16000, 80 mels)
+ dict(sr=16000, n_fft=400, n_mels=80),
+ # CosyVoice3 whisper encoder (sr=16000, 128 mels)
+ dict(sr=16000, n_fft=400, n_mels=128),
+]
+
+_parametrize_params = pytest.mark.parametrize(
+ "params", _PARAM_SETS, ids=lambda p: f"{p['sr']}_{p['n_fft']}_{p['n_mels']}"
+)
+
+
+class TestMelFilterBank:
+ @_parametrize_params
+ def test_output_shape(self, params):
+ fb = mel_filter_bank(**params)
+ n_freqs = params["n_fft"] // 2 + 1
+ assert fb.shape == (params["n_mels"], n_freqs)
+
+ @_parametrize_params
+ def test_non_negative(self, params):
+ fb = mel_filter_bank(**params)
+ assert (fb >= 0).all()
+
+ def test_dtype_is_float(self):
+ fb = mel_filter_bank(sr=16000, n_fft=400, n_mels=80)
+ assert fb.dtype == torch.float32
+
+ def test_fmax_defaults_to_nyquist(self):
+ """When fmax is omitted it should equal sr / 2."""
+ fb_default = mel_filter_bank(sr=16000, n_fft=400, n_mels=80)
+ fb_explicit = mel_filter_bank(sr=16000, n_fft=400, n_mels=80, fmax=8000.0)
+ torch.testing.assert_close(fb_default, fb_explicit)
+
+ def test_each_mel_band_has_nonzero_energy(self):
+ """Every mel band should have at least one nonzero frequency bin."""
+ fb = mel_filter_bank(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=12000)
+ for i in range(fb.shape[0]):
+ assert fb[i].sum() > 0, f"mel band {i} is all zeros"
+
+ def test_higher_fmax_extends_coverage(self):
+ """A higher fmax should produce nonzero weights at higher frequency bins."""
+ fb_low = mel_filter_bank(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=6000)
+ fb_high = mel_filter_bank(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=12000)
+ # The highest nonzero column should be larger for fb_high.
+ last_nonzero_low = (fb_low.sum(dim=0) > 0).nonzero()[-1].item()
+ last_nonzero_high = (fb_high.sum(dim=0) > 0).nonzero()[-1].item()
+ assert last_nonzero_high > last_nonzero_low
+
+
+class TestPeakNormalize:
+ def test_silence_unchanged(self):
+ """All-zero input should remain all-zero."""
+ audio = np.zeros(1600, dtype=np.float32)
+ result = peak_normalize(audio, db_level=-6.0)
+ np.testing.assert_array_equal(result, audio)
+
+ def test_peak_reaches_target(self):
+ """After normalization, peak amplitude should be at target dB."""
+ rng = np.random.default_rng(7)
+ audio = rng.uniform(-0.4, 0.4, size=16000).astype(np.float32)
+
+ result = peak_normalize(audio, db_level=-6.0)
+ peak_db = 20 * np.log10(np.abs(result).max())
+ np.testing.assert_allclose(peak_db, -6.0, atol=1e-4)
diff --git a/tests/worker/test_omni_connector_mixin.py b/tests/worker/test_omni_connector_mixin.py
new file mode 100644
index 0000000000..0e162a37e5
--- /dev/null
+++ b/tests/worker/test_omni_connector_mixin.py
@@ -0,0 +1,1419 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for OmniConnectorModelRunnerMixin.
+
+These tests use a mock connector (in-memory dict store) and do not require
+GPU or vLLM runtime.
+"""
+
+from __future__ import annotations
+
+import time
+import unittest
+from types import SimpleNamespace
+from typing import Any
+from unittest.mock import MagicMock, patch
+
+import pytest
+import torch
+
+from vllm_omni.outputs import OmniConnectorOutput
+from vllm_omni.worker.omni_connector_model_runner_mixin import (
+ OmniConnectorModelRunnerMixin,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+# ------------------------------------------------------------------ #
+# Mock helpers
+# ------------------------------------------------------------------ #
+
+
+class MockConnector:
+ """In-memory connector for testing (mimics OmniConnectorBase)."""
+
+ def __init__(self, stage_id: int = 0):
+ self.stage_id = stage_id
+ self._store: dict[str, Any] = {}
+
+ def put(self, from_stage, to_stage, put_key, data):
+ key = f"{from_stage}_{to_stage}_{put_key}"
+ self._store[key] = data
+ return True, len(str(data)), None
+
+ def get(self, from_stage, to_stage, get_key, metadata=None):
+ key = f"{from_stage}_{to_stage}_{get_key}"
+ data = self._store.pop(key, None)
+ if data is None:
+ return None
+ return data, len(str(data))
+
+ def close(self):
+ pass
+
+
+def _make_model_config(
+ stage_id: int = 0,
+ async_chunk: bool = False,
+ worker_type: str = "ar",
+ custom_func: str | None = None,
+) -> SimpleNamespace:
+ return SimpleNamespace(
+ stage_connector_config=None,
+ async_chunk=async_chunk,
+ worker_type=worker_type,
+ custom_process_next_stage_input_func=custom_func,
+ )
+
+
+def _make_request(req_id: str, external_req_id: str | None = None):
+ r = SimpleNamespace(
+ request_id=req_id,
+ external_req_id=external_req_id or req_id,
+ additional_information=None,
+ prompt_token_ids=[],
+ num_computed_tokens=0,
+ )
+ return r
+
+
+class MixinHost(OmniConnectorModelRunnerMixin):
+ """Minimal class that mixes in the mixin for testing."""
+
+ pass
+
+
+class _FakeTPGroup:
+ def __init__(self, *, world_size: int, rank_in_group: int, follower_result: Any = None):
+ self.world_size = world_size
+ self.rank_in_group = rank_in_group
+ self.follower_result = follower_result
+ self.broadcast_inputs: list[Any] = []
+
+ def broadcast_object(self, obj: Any | None = None, src: int = 0):
+ self.broadcast_inputs.append(obj)
+ if self.rank_in_group == src:
+ return obj
+ return self.follower_result
+
+
+# ------------------------------------------------------------------ #
+# Test cases
+# ------------------------------------------------------------------ #
+
+
+class TestMixinAsyncChunkSendRecv(unittest.TestCase):
+ """Test 2: Async chunk send/recv + bg threads."""
+
+ def test_send_chunk_passes_is_finished_and_connector(self):
+ connector = MockConnector(stage_id=0)
+
+ sender = MixinHost()
+ sender.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=0, async_chunk=True),
+ )
+ sender._omni_connector = connector
+ sender._stage_id = 0
+ sender._async_chunk = True
+
+ seen = {}
+
+ def mock_process(transfer_manager, pooling_output, request, is_finished=False):
+ seen["connector"] = transfer_manager.connector
+ seen["is_finished"] = is_finished
+ return {"data": pooling_output, "finished": is_finished}
+
+ sender._custom_process_func = mock_process
+
+ request = _make_request("req-1", "ext-req-1")
+ request.is_finished = lambda: True
+ sender._send_single_request(
+ {
+ "stage_id": 0,
+ "next_stage_id": 1,
+ "request_id": "ext-req-1",
+ "request": request,
+ "pooling_output": {"value": 42},
+ }
+ )
+ self.assertIs(seen["connector"], connector)
+ self.assertTrue(seen["is_finished"])
+
+ sender.shutdown_omni_connectors()
+
+ def test_send_chunk_does_not_retry_real_type_error(self):
+ connector = MockConnector(stage_id=0)
+
+ sender = MixinHost()
+ sender.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=0, async_chunk=True),
+ )
+ sender._omni_connector = connector
+ sender._stage_id = 0
+ sender._async_chunk = True
+
+ seen = {"calls": 0}
+
+ def broken_process(transfer_manager, pooling_output, request, is_finished=""):
+ seen["calls"] += 1
+ return {"data": is_finished + "tail"}
+
+ sender._custom_process_func = broken_process
+
+ request = _make_request("req-1", "ext-req-1")
+ request.is_finished = lambda: True
+ ok = sender.send_chunk(request, pooling_output={"value": 42})
+ self.assertFalse(ok)
+ self.assertEqual(seen["calls"], 1)
+
+ sender.shutdown_omni_connectors()
+
+
+class TestMixinKVCacheTransfer(unittest.TestCase):
+ """Test 3: KV cache delegation to OmniKVTransferManager."""
+
+ def test_send_kv_delegates(self):
+ mock_kvm = MagicMock()
+ mock_kvm.handle_finished_requests_kv_transfer.return_value = ["req-1"]
+
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ kv_transfer_manager=mock_kvm,
+ )
+
+ result = host.send_kv_cache(
+ finished_reqs={"req-1": {"seq_len": 10, "block_ids": [0]}},
+ kv_caches=[],
+ block_size=16,
+ cache_dtype="float16",
+ )
+ self.assertEqual(result, ["req-1"])
+ mock_kvm.handle_finished_requests_kv_transfer.assert_called_once()
+
+ host.shutdown_omni_connectors()
+
+ def test_recv_kv_delegates(self):
+ mock_kvm = MagicMock()
+ mock_kvm.receive_kv_cache_for_request.return_value = ({"layer_blocks": {}}, 100)
+
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ kv_transfer_manager=mock_kvm,
+ )
+
+ data, size = host.recv_kv_cache("req-1")
+ self.assertIsNotNone(data)
+ self.assertEqual(size, 100)
+ mock_kvm.receive_kv_cache_for_request.assert_called_once()
+
+ host.shutdown_omni_connectors()
+
+ def test_receive_multi_kv_fetches_companions_via_mixin(self):
+ mock_kvm = MagicMock()
+
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ kv_transfer_manager=mock_kvm,
+ )
+
+ host.recv_kv_cache = MagicMock(
+ side_effect=[({"layer_blocks": {"k": [1]}}, 64), ({"layer_blocks": {"k": [2]}}, 32)]
+ )
+ seen = {}
+
+ def collect_cfg(request_id, cfg_role_payloads):
+ seen["request_id"] = request_id
+ seen["cfg_role_payloads"] = cfg_role_payloads
+ return {"cfg_text_kv_metadata": {"seq_len": 3}}
+
+ req = SimpleNamespace(
+ request_id="req-1",
+ sampling_params=SimpleNamespace(cfg_kv_request_ids={"cfg_text": "req-1__cfg_text"}),
+ )
+ ok = host.receive_multi_kv_cache(req, cfg_kv_collect_func=collect_cfg)
+ self.assertTrue(ok)
+ host.recv_kv_cache.assert_any_call("req-1", target_device=None)
+ host.recv_kv_cache.assert_any_call("req-1__cfg_text", target_device=None)
+ mock_kvm.apply_kv_cache_to_request.assert_called_once_with(req, {"layer_blocks": {"k": [1]}})
+ self.assertEqual(seen["request_id"], "req-1")
+ self.assertEqual(
+ seen["cfg_role_payloads"],
+ {"cfg_text": ({"layer_blocks": {"k": [2]}}, 32)},
+ )
+ self.assertEqual(req.sampling_params.cfg_text_kv_metadata, {"seq_len": 3})
+
+ host.shutdown_omni_connectors()
+
+ def test_receive_multi_kv_skips_inactive_request(self):
+ mock_kvm = MagicMock()
+
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ kv_transfer_manager=mock_kvm,
+ )
+
+ host.requests = {}
+ host.recv_kv_cache = MagicMock(return_value=({"layer_blocks": {"k": [1]}}, 64))
+ req = SimpleNamespace(request_id="req-1", sampling_params=None)
+
+ ok = host.receive_multi_kv_cache(req)
+
+ self.assertFalse(ok)
+ host.recv_kv_cache.assert_not_called()
+ mock_kvm.apply_kv_cache_to_request.assert_not_called()
+
+ host.shutdown_omni_connectors()
+
+
+class TestOmniConnectorOutput(unittest.TestCase):
+ """Test 4: Output aggregation across transfer modes."""
+
+ def test_output_aggregation(self):
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ )
+
+ host._chunk_ready_req_ids.add("req-1")
+ host._chunk_finished_req_ids.add("req-2")
+ host._local_request_metadata["req-1"] = {"next_stage_prompt_len": 10}
+ host._stage_recv_req_ids.add("req-3")
+
+ output = host.get_omni_connector_output()
+ self.assertIsInstance(output, OmniConnectorOutput)
+ self.assertEqual(output.chunk_ready_req_ids, {"req-1"})
+ self.assertEqual(output.chunk_finished_req_ids, {"req-2"})
+ self.assertEqual(output.request_metadata, {"req-1": {"next_stage_prompt_len": 10}})
+ self.assertEqual(output.stage_recv_req_ids, {"req-3"})
+
+ output2 = host.get_omni_connector_output()
+ self.assertEqual(output2.chunk_ready_req_ids, set())
+ self.assertEqual(output2.request_metadata, {})
+
+ host.shutdown_omni_connectors()
+
+
+class TestMixinNoConnector(unittest.TestCase):
+ """Edge case: mixin works gracefully without a connector."""
+
+ def test_no_connector(self):
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ )
+ self.assertIsNone(host._omni_connector)
+
+ results = host.recv_full_payload_inputs(scheduler_output=None)
+ self.assertIsNone(results)
+
+ sent = host.send_full_payload_outputs(None, {"req-1": {}})
+ self.assertEqual(sent, [])
+
+ ok = host.send_chunk(_make_request("req-1"), pooling_output={})
+ self.assertFalse(ok)
+
+ output = host.get_omni_connector_output()
+ self.assertIsInstance(output, OmniConnectorOutput)
+
+ host.shutdown_omni_connectors()
+
+
+class TestFinishedLoadReqsDrain(unittest.TestCase):
+ """Test A1 fix: get_omni_connector_output drains _finished_load_reqs."""
+
+ def test_finished_load_reqs_flow_to_chunk_ready(self):
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ )
+
+ host._finished_load_reqs.add("req-1")
+ host._finished_load_reqs.add("req-2")
+
+ output = host.get_omni_connector_output()
+ self.assertIn("req-1", output.chunk_ready_req_ids)
+ self.assertIn("req-2", output.chunk_ready_req_ids)
+
+ self.assertEqual(len(host._finished_load_reqs), 0)
+ self.assertEqual(len(host._chunk_ready_req_ids), 0)
+
+ host.shutdown_omni_connectors()
+
+
+class TestLoadCustomFuncSelection(unittest.TestCase):
+ def test_skips_legacy_stage_list_processors_for_full_payload_mode(self):
+ legacy_paths = [
+ "vllm_omni.model_executor.stage_input_processors.mimo_audio.llm2code2wav",
+ "vllm_omni.model_executor.stage_input_processors.mammoth_moda2.ar2dit",
+ "vllm_omni.model_executor.stage_input_processors.cosyvoice3.text2flow",
+ "vllm_omni.model_executor.stage_input_processors.glm_image.ar2diffusion",
+ ]
+
+ for func_path in legacy_paths:
+ selected_path, func = MixinHost._load_custom_func(
+ SimpleNamespace(
+ async_chunk=False,
+ custom_process_input_func=func_path,
+ custom_process_next_stage_input_func=None,
+ )
+ )
+ assert selected_path != func_path
+ assert func is None or MixinHost._is_connector_payload_builder(func)
+
+
+class TestFullPayloadSendWithCustomFunc(unittest.TestCase):
+ """Test B4: send_full_payload_outputs with full_payload_mode custom process func."""
+
+ def test_full_payload_send_passes_is_finished_and_connector(self):
+ seen = {}
+
+ def full_payload_func(transfer_manager, pooling_output, request, is_finished=False):
+ seen["connector"] = transfer_manager.connector
+ seen["is_finished"] = is_finished
+ seen["data"] = pooling_output
+ seen["rid"] = request.request_id if request else None
+ return {"processed": True, "finished": is_finished}
+
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ )
+ host._omni_connector = MockConnector(stage_id=0)
+ host._stage_id = 0
+ host._custom_process_func = full_payload_func
+
+ req = _make_request("req-1")
+ req.is_finished = lambda: True
+ sent = host.send_full_payload_outputs(
+ scheduler_output=None,
+ outputs={"req-1": ({"raw": 100}, req)},
+ )
+ self.assertEqual(sent, ["req-1"])
+ self.assertEqual(
+ seen,
+ {
+ "connector": host._omni_connector,
+ "is_finished": True,
+ "data": {"raw": 100},
+ "rid": "req-1",
+ },
+ )
+
+ host.shutdown_omni_connectors()
+
+ def test_accumulate_and_flush(self):
+ call_log = []
+
+ def full_payload_func(transfer_manager, pooling_output, request):
+ call_log.append(request.request_id if request else None)
+ return {"processed": True}
+
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ )
+ host._omni_connector = MockConnector(stage_id=0)
+ host._stage_id = 0
+ host._custom_process_func = full_payload_func
+
+ req = _make_request("req-1")
+ host.accumulate_full_payload_output("req-1", {"raw": 42}, req)
+ self.assertEqual(len(host._pending_full_payload_send), 1)
+
+ host.flush_full_payload_outputs({"req-1"})
+ self.assertEqual(len(host._pending_full_payload_send), 0)
+ self.assertEqual(len(call_log), 1)
+ self.assertEqual(call_log[0], "req-1")
+
+ time.sleep(0.1)
+ host.shutdown_omni_connectors()
+
+
+class TestKVSentReqIdsAccumulation(unittest.TestCase):
+ """Test that kv_sent_req_ids accumulates results from send_kv_cache."""
+
+ def test_kv_sent_accumulation(self):
+ mock_kvm = MagicMock()
+ mock_kvm.handle_finished_requests_kv_transfer.return_value = ["req-1", "req-2"]
+
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(),
+ kv_transfer_manager=mock_kvm,
+ )
+
+ host.send_kv_cache(
+ finished_reqs={"req-1": {}, "req-2": {}},
+ kv_caches=[],
+ block_size=16,
+ cache_dtype="float16",
+ )
+
+ output = host.get_omni_connector_output()
+ self.assertIn("req-1", output.kv_sent_req_ids)
+ self.assertIn("req-2", output.kv_sent_req_ids)
+
+ output2 = host.get_omni_connector_output()
+ self.assertEqual(output2.kv_sent_req_ids, [])
+
+ host.shutdown_omni_connectors()
+
+
+class TestChunkStreamCompletedGuard(unittest.TestCase):
+ """Test that register_chunk_recv is skipped after finish sentinel.
+
+ This validates the fix for the race condition where the scheduling
+ coordinator re-registers a request for chunk polling after its
+ upstream chunk stream has already finished (is_finished sentinel
+ received), causing the bg recv thread to poll for a non-existent
+ shared-memory segment (e.g. ``_0_7`` when only 7 chunks 0–6 exist).
+ """
+
+ def _make_host(self, stage_id: int = 1) -> MixinHost:
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=stage_id, async_chunk=True),
+ )
+ host._omni_connector = MockConnector(stage_id=stage_id)
+ host._stage_id = stage_id
+ host._async_chunk = True
+ return host
+
+ def test_register_blocked_after_finish_sentinel(self):
+ """register_chunk_recv must be a no-op after the finish sentinel."""
+ host = self._make_host(stage_id=1)
+
+ req = _make_request("req-1", "ext-req-1")
+
+ # Simulate the bg thread having received the finish sentinel:
+ with host._lock:
+ host._chunk_stream_completed.add("req-1")
+
+ # Now try to re-register — this mimics the coordinator asking
+ # the model runner to poll for the next (non-existent) chunk.
+ host.register_chunk_recv(req)
+
+ # The request must NOT appear in _pending_load_reqs
+ self.assertNotIn(
+ "req-1",
+ host._pending_load_reqs,
+ "register_chunk_recv should skip requests whose chunk stream is already complete",
+ )
+
+ host.shutdown_omni_connectors()
+
+ def test_register_allowed_before_finish(self):
+ """register_chunk_recv works normally before finish sentinel."""
+ host = self._make_host(stage_id=1)
+ req = _make_request("req-1", "ext-req-1")
+
+ host.register_chunk_recv(req)
+ self.assertIn(
+ "req-1",
+ host._pending_load_reqs,
+ "register_chunk_recv should add request to pending when stream is not yet complete",
+ )
+
+ host.shutdown_omni_connectors()
+
+ def test_finish_sentinel_populates_completed_set(self):
+ """Receiving is_finished=True adds to _chunk_stream_completed."""
+ host = self._make_host(stage_id=1)
+
+ # Simulate _poll_single_request receiving is_finished=True
+ req_id = "req-1"
+ with host._lock:
+ host._chunk_finished_req_ids.add(req_id)
+ host._chunk_stream_completed.add(req_id)
+ host._local_stage_payload_cache[req_id] = {"finished": True}
+ host._local_request_metadata[req_id] = {}
+ host._finished_load_reqs.add(req_id)
+ host._pending_load_reqs.pop(req_id, None)
+
+ self.assertIn(req_id, host._chunk_stream_completed)
+
+ # Subsequent register_chunk_recv should be blocked
+ req = _make_request(req_id, f"ext-{req_id}")
+ host.register_chunk_recv(req)
+ self.assertNotIn(req_id, host._pending_load_reqs)
+
+ host.shutdown_omni_connectors()
+
+ def test_stage_0_always_skipped(self):
+ """Stage-0 has no upstream, register_chunk_recv is always no-op."""
+ host = self._make_host(stage_id=0)
+ host._stage_id = 0
+
+ req = _make_request("req-1")
+ host.register_chunk_recv(req)
+ self.assertNotIn("req-1", host._pending_load_reqs)
+
+ host.shutdown_omni_connectors()
+
+ def test_full_payload_recv_guard_still_works(self):
+ """Pre-existing guard: staged full-payload results prevent registration."""
+ host = self._make_host(stage_id=1)
+
+ with host._lock:
+ host._stage_recv_req_ids.add("req-1")
+
+ req = _make_request("req-1", "ext-req-1")
+ host.register_chunk_recv(req)
+ self.assertNotIn("req-1", host._pending_load_reqs)
+
+ host.shutdown_omni_connectors()
+
+
+class TestCleanupFinishedRequest(unittest.TestCase):
+ """Test cleanup_finished_request frees per-request mixin state."""
+
+ def _make_host(self, stage_id: int = 1) -> MixinHost:
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=stage_id, async_chunk=True),
+ )
+ host._omni_connector = MockConnector(stage_id=stage_id)
+ host._stage_id = stage_id
+ host._async_chunk = True
+ return host
+
+ def test_cleanup_removes_all_state(self):
+ """cleanup_finished_request removes all tracking dicts/sets."""
+ host = self._make_host(stage_id=1)
+ req_id = "req-1"
+ ext_id = "ext-req-1"
+
+ # Simulate state accumulated during a request's lifetime
+ host._request_ids_mapping[req_id] = ext_id
+ host._put_req_chunk[ext_id] = 5
+ host._get_req_chunk[req_id] = 3
+ host._send_side_request_payload[ext_id] = {"some": "data"}
+ host._code_prompt_token_ids[ext_id] = [[1, 2, 3]]
+ host._chunk_stream_completed.add(req_id)
+ host._stage_recv_req_ids.add(req_id)
+ host._local_stage_payload_cache[req_id] = {"engine_inputs": {}}
+ host._local_request_metadata[req_id] = {"prompt_len": 10}
+
+ # Cleanup
+ host.cleanup_finished_request(req_id)
+
+ # All state should be gone
+ self.assertNotIn(req_id, host._request_ids_mapping)
+ self.assertNotIn(ext_id, host._put_req_chunk)
+ self.assertNotIn(req_id, host._get_req_chunk)
+ self.assertNotIn(ext_id, host._send_side_request_payload)
+ self.assertNotIn(ext_id, host._code_prompt_token_ids)
+ self.assertNotIn(req_id, host._chunk_stream_completed)
+ self.assertNotIn(req_id, host._stage_recv_req_ids)
+ self.assertNotIn(req_id, host._local_stage_payload_cache)
+ self.assertNotIn(req_id, host._local_request_metadata)
+
+ host.shutdown_omni_connectors()
+
+ def test_cleanup_removes_per_cycle_ready_state(self):
+ """cleanup_finished_request clears ready/finished carry-over for req-id reuse."""
+ host = self._make_host(stage_id=1)
+ req_id = "req-1"
+
+ host._pending_load_reqs[req_id] = _make_request(req_id, "ext-req-1")
+ host._finished_load_reqs.add(req_id)
+ host._chunk_ready_req_ids.add(req_id)
+ host._chunk_finished_req_ids.add(req_id)
+
+ host.cleanup_finished_request(req_id)
+
+ self.assertNotIn(req_id, host._pending_load_reqs)
+ self.assertNotIn(req_id, host._finished_load_reqs)
+ self.assertNotIn(req_id, host._chunk_ready_req_ids)
+ self.assertNotIn(req_id, host._chunk_finished_req_ids)
+
+ host.shutdown_omni_connectors()
+
+ def test_cleanup_without_mapping(self):
+ """cleanup works for Stage-0 where _request_ids_mapping isn't set."""
+ host = self._make_host(stage_id=0)
+ host._stage_id = 0
+ req_id = "req-1"
+
+ # Stage-0 uses req_id directly (no ext_id mapping)
+ host._put_req_chunk[req_id] = 3
+ host._get_req_chunk[req_id] = 0
+
+ host.cleanup_finished_request(req_id)
+
+ self.assertNotIn(req_id, host._put_req_chunk)
+ self.assertNotIn(req_id, host._get_req_chunk)
+
+ host.shutdown_omni_connectors()
+
+ def test_prune_inactive_requests_cleans_stale_state_but_keeps_active(self):
+ """Inactive request IDs should be pruned without touching active ones."""
+ host = self._make_host(stage_id=1)
+ active_req_id = "req-active"
+ stale_req_id = "req-stale"
+ stale_ext_id = "ext-stale"
+
+ host._request_ids_mapping[active_req_id] = "ext-active"
+ host._request_ids_mapping[stale_req_id] = stale_ext_id
+ host._put_req_chunk[stale_ext_id] = 2
+ host._get_req_chunk[stale_req_id] = 1
+ host._finished_load_reqs.add(stale_req_id)
+ host._chunk_ready_req_ids.update({active_req_id, stale_req_id})
+ host._chunk_finished_req_ids.add(stale_req_id)
+ host._chunk_stream_completed.add(stale_req_id)
+ host._stage_recv_req_ids.add(active_req_id)
+ host._send_side_request_payload[stale_ext_id] = {"stale": True}
+ host._code_prompt_token_ids[stale_ext_id] = [[1, 2, 3]]
+
+ pruned = host.prune_inactive_requests({active_req_id})
+
+ self.assertEqual(pruned, {stale_req_id})
+ self.assertIn(active_req_id, host._request_ids_mapping)
+ self.assertIn(active_req_id, host._chunk_ready_req_ids)
+ self.assertIn(active_req_id, host._stage_recv_req_ids)
+ self.assertNotIn(stale_req_id, host._request_ids_mapping)
+ self.assertNotIn(stale_ext_id, host._put_req_chunk)
+ self.assertNotIn(stale_req_id, host._get_req_chunk)
+ self.assertNotIn(stale_req_id, host._pending_load_reqs)
+ self.assertNotIn(stale_req_id, host._finished_load_reqs)
+ self.assertNotIn(stale_req_id, host._chunk_ready_req_ids)
+ self.assertNotIn(stale_req_id, host._chunk_finished_req_ids)
+ self.assertNotIn(stale_req_id, host._chunk_stream_completed)
+ self.assertNotIn(stale_req_id, host._stage_recv_req_ids)
+ self.assertNotIn(stale_ext_id, host._send_side_request_payload)
+ self.assertNotIn(stale_ext_id, host._code_prompt_token_ids)
+
+ host.shutdown_omni_connectors()
+
+ def test_prune_inactive_requests_keeps_recently_received_full_payload_state(self):
+ """Late bg-thread receives must survive until the scheduler catches up."""
+ host = self._make_host(stage_id=1)
+ req_id = "req-recv-race"
+ ext_id = "ext-recv-race"
+
+ host._request_ids_mapping[req_id] = ext_id
+ host._put_req_chunk[ext_id] = 1
+ host._local_stage_payload_cache[req_id] = {"engine_inputs": {"ids": [1, 2, 3]}}
+ host._local_request_metadata[req_id] = {"next_stage_prompt_len": 3}
+ host._stage_recv_req_ids.add(req_id)
+
+ pruned = host.prune_inactive_requests(set())
+
+ self.assertEqual(pruned, set())
+ self.assertIn(req_id, host._request_ids_mapping)
+ self.assertIn(req_id, host._local_stage_payload_cache)
+ self.assertIn(req_id, host._local_request_metadata)
+ self.assertIn(req_id, host._stage_recv_req_ids)
+ self.assertIn(ext_id, host._put_req_chunk)
+
+ # Once the scheduler has consumed the wake-up and the request really
+ # disappears from all protected sets, prune should clean it up.
+ host._stage_recv_req_ids.clear()
+ host._local_stage_payload_cache.clear()
+ host._local_request_metadata.clear()
+
+ pruned = host.prune_inactive_requests(set())
+
+ self.assertEqual(pruned, {req_id})
+ self.assertNotIn(req_id, host._request_ids_mapping)
+ self.assertNotIn(ext_id, host._put_req_chunk)
+
+ host.shutdown_omni_connectors()
+
+
+class TestSendChunkCachesMapping(unittest.TestCase):
+ """Test that send_chunk caches internal→external req ID mapping."""
+
+ def test_send_chunk_populates_request_ids_mapping(self):
+ """send_chunk should cache the internal→external mapping."""
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=0, async_chunk=True),
+ )
+ host._omni_connector = MockConnector(stage_id=0)
+ host._stage_id = 0
+ host._async_chunk = True
+
+ def mock_process(transfer_manager, pooling_output, request):
+ return {"data": "test", "finished": False}
+
+ host._custom_process_func = mock_process
+
+ request = _make_request("internal-1", "external-1")
+ host.send_chunk(request, pooling_output={"v": 1})
+
+ # The mapping should be cached
+ self.assertEqual(
+ host._request_ids_mapping.get("internal-1"),
+ "external-1",
+ )
+
+ time.sleep(0.1)
+ host.shutdown_omni_connectors()
+
+
+class TestLocalPayloadCacheLifecycle(unittest.TestCase):
+ """Unit tests for the local payload cache API (RFC §2.4)."""
+
+ def _make_host(self) -> MixinHost:
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=0),
+ )
+ host._omni_connector = MockConnector(stage_id=0)
+ host._stage_id = 0
+ return host
+
+ def test_put_get_pop(self):
+ host = self._make_host()
+ payload = {"engine_inputs": {"ids": [1, 2, 3]}}
+ host.put_local_stage_payload("r1", payload)
+
+ self.assertEqual(host.get_local_stage_payload("r1"), payload)
+ popped = host.pop_local_stage_payload("r1")
+ self.assertEqual(popped, payload)
+ self.assertIsNone(host.get_local_stage_payload("r1"))
+ host.shutdown_omni_connectors()
+
+ def test_recv_full_payload_inputs_populates_local_cache(self):
+ host = self._make_host()
+ host._omni_connector = MockConnector(stage_id=0)
+ host._stage_id = 0
+
+ # Simulate a full payload already staged by the bg recv path
+ with host._lock:
+ host._local_stage_payload_cache["r1"] = {"tok": [10]}
+ host._stage_recv_req_ids.add("r1")
+
+ host.recv_full_payload_inputs(scheduler_output=None)
+ self.assertEqual(host.get_local_stage_payload("r1"), {"tok": [10]})
+ host.shutdown_omni_connectors()
+
+ def test_rank0_only_polls_connector_for_tp_full_payload(self):
+ host = self._make_host()
+ host._omni_connector = MagicMock()
+ host._stage_id = 2
+ host._local_rank = 0
+ host._request_ids_mapping["r1"] = "ext-r1"
+ host._get_req_chunk["r1"] = 0
+ payload = {"tok": [10], "finished": torch.tensor(True)}
+ connector_result = (payload, 123)
+ host._omni_connector.get.return_value = connector_result
+ tp_group = _FakeTPGroup(world_size=2, rank_in_group=0)
+
+ with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group):
+ made_progress = host._poll_single_request("r1")
+
+ self.assertTrue(made_progress)
+ host._omni_connector.get.assert_called_once_with("1", "2", "ext-r1_1_0")
+ self.assertEqual(tp_group.broadcast_inputs, [])
+ self.assertEqual(host.get_local_stage_payload("r1"), payload)
+ self.assertIn("r1", host._full_payload_pending_broadcast_req_ids)
+ self.assertNotIn("r1", host._stage_recv_req_ids)
+ self.assertIsNone(host.get_local_request_metadata("r1"))
+ host.shutdown_omni_connectors()
+
+ def test_tp_follower_skips_connector_poll_for_full_payload(self):
+ host = self._make_host()
+ host._omni_connector = MagicMock()
+ host._stage_id = 2
+ host._local_rank = 1
+ host._request_ids_mapping["r1"] = "ext-r1"
+ host._get_req_chunk["r1"] = 0
+ tp_group = _FakeTPGroup(world_size=2, rank_in_group=1)
+
+ with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group):
+ made_progress = host._poll_single_request("r1")
+
+ self.assertFalse(made_progress)
+ host._omni_connector.get.assert_not_called()
+ self.assertEqual(tp_group.broadcast_inputs, [])
+ self.assertNotIn("r1", host._local_stage_payload_cache)
+ host.shutdown_omni_connectors()
+
+ def test_recv_full_payload_inputs_broadcasts_tp_leader_results_to_followers(self):
+ host = self._make_host()
+ host._omni_connector = MagicMock()
+ host._stage_id = 2
+ host._local_rank = 1
+ host._pending_load_reqs["r1"] = object()
+ payload = {"tok": [10], "finished": torch.tensor(True)}
+ tp_group = _FakeTPGroup(world_size=2, rank_in_group=1, follower_result={"r1": payload})
+
+ with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group):
+ results = host.recv_full_payload_inputs(scheduler_output=None)
+
+ self.assertEqual(results, {"r1": payload})
+ self.assertEqual(host.get_local_stage_payload("r1"), payload)
+ self.assertEqual(host.get_local_request_metadata("r1"), {})
+ self.assertEqual(host._stage_recv_req_ids, {"r1"})
+ self.assertNotIn("r1", host._pending_load_reqs)
+ self.assertEqual(tp_group.broadcast_inputs, [None])
+ host.shutdown_omni_connectors()
+
+
+class TestTPAsyncChunkFanout(unittest.TestCase):
+ def _make_host(self, rank: int) -> MixinHost:
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=2, async_chunk=True, worker_type="gen"),
+ )
+ host._omni_connector = MagicMock()
+ host._stage_id = 2
+ host._async_chunk = True
+ host._model_mode = "gen"
+ host._local_rank = rank
+ host._request_ids_mapping["r1"] = "ext-r1"
+ host._get_req_chunk["r1"] = 0
+ return host
+
+ def test_rank0_only_polls_connector_for_tp_async_chunk(self):
+ host = self._make_host(rank=0)
+ payload = {
+ "code_predictor_codes": [10, 11],
+ "left_context_size": 0,
+ "finished": torch.tensor(False),
+ }
+ host._omni_connector.get.return_value = (payload, 123)
+ tp_group = _FakeTPGroup(world_size=2, rank_in_group=0)
+
+ with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group):
+ made_progress = host._poll_single_request("r1")
+
+ self.assertTrue(made_progress)
+ host._omni_connector.get.assert_called_once_with("1", "2", "ext-r1_1_0")
+ self.assertEqual(host.get_local_stage_payload("r1"), payload)
+ self.assertIn("r1", host._finished_load_reqs)
+ self.assertIn("r1", host._async_chunk_updated_req_ids)
+ self.assertEqual(tp_group.broadcast_inputs, [])
+ host.shutdown_omni_connectors()
+
+ def test_tp_follower_skips_connector_poll_for_async_chunk(self):
+ host = self._make_host(rank=1)
+ tp_group = _FakeTPGroup(world_size=2, rank_in_group=1)
+
+ with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group):
+ made_progress = host._poll_single_request("r1")
+
+ self.assertFalse(made_progress)
+ host._omni_connector.get.assert_not_called()
+ self.assertIsNone(host.get_local_stage_payload("r1"))
+ self.assertEqual(tp_group.broadcast_inputs, [])
+ host.shutdown_omni_connectors()
+
+ def test_get_output_broadcasts_tp_async_chunk_payloads_to_followers(self):
+ host = self._make_host(rank=1)
+ host._pending_load_reqs["r1"] = object()
+ payload = {
+ "code_predictor_codes": [10, 11],
+ "left_context_size": 0,
+ "finished": torch.tensor(True),
+ }
+ packet = {
+ "staged_payloads": {"r1": payload},
+ "request_metadata": {"r1": {"code_predictor_codes": [10, 11], "left_context_size": 0}},
+ "newly_finished": {"r1"},
+ "chunk_finished": {"r1"},
+ }
+ tp_group = _FakeTPGroup(world_size=2, rank_in_group=1, follower_result=packet)
+
+ with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group):
+ output = host.get_omni_connector_output()
+
+ self.assertEqual(output.chunk_ready_req_ids, {"r1"})
+ self.assertEqual(output.chunk_finished_req_ids, {"r1"})
+ self.assertEqual(
+ output.request_metadata,
+ {"r1": {"code_predictor_codes": [10, 11], "left_context_size": 0}},
+ )
+ self.assertEqual(host.get_local_stage_payload("r1"), payload)
+ self.assertNotIn("r1", host._pending_load_reqs)
+ self.assertIn("r1", host._chunk_stream_completed)
+ self.assertEqual(tp_group.broadcast_inputs, [None])
+ host.shutdown_omni_connectors()
+
+
+class TestKVTransferLifecycle(unittest.TestCase):
+ """Unit tests for KV transfer lifecycle methods."""
+
+ def _make_host(self) -> MixinHost:
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=0),
+ )
+ return host
+
+ def test_mark_drain_ack_complete(self):
+ host = self._make_host()
+ self.assertFalse(host.has_pending_kv_work())
+
+ host.mark_kv_transfer("r1", seq_len=100, block_ids=[0, 1, 2])
+ self.assertTrue(host.has_pending_kv_work())
+ self.assertTrue(host.is_kv_transfer_triggered("r1"))
+
+ # Drain moves pending → active
+ pending = host.drain_pending_kv_transfers()
+ self.assertEqual(pending, {"r1": {"seq_len": 100, "block_ids": [0, 1, 2]}})
+ self.assertIn("r1", host._kv_active_transfers)
+ self.assertTrue(host.has_pending_kv_work())
+
+ # Ack moves active → completed
+ host.ack_kv_transfers(["r1"])
+ self.assertNotIn("r1", host._kv_active_transfers)
+ self.assertIn("r1", host._kv_completed_transfers)
+
+ # Drain completed
+ completed = host.drain_completed_kv_transfers()
+ self.assertEqual(completed, {"r1"})
+ self.assertFalse(host.has_pending_kv_work())
+ host.shutdown_omni_connectors()
+
+ def test_mark_dedup(self):
+ host = self._make_host()
+ host.mark_kv_transfer("r1", seq_len=100, block_ids=[0])
+ host.mark_kv_transfer("r1", seq_len=200, block_ids=[0, 1])
+ # Second mark is a no-op
+ self.assertEqual(host._kv_pending_transfers["r1"]["seq_len"], 100)
+ host.shutdown_omni_connectors()
+
+ def test_cleanup_removes_kv_state(self):
+ host = self._make_host()
+ host.mark_kv_transfer("r1", seq_len=50, block_ids=[0])
+ host.drain_pending_kv_transfers()
+ host.cleanup_finished_request("r1")
+ self.assertFalse(host.is_kv_transfer_triggered("r1"))
+ self.assertNotIn("r1", host._kv_active_transfers)
+ self.assertFalse(host.has_pending_kv_work())
+ host.shutdown_omni_connectors()
+
+
+class TestAsyncPayloadLifecycle(unittest.TestCase):
+ """Regression tests for async payload delivery lifecycle."""
+
+ def test_send_side_request_payload_not_cleared_before_payload_is_consumable(self):
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"),
+ )
+ host._request_ids_mapping["r1"] = "r1"
+ payload = {
+ "thinker_decode_embeddings": torch.ones(1, 2),
+ "thinker_output_token_ids": [1],
+ "override_keys": ["thinker_decode_embeddings", "thinker_output_token_ids"],
+ "finished": torch.tensor(False),
+ }
+
+ host._accumulate_payload("r1", dict(payload))
+ with host._lock:
+ host._finished_load_reqs.add("r1")
+
+ host.get_omni_connector_output()
+ self.assertIn("r1", host._send_side_request_payload)
+ host.shutdown_omni_connectors()
+
+ def test_payload_consumable_ignores_token_horizon_only_updates(self):
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"),
+ )
+ payload = {
+ "thinker_output_token_ids": [1, 2, 3],
+ "finished": torch.tensor(False),
+ "override_keys": [
+ "thinker_output_token_ids",
+ "thinker_decode_embeddings_token_start",
+ "thinker_decode_embeddings_token_end",
+ ],
+ "thinker_decode_embeddings_token_start": 2,
+ "thinker_decode_embeddings_token_end": 3,
+ }
+ self.assertFalse(host._payload_is_consumable(payload))
+ host.shutdown_omni_connectors()
+
+ def test_payload_consumable_accepts_decode_embeddings(self):
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"),
+ )
+ payload = {
+ "thinker_output_token_ids": [1, 2, 3],
+ "thinker_decode_embeddings": torch.ones(1, 2),
+ "finished": torch.tensor(False),
+ }
+ self.assertTrue(host._payload_is_consumable(payload))
+ host.shutdown_omni_connectors()
+
+ def test_ar_metadata_only_followup_chunk_does_not_rewake_request(self):
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"),
+ )
+ host._omni_connector = MagicMock()
+ host._stage_id = 1
+ host._async_chunk = True
+ host._model_mode = "ar"
+ host._request_ids_mapping["r1"] = "ext-r1"
+ host._get_req_chunk["r1"] = 0
+
+ host._omni_connector.get.side_effect = [
+ (
+ {
+ "thinker_decode_embeddings": torch.ones(1, 2),
+ "finished": torch.tensor(False),
+ },
+ 1,
+ ),
+ (
+ {
+ "next_stage_prompt_len": 7,
+ "finished": torch.tensor(False),
+ },
+ 1,
+ ),
+ ]
+
+ host._poll_single_request("r1")
+ output1 = host.get_omni_connector_output()
+ self.assertEqual(output1.chunk_ready_req_ids, {"r1"})
+
+ host._poll_single_request("r1")
+ output2 = host.get_omni_connector_output()
+ self.assertEqual(output2.chunk_ready_req_ids, set())
+ self.assertEqual(output2.request_metadata, {"r1": {"next_stage_prompt_len": 7}})
+
+ host.shutdown_omni_connectors()
+
+ def test_non_ar_recv_does_not_overwrite_unconsumed_staged_chunk(self):
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=2, async_chunk=True, worker_type="gen"),
+ )
+ host._omni_connector = MagicMock()
+ host._stage_id = 2
+ host._async_chunk = True
+ host._model_mode = "gen"
+ host._request_ids_mapping["r1"] = "ext-r1"
+ host._get_req_chunk["r1"] = 1
+ host._local_stage_payload_cache["r1"] = {
+ "code_predictor_codes": [1, 2, 3],
+ "left_context_size": 0,
+ "finished": torch.tensor(False),
+ }
+
+ made_progress = host._poll_single_request("r1")
+
+ self.assertFalse(made_progress)
+ host._omni_connector.get.assert_not_called()
+ self.assertEqual(host._get_req_chunk["r1"], 1)
+
+ host.shutdown_omni_connectors()
+
+ def test_non_ar_recv_waits_for_scheduler_handoff_before_fetching_next_chunk(self):
+ host = MixinHost()
+ host.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=2, async_chunk=True, worker_type="gen"),
+ )
+ host._omni_connector = MagicMock()
+ host._stage_id = 2
+ host._async_chunk = True
+ host._model_mode = "gen"
+ host._request_ids_mapping["r1"] = "ext-r1"
+ host._get_req_chunk["r1"] = 1
+ host._local_request_metadata["r1"] = {
+ "code_predictor_codes": [10, 11, 12],
+ "left_context_size": 0,
+ }
+ host._finished_load_reqs.add("r1")
+
+ made_progress = host._poll_single_request("r1")
+
+ self.assertFalse(made_progress)
+ host._omni_connector.get.assert_not_called()
+ self.assertEqual(host._get_req_chunk["r1"], 1)
+
+ output = host.get_omni_connector_output()
+ self.assertEqual(output.request_metadata["r1"]["code_predictor_codes"], [10, 11, 12])
+ self.assertEqual(output.chunk_ready_req_ids, {"r1"})
+
+ host._omni_connector.get.return_value = (
+ {
+ "code_predictor_codes": [20, 21, 22],
+ "left_context_size": 0,
+ "finished": torch.tensor(False),
+ },
+ 1,
+ )
+ made_progress = host._poll_single_request("r1")
+
+ self.assertTrue(made_progress)
+ host._omni_connector.get.assert_called_once()
+ self.assertEqual(host._get_req_chunk["r1"], 2)
+
+ host.shutdown_omni_connectors()
+
+
+class TestRankAwareKVRouting(unittest.TestCase):
+ def _make_host(self, *, from_tp: int, to_tp: int, local_rank: int) -> MixinHost:
+ host = MixinHost()
+ host.init_omni_connectors(vllm_config=None, model_config=_make_model_config(stage_id=1))
+ host._from_tp = from_tp
+ host._to_tp = to_tp
+ host._local_rank = local_rank
+ return host
+
+ def test_recv_keys_use_remote_rank_as_from_rank(self):
+ host = self._make_host(from_tp=4, to_tp=2, local_rank=1)
+ self.assertEqual(
+ host.get_rank_aware_kv_keys("req", from_stage=0),
+ ["req_0_0_2_1", "req_0_0_3_1"],
+ )
+ host.shutdown_omni_connectors()
+
+ def test_send_keys_route_from_rank_gt_to_rank(self):
+ host = self._make_host(from_tp=4, to_tp=2, local_rank=3)
+ self.assertEqual(host.get_rank_aware_kv_send_keys("req", from_stage=0), ["req_0_0_3_1"])
+ host.shutdown_omni_connectors()
+
+ def test_invalid_recv_rank_mapping_raises(self):
+ host = self._make_host(from_tp=3, to_tp=2, local_rank=1)
+ with self.assertRaises(ValueError):
+ host.get_rank_aware_kv_keys("req", from_stage=0)
+ host.shutdown_omni_connectors()
+
+ def test_invalid_send_rank_mapping_raises(self):
+ host = self._make_host(from_tp=3, to_tp=2, local_rank=1)
+ with self.assertRaises(ValueError):
+ host.get_rank_aware_kv_send_keys("req", from_stage=0)
+ host.shutdown_omni_connectors()
+
+ def test_merge_rank_sharded_payloads_concatenates_head_dimension(self):
+ host = self._make_host(from_tp=4, to_tp=2, local_rank=0)
+ payloads = [
+ {"layer_blocks": {"key_cache": [torch.ones(2, 1, 3)], "value_cache": [torch.ones(2, 1, 3)]}},
+ {"layer_blocks": {"key_cache": [torch.full((2, 1, 3), 2.0)], "value_cache": [torch.full((2, 1, 3), 2.0)]}},
+ ]
+ merged = host._merge_rank_sharded_kv_payloads(payloads)
+ self.assertEqual(tuple(merged["layer_blocks"]["key_cache"][0].shape), (2, 2, 3))
+ self.assertTrue(torch.equal(merged["layer_blocks"]["key_cache"][0][:, 0], torch.ones(2, 3)))
+ self.assertTrue(torch.equal(merged["layer_blocks"]["key_cache"][0][:, 1], torch.full((2, 3), 2.0)))
+ host.shutdown_omni_connectors()
+
+ def test_slice_rank_sharded_payload_splits_head_dimension(self):
+ host = self._make_host(from_tp=2, to_tp=4, local_rank=1)
+ payload = {
+ "layer_blocks": {
+ "key_cache": [torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)],
+ "value_cache": [torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)],
+ },
+ "metadata": {},
+ }
+ sliced = host._slice_rank_sharded_kv_payload(payload)
+ self.assertEqual(tuple(sliced["layer_blocks"]["key_cache"][0].shape), (2, 2, 3))
+ expected = torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)[:, 2:4, :]
+ self.assertTrue(torch.equal(sliced["layer_blocks"]["key_cache"][0], expected))
+ host.shutdown_omni_connectors()
+
+
+class TestAttachOmniConnectorOutput(unittest.TestCase):
+ def test_wraps_empty_model_runner_output_when_signals_exist(self):
+ from vllm.v1.worker.gpu_model_runner import EMPTY_MODEL_RUNNER_OUTPUT
+
+ host = MixinHost()
+ host.get_omni_connector_output = lambda: OmniConnectorOutput(chunk_ready_req_ids={"req-1"})
+
+ wrapped = host.attach_omni_connector_output(EMPTY_MODEL_RUNNER_OUTPUT)
+
+ self.assertIsNot(wrapped, EMPTY_MODEL_RUNNER_OUTPUT)
+ self.assertEqual(wrapped.omni_connector_output.chunk_ready_req_ids, {"req-1"})
+
+
+class TestConnectorConfigValidation(unittest.TestCase):
+ def test_invalid_connector_name_raises(self):
+ host = MixinHost()
+ model_config = _make_model_config(stage_id=1)
+ model_config.stage_connector_config = {"name": " "}
+
+ with self.assertRaisesRegex(RuntimeError, "missing connector name"):
+ host.init_omni_connectors(vllm_config=None, model_config=model_config)
+
+
+class _FailingConnector:
+ """Connector whose put() fails a configurable number of times."""
+
+ def __init__(self, fail_count: int = 1, raise_on_fail: bool = False):
+ self._fail_count = fail_count
+ self._raise_on_fail = raise_on_fail
+ self.attempt = 0
+
+ def put(self, from_stage, to_stage, put_key, data):
+ self.attempt += 1
+ if self.attempt <= self._fail_count:
+ if self._raise_on_fail:
+ raise ConnectionError("transient connector error")
+ return False, 0, None
+ return True, len(str(data)), None
+
+ def get(self, *a, **kw):
+ return None
+
+ def close(self):
+ pass
+
+
+class TestSendRetry(unittest.TestCase):
+ """Tests for P1-2: failed connector sends must be retried."""
+
+ def _make_sender(self, connector):
+ sender = MixinHost()
+ sender.init_omni_connectors(
+ vllm_config=None,
+ model_config=_make_model_config(stage_id=0, async_chunk=True),
+ )
+ sender._omni_connector = connector
+ sender._stage_id = 0
+ sender._async_chunk = True
+ return sender
+
+ def _make_task(self, req_id="r1"):
+ return {
+ "stage_id": 0,
+ "next_stage_id": 1,
+ "request_id": req_id,
+ "data": {"payload": "test"},
+ }
+
+ def test_send_single_request_returns_false_on_put_failure(self):
+ connector = _FailingConnector(fail_count=999)
+ sender = self._make_sender(connector)
+
+ result = sender._send_single_request(self._make_task())
+ self.assertFalse(result)
+ sender.shutdown_omni_connectors()
+
+ def test_send_single_request_does_not_decrement_on_failure(self):
+ connector = _FailingConnector(fail_count=999)
+ sender = self._make_sender(connector)
+ sender._pending_save_counts["r1"] = 1
+
+ sender._send_single_request(self._make_task())
+ self.assertEqual(sender._pending_save_counts.get("r1"), 1, "pending count must NOT be decremented on failure")
+ sender.shutdown_omni_connectors()
+
+ def test_send_single_request_decrements_on_success(self):
+ connector = MockConnector(stage_id=0)
+ sender = self._make_sender(connector)
+ sender._pending_save_counts["r1"] = 1
+
+ result = sender._send_single_request(self._make_task())
+ self.assertTrue(result)
+ self.assertNotIn("r1", sender._pending_save_counts, "pending count should be zero/removed on success")
+ sender.shutdown_omni_connectors()
+
+ def test_requeue_or_drop_requeues_on_first_failure(self):
+ sender = self._make_sender(MockConnector(stage_id=0))
+ task = self._make_task()
+
+ sender._requeue_or_drop_failed_send(task)
+
+ self.assertEqual(task.get("_retry_count"), 1)
+ with sender._lock:
+ dq = sender._pending_save_reqs.get("r1")
+ self.assertIsNotNone(dq)
+ self.assertEqual(len(dq), 1)
+ sender.shutdown_omni_connectors()
+
+ def test_requeue_or_drop_drops_after_max_retries(self):
+ sender = self._make_sender(MockConnector(stage_id=0))
+ sender._pending_save_counts["r1"] = 1
+ task = self._make_task()
+ task["_retry_count"] = sender._MAX_SEND_RETRIES # already at max
+
+ sender._requeue_or_drop_failed_send(task)
+
+ with sender._lock:
+ dq = sender._pending_save_reqs.get("r1")
+ self.assertTrue(dq is None or len(dq) == 0, "task should NOT be re-enqueued after max retries")
+ self.assertNotIn("r1", sender._pending_save_counts, "pending count should be cleaned up on final drop")
+ sender.shutdown_omni_connectors()
+
+ def test_save_loop_retries_on_exception(self):
+ """Integration: _save_loop retries a task when put() raises."""
+ from collections import deque
+
+ connector = _FailingConnector(fail_count=1, raise_on_fail=True)
+ sender = self._make_sender(connector)
+ task = self._make_task()
+
+ with sender._lock:
+ sender._pending_save_reqs["r1"] = deque([task])
+ sender._pending_save_counts["r1"] = 1
+
+ sender._stop_event.clear()
+
+ def run_one_loop():
+ sender._save_loop()
+
+ sender._stop_event.set() # will exit after one iteration
+ # Run manually instead of threading
+ # Simulate: pop task, send fails, requeue
+ popped_task = None
+ with sender._lock:
+ dq = sender._pending_save_reqs.get("r1")
+ if dq:
+ popped_task = dq.popleft()
+ if not dq:
+ del sender._pending_save_reqs["r1"]
+
+ if popped_task is not None:
+ success = False
+ try:
+ success = sender._send_single_request(popped_task)
+ except Exception:
+ pass
+ if not success:
+ sender._requeue_or_drop_failed_send(popped_task)
+
+ # After first failure, task should be re-enqueued
+ with sender._lock:
+ dq = sender._pending_save_reqs.get("r1")
+ self.assertIsNotNone(dq)
+ self.assertEqual(len(dq), 1)
+ requeued = dq[0]
+ self.assertEqual(requeued.get("_retry_count"), 1)
+
+ # Second attempt should succeed (connector now returns True)
+ success = sender._send_single_request(requeued)
+ self.assertTrue(success)
+ sender.shutdown_omni_connectors()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tools/nightly/generate_nightly_perf_excel.py b/tools/nightly/generate_nightly_perf_excel.py
index 817f37f664..6ba1d1eef0 100644
--- a/tools/nightly/generate_nightly_perf_excel.py
+++ b/tools/nightly/generate_nightly_perf_excel.py
@@ -23,6 +23,22 @@
GREY_BLOCK_FILL = PatternFill(start_color="D3D3D3", fill_type="solid")
# Diffusion sheet columns (Qwen-Image diffusion benchmark).
+# Per-stage latency metrics. Unpack from stage_durations_mean/p50/p99 dicts
+DIFFUSION_STAGE_LATENCY_COLUMNS: tuple[str, ...] = (
+ # "vae.encode_mean",
+ # "vae.encode_p50",
+ # "vae.encode_p99",
+ "vae.decode_mean",
+ "vae.decode_p50",
+ "vae.decode_p99",
+ "diffuse_mean",
+ "diffuse_p50",
+ "diffuse_p99",
+ "text_encoder.forward_mean",
+ "text_encoder.forward_p50",
+ "text_encoder.forward_p99",
+)
+
DIFFUSION_BENCHMARK_COLUMNS: tuple[str, ...] = (
"duration",
"completed_requests",
@@ -36,7 +52,7 @@
"peak_memory_mb_mean",
"peak_memory_mb_median",
"slo_attainment_rate",
-)
+) + DIFFUSION_STAGE_LATENCY_COLUMNS
DIFFUSION_NUMERIC_FORMAT_COLUMNS: tuple[str, ...] = DIFFUSION_BENCHMARK_COLUMNS
@@ -63,7 +79,7 @@
"build_id",
"build_url",
"source_file",
-)
+) + DIFFUSION_STAGE_LATENCY_COLUMNS
# Benchmark metric columns: grey the latest row's cell when value changed vs previous date.
BENCHMARK_COLUMNS: tuple[str, ...] = (
@@ -106,7 +122,7 @@
_COLUMNS_FILENAME = "nightly_perf_summary_columns.txt"
_RESULT_JSON_PREFIX = "result_test_"
-_DIFFUSION_JSON_PREFIX = "diffusion_perf_"
+_DIFFUSION_RESULT_PREFIX = "diffusion_result_"
DEFAULT_INPUT_DIR = os.getenv("DEFAULT_INPUT_DIR") if os.getenv("DEFAULT_INPUT_DIR") else "tests"
DEFAULT_OUTPUT_DIR = os.getenv("DEFAULT_OUTPUT_DIR") if os.getenv("DEFAULT_OUTPUT_DIR") else "tests"
DEFAULT_DIFFUSION_INPUT_DIR = os.getenv("DIFFUSION_BENCHMARK_DIR")
@@ -252,7 +268,7 @@ def parse_args() -> argparse.Namespace:
type=str,
default=None,
help=(
- "Directory containing diffusion_perf_*.json files; default is "
+ "Directory containing diffusion_result_*.json files; default is "
"DIFFUSION_BENCHMARK_DIR, fallback to --input-dir."
),
)
@@ -286,7 +302,7 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args()
-def _load_json_file(path: str) -> dict[str, Any] | None:
+def _load_json_file(path: str) -> dict[str, Any] | list[Any] | None:
"""Safely load a single JSON file; return None and log a warning on failure."""
try:
with open(path, encoding="utf-8") as f:
@@ -295,18 +311,18 @@ def _load_json_file(path: str) -> dict[str, Any] | None:
LOGGER.warning("failed to load json '%s': %s", path, exc)
return None
- if not isinstance(data, dict):
- LOGGER.warning("json root in '%s' is not an object, skip", path)
+ if not isinstance(data, (dict, list)):
+ LOGGER.warning("json root in '%s' is not a dict or list, skip", path)
return None
return data
def _parse_from_filename(filename: str) -> dict[str, Any]:
- """Parse test-related metadata from a result JSON filename.
+ """Parse test-related metadata from a ``result_test_*.json`` filename.
- Expected pattern (after prefix/suffix stripped):
- ____
+ Matches ``tests/dfx/perf/scripts/run_benchmark.py`` naming, including optional
+ ``_in{X}_out{Y}_`` before the timestamp (``na`` when unset).
"""
name, ext = os.path.splitext(filename)
if ext != ".json" or not name.startswith(_RESULT_JSON_PREFIX):
@@ -315,22 +331,42 @@ def _parse_from_filename(filename: str) -> dict[str, Any]:
core = name[len(_RESULT_JSON_PREFIX) :]
parts = core.split("_")
if len(parts) < 5:
- LOGGER.warning("filename '%s' does not match expected pattern, skip parsing test metadata", filename)
+ LOGGER.warning(
+ "filename '%s' does not match expected pattern (need >= 5 segments), skip parsing",
+ filename,
+ )
return {}
- timestamp = parts[-1]
- num_prompts_str = parts[-2]
- max_concurrency_str = parts[-3]
- dataset_name = parts[-4]
- test_name = "_".join(parts[:-4]) if parts[:-4] else ""
+ idx = len(parts) - 1
+ timestamp = parts[idx]
+ idx -= 1
parsed: dict[str, Any] = {}
-
if len(timestamp) >= 15:
parsed["date"] = timestamp
- if dataset_name in DATASET_NAME_ALLOWED:
- parsed["dataset_name"] = dataset_name
+ if idx >= 0 and parts[idx].startswith("out"):
+ parsed["random_output_len"] = parts[idx][3:]
+ idx -= 1
+ if idx >= 0 and parts[idx].startswith("in"):
+ parsed["random_input_len"] = parts[idx][2:]
+ idx -= 1
+
+ if idx < 3:
+ LOGGER.warning(
+ "filename '%s' has too few segments after timestamp / optional in-out (idx=%s)",
+ filename,
+ idx,
+ )
+ return parsed
+
+ num_prompts_str = parts[idx]
+ idx -= 1
+ flow_str = parts[idx]
+ idx -= 1
+ dataset_name = parts[idx]
+ idx -= 1
+ test_name = "_".join(parts[: idx + 1]) if idx >= 0 else ""
try:
parsed["num_prompts"] = int(num_prompts_str)
@@ -338,13 +374,16 @@ def _parse_from_filename(filename: str) -> dict[str, Any]:
pass
try:
- parsed["max_concurrency"] = int(max_concurrency_str)
+ parsed["max_concurrency"] = int(flow_str)
except (TypeError, ValueError):
pass
if test_name:
parsed["test_name"] = test_name
+ if dataset_name in DATASET_NAME_ALLOWED:
+ parsed["dataset_name"] = dataset_name
+
return parsed
@@ -396,27 +435,29 @@ def _iter_omni_json_records(input_dir: str) -> Iterable[dict[str, Any]]:
yield record
-def _parse_diffusion_from_filename(filename: str) -> dict[str, Any]:
- """Parse diffusion test_name/date from filename: diffusion_perf__.json"""
+def _parse_diffusion_result_from_filename(filename: str) -> dict[str, Any]:
+ """Parse test_name/date from filename: diffusion_result__.json"""
name, ext = os.path.splitext(filename)
- if ext != ".json" or not name.startswith(_DIFFUSION_JSON_PREFIX):
+ if ext != ".json" or not name.startswith(_DIFFUSION_RESULT_PREFIX):
return {}
- core = name[len(_DIFFUSION_JSON_PREFIX) :]
+ core = name[len(_DIFFUSION_RESULT_PREFIX) :]
parts = core.split("_")
if len(parts) < 2:
return {}
timestamp = parts[-1]
- test_name = "_".join(parts[:-1]) if parts[:-1] else ""
parsed: dict[str, Any] = {}
if len(timestamp) >= 15:
parsed["date"] = timestamp
- if test_name:
- parsed["test_name"] = test_name
return parsed
-def _iter_diffusion_json_records(input_dir: str) -> Iterable[dict[str, Any]]:
- """Iterate over diffusion_perf_*.json files and yield normalized diffusion records."""
+def _iter_diffusion_records(input_dir: str) -> Iterable[dict[str, Any]]:
+ """Iterate over diffusion_result_*.json files and yield normalized records.
+
+ Unlike omni format where each JSON file contains one test case, diffusion format
+ produces a single JSON file containing a list of all test case records.
+ Test params (feature toggles) are NOT embedded in the filename.
+ """
if not os.path.isdir(input_dir):
LOGGER.warning("diffusion input dir '%s' does not exist or is not a directory", input_dir)
return
@@ -424,7 +465,7 @@ def _iter_diffusion_json_records(input_dir: str) -> Iterable[dict[str, Any]]:
for entry in sorted(os.listdir(input_dir)):
if not entry.endswith(".json"):
continue
- if not entry.startswith(_DIFFUSION_JSON_PREFIX):
+ if not entry.startswith(_DIFFUSION_RESULT_PREFIX):
continue
full_path = os.path.join(input_dir, entry)
if not os.path.isfile(full_path):
@@ -434,23 +475,63 @@ def _iter_diffusion_json_records(input_dir: str) -> Iterable[dict[str, Any]]:
if data is None:
continue
- record: dict[str, Any] = dict(data)
- filename_meta = _parse_diffusion_from_filename(os.path.basename(full_path))
- if "date" not in record or not record.get("date"):
- record["date"] = filename_meta.get("date") or datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
- if "test_name" not in record or not record.get("test_name"):
- if "test_name" in filename_meta:
- record["test_name"] = filename_meta["test_name"]
- record["source_file"] = os.path.basename(full_path)
- yield record
+ filename_meta = _parse_diffusion_result_from_filename(os.path.basename(full_path))
+ if not isinstance(data, list):
+ LOGGER.warning("diffusion result file '%s' root is not a list, skip", full_path)
+ continue
+
+ for record in data:
+ if not isinstance(record, dict):
+ continue
+ record = dict(record)
+ if "date" not in record or not record.get("date"):
+ record["date"] = filename_meta.get("date") or datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
+ record["source_file"] = os.path.basename(full_path)
+ yield record
-def _collect_records(input_dir: str) -> list[dict[str, Any]]:
+
+def _collect_omni_records(input_dir: str) -> list[dict[str, Any]]:
return list(_iter_omni_json_records(input_dir))
def _collect_diffusion_records(diffusion_input_dir: str) -> list[dict[str, Any]]:
- return list(_iter_diffusion_json_records(diffusion_input_dir))
+ """Collect diffusion records from diffusion_result_*.json files.
+ Their format is different from omni JSON files.
+ """
+ return [_process_diffusion_record(r) for r in _iter_diffusion_records(diffusion_input_dir)]
+
+
+def _flatten_stage_durations(record: dict[str, Any]) -> dict[str, Any]:
+ """Flatten stage_durations dict into individual columns matching DIFFUSION_STAGE_LATENCY_COLUMNS."""
+ result = dict(record)
+
+ for prefix in ("stage_durations_mean", "stage_durations_p50", "stage_durations_p99"):
+ durations = result.pop(prefix, None)
+ if not isinstance(durations, dict):
+ continue
+
+ suffix = prefix.replace("stage_durations_", "") # "mean", "p50", "p99"
+
+ for stage_key, value in durations.items(): # e.g., "SomePipeline.vae.decode_mean": 100.0
+ stage_key = stage_key.split(".", 1)[-1] # "decode_mean"
+ col_name = f"{stage_key}_{suffix}"
+ if col_name not in DIFFUSION_STAGE_LATENCY_COLUMNS:
+ print(f"skipping stage_key: {col_name}")
+ continue
+ result[col_name] = value
+
+ return result
+
+
+def _process_diffusion_record(record: dict[str, Any]) -> dict[str, Any]:
+ """Normalize a diffusion record by merging `result` and flattening stage metrics."""
+ flat = record.copy()
+ flat.update(flat.pop("result", {}))
+ flat = _flatten_stage_durations(flat)
+ flat.pop("benchmark_params", None)
+ flat.pop("server_params", None)
+ return flat
def _apply_build_metadata_to_latest_only(
@@ -493,7 +574,7 @@ def _apply_build_metadata_to_latest_only(
def _sort_records_for_summary(records: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Sort so that same test configuration is grouped, newest date first within each group."""
- by_date_desc = sorted(records, key=lambda r: (r.get("date") or ""), reverse=True)
+ by_date_desc = sorted(records, key=lambda r: r.get("date") or "", reverse=True)
return sorted(
by_date_desc,
key=_omni_group_key,
@@ -501,7 +582,7 @@ def _sort_records_for_summary(records: list[dict[str, Any]]) -> list[dict[str, A
def _sort_diffusion_records_for_summary(records: list[dict[str, Any]]) -> list[dict[str, Any]]:
- by_date_desc = sorted(records, key=lambda r: (r.get("date") or ""), reverse=True)
+ by_date_desc = sorted(records, key=lambda r: r.get("date") or "", reverse=True)
return sorted(by_date_desc, key=_diffusion_group_key)
@@ -584,6 +665,21 @@ def _to_float_if_numeric(value: Any) -> Any:
return value
+def _to_excel_compatible(value: Any) -> Any:
+ """Convert non-scalar objects to JSON string for openpyxl cell values."""
+ if isinstance(value, (dict, list, tuple)):
+ try:
+ return json.dumps(value, ensure_ascii=False, sort_keys=True)
+ except (TypeError, ValueError):
+ return str(value)
+ if isinstance(value, set):
+ try:
+ return json.dumps(sorted(value), ensure_ascii=False)
+ except (TypeError, ValueError):
+ return str(value)
+ return value
+
+
def _write_sheet(
ws,
columns: Sequence[str],
@@ -599,6 +695,7 @@ def _write_sheet(
v = record.get(col)
if col in numeric_set:
v = _to_float_if_numeric(v)
+ v = _to_excel_compatible(v)
row_values.append(v)
ws.append(row_values)
@@ -678,7 +775,7 @@ def generate_excel_report(
script_dir = os.path.dirname(os.path.abspath(__file__))
omni_summary_columns = _ensure_omni_summary_columns(_load_summary_columns(script_dir))
- omni_records = _collect_records(input_dir)
+ omni_records = _collect_omni_records(input_dir)
diffusion_records = _collect_diffusion_records(diffusion_input_dir)
if not omni_records:
diff --git a/tools/nightly/generate_nightly_perf_html.py b/tools/nightly/generate_nightly_perf_html.py
index 05dc48d717..a50a462550 100644
--- a/tools/nightly/generate_nightly_perf_html.py
+++ b/tools/nightly/generate_nightly_perf_html.py
@@ -17,7 +17,7 @@
LOGGER = logging.getLogger(__name__)
_RESULT_JSON_PREFIX = "result_test_"
-_DIFFUSION_JSON_PREFIX = "diffusion_perf_"
+_DIFFUSION_JSON_PREFIXES = ("diffusion_perf_", "diffusion_result_")
DEFAULT_INPUT_DIR = os.getenv("DEFAULT_INPUT_DIR") or "tests"
DEFAULT_OUTPUT_DIR = os.getenv("DEFAULT_OUTPUT_DIR") or "tests"
DEFAULT_DIFFUSION_INPUT_DIR = os.getenv("DIFFUSION_BENCHMARK_DIR")
@@ -51,7 +51,7 @@ def _default_diffusion_input_dir(input_dir: str) -> str:
return DEFAULT_DIFFUSION_INPUT_DIR if DEFAULT_DIFFUSION_INPUT_DIR else input_dir
-def _load_json_file(path: str) -> dict[str, Any] | None:
+def _load_json_file(path: str) -> dict[str, Any] | list[Any] | None:
try:
with open(path, encoding="utf-8") as f:
data = json.load(f)
@@ -59,14 +59,15 @@ def _load_json_file(path: str) -> dict[str, Any] | None:
LOGGER.warning("failed to load json '%s': %s", path, exc)
return None
- if not isinstance(data, dict):
- LOGGER.warning("json root in '%s' is not an object, skip", path)
+ if not isinstance(data, (dict, list)):
+ LOGGER.warning("json root in '%s' is not an object or list, skip", path)
return None
return data
def _parse_from_filename(filename: str) -> dict[str, Any]:
+ """Parse ``result_test_*.json`` filenames; same rules as ``generate_nightly_perf_excel``."""
name, ext = os.path.splitext(filename)
if ext != ".json" or not name.startswith(_RESULT_JSON_PREFIX):
return {}
@@ -75,32 +76,58 @@ def _parse_from_filename(filename: str) -> dict[str, Any]:
parts = core.split("_")
if len(parts) < 5:
LOGGER.warning(
- "filename '%s' does not match expected pattern, skip parsing test metadata",
+ "filename '%s' does not match expected pattern (need >= 5 segments), skip parsing",
filename,
)
return {}
- timestamp = parts[-1]
- num_prompts_str = parts[-2]
- max_concurrency_str = parts[-3]
- dataset_name = parts[-4]
- test_name = "_".join(parts[:-4]) if parts[:-4] else ""
+ idx = len(parts) - 1
+ timestamp = parts[idx]
+ idx -= 1
parsed: dict[str, Any] = {}
if len(timestamp) >= 15:
parsed["date"] = timestamp
- if dataset_name in ("random", "random-mm"):
- parsed["dataset_name"] = dataset_name
+
+ if idx >= 0 and parts[idx].startswith("out"):
+ parsed["random_output_len"] = parts[idx][3:]
+ idx -= 1
+ if idx >= 0 and parts[idx].startswith("in"):
+ parsed["random_input_len"] = parts[idx][2:]
+ idx -= 1
+
+ if idx < 3:
+ LOGGER.warning(
+ "filename '%s' has too few segments after timestamp / optional in-out (idx=%s)",
+ filename,
+ idx,
+ )
+ return parsed
+
+ num_prompts_str = parts[idx]
+ idx -= 1
+ flow_str = parts[idx]
+ idx -= 1
+ dataset_name = parts[idx]
+ idx -= 1
+ test_name = "_".join(parts[: idx + 1]) if idx >= 0 else ""
+
try:
parsed["num_prompts"] = int(num_prompts_str)
except (TypeError, ValueError):
pass
+
try:
- parsed["max_concurrency"] = int(max_concurrency_str)
+ parsed["max_concurrency"] = int(flow_str)
except (TypeError, ValueError):
pass
+
if test_name:
parsed["test_name"] = test_name
+
+ if dataset_name in ("random", "random-mm"):
+ parsed["dataset_name"] = dataset_name
+
return parsed
@@ -143,9 +170,10 @@ def _iter_omni_json_records(input_dir: str) -> Iterable[dict[str, Any]]:
def _parse_diffusion_from_filename(filename: str) -> dict[str, Any]:
name, ext = os.path.splitext(filename)
- if ext != ".json" or not name.startswith(_DIFFUSION_JSON_PREFIX):
+ if ext != ".json" or not any(name.startswith(prefix) for prefix in _DIFFUSION_JSON_PREFIXES):
return {}
- core = name[len(_DIFFUSION_JSON_PREFIX) :]
+ matched_prefix = next(prefix for prefix in _DIFFUSION_JSON_PREFIXES if name.startswith(prefix))
+ core = name[len(matched_prefix) :]
parts = core.split("_")
if len(parts) < 2:
return {}
@@ -168,7 +196,7 @@ def _iter_diffusion_json_records(input_dir: str) -> Iterable[dict[str, Any]]:
return
for entry in sorted(os.listdir(input_dir)):
- if not entry.endswith(".json") or not entry.startswith(_DIFFUSION_JSON_PREFIX):
+ if not entry.endswith(".json") or not any(entry.startswith(prefix) for prefix in _DIFFUSION_JSON_PREFIXES):
continue
full_path = os.path.join(input_dir, entry)
if not os.path.isfile(full_path):
@@ -177,17 +205,32 @@ def _iter_diffusion_json_records(input_dir: str) -> Iterable[dict[str, Any]]:
if data is None:
continue
- record: dict[str, Any] = dict(data)
filename_meta = _parse_diffusion_from_filename(os.path.basename(full_path))
- if "date" not in record or not record.get("date"):
- record["date"] = filename_meta.get("date") or datetime.now(
- timezone.utc,
- ).strftime("%Y%m%d-%H%M%S")
- if "test_name" not in record or not record.get("test_name"):
- if "test_name" in filename_meta:
- record["test_name"] = filename_meta["test_name"]
- record["source_file"] = os.path.basename(full_path)
- yield record
+ if isinstance(data, dict):
+ records = [data]
+ elif isinstance(data, list):
+ records = [r for r in data if isinstance(r, dict)]
+ else:
+ records = []
+
+ if not records:
+ LOGGER.warning("diffusion json '%s' has no valid records, skip", full_path)
+ continue
+
+ for record in records:
+ flat: dict[str, Any] = dict(record)
+ result = flat.pop("result", None)
+ if isinstance(result, dict):
+ flat.update(result)
+ if "date" not in flat or not flat.get("date"):
+ flat["date"] = filename_meta.get("date") or datetime.now(
+ timezone.utc,
+ ).strftime("%Y%m%d-%H%M%S")
+ if "test_name" not in flat or not flat.get("test_name"):
+ if "test_name" in filename_meta:
+ flat["test_name"] = filename_meta["test_name"]
+ flat["source_file"] = os.path.basename(full_path)
+ yield flat
def _collect_omni_records(input_dir: str) -> list[dict[str, Any]]:
diff --git a/tools/wan22/assemble_wan22_i2v_diffusers.py b/tools/wan22/assemble_wan22_i2v_diffusers.py
new file mode 100644
index 0000000000..8e14ca3c26
--- /dev/null
+++ b/tools/wan22/assemble_wan22_i2v_diffusers.py
@@ -0,0 +1,385 @@
+#!/usr/bin/env python3
+"""
+Assemble a Wan2.2-I2V-A14B-Diffusers-style model directory using a Diffusers
+skeleton and optional replacement transformer checkpoints.
+
+This tool does NOT run any external conversion step. You can use it in two
+ways:
+- keep the original weights from the Diffusers skeleton
+- replace transformer/transformer_2 with converted checkpoints such as
+ LightX2V outputs
+- use legacy LightX2V arg names (--high-noise-weight/--low-noise-weight),
+ which are accepted as aliases
+
+Typical use:
+ python tools/wan22/assemble_wan22_i2v_diffusers.py \
+ --diffusers-skeleton /path/to/Wan2.2-I2V-A14B-Diffusers \
+ --transformer-weight /path/to/high_noise_out/diffusion_pytorch_model.safetensors \
+ --transformer-2-weight /path/to/low_noise_out/diffusion_pytorch_model.safetensors \
+ --output-dir /path/to/Wan2.2-I2V-A14B-Custom-Diffusers
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import shutil
+import sys
+from dataclasses import dataclass
+from pathlib import Path
+
+WEIGHT_CANDIDATES = (
+ "diffusion_pytorch_model.safetensors",
+ "diffusion_pytorch_model.bin",
+ "diffusion_pytorch_model.pt",
+ "model.safetensors",
+ "pytorch_model.bin",
+ "model.pt",
+)
+WEIGHT_INDEX_CANDIDATES = (
+ "diffusion_pytorch_model.safetensors.index.json",
+ "model.safetensors.index.json",
+ "pytorch_model.bin.index.json",
+)
+
+ROOT_REQUIRED_FILES = ("model_index.json",)
+ROOT_REQUIRED_DIRS = ("tokenizer", "text_encoder", "vae", "transformer", "transformer_2")
+OPTIONAL_DIRS = ("image_encoder", "image_processor", "scheduler", "feature_extractor")
+
+
+class AssembleError(RuntimeError):
+ pass
+
+
+@dataclass(frozen=True)
+class WeightSpec:
+ kind: str # "single" | "sharded"
+ single_file: Path | None = None
+ index_file: Path | None = None
+ shard_files: tuple[Path, ...] = ()
+
+
+def _load_shard_files_from_index(index_file: Path, role: str) -> tuple[Path, ...]:
+ try:
+ with index_file.open(encoding="utf-8") as f:
+ payload = json.load(f)
+ except Exception as exc:
+ raise AssembleError(f"Failed to parse {role} index file: {index_file}. error={exc}") from exc
+
+ weight_map = payload.get("weight_map")
+ if not isinstance(weight_map, dict) or not weight_map:
+ raise AssembleError(f"Invalid {role} index file (missing/empty weight_map): {index_file}")
+
+ shard_names = sorted({str(v) for v in weight_map.values()})
+ shard_paths: list[Path] = []
+ missing: list[str] = []
+ for shard_name in shard_names:
+ shard_path = index_file.parent / shard_name
+ if not shard_path.is_file():
+ missing.append(str(shard_path))
+ else:
+ shard_paths.append(shard_path)
+
+ if missing:
+ raise AssembleError(f"{role} index references missing shard file(s): " + ", ".join(missing))
+
+ if not shard_paths:
+ raise AssembleError(f"No shard files referenced by {role} index: {index_file}")
+
+ return tuple(shard_paths)
+
+
+def _resolve_weight_spec(path: Path, role: str) -> WeightSpec:
+ if path.is_file():
+ return WeightSpec(kind="single", single_file=path)
+
+ if path.is_dir():
+ for name in WEIGHT_CANDIDATES:
+ candidate = path / name
+ if candidate.is_file():
+ return WeightSpec(kind="single", single_file=candidate)
+
+ for index_name in WEIGHT_INDEX_CANDIDATES:
+ index_file = path / index_name
+ if not index_file.is_file():
+ continue
+ shard_files = _load_shard_files_from_index(index_file, role=role)
+ return WeightSpec(
+ kind="sharded",
+ index_file=index_file,
+ shard_files=shard_files,
+ )
+
+ shard_candidates = sorted(path.glob("diffusion_pytorch_model-*.safetensors"))
+ if shard_candidates:
+ raise AssembleError(
+ f"Detected sharded {role} files under {path}, but index json is missing. "
+ f"Expected one of: {', '.join(WEIGHT_INDEX_CANDIDATES)}"
+ )
+
+ raise AssembleError(
+ f"Cannot find {role} weight under directory: {path}. "
+ f"Expected one of single files [{', '.join(WEIGHT_CANDIDATES)}] "
+ f"or sharded index files [{', '.join(WEIGHT_INDEX_CANDIDATES)}]."
+ )
+
+ raise AssembleError(f"{role} path does not exist: {path}")
+
+
+def _canonical_weight_name(weight_file: Path) -> str:
+ suffix = weight_file.suffix.lower()
+ if suffix == ".safetensors":
+ return "diffusion_pytorch_model.safetensors"
+ if suffix == ".bin":
+ return "diffusion_pytorch_model.bin"
+ if suffix == ".pt":
+ return "diffusion_pytorch_model.pt"
+ return weight_file.name
+
+
+def _validate_skeleton(skeleton: Path) -> None:
+ if not skeleton.is_dir():
+ raise AssembleError(f"--diffusers-skeleton is not a directory: {skeleton}")
+
+ for file_name in ROOT_REQUIRED_FILES:
+ if not (skeleton / file_name).is_file():
+ raise AssembleError(f"Missing required file in skeleton: {skeleton / file_name}")
+
+ for dir_name in ROOT_REQUIRED_DIRS:
+ if not (skeleton / dir_name).is_dir():
+ raise AssembleError(f"Missing required directory in skeleton: {skeleton / dir_name}")
+
+ if not (skeleton / "transformer" / "config.json").is_file():
+ raise AssembleError(f"Missing transformer config: {skeleton / 'transformer/config.json'}")
+
+ if not (skeleton / "transformer_2" / "config.json").is_file():
+ raise AssembleError(f"Missing transformer_2 config: {skeleton / 'transformer_2/config.json'}")
+
+
+def _ensure_clean_output(output_dir: Path, overwrite: bool) -> None:
+ if output_dir.exists():
+ if not overwrite:
+ raise AssembleError(
+ f"Output directory already exists: {output_dir}. Use --overwrite to remove and recreate it."
+ )
+ shutil.rmtree(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+
+def _copy_or_link_dir(src: Path, dst: Path, asset_mode: str) -> None:
+ if asset_mode == "copy":
+ shutil.copytree(src, dst)
+ elif asset_mode == "symlink":
+ dst.symlink_to(src, target_is_directory=True)
+ else:
+ raise AssembleError(f"Unknown asset mode: {asset_mode}")
+
+
+def _materialize_weight(weight: WeightSpec, dst_dir: Path, role: str) -> tuple[Path, ...]:
+ if weight.kind == "single":
+ assert weight.single_file is not None
+ dst = dst_dir / _canonical_weight_name(weight.single_file)
+ shutil.copy2(weight.single_file, dst)
+ return (dst,)
+
+ if weight.kind == "sharded":
+ assert weight.index_file is not None
+ copied: list[Path] = []
+ index_dst = dst_dir / weight.index_file.name
+ shutil.copy2(weight.index_file, index_dst)
+ copied.append(index_dst)
+ for shard_file in weight.shard_files:
+ shard_dst = dst_dir / shard_file.name
+ shutil.copy2(shard_file, shard_dst)
+ copied.append(shard_dst)
+ return tuple(copied)
+
+ raise AssembleError(f"Unknown {role} weight kind: {weight.kind}")
+
+
+def _assemble(
+ skeleton: Path,
+ output_dir: Path,
+ transformer_weight: WeightSpec,
+ transformer_2_weight: WeightSpec,
+ asset_mode: str,
+) -> tuple[tuple[Path, ...], tuple[Path, ...]]:
+ shutil.copy2(skeleton / "model_index.json", output_dir / "model_index.json")
+
+ for dir_name in ROOT_REQUIRED_DIRS:
+ if dir_name in ("transformer", "transformer_2"):
+ continue
+ _copy_or_link_dir(skeleton / dir_name, output_dir / dir_name, asset_mode)
+
+ for dir_name in OPTIONAL_DIRS:
+ src_dir = skeleton / dir_name
+ if src_dir.is_dir():
+ _copy_or_link_dir(src_dir, output_dir / dir_name, asset_mode)
+
+ (output_dir / "transformer").mkdir(parents=True, exist_ok=True)
+ (output_dir / "transformer_2").mkdir(parents=True, exist_ok=True)
+
+ shutil.copy2(skeleton / "transformer" / "config.json", output_dir / "transformer" / "config.json")
+ shutil.copy2(skeleton / "transformer_2" / "config.json", output_dir / "transformer_2" / "config.json")
+
+ transformer_copied = _materialize_weight(transformer_weight, output_dir / "transformer", role="transformer")
+ transformer_2_copied = _materialize_weight(
+ transformer_2_weight,
+ output_dir / "transformer_2",
+ role="transformer_2",
+ )
+
+ return transformer_copied, transformer_2_copied
+
+
+def _validate_output(
+ output_dir: Path,
+ transformer_copied: tuple[Path, ...],
+ transformer_2_copied: tuple[Path, ...],
+) -> None:
+ if not (output_dir / "model_index.json").is_file():
+ raise AssembleError("Output validation failed: model_index.json missing")
+
+ required_paths = (
+ output_dir / "tokenizer",
+ output_dir / "text_encoder",
+ output_dir / "vae",
+ output_dir / "transformer" / "config.json",
+ output_dir / "transformer_2" / "config.json",
+ *transformer_copied,
+ *transformer_2_copied,
+ )
+ missing = [str(p) for p in required_paths if not p.exists()]
+ if missing:
+ raise AssembleError("Output validation failed, missing: " + ", ".join(missing))
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description=(
+ "Assemble a Wan2.2-I2V-A14B-Diffusers directory while optionally "
+ "replacing transformer and transformer_2 weights."
+ )
+ )
+ parser.add_argument(
+ "--diffusers-skeleton",
+ type=Path,
+ required=True,
+ help="Path to a local Wan-AI/Wan2.2-I2V-A14B-Diffusers directory.",
+ )
+ parser.add_argument(
+ "--transformer-weight",
+ type=Path,
+ help=(
+ "Optional checkpoint file, or directory containing either a single-file "
+ "weight or sharded index+shards for transformer/. If omitted, keep the "
+ "skeleton's original transformer weights."
+ ),
+ )
+ parser.add_argument(
+ "--transformer-2-weight",
+ type=Path,
+ help=(
+ "Optional checkpoint file, or directory containing either a single-file "
+ "weight or sharded index+shards for transformer_2/. If omitted, keep the "
+ "skeleton's original transformer_2 weights."
+ ),
+ )
+ parser.add_argument(
+ "--high-noise-weight",
+ type=Path,
+ help=argparse.SUPPRESS,
+ )
+ parser.add_argument(
+ "--low-noise-weight",
+ type=Path,
+ help=argparse.SUPPRESS,
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=Path,
+ required=True,
+ help="Output directory for the assembled model.",
+ )
+ parser.add_argument(
+ "--asset-mode",
+ choices=("symlink", "copy"),
+ default="symlink",
+ help=(
+ "How to materialize non-transformer assets (tokenizer/text_encoder/vae/optional dirs). "
+ "symlink saves disk and is default."
+ ),
+ )
+ parser.add_argument(
+ "--overwrite",
+ action="store_true",
+ help="Overwrite output-dir if it exists.",
+ )
+ return parser.parse_args()
+
+
+def main() -> int:
+ args = parse_args()
+
+ skeleton = args.diffusers_skeleton.resolve()
+ output_dir = args.output_dir.resolve()
+
+ if args.transformer_weight is not None and args.high_noise_weight is not None:
+ print(
+ "[ERROR] --transformer-weight and --high-noise-weight are aliases; please provide only one.",
+ file=sys.stderr,
+ )
+ return 2
+ if args.transformer_2_weight is not None and args.low_noise_weight is not None:
+ print(
+ "[ERROR] --transformer-2-weight and --low-noise-weight are aliases; please provide only one.",
+ file=sys.stderr,
+ )
+ return 2
+
+ transformer_weight_arg = args.transformer_weight if args.transformer_weight is not None else args.high_noise_weight
+ transformer_2_weight_arg = (
+ args.transformer_2_weight if args.transformer_2_weight is not None else args.low_noise_weight
+ )
+
+ transformer_input = (
+ transformer_weight_arg.resolve() if transformer_weight_arg is not None else skeleton / "transformer"
+ )
+ transformer_2_input = (
+ transformer_2_weight_arg.resolve() if transformer_2_weight_arg is not None else skeleton / "transformer_2"
+ )
+
+ try:
+ _validate_skeleton(skeleton)
+ transformer_weight = _resolve_weight_spec(transformer_input, role="transformer")
+ transformer_2_weight = _resolve_weight_spec(transformer_2_input, role="transformer_2")
+
+ _ensure_clean_output(output_dir, overwrite=args.overwrite)
+ transformer_copied, transformer_2_copied = _assemble(
+ skeleton=skeleton,
+ output_dir=output_dir,
+ transformer_weight=transformer_weight,
+ transformer_2_weight=transformer_2_weight,
+ asset_mode=args.asset_mode,
+ )
+ _validate_output(output_dir, transformer_copied, transformer_2_copied)
+ except AssembleError as exc:
+ print(f"[ERROR] {exc}", file=sys.stderr)
+ return 2
+
+ def _weight_summary(copied: tuple[Path, ...]) -> str:
+ if len(copied) == 1:
+ return copied[0].name
+ return f"{copied[0].name} + {len(copied) - 1} shard files"
+
+ print("[OK] Assembled Wan2.2 I2V Diffusers directory:")
+ print(f" output_dir: {output_dir}")
+ print(f" transformer weight: {_weight_summary(transformer_copied)}")
+ print(f" transformer_2 weight: {_weight_summary(transformer_2_copied)}")
+ print("\nUse it with vLLM-Omni, for example:")
+ print(f" vllm serve {output_dir} --omni --port 8091")
+ return 0
+
+
+if __name__ == "__main__":
+ raise SystemExit(main())
diff --git a/vllm_omni/__init__.py b/vllm_omni/__init__.py
index cec8b0af7e..65ad79c725 100644
--- a/vllm_omni/__init__.py
+++ b/vllm_omni/__init__.py
@@ -12,6 +12,12 @@
processing
"""
+# We import version early, because it will warn if vLLM / vLLM Omni
+# are not using the same major + minor version (if vLLM is installed).
+# We should do this before applying patch, because vLLM imports might
+# throw in patch if the versions differ.
+from .version import __version__, __version_tuple__ # isort:skip # noqa: F401
+
try:
from . import patch # noqa: F401
except ModuleNotFoundError as exc: # pragma: no cover - optional dependency
@@ -25,8 +31,6 @@
from .config import OmniModelConfig
-from .version import __version__, __version_tuple__ # isort:skip
-
def __getattr__(name: str):
# Lazy import for AsyncOmni and Omni to avoid pulling in heavy
diff --git a/vllm_omni/assets/video.py b/vllm_omni/assets/video.py
index 98b1f7e4e2..6a5f3204a9 100644
--- a/vllm_omni/assets/video.py
+++ b/vllm_omni/assets/video.py
@@ -1,6 +1,6 @@
-import librosa
import numpy as np
from vllm.assets.video import VideoAsset
+from vllm.multimodal.media.audio import load_audio
def extract_video_audio(path: str = None, sampling_rate: int = 16000) -> np.ndarray:
@@ -12,5 +12,5 @@ def extract_video_audio(path: str = None, sampling_rate: int = 16000) -> np.ndar
"""
if not path:
path = VideoAsset(name="baby_reading").video_path
- audio_signal, sr = librosa.load(path, sr=sampling_rate)
+ audio_signal, sr = load_audio(path, sr=sampling_rate)
return audio_signal
diff --git a/vllm_omni/benchmarks/data_modules/daily_omni_dataset.py b/vllm_omni/benchmarks/data_modules/daily_omni_dataset.py
new file mode 100644
index 0000000000..01b86d0fd1
--- /dev/null
+++ b/vllm_omni/benchmarks/data_modules/daily_omni_dataset.py
@@ -0,0 +1,887 @@
+"""Daily-Omni Dataset loader for benchmark.
+
+Daily-Omni is an audio-visual reasoning benchmark with 684 videos
+and 1,197 multiple-choice QA pairs across 6 major task types.
+
+Dataset source: https://huggingface.co/datasets/liarliar/Daily-Omni
+
+Supports loading QA metadata from:
+- Local JSON file (``qa_json_path``): recommended for offline/air-gapped environments
+- HuggingFace datasets (``dataset_path``): legacy online mode
+
+The videos must be separately downloaded and extracted from Videos.tar.
+
+Why ``BenchmarkDataset`` instead of ``HuggingFaceDataset``?
+ vLLM's ``HuggingFaceDataset`` is a thin wrapper whose ``__init__`` always ends by calling
+ ``load_data()`` → ``datasets.load_dataset(...)`` with a required Hub id and split. That
+ contract fits "Hub-only" benches, but Daily-Omni also needs **offline QA metadata** from a
+ local ``qa.json`` without touching the network. Subclassing ``HuggingFaceDataset`` would
+ mean fighting the parent constructor (fake ``dataset_path``, reordering ``load_data``, or
+ duplicating half the parent) and would still imply ``datasets`` is always relevant.
+
+ This class therefore inherits only ``BenchmarkDataset`` (minimal: ``dataset_path``,
+ ``random_seed``, ``self.data``) and implements **two explicit loaders**:
+ ``_load_from_local_json`` (default path for air-gapped runs) and ``_load_from_huggingface``
+ (optional legacy path for users who prefer ``datasets`` + Hub cache). The latter is **not**
+ inheritance; it is the same Hub rows as before, factored into a helper so one class can
+ serve both deployment modes without mandatory ``datasets`` when using ``qa_json_path``.
+
+Usage:
+ from vllm_omni.benchmarks.data_modules.daily_omni_dataset import DailyOmniDataset
+
+ # Local JSON mode (recommended)
+ dataset = DailyOmniDataset(
+ qa_json_path="/path/to/qa.json",
+ video_dir="/path/to/Videos",
+ random_seed=42,
+ )
+
+ # HuggingFace mode (legacy, requires network)
+ dataset = DailyOmniDataset(
+ dataset_path="liarliar/Daily-Omni",
+ dataset_split="train",
+ random_seed=42,
+ )
+ requests = dataset.sample(
+ tokenizer=tokenizer,
+ num_requests=100,
+ output_len=256,
+ )
+"""
+
+import base64
+import json
+import logging
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Literal
+
+try:
+ from vllm.benchmarks.datasets import BenchmarkDataset, SampleRequest
+except ImportError:
+ # Fallback: if BenchmarkDataset not available, use base class from same module
+ from vllm.benchmarks.datasets import HuggingFaceDataset as BenchmarkDataset
+ from vllm.benchmarks.datasets import SampleRequest
+from vllm.tokenizers import TokenizerLike
+from vllm.tokenizers.hf import get_cached_tokenizer
+
+try:
+ from datasets import load_dataset
+except ImportError:
+ load_dataset = None
+
+logger = logging.getLogger(__name__)
+
+
+class _ListDatasetIterator:
+ """Simple iterator wrapper around a list to mimic HuggingFace streaming dataset behavior."""
+
+ def __init__(self, data: list[dict[str, Any]]) -> None:
+ self._data = data
+ self._index = 0
+
+ def __iter__(self):
+ self._index = 0
+ return self
+
+ def __next__(self) -> dict[str, Any]:
+ if self._index >= len(self._data):
+ raise StopIteration
+ item = self._data[self._index]
+ self._index += 1
+ return item
+
+ def __len__(self) -> int:
+ return len(self._data)
+
+ def __getitem__(self, idx: int | slice) -> dict[str, Any] | list[dict[str, Any]]:
+ return self._data[idx]
+
+
+# Aligns with Lliar-liar/Daily-Omni CLI ``--input_mode`` (test_model/*/testmodel.py).
+DailyOmniInputMode = Literal["all", "visual", "audio"]
+
+# ``build_conversation()`` in Daily-Omni ``test_model/Qwen2.5-Omni/testmodel.py`` (verbatim).
+DAILY_OMNI_SYSTEM_TEXT = (
+ "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
+ "capable of perceiving auditory and visual inputs, as well as generating text and speech."
+)
+
+
+@dataclass
+class DailyOmniSampleRequest(SampleRequest):
+ """``SampleRequest`` with Daily-Omni gold labels for post-run accuracy scoring."""
+
+ daily_omni_gold_answer: str = ""
+ daily_omni_video_id: str = ""
+ daily_omni_task_type: str = ""
+ #: Official qa.json ``video_duration`` (e.g. ``30s``, ``60s``) for leaderboard-style breakdown.
+ daily_omni_video_duration: str = ""
+ #: Official ``video_category`` (YouTube-style category string) for per-category accuracy.
+ daily_omni_video_category: str = ""
+ #: Extra JSON fields merged into chat-completions ``extra_body`` (e.g. ``mm_processor_kwargs``).
+ omni_extra_body: dict[str, Any] | None = None
+ #: Full OpenAI ``messages`` (system + user) mirroring upstream Daily-Omni conversation.
+ omni_chat_messages: list[dict[str, Any]] | None = None
+ #: Used only when ``omni_chat_messages`` is None (non-Daily-Omni-style requests).
+ omni_chat_mm_position: Literal["first", "last"] = "last"
+
+
+class DailyOmniDataset(BenchmarkDataset):
+ """Daily-Omni audio-visual QA dataset for benchmarking.
+
+ Inherits ``BenchmarkDataset`` only (not ``HuggingFaceDataset``): see module docstring for why
+ Hub loading lives in ``_load_from_huggingface`` instead of subclassing the HF base class.
+
+ The dataset includes:
+ - 684 videos from daily life scenarios (available in Videos.tar)
+ - 1,197 multiple-choice QA pairs in qa.json
+ - 6 major task categories
+
+ QA metadata can be loaded from:
+ - Local JSON file (``qa_json_path``): recommended for offline/air-gapped environments
+ - HuggingFace datasets (``dataset_path``): legacy online mode
+
+ The videos must be separately downloaded and extracted from Videos.tar.
+
+ Args:
+ qa_json_path: Path to local qa.json file (offline mode, preferred). When provided,
+ ``dataset_path`` and ``dataset_split`` are ignored.
+ dataset_path: HuggingFace dataset path (e.g., "liarliar/Daily-Omni"). Used only if
+ ``qa_json_path`` is not provided (legacy online mode).
+ dataset_split: Dataset split to use (default: "train"). Used only in online mode.
+ random_seed: Random seed for shuffling
+ video_dir: Directory containing extracted video files (default: None)
+ input_mode: Which modalities to send, matching upstream Daily-Omni ``--input_mode``:
+ ``all`` — video + WAV (default; official audio-visual protocol);
+ ``visual`` — video only;
+ ``audio`` — extracted WAV only (requires ``{video_id}/{video_id}_audio.wav`` under ``video_dir``).
+ max_duration_seconds: Reserved for future ffprobe-based filtering; currently **not applied**
+ when building requests (metadata ``video_duration`` is still passed through for eval).
+ dataset_subset: Optional HuggingFace subset name (``load_dataset(..., name=...)``); used by bench
+ ``--hf-subset`` / patch.
+ no_stream: If True, load the Hub split non-streaming (matches bench ``--no-stream``).
+ inline_local_video: If True, embed local MP4 as ``data:video/mp4;base64,...`` in requests so
+ the API server does not need ``--allowed-local-media-path`` (large JSON; use for small runs).
+ When ``input_mode`` is ``audio`` or ``all``, local WAV is embedded the same way
+ (``data:audio/wav;base64,...``).
+ trust_remote_code: Whether to trust remote code when loading HuggingFace dataset
+ (online mode only).
+ """
+
+ SUPPORTED_DATASET_PATHS: set[str] = {
+ "liarliar/Daily-Omni",
+ }
+ #: Default Hub id for synthetic video URLs when ``qa_json_path`` is used (``dataset_path`` None).
+ DEFAULT_HF_DATASET_ID = "liarliar/Daily-Omni"
+ IS_MULTIMODAL = True
+ DEFAULT_OUTPUT_LEN = 256
+
+ def __init__(
+ self,
+ qa_json_path: str | None = None,
+ dataset_path: str | None = None,
+ dataset_split: str = "train",
+ random_seed: int = 0,
+ video_dir: str | None = None,
+ input_mode: DailyOmniInputMode = "all",
+ inline_local_video: bool = False,
+ trust_remote_code: bool = False,
+ max_duration_seconds: float | None = None,
+ dataset_subset: str | None = None,
+ no_stream: bool = False,
+ **kwargs,
+ ) -> None:
+ if input_mode not in ("all", "visual", "audio"):
+ raise ValueError(f"input_mode must be 'all', 'visual', or 'audio', got {input_mode!r}")
+
+ # Validate arguments: need either local JSON or HF path
+ if qa_json_path is None and dataset_path is None:
+ raise ValueError(
+ "Either 'qa_json_path' (local JSON) or 'dataset_path' (HuggingFace) must be provided. "
+ "For offline/air-gapped environments, download qa.json and use qa_json_path."
+ )
+
+ # Store configuration
+ self.qa_json_path = Path(qa_json_path) if qa_json_path else None
+ self.dataset_path = dataset_path
+ self.dataset_split = dataset_split
+ self.dataset_subset = dataset_subset
+ #: Match vLLM ``HuggingFaceDataset`` / bench CLI ``--no-stream``.
+ self._hf_streaming = not no_stream
+ self.video_dir = Path(video_dir) if video_dir else None
+ self.inline_local_video = inline_local_video
+ self.input_mode: DailyOmniInputMode = input_mode
+ self.max_duration_seconds = max_duration_seconds
+ self.trust_remote_code = trust_remote_code
+
+ #: In-process cache of ffprobe durations only (no disk persistence).
+ self._video_durations: dict[str, float] = {}
+
+ # Initialize parent BenchmarkDataset
+ super().__init__(
+ dataset_path=dataset_path if qa_json_path is None else None,
+ random_seed=random_seed,
+ **kwargs,
+ )
+
+ # Load data based on mode
+ self.load_data()
+
+ # Verify dataset info
+ logger.info(
+ "Loaded Daily-Omni dataset: mode=%s, source=%s, random_seed=%d, input_mode=%s, max_duration=%s",
+ "local_json" if self.qa_json_path else "huggingface",
+ str(self.qa_json_path) if self.qa_json_path else f"{dataset_path}/{dataset_split}",
+ random_seed,
+ input_mode,
+ f"{max_duration_seconds}s" if max_duration_seconds else "unlimited",
+ )
+
+ def load_data(self) -> None:
+ """Populate ``self.data`` from either local JSON or the Hub.
+
+ See module docstring: we do not subclass ``HuggingFaceDataset`` because Daily-Omni needs
+ a first-class offline path; Hub loading is an optional branch implemented below.
+ """
+ if self.qa_json_path is not None:
+ self._load_from_local_json()
+ else:
+ self._load_from_huggingface()
+
+ def _load_from_local_json(self) -> None:
+ """Load QA data from local JSON file."""
+ if not self.qa_json_path.exists():
+ raise FileNotFoundError(f"QA JSON file not found: {self.qa_json_path}")
+
+ with open(self.qa_json_path, encoding="utf-8") as f:
+ data = json.load(f)
+
+ # Support both list format and dict with "train"/"test" splits
+ if isinstance(data, dict):
+ # Try to get the requested split, fallback to first available
+ split_data = data.get(self.dataset_split)
+ if split_data is None:
+ available = list(data.keys())
+ if available:
+ logger.warning(
+ "Split '%s' not found in %s, using '%s' instead",
+ self.dataset_split,
+ self.qa_json_path,
+ available[0],
+ )
+ split_data = data[available[0]]
+ else:
+ split_data = []
+ data = split_data
+
+ if not isinstance(data, list):
+ raise ValueError(f"Expected list of QA items in JSON, got {type(data).__name__}")
+
+ # Shuffle if requested
+ if not getattr(self, "disable_shuffle", False) and self.random_seed is not None:
+ import random
+
+ rng = random.Random(self.random_seed)
+ shuffled = data[:]
+ rng.shuffle(shuffled)
+ data = shuffled
+
+ # Create an iterator-like wrapper for compatibility
+ self.data = _ListDatasetIterator(data)
+
+ def _load_from_huggingface(self) -> None:
+ """Load QA rows via ``datasets.load_dataset`` (legacy / convenience path).
+
+ Kept for backward compatibility: callers can still pass ``dataset_path=liarliar/Daily-Omni``
+ and get the same parquet-backed rows as the Hub dataset card, with streaming (or
+ non-streaming if ``no_stream=True``) and shuffle.
+
+ This is intentionally **not** implemented by subclassing ``HuggingFaceDataset``: that base
+ always runs Hub ``load_dataset`` from its constructor and expects a Hub id as the primary
+ API; Daily-Omni instead chooses the source in ``load_data()`` (JSON vs Hub) while sharing
+ one ``sample()`` / request-building implementation for both.
+ """
+ if load_dataset is None:
+ raise ImportError(
+ "datasets library is required for HuggingFace mode. "
+ "Install with: pip install datasets, or use local JSON mode instead."
+ )
+
+ ds = load_dataset(
+ self.dataset_path,
+ name=self.dataset_subset,
+ split=self.dataset_split,
+ streaming=self._hf_streaming,
+ trust_remote_code=self.trust_remote_code,
+ )
+ if not getattr(self, "disable_shuffle", False):
+ ds = ds.shuffle(seed=self.random_seed)
+ self.data = ds
+
+ def get_task_statistics(self) -> dict[str, int]:
+ """Get distribution of task types in the dataset.
+
+ Returns:
+ Dict mapping task type to count
+ """
+ stats: dict[str, int] = {}
+ for item in self.data:
+ row = self._coerce_row(item)
+ fields = self._normalize_qa_fields(row)
+ task_type = fields["task_type"] or "unknown"
+ stats[task_type] = stats.get(task_type, 0) + 1
+ return stats
+
+ @staticmethod
+ def _coerce_row(item: Any) -> dict[str, Any]:
+ """Turn a dataset row into a plain dict (Arrow / Mapping)."""
+ if isinstance(item, dict):
+ return item
+ if hasattr(item, "as_py"):
+ return dict(item.as_py()) # pyarrow Row
+ try:
+ return dict(item)
+ except (TypeError, ValueError):
+ return {k: item[k] for k in item} # type: ignore[misc]
+
+ @staticmethod
+ def _normalize_qa_fields(row: dict[str, Any]) -> dict[str, Any]:
+ """Map official Daily-Omni qa.json / Hub schema to internal fields.
+
+ Official fields (see liarliar/Daily-Omni ``qa.json``): ``Question``, ``Choice`` (list),
+ ``Answer``, ``video_id``, ``Type``, ``video_duration`` (``30s`` / ``60s``), ``video_category``,
+ plus other category columns. Legacy aliases (lowercase / older loaders) are still accepted.
+ """
+ out: dict[str, Any] = {}
+
+ out["question"] = str(row.get("Question") or row.get("question") or "").strip()
+ vid = row.get("video_id") if row.get("video_id") is not None else row.get("video")
+ out["video_id"] = str(vid).strip() if vid is not None else ""
+ out["task_type"] = str(row.get("Type") or row.get("task_type") or row.get("type") or "").strip()
+ vc = row.get("video_category") if row.get("video_category") is not None else row.get("videoCategory")
+ out["video_category"] = str(vc).strip() if vc is not None else ""
+ vd = row.get("video_duration") if row.get("video_duration") is not None else row.get("videoDuration")
+ out["video_duration"] = str(vd).strip() if vd is not None else ""
+ out["answer"] = str(row.get("Answer") or row.get("answer") or "").strip()
+ vu = row.get("video_url") if row.get("video_url") is not None else row.get("Video_URL")
+ out["video_url"] = str(vu).strip() if vu is not None and str(vu).strip() else None
+
+ choice = row.get("Choice")
+ if choice is None:
+ choice = row.get("options") or row.get("choice")
+ out["choice"] = choice
+
+ return out
+
+ def sample(
+ self,
+ tokenizer: TokenizerLike,
+ num_requests: int,
+ output_len: int | None = None,
+ request_id_prefix: str = "",
+ no_oversample: bool = False,
+ **kwargs,
+ ) -> list[SampleRequest]:
+ """Sample requests from Daily-Omni dataset.
+
+ Args:
+ tokenizer: Tokenizer for computing prompt length
+ num_requests: Number of requests to sample
+ output_len: Target output length in tokens (default: 256)
+ request_id_prefix: Prefix for request IDs
+ no_oversample: If True, do not oversample if fewer examples available
+ **kwargs: Additional arguments (ignored)
+
+ Returns:
+ List of SampleRequest objects with video URLs and prompts
+ """
+ if output_len is None:
+ output_len = self.DEFAULT_OUTPUT_LEN
+
+ sampled_requests: list[SampleRequest] = []
+ ind = 0
+ cached_tokenizer = get_cached_tokenizer(tokenizer)
+
+ # Iterate over shuffled dataset
+ for item in self.data:
+ if len(sampled_requests) >= num_requests:
+ break
+
+ request = self._create_sample_request(
+ self._coerce_row(item), cached_tokenizer, output_len, request_id_prefix, ind
+ )
+ if request:
+ sampled_requests.append(request)
+ ind += 1
+
+ logger.info("Created %d sample requests from Daily-Omni dataset", len(sampled_requests))
+
+ # Handle oversampling if needed
+ self.maybe_oversample_requests(sampled_requests, num_requests, request_id_prefix, no_oversample)
+
+ return sampled_requests
+
+ def _create_sample_request(
+ self,
+ qa_item: dict[str, Any],
+ tokenizer: TokenizerLike,
+ output_len: int,
+ request_id_prefix: str,
+ index: int,
+ ) -> SampleRequest | None:
+ """Create a SampleRequest from a QA item.
+
+ Args:
+ qa_item: QA pair from the dataset
+ tokenizer: Tokenizer
+ output_len: Target output length
+ request_id_prefix: Prefix for request ID
+ index: Request index
+
+ Returns:
+ SampleRequest or None if invalid
+ """
+ fields = self._normalize_qa_fields(qa_item)
+ video_id = fields["video_id"]
+ question = fields["question"]
+ choice = fields["choice"]
+ task_type = fields["task_type"]
+ video_url = fields["video_url"]
+ video_duration = fields.get("video_duration") or ""
+ video_category = fields.get("video_category") or ""
+
+ if not video_id and not video_url:
+ logger.warning("Skipping item: no video_id / video_url")
+ return None
+
+ if not question:
+ logger.warning("Skipping item: no question found")
+ return None
+
+ # Official layout after extracting Videos.tar (see Lliar-liar/Daily-Omni test_model):
+ # {video_base_dir}/{video_id}/{video_id}_video.mp4
+ mm_payload, omni_extra, mm_pos = self._compose_daily_omni_multimodal(video_id, video_url)
+ if not mm_payload:
+ return None
+
+ messages = self._build_daily_omni_openai_messages(mm_payload, question, choice)
+ user_text = self._official_daily_omni_user_prompt(question, choice)
+ # Text-only length estimate (same as before: no MM token count in bench).
+ prompt_len = len(tokenizer.encode(f"{DAILY_OMNI_SYSTEM_TEXT}\n{user_text}"))
+
+ return DailyOmniSampleRequest(
+ prompt=user_text,
+ prompt_len=prompt_len,
+ expected_output_len=output_len,
+ multi_modal_data=None,
+ request_id=f"{request_id_prefix}{index}",
+ daily_omni_gold_answer=fields["answer"],
+ daily_omni_video_id=video_id,
+ daily_omni_task_type=task_type,
+ daily_omni_video_duration=video_duration,
+ daily_omni_video_category=video_category,
+ omni_extra_body=omni_extra,
+ omni_chat_messages=messages,
+ omni_chat_mm_position=mm_pos,
+ )
+
+ @staticmethod
+ def _official_video_relpath(video_id: str) -> str:
+ """Relative path inside extracted ``Videos/`` per upstream Daily-Omni scripts."""
+ return f"{video_id}/{video_id}_video.mp4"
+
+ @staticmethod
+ def _official_audio_relpath(video_id: str) -> str:
+ """Relative path for extracted WAV per upstream ``get_audio_path``."""
+ return f"{video_id}/{video_id}_audio.wav"
+
+ def _resolve_local_video_path(self, video_id: str) -> Path | None:
+ """Pick an existing file under ``video_dir`` (official layout + flat fallback)."""
+ if not self.video_dir or not video_id:
+ return None
+
+ candidates = [
+ self.video_dir / self._official_video_relpath(video_id),
+ self.video_dir / f"{video_id}.mp4", # flat layout (custom mirrors / outdated docs)
+ ]
+ seen: set[Path] = set()
+ for p in candidates:
+ rp = p.resolve()
+ if rp in seen:
+ continue
+ seen.add(rp)
+ if p.exists():
+ return p
+ return None
+
+ def _resolve_local_audio_path(self, video_id: str) -> Path | None:
+ """Pick an existing WAV under ``video_dir`` (official layout + flat fallback)."""
+ if not self.video_dir or not video_id:
+ return None
+ candidates = [
+ self.video_dir / self._official_audio_relpath(video_id),
+ self.video_dir / f"{video_id}.wav",
+ ]
+ seen: set[Path] = set()
+ for p in candidates:
+ rp = p.resolve()
+ if rp in seen:
+ continue
+ seen.add(rp)
+ if p.exists():
+ return p
+ return None
+
+ def _local_file_to_video_url_payload(self, video_path: Path) -> dict[str, Any]:
+ """Build OpenAI-style video_url part for a resolved local file.
+
+ vLLM rejects ``file://`` unless the server was started with
+ ``--allowed-local-media-path`` set to a directory that **contains** the file
+ (typically the extracted ``Videos`` root). Use ``inline_local_video=True`` to
+ send base64 data URLs instead (no server path allowlist; larger requests).
+ """
+ path = video_path.expanduser().resolve()
+ if self.inline_local_video:
+ raw = path.read_bytes()
+ b64 = base64.b64encode(raw).decode("ascii")
+ return {
+ "type": "video_url",
+ "video_url": {"url": f"data:video/mp4;base64,{b64}"},
+ }
+ return {
+ "type": "video_url",
+ "video_url": {"url": path.as_uri()},
+ }
+
+ def _local_file_to_audio_url_payload(self, audio_path: Path) -> dict[str, Any]:
+ """Build OpenAI-style ``audio_url`` part for a resolved local WAV file."""
+ path = audio_path.expanduser().resolve()
+ if self.inline_local_video:
+ raw = path.read_bytes()
+ b64 = base64.b64encode(raw).decode("ascii")
+ return {
+ "type": "audio_url",
+ "audio_url": {"url": f"data:audio/wav;base64,{b64}"},
+ }
+ return {
+ "type": "audio_url",
+ "audio_url": {"url": path.as_uri()},
+ }
+
+ def _get_video_content(
+ self,
+ video_id: str,
+ video_url: str | None,
+ ) -> dict[str, Any] | None:
+ """Resolve video for OpenAI-style ``video_url`` content.
+
+ Upstream uses ``get_video_path(video_id, base) -> base/video_id/video_id_video.mp4``.
+ The Hub repo only publishes ``Videos.tar``; use ``--daily-omni-video-dir`` pointing
+ at the extracted ``Videos`` folder (parent of per-``video_id`` subdirs).
+
+ For ``file://`` URLs, start ``vllm serve`` with e.g.
+ ``--allowed-local-media-path /same/path/as/daily-omni-video-dir``.
+ """
+ if video_url:
+ url = video_url
+ if not url.startswith(("http://", "https://", "file://")):
+ url = f"https://{url.lstrip('/')}"
+ return {"type": "video_url", "video_url": {"url": url}}
+
+ if self.video_dir and video_id:
+ video_path = self._resolve_local_video_path(video_id)
+ if video_path is not None:
+ return self._local_file_to_video_url_payload(video_path)
+ logger.warning(
+ "Video not found under video_dir=%s for video_id=%r (expected %s or %s)",
+ self.video_dir,
+ video_id,
+ self._official_video_relpath(video_id),
+ f"{video_id}.mp4",
+ )
+
+ if video_id:
+ repo = self.dataset_path or self.DEFAULT_HF_DATASET_ID
+ rel = self._official_video_relpath(video_id)
+ hf_video_url = f"https://huggingface.co/datasets/{repo}/resolve/main/Videos/{rel}"
+ logger.debug(
+ "Using HF video URL (likely 404 — Hub ships Videos.tar only): %s",
+ hf_video_url,
+ )
+ return {"type": "video_url", "video_url": {"url": hf_video_url}}
+
+ logger.error("Could not determine video source for video_id=%r", video_id)
+ return None
+
+ def _get_audio_content(self, video_id: str) -> dict[str, Any] | None:
+ """Resolve extracted WAV for OpenAI-style ``audio_url`` (local files only)."""
+ if not self.video_dir or not video_id:
+ logger.warning(
+ "Daily-Omni input_mode %r requires --daily-omni-video-dir with %s",
+ self.input_mode,
+ self._official_audio_relpath(video_id),
+ )
+ return None
+ audio_path = self._resolve_local_audio_path(video_id)
+ if audio_path is not None:
+ return self._local_file_to_audio_url_payload(audio_path)
+ logger.warning(
+ "Audio not found under video_dir=%s for video_id=%r (expected %s or %s)",
+ self.video_dir,
+ video_id,
+ self._official_audio_relpath(video_id),
+ f"{video_id}.wav",
+ )
+ return None
+
+ def _compose_daily_omni_multimodal(
+ self,
+ video_id: str,
+ video_url: str | None,
+ ) -> tuple[dict[str, Any] | list[dict[str, Any]] | None, dict[str, Any] | None, Literal["first", "last"]]:
+ """Build ``multi_modal_data`` and request extras for the active ``input_mode``.
+
+ Mirrors upstream Daily-Omni: separate video + WAV with ``use_audio_in_video=False``.
+ """
+ extra: dict[str, Any] = {"mm_processor_kwargs": {"use_audio_in_video": False}}
+ mode = self.input_mode
+
+ if mode == "visual":
+ v = self._get_video_content(video_id, video_url)
+ return v, extra, "last"
+
+ if mode == "audio":
+ a = self._get_audio_content(video_id)
+ return a, extra, "first"
+
+ v = self._get_video_content(video_id, video_url)
+ a = self._get_audio_content(video_id)
+ if not v or not a:
+ return None, None, "first"
+ return [v, a], extra, "first"
+
+ @staticmethod
+ def _media_desc_for_official_prompt(mode: DailyOmniInputMode) -> str:
+ """``media_desc`` in upstream ``build_conversation``."""
+ if mode == "audio":
+ return "given audio"
+ if mode == "all":
+ return "given video and audio together"
+ return "given video"
+
+ @staticmethod
+ def _choices_repr_for_official_prompt(choice: Any) -> str:
+ """Format ``Choice`` from qa.json for the model (one option per line when possible).
+
+ Using ``str(list)`` embeds Python list brackets and quotes, which is poor for MCQ
+ reading; lists/tuples are joined with newlines instead. Other shapes fall back to
+ ``str(choice)`` for parity with exotic upstream payloads.
+ """
+ if choice is None:
+ return ""
+ if isinstance(choice, (list, tuple)):
+ lines = [str(x).strip() for x in choice if str(x).strip()]
+ return "\n".join(lines)
+ if isinstance(choice, dict):
+ return "\n".join(f"{k}. {v}" for k, v in choice.items())
+ return str(choice)
+
+ def _official_daily_omni_user_prompt(self, question: str, choice: Any) -> str:
+ """User text block from Daily-Omni ``build_conversation`` (after media parts)."""
+ task_prompt = self._media_desc_for_official_prompt(self.input_mode)
+ choices = self._choices_repr_for_official_prompt(choice)
+ # Single f-string with explicit newlines avoids accidental implicit concatenation
+ # gluing sentences (e.g. ``...media_desc.Select...``) when editing.
+ return (
+ "Your task is to accurately answer multiple-choice questions "
+ f"based on the {task_prompt}.\n"
+ "Select the single most accurate answer from the given choices.\n"
+ f"Question: {question}\n"
+ f"Choices: {choices}\n"
+ "Your answer should be a capital letter representing your choice: "
+ "A, B, C, or D. Don't generate any other text.\n"
+ )
+
+ def _build_daily_omni_openai_messages(
+ self,
+ mm_payload: dict[str, Any] | list[dict[str, Any]],
+ question: str,
+ choice: Any,
+ ) -> list[dict[str, Any]]:
+ """Map upstream conversation to OpenAI Chat Completions ``messages`` (video_url / audio_url parts)."""
+ user_text = self._official_daily_omni_user_prompt(question, choice)
+ mm_list: list[dict[str, Any]] = mm_payload if isinstance(mm_payload, list) else [mm_payload]
+ user_content: list[dict[str, Any]] = [*mm_list, {"type": "text", "text": user_text}]
+ return [
+ {"role": "system", "content": [{"type": "text", "text": DAILY_OMNI_SYSTEM_TEXT}]},
+ {"role": "user", "content": user_content},
+ ]
+
+ def sample_by_task_type(
+ self,
+ tokenizer: TokenizerLike,
+ task_type: str,
+ num_samples: int,
+ output_len: int | None = None,
+ request_id_prefix: str = "",
+ **kwargs,
+ ) -> list[SampleRequest]:
+ """Sample requests filtered by task type.
+
+ Args:
+ tokenizer: Tokenizer
+ task_type: Task type to filter by
+ num_samples: Number of samples
+ output_len: Target output length
+ request_id_prefix: Prefix for request IDs
+ **kwargs: Additional sampling arguments
+
+ Returns:
+ List of SampleRequest objects matching the task type
+ """
+ if output_len is None:
+ output_len = self.DEFAULT_OUTPUT_LEN
+
+ filtered = [
+ item for item in self.data if self._normalize_qa_fields(self._coerce_row(item))["task_type"] == task_type
+ ]
+
+ available = len(filtered)
+ if available < num_samples:
+ logger.warning(
+ "Only %d samples available for task type '%s', requested %d",
+ available,
+ task_type,
+ num_samples,
+ )
+ num_samples = available
+
+ sampled_requests: list[SampleRequest] = []
+ cached_tokenizer = get_cached_tokenizer(tokenizer)
+
+ for i, item in enumerate(filtered[:num_samples]):
+ request = self._create_sample_request(item, cached_tokenizer, output_len, request_id_prefix, i)
+ if request:
+ sampled_requests.append(request)
+
+ return sampled_requests
+
+ def __repr__(self) -> str:
+ return (
+ f"DailyOmniDataset("
+ f"dataset_path={self.dataset_path!r}, "
+ f"dataset_split={self.dataset_split!r}, "
+ f"video_dir={self.video_dir!r}, "
+ f"input_mode={self.input_mode!r}, "
+ f"inline_local_video={self.inline_local_video!r}, "
+ f"max_duration_seconds={self.max_duration_seconds}, "
+ f"random_seed={self.random_seed}"
+ f")"
+ )
+
+
+def load_daily_omni_dataset(
+ qa_json_path: str | None = None,
+ dataset_path: str | None = None,
+ dataset_split: str = "train",
+ random_seed: int = 0,
+ video_dir: str | None = None,
+ input_mode: DailyOmniInputMode = "all",
+ max_duration_seconds: float | None = None,
+ dataset_subset: str | None = None,
+ no_stream: bool = False,
+ **kwargs,
+) -> DailyOmniDataset:
+ """Convenience function to load Daily-Omni dataset.
+
+ Args:
+ qa_json_path: Path to local qa.json file (recommended for offline/air-gapped environments).
+ When provided, ``dataset_path`` is ignored.
+ dataset_path: HuggingFace dataset path (default: liarliar/Daily-Omni). Used only if
+ ``qa_json_path`` is not provided (legacy online mode).
+ dataset_split: Dataset split to use (default: "train")
+ random_seed: Random seed for shuffling
+ video_dir: Directory containing extracted ``Videos/`` tree (MP4 and, for ``all``/``audio``, WAV)
+ input_mode: ``visual`` | ``audio`` | ``all`` (same semantics as upstream Daily-Omni)
+ max_duration_seconds: Maximum video duration in seconds (e.g., 30 for 30s subset, 60 for 60s subset);
+ uses ffprobe on local files under ``video_dir`` (in-memory cache only for this process).
+ **kwargs: Additional arguments passed to DailyOmniDataset
+
+ Returns:
+ DailyOmniDataset instance
+
+ Example:
+ >>> from vllm_omni.benchmarks.data_modules.daily_omni_dataset import load_daily_omni_dataset
+
+ # Local JSON mode (recommended for offline)
+ >>> dataset = load_daily_omni_dataset(
+ ... qa_json_path="/path/to/qa.json",
+ ... video_dir="/path/to/Daily-Omni/Videos",
+ ... random_seed=42,
+ ... max_duration_seconds=30,
+ ... )
+
+ # HuggingFace mode (legacy online)
+ >>> dataset = load_daily_omni_dataset(
+ ... dataset_path="liarliar/Daily-Omni",
+ ... video_dir="/path/to/Daily-Omni/Videos",
+ ... random_seed=42,
+ ... )
+ >>> requests = dataset.sample(tokenizer, num_requests=100)
+ """
+ return DailyOmniDataset(
+ qa_json_path=qa_json_path,
+ dataset_path=dataset_path,
+ dataset_split=dataset_split,
+ random_seed=random_seed,
+ video_dir=video_dir,
+ input_mode=input_mode,
+ max_duration_seconds=max_duration_seconds,
+ dataset_subset=dataset_subset,
+ no_stream=no_stream,
+ **kwargs,
+ )
+
+
+def get_daily_omni_statistics(
+ qa_json_path: str | None = None,
+ dataset_path: str | None = DailyOmniDataset.DEFAULT_HF_DATASET_ID,
+ dataset_split: str = "train",
+) -> dict[str, Any]:
+ """Get statistics about the Daily-Omni dataset.
+
+ Args:
+ qa_json_path: Path to local qa.json file (recommended for offline/air-gapped environments).
+ When provided, ``dataset_path`` is ignored.
+ dataset_path: HuggingFace dataset path. Defaults to ``DailyOmniDataset.DEFAULT_HF_DATASET_ID``
+ when ``qa_json_path`` is omitted. Pass ``None`` only together with ``qa_json_path``.
+ dataset_split: Dataset split to use (default: "train")
+
+ Returns:
+ Statistics dict with task type distribution and other info
+
+ Example:
+ >>> from vllm_omni.benchmarks.data_modules.daily_omni_dataset import get_daily_omni_statistics
+
+ # Local JSON mode
+ >>> stats = get_daily_omni_statistics(qa_json_path="/path/to/qa.json")
+
+ # HuggingFace mode
+ >>> stats = get_daily_omni_statistics(dataset_path="liarliar/Daily-Omni")
+ >>> print(f"Total QA pairs: {stats['total_qa_pairs']}")
+ >>> print(f"Task distribution: {stats['task_distribution']}")
+ """
+ dataset = DailyOmniDataset(
+ qa_json_path=qa_json_path,
+ dataset_path=dataset_path,
+ dataset_split=dataset_split,
+ )
+ task_stats = dataset.get_task_statistics()
+
+ source = str(qa_json_path) if qa_json_path else f"{dataset_path}/{dataset_split}"
+ return {
+ "source": source,
+ "total_qa_pairs": len(list(dataset.data)),
+ "task_distribution": task_stats,
+ }
diff --git a/vllm_omni/benchmarks/data_modules/daily_omni_eval.py b/vllm_omni/benchmarks/data_modules/daily_omni_eval.py
new file mode 100644
index 0000000000..ecc9edc844
--- /dev/null
+++ b/vllm_omni/benchmarks/data_modules/daily_omni_eval.py
@@ -0,0 +1,406 @@
+"""Daily-Omni multiple-choice accuracy scoring for vLLM-Omni bench serve.
+
+Compares model ``generated_text`` to dataset ``Answer`` (A/B/C/D).
+
+**Alignment with open-source** (`Lliar-liar/Daily-Omni` ``test_model/.../testmodel.py``):
+
+- Answer extraction defaults to the same rules as ``extract_choice_letter`` (strip after an
+ ``assistant`` marker, then leading ``A``–``D``, else first ``\\b[A-D]\\b``). Set env
+ ``DAILY_OMNI_EXTRACT_MODE=relaxed`` to use the older vLLM-Omni heuristics (last ``answer:``,
+ tail scan, etc.).
+- Overall accuracy comparable to the official script uses **successful HTTP responses only** as
+ the denominator (their ``valid_questions = total - failed`` excludes inference / I/O skips).
+ We also report ``daily_omni_accuracy_incl_http_fail`` where each failed request counts as a
+ wrong answer in the denominator (stricter throughput-bench view).
+- **By video length:** mirrors upstream ``--- Accuracy by Video Duration ---`` for ``30s`` /
+ ``60s`` (``qa.json`` ``video_duration``): ``daily_omni_per_duration*`` metrics and a printed block.
+- **By video category:** mirrors ``--- Accuracy by Video Category ---`` using ``video_category``
+ from ``qa.json`` (``daily_omni_per_category*``; empty category is bucketed as ``unknown``).
+- **Correctness:** uses the same ``evaluate_answer`` rule as upstream (truthy extracted letter vs
+ raw ``Answer`` string, both ``strip().upper()``). Rows with empty ``Answer`` are skipped
+ (``no_gold``), matching missing-field skips in the official loop.
+"""
+
+from __future__ import annotations
+
+import os
+import re
+from typing import Any
+
+from vllm.benchmarks.lib.endpoint_request_func import RequestFuncOutput
+
+from vllm_omni.benchmarks.data_modules.daily_omni_dataset import DailyOmniSampleRequest
+
+_VALID = frozenset("ABCD")
+
+# Official ``testmodel.py`` buckets (``qa.json`` ``video_duration``).
+DAILY_OMNI_DURATION_KEYS: tuple[str, ...] = ("30s", "60s")
+
+
+def extract_choice_letter_official(text: str | None) -> str | None:
+ """Port of Daily-Omni ``extract_choice_letter`` (first A–D, assistant-tail semantics)."""
+ if not text:
+ return None
+ raw = str(text).strip()
+ if not raw:
+ return None
+ match = re.search(r"assistant\s*([\s\S]*)$", raw, flags=re.IGNORECASE)
+ candidate = match.group(1).strip() if match else raw
+ direct = re.match(r"(?i)^\s*([A-D])(?:[\s\.\)::]|$)", candidate)
+ if direct:
+ return direct.group(1).upper()
+ fallback = re.search(r"\b([A-D])\b", candidate.upper())
+ if fallback:
+ return fallback.group(1)
+ return None
+
+
+def evaluate_answer_official(model_answer: str | None, correct_answer: str) -> bool:
+ """Port of Daily-Omni ``evaluate_answer`` (strict string match after strip/upper)."""
+ if not model_answer:
+ return False
+ return model_answer.strip().upper() == (correct_answer or "").strip().upper()
+
+
+def normalize_gold_answer(gold: str) -> str | None:
+ """Best-effort single letter from ``Answer`` (for ``gold_normalized`` in saved items only)."""
+ g = (gold or "").strip().upper()
+ if len(g) == 1 and g in _VALID:
+ return g
+ m = re.search(r"([ABCD])\b", g)
+ if m:
+ return m.group(1).upper()
+ return None
+
+
+def _extract_predicted_choice_relaxed(text: str) -> str | None:
+ """Legacy vLLM-Omni heuristics (last ``answer:`` patterns, tail scan)."""
+ if not text or not str(text).strip():
+ return None
+ t = str(text).strip()
+
+ strong_patterns = [
+ r"(?i)\*\*answer\*\*\s*[::]?\s*\(?([ABCD])\)?",
+ r"(?i)\banswer\s*[::]?\s*\(?([ABCD])\)?",
+ r"(?i)\bfinal\s+answer\s*[::]?\s*\(?([ABCD])\)?",
+ r"(?i)\bcorrect\s+(?:answer|option)\s*[::]?\s*\(?([ABCD])\)?",
+ r"(?i)\bthe\s+(?:correct\s+)?option\s+(?:is|would\s+be)\s*\(?([ABCD])\)?",
+ r"(?i)\bI\s+(?:would\s+)?(?:choose|select|pick)\s*\(?([ABCD])\)?",
+ ]
+ last_letter: str | None = None
+ for pat in strong_patterns:
+ for m in re.finditer(pat, t):
+ last_letter = m.group(1).upper()
+ if last_letter:
+ return last_letter
+
+ # Weaker phrases: first match can be spurious; still prefer last occurrence.
+ weak_patterns = [
+ r"(?i)\boption\s*[::]?\s*\(?([ABCD])\)?",
+ r"(?i)\bchoice\s*[::]?\s*\(?([ABCD])\)?",
+ ]
+ for pat in weak_patterns:
+ for m in re.finditer(pat, t):
+ last_letter = m.group(1).upper()
+ if last_letter:
+ return last_letter
+
+ paren = list(re.finditer(r"\(([ABCD])\)", t))
+ if paren:
+ return paren[-1].group(1).upper()
+
+ # First line sometimes is just "B" or "B." — allow if whole output is short
+ one_line = t.split("\n", 1)[0].strip()
+ if len(t) < 120 and len(one_line) <= 6:
+ m0 = re.match(r"^([ABCD])\s*[.:\)]?\s*$", one_line, re.I)
+ if m0:
+ return m0.group(1).upper()
+
+ # Tail-only: avoids matching echoed "A. ..." option blocks at the start
+ tail_len = min(500, len(t))
+ tail = t[-tail_len:]
+ # ``\b`` after the letter avoids "Because"/"Definitely" false positives
+ m = re.search(r"(?:^|[^\w])([ABCD])\b", tail, re.I)
+ if m:
+ return m.group(1).upper()
+
+ return None
+
+
+def extract_predicted_choice(text: str | None) -> str | None:
+ """Parse model output to A–D (official Daily-Omni rules by default)."""
+ if not text or not str(text).strip():
+ return None
+ mode = os.environ.get("DAILY_OMNI_EXTRACT_MODE", "official").strip().lower()
+ if mode in ("relaxed", "heuristic", "legacy"):
+ return _extract_predicted_choice_relaxed(str(text))
+ return extract_choice_letter_official(text)
+
+
+def compute_daily_omni_accuracy_metrics(
+ input_requests: list[Any],
+ outputs: list[RequestFuncOutput],
+ *,
+ include_per_item: bool = False,
+) -> dict[str, Any] | None:
+ """If all requests are :class:`DailyOmniSampleRequest`, compute accuracy stats.
+
+ Rows with empty ``Answer`` (after strip) are skipped as ``no_gold``, like upstream missing
+ ``correct_answer``.
+
+ **Denominators:** The open-source script excludes items that hit inference / I/O failures
+ from ``valid_questions``; we mirror that with ``daily_omni_accuracy`` (= correct /
+ successful responses). Failed HTTP requests are also tracked and used in
+ ``daily_omni_accuracy_incl_http_fail`` (each failure counts as incorrect in the
+ denominator).
+ """
+ if not input_requests or len(input_requests) != len(outputs):
+ return None
+ if not all(isinstance(r, DailyOmniSampleRequest) for r in input_requests):
+ return None
+
+ # total / correct: all rows with gold (incl. HTTP fail in total)
+ # total_ok / correct_ok: successful HTTP only (GitHub-style per-type denominator)
+ per_task: dict[str, dict[str, int]] = {}
+ per_category: dict[str, dict[str, int]] = {}
+ per_duration: dict[str, dict[str, int]] = {
+ k: {"correct": 0, "total": 0, "correct_ok": 0, "total_ok": 0} for k in DAILY_OMNI_DURATION_KEYS
+ }
+ items: list[dict[str, Any]] = []
+ correct = 0
+ evaluated = 0
+ no_gold = 0
+ request_failed = 0
+ parse_failed = 0 # success but could not extract A–D
+
+ for req, out in zip(input_requests, outputs, strict=True):
+ assert isinstance(req, DailyOmniSampleRequest)
+ gold_raw = (req.daily_omni_gold_answer or "").strip()
+ gold_norm = normalize_gold_answer(req.daily_omni_gold_answer)
+ tt = (req.daily_omni_task_type or "unknown").strip() or "unknown"
+ dur_key = (req.daily_omni_video_duration or "").strip()
+ dur_active = dur_key in per_duration
+ cat_key = (req.daily_omni_video_category or "").strip() or "unknown"
+ if tt not in per_task:
+ per_task[tt] = {"correct": 0, "total": 0, "correct_ok": 0, "total_ok": 0}
+ if cat_key not in per_category:
+ per_category[cat_key] = {"correct": 0, "total": 0, "correct_ok": 0, "total_ok": 0}
+
+ if not gold_raw:
+ no_gold += 1
+ items.append(
+ {
+ "request_id": req.request_id,
+ "skipped": True,
+ "reason": "no_gold",
+ "task_type": tt,
+ "video_id": req.daily_omni_video_id,
+ "video_duration": dur_key or None,
+ "video_category": cat_key if cat_key != "unknown" else None,
+ }
+ )
+ continue
+
+ if not out.success:
+ request_failed += 1
+ evaluated += 1
+ per_task[tt]["total"] += 1
+ per_category[cat_key]["total"] += 1
+ if dur_active:
+ per_duration[dur_key]["total"] += 1
+ # GitHub: failed inference not in valid_questions — do not increment total_ok
+ items.append(
+ {
+ "request_id": req.request_id,
+ "gold": gold_raw,
+ "gold_normalized": gold_norm,
+ "predicted": None,
+ "correct": False,
+ "task_type": tt,
+ "video_id": req.daily_omni_video_id,
+ "video_duration": dur_key or None,
+ "video_category": cat_key if cat_key != "unknown" else None,
+ "error": (out.error or "")[:500],
+ }
+ )
+ continue
+
+ pred = extract_predicted_choice(out.generated_text)
+ evaluated += 1
+ per_task[tt]["total"] += 1
+ per_task[tt]["total_ok"] += 1
+ per_category[cat_key]["total"] += 1
+ per_category[cat_key]["total_ok"] += 1
+ if dur_active:
+ per_duration[dur_key]["total"] += 1
+ per_duration[dur_key]["total_ok"] += 1
+ if pred is None:
+ parse_failed += 1
+ is_correct = evaluate_answer_official(pred, req.daily_omni_gold_answer)
+ if is_correct:
+ correct += 1
+ per_task[tt]["correct"] += 1
+ per_task[tt]["correct_ok"] += 1
+ per_category[cat_key]["correct"] += 1
+ per_category[cat_key]["correct_ok"] += 1
+ if dur_active:
+ per_duration[dur_key]["correct"] += 1
+ per_duration[dur_key]["correct_ok"] += 1
+
+ items.append(
+ {
+ "request_id": req.request_id,
+ "gold": gold_raw,
+ "gold_normalized": gold_norm,
+ "predicted": pred,
+ "correct": is_correct,
+ "parse_failed": pred is None,
+ "task_type": tt,
+ "video_id": req.daily_omni_video_id,
+ "video_duration": dur_key or None,
+ "video_category": cat_key if cat_key != "unknown" else None,
+ }
+ )
+
+ evaluated_ok = evaluated - request_failed
+ accuracy_github = (correct / evaluated_ok) if evaluated_ok else None
+ accuracy_incl_fail = (correct / evaluated) if evaluated else None
+
+ per_task_accuracy: dict[str, float | None] = {}
+ per_task_accuracy_github: dict[str, float | None] = {}
+ for name, st in per_task.items():
+ tot = st["total"]
+ per_task_accuracy[name] = (st["correct"] / tot) if tot else None
+ tok = st["total_ok"]
+ per_task_accuracy_github[name] = (st["correct_ok"] / tok) if tok else None
+
+ per_category_accuracy: dict[str, float | None] = {}
+ per_category_accuracy_github: dict[str, float | None] = {}
+ for name, st in per_category.items():
+ tot = st["total"]
+ per_category_accuracy[name] = (st["correct"] / tot) if tot else None
+ tok = st["total_ok"]
+ per_category_accuracy_github[name] = (st["correct_ok"] / tok) if tok else None
+
+ per_duration_accuracy: dict[str, float | None] = {}
+ per_duration_accuracy_github: dict[str, float | None] = {}
+ for name, st in per_duration.items():
+ tot = st["total"]
+ per_duration_accuracy[name] = (st["correct"] / tot) if tot else None
+ tok = st["total_ok"]
+ per_duration_accuracy_github[name] = (st["correct_ok"] / tok) if tok else None
+
+ out: dict[str, Any] = {
+ # Comparable to GitHub testmodel.py: correct / successful inferences
+ "daily_omni_accuracy": accuracy_github,
+ "daily_omni_accuracy_incl_http_fail": accuracy_incl_fail,
+ "daily_omni_correct": correct,
+ "daily_omni_evaluated": evaluated,
+ "daily_omni_evaluated_ok": evaluated_ok,
+ "daily_omni_no_gold": no_gold,
+ "daily_omni_request_failed": request_failed,
+ "daily_omni_parse_failed": parse_failed,
+ "daily_omni_per_task": {k: dict(v) for k, v in per_task.items()},
+ "daily_omni_per_task_accuracy": per_task_accuracy,
+ "daily_omni_per_task_accuracy_github_style": per_task_accuracy_github,
+ "daily_omni_per_category": {k: dict(v) for k, v in per_category.items()},
+ "daily_omni_per_category_accuracy": per_category_accuracy,
+ "daily_omni_per_category_accuracy_github_style": per_category_accuracy_github,
+ "daily_omni_per_duration": {k: dict(v) for k, v in per_duration.items()},
+ "daily_omni_per_duration_accuracy": per_duration_accuracy,
+ "daily_omni_per_duration_accuracy_github_style": per_duration_accuracy_github,
+ }
+ if include_per_item:
+ out["daily_omni_eval_items"] = items
+ return out
+
+
+def print_daily_omni_accuracy_summary(metrics: dict[str, Any]) -> None:
+ """Pretty-print accuracy block (stdout)."""
+ acc = metrics.get("daily_omni_accuracy")
+ acc_fail = metrics.get("daily_omni_accuracy_incl_http_fail")
+ if acc is None and acc_fail is None and metrics.get("daily_omni_evaluated", 0) == 0:
+ return
+ print("{s:{c}^{n}}".format(s=" Daily-Omni accuracy (MCQ) ", n=50, c="="))
+ ok = int(metrics.get("daily_omni_evaluated_ok", 0) or 0)
+ cor = int(metrics.get("daily_omni_correct", 0) or 0)
+ if ok > 0 and acc is not None:
+ print(f"Overall Accuracy: {cor}/{ok} = {acc:.2%}")
+ elif int(metrics.get("daily_omni_evaluated", 0) or 0) > 0:
+ print("Overall Accuracy: 0/0 = N/A (no successful HTTP responses)")
+ print(
+ "{:<40} {:<10}".format(
+ "Submitted (gold present):",
+ metrics.get("daily_omni_evaluated", 0),
+ )
+ )
+ print(
+ "{:<40} {:<10}".format(
+ "Successful HTTP (GitHub denom.):",
+ metrics.get("daily_omni_evaluated_ok", 0),
+ )
+ )
+ print("{:<40} {:<10}".format("Correct:", metrics.get("daily_omni_correct", 0)))
+ if acc is not None:
+ print("{:<40} {:<10.4f}".format("Accuracy (ratio, same as above):", acc))
+ if acc_fail is not None and metrics.get("daily_omni_request_failed", 0):
+ print(
+ "{:<40} {:<10.4f}".format(
+ "Accuracy (incl. HTTP as wrong):",
+ acc_fail,
+ )
+ )
+ print("{:<40} {:<10}".format("Skipped (no gold):", metrics.get("daily_omni_no_gold", 0)))
+ print(
+ "{:<40} {:<10}".format(
+ "HTTP failed (excl. from GitHub acc.):",
+ metrics.get("daily_omni_request_failed", 0),
+ )
+ )
+ print(
+ "{:<40} {:<10}".format(
+ "Parsed OK but no A–D found:",
+ metrics.get("daily_omni_parse_failed", 0),
+ )
+ )
+ pt = metrics.get("daily_omni_per_task") or {}
+ pta = metrics.get("daily_omni_per_task_accuracy_github_style") or {}
+ if pta:
+ print("\n--- Accuracy by QA Type ---")
+ for name in sorted(pta.keys()):
+ a = pta[name]
+ st = pt.get(name) or {}
+ tok = int(st.get("total_ok", 0) or 0)
+ cok = int(st.get("correct_ok", 0) or 0)
+ if tok and a is not None:
+ print(f"{name}: {cok}/{tok} = {a:.2%}")
+ else:
+ print(f"{name}: 0/0 = N/A")
+
+ pc = metrics.get("daily_omni_per_category") or {}
+ ptc = metrics.get("daily_omni_per_category_accuracy_github_style") or {}
+ if ptc:
+ print("\n--- Accuracy by Video Category ---")
+ for name in sorted(ptc.keys()):
+ a = ptc[name]
+ st = pc.get(name) or {}
+ tok = int(st.get("total_ok", 0) or 0)
+ cok = int(st.get("correct_ok", 0) or 0)
+ if tok and a is not None:
+ print(f"{name}: {cok}/{tok} = {a:.2%}")
+ else:
+ print(f"{name}: 0/0 = N/A")
+
+ pdf = metrics.get("daily_omni_per_duration_accuracy_github_style") or {}
+ if pdf:
+ print("\n--- Accuracy by Video Duration ---")
+ for name in DAILY_OMNI_DURATION_KEYS:
+ a = pdf.get(name)
+ st = (metrics.get("daily_omni_per_duration") or {}).get(name) or {}
+ tok = int(st.get("total_ok", 0) or 0)
+ cor = int(st.get("correct_ok", 0) or 0)
+ if tok and a is not None:
+ print(f"{name} Duration: {cor}/{tok} = {a:.2%}")
+ else:
+ print(f"{name} Duration: 0/0 = N/A")
+ print("=" * 50)
diff --git a/vllm_omni/benchmarks/data_modules/daily_omni_text_audio.py b/vllm_omni/benchmarks/data_modules/daily_omni_text_audio.py
new file mode 100644
index 0000000000..69fbe026bd
--- /dev/null
+++ b/vllm_omni/benchmarks/data_modules/daily_omni_text_audio.py
@@ -0,0 +1,255 @@
+"""Daily-Omni: optional consistency check between text stream and generated speech.
+
+The benchmark MCQ accuracy uses ``generated_text`` only. When the omni server also
+streams ``modality=audio`` (TTS), this module can transcribe the concatenated WAV
+with Whisper and compare the inferred option letter to the one parsed from text.
+
+Requires ``openai-whisper`` (``pip install openai-whisper``). Enable via env
+``DAILY_OMNI_TEXT_AUDIO_CONSISTENCY=1`` or CLI ``--daily-omni-text-audio-consistency``.
+
+Whisper model name defaults to ``tiny`` (override with ``DAILY_OMNI_WHISPER_MODEL``).
+"""
+
+from __future__ import annotations
+
+import logging
+import os
+import re
+import threading
+from typing import Any
+
+from vllm_omni.benchmarks.data_modules.daily_omni_dataset import DailyOmniSampleRequest
+from vllm_omni.benchmarks.data_modules.daily_omni_eval import extract_predicted_choice
+
+logger = logging.getLogger(__name__)
+
+_whisper_model = None
+_whisper_model_name: str | None = None
+_whisper_lock = threading.Lock()
+
+
+def env_text_audio_check_enabled() -> bool:
+ return os.environ.get("DAILY_OMNI_TEXT_AUDIO_CONSISTENCY", "").lower() in (
+ "1",
+ "true",
+ "yes",
+ )
+
+
+def extract_choice_from_asr_transcript(transcript: str) -> str | None:
+ """Parse A–D from ASR text; extends :func:`extract_predicted_choice` with spoken Chinese phrases."""
+ c = extract_predicted_choice(transcript)
+ if c:
+ return c
+ t = transcript or ""
+ for pat in (
+ r"(?i)选项\s*([ABCD])\b",
+ r"(?i)选\s*([ABCD])\b",
+ r"(?i)答案\s*是\s*([ABCD])\b",
+ r"(?i)答案\s*([ABCD])\b",
+ ):
+ m = re.search(pat, t)
+ if m:
+ return m.group(1).upper()
+ return None
+
+
+def _get_whisper_model(model_name: str):
+ global _whisper_model, _whisper_model_name
+ with _whisper_lock:
+ if _whisper_model is None or _whisper_model_name != model_name:
+ import whisper
+
+ logger.warning(
+ "Loading Whisper model %r for Daily-Omni text/audio consistency (one-time)...",
+ model_name,
+ )
+ _whisper_model = whisper.load_model(model_name)
+ _whisper_model_name = model_name
+ return _whisper_model
+
+
+def transcribe_wav_bytes(
+ wav_bytes: bytes,
+ *,
+ language: str | None = None,
+ model_name: str | None = None,
+) -> tuple[str | None, str | None]:
+ """Transcribe WAV bytes. Returns ``(transcript, error)`` — one of them is set.
+
+ Args:
+ wav_bytes: RIFF WAV file bytes.
+ language: Optional Whisper language code (e.g. ``en``, ``zh``); improves accuracy/latency.
+ model_name: Override model id; else ``DAILY_OMNI_WHISPER_MODEL`` or ``tiny``.
+ """
+ if not wav_bytes:
+ return None, "empty_wav"
+ if model_name is None or not str(model_name).strip():
+ model_name = os.environ.get("DAILY_OMNI_WHISPER_MODEL") or "tiny"
+ model_name = str(model_name).strip() or "tiny"
+ path: str | None = None
+ try:
+ import tempfile
+
+ model = _get_whisper_model(model_name)
+ fd, path = tempfile.mkstemp(suffix=".wav")
+ with os.fdopen(fd, "wb") as fp:
+ fp.write(wav_bytes)
+ kwargs: dict = {}
+ if language:
+ kwargs["language"] = language
+ result = model.transcribe(path, **kwargs)
+ text = (result.get("text") or "").strip()
+ return (text if text else None), None
+ except ImportError:
+ return None, "openai-whisper is not installed (pip install openai-whisper)"
+ except Exception as e:
+ return None, str(e)[:500]
+ finally:
+ if path:
+ try:
+ os.unlink(path)
+ except OSError:
+ pass
+
+
+def compute_daily_omni_text_audio_consistency_metrics(
+ input_requests: list[Any],
+ outputs: list[Any],
+ *,
+ include_per_item: bool = False,
+) -> dict[str, Any] | None:
+ """Compare option letter from ``generated_text`` vs Whisper transcript of output audio.
+
+ Only considers requests where ``outputs[i]`` has ``generated_audio_wav_bytes`` set
+ (populated by the omni benchmark when TA check is enabled).
+ """
+ if not input_requests or len(input_requests) != len(outputs):
+ return None
+ if not all(isinstance(r, DailyOmniSampleRequest) for r in input_requests):
+ return None
+
+ ta_no_wav = 0
+ ta_asr_failed = 0
+ ta_text_unparsed = 0
+ ta_audio_unparsed = 0
+ ta_consistent = 0
+ ta_mismatch = 0
+ ta_both_parsed = 0
+ items: list[dict[str, Any]] = []
+
+ for req, out in zip(input_requests, outputs, strict=True):
+ assert isinstance(req, DailyOmniSampleRequest)
+ rid = req.request_id
+ if not getattr(out, "success", False):
+ if include_per_item:
+ items.append(
+ {
+ "request_id": rid,
+ "skipped": True,
+ "reason": "request_not_success",
+ }
+ )
+ continue
+
+ wav = getattr(out, "generated_audio_wav_bytes", None)
+ if not wav:
+ ta_no_wav += 1
+ if include_per_item:
+ items.append(
+ {
+ "request_id": rid,
+ "skipped": False,
+ "reason": "no_output_audio",
+ "text_choice": extract_predicted_choice(getattr(out, "generated_text", "") or ""),
+ }
+ )
+ continue
+
+ transcript, asr_err = transcribe_wav_bytes(wav)
+ if asr_err:
+ ta_asr_failed += 1
+ if include_per_item:
+ items.append(
+ {
+ "request_id": rid,
+ "asr_error": asr_err,
+ "text_choice": extract_predicted_choice(getattr(out, "generated_text", "") or ""),
+ }
+ )
+ continue
+
+ text_choice = extract_predicted_choice(getattr(out, "generated_text", "") or "")
+ audio_choice = extract_choice_from_asr_transcript(transcript or "")
+
+ if text_choice is None:
+ ta_text_unparsed += 1
+ if audio_choice is None:
+ ta_audio_unparsed += 1
+
+ if text_choice is not None and audio_choice is not None:
+ ta_both_parsed += 1
+ if text_choice == audio_choice:
+ ta_consistent += 1
+ else:
+ ta_mismatch += 1
+
+ if include_per_item:
+ consistent: bool | None
+ if text_choice is None or audio_choice is None:
+ consistent = None
+ else:
+ consistent = text_choice == audio_choice
+ items.append(
+ {
+ "request_id": rid,
+ "text_choice": text_choice,
+ "audio_choice": audio_choice,
+ "asr_transcript": (transcript or "")[:500],
+ "text_audio_consistent": consistent,
+ }
+ )
+
+ comparable = ta_consistent + ta_mismatch
+ rate = (ta_consistent / comparable) if comparable else None
+
+ out: dict[str, Any] = {
+ "daily_omni_ta_enabled": True,
+ "daily_omni_ta_no_output_audio": ta_no_wav,
+ "daily_omni_ta_asr_failed": ta_asr_failed,
+ "daily_omni_ta_text_unparsed": ta_text_unparsed,
+ "daily_omni_ta_audio_unparsed": ta_audio_unparsed,
+ "daily_omni_ta_both_parsed": ta_both_parsed,
+ "daily_omni_ta_consistent": ta_consistent,
+ "daily_omni_ta_mismatch": ta_mismatch,
+ "daily_omni_ta_consistency_rate": rate,
+ }
+ if include_per_item:
+ out["daily_omni_ta_items"] = items
+ return out
+
+
+def print_daily_omni_text_audio_summary(metrics: dict[str, Any]) -> None:
+ if not metrics.get("daily_omni_ta_enabled"):
+ return
+ print("{s:{c}^{n}}".format(s=" Daily-Omni text vs audio (ASR) ", n=50, c="="))
+ print("{:<40} {:<10}".format("No output audio captured:", metrics.get("daily_omni_ta_no_output_audio", 0)))
+ print("{:<40} {:<10}".format("ASR failed:", metrics.get("daily_omni_ta_asr_failed", 0)))
+ print("{:<40} {:<10}".format("Both text+audio letter parsed:", metrics.get("daily_omni_ta_both_parsed", 0)))
+ print("{:<40} {:<10}".format("Consistent (same letter):", metrics.get("daily_omni_ta_consistent", 0)))
+ print("{:<40} {:<10}".format("Mismatch:", metrics.get("daily_omni_ta_mismatch", 0)))
+ r = metrics.get("daily_omni_ta_consistency_rate")
+ if r is not None:
+ print("{:<40} {:<10.4f}".format("Consistency rate (of both parsed):", r))
+ print(
+ "{:<40} {:<10}".format(
+ "Text unparsed (among w/ audio):",
+ metrics.get("daily_omni_ta_text_unparsed", 0),
+ )
+ )
+ print(
+ "{:<40} {:<10}".format(
+ "Audio unparsed (among w/ audio):",
+ metrics.get("daily_omni_ta_audio_unparsed", 0),
+ )
+ )
diff --git a/vllm_omni/benchmarks/data_modules/seed_tts_dataset.py b/vllm_omni/benchmarks/data_modules/seed_tts_dataset.py
new file mode 100644
index 0000000000..ca6de4cb20
--- /dev/null
+++ b/vllm_omni/benchmarks/data_modules/seed_tts_dataset.py
@@ -0,0 +1,272 @@
+"""Seed-TTS zero-shot evaluation-style prompts for ``vllm bench serve``.
+
+Loads rows from the `meta.lst` format used in `BytedanceSpeech/seed-tts-eval`_ (or any
+HuggingFace dataset repo with the same layout)::
+
+ utt_id|prompt_transcript|prompt_wav_relative_path|text_to_synthesize
+
+Each benchmark request supplies target text plus ``ref_text`` / ``ref_audio`` (Qwen3-TTS ``Base`` /
+voice clone), merged into the JSON body. By default ``ref_audio`` is an inline ``data:`` URL so
+the server does not need ``--allowed-local-media-path``. Use ``--seed-tts-file-ref-audio`` for
+``file://`` (smaller bodies; requires that flag). Use ``--backend openai-audio-speech``
+(``/v1/audio/speech``) or ``--backend openai-chat-omni`` (``/v1/chat/completions`` with the same
+fields on the body plus a Qwen3-Omni-style ``system`` message and the target text as ``user`` content).
+
+.. _BytedanceSpeech/seed-tts-eval: https://github.com/BytedanceSpeech/seed-tts-eval
+"""
+
+from __future__ import annotations
+
+import base64
+import logging
+import random
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+
+from vllm.benchmarks.datasets import BenchmarkDataset, SampleRequest
+from vllm.tokenizers import TokenizerLike
+from vllm.tokenizers.hf import get_cached_tokenizer
+
+logger = logging.getLogger(__name__)
+
+# Matches Qwen3-Omni serving examples (``openai_chat_completion_client_for_multimodal_generation`` /
+# ``qwen3_omni/gradio_demo``) plus explicit TTS / voice-clone instructions for chat completions.
+SEED_TTS_DEFAULT_OMNI_SYSTEM_PROMPT = (
+ "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
+ "capable of perceiving auditory and visual inputs, as well as generating text and speech.\n"
+ "For this request you act as a text-to-speech engine with zero-shot voice cloning: "
+ "the API provides reference audio and its transcript (ref_audio, ref_text) and task_type Base. "
+ "The user message is the exact text you must speak. "
+ "Synthesize natural speech in the same language as that user text, "
+ "matching the timbre, prosody, and speaking style of the reference audio while reading the new content clearly."
+)
+
+
+@dataclass
+class SeedTTSSampleRequest(SampleRequest):
+ """``SampleRequest`` with per-row fields merged into ``/v1/audio/speech`` JSON."""
+
+ #: Shallow-merged into ``RequestFuncInput.extra_body`` (ref_audio, ref_text, task_type, …).
+ seed_tts_speech_extra: dict[str, Any] | None = None
+ seed_tts_utterance_id: str = ""
+ seed_tts_locale: str = ""
+ #: For ``openai-chat-omni``: becomes the chat ``system`` message (Qwen3-Omni + TTS behavior).
+ seed_tts_system_prompt: str = ""
+ #: Local path to reference prompt WAV (for SIM vs. synthesized PCM in ``seed_tts_eval``).
+ seed_tts_ref_wav_path: str = ""
+
+
+@dataclass
+class _SeedTTSRow:
+ utterance_id: str
+ ref_text: str
+ prompt_wav_rel: str
+ target_text: str
+
+
+def _parse_meta_line(line: str) -> _SeedTTSRow | None:
+ line = line.strip()
+ if not line or line.startswith("#"):
+ return None
+ parts = line.split("|")
+ if len(parts) < 4:
+ logger.warning("Skipping malformed meta.lst line (need 4 '|'-fields): %r", line[:120])
+ return None
+ utt_id, ref_text, wav_rel, target = parts[0], parts[1], parts[2], parts[3]
+ if not target.strip():
+ return None
+ return _SeedTTSRow(
+ utterance_id=utt_id.strip(),
+ ref_text=ref_text.strip(),
+ prompt_wav_rel=wav_rel.strip(),
+ target_text=target.strip(),
+ )
+
+
+def _load_meta_rows(meta_file: Path) -> list[_SeedTTSRow]:
+ text = meta_file.read_text(encoding="utf-8")
+ rows: list[_SeedTTSRow] = []
+ for line in text.splitlines():
+ r = _parse_meta_line(line)
+ if r is not None:
+ rows.append(r)
+ return rows
+
+
+def resolve_seed_tts_root(dataset_path: str | None, *, explicit_root: str | None) -> Path:
+ """Return directory containing ``{locale}/meta.lst`` and ``{locale}/prompt-wavs/``."""
+ if explicit_root:
+ root = Path(explicit_root).expanduser().resolve()
+ if not root.is_dir():
+ raise FileNotFoundError(f"--seed-tts-root is not a directory: {root}")
+ return root
+
+ if not dataset_path:
+ raise ValueError("Seed-TTS requires --dataset-path (HF repo id or local root) or --seed-tts-root.")
+
+ p = Path(dataset_path).expanduser()
+ if p.exists() and p.is_dir():
+ return p.resolve()
+
+ repo_id = dataset_path.strip()
+ try:
+ from huggingface_hub import snapshot_download
+ except ImportError as e:
+ raise ImportError(
+ "Install huggingface_hub to download Seed-TTS from the Hub, or clone the dataset "
+ "locally and pass --dataset-path / --seed-tts-root to that directory."
+ ) from e
+ cache = snapshot_download(repo_id=repo_id, repo_type="dataset")
+ return Path(cache).resolve()
+
+
+def _ref_audio_payload(wav_path: Path, *, inline: bool) -> str:
+ if inline:
+ raw = wav_path.read_bytes()
+ b64 = base64.b64encode(raw).decode("ascii")
+ return f"data:audio/wav;base64,{b64}"
+ return wav_path.expanduser().resolve().as_uri()
+
+
+class SeedTTSDataset(BenchmarkDataset):
+ """Seed-TTS-style zero-shot TTS rows for throughput/latency benchmarking.
+
+ Args:
+ dataset_path: HuggingFace dataset repo id (``org/dataset``) or local directory with
+ ``en/meta.lst`` (and ``zh/meta.lst`` if using zh).
+ locale: ``en`` or ``zh`` — which subfolder under the root to read.
+ inline_ref_audio: If True (default), embed prompt WAV as ``data:audio/wav;base64,...``
+ so Qwen3-TTS / ``/v1/audio/speech`` works without server
+ ``--allowed-local-media-path``. If False, use ``file://`` (smaller
+ requests; server must set ``--allowed-local-media-path`` to the dataset root).
+ seed_tts_root: Optional override for the root directory (same layout as HF dataset).
+ system_prompt: Optional override for the chat system message when using
+ ``--backend openai-chat-omni``; defaults to :data:`SEED_TTS_DEFAULT_OMNI_SYSTEM_PROMPT`.
+ """
+
+ IS_MULTIMODAL = False
+ DEFAULT_OUTPUT_LEN = 2048
+
+ def __init__(
+ self,
+ dataset_path: str,
+ random_seed: int = 0,
+ locale: str = "en",
+ inline_ref_audio: bool = True,
+ seed_tts_root: str | None = None,
+ system_prompt: str | None = None,
+ disable_shuffle: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ if locale not in ("en", "zh"):
+ raise ValueError("locale must be 'en' or 'zh'")
+ self.locale = locale
+ self.inline_ref_audio = inline_ref_audio
+ self._explicit_root = seed_tts_root
+ sp = (system_prompt or "").strip()
+ self._system_prompt = sp if sp else SEED_TTS_DEFAULT_OMNI_SYSTEM_PROMPT
+ super().__init__(
+ dataset_path=dataset_path,
+ random_seed=random_seed,
+ disable_shuffle=disable_shuffle,
+ **kwargs,
+ )
+ self._root = resolve_seed_tts_root(self.dataset_path, explicit_root=self._explicit_root)
+ self._rows: list[_SeedTTSRow] = []
+ self.load_data()
+
+ def load_data(self) -> None:
+ meta = self._root / self.locale / "meta.lst"
+ if not meta.is_file():
+ raise FileNotFoundError(
+ f"Seed-TTS meta not found: {meta}. "
+ f"Expected layout from seed-tts-eval (e.g. {self._root}/{self.locale}/meta.lst)."
+ )
+ self._rows = _load_meta_rows(meta)
+ if not self._rows:
+ raise ValueError(f"No valid rows in {meta}")
+ if not self.disable_shuffle:
+ rng = random.Random(self.random_seed)
+ rng.shuffle(self._rows)
+ self.data = self._rows
+ logger.info(
+ "Loaded Seed-TTS: root=%s locale=%s rows=%d inline_ref_audio=%s",
+ self._root,
+ self.locale,
+ len(self._rows),
+ self.inline_ref_audio,
+ )
+
+ def sample(
+ self,
+ tokenizer: TokenizerLike,
+ num_requests: int,
+ output_len: int | None = None,
+ request_id_prefix: str = "",
+ no_oversample: bool = False,
+ **kwargs: Any,
+ ) -> list[SampleRequest]:
+ if output_len is None:
+ output_len = self.DEFAULT_OUTPUT_LEN
+
+ tok = get_cached_tokenizer(tokenizer)
+ out: list[SampleRequest] = []
+ for i, row in enumerate(self._rows):
+ if len(out) >= num_requests:
+ break
+ wav_path = (self._root / self.locale / row.prompt_wav_rel).resolve()
+ if not wav_path.is_file():
+ logger.warning("Missing prompt wav for %s: %s", row.utterance_id, wav_path)
+ continue
+
+ target = row.target_text
+ prompt_len = len(tok.encode(f"{self._system_prompt}\n{target}"))
+ lang = "English" if self.locale == "en" else "Chinese"
+ ref_uri = _ref_audio_payload(wav_path, inline=self.inline_ref_audio)
+ speech_extra: dict[str, Any] = {
+ "ref_audio": ref_uri,
+ "ref_text": row.ref_text,
+ "task_type": "Base",
+ "language": lang,
+ "max_new_tokens": output_len,
+ }
+
+ out.append(
+ SeedTTSSampleRequest(
+ prompt=target,
+ prompt_len=prompt_len,
+ expected_output_len=output_len,
+ multi_modal_data=None,
+ request_id=f"{request_id_prefix}{i}",
+ seed_tts_speech_extra=speech_extra,
+ seed_tts_utterance_id=row.utterance_id,
+ seed_tts_locale=self.locale,
+ seed_tts_system_prompt=self._system_prompt,
+ seed_tts_ref_wav_path=str(wav_path),
+ )
+ )
+
+ logger.info("Seed-TTS: built %d requests (asked %d)", len(out), num_requests)
+ self.maybe_oversample_requests(out, num_requests, request_id_prefix, no_oversample)
+ return out
+
+
+def load_seed_tts_dataset(
+ dataset_path: str,
+ random_seed: int = 0,
+ locale: str = "en",
+ inline_ref_audio: bool = True,
+ seed_tts_root: str | None = None,
+ system_prompt: str | None = None,
+ **kwargs: Any,
+) -> SeedTTSDataset:
+ return SeedTTSDataset(
+ dataset_path=dataset_path,
+ random_seed=random_seed,
+ locale=locale,
+ inline_ref_audio=inline_ref_audio,
+ seed_tts_root=seed_tts_root,
+ system_prompt=system_prompt,
+ **kwargs,
+ )
diff --git a/vllm_omni/benchmarks/data_modules/seed_tts_eval.py b/vllm_omni/benchmarks/data_modules/seed_tts_eval.py
new file mode 100644
index 0000000000..d5f1b64709
--- /dev/null
+++ b/vllm_omni/benchmarks/data_modules/seed_tts_eval.py
@@ -0,0 +1,729 @@
+"""Seed-TTS WER aligned with Bytedance ``seed-tts-eval`` / ``run_wer.py``.
+
+Matches the published protocol (see Hugging Face dataset card and
+https://github.com/BytedanceSpeech/seed-tts-eval):
+
+- **EN**: ``openai/whisper-large-v3`` via ``transformers``, audio resampled to **16 kHz**
+ (same as ``run_wer.py``).
+- **ZH**: ``funasr`` **paraformer-zh**, hypothesis converted with **zhconv** to zh-cn.
+- **WER**: ``jiwer`` after punctuation stripping (``zhon.hanzi.punctuation`` + ``string.punctuation``,
+ preserving ``'``) and EN lowercasing / ZH per-character spacing. Supports jiwer 3.x
+ (``compute_measures``) and 4.x (``process_words``).
+
+- **SIM** (speaker similarity proxy): cosine similarity of L2-normalized mean-pooled **WavLM**
+ embeddings (reference prompt WAV vs. synthesized PCM), 16 kHz. Official ``cal_sim.sh`` uses
+ UniSpeech ``verification_pair_list_v2.py`` with a **fine-tuned** WavLM SV checkpoint — set
+ ``SEED_TTS_WAVLM_MODEL`` to another HF id if you need closer parity. Disable with
+ ``SEED_TTS_SIM_EVAL=0``. Optional: ``SEED_TTS_SIM_DEVICE`` (e.g. ``cpu``) to avoid GPU
+ issues when Whisper already uses CUDA; ``SEED_TTS_WAVLM_MIN_SAMPLES`` pads very short
+ waveforms so the WavLM CNN front-end does not fail.
+
+- **UTMOS** (predicted MOS from TorchScript): default ``balacoon/utmos`` → ``utmos.jit``
+ (Sarulab-style demo export). Uses ``torch`` + ``huggingface_hub`` only. Aggregate metrics
+ are over **all requests with captured PCM** (independent of ASR/WER). Non-finite scores are
+ dropped and counted as failures. Override repo/file via ``SEED_TTS_UTMOS_HF_REPO`` /
+ ``SEED_TTS_UTMOS_JIT_FILE``. **Device**: defaults to **CPU** when ``SEED_TTS_UTMOS_DEVICE``
+ is unset; set ``SEED_TTS_UTMOS_DEVICE=cuda:0`` (or ``cuda:1`` etc.) to run on GPU. The JIT
+ model is loaded directly onto the target device via ``map_location`` to avoid cross-device
+ issues (some PyTorch builds/Windows have problems moving TorchScript modules after load).
+ Forward uses **float32** waveform in ``[-1, 1]`` (same as the WER resampled array) so
+ tensor dtypes match JIT weights; using int16 triggers
+ ``RuntimeError: input type and weight type should be same`` on common exports. Disable
+ with ``SEED_TTS_UTMOS_EVAL=0``.
+
+Enable with ``SEED_TTS_WER_EVAL=1`` or ``--seed-tts-wer-eval``. Install optional deps::
+
+ pip install 'vllm-omni[seed-tts-eval]'
+
+Env: ``SEED_TTS_EVAL_DEVICE`` (e.g. ``cuda:0``, ``cpu``); ``SEED_TTS_HF_WHISPER_MODEL``
+defaults to ``openai/whisper-large-v3`` (override for debugging only).
+"""
+
+from __future__ import annotations
+
+import io
+import logging
+import math
+import os
+import statistics
+import string
+import tempfile
+import threading
+import wave
+from typing import Any
+
+import numpy as np
+from vllm.benchmarks.datasets import SampleRequest
+
+from vllm_omni.benchmarks.data_modules.seed_tts_dataset import SeedTTSSampleRequest
+
+logger = logging.getLogger(__name__)
+
+# Mirrors seed-tts-eval/run_wer.py
+OFFICIAL_WHISPER_HF_ID = "openai/whisper-large-v3"
+PARAFORMER_MODEL_ID = "paraformer-zh"
+
+_lock = threading.Lock()
+_device: str | None = None
+_en_processor = None
+_en_model = None
+_zh_paraformer = None
+_wavlm_model = None
+_wavlm_processor = None
+_wavlm_device: str | None = None
+_utmos_jit_model = None
+_utmos_jit_device: str | None = None
+_utmos_jit_load_failed = False
+_utmos_forward_warned = False
+
+
+def pcm_s16le_mono_to_wav_bytes(pcm: bytes, *, sample_rate: int = 24000) -> bytes:
+ buf = io.BytesIO()
+ with wave.open(buf, "wb") as wf:
+ wf.setnchannels(1)
+ wf.setsampwidth(2)
+ wf.setframerate(sample_rate)
+ wf.writeframes(pcm)
+ return buf.getvalue()
+
+
+def _get_eval_device() -> str:
+ explicit = os.environ.get("SEED_TTS_EVAL_DEVICE", "").strip()
+ if explicit:
+ return explicit
+ try:
+ import torch
+
+ return "cuda:0" if torch.cuda.is_available() else "cpu"
+ except ImportError:
+ return "cpu"
+
+
+def _punctuation_all() -> str:
+ from zhon.hanzi import punctuation
+
+ return punctuation + string.punctuation
+
+
+def _jiwer_wer(reference: str, hypothesis: str) -> float:
+ """Word-level WER; strings are normalized like ``run_wer.process_one``.
+
+ jiwer 4.x removed ``compute_measures`` (``ImportError``); fall back to ``process_words``.
+ """
+ try:
+ from jiwer import compute_measures
+
+ return float(compute_measures(reference, hypothesis)["wer"])
+ except ImportError:
+ import jiwer
+
+ out = jiwer.process_words(reference, hypothesis)
+ return float(out.wer)
+
+
+def process_one_official(hypo: str, truth: str, lang: str) -> tuple[float, str, str]:
+ """Same normalization + ``jiwer`` call as ``run_wer.process_one`` (hypo=ASR, truth=reference)."""
+ raw_truth = truth
+ raw_hypo = hypo
+ truth_n = truth
+ hypo_n = hypo
+ for x in _punctuation_all():
+ if x == "'":
+ continue
+ truth_n = truth_n.replace(x, "")
+ hypo_n = hypo_n.replace(x, "")
+ truth_n = truth_n.replace(" ", " ")
+ hypo_n = hypo_n.replace(" ", " ")
+ if lang == "zh":
+ truth_n = " ".join([x for x in truth_n])
+ hypo_n = " ".join([x for x in hypo_n])
+ elif lang == "en":
+ truth_n = truth_n.lower()
+ hypo_n = hypo_n.lower()
+ else:
+ raise ValueError(f"unsupported lang {lang!r}")
+ wer = _jiwer_wer(truth_n, hypo_n)
+ return wer, raw_truth, raw_hypo
+
+
+def _pcm_s16le_to_f32_16k(pcm: bytes, pcm_sample_rate: int = 24000) -> np.ndarray:
+ import scipy.signal
+
+ if not pcm:
+ return np.zeros(0, dtype=np.float32)
+ raw = np.frombuffer(pcm, dtype=np.int16).astype(np.float32) / 32768.0
+ target_len = int(len(raw) * 16000 / pcm_sample_rate)
+ if target_len <= 0:
+ return np.zeros(0, dtype=np.float32)
+ return scipy.signal.resample(raw, target_len).astype(np.float32)
+
+
+def _eval_submetric_enabled(env_name: str, *, default: bool = True) -> bool:
+ raw = os.environ.get(env_name, "").strip().lower()
+ if raw in ("0", "false", "no", "off"):
+ return False
+ if raw in ("1", "true", "yes", "on"):
+ return True
+ return default
+
+
+def _audio_path_to_f32_16k(path: str) -> np.ndarray:
+ import scipy.signal
+ import soundfile as sf
+
+ data, sr = sf.read(path, dtype="float32", always_2d=True)
+ mono = np.mean(data, axis=1).astype(np.float32)
+ if int(sr) == 16000:
+ return mono
+ target_len = max(1, int(len(mono) * 16000 / int(sr)))
+ return scipy.signal.resample(mono, target_len).astype(np.float32)
+
+
+def _ensure_wavlm_sim() -> None:
+ global _wavlm_model, _wavlm_processor, _wavlm_device
+ with _lock:
+ if _wavlm_model is not None:
+ return
+ from transformers import AutoFeatureExtractor, AutoModel
+
+ mid = os.environ.get("SEED_TTS_WAVLM_MODEL", "microsoft/wavlm-base-plus").strip() or "microsoft/wavlm-base-plus"
+ _wavlm_device = os.environ.get("SEED_TTS_SIM_DEVICE", "").strip() or _get_eval_device()
+ logger.warning(
+ "Loading WavLM %r on %s for Seed-TTS SIM (embedding cosine; not identical to "
+ "seed-tts-eval UniSpeech SV checkpoint).",
+ mid,
+ _wavlm_device,
+ )
+ _wavlm_processor = AutoFeatureExtractor.from_pretrained(mid)
+ _wavlm_model = AutoModel.from_pretrained(mid).to(_wavlm_device)
+ _wavlm_model.eval()
+
+
+def _wavlm_prepare_waveform(wav: np.ndarray) -> np.ndarray:
+ """Trim, pad to a minimum length WavLM/Wav2Vec2 CNN stack accepts, float32 mono."""
+ max_sec = float(os.environ.get("SEED_TTS_WAVLM_MAX_SECONDS", "30"))
+ cap = int(max_sec * 16000)
+ w = np.asarray(wav, dtype=np.float32).reshape(-1)
+ if len(w) == 0:
+ return w
+ if len(w) > cap:
+ w = w[:cap].copy()
+ # Very short clips make the strided conv front-end fail (shape / empty time dim).
+ min_samples = int(os.environ.get("SEED_TTS_WAVLM_MIN_SAMPLES", "4000"))
+ if len(w) < min_samples:
+ w = np.pad(w, (0, min_samples - len(w)), mode="constant")
+ return w
+
+
+def _wavlm_mean_embedding_f32_16k(wav: np.ndarray) -> np.ndarray | None:
+ import torch
+
+ _ensure_wavlm_sim()
+ w = _wavlm_prepare_waveform(wav)
+ if len(w) == 0:
+ return None
+ assert _wavlm_processor is not None and _wavlm_model is not None and _wavlm_device is not None
+ # Single utterance: avoid padding=True (adds zeros that distort mean pooling). Still pass
+ # attention_mask when the extractor provides it (sample-level; do not mix with hidden length).
+ try:
+ inputs = _wavlm_processor(
+ w,
+ sampling_rate=16000,
+ return_tensors="pt",
+ padding=False,
+ return_attention_mask=True,
+ )
+ except TypeError:
+ inputs = _wavlm_processor(
+ w,
+ sampling_rate=16000,
+ return_tensors="pt",
+ padding=False,
+ )
+ iv = inputs["input_values"].to(_wavlm_device)
+ am = inputs.get("attention_mask")
+ if am is not None:
+ am = am.to(_wavlm_device)
+ with torch.inference_mode():
+ out = _wavlm_model(iv, attention_mask=am)
+ h = out.last_hidden_state
+ v = h.mean(dim=1).squeeze(0).float().cpu().numpy()
+ n = float(np.linalg.norm(v))
+ if not np.isfinite(n) or n < 1e-8:
+ return None
+ return (v / n).astype(np.float32)
+
+
+def _cosine_similarity_unit_vectors(a: np.ndarray, b: np.ndarray) -> float:
+ return float(np.dot(a, b))
+
+
+def _ensure_utmos_jit_model() -> Any | None:
+ """Load UTMOS as TorchScript (``balacoon/utmos`` style): no ``import utmos`` / fairseq."""
+ global _utmos_jit_model, _utmos_jit_device, _utmos_jit_load_failed
+ with _lock:
+ if _utmos_jit_load_failed:
+ return None
+ if _utmos_jit_model is not None:
+ return _utmos_jit_model
+ try:
+ import torch
+ from huggingface_hub import hf_hub_download
+
+ repo = os.environ.get("SEED_TTS_UTMOS_HF_REPO", "balacoon/utmos").strip() or "balacoon/utmos"
+ fname = os.environ.get("SEED_TTS_UTMOS_JIT_FILE", "utmos.jit").strip() or "utmos.jit"
+ logger.warning(
+ "Loading UTMOS TorchScript from Hugging Face %r file %r (one-time download/cache)...",
+ repo,
+ fname,
+ )
+ path = hf_hub_download(repo_id=repo, filename=fname, repo_type="model")
+
+ # TODO The model weights in UTMOS must be loaded in cuda:0; otherwise, the model execution will fail.
+ want = "cuda:0"
+ if want.startswith("cuda") and torch.cuda.is_available():
+ idx = want.split(":")[-1] if ":" in want else "0"
+ target_dev = f"cuda:{idx}"
+ else:
+ target_dev = "cpu"
+
+ try:
+ m = torch.jit.load(path, map_location=target_dev)
+ m.eval()
+ _utmos_jit_device = target_dev
+ except Exception as load_e:
+ if target_dev.startswith("cuda"):
+ logger.warning(
+ "UTMOS JIT load on %s failed (%s), retrying on CPU...",
+ target_dev,
+ load_e,
+ )
+ m = torch.jit.load(path, map_location="cpu")
+ m.eval()
+ _utmos_jit_device = "cpu"
+ else:
+ raise
+ _utmos_jit_model = m
+ except Exception as e:
+ logger.warning(
+ "UTMOS JIT unavailable (install torch + huggingface_hub; check HF access): %s",
+ e,
+ )
+ _utmos_jit_load_failed = True
+ return None
+ return _utmos_jit_model
+
+
+def _utmos_predict_f32_16k(wav_f32: np.ndarray) -> float | None:
+ """MOS from JIT model; input is float32 mono @ 16 kHz in ``[-1, 1]`` (WER pipeline).
+
+ ``balacoon/utmos`` demos sometimes use int16 numpy, but the exported ``.jit`` weights are
+ float32; passing int16 tensors causes: "RuntimeError: ... input type and weight type
+ should be same".
+ """
+ import torch
+
+ if len(wav_f32) == 0:
+ return None
+ model = _ensure_utmos_jit_model()
+ if model is None:
+ return None
+ # Infer model's device from its first parameter/buffer to guarantee input sits with weights.
+ try:
+ model_dev = next(model.parameters()).device
+ except StopIteration:
+ try:
+ model_dev = next(model.buffers()).device
+ except StopIteration:
+ model_dev = torch.device("cpu")
+ w = np.ascontiguousarray(wav_f32, dtype=np.float32)
+ x = torch.from_numpy(w).unsqueeze(0).to(device=model_dev, dtype=torch.float32)
+ with torch.no_grad():
+ out = model(x)
+ val = float(out.reshape(-1)[0].item())
+ if not math.isfinite(val):
+ return None
+ return val
+
+
+def _ensure_en_asr() -> None:
+ global _en_processor, _en_model, _device
+ with _lock:
+ if _en_processor is not None:
+ return
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
+
+ _device = _get_eval_device()
+ mid = os.environ.get("SEED_TTS_HF_WHISPER_MODEL", OFFICIAL_WHISPER_HF_ID).strip() or OFFICIAL_WHISPER_HF_ID
+ logger.warning(
+ "Loading Seed-TTS eval Whisper HF model %r on %s (one-time, seed-tts-eval protocol)...",
+ mid,
+ _device,
+ )
+ _en_processor = WhisperProcessor.from_pretrained(mid)
+ _en_model = WhisperForConditionalGeneration.from_pretrained(mid).to(_device)
+ _en_model.eval()
+
+
+def _ensure_zh_asr() -> None:
+ global _zh_paraformer, _device
+ with _lock:
+ if _zh_paraformer is not None:
+ return
+ from funasr import AutoModel
+
+ _device = _get_eval_device()
+ logger.warning(
+ "Loading Seed-TTS eval Paraformer %r on %s (one-time, seed-tts-eval protocol)...",
+ PARAFORMER_MODEL_ID,
+ _device,
+ )
+ try:
+ _zh_paraformer = AutoModel(model=PARAFORMER_MODEL_ID, device=_device)
+ except TypeError:
+ _zh_paraformer = AutoModel(model=PARAFORMER_MODEL_ID)
+
+
+def _transcribe_en_f32_16k(wav_f32: np.ndarray) -> str:
+ import torch
+
+ _ensure_en_asr()
+ if len(wav_f32) == 0:
+ return ""
+ with _lock:
+ assert _en_processor is not None and _en_model is not None and _device is not None
+ inputs = _en_processor(wav_f32, sampling_rate=16000, return_tensors="pt")
+ input_features = inputs.input_features.to(_device)
+ with torch.no_grad():
+ try:
+ forced = _en_processor.get_decoder_prompt_ids(language="english", task="transcribe")
+ predicted_ids = _en_model.generate(input_features, forced_decoder_ids=forced)
+ except Exception:
+ predicted_ids = _en_model.generate(
+ input_features,
+ language="english",
+ task="transcribe",
+ )
+ text = _en_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
+ return (text or "").strip()
+
+
+def _transcribe_zh_wav_path(wav_path: str) -> str:
+ import zhconv
+
+ _ensure_zh_asr()
+ with _lock:
+ assert _zh_paraformer is not None
+ res = _zh_paraformer.generate(input=wav_path, batch_size_s=300)
+ transcription = res[0]["text"] if res else ""
+ return zhconv.convert(transcription, "zh-cn").strip()
+
+
+def _missing_deps_message(lang: str) -> str | None:
+ try:
+ import jiwer # noqa: F401
+ from zhon.hanzi import punctuation # noqa: F401
+ except ImportError as e:
+ return f"Seed-TTS WER eval needs jiwer and zhon ({e!s}). Install: pip install 'vllm-omni[seed-tts-eval]'"
+ try:
+ import scipy.signal # noqa: F401
+ import soundfile # noqa: F401
+ except ImportError as e:
+ return f"Seed-TTS WER eval needs scipy and soundfile ({e!s})."
+ if lang == "en":
+ try:
+ import torch # noqa: F401
+ from transformers import WhisperForConditionalGeneration # noqa: F401
+ except ImportError as e:
+ return f"English WER needs torch and transformers ({e!s}). Install: pip install 'vllm-omni[seed-tts-eval]'"
+ else:
+ try:
+ import zhconv # noqa: F401
+ from funasr import AutoModel # noqa: F401
+ except ImportError as e:
+ return f"Chinese WER needs funasr and zhconv ({e!s}). Install: pip install 'vllm-omni[seed-tts-eval]'"
+ return None
+
+
+def compute_seed_tts_wer_metrics(
+ input_requests: list[SampleRequest],
+ outputs: list[Any],
+ *,
+ include_per_item: bool = False,
+) -> dict[str, Any] | None:
+ """If all requests are :class:`SeedTTSSampleRequest`, run seed-tts-eval-style WER."""
+ global _utmos_forward_warned
+ if not input_requests or len(input_requests) != len(outputs):
+ return None
+ if not all(isinstance(r, SeedTTSSampleRequest) for r in input_requests):
+ return None
+
+ first = input_requests[0]
+ assert isinstance(first, SeedTTSSampleRequest)
+ lang = "zh" if (first.seed_tts_locale or "en").lower().startswith("zh") else "en"
+
+ setup_err = _missing_deps_message(lang)
+ if setup_err:
+ logger.error("%s", setup_err)
+ return {
+ "seed_tts_eval_setup_error": setup_err,
+ "seed_tts_eval_protocol": "seed-tts-eval",
+ "seed_tts_content_evaluated": 0,
+ "seed_tts_content_error_mean": None,
+ "seed_tts_content_error_median": None,
+ "seed_tts_request_failed": 0,
+ "seed_tts_no_pcm": 0,
+ "seed_tts_asr_failed": 0,
+ "seed_tts_content_metric": "wer",
+ }
+
+ import soundfile as sf
+
+ errs: list[float] = []
+ items: list[dict[str, Any]] = []
+ asr_failed = 0
+ no_pcm = 0
+ request_failed = 0
+ sim_values: list[float] = []
+ utmos_values: list[float] = []
+ sim_failed = 0
+ sim_skipped_no_ref = 0
+ utmos_failed = 0
+ utmos_on = _eval_submetric_enabled("SEED_TTS_UTMOS_EVAL", default=True)
+
+ for req, out in zip(input_requests, outputs, strict=True):
+ assert isinstance(req, SeedTTSSampleRequest)
+ ref = req.prompt
+ locale = req.seed_tts_locale or "en"
+ row_lang = "zh" if locale.lower().startswith("zh") else "en"
+ utmos_v: float | None = None
+
+ if not out.success:
+ request_failed += 1
+ if include_per_item:
+ items.append(
+ {
+ "utterance_id": req.seed_tts_utterance_id,
+ "locale": locale,
+ "error": "request_failed",
+ "detail": (out.error or "")[:500],
+ }
+ )
+ continue
+
+ pcm = getattr(out, "tts_output_pcm_bytes", None)
+ if not pcm:
+ no_pcm += 1
+ if include_per_item:
+ items.append(
+ {
+ "utterance_id": req.seed_tts_utterance_id,
+ "locale": locale,
+ "error": "no_pcm",
+ }
+ )
+ continue
+
+ wav_16k = _pcm_s16le_to_f32_16k(pcm)
+ if len(wav_16k) == 0:
+ asr_failed += 1
+ if include_per_item:
+ items.append(
+ {
+ "utterance_id": req.seed_tts_utterance_id,
+ "locale": locale,
+ "error": "empty_audio",
+ }
+ )
+ continue
+
+ # UTMOS scores synthesized audio only; do not gate on ASR/WER (those can fail independently).
+ if utmos_on:
+ try:
+ utmos_v = _utmos_predict_f32_16k(wav_16k)
+ if utmos_v is not None:
+ utmos_values.append(utmos_v)
+ elif not _utmos_jit_load_failed:
+ utmos_failed += 1
+ except Exception:
+ if not _utmos_forward_warned:
+ _utmos_forward_warned = True
+ logger.warning(
+ "UTMOS JIT forward failed (first utterance=%s; set logging DEBUG for "
+ "full trace). Check sample rate (16 kHz), input shape, or "
+ "SEED_TTS_UTMOS_DEVICE.",
+ req.seed_tts_utterance_id,
+ exc_info=True,
+ )
+ else:
+ logger.debug(
+ "UTMOS forward failed for %s",
+ req.seed_tts_utterance_id,
+ exc_info=True,
+ )
+ utmos_failed += 1
+
+ try:
+ if row_lang == "en":
+ hyp = _transcribe_en_f32_16k(wav_16k)
+ else:
+ fd, tmp_wav = tempfile.mkstemp(suffix=".wav")
+ os.close(fd)
+ try:
+ sf.write(tmp_wav, wav_16k, 16000, subtype="PCM_16")
+ hyp = _transcribe_zh_wav_path(tmp_wav)
+ finally:
+ try:
+ os.unlink(tmp_wav)
+ except OSError:
+ pass
+ except Exception as e:
+ logger.exception("Seed-TTS ASR failed for %s", req.seed_tts_utterance_id)
+ asr_failed += 1
+ if include_per_item:
+ items.append(
+ {
+ "utterance_id": req.seed_tts_utterance_id,
+ "locale": locale,
+ "error": "asr_exception",
+ "detail": str(e)[:500],
+ }
+ )
+ continue
+
+ if not hyp:
+ asr_failed += 1
+ if include_per_item:
+ items.append(
+ {
+ "utterance_id": req.seed_tts_utterance_id,
+ "locale": locale,
+ "error": "empty_asr",
+ }
+ )
+ continue
+
+ try:
+ wer, raw_truth, raw_hypo = process_one_official(hyp, ref, row_lang)
+ except Exception as e:
+ logger.warning("jiwer/normalize failed for %s: %s", req.seed_tts_utterance_id, e)
+ asr_failed += 1
+ if include_per_item:
+ items.append(
+ {
+ "utterance_id": req.seed_tts_utterance_id,
+ "locale": locale,
+ "error": "wer_compute_failed",
+ "detail": str(e)[:500],
+ }
+ )
+ continue
+
+ errs.append(wer)
+ sim_v: float | None = None
+
+ if _eval_submetric_enabled("SEED_TTS_SIM_EVAL", default=True):
+ ref_path = getattr(req, "seed_tts_ref_wav_path", "") or ""
+ if ref_path and os.path.isfile(ref_path):
+ try:
+ ref_wav = _audio_path_to_f32_16k(ref_path)
+ e_ref = _wavlm_mean_embedding_f32_16k(ref_wav)
+ e_hyp = _wavlm_mean_embedding_f32_16k(wav_16k)
+ if e_ref is not None and e_hyp is not None:
+ sim_v = _cosine_similarity_unit_vectors(e_ref, e_hyp)
+ sim_values.append(sim_v)
+ except Exception as e:
+ logger.warning(
+ "SIM embedding failed for utterance=%s: %s: %s",
+ req.seed_tts_utterance_id,
+ type(e).__name__,
+ e,
+ )
+ sim_failed += 1
+ else:
+ sim_skipped_no_ref += 1
+
+ if include_per_item:
+ row: dict[str, Any] = {
+ "utterance_id": req.seed_tts_utterance_id,
+ "locale": locale,
+ "wer": wer,
+ "reference_raw": raw_truth,
+ "asr_raw": raw_hypo,
+ }
+ if sim_v is not None:
+ row["sim"] = sim_v
+ if utmos_v is not None:
+ row["utmos"] = utmos_v
+ items.append(row)
+
+ result: dict[str, Any] = {
+ "seed_tts_eval_protocol": "seed-tts-eval",
+ "seed_tts_content_evaluated": len(errs),
+ "seed_tts_content_error_mean": statistics.fmean(errs) if errs else None,
+ "seed_tts_content_error_median": statistics.median(errs) if errs else None,
+ "seed_tts_request_failed": request_failed,
+ "seed_tts_no_pcm": no_pcm,
+ "seed_tts_asr_failed": asr_failed,
+ "seed_tts_content_metric": "wer",
+ "seed_tts_sim_evaluated": len(sim_values),
+ "seed_tts_sim_mean": statistics.fmean(sim_values) if sim_values else None,
+ "seed_tts_sim_median": statistics.median(sim_values) if sim_values else None,
+ "seed_tts_sim_failed": sim_failed,
+ "seed_tts_sim_skipped_no_ref": sim_skipped_no_ref,
+ "seed_tts_utmos_evaluated": len(utmos_values),
+ "seed_tts_utmos_mean": statistics.fmean(utmos_values) if utmos_values else None,
+ "seed_tts_utmos_median": statistics.median(utmos_values) if utmos_values else None,
+ "seed_tts_utmos_failed": utmos_failed,
+ }
+ if include_per_item:
+ result["seed_tts_wer_eval_items"] = items
+ return result
+
+
+def print_seed_tts_wer_summary(metrics: dict[str, Any]) -> None:
+ setup = metrics.get("seed_tts_eval_setup_error")
+ if setup:
+ print("{s:{c}^{n}}".format(s=" Seed-TTS eval (seed-tts-eval protocol) ", n=50, c="="))
+ print(setup)
+ return
+
+ ev = int(metrics.get("seed_tts_content_evaluated", 0) or 0)
+ rf = int(metrics.get("seed_tts_request_failed", 0) or 0)
+ npc = int(metrics.get("seed_tts_no_pcm", 0) or 0)
+ af = int(metrics.get("seed_tts_asr_failed", 0) or 0)
+ sim_ev = int(metrics.get("seed_tts_sim_evaluated", 0) or 0)
+ ut_ev = int(metrics.get("seed_tts_utmos_evaluated", 0) or 0)
+ if ev == 0 and rf == 0 and npc == 0 and af == 0 and sim_ev == 0 and ut_ev == 0:
+ return
+ print("{s:{c}^{n}}".format(s=" Seed-TTS eval (seed-tts-eval protocol) ", n=50, c="="))
+ print("{:<40} {:<10}".format("Evaluated (WER, lower is better):", ev))
+ mean = metrics.get("seed_tts_content_error_mean")
+ if mean is not None:
+ print("{:<40} {:<10.4f}".format("Mean WER:", float(mean)))
+ med = metrics.get("seed_tts_content_error_median")
+ if med is not None:
+ print("{:<40} {:<10.4f}".format("Median WER:", float(med)))
+ print("{:<40} {:<10}".format("Request failed:", metrics.get("seed_tts_request_failed", 0)))
+ print("{:<40} {:<10}".format("No PCM captured:", metrics.get("seed_tts_no_pcm", 0)))
+ print("{:<40} {:<10}".format("ASR / WER failed:", metrics.get("seed_tts_asr_failed", 0)))
+ if sim_ev or metrics.get("seed_tts_sim_skipped_no_ref") or metrics.get("seed_tts_sim_failed"):
+ print("{:<40} {:<10}".format("SIM evaluated (higher ~ closer):", sim_ev))
+ sm = metrics.get("seed_tts_sim_mean")
+ if sm is not None:
+ print("{:<40} {:<10.4f}".format("Mean SIM:", float(sm)))
+ s_med = metrics.get("seed_tts_sim_median")
+ if s_med is not None:
+ print("{:<40} {:<10.4f}".format("Median SIM:", float(s_med)))
+ print("{:<40} {:<10}".format("SIM skipped (no ref path):", metrics.get("seed_tts_sim_skipped_no_ref", 0)))
+ print("{:<40} {:<10}".format("SIM embedding errors:", metrics.get("seed_tts_sim_failed", 0)))
+ if ut_ev or metrics.get("seed_tts_utmos_failed"):
+ print("{:<40} {:<10}".format("UTMOS evaluated (JIT MOS, higher better):", ut_ev))
+ um = metrics.get("seed_tts_utmos_mean")
+ if um is not None:
+ print("{:<40} {:<10.4f}".format("Mean UTMOS:", float(um)))
+ u_med = metrics.get("seed_tts_utmos_median")
+ if u_med is not None:
+ print("{:<40} {:<10.4f}".format("Median UTMOS:", float(u_med)))
+ print("{:<40} {:<10}".format("UTMOS errors:", metrics.get("seed_tts_utmos_failed", 0)))
+ print("=" * 50)
diff --git a/vllm_omni/benchmarks/patch/__init__.py b/vllm_omni/benchmarks/patch/__init__.py
index e69de29bb2..ca6b41ba8f 100644
--- a/vllm_omni/benchmarks/patch/__init__.py
+++ b/vllm_omni/benchmarks/patch/__init__.py
@@ -0,0 +1,3 @@
+"""Omni benchmark monkey-patches (side effects in ``patch.patch``)."""
+
+from . import patch as _patch_module # noqa: F401
diff --git a/vllm_omni/benchmarks/patch/patch.py b/vllm_omni/benchmarks/patch/patch.py
index 343655df20..41aed09423 100644
--- a/vllm_omni/benchmarks/patch/patch.py
+++ b/vllm_omni/benchmarks/patch/patch.py
@@ -6,6 +6,7 @@
import os
import random
import ssl
+import sys
import time
import traceback
from collections.abc import Iterable
@@ -33,15 +34,245 @@
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
+
+from vllm_omni.benchmarks.data_modules.daily_omni_dataset import DailyOmniDataset, DailyOmniSampleRequest
from vllm_omni.benchmarks.data_modules.random_multi_modal_dataset import OmniRandomMultiModalDataset
+from vllm_omni.benchmarks.data_modules.seed_tts_dataset import (
+ SEED_TTS_DEFAULT_OMNI_SYSTEM_PROMPT,
+ SeedTTSDataset,
+ SeedTTSSampleRequest,
+)
get_samples_old = datasets.get_samples
+_DEFAULT_DAILY_OMNI_REPO = "liarliar/Daily-Omni"
+
+
+def _seed_tts_capture_pcm_for_wer() -> bool:
+ return os.environ.get("SEED_TTS_WER_EVAL", "").lower() in (
+ "1",
+ "true",
+ "yes",
+ )
+
+
+def _merge_extra_body_mm_kwargs(base: dict | None, overlay: dict | None) -> dict | None:
+ """Shallow-merge ``extra_body`` dicts; deep-merge ``mm_processor_kwargs`` if both set."""
+ if not base and not overlay:
+ return None
+ out = dict(base or {})
+ if not overlay:
+ return out
+ for k, v in overlay.items():
+ if k == "mm_processor_kwargs" and isinstance(v, dict):
+ prev = out.get("mm_processor_kwargs")
+ merged_kw = {**(prev if isinstance(prev, dict) else {}), **v}
+ out["mm_processor_kwargs"] = merged_kw
+ else:
+ out[k] = v
+ return out
+
+
+def _attach_daily_omni_to_request_func_input(sample: SampleRequest, rfi: RequestFuncInput) -> None:
+ """Apply per-request OpenAI fields (``mm_processor_kwargs``, messages) for Daily-Omni."""
+ if not isinstance(sample, DailyOmniSampleRequest):
+ return
+ rfi.extra_body = _merge_extra_body_mm_kwargs(rfi.extra_body, sample.omni_extra_body)
+ if sample.omni_chat_messages is not None:
+ setattr(rfi, "omni_chat_messages", sample.omni_chat_messages)
+ else:
+ setattr(rfi, "mm_position", sample.omni_chat_mm_position)
+
+
+def _attach_seed_tts_to_request_func_input(sample: SampleRequest, rfi: RequestFuncInput) -> None:
+ """Merge Seed-TTS per-row TTS fields (ref_audio, ref_text, task_type, …) into ``extra_body``.
+
+ Used by both ``/v1/audio/speech`` and ``/v1/chat/completions`` (flattened into JSON body).
+ For ``openai-chat-omni``, also sets ``omni_chat_messages`` (system + user) so Qwen3-Omni
+ follows the same role layout as official TTS / multimodal demos. ``/v1/audio/speech`` ignores
+ ``messages`` and only uses ``input`` + body fields.
+ Flags ``openai-chat-omni`` to request audio output and optionally export PCM for WER.
+ """
+ if not isinstance(sample, SeedTTSSampleRequest):
+ return
+ ex = sample.seed_tts_speech_extra
+ if not ex:
+ return
+ base = dict(rfi.extra_body) if rfi.extra_body else {}
+ base.update(ex)
+ rfi.extra_body = base
+ # Used by request funcs to force streaming TTS behavior and to export PCM when WER is on.
+ setattr(rfi, "seed_tts_row", True)
+ sys_prompt = (sample.seed_tts_system_prompt or "").strip() or SEED_TTS_DEFAULT_OMNI_SYSTEM_PROMPT
+ setattr(
+ rfi,
+ "omni_chat_messages",
+ [
+ {"role": "system", "content": [{"type": "text", "text": sys_prompt}]},
+ {"role": "user", "content": [{"type": "text", "text": sample.prompt}]},
+ ],
+ )
+
+
+def _daily_omni_repo_from_args(args) -> str | None:
+ """Resolve HuggingFace repo id for Daily-Omni from CLI args.
+
+ vLLM allows ``--dataset-path`` to be a local path while the real HF id is
+ passed via ``--hf-name``. Upstream ``get_samples`` for ``hf`` only matches
+ a fixed elif-chain and never discovers Omni's loader, so we must detect
+ Daily-Omni here using either field.
+ """
+ dp = getattr(args, "dataset_path", None)
+ hn = getattr(args, "hf_name", None)
+ if dp in DailyOmniDataset.SUPPORTED_DATASET_PATHS:
+ return dp
+ if hn in DailyOmniDataset.SUPPORTED_DATASET_PATHS:
+ return hn
+ return None
+
def get_samples(args, tokenizer):
- if args.backend not in ["openai-chat-omni", "openai-audio-speech"]:
+ # Daily-Omni: explicit dataset name, or hf + matching path/hf-name
+ is_daily_omni = args.dataset_name == "daily-omni" or (
+ args.dataset_name == "hf" and _daily_omni_repo_from_args(args) is not None
+ )
+ is_seed_tts = args.dataset_name == "seed-tts"
+
+ # Check if we need to handle omni-related backends/datasets
+ is_omni_backend = args.backend in ["openai-chat-omni", "openai-audio-speech", "daily-omni"]
+ is_omni_dataset = is_daily_omni or is_seed_tts or args.dataset_name == "random-mm"
+
+ if not is_omni_backend and not is_omni_dataset:
+ # Not an omni-related request, delegate to original implementation
return get_samples_old(args, tokenizer)
- elif args.dataset_name == "random-mm":
+
+ # Handle Daily-Omni dataset
+ if is_daily_omni:
+ # Support:
+ # --dataset-name daily-omni [--dataset-path liarliar/Daily-Omni]
+ # --dataset-name daily-omni --daily-omni-qa-json /path/to/qa.json (offline QA)
+ # --dataset-name hf --dataset-path liarliar/Daily-Omni
+ # --dataset-name hf --hf-name liarliar/Daily-Omni (dataset-path may be local)
+
+ # Validate backend supports multimodal (video)
+ if args.backend not in ["openai-chat-omni", "daily-omni"]:
+ raise ValueError(
+ f"Daily-Omni dataset requires a multimodal backend that supports video. "
+ f"Got backend='{args.backend}'. Please use '--backend openai-chat-omni'"
+ )
+
+ # Determine video directory if specified (for local video files)
+ video_dir = getattr(args, "daily_omni_video_dir", None)
+
+ # Get HF split (default to "train"; unused when loading from local qa.json)
+ dataset_split = getattr(args, "hf_split", None) or "train"
+
+ qa_json = getattr(args, "daily_omni_qa_json", None)
+ if isinstance(qa_json, str):
+ qa_json = qa_json.strip() or None
+
+ if qa_json is not None:
+ logger.info(
+ "Loading Daily-Omni dataset: qa_json=%s, video_dir=%s (Hub not used for QA)",
+ qa_json,
+ video_dir,
+ )
+ dataset = DailyOmniDataset(
+ qa_json_path=qa_json,
+ dataset_path=None,
+ dataset_split=dataset_split,
+ random_seed=args.seed,
+ video_dir=video_dir,
+ input_mode=getattr(args, "daily_omni_input_mode", "all"),
+ inline_local_video=getattr(args, "daily_omni_inline_local_video", False),
+ trust_remote_code=getattr(args, "trust_remote_code", False),
+ disable_shuffle=getattr(args, "disable_shuffle", False),
+ )
+ else:
+ repo_id = _daily_omni_repo_from_args(args)
+ if args.dataset_name == "daily-omni":
+ if repo_id is None:
+ repo_id = _DEFAULT_DAILY_OMNI_REPO
+ elif repo_id is None:
+ raise ValueError(
+ "Daily-Omni with --dataset-name hf requires "
+ f"--dataset-path {_DEFAULT_DAILY_OMNI_REPO} or "
+ f"--hf-name {_DEFAULT_DAILY_OMNI_REPO}."
+ )
+
+ logger.info(
+ "Loading Daily-Omni dataset: hf_repo=%s, split=%s, video_dir=%s",
+ repo_id,
+ dataset_split,
+ video_dir,
+ )
+
+ dataset = DailyOmniDataset(
+ dataset_path=repo_id,
+ dataset_split=dataset_split,
+ dataset_subset=getattr(args, "hf_subset", None),
+ random_seed=args.seed,
+ video_dir=video_dir,
+ input_mode=getattr(args, "daily_omni_input_mode", "all"),
+ inline_local_video=getattr(args, "daily_omni_inline_local_video", False),
+ trust_remote_code=getattr(args, "trust_remote_code", False),
+ no_stream=getattr(args, "no_stream", False),
+ disable_shuffle=getattr(args, "disable_shuffle", False),
+ )
+
+ out_len = getattr(args, "output_len", None)
+ if out_len is None:
+ out_len = getattr(args, "hf_output_len", None)
+ if out_len is None:
+ out_len = DailyOmniDataset.DEFAULT_OUTPUT_LEN
+
+ input_requests = dataset.sample(
+ tokenizer=tokenizer,
+ num_requests=args.num_prompts,
+ output_len=out_len,
+ request_id_prefix=args.request_id_prefix,
+ no_oversample=args.no_oversample,
+ )
+ return input_requests
+
+ if is_seed_tts:
+ if args.backend not in ("openai-audio-speech", "openai-chat-omni"):
+ raise ValueError(
+ "Seed-TTS requires --backend openai-audio-speech (POST /v1/audio/speech) or "
+ "--backend openai-chat-omni (POST /v1/chat/completions with ref_audio/ref_text). "
+ f"Got backend={args.backend!r}."
+ )
+ repo_id = getattr(args, "dataset_path", None) or getattr(args, "hf_name", None)
+ if not repo_id:
+ raise ValueError(
+ "Seed-TTS requires --dataset-path (HF dataset repo id or local directory) or "
+ "--hf-name for the Hub dataset id."
+ )
+
+ dataset = SeedTTSDataset(
+ dataset_path=repo_id,
+ random_seed=args.seed,
+ locale=getattr(args, "seed_tts_locale", "en"),
+ inline_ref_audio=not getattr(args, "seed_tts_file_ref_audio", False),
+ seed_tts_root=getattr(args, "seed_tts_root", None),
+ system_prompt=getattr(args, "seed_tts_system_prompt", None),
+ disable_shuffle=getattr(args, "disable_shuffle", False),
+ )
+ out_len = getattr(args, "output_len", None)
+ if out_len is None:
+ out_len = getattr(args, "hf_output_len", None)
+ if out_len is None:
+ out_len = SeedTTSDataset.DEFAULT_OUTPUT_LEN
+ return dataset.sample(
+ tokenizer=tokenizer,
+ num_requests=args.num_prompts,
+ output_len=out_len,
+ request_id_prefix=args.request_id_prefix,
+ no_oversample=args.no_oversample,
+ )
+
+ # Handle random-mm dataset (Omni's synthetic multimodal dataset)
+ if args.dataset_name == "random-mm":
dataset = OmniRandomMultiModalDataset(random_seed=args.seed, dataset_path=args.dataset_path)
input_requests = dataset.sample(
tokenizer=tokenizer,
@@ -64,6 +295,10 @@ def get_samples(args, tokenizer):
datasets.get_samples = get_samples
+_serve_mod = sys.modules.get("vllm.benchmarks.serve")
+if _serve_mod is not None:
+ _serve_mod.get_samples = get_samples
+
@dataclass
class MixRequestFuncOutput(RequestFuncOutput):
@@ -72,6 +307,9 @@ class MixRequestFuncOutput(RequestFuncOutput):
audio_frames: int = 0
audio_rtf: float = 0.0
text_latency: float = 0.0
+ #: Raw PCM s16le mono at 24 kHz for Seed-TTS WER: from ``/v1/audio/speech`` stream or
+ #: resampled export after ``openai-chat-omni`` audio deltas.
+ tts_output_pcm_bytes: bytes | None = None
async def async_request_openai_chat_omni_completions(
@@ -83,13 +321,17 @@ async def async_request_openai_chat_omni_completions(
api_url = request_func_input.api_url
_validate_api_url(api_url, "OpenAI Chat Completions API", "chat/completions")
- content = _get_chat_content(request_func_input, mm_position=mm_position)
+ omni_messages = getattr(request_func_input, "omni_chat_messages", None)
+ if omni_messages is not None:
+ messages_payload = omni_messages
+ else:
+ effective_mm_position = getattr(request_func_input, "mm_position", mm_position)
+ content = _get_chat_content(request_func_input, mm_position=effective_mm_position)
+ messages_payload = [{"role": "user", "content": content}]
payload = {
"model": request_func_input.model_name if request_func_input.model_name else request_func_input.model,
- "messages": [
- {"role": "user", "content": content},
- ],
+ "messages": messages_payload,
"temperature": 0.0,
"max_tokens": request_func_input.output_len,
"stream": True,
@@ -98,6 +340,10 @@ async def async_request_openai_chat_omni_completions(
},
}
_update_payload_common(payload, request_func_input)
+ # Seed-TTS via chat: voice-clone fields live on the body; ensure audio is streamed.
+ if getattr(request_func_input, "seed_tts_row", False):
+ if payload.get("modalities") is None:
+ payload["modalities"] = ["text", "audio"]
response_format = payload.get("response_format", "wav")
if response_format == "pcm":
@@ -143,7 +389,11 @@ async def async_request_openai_chat_omni_completions(
if response.status == 200:
handler = StreamedResponseHandler()
async for chunk_bytes in response.content.iter_any():
- chunk_bytes = chunk_bytes.strip()
+ # NOTE: Do NOT strip() here; TCP may fragment the SSE messages,
+ # so stripping here can cause problems depending on how it is split.
+ #
+ # Simple example: [b'data: ', b'{json}\n\n'] <- stripping the first
+ # chunk will break SSE parsing because the space after 'data:' is required.
if not chunk_bytes:
continue
@@ -163,7 +413,10 @@ async def async_request_openai_chat_omni_completions(
data = json.loads(chunk)
if choices := data.get("choices"):
modality = data.get("modality")
- content = choices[0]["delta"].get("content")
+ delta = choices[0].get("delta") or {}
+ content = delta.get("content")
+ if not content and isinstance(delta.get("audio"), dict):
+ content = delta["audio"].get("data")
if modality == "text":
# First token
if ttft == 0.0:
@@ -178,7 +431,7 @@ async def async_request_openai_chat_omni_completions(
if output.audio_ttfp == 0.0:
output.audio_ttfp = timestamp - st
audio_generate_time = timestamp - st
- if content != "":
+ if content:
audio_bytes = base64.b64decode(content)
seg = AudioSegment.from_file(io.BytesIO(audio_bytes))
if seg is not None:
@@ -210,6 +463,12 @@ async def async_request_openai_chat_omni_completions(
else:
output.audio_rtf = 0
logger.warning("Audio duration is zero")
+ if _seed_tts_capture_pcm_for_wer() and getattr(request_func_input, "seed_tts_row", False):
+ try:
+ seg = generated_audio.set_frame_rate(24000).set_channels(1).set_sample_width(2)
+ output.tts_output_pcm_bytes = bytes(seg.raw_data)
+ except Exception as ex:
+ logger.warning("seed_tts WER PCM export failed: %s", ex)
output.success = True
else:
output.error = response.reason or ""
@@ -264,6 +523,10 @@ async def async_request_openai_audio_speech(
"response_format": "pcm",
}
_update_payload_common(payload, request_func_input)
+ # Seed-TTS + WER: ``--extra-body`` may set stream=false / other formats; speech must stream PCM.
+ if getattr(request_func_input, "seed_tts_row", False) and _seed_tts_capture_pcm_for_wer():
+ payload["stream"] = True
+ payload["response_format"] = "pcm"
headers = {
"Content-Type": "application/json",
@@ -282,6 +545,8 @@ async def async_request_openai_audio_speech(
st = time.perf_counter()
output.start_time = st
total_pcm_bytes = 0
+ capture_wer_pcm = _seed_tts_capture_pcm_for_wer() and getattr(request_func_input, "seed_tts_row", False)
+ pcm_capture = bytearray() if capture_wer_pcm else None
try:
async with session.post(url=api_url, json=payload, headers=headers) as response:
if response.status == 200:
@@ -293,6 +558,8 @@ async def async_request_openai_audio_speech(
output.audio_ttfp = timestamp - st
output.ttft = output.audio_ttfp
total_pcm_bytes += len(chunk)
+ if pcm_capture is not None:
+ pcm_capture.extend(chunk)
end_time = time.perf_counter()
output.latency = end_time - st
@@ -305,6 +572,16 @@ async def async_request_openai_audio_speech(
else:
output.audio_rtf = 0
logger.warning("Audio duration is zero")
+ if pcm_capture is not None and pcm_capture:
+ output.tts_output_pcm_bytes = bytes(pcm_capture)
+ elif capture_wer_pcm:
+ ct = response.headers.get("Content-Type", "")
+ logger.warning(
+ "Seed-TTS WER: HTTP 200 but no PCM bytes (Content-Type=%r, url=%s). "
+ "Check stream=true and response_format=pcm on the server.",
+ ct,
+ api_url,
+ )
output.success = True
else:
output.error = response.reason or ""
@@ -327,6 +604,12 @@ async def async_request_openai_audio_speech(
if "openai-audio-speech" not in OPENAI_COMPATIBLE_BACKENDS:
OPENAI_COMPATIBLE_BACKENDS.append("openai-audio-speech")
+# Daily-Omni backend for audio-visual reasoning benchmark
+# Reuses openai-chat-omni completions for video+text understanding
+ASYNC_REQUEST_FUNCS["daily-omni"] = async_request_openai_chat_omni_completions
+if "daily-omni" not in OPENAI_COMPATIBLE_BACKENDS:
+ OPENAI_COMPATIBLE_BACKENDS.append("daily-omni")
+
# ruff: noqa: E402
# Prevent import order from causing patch failures
from vllm.benchmarks import serve
@@ -418,6 +701,8 @@ async def benchmark(
extra_headers=extra_headers,
extra_body=extra_body,
)
+ _attach_daily_omni_to_request_func_input(input_requests[0], test_input)
+ _attach_seed_tts_to_request_func_input(input_requests[0], test_input)
if ready_check_timeout_sec > 0:
test_output = await wait_for_endpoint(
@@ -480,6 +765,8 @@ async def warmup_limited_request_func():
extra_headers=extra_headers,
extra_body=extra_body,
)
+ _attach_daily_omni_to_request_func_input(input_requests[0], profile_input)
+ _attach_seed_tts_to_request_func_input(input_requests[0], profile_input)
profile_output = await request_func(request_func_input=profile_input, session=session)
if profile_output.success:
print("Profiler started")
@@ -560,6 +847,8 @@ async def limited_request_func(request_func_input, session, pbar):
extra_body=extra_body,
request_id=request_id,
)
+ _attach_daily_omni_to_request_func_input(request, request_func_input)
+ _attach_seed_tts_to_request_func_input(request, request_func_input)
tasks.append(
asyncio.create_task(limited_request_func(request_func_input=request_func_input, session=session, pbar=pbar))
)
@@ -627,6 +916,37 @@ async def limited_request_func(request_func_input, session, pbar):
"errors": [output.error for output in outputs],
}
+ from vllm_omni.benchmarks.data_modules.daily_omni_eval import (
+ compute_daily_omni_accuracy_metrics,
+ print_daily_omni_accuracy_summary,
+ )
+
+ _save_items = os.environ.get("DAILY_OMNI_SAVE_EVAL_ITEMS", "").lower() in (
+ "1",
+ "true",
+ "yes",
+ )
+ _daily_acc = compute_daily_omni_accuracy_metrics(input_requests, outputs, include_per_item=_save_items)
+ if _daily_acc is not None:
+ result.update(_daily_acc)
+ print_daily_omni_accuracy_summary(_daily_acc)
+
+ if _seed_tts_capture_pcm_for_wer():
+ from vllm_omni.benchmarks.data_modules.seed_tts_eval import (
+ compute_seed_tts_wer_metrics,
+ print_seed_tts_wer_summary,
+ )
+
+ _save_wer = os.environ.get("SEED_TTS_WER_SAVE_ITEMS", "").lower() in (
+ "1",
+ "true",
+ "yes",
+ )
+ _wer_m = compute_seed_tts_wer_metrics(input_requests, outputs, include_per_item=_save_wer)
+ if _wer_m is not None:
+ result.update(_wer_m)
+ print_seed_tts_wer_summary(_wer_m)
+
if rps_change_events:
result["rps_change_events"] = rps_change_events
diff --git a/vllm_omni/benchmarks/serve.py b/vllm_omni/benchmarks/serve.py
index fe94603693..d3f3510c56 100644
--- a/vllm_omni/benchmarks/serve.py
+++ b/vllm_omni/benchmarks/serve.py
@@ -1,9 +1,21 @@
import argparse
import asyncio
+import os
from typing import Any
from vllm.benchmarks.serve import main_async
+# Import patch to register daily-omni dataset and omni backends
+# This monkey-patches vllm.benchmarks.datasets.get_samples before it's used
+# Must be imported before any vllm.benchmarks module usage
+import vllm_omni.benchmarks.patch.patch # noqa: F401
+
def main(args: argparse.Namespace) -> dict[str, Any]:
+ if getattr(args, "seed_tts_wer_eval", False):
+ os.environ["SEED_TTS_WER_EVAL"] = "1"
+ if getattr(args, "seed_tts_wer_save_items", False):
+ os.environ["SEED_TTS_WER_SAVE_ITEMS"] = "1"
+ if getattr(args, "daily_omni_save_eval_items", False):
+ os.environ["DAILY_OMNI_SAVE_EVAL_ITEMS"] = "1"
return asyncio.run(main_async(args))
diff --git a/vllm_omni/config/__init__.py b/vllm_omni/config/__init__.py
index 2aa236e69f..f02c075880 100644
--- a/vllm_omni/config/__init__.py
+++ b/vllm_omni/config/__init__.py
@@ -5,10 +5,18 @@
from vllm_omni.config.lora import LoRAConfig
from vllm_omni.config.model import OmniModelConfig
from vllm_omni.config.stage_config import (
+ DeployConfig,
ModelPipeline,
+ PipelineConfig,
StageConfig,
StageConfigFactory,
+ StageDeployConfig,
+ StageExecutionType,
+ StagePipelineConfig,
StageType,
+ load_deploy_config,
+ merge_pipeline_deploy,
+ register_pipeline,
)
from vllm_omni.config.yaml_util import (
create_config,
@@ -24,6 +32,14 @@
"StageConfigFactory",
"ModelPipeline",
"StageType",
+ "StageExecutionType",
+ "StagePipelineConfig",
+ "PipelineConfig",
+ "StageDeployConfig",
+ "DeployConfig",
+ "load_deploy_config",
+ "merge_pipeline_deploy",
+ "register_pipeline",
"create_config",
"load_yaml_config",
"merge_configs",
diff --git a/vllm_omni/config/pipeline_registry.py b/vllm_omni/config/pipeline_registry.py
new file mode 100644
index 0000000000..c07bc2610c
--- /dev/null
+++ b/vllm_omni/config/pipeline_registry.py
@@ -0,0 +1,55 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Central declarative registry of all vllm-omni pipelines.
+
+Mirrors the pattern in ``vllm/model_executor/models/registry.py``: each entry
+is ``model_type -> (module_path, variable_name)``, and the module is imported
+lazily on first lookup (see ``_LazyPipelineRegistry`` in
+``vllm_omni/config/stage_config.py``). Keeping every pipeline declared in one
+file makes it easy to spot a missing registration, which was the original
+motivation in https://github.com/vllm-project/vllm-omni/issues/2887 (item 4).
+
+Per-model ``pipeline.py`` modules still define the ``PipelineConfig`` instance;
+they just no longer need to self-register via ``register_pipeline(...)``.
+
+Adding a new pipeline:
+ 1. Define the ``PipelineConfig`` instance as a module-level variable in
+ ``vllm_omni/.../pipeline.py``.
+ 2. Add one line to ``_OMNI_PIPELINES`` or ``_DIFFUSION_PIPELINES`` below.
+
+``register_pipeline(config)`` in ``stage_config`` is still supported for
+out-of-tree plugins and tests that create pipelines at runtime; those override
+the entries declared here.
+"""
+
+from __future__ import annotations
+
+# --- Multi-stage omni pipelines (LLM-centric; audio / video I/O) ---
+_OMNI_PIPELINES: dict[str, tuple[str, str]] = {
+ # model_type -> (module_path, variable_name)
+ "qwen2_5_omni": (
+ "vllm_omni.model_executor.models.qwen2_5_omni.pipeline",
+ "QWEN2_5_OMNI_PIPELINE",
+ ),
+ "qwen2_5_omni_thinker_only": (
+ "vllm_omni.model_executor.models.qwen2_5_omni.pipeline",
+ "QWEN2_5_OMNI_THINKER_ONLY_PIPELINE",
+ ),
+ "qwen3_omni_moe": (
+ "vllm_omni.model_executor.models.qwen3_omni.pipeline",
+ "QWEN3_OMNI_PIPELINE",
+ ),
+ "qwen3_tts": (
+ "vllm_omni.model_executor.models.qwen3_tts.pipeline",
+ "QWEN3_TTS_PIPELINE",
+ ),
+}
+
+# --- Single-stage diffusion pipelines (populated in PR 3/N) ---
+_DIFFUSION_PIPELINES: dict[str, tuple[str, str]] = {}
+
+# Union view used by ``_LazyPipelineRegistry``; don't mutate at runtime.
+_VLLM_OMNI_PIPELINES: dict[str, tuple[str, str]] = {
+ **_OMNI_PIPELINES,
+ **_DIFFUSION_PIPELINES,
+}
diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py
index a4e186c3bd..392a550be6 100644
--- a/vllm_omni/config/stage_config.py
+++ b/vllm_omni/config/stage_config.py
@@ -1,18 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""
-Stage Configuration System for vLLM-Omni.
-
-Pipeline structure (stages, types, data-flow) is defined in per-model YAML
-files and is set by model developers at integration time.
-Runtime parameters (gpu_memory_utilization, tp_size, etc.) come from CLI.
-"""
+"""Stage configuration system for vLLM-Omni."""
from __future__ import annotations
+import dataclasses
import re
import warnings
-from dataclasses import asdict, dataclass, field
+from dataclasses import asdict, dataclass, field, fields
from enum import Enum
from pathlib import Path
from typing import Any
@@ -20,76 +15,818 @@
from vllm.logger import init_logger
from vllm_omni.config.yaml_util import create_config, load_yaml_config, to_dict
+from vllm_omni.core.sched.omni_ar_scheduler import OmniARScheduler
+from vllm_omni.core.sched.omni_generation_scheduler import OmniGenerationScheduler
-# Pipeline YAMLs live alongside model code in model_executor/models//
_MODELS_DIR = Path(__file__).resolve().parent.parent / "model_executor" / "models"
def get_pipeline_path(model_dir: str, filename: str) -> Path:
- """Return the full path to a pipeline YAML file.
+ return _MODELS_DIR / model_dir / filename
+
+
+logger = init_logger(__name__)
+
+
+_STAGE_OVERRIDE_PATTERN = re.compile(r"^stage_(\d+)_(.+)$")
- Args:
- model_dir: Model subdirectory name (e.g., "qwen3_omni").
- filename: Name of the YAML file (e.g., "pipeline.yaml").
- Returns:
- Absolute path to the file.
+def build_stage_runtime_overrides(
+ stage_id: int,
+ cli_overrides: dict[str, Any],
+ *,
+ internal_keys: set[str] | frozenset[str] | None = None,
+) -> dict[str, Any]:
+ """Build per-stage runtime overrides from global and ``stage__*`` kwargs.
+
+ ``internal_keys`` defaults to the union of
+ ``arg_utils.internal_blacklist_keys()`` and ``arg_utils.SHARED_FIELDS``
+ so that neither orchestrator-only fields nor shared-pipeline fields
+ (``model`` / ``stage_configs_path`` / ``log_stats`` / ``stage_id``) leak
+ into a stage's per-stage runtime overrides — the orchestrator sets those
+ uniformly for every stage, they are not per-stage knobs. Callers can
+ pass an explicit set for tests or specialized flows.
"""
- return _MODELS_DIR / model_dir / filename
+ if internal_keys is None:
+ from vllm_omni.engine.arg_utils import SHARED_FIELDS, internal_blacklist_keys
+ internal_keys = internal_blacklist_keys() | SHARED_FIELDS
-logger = init_logger(__name__)
+ result: dict[str, Any] = {}
+
+ for key, value in cli_overrides.items():
+ if value is None or key in internal_keys:
+ continue
+
+ match = _STAGE_OVERRIDE_PATTERN.match(key)
+ if match is not None:
+ override_stage_id = int(match.group(1))
+ param_name = match.group(2)
+ if override_stage_id == stage_id and param_name not in internal_keys:
+ result[param_name] = value
+ continue
+
+ result[key] = value
+
+ return result
+
+
+def strip_parent_engine_args(
+ kwargs: dict[str, Any],
+ *,
+ parent_fields: dict[str, dataclasses.Field],
+ keep_keys: set[str] | frozenset[str] = frozenset(),
+ strip_keys: set[str] | frozenset[str] = frozenset(),
+ no_warn_keys: set[str] | frozenset[str] = frozenset(),
+) -> tuple[dict[str, Any], list[str]]:
+ """Strip parent ``EngineArgs`` fields before merging into stage YAML."""
+ overridden: list[str] = []
+ result: dict[str, Any] = {}
+
+ for key, value in kwargs.items():
+ if key in strip_keys:
+ continue
+
+ if key not in parent_fields or key in keep_keys:
+ result[key] = value
+ continue
+
+ field_def = parent_fields[key]
+ if field_def.default is not dataclasses.MISSING:
+ default = field_def.default
+ elif field_def.default_factory is not dataclasses.MISSING:
+ default = field_def.default_factory()
+ else:
+ default = dataclasses.MISSING
+
+ if default is dataclasses.MISSING or value is None:
+ continue
+
+ if dataclasses.is_dataclass(default) and not isinstance(default, type):
+ default = asdict(default)
+
+ if value != default and key not in no_warn_keys:
+ overridden.append(key)
+
+ return result, sorted(overridden)
class StageType(str, Enum):
"""Type of processing stage in the Omni pipeline."""
+ # TODO(@lishunyang12): remove once all models migrate to StageExecutionType
LLM = "llm"
DIFFUSION = "diffusion"
+class StageExecutionType(str, Enum):
+ """Merged StageType + WorkerType — 3 combinations today."""
+
+ LLM_AR = "llm_ar"
+ LLM_GENERATION = "llm_generation"
+ DIFFUSION = "diffusion"
+
+
+# Mapping class refs (not dotted-path strings) so module/class renames fail
+# at import time instead of lazily at scheduler resolution. YAML overrides
+# and downstream serialization still use the dotted-path string form; the
+# conversion happens at the map lookup site via _scheduler_path().
+_EXECUTION_TYPE_TO_SCHEDULER: dict[StageExecutionType, type | None] = {
+ StageExecutionType.LLM_AR: OmniARScheduler,
+ StageExecutionType.LLM_GENERATION: OmniGenerationScheduler,
+ StageExecutionType.DIFFUSION: None,
+}
+
+
+def _scheduler_path(cls: type | None) -> str | None:
+ """Return the dotted import path for a scheduler class (``None`` passes through)."""
+ if cls is None:
+ return None
+ return f"{cls.__module__}.{cls.__qualname__}"
+
+
+@dataclass(frozen=True)
+class StagePipelineConfig:
+ """Fixed topology for one stage (frozen, not user-configurable)."""
+
+ stage_id: int
+ model_stage: str
+ execution_type: StageExecutionType = StageExecutionType.LLM_AR
+ input_sources: tuple[int, ...] = ()
+ final_output: bool = False
+ final_output_type: str | None = None
+ owns_tokenizer: bool = False
+ requires_multimodal_data: bool = False
+ hf_config_name: str | None = None
+ engine_output_type: str | None = None
+ model_arch: str | None = None
+ sampling_constraints: dict[str, Any] = field(default_factory=dict)
+ custom_process_input_func: str | None = None
+ custom_process_next_stage_input_func: str | None = None
+ # Alternates picked by ``merge_pipeline_deploy`` based on ``deploy.async_chunk``.
+ async_chunk_process_next_stage_input_func: str | None = None
+ sync_process_input_func: str | None = None
+ prompt_expand_func: str | None = None
+ cfg_kv_collect_func: str | None = None
+ omni_kv_config: dict[str, Any] | None = None
+ extras: dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass(frozen=True)
+class PipelineConfig:
+ """Complete pipeline topology for a model (frozen)."""
+
+ model_type: str
+ model_arch: str = ""
+ stages: tuple[StagePipelineConfig, ...] = ()
+ # HF architecture aliases: used by StageConfigFactory when the model's
+ # HF config reports a generic model_type that collides with a different
+ # model (e.g. MiMo Audio reports model_type="qwen2"). The factory
+ # matches ``hf_config.architectures[*]`` against this tuple to route
+ # to the correct pipeline. Leave empty for models with unique model_type.
+ hf_architectures: tuple[str, ...] = ()
+
+ def get_stage(self, stage_id: int) -> StagePipelineConfig | None:
+ """Look up a stage by its ID."""
+ for stage in self.stages:
+ if stage.stage_id == stage_id:
+ return stage
+ return None
+
+ def get_scheduler_cls(self, stage_id: int) -> str | None:
+ """Return the inferred scheduler class path for a stage.
+
+ Returns ``None`` for DIFFUSION stages (no vLLM scheduler). Raises
+ ``ValueError`` if ``stage_id`` doesn't exist in this pipeline, and
+ ``KeyError`` if ``execution_type`` isn't in the scheduler map.
+ """
+ stage = self.get_stage(stage_id)
+ if stage is None:
+ raise ValueError(f"Pipeline {self.model_type!r} has no stage with id {stage_id}")
+ return _scheduler_path(_EXECUTION_TYPE_TO_SCHEDULER[stage.execution_type])
+
+ def validate(self) -> list[str]:
+ """Return list of topology errors (empty if valid)."""
+ errors: list[str] = []
+ if not self.stages:
+ errors.append("Pipeline has no stages defined")
+ return errors
+ stage_ids = [s.stage_id for s in self.stages]
+ if len(stage_ids) != len(set(stage_ids)):
+ errors.append("Duplicate stage IDs found")
+ stage_id_set = set(stage_ids)
+ for stage in self.stages:
+ for src in stage.input_sources:
+ if src not in stage_id_set:
+ errors.append(f"Stage {stage.stage_id} references non-existent input source {src}")
+ if src == stage.stage_id:
+ errors.append(f"Stage {stage.stage_id} references itself")
+ if not any(not s.input_sources for s in self.stages):
+ errors.append("No entry point (stage with empty input_sources)")
+ return errors
+
+
+class _LazyPipelineRegistry:
+ """Dict-like registry that lazy-loads pipelines from the central declaration.
+
+ In-tree pipelines are declared once in
+ ``vllm_omni/config/pipeline_registry.py`` as
+ ``model_type -> (module_path, variable_name)`` entries; the module is
+ imported only when the pipeline is first looked up. This mirrors the
+ pattern in ``vllm/model_executor/models/registry.py`` and addresses
+ https://github.com/vllm-project/vllm-omni/issues/2887 (item 4): having
+ every registration in one file makes a missing entry easy to spot.
+
+ Out-of-tree / dynamic registrations via ``register_pipeline()`` are stored
+ directly in ``_loaded`` and take precedence over the lazy-map entry with
+ the same ``model_type``.
+
+ The class exposes the subset of ``dict`` operations the rest of this
+ module relies on (``__contains__``, ``__getitem__``, ``__setitem__``,
+ ``get``, ``keys``, ``values``, ``items``, ``__iter__``), so existing call
+ sites don't need to change.
+ """
+
+ def __init__(self) -> None:
+ self._loaded: dict[str, PipelineConfig] = {}
+ # Populated lazily to avoid a circular import at module init time.
+ self._lazy_map: dict[str, tuple[str, str]] | None = None
+
+ def _get_lazy_map(self) -> dict[str, tuple[str, str]]:
+ if self._lazy_map is None:
+ from vllm_omni.config.pipeline_registry import _VLLM_OMNI_PIPELINES
+
+ self._lazy_map = _VLLM_OMNI_PIPELINES
+ return self._lazy_map
+
+ def _load_lazy(self, model_type: str) -> PipelineConfig | None:
+ entry = self._get_lazy_map().get(model_type)
+ if entry is None:
+ return None
+ module_path, var_name = entry
+ import importlib
+
+ try:
+ module = importlib.import_module(module_path)
+ except ImportError as exc:
+ logger.error(
+ "Failed to import pipeline module %r for %r: %s",
+ module_path,
+ model_type,
+ exc,
+ )
+ return None
+ pipeline = getattr(module, var_name, None)
+ if pipeline is None:
+ logger.error(
+ "Pipeline variable %r not found in module %r (registered for %r)",
+ var_name,
+ module_path,
+ model_type,
+ )
+ return None
+ errors = pipeline.validate()
+ if errors:
+ logger.warning("Pipeline %s has issues: %s", pipeline.model_type, errors)
+ self._loaded[model_type] = pipeline
+ return pipeline
+
+ def __contains__(self, model_type: str) -> bool:
+ if model_type in self._loaded:
+ return True
+ return model_type in self._get_lazy_map()
+
+ def __getitem__(self, model_type: str) -> PipelineConfig:
+ if model_type in self._loaded:
+ return self._loaded[model_type]
+ pipeline = self._load_lazy(model_type)
+ if pipeline is None:
+ raise KeyError(model_type)
+ return pipeline
+
+ def get(self, model_type: str, default: PipelineConfig | None = None) -> PipelineConfig | None:
+ if model_type in self._loaded:
+ return self._loaded[model_type]
+ pipeline = self._load_lazy(model_type)
+ return pipeline if pipeline is not None else default
+
+ def __setitem__(self, model_type: str, pipeline: PipelineConfig) -> None:
+ self._loaded[model_type] = pipeline
+
+ def __delitem__(self, model_type: str) -> None:
+ """Remove a dynamically-registered pipeline.
+
+ Only the dynamic-cache side of the registry can be mutated; the
+ central declarative registry is immutable at runtime. Calling ``del``
+ on a model_type that only exists in the central registry raises
+ ``KeyError``.
+ """
+ if model_type in self._loaded:
+ del self._loaded[model_type]
+ return
+ if model_type in self._get_lazy_map():
+ raise KeyError(
+ f"{model_type!r} is declared in the central pipeline_registry and "
+ "cannot be removed at runtime. Edit "
+ "vllm_omni/config/pipeline_registry.py to delete a built-in entry."
+ )
+ raise KeyError(model_type)
+
+ def keys(self) -> set[str]:
+ return set(self._get_lazy_map().keys()) | set(self._loaded.keys())
+
+ def values(self):
+ # Iterating values forces load of every lazy pipeline.
+ for key in self.keys():
+ yield self[key]
+
+ def items(self):
+ for key in self.keys():
+ yield key, self[key]
+
+ def __iter__(self):
+ return iter(self.keys())
+
+
+_PIPELINE_REGISTRY = _LazyPipelineRegistry()
+
+
+def register_pipeline(pipeline: PipelineConfig) -> None:
+ """Register a pipeline config dynamically.
+
+ In-tree pipelines are declared in ``pipeline_registry._VLLM_OMNI_PIPELINES``
+ and loaded lazily; calling ``register_pipeline`` is only needed for
+ out-of-tree plugins or tests that build a ``PipelineConfig`` at runtime.
+ A dynamic registration overrides the central-registry entry with the same
+ ``model_type``.
+ """
+ errors = pipeline.validate()
+ if errors:
+ logger.warning("Pipeline %s has issues: %s", pipeline.model_type, errors)
+ _PIPELINE_REGISTRY[pipeline.model_type] = pipeline
+
+
+_DEPLOY_DIR = Path(__file__).resolve().parent.parent / "deploy"
+
+
+@dataclass
+class StageDeployConfig:
+ """Per-stage deployment knobs.
+
+ Only fields whose value legitimately varies across stages of the same
+ pipeline live here (e.g. ``max_num_seqs`` on thinker vs talker,
+ ``devices`` for GPU placement). Pipeline-wide settings
+ (``trust_remote_code``, ``distributed_executor_backend``, ``dtype``,
+ ``quantization``, prefix/chunked prefill, DP/PP sizes) are declared at
+ the top level of ``DeployConfig`` and propagated to every stage.
+ """
+
+ stage_id: int
+ max_num_seqs: int = 64
+ gpu_memory_utilization: float = 0.9
+ tensor_parallel_size: int = 1
+ enforce_eager: bool = False
+ max_num_batched_tokens: int = 32768
+ max_model_len: int | None = None
+ async_scheduling: bool | None = None
+ devices: str = "0"
+ output_connectors: dict[str, str] | None = None
+ input_connectors: dict[str, str] | None = None
+ default_sampling_params: dict[str, Any] | None = None
+ engine_extras: dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass
+class DeployConfig:
+ """Loaded from deploy/.yaml — the only config file users edit.
+
+ Top-level fields (``trust_remote_code``, ``distributed_executor_backend``,
+ ``dtype``, ``quantization``, ``enable_prefix_caching``,
+ ``enable_chunked_prefill``, ``data_parallel_size``,
+ ``pipeline_parallel_size``) are pipeline-wide: they apply uniformly to
+ every stage. Fields that legitimately vary per stage live in the
+ individual ``StageDeployConfig`` entries under ``stages:``.
+ """
+
+ async_chunk: bool = True
+ connectors: dict[str, Any] | None = None
+ edges: list[dict[str, Any]] | None = None
+ stages: list[StageDeployConfig] = field(default_factory=list)
+ platforms: dict[str, Any] | None = None
+ # Overrides the auto-detected pipeline registry key for structural variants.
+ pipeline: str | None = None
+
+ # === Pipeline-wide engine settings (applied uniformly to every stage) ===
+ trust_remote_code: bool = True
+ distributed_executor_backend: str = "mp"
+ dtype: str | None = None
+ quantization: str | None = None
+ enable_prefix_caching: bool = False
+ enable_chunked_prefill: bool | None = None
+ data_parallel_size: int = 1
+ pipeline_parallel_size: int = 1
+
+
+_STAGE_NON_ENGINE_KEYS = frozenset(
+ {
+ "stage_id",
+ "devices",
+ "output_connectors",
+ "input_connectors",
+ "default_sampling_params",
+ "engine_extras",
+ }
+)
+
+# Fields on StageDeployConfig that are populated from engine_args dict
+_STAGE_DEPLOY_FIELDS = {f.name: f for f in fields(StageDeployConfig) if f.name not in _STAGE_NON_ENGINE_KEYS}
+
+
+def _parse_stage_deploy(stage_data: dict[str, Any]) -> StageDeployConfig:
+ """Parse a single stage entry from deploy YAML into StageDeployConfig."""
+ if "engine_args" in stage_data:
+ engine_args = dict(stage_data["engine_args"])
+ devices = stage_data.get("runtime", {}).get("devices", stage_data.get("devices", "0"))
+ else:
+ engine_args = {k: v for k, v in stage_data.items() if k not in _STAGE_NON_ENGINE_KEYS and k != "stage_id"}
+ devices = stage_data.get("devices", "0")
+
+ kwargs: dict[str, Any] = {"stage_id": stage_data["stage_id"], "devices": devices}
+ for name, f in _STAGE_DEPLOY_FIELDS.items():
+ if name in engine_args:
+ kwargs[name] = engine_args.pop(name)
+
+ kwargs["output_connectors"] = stage_data.get("output_connectors")
+ kwargs["input_connectors"] = stage_data.get("input_connectors")
+ kwargs["default_sampling_params"] = stage_data.get("default_sampling_params")
+ kwargs["engine_extras"] = engine_args
+ return StageDeployConfig(**kwargs)
+
+
+_DEEP_MERGE_KEYS = frozenset({"default_sampling_params", "engine_extras", "engine_args"})
+
+
+def _deep_merge_stage(base: dict, overlay: dict) -> dict:
+ """Deep-merge ``_DEEP_MERGE_KEYS`` so thin overlays don't drop base keys."""
+ merged = dict(base)
+ for k, v in overlay.items():
+ if k in _DEEP_MERGE_KEYS:
+ base_val = merged.get(k)
+ if isinstance(v, dict) and isinstance(base_val, dict):
+ merged[k] = {**base_val, **v}
+ continue
+ # Deep-merge key but at least one side isn't a dict: surface the
+ # silent clobber so mismatched YAML types don't get past review.
+ if base_val is not None:
+ logger.warning(
+ "Deep-merge key %r has non-dict value (base=%s, overlay=%s); "
+ "overlay will fully replace base instead of merging.",
+ k,
+ type(base_val).__name__,
+ type(v).__name__,
+ )
+ merged[k] = v
+ return merged
+
+
+def _merge_stage_lists(
+ base_stages: list[dict[str, Any]] | None,
+ overlay_stages: list[dict[str, Any]] | None,
+) -> list[dict[str, Any]]:
+ """Merge two ``stages:`` lists by ``stage_id`` (overlay wins per field)."""
+ by_id: dict[int, dict[str, Any]] = {s["stage_id"]: s for s in (base_stages or [])}
+ for overlay_stage in overlay_stages or []:
+ sid = overlay_stage["stage_id"]
+ if sid in by_id:
+ by_id[sid] = _deep_merge_stage(by_id[sid], overlay_stage)
+ else:
+ by_id[sid] = overlay_stage
+ return list(by_id.values())
+
+
+def _merge_platforms(
+ base: dict[str, Any] | None,
+ overlay: dict[str, Any] | None,
+) -> dict[str, Any] | None:
+ """Deep-merge two ``platforms:`` blocks per-platform, per-stage_id."""
+ if not base and not overlay:
+ return None
+ base = base or {}
+ overlay = overlay or {}
+ merged: dict[str, Any] = {}
+ for plat in set(base) | set(overlay):
+ bp = base.get(plat) or {}
+ op = overlay.get(plat) or {}
+ merged_plat = {**bp, **{k: v for k, v in op.items() if k != "stages"}}
+ merged_plat["stages"] = _merge_stage_lists(bp.get("stages"), op.get("stages"))
+ merged[plat] = merged_plat
+ return merged
+
+
+def resolve_deploy_yaml(path: str | Path) -> dict[str, Any]:
+ """Load a deploy YAML with optional ``base_config`` inheritance."""
+ raw_dict = to_dict(load_yaml_config(path))
+
+ base_path = raw_dict.pop("base_config", None)
+ if base_path is None:
+ return raw_dict
+
+ # Resolve relative to the overlay file's directory
+ base_path = Path(path).parent / base_path
+ base_dict = resolve_deploy_yaml(base_path)
+
+ # Merge top-level scalars: overlay wins. ``stages:`` and ``platforms:``
+ # are deep-merged below so an overlay can layer on top of the base.
+ merged = {
+ **base_dict,
+ **{k: v for k, v in raw_dict.items() if k not in ("stages", "platforms")},
+ }
+ merged["stages"] = _merge_stage_lists(base_dict.get("stages"), raw_dict.get("stages"))
+ merged_platforms = _merge_platforms(base_dict.get("platforms"), raw_dict.get("platforms"))
+ if merged_platforms is not None:
+ merged["platforms"] = merged_platforms
+
+ return merged
+
+
+def load_deploy_config(path: str | Path) -> DeployConfig:
+ """Load a deploy YAML (with optional base_config inheritance)."""
+ raw_dict = resolve_deploy_yaml(path)
+
+ stages = [_parse_stage_deploy(s) for s in raw_dict.get("stages", [])]
+
+ kwargs: dict[str, Any] = {
+ "async_chunk": raw_dict.get("async_chunk", True),
+ "connectors": raw_dict.get("connectors", None),
+ "edges": raw_dict.get("edges", None),
+ "stages": stages,
+ "platforms": raw_dict.get("platforms", None),
+ "pipeline": raw_dict.get("pipeline", None),
+ }
+ # Pipeline-wide engine settings: only set if explicitly present in YAML
+ # so the DeployConfig dataclass defaults take effect otherwise.
+ for name in (
+ "trust_remote_code",
+ "distributed_executor_backend",
+ "dtype",
+ "quantization",
+ "enable_prefix_caching",
+ "enable_chunked_prefill",
+ "data_parallel_size",
+ "pipeline_parallel_size",
+ ):
+ if name in raw_dict:
+ kwargs[name] = raw_dict[name]
+ return DeployConfig(**kwargs)
+
+
+def _detect_platform() -> str | None:
+ """Return "npu", "rocm", "xpu", or None (CUDA default)."""
+ try:
+ from vllm.platforms import current_platform
+
+ name = current_platform.device_name.lower()
+ if "npu" in name:
+ return "npu"
+ if "rocm" in name or "amd" in name:
+ return "rocm"
+ if "xpu" in name:
+ return "xpu"
+ except Exception as e:
+ logger.debug("Platform auto-detect failed, falling back to CUDA: %s", e)
+ return None
+
+
+def _extract_platform_overrides(ps: dict[str, Any]) -> tuple[dict[str, Any], str | None]:
+ """Return ``(overrides, devices)`` from a platform stage entry.
+
+ Handles both the nested layout (``engine_args:`` / ``runtime.devices``) and
+ the flat layout. ``devices`` is ``None`` when no override is set.
+ """
+ if "engine_args" in ps:
+ return dict(ps["engine_args"]), ps.get("runtime", {}).get("devices")
+ overrides = {k: v for k, v in ps.items() if k not in ("stage_id", "devices")}
+ return overrides, ps.get("devices")
+
+
+def _apply_platform_overrides(
+ deploy: DeployConfig,
+ platform: str | None = None,
+) -> DeployConfig:
+ """Merge platform-specific stage overrides into deploy config."""
+ if platform is None:
+ platform = _detect_platform()
+ if platform is None or deploy.platforms is None:
+ return deploy
+ platform_section = deploy.platforms.get(platform)
+ if platform_section is None:
+ return deploy
+
+ platform_stages = platform_section.get("stages", [])
+ base_by_id = {s.stage_id: s for s in deploy.stages}
+
+ for ps in platform_stages:
+ base = base_by_id.get(ps["stage_id"])
+ if base is None:
+ continue
+ overrides, devices = _extract_platform_overrides(ps)
+ if devices is not None:
+ base.devices = devices
+ for key, val in overrides.items():
+ if hasattr(base, key):
+ setattr(base, key, val)
+ else:
+ base.engine_extras[key] = val
+
+ return deploy
+
+
+_EXECUTION_TYPE_TO_STAGE_WORKER: dict[StageExecutionType, tuple[StageType, str | None]] = {
+ StageExecutionType.LLM_AR: (StageType.LLM, "ar"),
+ StageExecutionType.LLM_GENERATION: (StageType.LLM, "generation"),
+ StageExecutionType.DIFFUSION: (StageType.DIFFUSION, None),
+}
+
+
+def _resolve_execution_mode(
+ execution_type: StageExecutionType,
+) -> tuple[StageType, str | None]:
+ """Map ``execution_type`` → ``(stage_type, worker_type)`` legacy tuple."""
+ return _EXECUTION_TYPE_TO_STAGE_WORKER.get(execution_type, (StageType.LLM, None))
+
+
+def _select_processor_funcs(
+ ps: StagePipelineConfig,
+ async_chunk: bool,
+) -> tuple[str | None, str | None]:
+ """Pick ``(input_proc, next_stage_proc)`` based on the async_chunk mode."""
+ next_stage_proc = ps.custom_process_next_stage_input_func
+ input_proc = ps.custom_process_input_func
+ if async_chunk and ps.async_chunk_process_next_stage_input_func:
+ next_stage_proc = ps.async_chunk_process_next_stage_input_func
+ elif not async_chunk and ps.sync_process_input_func:
+ input_proc = ps.sync_process_input_func
+ return input_proc, next_stage_proc
+
+
+# Pipeline-wide DeployConfig fields that are propagated to every stage's
+# engine args during merge. These live at top level of the deploy YAML.
+_PIPELINE_WIDE_ENGINE_FIELDS: tuple[str, ...] = (
+ "trust_remote_code",
+ "distributed_executor_backend",
+ "dtype",
+ "quantization",
+ "enable_prefix_caching",
+ "enable_chunked_prefill",
+ "data_parallel_size",
+ "pipeline_parallel_size",
+)
+
+
+def _build_engine_args(
+ ps: StagePipelineConfig,
+ ds: StageDeployConfig | None,
+ pipeline: PipelineConfig,
+ deploy: DeployConfig,
+ next_stage_proc: str | None,
+) -> dict[str, Any]:
+ """Assemble the flat ``yaml_engine_args`` dict for one stage.
+
+ Pipeline-wide DeployConfig fields are applied uniformly to every stage;
+ per-stage StageDeployConfig overrides take precedence when present (e.g.
+ ``engine_extras`` can still carry a stage-specific ``dtype``).
+ """
+ engine_args: dict[str, Any] = {"model_arch": ps.model_arch or pipeline.model_arch}
+ if ps.engine_output_type:
+ engine_args["engine_output_type"] = ps.engine_output_type
+ if next_stage_proc:
+ engine_args["custom_process_next_stage_input_func"] = next_stage_proc
+
+ # Pipeline-wide top-level DeployConfig settings, applied to every stage.
+ for name in _PIPELINE_WIDE_ENGINE_FIELDS:
+ value = getattr(deploy, name)
+ if value is not None:
+ engine_args[name] = value
+
+ # Per-stage StageDeployConfig values override pipeline-wide settings.
+ if ds is not None:
+ for k, v in asdict(ds).items():
+ if k in _STAGE_NON_ENGINE_KEYS or v is None:
+ continue
+ engine_args[k] = v
+ engine_args.update(ds.engine_extras)
+ if deploy.async_chunk:
+ engine_args["async_chunk"] = True
+ return engine_args
+
+
+def _build_extras(
+ ps: StagePipelineConfig,
+ ds: StageDeployConfig | None,
+) -> dict[str, Any]:
+ """Assemble ``yaml_extras`` (sampling + connectors + pipeline extras)."""
+ extras: dict[str, Any] = {}
+ sampling: dict[str, Any] = {}
+ if ds is not None and ds.default_sampling_params:
+ sampling.update(ds.default_sampling_params)
+ sampling.update(ps.sampling_constraints)
+ if sampling:
+ extras["default_sampling_params"] = sampling
+ if ds is not None and ds.output_connectors:
+ extras["output_connectors"] = dict(ds.output_connectors)
+ if ds is not None and ds.input_connectors:
+ extras["input_connectors"] = dict(ds.input_connectors)
+ if ps.extras:
+ extras.update(ps.extras)
+ return extras
+
+
+def merge_pipeline_deploy(
+ pipeline: PipelineConfig,
+ deploy: DeployConfig,
+ cli_overrides: dict[str, Any] | None = None,
+) -> list[StageConfig]:
+ """Merge pipeline + deploy + platform overrides → list[StageConfig]."""
+ if cli_overrides is None:
+ cli_overrides = {}
+
+ deploy = _apply_platform_overrides(deploy)
+ deploy_by_id = {s.stage_id: s for s in deploy.stages}
+
+ # A pipeline supports async_chunk if any stage has either an explicit
+ # async-chunk-only processor slot OR a custom next-stage processor (some
+ # pipelines like qwen3_omni wire async-chunk processing directly through
+ # ``custom_process_next_stage_input_func``). Only raise when neither is
+ # present — that's the "user enabled async_chunk but pipeline has no
+ # inter-stage processing at all" case.
+ if deploy.async_chunk and not any(
+ ps.async_chunk_process_next_stage_input_func or ps.custom_process_next_stage_input_func
+ for ps in pipeline.stages
+ ):
+ raise ValueError(
+ f"Pipeline {pipeline.model_type!r} has async_chunk=True in deploy but no stage "
+ "declares a next-stage input processor "
+ "(``async_chunk_process_next_stage_input_func`` or ``custom_process_next_stage_input_func``). "
+ "Either set async_chunk=False or implement an async-chunk processor on the pipeline."
+ )
+
+ result: list[StageConfig] = []
+ for ps in pipeline.stages:
+ ds = deploy_by_id.get(ps.stage_id)
+ stage_type, worker_type = _resolve_execution_mode(ps.execution_type)
+ input_proc, next_stage_proc = _select_processor_funcs(ps, deploy.async_chunk)
+ engine_args = _build_engine_args(ps, ds, pipeline, deploy, next_stage_proc)
+ extras = _build_extras(ps, ds)
+ runtime: dict[str, Any] = {"process": True}
+ if ds is not None:
+ runtime["devices"] = ds.devices
+
+ result.append(
+ StageConfig(
+ stage_id=ps.stage_id,
+ model_stage=ps.model_stage,
+ stage_type=stage_type,
+ input_sources=list(ps.input_sources),
+ custom_process_input_func=input_proc,
+ final_output=ps.final_output,
+ final_output_type=ps.final_output_type,
+ worker_type=worker_type,
+ scheduler_cls=_scheduler_path(_EXECUTION_TYPE_TO_SCHEDULER.get(ps.execution_type)),
+ hf_config_name=ps.hf_config_name,
+ is_comprehension=ps.owns_tokenizer,
+ yaml_engine_args=engine_args,
+ yaml_runtime=runtime,
+ yaml_extras=extras,
+ )
+ )
+ return result
+
+
@dataclass
class StageConfig:
- """Per-stage configuration from pipeline YAML.
+ """Per-stage config (legacy path). Used by both new and legacy loaders.
- Topology fields (stage_id, input_sources, etc.) define the DAG.
- Engine and runtime defaults come from the YAML; CLI overrides take
- precedence via ``runtime_overrides``.
+ TODO(@lishunyang12): replace with ResolvedStageConfig once all models are migrated.
"""
- # Identity
stage_id: int
model_stage: str
-
- # Stage type
stage_type: StageType = StageType.LLM
-
input_sources: list[int] = field(default_factory=list)
custom_process_input_func: str | None = None
final_output: bool = False
- final_output_type: str | None = None # "text", "audio", "image"
- worker_type: str | None = None # "ar" or "generation"
+ final_output_type: str | None = None
+ worker_type: str | None = None
scheduler_cls: str | None = None
hf_config_name: str | None = None
is_comprehension: bool = False
-
- # Per-stage engine args from pipeline YAML (defaults)
yaml_engine_args: dict[str, Any] = field(default_factory=dict)
- # Per-stage runtime config from pipeline YAML (devices, etc.)
yaml_runtime: dict[str, Any] = field(default_factory=dict)
- # Pass-through fields from pipeline YAML (default_sampling_params,
- # output_connectors, input_connectors, tts_args, etc.)
yaml_extras: dict[str, Any] = field(default_factory=dict)
-
- # Runtime overrides (populated from CLI, not from pipeline YAML)
runtime_overrides: dict[str, Any] = field(default_factory=dict)
def to_omegaconf(self) -> Any:
- """Convert to OmegaConf for backward compatibility with OmniStage.
-
- Returns:
- OmegaConf DictConfig with stage configuration in legacy format.
- """
+ """TODO(@lishunyang12): remove once engine consumes ResolvedStageConfig directly."""
# Start with YAML engine_args defaults
engine_args: dict[str, Any] = dict(self.yaml_engine_args)
@@ -152,9 +889,9 @@ def to_omegaconf(self) -> Any:
@dataclass
class ModelPipeline:
- """Complete pipeline definition for a multi-stage model.
+ """Complete pipeline definition for a multi-stage model (legacy).
- Defined by model developers, bundled with the model, not user-editable.
+ TODO(@lishunyang12): remove once all models migrate to PipelineConfig.
"""
model_type: str
@@ -225,49 +962,55 @@ class StageConfigFactory:
"""Factory that loads pipeline YAML and merges CLI overrides.
Handles both single-stage and multi-stage models.
- """
- # Mapping of model types to directories under model_executor/models/.
- PIPELINE_MODELS: dict[str, str] = {
- "qwen3_omni_moe": "qwen3_omni",
- "qwen2_5_omni": "qwen2_5_omni",
- "bagel": "bagel",
- "qwen3_tts": "qwen3_tts",
- "voxtral_tts": "voxtral_tts",
- "mimo_audio": "mimo_audio",
- "glm-image": "glm_image",
- "cosyvoice3": "cosyvoice3",
- "mammothmoda2": "mammoth_moda2",
- }
-
- # Fallback: map HF architecture class names to pipeline dirs.
- # Used when model_type collides with another model (e.g. MiMo Audio
- # reports model_type="qwen2" which matches plain Qwen2, not our pipeline).
- _ARCHITECTURE_MODELS: dict[str, str] = {
- "MiMoAudioForConditionalGeneration": "mimo_audio",
- "HunyuanImage3ForCausalMM": "hunyuan_image3",
- }
+ Pipelines are declared in ``vllm_omni/config/pipeline_registry.py`` and
+ loaded lazily via ``_PIPELINE_REGISTRY``; no hardcoded model-type →
+ directory mapping is maintained here. Models with generic HF
+ ``model_type`` collisions (e.g. MiMo Audio reports ``qwen2``) should
+ declare ``hf_architectures=(...)`` on their ``PipelineConfig`` so the
+ factory can disambiguate via ``hf_config.architectures``.
+ """
@classmethod
def create_from_model(
cls,
model: str,
cli_overrides: dict[str, Any] | None = None,
+ deploy_config_path: str | None = None,
+ cli_explicit_keys: set[str] | None = None,
) -> list[StageConfig] | None:
- """Load pipeline YAML, merge with CLI overrides.
+ """Load pipeline + deploy config, merge with CLI overrides.
- Args:
- model: Model name or path.
- cli_overrides: CLI overrides from VllmConfig/OmniDiffusionConfig.
+ Checks _PIPELINE_REGISTRY first (new path), falls back to legacy YAML.
- Returns:
- List of StageConfig objects with CLI overrides applied,
- or None if no pipeline definition was found for this model.
+ ``cli_explicit_keys`` is the set of CLI keys the user actually typed
+ (captured at the parser layer in ``vllm serve``). When ``None`` —
+ which is the case for programmatic ``Omni()`` callers — every kwarg
+ in ``cli_overrides`` is treated as explicit.
"""
if cli_overrides is None:
cli_overrides = {}
trust_remote_code = cli_overrides.get("trust_remote_code", True)
+
+ # --- New path: check pipeline registry by model_type first ---
+ model_type, hf_config = cls._auto_detect_model_type(model, trust_remote_code=trust_remote_code)
+ if model_type and model_type in _PIPELINE_REGISTRY:
+ return cls._create_from_registry(model_type, cli_overrides, deploy_config_path, cli_explicit_keys)
+
+ # --- HF architecture fallback: some models report a generic
+ # model_type that collides with another model. Match by the
+ # hf_architectures declared on each registered PipelineConfig.
+ if hf_config is not None:
+ hf_archs = set(getattr(hf_config, "architectures", []) or [])
+ if hf_archs:
+ for registered in _PIPELINE_REGISTRY.values():
+ if hf_archs.intersection(registered.hf_architectures):
+ return cls._create_from_registry(
+ registered.model_type, cli_overrides, deploy_config_path, cli_explicit_keys
+ )
+
+ # --- Legacy path: load from pipeline YAML ---
pipeline = cls._load_pipeline(model, trust_remote_code=trust_remote_code)
if pipeline is None:
@@ -295,6 +1038,78 @@ def create_from_model(
return result
+ @classmethod
+ def _create_from_registry(
+ cls,
+ model_type: str,
+ cli_overrides: dict[str, Any],
+ deploy_config_path: str | None = None,
+ cli_explicit_keys: set[str] | None = None,
+ ) -> list[StageConfig]:
+ """Create StageConfigs from pipeline registry + deploy YAML.
+
+ Precedence (high → low):
+ explicit CLI args > deploy YAML > parser default CLI values
+
+ ``cli_explicit_keys`` carries the set of long-option attribute names
+ the user actually typed (captured in ``OmniServeCommand.cmd``). Any
+ kwarg whose key is not in that set is treated as a parser default
+ and is only used to fill fields YAML doesn't already cover. When the
+ set is ``None`` (programmatic ``Omni()`` callers, which have no
+ argparse layer), every kwarg is treated as explicit.
+ """
+ # Resolve deploy config path
+ if deploy_config_path is None:
+ deploy_path = _DEPLOY_DIR / f"{model_type}.yaml"
+ else:
+ deploy_path = Path(deploy_config_path)
+
+ if not deploy_path.exists():
+ logger.warning(
+ "Deploy config not found: %s — using pipeline defaults only",
+ deploy_path,
+ )
+ deploy_cfg = DeployConfig()
+ else:
+ deploy_cfg = load_deploy_config(deploy_path)
+
+ cli_async_chunk = cli_overrides.get("async_chunk")
+ if cli_async_chunk is not None and (cli_explicit_keys is None or "async_chunk" in cli_explicit_keys):
+ deploy_cfg.async_chunk = bool(cli_async_chunk)
+
+ pipeline_key = deploy_cfg.pipeline or model_type
+ if pipeline_key not in _PIPELINE_REGISTRY:
+ raise KeyError(
+ f"Pipeline {pipeline_key!r} not in registry "
+ f"(resolved from {deploy_path.name!r}). Available: "
+ f"{sorted(_PIPELINE_REGISTRY.keys())}"
+ )
+ pipeline_cfg = _PIPELINE_REGISTRY[pipeline_key]
+
+ stages = merge_pipeline_deploy(pipeline_cfg, deploy_cfg, cli_overrides)
+
+ # Precedence: explicit CLI > yaml > parser-default CLI.
+ # Per-stage (``stage_N_*``) keys are always treated as explicit.
+ explicit_overrides: dict[str, Any] = {}
+ default_overrides: dict[str, Any] = {}
+ for key, value in cli_overrides.items():
+ if value is None:
+ continue
+ is_per_stage = bool(re.match(r"stage_\d+_", key))
+ is_explicit = cli_explicit_keys is None or key in cli_explicit_keys or is_per_stage
+ if is_explicit:
+ explicit_overrides[key] = value
+ else:
+ default_overrides[key] = value
+
+ for stage in stages:
+ yaml_keys = set(stage.yaml_engine_args)
+ fallback = {k: v for k, v in default_overrides.items() if k not in yaml_keys}
+ merged = {**fallback, **explicit_overrides}
+ stage.runtime_overrides = cls._merge_cli_overrides(stage, merged)
+
+ return stages
+
@classmethod
def create_default_diffusion(cls, kwargs: dict[str, Any]) -> list[dict[str, Any]]:
"""Single-stage diffusion - no YAML needed.
@@ -322,9 +1137,16 @@ def create_default_diffusion(cls, kwargs: dict[str, Any]) -> list[dict[str, Any]
continue
engine_args[key] = value
- # Serialize parallel_config as dict for OmegaConf compatibility
+ # Serialize parallel_config as dict for OmegaConf. Test helpers
+ # sometimes pass SimpleNamespace rather than a dataclass instance.
if "parallel_config" in kwargs:
- engine_args["parallel_config"] = asdict(kwargs["parallel_config"])
+ parallel_config = kwargs["parallel_config"]
+ if dataclasses.is_dataclass(parallel_config) and not isinstance(parallel_config, type):
+ engine_args["parallel_config"] = asdict(parallel_config)
+ elif hasattr(parallel_config, "__dict__"):
+ engine_args["parallel_config"] = dict(vars(parallel_config))
+ else:
+ engine_args["parallel_config"] = parallel_config
engine_args.setdefault("cache_backend", "none")
engine_args["model_stage"] = "diffusion"
@@ -351,40 +1173,49 @@ def create_default_diffusion(cls, kwargs: dict[str, Any]) -> list[dict[str, Any]
@classmethod
def _load_pipeline(cls, model: str, trust_remote_code: bool = True) -> ModelPipeline | None:
- """Load pipeline YAML for the model.
+ """Load a legacy ``pipeline.yaml`` for the model.
- Args:
- model: Model name or path.
- trust_remote_code: Whether to trust remote code for HF config loading.
+ Searches ``model_executor/models//pipeline.yaml`` by trying
+ (a) the raw ``model_type`` as the directory name, then
+ (b) ``model_type`` with hyphens replaced by underscores,
+ and finally (c) scanning every ``pipeline.yaml`` for one that
+ declares a matching ``model_type`` or ``hf_architectures``.
- Returns:
- ModelPipeline if found, None otherwise.
+ Returns None if no pipeline.yaml is found — caller handles the
+ ``resolve_model_config_path`` fallback via stage_configs/ YAMLs.
"""
model_type, hf_config = cls._auto_detect_model_type(model, trust_remote_code=trust_remote_code)
if model_type is None:
return None
- pipeline_dir = cls.PIPELINE_MODELS.get(model_type)
-
- # Fallback: check HF architectures when model_type doesn't match
- if pipeline_dir is None and hf_config is not None:
- for arch in getattr(hf_config, "architectures", []) or []:
- pipeline_dir = cls._ARCHITECTURE_MODELS.get(arch)
- if pipeline_dir is not None:
- model_type = pipeline_dir
- break
-
- if pipeline_dir is None:
- logger.debug(f"No pipeline mapping for model_type: {model_type}")
- return None
-
- pipeline_path = get_pipeline_path(pipeline_dir, "pipeline.yaml")
-
- if not pipeline_path.exists():
- logger.debug(f"Pipeline file not found: {pipeline_path}")
- return None
+ # Direct lookups by convention
+ candidates = [model_type, model_type.replace("-", "_")]
+ for dir_name in candidates:
+ pipeline_path = get_pipeline_path(dir_name, "pipeline.yaml")
+ if pipeline_path.exists():
+ return cls._parse_pipeline_yaml(pipeline_path, model_type)
+
+ # Scan fallback: read every pipeline.yaml and match on declared fields
+ hf_archs = set(getattr(hf_config, "architectures", []) or []) if hf_config else set()
+ if _MODELS_DIR.exists():
+ for subdir in sorted(_MODELS_DIR.iterdir()):
+ if not subdir.is_dir():
+ continue
+ pipeline_path = subdir / "pipeline.yaml"
+ if not pipeline_path.exists():
+ continue
+ try:
+ cfg = load_yaml_config(pipeline_path)
+ except Exception as exc:
+ logger.debug("Skip %s: %s", pipeline_path, exc)
+ continue
+ declared_type = getattr(cfg, "model_type", None)
+ declared_archs = set(getattr(cfg, "hf_architectures", None) or [])
+ if declared_type == model_type or (hf_archs and hf_archs.intersection(declared_archs)):
+ return cls._parse_pipeline_yaml(pipeline_path, declared_type or model_type)
- return cls._parse_pipeline_yaml(pipeline_path, model_type)
+ logger.debug("No pipeline.yaml found for model_type %s (archs=%s)", model_type, sorted(hf_archs))
+ return None
# Keys consumed as explicit StageConfig fields — everything else is
# passed through via yaml_extras.
@@ -542,66 +1373,17 @@ def _auto_detect_model_type(cls, model: str, trust_remote_code: bool = True) ->
return None, None
- # Keys that should never be forwarded as engine overrides (internal /
- # orchestrator-only knobs, complex objects, etc.).
- _INTERNAL_KEYS: set[str] = {
- "model",
- "stage_configs_path",
- "stage_id",
- "stage_init_timeout",
- "init_timeout",
- "shm_threshold_bytes",
- "worker_backend",
- "ray_address",
- "batch_timeout",
- "log_stats",
- "tokenizer",
- "parallel_config",
- }
-
@classmethod
def _merge_cli_overrides(
cls,
stage: StageConfig,
cli_overrides: dict[str, Any],
) -> dict[str, Any]:
- """Merge CLI overrides into stage runtime config.
+ """Merge global and per-stage (``stage_N_*``) CLI overrides.
- All CLI arguments registered by engine config classes (e.g.
- EngineArgs / OmniDiffusionConfig) are accepted as overrides
- unless they appear in ``_INTERNAL_KEYS``.
-
- Handles:
- - Global overrides (apply to all stages)
- - Per-stage overrides (--stage-N-* format, take precedence)
-
- Args:
- stage: The stage to merge overrides into.
- cli_overrides: CLI arguments from VllmConfig/OmniDiffusionConfig.
-
- Returns:
- Dict of runtime overrides for this stage.
+ Orchestrator-owned keys are filtered by ``build_stage_runtime_overrides``
+ using ``OrchestratorArgs`` as the single source of truth; unknown
+ server/uvicorn keys are dropped downstream by
+ ``filter_dataclass_kwargs(OmniEngineArgs, ...)``.
"""
- result: dict[str, Any] = {}
-
- # Apply global overrides – any key not in the internal blocklist
- # is forwarded so that engine-registered params work out of the box.
- for key, value in cli_overrides.items():
- if key in cls._INTERNAL_KEYS:
- continue
- if re.match(r"stage_\d+_", key):
- # Per-stage keys handled below
- continue
- if value is not None:
- result[key] = value
-
- # Apply per-stage overrides (--stage-N-* format, take precedence)
- stage_prefix = f"stage_{stage.stage_id}_"
- for key, value in cli_overrides.items():
- if key.startswith(stage_prefix) and value is not None:
- param_name = key[len(stage_prefix) :]
- if param_name in cls._INTERNAL_KEYS:
- continue
- result[param_name] = value
-
- return result
+ return build_stage_runtime_overrides(stage.stage_id, cli_overrides)
diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py
new file mode 100644
index 0000000000..69e7346c4c
--- /dev/null
+++ b/vllm_omni/core/prefix_cache.py
@@ -0,0 +1,264 @@
+"""
+Utilities for Prefix Caching in Omni models.
+"""
+
+import torch
+from vllm.logger import init_logger
+from vllm.v1.worker.gpu_input_batch import InputBatch
+
+from vllm_omni.utils.mm_outputs import build_mm_cpu, to_payload_element
+
+logger = init_logger(__name__)
+
+
+class OmniTensorPrefixCache:
+ """Prefix cache for hidden states (model outputs) and model specific
+ multimodal outputs.
+
+ This class implements prefix caching in a non-invasive way on top of
+ vLLM by leveraging the same slot mappings that the vLLM scheduler uses
+ for the KV Cache.
+
+ Conceptually, this means we are mapping vLLM's cache mapping:
+ (num_blocks, block_size)
+
+ to 3D tensors of shape:
+ (num_blocks, block_size, feature_size)
+
+ Note that feature_size may vary across multimodal_outputs.
+ """
+
+ def __init__(
+ self,
+ num_blocks: int,
+ block_size: int,
+ hidden_size: int,
+ hs_dtype: torch.dtype,
+ ):
+ self.num_blocks = num_blocks
+ self.block_size = block_size
+ self.default_hidden_size = hidden_size
+
+ # Initialize the hidden states cache immediately
+ self.hidden_states_cache = self._get_cache_tensor(dtype=hs_dtype)
+
+ # Defer initialization of the mm_outputs_cache until we
+ # actually see mm output tensors dependent on num tokens.
+ self.mm_outputs_cache = {}
+ self.mm_cache_keys = set()
+ self._new_req_cache_hit_ids: set[str] = set()
+
+ def maybe_init_missing_mm_cache_keys(self, multimodal_outputs: dict, seq_len: int):
+ """Given multimodal outputs from executing the model, dynamically
+ determine which multimodal outputs are tensors depending on sequence
+ length and should be cached, and initialize the cache tensors
+ accordingly.
+
+ NOTE: This is done to avoid the need for explicit specification of
+ cache keys for every model/stage and aligns with the current way
+ that we slice the multimodal outputs based on the first dimension.
+
+ This will usually be called by the first forward pass, i.e.,
+ determined by the warmup.
+ """
+ for key, val in multimodal_outputs.items():
+ if isinstance(val, torch.Tensor) and val.shape[0] == seq_len and key not in self.mm_cache_keys:
+ feat_dim = val.shape[-1]
+ self.mm_outputs_cache[key] = self._get_cache_tensor(
+ dtype=val.dtype,
+ hidden_size=feat_dim,
+ )
+ self.mm_cache_keys.add(key)
+ new_tensor_shape = self.mm_outputs_cache[key].shape
+ logger.info("Initializing multimodal output cache of size %s for key: %s", list(new_tensor_shape), key)
+
+ def _get_cache_tensor(self, dtype: torch.dtype, hidden_size: int | None = None) -> torch.Tensor:
+ """Allocate a CPU cache tensor for a specific key."""
+ actual_hidden_size = hidden_size if hidden_size is not None else self.default_hidden_size
+ return torch.zeros(
+ (self.num_blocks, self.block_size, actual_hidden_size),
+ dtype=dtype,
+ device="cpu",
+ )
+
+ def add_prefix_cached_new_req_id(self, req_id: str):
+ """Adds a new request ID to the set of prefix cache hits on the batch."""
+ self._new_req_cache_hit_ids.add(req_id)
+
+ def reset_prefix_cached_new_req_ids(self):
+ """Clears the cache hit IDs to prepare for a new engine step."""
+ self._new_req_cache_hit_ids.clear()
+
+ @staticmethod
+ def _coerce_to_cpu_tensor(maybe_gpu_tensor: torch.Tensor) -> torch.Tensor:
+ """Convert GPU tensors -> contiguous CPU tensors if needed."""
+ return maybe_gpu_tensor.detach().cpu().contiguous()
+
+ def update_omni_tensor_prefix_cache(
+ self,
+ hidden_states: torch.Tensor | None,
+ multimodal_outputs: dict[str, torch.Tensor] | None,
+ num_tokens_unpadded: int,
+ slot_mapping: torch.Tensor,
+ num_tokens_padded: int | None = None,
+ ):
+ """Updates the hidden cache state for the provided hidden states and multimodal outputs.
+
+ Args:
+ hidden_states: Hidden states tensor to cache (if any)
+ multimodal_outputs: Multimodal dict whose tensors may be cached
+ num_tokens_unpadded: Number of tokens without padding
+ slot_mapping: Slot mapping for the input sequence
+ num_tokens_padded: Total number of tokens including padding
+ """
+ unpadded_slot_mapping = slot_mapping[:num_tokens_unpadded]
+ if num_tokens_padded is None:
+ num_tokens_padded = num_tokens_unpadded
+
+ if hidden_states is not None:
+ # Slice to unpadded portion before caching
+ hidden_states = hidden_states[:num_tokens_unpadded]
+ # Ensure that hidden states are on the CPU
+ hidden_states = OmniTensorPrefixCache._coerce_to_cpu_tensor(hidden_states)
+ # View the cache as 2D so that we can treat our slots as row indices
+ flat_cache = self.hidden_states_cache.view(-1, self.hidden_states_cache.shape[-1])
+ flat_cache[unpadded_slot_mapping] = hidden_states
+ logger.debug("Writing to hidden states for %s tokens", num_tokens_unpadded)
+
+ # Do the same for the stage's cached multimodal outputs
+ if multimodal_outputs is not None:
+ # If we haven't initialized the keys already, do it now
+ # We check against the padded token count since we haven't sliced yet
+ self.maybe_init_missing_mm_cache_keys(
+ multimodal_outputs,
+ seq_len=num_tokens_padded,
+ )
+
+ for mm_out_key, mm_cache in self.mm_outputs_cache.items():
+ if mm_out_key in multimodal_outputs:
+ # Slice to unpadded portion before caching
+ mm_state = multimodal_outputs[mm_out_key][:num_tokens_unpadded]
+ mm_state = OmniTensorPrefixCache._coerce_to_cpu_tensor(mm_state)
+ flat_cache = mm_cache.view(-1, mm_cache.shape[-1])
+ flat_cache[unpadded_slot_mapping] = mm_state
+ logger.debug("Writing to mm output cache for %s tokens", num_tokens_unpadded)
+
+ def _coerce_to_payload_dict(
+ self,
+ element: object,
+ query_start_loc: torch.Tensor,
+ input_batch: InputBatch,
+ num_scheduled_tokens: dict[str, int],
+ ) -> dict[str, object]:
+ """Build the multimodal passthrough data per request for
+ the object under consideration. This is identical to the case
+ for no prefix cache when we tensor does have a first dimension
+ matching the seq len.
+ """
+ elem_dict = {}
+ for req_id in input_batch.req_ids:
+ req_idx = input_batch.req_id_to_index[req_id]
+ start = query_start_loc[req_idx]
+ end = start + num_scheduled_tokens[req_id]
+ elem_dict[req_id] = to_payload_element(
+ element, req_idx, start=start, end=end, pass_lists_through=True, seq_len=None
+ )
+ return elem_dict
+
+ def get_merged_multimodal_states(
+ self,
+ query_start_loc: torch.Tensor,
+ input_batch: InputBatch,
+ multimodal_outputs: dict,
+ num_scheduled_tokens: dict[str, int],
+ ):
+ """Get the merged multimodal states if hidden state prefix caching is enabled."""
+ combined_multimodal_outputs = {}
+ # First get the prefix cached tensors that are present in the mm data
+ for mm_key in self.mm_cache_keys:
+ if mm_key in multimodal_outputs:
+ combined_multimodal_outputs[mm_key] = self._get_merged_tensors(
+ query_start_loc=query_start_loc,
+ input_batch=input_batch,
+ cache=self.mm_outputs_cache[mm_key],
+ hidden_states=multimodal_outputs[mm_key],
+ num_scheduled_tokens=num_scheduled_tokens,
+ )
+
+ # Then, get everything else (passthrough data); first, convert to CPU
+ # tensors similarly to the non prefix cached path, and then populate
+ # the subdicts mapping request IDs -> payload objects
+ passthrough_keys = set(multimodal_outputs.keys()) - self.mm_cache_keys
+ passthrough_mm_data = {k: v for k, v in multimodal_outputs.items() if k in passthrough_keys}
+ mm_cpu = build_mm_cpu(multimodal_outputs=passthrough_mm_data)
+
+ for mm_key, mm_val in mm_cpu.items():
+ combined_multimodal_outputs[mm_key] = self._coerce_to_payload_dict(
+ element=mm_val,
+ query_start_loc=query_start_loc,
+ input_batch=input_batch,
+ num_scheduled_tokens=num_scheduled_tokens,
+ )
+ return combined_multimodal_outputs
+
+ def get_merged_hidden_states(self, *args, **kwargs) -> dict[str, torch.Tensor]:
+ """Get the merged hidden states."""
+ return self._get_merged_tensors(
+ *args,
+ **kwargs,
+ cache=self.hidden_states_cache,
+ )
+
+ def _get_merged_tensors(
+ self,
+ query_start_loc: torch.Tensor,
+ input_batch: InputBatch,
+ cache: torch.Tensor,
+ hidden_states: torch.Tensor,
+ num_scheduled_tokens: dict[str, int],
+ ) -> dict[str, torch.Tensor]:
+ """When hidden state caching is enabled, takes the input hidden_states,
+ which only correspond to the scheduled tokens, and returns a mapping
+ from request IDs to their full hidden states. This is accomplished by
+ looking up the block IDs & scheduled token counts to split the
+ hidden_states.
+ """
+ # We do not support hybrid caches at the moment.
+ if len(input_batch.block_table.block_tables) > 1:
+ logger.warning_once(
+ "Omni prefix caching is enabled, but the batch block table appears to"
+ " have multiple kv groups; only the first group will be used!"
+ )
+
+ combined_hidden_states = {}
+ hidden_states = OmniTensorPrefixCache._coerce_to_cpu_tensor(hidden_states)
+ for req_id in input_batch.req_ids:
+ req_idx = input_batch.req_id_to_index[req_id]
+
+ if req_id in self._new_req_cache_hit_ids:
+ block_ids = self._get_cached_block_ids(req_idx, input_batch)
+ cached_hs = cache[block_ids].reshape(-1, cache.shape[-1])
+
+ # Slice the hidden states corresponding to this request;
+ # we do this by using the query start
+ start = query_start_loc[req_idx]
+ new_hs = hidden_states[start : start + num_scheduled_tokens[req_id]]
+ combined_hidden_states[req_id] = torch.cat([cached_hs, new_hs], dim=0)
+ else:
+ # cache miss for this request, pass through normally
+ start = query_start_loc[req_idx]
+ new_hs = hidden_states[start : start + num_scheduled_tokens[req_id]]
+ combined_hidden_states[req_id] = new_hs
+
+ return combined_hidden_states
+
+ def _get_cached_block_ids(self, req_idx: int, input_batch: InputBatch) -> torch.Tensor:
+ """Given an input batch and request index in the batch (not ID), get the
+ block IDs corresponding to the cache hit.
+ """
+ num_computed = input_batch.num_computed_tokens_cpu[req_idx]
+ # NOTE: vLLM only caches full blocks
+ num_cached_blocks = num_computed // self.block_size
+ # Get the block IDs attached to this cache hit and reindex into
+ # the flattened cached hidden states (i.e., 1 row per token).
+ return input_batch.block_table[0].block_table.cpu[req_idx, :num_cached_blocks]
diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py
index eac737b6e6..a5579dd464 100644
--- a/vllm_omni/core/sched/omni_ar_scheduler.py
+++ b/vllm_omni/core/sched/omni_ar_scheduler.py
@@ -15,9 +15,10 @@
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.metrics.perf import PerfStats
from vllm.v1.outputs import ModelRunnerOutput
-from vllm.v1.request import Request, RequestStatus
+from vllm.v1.request import Request, RequestStatus, StreamingUpdate
from vllm.v1.spec_decode.metrics import SpecDecodingStats
+from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin
from vllm_omni.core.sched.output import OmniSchedulerOutput
from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import (
OmniChunkTransferAdapter,
@@ -38,7 +39,7 @@ def to_dict(self) -> dict[str, Any]:
return asdict(self)
-class OmniARScheduler(VLLMScheduler):
+class OmniARScheduler(OmniSchedulerMixin, VLLMScheduler):
"""
OmniARScheduler: Scheduler for vLLM-Omni multimodal processing.
@@ -59,6 +60,11 @@ def __init__(self, *args, **kwargs):
# Track ACTIVE transfers (submitted to runner but not yet acked via kv_extracted_req_ids)
self.active_kv_transfers: set[str] = set()
+ # Requests marked for deferred stop: keep running until KV extraction
+ # completes so that kv_ready can be emitted while the request is still
+ # alive. Stopped on the first scheduler step after extraction ack.
+ self.pending_stop_after_extraction: set[str] = set()
+
# [Omni] Pre-parse KV transfer criteria
self.kv_transfer_criteria = self._get_kv_transfer_criteria()
@@ -71,6 +77,8 @@ def __init__(self, *args, **kwargs):
self.chunk_transfer_adapter = None
if getattr(model_config, "async_chunk", False):
self.chunk_transfer_adapter = OmniChunkTransferAdapter(self.vllm_config)
+ # Snapshot prompt length for each streaming input update
+ self._new_prompt_len_snapshot: dict[str, int] = {}
def _get_kv_transfer_criteria(self) -> dict | None:
# Note: vllm_config is available in Scheduler after super().__init__
@@ -126,11 +134,16 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int
stop_decode_on_trigger = self.kv_transfer_criteria.get("stop_after_transfer", True)
if request.request_id in self.transfer_triggered_requests:
- # Already triggered. When stop_decode_on_trigger is True AND
- # transfer was actually queued, the request was already stopped
- # at trigger time (see below). Any request that reaches this
- # point either has stop_decode_on_trigger=False (continue
- # decoding) or was not actually queued (should not be stopped).
+ # Deferred stop: once KV extraction is complete (no longer in
+ # active_kv_transfers), stop the request. This guarantees the
+ # kv_ready signal was emitted while the request was still alive.
+ if (
+ request.request_id in self.pending_stop_after_extraction
+ and request.request_id not in self.active_kv_transfers
+ ):
+ self.pending_stop_after_extraction.discard(request.request_id)
+ request.status = RequestStatus.FINISHED_STOPPED
+ return True
return False
if criteria_type == "prefill_finished":
@@ -140,14 +153,11 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int
actually_queued = request.request_id in self.requests_needing_kv_transfer
if stop_decode_on_trigger and actually_queued:
- # Stop immediately so the request is NOT scheduled in
- # the next step, freeing scheduling budget for companion
- # requests whose chunked-prefill boundaries must be
- # deterministic. waiting_for_transfer_free keeps blocks
- # alive until the model runner finishes KV extraction.
- self.waiting_for_transfer_free.add(request.request_id)
- request.status = RequestStatus.FINISHED_STOPPED
- return True
+ # Defer the stop until KV extraction completes so that
+ # the kv_ready signal can be emitted while the request
+ # is still alive. The request will be stopped on the
+ # next scheduler step after extraction ack arrives.
+ self.pending_stop_after_extraction.add(request.request_id)
return False
@@ -167,9 +177,7 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int
actually_queued = request.request_id in self.requests_needing_kv_transfer
if stop_decode_on_trigger and actually_queued:
- self.waiting_for_transfer_free.add(request.request_id)
- request.status = RequestStatus.FINISHED_STOPPED
- return True
+ self.pending_stop_after_extraction.add(request.request_id)
return False
@@ -268,6 +276,26 @@ def update_from_output(
num_scheduled_tokens,
)
+ # Pre-process KV extraction acks so that the per-request loop below
+ # can see up-to-date active_kv_transfers state and emit kv_ready
+ # signals while requests are still alive (before any deferred stop).
+ kv_extracted_ids = getattr(model_runner_output, "kv_extracted_req_ids", None)
+ if kv_extracted_ids:
+ for req_id in kv_extracted_ids:
+ try:
+ self.active_kv_transfers.discard(req_id)
+ req = self.requests.get(req_id)
+ if req is not None and not req.is_finished():
+ outputs[req.client_index].append(
+ EngineCoreOutput(
+ request_id=req_id,
+ new_token_ids=[],
+ kv_transfer_params={"kv_ready": True},
+ )
+ )
+ except Exception:
+ init_logger(__name__).exception("Failed to pre-process KV extraction for %s", req_id)
+
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
# the below loop can be a performance bottleneck. We should do our best
# to avoid expensive operations inside the loop.
@@ -313,6 +341,7 @@ def update_from_output(
)
stopped = False
+ is_segment_finished = False
new_logprobs = None
new_token_ids = generated_token_ids
pooler_output = pooler_outputs[req_index] if pooler_outputs else None
@@ -341,6 +370,7 @@ def update_from_output(
# Capture finish_reason BEFORE _handle_stopped_request, which may
# reset the status to WAITING for streaming requests that continue.
finish_reason = request.get_finished_reason()
+ is_segment_finished = request.is_finished() and request.resumable
finished = self._handle_stopped_request(request)
if finished:
kv_transfer_params = self._free_request(request)
@@ -393,6 +423,8 @@ def update_from_output(
num_external_computed_tokens=request.num_external_computed_tokens,
routed_experts=routed_experts,
num_nans_in_logits=request.num_nans_in_logits,
+ is_segment_finished=is_segment_finished,
+ new_prompt_len_snapshot=self._new_prompt_len_snapshot.get(req_id, None),
)
)
if self.chunk_transfer_adapter is not None:
@@ -436,6 +468,7 @@ def update_from_output(
self.transfer_triggered_requests.remove(req.request_id)
if req.request_id in self.active_kv_transfers:
self.active_kv_transfers.remove(req.request_id)
+ self.pending_stop_after_extraction.discard(req.request_id)
# Same for preempted
for req in stopped_preempted_reqs:
@@ -444,6 +477,8 @@ def update_from_output(
self.transfer_triggered_requests.remove(req.request_id)
if req.request_id in self.active_kv_transfers:
self.active_kv_transfers.remove(req.request_id)
+ self.pending_stop_after_extraction.discard(req.request_id)
+
# KV Connector: update state for finished KV Transfers.
if kv_connector_output:
self._update_from_kv_xfer_finished(kv_connector_output)
@@ -489,35 +524,12 @@ def update_from_output(
engine_core_outputs[0] = eco = EngineCoreOutputs()
eco.scheduler_stats = stats
- # This is where we free blocks that were held for transfer
- try:
- kv_extracted_ids = getattr(model_runner_output, "kv_extracted_req_ids", None)
- if kv_extracted_ids:
- for req_id in kv_extracted_ids:
- # Emit a kv_ready signal so the orchestrator can forward
- # the request to the DiT stage immediately after KV
- # extraction, without waiting for AR decode to finish.
- req = self.requests.get(req_id)
- if req is not None and not req.is_finished():
- eco = engine_core_outputs.get(req.client_index)
- if eco is None:
- eco = EngineCoreOutputs()
- engine_core_outputs[req.client_index] = eco
- eco.outputs.append(
- EngineCoreOutput(
- request_id=req_id,
- new_token_ids=[],
- kv_transfer_params={"kv_ready": True},
- )
- )
-
- # Mark transfer as finished
- if req_id in self.active_kv_transfers:
- self.active_kv_transfers.remove(req_id)
- logger.debug(f"[Omni] KV Transfer finished for {req_id}")
-
+ # Free blocks that were held for transfer (kv_ready and
+ # active_kv_transfers updates already done before the per-request loop).
+ if kv_extracted_ids:
+ for req_id in kv_extracted_ids:
+ try:
if req_id in self.waiting_for_transfer_free:
- # Now it's safe to free blocks
req = self.requests.get(req_id)
if req:
self.kv_cache_manager.free(req)
@@ -525,16 +537,48 @@ def update_from_output(
del self.requests[req_id]
if req_id in self.transfer_triggered_requests:
self.transfer_triggered_requests.remove(req_id)
- if req_id in self.active_kv_transfers:
- self.active_kv_transfers.remove(req_id)
-
+ self.active_kv_transfers.discard(req_id)
+ self.pending_stop_after_extraction.discard(req_id)
logger.debug(f"Freed blocks for {req_id} after transfer extraction")
self.waiting_for_transfer_free.remove(req_id)
- except Exception:
- init_logger(__name__).exception("Failed to process finished transfer requests")
+ except Exception:
+ init_logger(__name__).exception("Failed to free blocks for %s after transfer", req_id)
return engine_core_outputs
+ def finish_requests(self, request_ids: Any, finished_status: RequestStatus) -> list[tuple[str, int]]:
+ """Handles the finish signal from outside the scheduler.
+
+ For example, the API server can abort a request when the client
+ disconnects.
+
+ If request_ids is None, all requests will be finished.
+
+ Returns:
+ Tuple of (req_id, client_index) for requests that were aborted. Will not
+ include any that were already finished.
+ """
+
+ if self.chunk_transfer_adapter:
+ self.chunk_transfer_adapter.finish_requests(request_ids, finished_status, self.requests)
+
+ return super().finish_requests(request_ids, finished_status)
+
+ def _update_request_as_session(self, session: Request, update: StreamingUpdate) -> None:
+ """
+ Override: Only extend prompt at stage 0, and replace
+ the existing session with the next streaming update at other stages.
+
+ Discards the last sampled output token from the prior input chunk at stage 0.
+ """
+ req_id = session.request_id
+ self._new_prompt_len_snapshot[req_id] = len(update.prompt_token_ids)
+ if self.vllm_config.model_config.stage_id != 0:
+ self._replace_session_with_streaming_update(session, update)
+
+ else:
+ super()._update_request_as_session(session, update)
+
def _free_request(self, request: Request, delay_free_blocks: bool = False) -> dict[str, Any] | None:
# TODO(wzliu)! for offline mode, we should not end process until all data is transferred
"""Mark a request as finished and free its resources."""
@@ -548,6 +592,7 @@ def _free_request(self, request: Request, delay_free_blocks: bool = False) -> di
self.encoder_cache_manager.free(request)
request_id = request.request_id
self.finished_req_ids.add(request_id)
+ self._new_prompt_len_snapshot.pop(request_id, None)
if self.finished_req_ids_dict is not None:
self.finished_req_ids_dict[request.client_index].add(request_id)
@@ -564,8 +609,7 @@ def _free_request(self, request: Request, delay_free_blocks: bool = False) -> di
kv_xfer_params = None
return kv_xfer_params
elif request_id in self.waiting_for_transfer_free:
- # Stopped immediately by stop_decode_on_trigger; blocks are
- # held until KV extraction completes in a future step.
+ # Blocks held until KV extraction completes in a future step.
return None
else:
logger.debug(
diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py
index 1c4356d4f5..81f0b7fc2b 100644
--- a/vllm_omni/core/sched/omni_generation_scheduler.py
+++ b/vllm_omni/core/sched/omni_generation_scheduler.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import time
from collections import defaultdict
@@ -11,11 +13,16 @@
from vllm.v1.core.sched.request_queue import create_request_queue
from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler
from vllm.v1.core.sched.utils import remove_all
-from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
+from vllm.v1.engine import (
+ EngineCoreEventType,
+ EngineCoreOutput,
+ EngineCoreOutputs,
+)
from vllm.v1.metrics.perf import PerfStats
-from vllm.v1.request import Request, RequestStatus
+from vllm.v1.request import Request, RequestStatus, StreamingUpdate
from vllm.v1.spec_decode.metrics import SpecDecodingStats
+from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin
from vllm_omni.core.sched.output import OmniCachedRequestData, OmniNewRequestData
from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import (
OmniChunkTransferAdapter,
@@ -25,7 +32,7 @@
logger = init_logger(__name__)
-class OmniGenerationScheduler(VLLMScheduler):
+class OmniGenerationScheduler(OmniSchedulerMixin, VLLMScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
model_config = self.vllm_config.model_config
@@ -324,6 +331,24 @@ def schedule(self) -> SchedulerOutput:
return scheduler_output
+ def finish_requests(self, request_ids, finished_status: RequestStatus) -> list[tuple[str, int]]:
+ """Handles the finish signal from outside the scheduler.
+
+ For example, the API server can abort a request when the client
+ disconnects.
+
+ If request_ids is None, all requests will be finished.
+
+ Returns:
+ Tuple of (req_id, client_index) for requests that were aborted. Will not
+ include any that were already finished.
+ """
+
+ if self.chunk_transfer_adapter:
+ self.chunk_transfer_adapter.finish_requests(request_ids, finished_status, self.requests)
+
+ return super().finish_requests(request_ids, finished_status)
+
"""
Scheduler for the diffusion model.
This scheduler is modified to stop the request immediately for the diffusion model.
@@ -581,3 +606,11 @@ def update_from_output(
eco.scheduler_stats = stats
return engine_core_outputs
+
+ def _update_request_as_session(self, session: Request, update: StreamingUpdate) -> None:
+ """
+ Override: Just replace the existing session with the next streaming update.
+
+ Do not expend prompt id using update.
+ """
+ self._replace_session_with_streaming_update(session, update)
diff --git a/vllm_omni/core/sched/omni_scheduler_mixin.py b/vllm_omni/core/sched/omni_scheduler_mixin.py
new file mode 100644
index 0000000000..36080e63ac
--- /dev/null
+++ b/vllm_omni/core/sched/omni_scheduler_mixin.py
@@ -0,0 +1,33 @@
+from __future__ import annotations
+
+from vllm.v1.engine import EngineCoreEventType
+from vllm.v1.request import Request, RequestStatus, StreamingUpdate
+
+
+class OmniSchedulerMixin:
+ """Shared scheduler helpers for omni-specific request handling."""
+
+ def _replace_session_with_streaming_update(
+ self,
+ session: Request,
+ update: StreamingUpdate,
+ ) -> None:
+ """For streaming input: Replace an existing streaming session payload with the latest update."""
+ session._output_token_ids.clear()
+ session._all_token_ids.clear()
+ new_prompt = update.prompt_token_ids or ()
+ session._all_token_ids.extend(new_prompt)
+ session.num_computed_tokens = 0
+ session.prompt_token_ids = update.prompt_token_ids or ()
+ session.additional_information = update.additional_information or None
+ # Update block hashes for the new tokens.
+ session.update_block_hashes()
+ session.num_prompt_tokens = len(session.prompt_token_ids)
+ session.arrival_time = update.arrival_time
+ session.sampling_params = update.sampling_params
+ if session.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
+ self.num_waiting_for_streaming_input -= 1
+ session.status = RequestStatus.WAITING
+
+ if self.log_stats:
+ session.record_event(EngineCoreEventType.QUEUED)
diff --git a/vllm_omni/core/sched/omni_scheduling_coordinator.py b/vllm_omni/core/sched/omni_scheduling_coordinator.py
new file mode 100644
index 0000000000..c9d891afb4
--- /dev/null
+++ b/vllm_omni/core/sched/omni_scheduling_coordinator.py
@@ -0,0 +1,380 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Scheduling-side coordination for chunk and full_payload input waiting.
+
+Manages WAITING_FOR_CHUNK and WAITING_FOR_INPUT state transitions
+based on readiness signals from OmniConnectorOutput, without ever
+calling connector.put()/get().
+
+This replaces the scheduling half of OmniChunkTransferAdapter; the
+transport half lives in OmniConnectorModelRunnerMixin.
+"""
+
+from __future__ import annotations
+
+import time
+from collections import deque
+from typing import Any
+
+from vllm.logger import init_logger
+from vllm.v1.request import Request, RequestStatus
+
+logger = init_logger(__name__)
+
+
+class OmniSchedulingCoordinator:
+ """Pure-scheduling coordinator for chunk and full_payload input waiting.
+
+ The Scheduler owns an instance of this class. It consumes readiness
+ signals produced by the Model Runner's ``OmniConnectorModelRunnerMixin``
+ (via ``OmniConnectorOutput``) and manages ``WAITING_FOR_CHUNK`` and
+ ``WAITING_FOR_INPUT`` state transitions accordingly.
+ """
+
+ def __init__(self, scheduler_max_num_seqs: int, stage_id: int = 0, async_chunk: bool = False):
+ self._stage_id = stage_id
+ self._scheduler_max_num_seqs = scheduler_max_num_seqs
+ self._async_chunk = async_chunk
+
+ self.finished_requests: set[str] = set()
+ self.requests_with_ready_chunks: set[str] = set()
+ self._full_payload_input_received: set[str] = set()
+
+ self._waiting_for_chunk_waiting: deque[Any] = deque()
+ self._waiting_for_chunk_running: deque[Any] = deque()
+
+ # Request IDs that were newly registered for chunk recv this cycle.
+ # The engine/Model Runner should call register_chunk_recv() for these
+ # so the bg thread starts polling.
+ self.pending_chunk_registrations: list[Any] = []
+
+ # Requests waiting for full_payload stage input (WAITING_FOR_INPUT).
+ self._waiting_for_input: deque[Any] = deque()
+ self.pending_input_registrations: list[Any] = []
+
+ # Monotonic timestamp recording when each request first entered
+ # WAITING_FOR_CHUNK or WAITING_FOR_INPUT. Used by
+ # collect_timed_out_request_ids() to detect orphaned waits.
+ self._waiting_since: dict[str, float] = {}
+
+ # ------------------------------------------------------------------ #
+ # Core scheduling methods
+ # ------------------------------------------------------------------ #
+
+ def process_pending_chunks(
+ self,
+ waiting_queue: Any,
+ running_queue: list[Request],
+ chunk_ready_req_ids: set[str],
+ chunk_finished_req_ids: set[str],
+ ) -> None:
+ """Transition requests whose chunks have arrived.
+
+ Args:
+ waiting_queue: Scheduler's waiting request queue.
+ running_queue: Scheduler's running request list.
+ chunk_ready_req_ids: IDs with a newly arrived chunk this cycle.
+ chunk_finished_req_ids: IDs whose final chunk has arrived.
+ """
+ if self._stage_id == 0 or not self._async_chunk:
+ return
+
+ terminal_ready_req_ids = chunk_ready_req_ids.intersection(chunk_finished_req_ids)
+ self.finished_requests.update(chunk_finished_req_ids - terminal_ready_req_ids)
+ self.pending_chunk_registrations = []
+
+ self._process_chunk_queue(
+ waiting_queue,
+ self._waiting_for_chunk_waiting,
+ RequestStatus.WAITING,
+ chunk_ready_req_ids,
+ )
+ self._process_chunk_queue(
+ running_queue,
+ self._waiting_for_chunk_running,
+ RequestStatus.RUNNING,
+ chunk_ready_req_ids,
+ )
+ self.finished_requests.update(terminal_ready_req_ids)
+
+ while len(running_queue) > self._scheduler_max_num_seqs:
+ request = running_queue.pop()
+ # Must reset status to WAITING so the scheduler treats it as
+ # schedulable work. KV blocks are NOT freed here (unlike a
+ # real preemption), so PREEMPTED would be incorrect.
+ request.status = RequestStatus.WAITING
+ waiting_queue.prepend_requests([request])
+
+ def process_pending_full_payload_inputs(
+ self,
+ waiting_queue: Any,
+ running_queue: list[Request],
+ stage_recv_req_ids: set[str],
+ ) -> None:
+ """Manage WAITING_FOR_INPUT lifecycle for full_payload_mode.
+
+ For non-Stage-0 stages in full_payload_mode (``async_chunk=False``):
+ 1. Fresh WAITING requests are transitioned to WAITING_FOR_INPUT
+ and registered for bg-thread polling.
+ 2. WAITING_FOR_INPUT requests whose data has arrived (in
+ ``stage_recv_req_ids``) are transitioned back to WAITING.
+ """
+ if self._stage_id == 0:
+ return
+
+ self._full_payload_input_received.update(stage_recv_req_ids)
+ if not self._async_chunk and stage_recv_req_ids:
+ self.finished_requests.update(stage_recv_req_ids)
+ logger.debug(
+ "[Coordinator stage-%s] full_payload recv -> finished_requests: %s",
+ self._stage_id,
+ stage_recv_req_ids,
+ )
+ self.pending_input_registrations = []
+
+ remaining: deque[Any] = deque()
+ for request in self._waiting_for_input:
+ if request.request_id in stage_recv_req_ids:
+ request.status = RequestStatus.WAITING
+ self._waiting_since.pop(request.request_id, None)
+ waiting_queue.add_request(request)
+ else:
+ remaining.append(request)
+ self._waiting_for_input = remaining
+
+ if not self._async_chunk:
+ to_remove: list[Any] = []
+ queue_snapshot = list(waiting_queue)
+ for request in queue_snapshot:
+ if request.status == RequestStatus.WAITING:
+ if request.request_id in self._full_payload_input_received:
+ continue
+ if request.request_id in self.requests_with_ready_chunks:
+ continue
+ if request.request_id in self.finished_requests:
+ continue
+ request.status = RequestStatus.WAITING_FOR_INPUT
+ self._waiting_since.setdefault(request.request_id, time.monotonic())
+ to_remove.append(request)
+ self._waiting_for_input.append(request)
+ self.pending_input_registrations.append(request)
+ elif request.status == RequestStatus.WAITING_FOR_INPUT:
+ if request.request_id in stage_recv_req_ids:
+ request.status = RequestStatus.WAITING
+ self._waiting_since.pop(request.request_id, None)
+ else:
+ to_remove.append(request)
+ self._waiting_for_input.append(request)
+ self.pending_input_registrations.append(request)
+ for request in to_remove:
+ waiting_queue.remove(request)
+
+ def process_pending_full_payload_inputs_legacy(
+ self,
+ waiting_queue: Any,
+ running_queue: list[Request],
+ stage_recv_req_ids: set[str],
+ ) -> None:
+ """Compatibility wrapper for ``process_pending_full_payload_inputs``."""
+ self.process_pending_full_payload_inputs(waiting_queue, running_queue, stage_recv_req_ids)
+
+ def free_finished_request(self, request_id: str) -> None:
+ """Prune internal tracking sets for a freed request to prevent unbounded growth."""
+ self._full_payload_input_received.discard(request_id)
+ self.finished_requests.discard(request_id)
+ self.requests_with_ready_chunks.discard(request_id)
+ self._waiting_since.pop(request_id, None)
+
+ def collect_timed_out_request_ids(
+ self,
+ timeout_s: float,
+ ) -> set[str]:
+ """Return IDs of requests that have been waiting longer than *timeout_s*.
+
+ Uses ``_waiting_since`` timestamps (always up-to-date) to detect
+ timed-out requests. This method is safe to call at any point in
+ the scheduling cycle — it does **not** rely on coordinator internal
+ queues (which are empty after ``restore_queues()``).
+
+ Clears ``_waiting_since`` for timed-out IDs and defensively removes
+ them from coordinator internal queues if present. The caller
+ (scheduler) should then remove the requests from its queues,
+ set ``FINISHED_ERROR``, and call ``_free_request()`` so that
+ ``cleanup_finished_request()`` fires in the model runner mixin.
+ """
+ if timeout_s <= 0:
+ return set()
+ now = time.monotonic()
+ timed_out_ids: set[str] = set()
+ for req_id, start_time in self._waiting_since.items():
+ if now - start_time > timeout_s:
+ timed_out_ids.add(req_id)
+ if not timed_out_ids:
+ return set()
+
+ # Defensively remove from coordinator internal queues (may already
+ # be empty if restore_queues() has run).
+ for queue_attr in (
+ "_waiting_for_chunk_waiting",
+ "_waiting_for_chunk_running",
+ "_waiting_for_input",
+ ):
+ queue = getattr(self, queue_attr)
+ remaining: deque[Any] = deque()
+ for request in queue:
+ if request.request_id not in timed_out_ids:
+ remaining.append(request)
+ setattr(self, queue_attr, remaining)
+
+ for req_id in timed_out_ids:
+ self._waiting_since.pop(req_id, None)
+ logger.warning(
+ "[Coordinator stage-%s] Request %s timed out waiting for chunk/input (waited > %.0fs)",
+ self._stage_id,
+ req_id,
+ timeout_s,
+ )
+
+ return timed_out_ids
+
+ def restore_queues(
+ self,
+ waiting_queue: Any,
+ running_queue: list[Request],
+ ) -> None:
+ """Return waiting-for-chunk/input requests to scheduling queues."""
+ for request in self._waiting_for_chunk_waiting:
+ waiting_queue.add_request(request)
+ self._waiting_for_chunk_waiting = deque()
+
+ if self._waiting_for_chunk_running:
+ running_queue.extend(self._waiting_for_chunk_running)
+ self._waiting_for_chunk_running = deque()
+
+ for request in self._waiting_for_input:
+ waiting_queue.add_request(request)
+ self._waiting_for_input = deque()
+
+ def update_request_metadata(
+ self,
+ requests: dict[str, Request],
+ request_metadata: dict[str, dict[str, Any]],
+ model_mode: str = "ar",
+ ) -> None:
+ """Apply received scheduling metadata to request objects.
+
+ For AR mode: only scheduler-visible metadata is applied locally.
+ For Generation mode: updates ``request.prompt_token_ids``.
+
+ Additionally, if the payload contains ``next_stage_prompt_len``,
+ updates the request's ``prompt_token_ids`` to the correct length.
+ """
+ for req_id, metadata in request_metadata.items():
+ request = requests.get(req_id)
+ if request is None:
+ continue
+
+ # Handle next_stage_prompt_len if present (for models like Qwen3-Omni).
+ # Only apply when the request has not started decoding yet
+ # (no output tokens). Resetting a mid-decode request would
+ # destroy generated tokens and desync KV cache state.
+ if "next_stage_prompt_len" in metadata:
+ next_len = metadata["next_stage_prompt_len"]
+ if isinstance(next_len, int) and next_len > 0:
+ output_token_ids = getattr(request, "_output_token_ids", None)
+ has_decode_output = output_token_ids is not None and len(output_token_ids) > 0
+ if has_decode_output:
+ logger.debug(
+ "[Coordinator stage-%s] Skipping prompt resize for req %s: "
+ "request already has %s output tokens",
+ self._stage_id,
+ req_id,
+ len(output_token_ids),
+ )
+ else:
+ current_prompt_ids = getattr(request, "prompt_token_ids", []) or []
+ current_prompt_len = len(current_prompt_ids)
+ if current_prompt_len != next_len or getattr(request, "num_prompt_tokens", None) != next_len:
+ new_prompt = [0] * next_len
+ request.prompt_token_ids = new_prompt
+ request.num_prompt_tokens = next_len
+ request._all_token_ids.clear()
+ request._all_token_ids.extend(new_prompt)
+ request._output_token_ids.clear()
+ request.num_computed_tokens = 0
+ logger.debug(
+ "[Coordinator stage-%s] Updated prompt_token_ids length to %s for req %s",
+ self._stage_id,
+ next_len,
+ req_id,
+ )
+
+ if model_mode != "ar":
+ new_ids = metadata.get("code_predictor_codes", [])
+ runtime_seed = None
+ if "left_context_size" in metadata:
+ runtime_seed = {
+ "left_context_size": metadata["left_context_size"],
+ }
+ request._omni_initial_model_buffer = runtime_seed
+ if new_ids:
+ request.prompt_token_ids = new_ids
+ request.num_computed_tokens = 0
+
+ def postprocess_scheduler_output(
+ self,
+ scheduler_output: Any,
+ requests: dict[str, Request] | None = None,
+ ) -> None:
+ """Clear per-cycle ready state after scheduler output is materialized."""
+ self._clear_chunk_ready(scheduler_output)
+
+ # ------------------------------------------------------------------ #
+ # Internal helpers
+ # ------------------------------------------------------------------ #
+
+ def _process_chunk_queue(
+ self,
+ queue: Any,
+ waiting_for_chunk_list: deque[Any],
+ target_status: RequestStatus,
+ chunk_ready_req_ids: set[str],
+ ) -> None:
+ queue_snapshot = list(queue)
+ for request in queue_snapshot:
+ if request.status != RequestStatus.WAITING_FOR_CHUNK:
+ if request.request_id in self.requests_with_ready_chunks:
+ continue
+ if request.request_id in self.finished_requests:
+ continue
+ if request.status == RequestStatus.WAITING_FOR_INPUT:
+ continue
+ if request.request_id in chunk_ready_req_ids:
+ self.requests_with_ready_chunks.add(request.request_id)
+ continue
+ self.pending_chunk_registrations.append(request)
+ request.status = RequestStatus.WAITING_FOR_CHUNK
+ self._waiting_since.setdefault(request.request_id, time.monotonic())
+ else:
+ if request.request_id in chunk_ready_req_ids:
+ request.status = target_status
+ self.requests_with_ready_chunks.add(request.request_id)
+ self._waiting_since.pop(request.request_id, None)
+ continue
+ queue.remove(request)
+ waiting_for_chunk_list.append(request)
+
+ def _clear_chunk_ready(self, scheduler_output: Any) -> None:
+ if scheduler_output.scheduled_new_reqs:
+ for req_data in scheduler_output.scheduled_new_reqs:
+ self.requests_with_ready_chunks.discard(
+ getattr(req_data, "req_id", None),
+ )
+
+ if scheduler_output.scheduled_cached_reqs:
+ for req_id in scheduler_output.scheduled_cached_reqs.req_ids:
+ self.requests_with_ready_chunks.discard(req_id)
+
+
+# Backward-compatible alias
+ChunkSchedulingCoordinator = OmniSchedulingCoordinator
diff --git a/vllm_omni/deploy/qwen2_5_omni.yaml b/vllm_omni/deploy/qwen2_5_omni.yaml
new file mode 100644
index 0000000000..41aef0df6f
--- /dev/null
+++ b/vllm_omni/deploy/qwen2_5_omni.yaml
@@ -0,0 +1,92 @@
+# Qwen2.5-Omni deploy: CUDA defaults + platform overrides, verified on 2x H100.
+# Stage 2 disables flashinfer autotune because its DiT block never invokes
+# flashinfer; the autotune dummy run OOMs the shared cuda:0 device otherwise.
+#
+# Fields omitted from a stage fall back to StageDeployConfig dataclass
+# defaults (see vllm_omni/config/stage_config.py). For instance, every
+# stage here uses vLLM's default max_num_batched_tokens=32768 because
+# chat-sized prefill comfortably fits; only models with codec prefill
+# (Qwen3-Omni, Qwen3-TTS) need to bump it above 32k.
+#
+# enforce_eager policy across the three deploy YAMLs:
+# * code2wav / generation stages: always true (cudagraph incompatible with
+# the custom generation loop — set explicitly everywhere).
+# * AR stages (thinker, talker): model-dependent. Qwen2.5-Omni runs eager
+# on CUDA (thinker uses custom ops that don't trace cleanly); NPU / XPU
+# platform overrides flip back to false where cudagraph is verified.
+# Qwen3-Omni / Qwen3-TTS AR stages use the default (false = cudagraph on).
+async_chunk: false
+
+stages:
+ - stage_id: 0
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.8
+ enforce_eager: true
+ mm_processor_cache_gb: 0
+ devices: "0"
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2048
+ seed: 42
+ repetition_penalty: 1.1
+
+ - stage_id: 1
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.8
+ enforce_eager: true
+ devices: "1"
+ default_sampling_params:
+ temperature: 0.9
+ top_p: 0.8
+ top_k: 40
+ max_tokens: 2048
+ seed: 42
+ repetition_penalty: 1.05
+
+ - stage_id: 2
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.15
+ enforce_eager: true
+ enable_flashinfer_autotune: false
+ async_scheduling: false
+ devices: "0"
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2048
+ seed: 42
+ repetition_penalty: 1.1
+
+platforms:
+ npu:
+ stages:
+ # NPU has cudagraph support for the thinker, unlike GPU which still
+ # only runs eager.
+ - stage_id: 0
+ enforce_eager: false
+ - stage_id: 2
+ # 3-NPU layout: stage 2 lives on its own card.
+ devices: "2"
+
+ rocm:
+ stages:
+ - stage_id: 2
+ # 3-GPU MI325 layout: stage 2 on a separate card.
+ devices: "2"
+
+ xpu:
+ stages:
+ # Verified on 2x Intel Arc Pro B60. Both AR stages use cudagraphs.
+ - stage_id: 0
+ gpu_memory_utilization: 0.9
+ enforce_eager: false
+ - stage_id: 1
+ gpu_memory_utilization: 0.5
+ enforce_eager: false
+ - stage_id: 2
+ gpu_memory_utilization: 0.3
+ # Stage 2 colocates with stage 1's device on XPU.
+ devices: "1"
diff --git a/vllm_omni/deploy/qwen3_omni_moe.yaml b/vllm_omni/deploy/qwen3_omni_moe.yaml
new file mode 100644
index 0000000000..fb8b616213
--- /dev/null
+++ b/vllm_omni/deploy/qwen3_omni_moe.yaml
@@ -0,0 +1,98 @@
+# Qwen3-Omni-MoE production deploy, verified on 2x H100 (stage 0 on cuda:0,
+# stages 1+2 on cuda:1).
+#
+# Fields omitted from a stage fall back to StageDeployConfig defaults (see
+# vllm_omni/config/stage_config.py). Notable implicit defaults for this
+# model:
+# * Stages 0/1 (thinker, talker) do not set max_num_batched_tokens —
+# chat-sized prefill fits in the 32768 default.
+# * Stages 0/1 do not set enforce_eager — cudagraph runs by default
+# (false). Stage 2 (code2wav) sets true because its generation loop
+# is cudagraph-incompatible.
+# * Platform sections flip enforce_eager per-stage where platform
+# cudagraph support differs.
+async_chunk: true
+
+connectors:
+ connector_of_shared_memory:
+ name: SharedMemoryConnector
+ extra:
+ codec_chunk_frames: 25
+ codec_left_context_frames: 25
+
+stages:
+ - stage_id: 0
+ gpu_memory_utilization: 0.9
+ devices: "0"
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 42
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ gpu_memory_utilization: 0.6
+ devices: "1"
+ input_connectors:
+ from_stage_0: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ repetition_penalty: 1.05
+
+ - stage_id: 2
+ gpu_memory_utilization: 0.1
+ max_num_seqs: 1
+ enforce_eager: true
+ async_scheduling: false
+ max_num_batched_tokens: 51200
+ devices: "1"
+ input_connectors:
+ from_stage_1: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ repetition_penalty: 1.1
+
+platforms:
+ npu:
+ stages:
+ - stage_id: 0
+ gpu_memory_utilization: 0.6
+ tensor_parallel_size: 2
+ devices: "0,1"
+ - stage_id: 1
+ gpu_memory_utilization: 0.6
+ enforce_eager: true
+ devices: "2"
+ - stage_id: 2
+ gpu_memory_utilization: 0.3
+ devices: "2"
+
+ rocm:
+ stages:
+ - stage_id: 0
+ enforce_eager: true
+
+ xpu:
+ stages:
+ - stage_id: 0
+ tensor_parallel_size: 4
+ enforce_eager: true
+ max_cudagraph_capture_size: 0
+ devices: "0,1,2,3"
+ - stage_id: 1
+ enforce_eager: true
+ max_cudagraph_capture_size: 0
+ devices: "4"
+ - stage_id: 2
+ gpu_memory_utilization: 0.3
+ max_cudagraph_capture_size: 0
+ devices: "4"
diff --git a/vllm_omni/deploy/qwen3_tts.yaml b/vllm_omni/deploy/qwen3_tts.yaml
new file mode 100644
index 0000000000..32dceebd80
--- /dev/null
+++ b/vllm_omni/deploy/qwen3_tts.yaml
@@ -0,0 +1,81 @@
+# Qwen3-TTS deploy: talker → code2wav via shared-memory chunk streaming.
+# Verified on 1x H100.
+#
+# Fields omitted from a stage fall back to StageDeployConfig defaults (see
+# vllm_omni/config/stage_config.py). Notable choices for this model:
+# * Stage 0 (talker) sets max_num_batched_tokens=512 for async-chunk
+# latency tuning (not correctness) — small per-step batches keep
+# first-chunk latency low.
+# * Stage 1 (code2wav) sets max_num_batched_tokens=65536 for correctness:
+# codec prefill length (Q * num_frames) exceeds the 32k default.
+# * Stage 0 does not set enforce_eager — talker runs cudagraph by default.
+# Stage 1 sets true because its codec generation loop is not
+# cudagraph-compatible. NPU platform flips stage 0 to true where
+# cudagraph is not yet verified.
+async_chunk: true
+
+connectors:
+ connector_of_shared_memory:
+ name: SharedMemoryConnector
+ extra:
+ shm_threshold_bytes: 65536
+ codec_streaming: true
+ connector_get_sleep_s: 0.01
+ connector_get_max_wait_first_chunk: 3000
+ connector_get_max_wait: 300
+ # Must match the decoder sliding attention window.
+ codec_chunk_frames: 25
+ codec_left_context_frames: 72
+
+stages:
+ - stage_id: 0
+ max_num_seqs: 10
+ gpu_memory_utilization: 0.3
+ async_scheduling: true
+ max_num_batched_tokens: 512
+ max_model_len: 4096
+ devices: "0"
+ output_connectors:
+ to_stage_1: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.3
+ enforce_eager: true
+ async_scheduling: true
+ # Must be divisible by num_code_groups and cover (left_context + chunk).
+ # Prefill length is Q * num_frames (e.g. 16 * 2148 = 34368); keep
+ # headroom past 32k.
+ max_num_batched_tokens: 65536
+ # async_chunk appends windows per step; max_model_len must cover the
+ # accumulated flat codec stream.
+ max_model_len: 65536
+ devices: "0"
+ input_connectors:
+ from_stage_0: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ repetition_penalty: 1.0
+
+platforms:
+ npu:
+ stages:
+ # NPU does not yet support async-scheduling for TTS, and the
+ # talker fits at max_num_seqs=1 only.
+ - stage_id: 0
+ max_num_seqs: 1
+ enforce_eager: true
+ async_scheduling: false
+ - stage_id: 1
+ gpu_memory_utilization: 0.2
+ async_scheduling: false
diff --git a/vllm_omni/diffusion/attention/parallel/ulysses.py b/vllm_omni/diffusion/attention/parallel/ulysses.py
index 5d860b3350..326b5d4567 100644
--- a/vllm_omni/diffusion/attention/parallel/ulysses.py
+++ b/vllm_omni/diffusion/attention/parallel/ulysses.py
@@ -414,10 +414,6 @@ def pre_attention(
def post_attention(self, attn_output: torch.Tensor, ctx: ParallelAttentionContext | None) -> torch.Tensor:
assert isinstance(ctx, _UlyssesCtx), f"Unexpected ctx type: {type(ctx)!r}"
- # If we have joint tensors (Text), they were Head-Sliced.
- # The main sequence (Image) was Sequence-Sliced.
- # attn_output contains [Joint_Sliced | Image_Sliced] (if strategy='front').
-
if ctx.joint_len > 0:
joint_len = ctx.joint_len
diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py
index a5055a0688..d5397dd166 100644
--- a/vllm_omni/diffusion/cache/cache_dit_backend.py
+++ b/vllm_omni/diffusion/cache/cache_dit_backend.py
@@ -281,6 +281,7 @@ def enable_cache_for_longcat_image(pipeline: Any, cache_config: Any) -> Callable
],
forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_1],
params_modifiers=[modifier],
+ has_separate_cfg=True,
)
),
cache_config=db_cache_config,
@@ -464,6 +465,77 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
return refresh_cache_context
+def enable_cache_for_stable_audio_open(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
+ """Enable cache-dit for Stable Audio Open pipeline.
+
+ Args:
+ pipeline: The StableAudioPipeline instance.
+ cache_config: DiffusionCacheConfig instance with cache configuration.
+
+ Returns:
+ A refresh function that can be called to update cache context with new num_inference_steps.
+ """
+ db_cache_config = _build_db_cache_config(cache_config)
+
+ calibrator_config = None
+ if cache_config.enable_taylorseer:
+ taylorseer_order = cache_config.taylorseer_order
+ calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
+ logger.info(f"TaylorSeer enabled with order={taylorseer_order}")
+
+ # StableAudio is officially registered in CacheDiT as Pattern_3:
+ # https://github.com/vipshop/cache-dit/blob/69e82bd1/src/cache_dit/caching/block_adapters/__init__.py#L562
+ #
+ # Pattern_3 is required because StableAudioDiT uses cross-attention
+ # with static encoder_hidden_states that do not change inside the
+ # transformer block loop.
+ cache_dit.enable_cache(
+ BlockAdapter(
+ transformer=pipeline.transformer,
+ blocks=pipeline.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_3,
+ params_modifiers=[
+ ParamsModifier(
+ cache_config=db_cache_config,
+ calibrator_config=calibrator_config,
+ )
+ ],
+ ),
+ cache_config=db_cache_config,
+ )
+
+ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
+ """Refresh cache context for the transformer with new num_inference_steps.
+
+ Args:
+ pipeline: The StableAudioPipeline instance.
+ num_inference_steps: New number of inference steps.
+ verbose: Whether to log refresh operations.
+ """
+ # Bypass SCM for step counts that don't support predefined masks (e.g., vLLM's 1-step dummy run)
+ scm_supported_steps = num_inference_steps >= 8 or num_inference_steps in (4, 6)
+
+ if cache_config.scm_steps_mask_policy is None or not scm_supported_steps:
+ cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose)
+ else:
+ updated_scm_config = DBCacheConfig().reset(
+ num_inference_steps=num_inference_steps,
+ steps_computation_mask=cache_dit.steps_mask(
+ mask_policy=cache_config.scm_steps_mask_policy,
+ total_steps=num_inference_steps,
+ ),
+ steps_computation_policy=cache_config.scm_steps_policy,
+ )
+
+ cache_dit.refresh_context(
+ pipeline.transformer,
+ cache_config=updated_scm_config,
+ verbose=verbose,
+ )
+
+ return refresh_cache_context
+
+
def enable_cache_for_sd3(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for StableDiffusion3Pipeline.
@@ -561,6 +633,7 @@ def enable_cache_for_ltx2(pipeline: Any, cache_config: Any) -> Callable[[int], N
forward_pattern=ForwardPattern.Pattern_0,
# Treat audio_hidden_states as encoder_hidden_states in Pattern_0
check_forward_pattern=False,
+ has_separate_cfg=True,
),
cache_config=db_cache_config,
calibrator_config=calibrator_config,
@@ -1097,41 +1170,85 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
return refresh_cache_context
-def enable_cache_for_glm_image(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
- """Enable cache-dit for GLM-Image pipeline.
+def enable_cache_for_flux2(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
+ """Enable cache-dit for Flux.2-dev pipeline.
- GLM-Image processes prompt and image by calling the transformer before the
- denoising loop. When an input image is provided (editing mode), the cache must
- be force-refreshed after the preprocessing step so stale hidden states are
- discarded. Set force_refresh_step_hint = 1 for editing, None for text-to-image.
+ Args:
+ pipeline: The Flux2 pipeline instance.
+ cache_config: DiffusionCacheConfig instance with cache configuration.
+ Returns:
+ A refresh function that can be called with a new ``num_inference_steps``
+ to update the cache context for the pipeline.
"""
+ # Build DBCacheConfig for transformer
db_cache_config = _build_db_cache_config(cache_config)
- calibrator_config = None
+ calibrator = None
if cache_config.enable_taylorseer:
- calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=cache_config.taylorseer_order)
- logger.info(f"TaylorSeer enabled with order={cache_config.taylorseer_order}")
+ taylorseer_order = cache_config.taylorseer_order
+ calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
+ logger.info(f"TaylorSeer enabled with order={taylorseer_order}")
+
+ # Build ParamsModifier for transformer
+ modifier = ParamsModifier(
+ cache_config=db_cache_config,
+ calibrator_config=calibrator,
+ )
logger.info(
- f"Enabling cache-dit on GLM-Image transformer: "
+ f"Enabling cache-dit on Flux transformer with BlockAdapter: "
f"Fn={db_cache_config.Fn_compute_blocks}, "
f"Bn={db_cache_config.Bn_compute_blocks}, "
f"W={db_cache_config.max_warmup_steps}, "
- f"force_refresh_step_hint={db_cache_config.force_refresh_step_hint}, "
)
+ # Enable cache-dit using BlockAdapter for transformer
cache_dit.enable_cache(
- pipeline.transformer,
+ (
+ BlockAdapter(
+ transformer=pipeline.transformer,
+ blocks=[
+ pipeline.transformer.transformer_blocks,
+ pipeline.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_2],
+ params_modifiers=[modifier],
+ )
+ ),
cache_config=db_cache_config,
- calibrator_config=calibrator_config,
)
+ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
+ """Refresh cache context for the transformer with new num_inference_steps.
-def enable_cache_for_flux2(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
- """Enable cache-dit for Flux.2-dev pipeline.
+ Args:
+ pipeline: The Flux2 pipeline instance.
+ num_inference_steps: New number of inference steps.
+ """
+ if cache_config.scm_steps_mask_policy is None:
+ cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose)
+ else:
+ cache_dit.refresh_context(
+ pipeline.transformer,
+ cache_config=DBCacheConfig().reset(
+ num_inference_steps=num_inference_steps,
+ steps_computation_mask=cache_dit.steps_mask(
+ mask_policy=cache_config.scm_steps_mask_policy,
+ total_steps=num_inference_steps,
+ ),
+ steps_computation_policy=cache_config.scm_steps_policy,
+ ),
+ verbose=verbose,
+ )
+
+ return refresh_cache_context
+
+
+def enable_cache_for_glm_image(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
+ """Enable cache-dit for GlmImage pipeline.
Args:
- pipeline: The Flux2 pipeline instance.
+ pipeline: The GlmImage pipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.
Returns:
A refresh function that can be called with a new ``num_inference_steps``
@@ -1153,23 +1270,25 @@ def enable_cache_for_flux2(pipeline: Any, cache_config: Any) -> Callable[[int],
)
logger.info(
- f"Enabling cache-dit on Flux transformer with BlockAdapter: "
+ f"Enabling cache-dit on GlmImage transformer with BlockAdapter: "
f"Fn={db_cache_config.Fn_compute_blocks}, "
f"Bn={db_cache_config.Bn_compute_blocks}, "
f"W={db_cache_config.max_warmup_steps}, "
)
# Enable cache-dit using BlockAdapter for transformer
+ # Note: We don't use patch_functor here because it's designed for diffusers' GlmImage,
+ # and our vllm-omni implementation has a different forward signature.
+ # We use ForwardPattern.Pattern_0 because our block returns (hidden_states, encoder_hidden_states)
cache_dit.enable_cache(
(
BlockAdapter(
transformer=pipeline.transformer,
- blocks=[
- pipeline.transformer.transformer_blocks,
- pipeline.transformer.single_transformer_blocks,
- ],
- forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_2],
+ blocks=pipeline.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_0,
params_modifiers=[modifier],
+ patch_functor=None,
+ has_separate_cfg=True,
)
),
cache_config=db_cache_config,
@@ -1179,7 +1298,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
"""Refresh cache context for the transformer with new num_inference_steps.
Args:
- pipeline: The Flux2 pipeline instance.
+ pipeline: The GlmImage pipeline instance.
num_inference_steps: New number of inference steps.
"""
if cache_config.scm_steps_mask_policy is None:
@@ -1212,6 +1331,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
"Flux2KleinPipeline": enable_cache_for_flux2_klein,
"LongCatImagePipeline": enable_cache_for_longcat_image,
"LongCatImageEditPipeline": enable_cache_for_longcat_image,
+ "StableAudioPipeline": enable_cache_for_stable_audio_open,
"StableDiffusion3Pipeline": enable_cache_for_sd3,
"LTX2Pipeline": enable_cache_for_ltx2,
"LTX2ImageToVideoPipeline": enable_cache_for_ltx2,
diff --git a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py
index 5dd80718d1..38c805c28d 100644
--- a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py
+++ b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py
@@ -1,19 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import os
from typing import Any
import numpy as np
import torch
from vllm.config import LoadConfig
-from vllm.utils.torch_utils import set_default_torch_dtype
+from vllm.transformers_utils.config import get_hf_file_to_dict
from vllm_omni.diffusion.cache.teacache.extractors import get_extractor
-from vllm_omni.diffusion.data import OmniDiffusionConfig
+from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig
from vllm_omni.diffusion.hooks import HookRegistry, ModelHook
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
-from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline
-from vllm_omni.diffusion.models.stable_audio.pipeline_stable_audio import StableAudioPipeline
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
@@ -35,6 +34,7 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any:
ctx = self.extractor_fn(module, *args, **kwargs)
+ # NOTE: We upcast to float32 to also handle bfloat16.
modulated_input_cpu = ctx.modulated_input.detach().float().cpu().numpy()
outputs = ctx.run_transformer_blocks()
@@ -53,23 +53,39 @@ def stop_collection(self) -> list[tuple[np.ndarray, np.ndarray]]:
return list(self.current_trajectory)
-class BagelAdapter:
- """Adapter for Bagel model."""
+class DefaultAdapter:
+ """Default adapter for standard diffusers pipelines."""
- @staticmethod
- def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16) -> BagelPipeline:
- od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype)
- od_config.model_class_name = "BagelPipeline"
+ model_class_name = None
+ uses_tf_config = True
+
+ @classmethod
+ def load_pipeline(cls, model_path: str, device: str, dtype: torch.dtype) -> Any:
+ if cls.model_class_name is None:
+ raise ValueError("Adapter doesn't have a set class name.")
+
+ od_config = OmniDiffusionConfig.from_kwargs(
+ model_class_name=cls.model_class_name,
+ model=model_path,
+ dtype=dtype,
+ )
+
+ if cls.uses_tf_config:
+ # TODO (Alex): Refactor to handle tf_model_config in OmniDiffusionConfig
+ # instead of OmniDiffusion and remove the manual population here
+ tf_config_dict = get_hf_file_to_dict(
+ os.path.join("transformer", "config.json"),
+ od_config.model,
+ )
+ od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict)
- pipeline = BagelPipeline(od_config=od_config)
- loader = DiffusersPipelineLoader(LoadConfig())
- loader.load_weights(pipeline)
- pipeline.to(device)
- return pipeline
+ loader = DiffusersPipelineLoader(LoadConfig(), od_config=od_config)
+ # load_model will handle dtypes / device placement, put in .eval() mode
+ return loader.load_model(od_config=od_config, load_device=device)
@staticmethod
def get_transformer(pipeline: Any) -> tuple[Any, str]:
- return pipeline.bagel, "Bagel"
+ return pipeline.transformer, pipeline.transformer.__class__.__name__
@staticmethod
def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
@@ -77,25 +93,17 @@ def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
registry.register_hook(hook._HOOK_NAME, hook)
-class StableAudioAdapter:
- """Adapter for Stable Audio Open 1.0 coefficient estimation."""
-
- @staticmethod
- def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.float16) -> Any:
- od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype)
-
- # Strictly necessary because we bypass loader.load_model()
- with set_default_torch_dtype(dtype):
- pipeline = StableAudioPipeline(od_config=od_config)
+class BagelAdapter(DefaultAdapter):
+ """Adapter for Bagel model."""
- loader = DiffusersPipelineLoader(LoadConfig())
- loader.load_weights(pipeline)
- pipeline.to(device)
- return pipeline
+ model_class_name = "BagelPipeline"
+ # Skip the hack for loading the tf model config,
+ # because bagel doesn't use it.
+ uses_tf_config = False
@staticmethod
def get_transformer(pipeline: Any) -> tuple[Any, str]:
- return pipeline.transformer, "StableAudioDiTModel"
+ return pipeline.bagel, "Bagel"
@staticmethod
def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
@@ -103,26 +111,32 @@ def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
registry.register_hook(hook._HOOK_NAME, hook)
-class DefaultAdapter:
- """Default adapter for standard diffusers pipelines."""
+class Flux2Adapter(DefaultAdapter):
+ """Adapter for Flux2 model coefficient estimation."""
- @staticmethod
- def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> Any:
- raise NotImplementedError("DefaultAdapter.load_pipeline not implemented")
+ model_class_name = "Flux2Pipeline"
- @staticmethod
- def get_transformer(pipeline: Any) -> tuple[Any, str]:
- return pipeline.transformer, pipeline.transformer.__class__.__name__
- @staticmethod
- def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
- registry = HookRegistry.get_or_create(transformer)
- registry.register_hook(hook._HOOK_NAME, hook)
+class LongCatAdapter(DefaultAdapter):
+ """Adapter for LongCat Image - NOTE: currently this model needs the vLLM
+ context to be correctly configured to actually run the estimation, since it
+ uses vLLM norm layers etc.
+ """
+
+ model_class_name = "LongCatImagePipeline"
+
+
+class StableAudioAdapter(DefaultAdapter):
+ """Adapter for Stable Audio Open 1.0 coefficient estimation."""
+
+ model_class_name = "StableAudioPipeline"
_MODEL_ADAPTERS: dict[str, type] = {
"Bagel": BagelAdapter,
"StableAudio": StableAudioAdapter,
+ "Flux2": Flux2Adapter,
+ "LongCat": LongCatAdapter,
}
_EPSILON = 1e-6
@@ -169,7 +183,6 @@ def __init__(
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
- # Add validation here ⬇️
if model_type not in _MODEL_ADAPTERS:
available_types = list(_MODEL_ADAPTERS.keys())
raise ValueError(
@@ -178,7 +191,7 @@ def __init__(
f"To add support for a new model, add an entry to _MODEL_ADAPTERS."
)
- adapter = _MODEL_ADAPTERS.get(model_type, DefaultAdapter)
+ adapter = _MODEL_ADAPTERS[model_type]
self.pipeline = adapter.load_pipeline(model_path, device, dtype)
self.transformer, self.transformer_type = adapter.get_transformer(self.pipeline)
self.hook = DataCollectionHook(self.transformer_type)
diff --git a/vllm_omni/diffusion/cache/teacache/config.py b/vllm_omni/diffusion/cache/teacache/config.py
index 96cf3f03ee..7efdd418e1 100644
--- a/vllm_omni/diffusion/cache/teacache/config.py
+++ b/vllm_omni/diffusion/cache/teacache/config.py
@@ -64,6 +64,17 @@
-1.04182570e01,
6.78098549e-01,
],
+ # Flux2 transformer coefficients
+ # Copied from Qwen-Image, need to be tuned specifically for Flux2 in future
+ "Flux2Transformer2DModel": [
+ -4.50000000e02,
+ 2.80000000e02,
+ -4.50000000e01,
+ 3.20000000e00,
+ -2.00000000e-02,
+ ],
+ # LongCat Image transformer coefficients
+ "LongCatImageTransformer2DModel": [652.5980, -424.1615, 84.5526, -4.5923, 0.1694],
}
diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py
index bdb3f6a786..d0da0d9df3 100644
--- a/vllm_omni/diffusion/cache/teacache/extractors.py
+++ b/vllm_omni/diffusion/cache/teacache/extractors.py
@@ -19,8 +19,12 @@
import torch
import torch.nn as nn
+from vllm.logger import init_logger
from vllm_omni.diffusion.forward_context import get_forward_context
+from vllm_omni.platforms import current_omni_platform
+
+logger = init_logger(__name__)
@dataclass
@@ -221,7 +225,8 @@ def extract_qwen_context(
block = module.transformer_blocks[0]
img_mod_params = block.img_mod(temb)
img_mod1, _ = img_mod_params.chunk(2, dim=-1)
- img_modulated, _ = block.img_norm1(hidden_states, img_mod1)
+ img_scale1, img_shift1, _ = block._modulate(img_mod1)
+ img_modulated = block.img_norm1(hidden_states, img_scale1, img_shift1)
# ============================================================================
# DEFINE TRANSFORMER EXECUTION (Qwen-specific)
@@ -721,6 +726,105 @@ def postprocess(h):
)
+def extract_longcat_context(
+ module: nn.Module, # LongCatImageTransformer2DModel
+ hidden_states,
+ timestep,
+ guidance,
+ encoder_hidden_states,
+ txt_ids,
+ img_ids,
+ **kwargs,
+) -> CacheContext:
+ """Extract the cache context for LongCat Image.
+
+ Similar to other extractors, this is currently the only code needed
+ for TeaCache support for LongCat image, and encapsulates preprocessing,
+ modulated input extraction, transformer execution, and postprocessing
+ logic.
+
+ Args & kawrgs are identical to the inputs to LongCat Image's forward.
+
+ Returns:
+ CacheContext with all information needed for generic caching
+ """
+ # TODO (Alex) - Refactor TeaCache extractors to more tightly integrate with .forward
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
+
+ # 1. Model specific preprocessing
+ fwd_context = get_forward_context()
+ sp_size = module.parallel_config.sequence_parallel_size
+ if sp_size is not None and sp_size > 1:
+ # NOTE: For now, we set this to False on the forward context
+ # to be consistent with LongCat Image's current behavior when
+ # TeaCache is enabled. We do not need to reset it in post process
+ # since we should never split text embed in sp for this model.
+ fwd_context.split_text_embed_in_sp = False
+
+ hidden_states = module.x_embedder(hidden_states)
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+
+ temb = module.time_embed(timestep, hidden_states.dtype)
+ encoder_hidden_states = module.context_embedder(encoder_hidden_states)
+
+ # Compute RoPE embeddings via rope_preparer module
+ # _sp_plan will automatically shard img_cos/img_sin (outputs 2, 3)
+ # txt_cos/txt_sin (outputs 0, 1) remain replicated for dual-stream attention
+ txt_cos, txt_sin, img_cos, img_sin = module.rope_preparer(txt_ids, img_ids)
+
+ # Reconstruct image_rotary_emb with chunked values
+ # Final shape: (txt_seq_len + img_seq_len // SP, head_dim)
+ image_rotary_emb = (
+ torch.cat([txt_cos, img_cos], dim=0),
+ torch.cat([txt_sin, img_sin], dim=0),
+ )
+
+ # 2. Extract the modulated output from the first mm-DiT block
+ first_block = module.transformer_blocks[0]
+ img_modulated = first_block.norm1(hidden_states, emb=temb)[0]
+
+ # 3. Define the transformer execution
+ def run_transformer_blocks():
+ """Execute all Longcat transformer blocks."""
+ h = hidden_states
+ e = encoder_hidden_states
+ for block in module.transformer_blocks:
+ e, h = block(
+ hidden_states=h,
+ encoder_hidden_states=e,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ for block in module.single_transformer_blocks:
+ e, h = block(
+ hidden_states=h,
+ encoder_hidden_states=e,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+ # Hook expects hidden states to be first
+ return (h, e)
+
+ # 4. Postprocessing
+ def postprocess(h):
+ """Apply Longcat-specific output postprocessing."""
+ h = module.norm_out(h, temb)
+ output = module.proj_out(h)
+ return Transformer2DModelOutput(sample=output)
+
+ # 5. Return the CacheContext
+ return CacheContext(
+ modulated_input=img_modulated,
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ run_transformer_blocks=run_transformer_blocks,
+ postprocess=postprocess,
+ )
+
+
def extract_stable_audio_context(
module: nn.Module,
hidden_states: torch.Tensor,
@@ -827,6 +931,144 @@ def postprocess(h: torch.Tensor) -> Any:
)
+def extract_flux2_context(
+ module: nn.Module,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor | None = None,
+ joint_attention_kwargs: dict[str, Any] | None = None,
+ return_dict: bool = True,
+ **kwargs: Any,
+) -> CacheContext:
+ """
+ Extract cache context for Flux2Transformer2DModel.
+
+ This is the ONLY Flux2-specific code needed for TeaCache support.
+ It encapsulates preprocessing, modulated input extraction, transformer execution,
+ and postprocessing logic.
+
+ Args:
+ module: Flux2Transformer2DModel instance
+ hidden_states: Input hidden states tensor
+ encoder_hidden_states: Text encoder outputs
+ timestep: Current diffusion timestep
+ img_ids: Image inputs for position embedding
+ txt_ids: Text inputs for position embedding
+ guidance: Optional guidance scale for CFG
+ joint_attention_kwargs: Additional attention arguments
+ return_dict: Whether to return a Transformer2DModelOutput instead of a plain tensor
+ **kwargs: Additional keyword arguments ignored by this extractor
+
+ Returns:
+ CacheContext with all information needed for generic caching
+ """
+
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
+
+ if not hasattr(module, "transformer_blocks") or len(module.transformer_blocks) == 0:
+ raise ValueError("Module must have transformer_blocks")
+
+ # ============================================================================
+ # PREPROCESSING (Flux2-specific)
+ # ============================================================================
+ num_txt_tokens = encoder_hidden_states.shape[1]
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
+
+ temb = module.time_guidance_embed(timestep, guidance)
+
+ double_stream_mod_img = module.double_stream_modulation_img(temb)
+ double_stream_mod_txt = module.double_stream_modulation_txt(temb)
+ single_stream_mod = module.single_stream_modulation(temb)[0]
+
+ hidden_states = module.x_embedder(hidden_states)
+ encoder_hidden_states = module.context_embedder(encoder_hidden_states)
+
+ if img_ids.ndim == 3:
+ img_ids = img_ids[0]
+ if txt_ids.ndim == 3:
+ txt_ids = txt_ids[0]
+
+ if current_omni_platform.is_npu():
+ freqs_cos_image, freqs_sin_image = module.pos_embed(img_ids.cpu())
+ image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
+ freqs_cos_text, freqs_sin_text = module.pos_embed(txt_ids.cpu())
+ text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
+ else:
+ image_rotary_emb = module.pos_embed(img_ids)
+ text_rotary_emb = module.pos_embed(txt_ids)
+ concat_rotary_emb = (
+ torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
+ torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
+ )
+
+ # ============================================================================
+ # EXTRACT MODULATED INPUT (for cache decision)
+ # ============================================================================
+ block = module.transformer_blocks[0]
+ (shift_msa, scale_msa, gate_msa), _ = double_stream_mod_img
+ modulated_input = block.norm1(hidden_states)
+ modulated_input = (1 + scale_msa) * modulated_input + shift_msa
+
+ # ============================================================================
+ # DEFINE TRANSFORMER EXECUTION (Flux2-specific)
+ # ============================================================================
+ def run_transformer_blocks():
+ """Execute all Flux2 transformer blocks."""
+ h = hidden_states
+ e = encoder_hidden_states
+
+ for transformer_block in module.transformer_blocks:
+ e, h = transformer_block(
+ hidden_states=h,
+ encoder_hidden_states=e,
+ temb_mod_params_img=double_stream_mod_img,
+ temb_mod_params_txt=double_stream_mod_txt,
+ image_rotary_emb=concat_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+ h = torch.cat([e, h], dim=1)
+
+ for single_transformer_block in module.single_transformer_blocks:
+ h = single_transformer_block(
+ hidden_states=h,
+ encoder_hidden_states=None,
+ temb_mod_params=single_stream_mod,
+ image_rotary_emb=concat_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ h = h[:, num_txt_tokens:, ...]
+ return (h,)
+
+ # ============================================================================
+ # DEFINE POSTPROCESSING
+ # ============================================================================
+ def postprocess(h):
+ h = module.norm_out(h, temb)
+ output = module.proj_out(h)
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
+
+ # ============================================================================
+ # RETURN CONTEXT
+ # ============================================================================
+ return CacheContext(
+ modulated_input=modulated_input,
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ run_transformer_blocks=run_transformer_blocks,
+ postprocess=postprocess,
+ )
+
+
# Registry for model-specific extractors
# Key: Transformer class name
# Value: extractor function with signature (module, *args, **kwargs) -> CacheContext
@@ -839,6 +1081,8 @@ def postprocess(h: torch.Tensor) -> Any:
"ZImageTransformer2DModel": extract_zimage_context,
"Flux2Klein": extract_flux2_klein_context,
"StableAudioDiTModel": extract_stable_audio_context,
+ "Flux2Transformer2DModel": extract_flux2_context,
+ "LongCatImageTransformer2DModel": extract_longcat_context,
# Future models:
# "FluxTransformer2DModel": extract_flux_context,
# "CogVideoXTransformer3DModel": extract_cogvideox_context,
diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py
index 56a891aa5c..0a19eb1197 100644
--- a/vllm_omni/diffusion/data.py
+++ b/vllm_omni/diffusion/data.py
@@ -18,6 +18,7 @@
QuantizationConfig,
)
+from vllm_omni.diffusion.model_metadata import get_diffusion_model_metadata
from vllm_omni.diffusion.utils.network_utils import is_port_available
from vllm_omni.quantization import build_quant_config
@@ -481,8 +482,10 @@ class OmniDiffusionConfig:
# Scheduler flow_shift for Wan2.2 (12.0 for 480p, 5.0 for 720p)
flow_shift: float | None = None
- # support multi images input
+ # Support multi-image inputs and expose any model-specific request limit
+ # through a generic config field so serving code stays model-agnostic.
supports_multimodal_inputs: bool = False
+ max_multimodal_image_inputs: int | None = None
log_level: str = "info"
@@ -664,7 +667,54 @@ def set_tf_model_config(self, tf_config: "TransformerConfig") -> None:
)
def update_multimodal_support(self) -> None:
- self.supports_multimodal_inputs = self.model_class_name in {"QwenImageEditPlusPipeline"}
+ # Resolve serving-visible multimodal behavior from shared metadata
+ # instead of importing concrete pipeline modules into the config layer.
+ metadata = get_diffusion_model_metadata(self.model_class_name)
+ self.supports_multimodal_inputs = metadata.supports_multimodal_inputs
+ self.max_multimodal_image_inputs = metadata.max_multimodal_image_inputs
+
+ def enrich_config(self) -> None:
+ """Load model metadata from HuggingFace and populate config fields.
+
+ Diffusers-style models expose ``model_index.json`` with ``_class_name``.
+ Non-diffusers models (e.g. Bagel, NextStep) only have ``config.json``,
+ so we fall back to reading that and mapping model_type manually.
+ """
+ from vllm.transformers_utils.config import get_hf_file_to_dict
+
+ try:
+ config_dict = get_hf_file_to_dict("model_index.json", self.model)
+ if config_dict is not None:
+ if self.model_class_name is None:
+ self.model_class_name = config_dict.get("_class_name", None)
+ self.update_multimodal_support()
+
+ tf_config_dict = get_hf_file_to_dict("transformer/config.json", self.model)
+ self.tf_model_config = TransformerConfig.from_dict(tf_config_dict)
+ else:
+ raise FileNotFoundError("model_index.json not found")
+ except (AttributeError, OSError, ValueError, FileNotFoundError):
+ cfg = get_hf_file_to_dict("config.json", self.model)
+ if cfg is None:
+ raise ValueError(f"Could not find config.json or model_index.json for model {self.model}")
+
+ self.tf_model_config = TransformerConfig.from_dict(cfg)
+ model_type = cfg.get("model_type")
+ architectures = cfg.get("architectures") or []
+
+ if model_type == "bagel" or "BagelForConditionalGeneration" in architectures:
+ self.model_class_name = "BagelPipeline"
+ self.tf_model_config = TransformerConfig()
+ self.update_multimodal_support()
+ elif model_type == "nextstep":
+ if self.model_class_name is None:
+ self.model_class_name = "NextStep11Pipeline"
+ self.tf_model_config = TransformerConfig()
+ self.update_multimodal_support()
+ elif architectures and len(architectures) == 1:
+ self.model_class_name = architectures[0]
+ else:
+ raise
@classmethod
def from_kwargs(cls, **kwargs: Any) -> "OmniDiffusionConfig":
diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py
index 422ef479b0..fe940d623e 100644
--- a/vllm_omni/diffusion/diffusion_engine.py
+++ b/vllm_omni/diffusion/diffusion_engine.py
@@ -3,6 +3,7 @@
from __future__ import annotations
+import inspect
import queue
import threading
import time
@@ -78,6 +79,12 @@ def __init__(
self.post_process_func = get_diffusion_post_process_func(od_config)
self.pre_process_func = get_diffusion_pre_process_func(od_config)
+ # Cache whether the model-specific postprocess accepts request-level
+ # sampling params so step() can support both legacy and extended hooks.
+ self._post_process_accepts_sampling_params = bool(
+ self.post_process_func is not None
+ and "sampling_params" in inspect.signature(self.post_process_func).parameters
+ )
executor_class = DiffusionExecutor.get_class(od_config)
self.executor = executor_class(od_config)
@@ -143,12 +150,22 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
output_data = output_data.cpu()
postprocess_start_time = time.perf_counter()
- outputs = self.post_process_func(output_data) if self.post_process_func is not None else output_data
+ if self.post_process_func is not None:
+ # Some video pipelines need request-level controls during
+ # postprocess (for example worker-side frame interpolation).
+ if self._post_process_accepts_sampling_params:
+ outputs = self.post_process_func(output_data, sampling_params=request.sampling_params)
+ else:
+ outputs = self.post_process_func(output_data)
+ else:
+ outputs = output_data
audio_payload = None
+ custom_output = output.custom_output or {}
model_audio_sample_rate = None
model_fps = None
if isinstance(outputs, dict):
audio_payload = outputs.get("audio")
+ custom_output.update(outputs.get("custom_output") or {})
model_audio_sample_rate = outputs.get("audio_sample_rate")
model_fps = outputs.get("fps")
outputs = outputs.get("video", outputs)
@@ -225,7 +242,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
trajectory_timesteps=output.trajectory_timesteps,
trajectory_log_probs=output.trajectory_log_probs,
trajectory_decoded=output.trajectory_decoded,
- custom_output=output.custom_output or {},
+ custom_output=custom_output,
multimodal_output=mm_output,
stage_durations=output.stage_durations,
peak_memory_mb=output.peak_memory_mb,
@@ -295,7 +312,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
trajectory_timesteps=output.trajectory_timesteps,
trajectory_log_probs=output.trajectory_log_probs,
trajectory_decoded=output.trajectory_decoded,
- custom_output=output.custom_output or {},
+ custom_output=custom_output,
multimodal_output=mm_output,
stage_durations=output.stage_durations,
peak_memory_mb=output.peak_memory_mb,
@@ -361,15 +378,11 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus
)
def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None:
- """Start or stop torch profiling on all diffusion workers.
+ """Start or stop profiling on all diffusion workers.
Args:
is_start: True to start profiling, False to stop.
- profile_prefix: Optional prefix for trace filename (vLLM compat).
-
- Note:
- Matches vLLM's worker.profile() signature for consistency.
- Traces are saved automatically via on_trace_ready callback.
+ profile_prefix: Optional prefix for trace filename.
"""
if is_start:
if profile_prefix is None:
diff --git a/vllm_omni/diffusion/distributed/cfg_parallel.py b/vllm_omni/diffusion/distributed/cfg_parallel.py
index a8b0012f66..98757006bf 100644
--- a/vllm_omni/diffusion/distributed/cfg_parallel.py
+++ b/vllm_omni/diffusion/distributed/cfg_parallel.py
@@ -9,6 +9,7 @@
from typing import Any
import torch
+from vllm.logger import init_logger
from vllm_omni.diffusion.distributed.parallel_state import (
get_cfg_group,
@@ -16,6 +17,8 @@
get_classifier_free_guidance_world_size,
)
+logger = init_logger(__name__)
+
def _wrap(pred: torch.Tensor | tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]:
"""Normalize prediction to tuple form."""
@@ -32,6 +35,24 @@ def _slice_pred(pred: tuple[torch.Tensor, ...], output_slice: int) -> tuple[torc
return tuple(p[:, :output_slice] for p in pred)
+def _dispatch_branches(n_branches: int, n_ranks: int) -> list[list[int]]:
+ """
+ Round-robin dispatch N branches to M ranks.
+
+ Rule: branch i → rank (i % n_ranks).
+
+ Examples:
+ _dispatch_branches(3, 2) -> [[0, 2], [1]]
+ _dispatch_branches(3, 3) -> [[0], [1], [2]]
+ _dispatch_branches(4, 2) -> [[0, 2], [1, 3]]
+ _dispatch_branches(4, 4) -> [[0], [1], [2], [3]]
+ """
+ assignments: list[list[int]] = [[] for _ in range(n_ranks)]
+ for i in range(n_branches):
+ assignments[i % n_ranks].append(i)
+ return assignments
+
+
class CFGParallelMixin(metaclass=ABCMeta):
"""
Base Mixin class for Diffusion pipelines providing shared CFG methods.
@@ -189,6 +210,165 @@ def combine_cfg_noise(self, positive_noise_pred, negative_noise_pred, scale, nor
results.append(comb)
return _unwrap(tuple(results))
+ # ── N-branch CFG interface (for 3+ branch models) ──
+
+ def predict_noise_with_multi_branch_cfg(
+ self,
+ do_true_cfg: bool,
+ true_cfg_scale: float | dict[str, float],
+ branches_kwargs: list[dict[str, Any]],
+ cfg_normalize: bool = False,
+ output_slice: int | None = None,
+ ) -> torch.Tensor | tuple[torch.Tensor, ...]:
+ """
+ Predict noise with N-branch CFG dispatch across M GPUs.
+
+ This is the multi-branch counterpart of predict_noise_maybe_with_cfg().
+ Use this for models with 3 or more CFG branches (e.g., OmniGen2, Bagel,
+ DreamID). Existing 2-branch models should continue using
+ predict_noise_maybe_with_cfg().
+
+ Args:
+ do_true_cfg: Whether to apply CFG.
+ true_cfg_scale: CFG scale factor (passed to combine_multi_branch_cfg_noise).
+ branches_kwargs: List of N dicts, each containing kwargs for one
+ predict_noise() call. branches_kwargs[0] is always the
+ positive/conditional branch.
+ cfg_normalize: Whether to normalize (passed to combine_multi_branch_cfg_noise).
+ output_slice: If set, slice each output to [:, :output_slice].
+
+ Returns:
+ Combined noise prediction, identical on all ranks in CFG parallel.
+ """
+ if do_true_cfg:
+ n_branches = len(branches_kwargs)
+ cfg_world_size = get_classifier_free_guidance_world_size()
+ cfg_parallel_ready = cfg_world_size > 1
+
+ if cfg_parallel_ready:
+ return self._predict_multi_branch_parallel(
+ branches_kwargs,
+ n_branches,
+ cfg_world_size,
+ true_cfg_scale,
+ cfg_normalize,
+ output_slice,
+ )
+ else:
+ # Sequential: run all N branches on single device
+ preds: list[torch.Tensor | tuple[torch.Tensor, ...]] = []
+ for kw in branches_kwargs:
+ pred = _wrap(self.predict_noise(**kw))
+ if output_slice is not None:
+ pred = _slice_pred(pred, output_slice)
+ preds.append(_unwrap(pred))
+ return self.combine_multi_branch_cfg_noise(preds, true_cfg_scale, cfg_normalize)
+ else:
+ # No CFG: only compute positive/conditional prediction
+ pred = self.predict_noise(**branches_kwargs[0])
+ if output_slice is not None:
+ pred = _unwrap(_slice_pred(_wrap(pred), output_slice))
+ return pred
+
+ def _predict_multi_branch_parallel(
+ self,
+ branches_kwargs: list[dict[str, Any]],
+ n_branches: int,
+ cfg_world_size: int,
+ true_cfg_scale: float,
+ cfg_normalize: bool,
+ output_slice: int | None,
+ ) -> torch.Tensor | tuple[torch.Tensor, ...]:
+ """Dispatch N branches across M ranks, all_gather, then combine."""
+ cfg_group = get_cfg_group()
+ cfg_rank = get_classifier_free_guidance_rank()
+
+ if cfg_world_size > n_branches:
+ logger.warning_once(
+ "cfg_parallel_size=%d > n_branches=%d, %d GPU(s) will be idle for CFG",
+ cfg_world_size,
+ n_branches,
+ cfg_world_size - n_branches,
+ )
+
+ # Assign branches to ranks via round-robin
+ assignments = _dispatch_branches(n_branches, cfg_world_size)
+ my_branch_ids = assignments[cfg_rank]
+ max_per_rank = max(len(a) for a in assignments)
+
+ # Run assigned branches
+ my_preds: list[tuple[torch.Tensor, ...]] = []
+ for bid in my_branch_ids:
+ pred = _wrap(self.predict_noise(**branches_kwargs[bid]))
+ if output_slice is not None:
+ pred = _slice_pred(pred, output_slice)
+ my_preds.append(pred)
+
+ # Idle ranks (cfg_world_size > n_branches) run a forward pass to get the output shape for all_gather.
+ # Output shape cannot be inferred from kwargs — may be tuple, sliced, etc.
+ if not my_preds:
+ pred = _wrap(self.predict_noise(**branches_kwargs[0]))
+ if output_slice is not None:
+ pred = _slice_pred(pred, output_slice)
+ my_preds.append(pred)
+
+ # Pad to max_per_rank with zeros so all ranks have same size
+ ref_pred = my_preds[0]
+ while len(my_preds) < max_per_rank:
+ my_preds.append(tuple(torch.zeros_like(t) for t in ref_pred))
+
+ # All-gather each output element separately (like predict_noise_maybe_with_cfg)
+ # For each slot, gather across ranks; then pick valid results by owner_rank
+ # all_slots[slot][elem_idx] = [rank0_tensor, rank1_tensor, ...]
+ all_slots: list[list[list[torch.Tensor]]] = []
+ for slot in range(max_per_rank):
+ slot_results: list[list[torch.Tensor]] = []
+ for p in my_preds[slot]:
+ gathered = cfg_group.all_gather(p, separate_tensors=True)
+ slot_results.append(gathered)
+ all_slots.append(slot_results)
+
+ # Reconstruct final_preds in branch order
+ final_preds: list[torch.Tensor | tuple[torch.Tensor, ...]] = []
+ for bid in range(n_branches):
+ owner_rank = bid % cfg_world_size
+ slot_idx = bid // cfg_world_size
+ elements = tuple(all_slots[slot_idx][elem_idx][owner_rank] for elem_idx in range(len(ref_pred)))
+ final_preds.append(_unwrap(elements))
+
+ return self.combine_multi_branch_cfg_noise(final_preds, true_cfg_scale, cfg_normalize)
+
+ def combine_multi_branch_cfg_noise(
+ self,
+ predictions: list[torch.Tensor | tuple[torch.Tensor, ...]],
+ true_cfg_scale: float | dict[str, float],
+ cfg_normalize: bool = False,
+ ) -> torch.Tensor | tuple[torch.Tensor, ...]:
+ """
+ Combine N branch predictions. Default: standard 2-branch CFG formula.
+
+ Override this method for custom multi-branch combine logic.
+
+ Args:
+ predictions: List of N predictions, where predictions[0] is always
+ the positive/conditional branch.
+ true_cfg_scale: CFG scale factor (float for 2-branch, dict for multi-branch).
+ cfg_normalize: Whether to normalize the combined prediction.
+
+ Returns:
+ Combined noise prediction.
+ """
+ positive = _wrap(predictions[0])
+ negative = _wrap(predictions[1])
+
+ results = []
+ for p, n in zip(positive, negative):
+ comb = n + true_cfg_scale * (p - n)
+ if cfg_normalize:
+ comb = self.cfg_normalize_function(p, comb)
+ results.append(comb)
+ return _unwrap(tuple(results))
+
def predict_noise(self, *args: Any, **kwargs: Any) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""
Forward pass through transformer to predict noise.
diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py
index 8ab38f2a65..5294e6c9ed 100644
--- a/vllm_omni/diffusion/distributed/group_coordinator.py
+++ b/vllm_omni/diffusion/distributed/group_coordinator.py
@@ -104,6 +104,7 @@ def __init__(
self.local_rank = local_rank
self.device_group = None
self.cpu_group = None
+ self.shm_broadcaster = None
for ranks in group_ranks:
device_group = torch.distributed.new_group(ranks, backend=torch_distributed_backend)
@@ -316,7 +317,7 @@ def send_object(self, obj: Any, dst: int) -> None:
assert dst < self.world_size, f"Invalid dst rank ({dst})"
- assert dst != self.rank, "Invalid destination rank. Destination rank is the same as the current rank."
+ assert dst != self.rank_in_group, "Invalid destination rank. Destination rank is the same as the current rank."
# Serialize object to tensor and get the size as well
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
@@ -338,7 +339,7 @@ def recv_object(self, src: int) -> Any:
assert src < self.world_size, f"Invalid src rank ({src})"
- assert src != self.rank, "Invalid source rank. Source rank is the same as the current rank."
+ assert src != self.rank_in_group, "Invalid source rank. Source rank is the same as the current rank."
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
diff --git a/vllm_omni/diffusion/hooks/base.py b/vllm_omni/diffusion/hooks/base.py
index cda4201ccf..517c661587 100644
--- a/vllm_omni/diffusion/hooks/base.py
+++ b/vllm_omni/diffusion/hooks/base.py
@@ -8,6 +8,7 @@
from __future__ import annotations
+import functools
import inspect
from collections.abc import Callable
from dataclasses import dataclass
@@ -94,10 +95,9 @@ def post_forward(self, module: nn.Module, output: Any) -> Any:
return output
def new_forward(self, module: nn.Module, *args: Any, **kwargs: Any) -> Any:
- """Override the module's forward pass completely.
-
- The default implementation calls pre_forward, then the original forward,
- then post_forward. Override this method for more complex behavior.
+ """Override the module's forward pass. This should be overridden for more complex
+ cases, e.g., TeaCache. If this method is overridden in a subclass, it will be called
+ instead of self.module._omni_original_forward when executing the hooks.
Args:
module: The module being called.
@@ -105,11 +105,9 @@ def new_forward(self, module: nn.Module, *args: Any, **kwargs: Any) -> Any:
**kwargs: Keyword arguments to forward.
Returns:
- The output of the forward pass.
+ The output of the replacement for the forward pass.
"""
- args, kwargs = self.pre_forward(module, *args, **kwargs)
- output = module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined]
- return self.post_forward(module, output)
+ raise NotImplementedError("By default, hooks do not implement new_forward")
def reset_state(self, module: nn.Module) -> nn.Module:
"""Reset any state associated with this hook.
@@ -136,6 +134,21 @@ def __call__(self, *args: Any, **kwargs: Any):
return registry.dispatch(*args, **kwargs)
+def sort_hooks_after_call(func):
+ """Calls the method on the hook registry, then sorts the hooks.
+
+ This should be added to methods that mutate add or remove hooks.
+ """
+
+ @functools.wraps(func)
+ def wrapper(self: HookRegistry, *args, **kwargs):
+ res = func(self, *args, **kwargs)
+ self.update_sorted_hooks()
+ return res
+
+ return wrapper
+
+
class HookRegistry:
"""Registry of hooks attached to a module.
@@ -146,6 +159,10 @@ class HookRegistry:
def __init__(self, module: nn.Module):
self.module = module
self._hooks: dict[str, ModelHook] = {}
+ # Hooks sorted by execution order
+ self._sorted_hooks: list[ModelHook] = []
+ # Hooks overriding new_forward (if any)
+ self._new_fwd_impl_hook: ModelHook | None = None
@classmethod
def get_or_create(cls, module: nn.Module) -> HookRegistry:
@@ -173,6 +190,14 @@ def get_or_create(cls, module: nn.Module) -> HookRegistry:
return registry
+ def update_sorted_hooks(self):
+ """Sort hooks by name, which dictates pre/post process order."""
+ sorted_hooks = [self._hooks[k] for k in sorted(self._hooks) if self._hooks[k] != self._new_fwd_impl_hook]
+ if self._new_fwd_impl_hook is not None:
+ sorted_hooks.append(self._new_fwd_impl_hook)
+ self._sorted_hooks = sorted_hooks
+
+ @sort_hooks_after_call
def register_hook(self, name: str, hook: ModelHook) -> None:
"""Register a hook with the given name.
@@ -182,7 +207,14 @@ def register_hook(self, name: str, hook: ModelHook) -> None:
"""
hook.initialize_hook(self.module)
self._hooks[name] = hook
-
+ # We can only have one hook that overrides new_forward,
+ # since we don't currently have a mechanism for combining them.
+ if type(hook).new_forward is not ModelHook.new_forward:
+ if self._new_fwd_impl_hook is not None:
+ raise RuntimeError("Cannot have multiple hooks that override forward active simultaneously")
+ self._new_fwd_impl_hook = hook
+
+ @sort_hooks_after_call
def remove_hook(self, name: str) -> None:
"""Remove a hook by name.
@@ -190,6 +222,9 @@ def remove_hook(self, name: str) -> None:
name: The name of the hook to remove.
"""
if name in self._hooks:
+ # clear the forward hook if it's the one to delete
+ if self._new_fwd_impl_hook is self._hooks[name]:
+ self._new_fwd_impl_hook = None
del self._hooks[name]
def get_hook(self, name: str) -> ModelHook | None:
@@ -206,8 +241,18 @@ def get_hook(self, name: str) -> ModelHook | None:
def dispatch(self, *args: Any, **kwargs: Any) -> Any:
"""Dispatch a forward call through registered hooks.
- Currently supports a single active hook. Multiple hooks are called
- in sorted order by name, with each hook's output passed to the next.
+ Multiple hooks may be used with the caveat that only one hook
+ may override new_forward. While it is assumed that pre/post process
+ on hooks are composable, the execution flow is as follows for determinism:
+
+ - Run preprocess on all hooks in their sorted order; hooks are sorted alphabetically,
+ except for the hook overriding forward (`self._new_fwd_impl_hook`), which is last
+ if it exists.
+
+ - If `self._new_fwd_impl_hook` isn't None, call its forward. Otherwise call the
+ original model forward.
+
+ - Run post process on all hooks in the reverse sorted order.
Args:
*args: Positional arguments to forward.
@@ -219,24 +264,19 @@ def dispatch(self, *args: Any, **kwargs: Any) -> Any:
if not self._hooks:
return self.module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined]
- # For single hook case, call directly
- if len(self._hooks) == 1:
- hook = next(iter(self._hooks.values()))
- return hook.new_forward(self.module, *args, **kwargs)
-
- # For multiple hooks, chain them in sorted order
- # Each hook can modify args/kwargs via pre_forward
- sorted_hooks = sorted(self._hooks.items(), key=lambda x: x[0])
-
- # Apply all pre_forward hooks
- for _, hook in sorted_hooks:
+ # Apply all pre_forward hooks; if _new_fwd_impl_hook is set, it's last
+ for hook in self._sorted_hooks:
args, kwargs = hook.pre_forward(self.module, *args, **kwargs)
- # Call original forward
- output = self.module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined]
+ # If we have a hook that overrides new_forward, call it directly
+ if self._new_fwd_impl_hook is not None:
+ output = self._new_fwd_impl_hook.new_forward(self.module, *args, **kwargs)
+ # Otherwise just call the original forward.
+ else:
+ output = self.module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined]
- # Apply all post_forward hooks in reverse order
- for _, hook in reversed(sorted_hooks):
+ # Apply all post_forward hooks in reverse order; if _new_fwd_impl_hook is set, it's first
+ for hook in reversed(self._sorted_hooks):
output = hook.post_forward(self.module, output)
return output
diff --git a/vllm_omni/diffusion/inline_stage_diffusion_client.py b/vllm_omni/diffusion/inline_stage_diffusion_client.py
new file mode 100644
index 0000000000..a33a3e9561
--- /dev/null
+++ b/vllm_omni/diffusion/inline_stage_diffusion_client.py
@@ -0,0 +1,348 @@
+"""Inline Stage Diffusion Client for vLLM-Omni multi-stage runtime.
+
+Runs DiffusionEngine in a ThreadPoolExecutor inside the Orchestrator process
+instead of spawning a separate StageDiffusionProc subprocess, eliminating ZMQ
+IPC overhead. Used when there is only a single diffusion stage.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import time
+from concurrent.futures import ThreadPoolExecutor
+from typing import TYPE_CHECKING, Any
+
+import torch
+from PIL import Image
+from vllm.logger import init_logger
+
+from vllm_omni.diffusion.data import DiffusionRequestAbortedError
+from vllm_omni.diffusion.diffusion_engine import DiffusionEngine
+from vllm_omni.diffusion.request import OmniDiffusionRequest
+from vllm_omni.engine.stage_init_utils import StageMetadata
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+from vllm_omni.outputs import OmniRequestOutput
+
+if TYPE_CHECKING:
+ from vllm_omni.diffusion.data import OmniDiffusionConfig
+ from vllm_omni.inputs.data import OmniPromptType
+
+logger = init_logger(__name__)
+
+
+class InlineStageDiffusionClient:
+ """Runs DiffusionEngine in a thread executor inside the Orchestrator."""
+
+ stage_type: str = "diffusion"
+
+ def __init__(
+ self,
+ model: str,
+ od_config: OmniDiffusionConfig,
+ metadata: StageMetadata,
+ batch_size: int = 1,
+ ) -> None:
+ self.model = model
+ self.od_config = od_config
+ self.stage_id = metadata.stage_id
+ self.final_output = metadata.final_output
+ self.final_output_type = metadata.final_output_type
+ self.default_sampling_params = metadata.default_sampling_params
+ self.custom_process_input_func = metadata.custom_process_input_func
+ self.engine_input_source = metadata.engine_input_source
+ self.batch_size = batch_size
+
+ self._enrich_config()
+ self._engine = DiffusionEngine.make_engine(self.od_config)
+ self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="inline-diffusion")
+
+ self._output_queue: asyncio.Queue[OmniRequestOutput] = asyncio.Queue()
+ self._tasks: dict[str, asyncio.Task] = {}
+ self._shutting_down = False
+
+ logger.info(
+ "[InlineStageDiffusionClient] Stage-%s initialized inline (batch_size=%d)",
+ self.stage_id,
+ self.batch_size,
+ )
+
+ def _enrich_config(self) -> None:
+ """Load model metadata from HuggingFace and populate od_config fields."""
+ self.od_config.enrich_config()
+
+ # ------------------------------------------------------------------
+ # Request processing
+ # ------------------------------------------------------------------
+
+ async def add_request_async(
+ self,
+ request_id: str,
+ prompt: OmniPromptType,
+ sampling_params: OmniDiffusionSamplingParams,
+ kv_sender_info: dict[int, dict[str, Any]] | None = None,
+ ) -> None:
+ task = asyncio.create_task(
+ self._dispatch_request(
+ request_id,
+ prompt,
+ sampling_params,
+ kv_sender_info,
+ )
+ )
+ self._tasks[request_id] = task
+
+ async def _dispatch_request(
+ self,
+ request_id: str,
+ prompt: Any,
+ sampling_params: OmniDiffusionSamplingParams,
+ kv_sender_info: dict[str, Any] | None = None,
+ ) -> None:
+ try:
+ request = OmniDiffusionRequest(
+ prompts=[prompt],
+ sampling_params=sampling_params,
+ request_ids=[request_id],
+ request_id=request_id,
+ kv_sender_info=kv_sender_info,
+ )
+
+ loop = asyncio.get_running_loop()
+ results = await loop.run_in_executor(self._executor, self._engine.step, request)
+ result = results[0]
+ if not result.request_id:
+ result.request_id = request_id
+
+ self._output_queue.put_nowait(result)
+ except DiffusionRequestAbortedError as e:
+ logger.info("request_id: %s aborted: %s", request_id, str(e))
+ except Exception as e:
+ logger.exception("Diffusion request %s failed: %s", request_id, e)
+ error_output = OmniRequestOutput.from_diffusion(
+ request_id=request_id,
+ images=[],
+ )
+ error_output.error = str(e)
+ self._output_queue.put_nowait(error_output)
+ finally:
+ self._tasks.pop(request_id, None)
+
+ async def add_batch_request_async(
+ self,
+ request_id: str,
+ prompts: list[OmniPromptType],
+ sampling_params: OmniDiffusionSamplingParams,
+ kv_sender_info: dict[int, dict[str, Any]] | None = None,
+ ) -> None:
+ task = asyncio.create_task(
+ self._dispatch_batch(
+ request_id,
+ prompts,
+ sampling_params,
+ kv_sender_info,
+ )
+ )
+ self._tasks[request_id] = task
+
+ async def _dispatch_batch(
+ self,
+ request_id: str,
+ prompts: list[Any],
+ sampling_params: OmniDiffusionSamplingParams,
+ kv_sender_info: dict[str, Any] | None = None,
+ ) -> None:
+ try:
+ request = OmniDiffusionRequest(
+ prompts=prompts,
+ sampling_params=sampling_params,
+ request_ids=[f"{request_id}-{i}" for i in range(len(prompts))],
+ request_id=request_id,
+ kv_sender_info=kv_sender_info,
+ )
+
+ loop = asyncio.get_running_loop()
+ results = await loop.run_in_executor(self._executor, self._engine.step, request)
+
+ all_images: list = []
+ merged_mm: dict[str, Any] = {}
+ merged_metrics: dict[str, Any] = {}
+ merged_durations: dict[str, float] = {}
+ merged_custom: dict[str, Any] = {}
+ peak_mem = 0.0
+ latents = None
+ trajectory_latents: list[torch.Tensor] | None = None
+ trajectory_timesteps: list[torch.Tensor] | None = None
+ trajectory_log_probs: torch.Tensor | None = None
+ trajectory_decoded: list[Image.Image] | None = None
+ final_output_type = "image"
+
+ for r in results:
+ all_images.extend(r.images)
+ merged_mm.update(r._multimodal_output)
+ merged_metrics.update(r.metrics)
+ merged_durations.update(r.stage_durations)
+ merged_custom.update(r._custom_output)
+ peak_mem = max(peak_mem, r.peak_memory_mb)
+ if latents is None and r.latents is not None:
+ latents = r.latents
+ if trajectory_latents is None:
+ trajectory_latents = r.trajectory_latents
+ if trajectory_timesteps is None:
+ trajectory_timesteps = r.trajectory_timesteps
+ if trajectory_log_probs is None:
+ trajectory_log_probs = r.trajectory_log_probs
+ if trajectory_decoded is None:
+ trajectory_decoded = r.trajectory_decoded
+ if r.final_output_type != "image":
+ final_output_type = r.final_output_type
+
+ result = OmniRequestOutput.from_diffusion(
+ request_id=request_id,
+ images=all_images,
+ prompt=prompts[0] if len(prompts) == 1 else None,
+ metrics=merged_metrics,
+ latents=latents,
+ trajectory_latents=trajectory_latents,
+ trajectory_timesteps=trajectory_timesteps,
+ trajectory_log_probs=trajectory_log_probs,
+ trajectory_decoded=trajectory_decoded,
+ custom_output=merged_custom or None,
+ multimodal_output=merged_mm or None,
+ final_output_type=final_output_type,
+ stage_durations=merged_durations,
+ peak_memory_mb=peak_mem,
+ )
+
+ self._output_queue.put_nowait(result)
+ except DiffusionRequestAbortedError as e:
+ logger.info("request_id: %s aborted: %s", request_id, str(e))
+ except Exception as e:
+ logger.exception("Batch diffusion request %s failed: %s", request_id, e)
+ error_output = OmniRequestOutput.from_diffusion(
+ request_id=request_id,
+ images=[],
+ )
+ error_output.error = str(e)
+ self._output_queue.put_nowait(error_output)
+ finally:
+ self._tasks.pop(request_id, None)
+
+ def get_diffusion_output_nowait(self) -> OmniRequestOutput | None:
+ try:
+ return self._output_queue.get_nowait()
+ except asyncio.QueueEmpty:
+ return None
+
+ async def abort_requests_async(self, request_ids: list[str]) -> None:
+ for rid in request_ids:
+ task = self._tasks.pop(rid, None)
+ if task:
+ task.cancel()
+ self._engine.abort(rid)
+
+ async def collective_rpc_async(
+ self,
+ method: str,
+ timeout: float | None = None,
+ args: tuple[Any, ...] = (),
+ kwargs: dict[str, Any] | None = None,
+ ) -> Any:
+ loop = asyncio.get_running_loop()
+
+ if method == "profile":
+ is_start = args[0] if args else True
+ profile_prefix = args[1] if len(args) > 1 else None
+ if is_start and profile_prefix is None:
+ profile_prefix = f"stage_{self.stage_id}_diffusion_{int(time.time())}"
+ return await loop.run_in_executor(
+ self._executor,
+ self._engine.profile,
+ is_start,
+ profile_prefix,
+ )
+
+ kwargs = kwargs or {}
+
+ # LoRA methods
+ if method == "add_lora":
+ lora_request = args[0] if args else kwargs.get("lora_request")
+ results = await loop.run_in_executor(
+ self._executor,
+ self._engine.collective_rpc,
+ "add_lora",
+ timeout,
+ (),
+ {"lora_request": lora_request},
+ None,
+ )
+ return all(results) if isinstance(results, list) else results
+
+ if method == "remove_lora":
+ results = await loop.run_in_executor(
+ self._executor,
+ self._engine.collective_rpc,
+ "remove_lora",
+ timeout,
+ args,
+ kwargs,
+ None,
+ )
+ return all(results) if isinstance(results, list) else results
+
+ if method == "list_loras":
+ results = await loop.run_in_executor(
+ self._executor,
+ self._engine.collective_rpc,
+ "list_loras",
+ timeout,
+ (),
+ {},
+ None,
+ )
+ if not isinstance(results, list):
+ return results or []
+ merged: set[int] = set()
+ for part in results:
+ merged.update(part or [])
+ return sorted(merged)
+
+ if method == "pin_lora":
+ lora_id = args[0] if args else kwargs.get("adapter_id")
+ results = await loop.run_in_executor(
+ self._executor,
+ self._engine.collective_rpc,
+ "pin_lora",
+ timeout,
+ (),
+ {"adapter_id": lora_id},
+ None,
+ )
+ return all(results) if isinstance(results, list) else results
+
+ return await loop.run_in_executor(
+ self._executor,
+ self._engine.collective_rpc,
+ method,
+ timeout,
+ args,
+ kwargs,
+ None,
+ )
+
+ def shutdown(self) -> None:
+ self._shutting_down = True
+
+ # Cancel all pending tasks
+ for task in self._tasks.values():
+ task.cancel()
+
+ try:
+ # Cancel queued futures and wait for the running one to complete deterministically
+ self._executor.shutdown(wait=True, cancel_futures=True)
+ except Exception:
+ pass
+
+ try:
+ self._engine.close()
+ except Exception:
+ pass
diff --git a/vllm_omni/diffusion/layers/adalayernorm.py b/vllm_omni/diffusion/layers/adalayernorm.py
index 35f63e2fc9..d147bdcfeb 100644
--- a/vllm_omni/diffusion/layers/adalayernorm.py
+++ b/vllm_omni/diffusion/layers/adalayernorm.py
@@ -7,6 +7,7 @@
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm_omni.diffusion.layers.custom_op import CustomOp
+from vllm_omni.diffusion.layers.norm import LayerNorm
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
@@ -27,107 +28,63 @@ def __init__(self, hidden_size: int, elementwise_affine: bool = False, eps: floa
self.eps = eps
self.elementwise_affine = elementwise_affine
self.hidden_size = hidden_size
- self.layernorm = nn.LayerNorm(self.hidden_size, elementwise_affine=self.elementwise_affine, eps=self.eps)
-
- def preprocess(
- self,
- mod_params: torch.Tensor,
- index: torch.Tensor = None,
- ) -> torch.Tensor:
- # shift: b d, scale: b d, gate: b d
- shift, scale, gate = mod_params.chunk(3, dim=-1)
-
- if index is not None:
- # Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
- # So shift, scale, gate have shape [2*actual_batch, d]
- actual_batch = shift.size(0) // 2
- shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
- scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
- gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
-
- # index: [b, l] where b is actual batch size
- # Expand to [b, l, 1] to match feature dimension
- index_expanded = index.unsqueeze(-1) # [b, l, 1]
-
- # Expand chunks to [b, 1, d] then broadcast to [b, l, d]
- shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
- shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
- scale_0_exp = scale_0.unsqueeze(1)
- scale_1_exp = scale_1.unsqueeze(1)
- gate_0_exp = gate_0.unsqueeze(1)
- gate_1_exp = gate_1.unsqueeze(1)
-
- # Use torch.where to select based on index
- shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
- scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
- gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
- else:
- shift_result = shift.unsqueeze(1)
- scale_result = scale.unsqueeze(1)
- gate_result = gate.unsqueeze(1)
-
- return shift_result, scale_result, gate_result
+ self.layernorm = LayerNorm(self.hidden_size, elementwise_affine=self.elementwise_affine, eps=self.eps)
def forward_cuda(
self,
x: torch.Tensor,
- mod_params: torch.Tensor,
- index: torch.Tensor = None,
+ scale: torch.Tensor,
+ shift: torch.Tensor,
) -> torch.Tensor:
- return self.forward_native(x, mod_params, index)
+ return self.forward_native(x, scale, shift)
def forward_hip(
self,
x: torch.Tensor,
- mod_params: torch.Tensor,
- index: torch.Tensor = None,
+ scale: torch.Tensor,
+ shift: torch.Tensor,
) -> torch.Tensor:
- return self.forward_native(x, mod_params, index)
+ return self.forward_native(x, scale, shift)
def forward_npu(
self,
x: torch.Tensor,
- mod_params: torch.Tensor,
- index: torch.Tensor = None,
+ scale: torch.Tensor,
+ shift: torch.Tensor,
) -> torch.Tensor:
- shift_result, scale_result, gate_result = self.preprocess(mod_params, index)
-
if _HAS_MINDIESD:
try:
from mindiesd import layernorm_scale_shift
- output = layernorm_scale_shift(self.layernorm, x, scale_result, shift_result, fused=True)
+ output = layernorm_scale_shift(self.layernorm, x, scale, shift, fused=True)
- return output, gate_result
+ return output
except ImportError as e:
logger.warning_once(f"mindiesd import failed, falling back to torch_npu: {e}")
import torch_npu
output = (
- torch_npu.npu_layer_norm_eval(x, normalized_shape=[self.hidden_size], eps=self.eps) * (1 + scale_result)
- + shift_result
+ torch_npu.npu_layer_norm_eval(x, normalized_shape=[self.hidden_size], eps=self.eps) * (1 + scale) + shift
)
- return output, gate_result
+ return output
def forward_xpu(
self,
x: torch.Tensor,
- mod_params: torch.Tensor,
- index: torch.Tensor = None,
+ scale: torch.Tensor,
+ shift: torch.Tensor,
) -> torch.Tensor:
- return self.forward_native(x, mod_params, index)
+ return self.forward_native(x, scale, shift)
def forward_native(
self,
x: torch.Tensor,
- mod_params: torch.Tensor,
- index: torch.Tensor = None,
+ scale: torch.Tensor,
+ shift: torch.Tensor,
) -> torch.Tensor:
- shift_result, scale_result, gate_result = self.preprocess(mod_params, index)
-
- return self.layernorm(x) * (1 + scale_result) + shift_result, gate_result
+ return self.layernorm(x) * (1 + scale) + shift
class AdaLayerNormZero(nn.Module):
diff --git a/vllm_omni/diffusion/layers/norm.py b/vllm_omni/diffusion/layers/norm.py
new file mode 100644
index 0000000000..6096ad7c37
--- /dev/null
+++ b/vllm_omni/diffusion/layers/norm.py
@@ -0,0 +1,110 @@
+from importlib.util import find_spec
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from vllm.logger import init_logger
+
+from vllm_omni.diffusion.layers.custom_op import CustomOp
+
+logger = init_logger(__name__)
+
+_HAS_MINDIESD = find_spec("mindiesd") is not None
+
+
+class LayerNorm(nn.LayerNorm, CustomOp):
+ """
+ LayerNorm implementation that inherits from both ``nn.LayerNorm`` and ``CustomOp``.
+ NPU:
+ Uses ``mindiesd.fast_layernorm(self, x)`` when MindIE-SD is installed.
+ CUDA / HIP / XPU / native:
+ Falls back to FP32 nn.LayerNorm implementation.
+ """
+
+ def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True):
+ super().__init__(normalized_shape=dim, eps=eps, elementwise_affine=elementwise_affine)
+ # CustomOp.__init__ cannot be called here because it would re-run
+ # nn.Module initialization and clear LayerNorm parameters.
+ self._forward_method = CustomOp.dispatch_forward(self)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self._forward_method(x)
+
+ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
+ return self.forward_native(x)
+
+ def forward_hip(self, x: torch.Tensor) -> torch.Tensor:
+ return self.forward_native(x)
+
+ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
+ return self.forward_native(x)
+
+ def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
+ if _HAS_MINDIESD:
+ try:
+ from mindiesd import fast_layernorm
+
+ return fast_layernorm(self, x)
+ except ImportError as e:
+ logger.warning_once(
+ "mindiesd.fast_layernorm import failed, falling back to FP32 layer_norm: %s",
+ e,
+ )
+
+ return self.forward_native(x)
+
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
+ origin_dtype = x.dtype
+ return F.layer_norm(
+ x.float(),
+ self.normalized_shape,
+ self.weight.float() if self.weight is not None else None,
+ self.bias.float() if self.bias is not None else None,
+ self.eps,
+ ).to(origin_dtype)
+
+
+class RMSNorm(CustomOp):
+ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward_cuda(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.forward_native(x)
+
+ def forward_hip(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.forward_native(x)
+
+ def forward_npu(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ import torch_npu
+
+ output = torch_npu.npu_rms_norm(x, gamma=self.weight, epsilon=self.variance_epsilon)[0]
+
+ return output
+
+ def forward_xpu(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.forward_native(x)
+
+ def forward_native(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ input_dtype = x.dtype
+ x = x.to(torch.float32)
+ variance = x.pow(2).mean(-1, keepdim=True)
+ out = x * torch.rsqrt(variance + self.variance_epsilon)
+ out = self.weight.to(torch.float32) * out
+ return out.to(input_dtype)
diff --git a/vllm_omni/diffusion/model_metadata.py b/vllm_omni/diffusion/model_metadata.py
new file mode 100644
index 0000000000..ec133e7380
--- /dev/null
+++ b/vllm_omni/diffusion/model_metadata.py
@@ -0,0 +1,31 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from dataclasses import dataclass
+
+
+@dataclass(frozen=True)
+class DiffusionModelMetadata:
+ # Keep serving-facing capability metadata in a lightweight shared module so
+ # config/model plumbing can read it without importing concrete pipelines.
+ supports_multimodal_inputs: bool = False
+ max_multimodal_image_inputs: int | None = None
+
+
+QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES = 4
+
+
+_DIFFUSION_MODEL_METADATA: dict[str, DiffusionModelMetadata] = {
+ "QwenImageEditPlusPipeline": DiffusionModelMetadata(
+ supports_multimodal_inputs=True,
+ max_multimodal_image_inputs=QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES,
+ ),
+}
+
+
+def get_diffusion_model_metadata(model_class_name: str | None) -> DiffusionModelMetadata:
+ # Unknown models fall back to "no special multimodal capabilities" so new
+ # pipelines do not accidentally inherit limits meant for other models.
+ if model_class_name is None:
+ return DiffusionModelMetadata()
+ return _DIFFUSION_MODEL_METADATA.get(model_class_name, DiffusionModelMetadata())
diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py
index f848077568..d1254f8456 100644
--- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py
+++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py
@@ -854,6 +854,7 @@ def __init__(
config, parallel_config=parallel_config, quant_config=quant_config, prefix=f"{prefix}.model"
)
self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
@@ -864,6 +865,12 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.embed_tokens = value
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
def set_decoder(self, decoder):
self.model = decoder
@@ -1207,7 +1214,7 @@ def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen
- text_ids = tokenizer.encode(prompt)
+ text_ids = tokenizer.encode(prompt, add_special_tokens=False)
text_ids = [new_token_ids["bos_token_id"]] + text_ids + [new_token_ids["eos_token_id"]]
text_token_lens.append(len(text_ids))
packed_text_ids.extend(text_ids)
@@ -1619,10 +1626,110 @@ def _merge_naive_caches(caches: list) -> NaiveCache:
num_layers = len(caches[0].key_cache)
merged = NaiveCache(num_layers)
for layer_idx in range(num_layers):
- merged.key_cache[layer_idx] = torch.cat([c.key_cache[layer_idx] for c in caches], dim=0)
- merged.value_cache[layer_idx] = torch.cat([c.value_cache[layer_idx] for c in caches], dim=0)
+ key_parts = [c.key_cache[layer_idx] for c in caches if c.key_cache[layer_idx] is not None]
+ val_parts = [c.value_cache[layer_idx] for c in caches if c.value_cache[layer_idx] is not None]
+ merged.key_cache[layer_idx] = torch.cat(key_parts, dim=0) if key_parts else None
+ merged.value_cache[layer_idx] = torch.cat(val_parts, dim=0) if val_parts else None
return merged
+ def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids):
+ """Prepare start tokens for autoregressive text generation.
+
+ Ported from the original BAGEL ``Bagel.prepare_start_tokens``.
+ """
+ packed_start_tokens, packed_key_value_indexes = list(), list()
+ packed_query_position_ids = list()
+
+ curr = 0
+ for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope):
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
+ packed_start_tokens.append(new_token_ids["bos_token_id"])
+ packed_query_position_ids.append(curr_position_id)
+ curr += curr_kvlen
+
+ generation_input = {
+ "packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long),
+ "packed_query_position_ids": torch.tensor(packed_query_position_ids, dtype=torch.long),
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
+ }
+ return generation_input
+
+ @torch.no_grad()
+ def generate_text(
+ self,
+ past_key_values: NaiveCache,
+ packed_key_value_indexes: torch.LongTensor,
+ key_values_lens: torch.IntTensor,
+ packed_start_tokens: torch.LongTensor,
+ packed_query_position_ids: torch.LongTensor,
+ max_length: int,
+ do_sample: bool = False,
+ temperature: float = 1.0,
+ end_token_id: int | None = None,
+ ):
+ """Autoregressive text generation (ported from original BAGEL).
+
+ Decodes tokens one at a time, appending to ``past_key_values``
+ until ``max_length`` is reached or ``end_token_id`` is generated.
+ """
+ step = 0
+ generated_sequence = []
+ curr_tokens = packed_start_tokens
+ while step < max_length:
+ generated_sequence.append(curr_tokens)
+ packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
+ query_lens = torch.ones_like(curr_tokens)
+ packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange(
+ 0,
+ len(key_values_lens),
+ device=key_values_lens.device,
+ dtype=key_values_lens.dtype,
+ )
+
+ uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
+ for i in range(len(uppacked)):
+ uppacked[i] += i
+ packed_key_value_indexes = torch.cat(uppacked, dim=0)
+
+ output = self.language_model(
+ packed_query_sequence=packed_text_embedding,
+ query_lens=query_lens,
+ packed_query_position_ids=packed_query_position_ids,
+ packed_query_indexes=packed_query_indexes,
+ past_key_values=past_key_values,
+ key_values_lens=key_values_lens,
+ packed_key_value_indexes=packed_key_value_indexes,
+ update_past_key_values=True,
+ is_causal=True,
+ mode="und",
+ )
+ past_key_values = output.past_key_values
+ packed_query_sequence = output.packed_query_sequence
+ pred_logits = self.language_model.lm_head(packed_query_sequence)
+
+ if do_sample:
+ probs = nn.functional.softmax(pred_logits / temperature, dim=-1)
+ curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+ else:
+ curr_tokens = torch.argmax(pred_logits, dim=-1)
+
+ uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
+ for i in range(len(uppacked)):
+ uppacked[i] = torch.cat(
+ [uppacked[i], torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device)], dim=0
+ )
+ packed_key_value_indexes = torch.cat(uppacked, dim=0)
+ key_values_lens = key_values_lens + 1
+ packed_query_position_ids = packed_query_position_ids + 1
+ step += 1
+
+ if end_token_id is not None and curr_tokens[0] == end_token_id:
+ break
+
+ output_device = generated_sequence[0].device
+ return torch.stack([i.to(output_device) for i in generated_sequence], dim=0)
+
def generate_image(
self,
packed_text_ids: torch.LongTensor,
diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py
index 13d0cc2093..90baf5f676 100644
--- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py
+++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py
@@ -365,35 +365,65 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
if req.sampling_params.kv_metadata and "image_shape" in req.sampling_params.kv_metadata:
image_shape = tuple(req.sampling_params.kv_metadata["image_shape"])
- cfg_text_kv = getattr(req.sampling_params, "cfg_text_past_key_values", None)
+ branch_kvs = getattr(req.sampling_params, "cfg_branch_past_key_values", None) or {}
+ branch_metadata = getattr(req.sampling_params, "cfg_branch_kv_metadata", None) or {}
+ active_branch = getattr(req.sampling_params, "cfg_active_branch", None)
+ branch_roles = getattr(req.sampling_params, "cfg_branch_roles", None) or list(branch_kvs.keys())
+
+ cfg_text_kv = getattr(req.sampling_params, "cfg_text_past_key_values", None) or branch_kvs.get("cfg_text")
+ cfg_text_metadata = getattr(req.sampling_params, "cfg_text_kv_metadata", None) or branch_metadata.get(
+ "cfg_text"
+ )
+ cfg_img_kv = getattr(req.sampling_params, "cfg_img_past_key_values", None) or branch_kvs.get("cfg_img")
+ cfg_img_metadata = getattr(req.sampling_params, "cfg_img_kv_metadata", None) or branch_metadata.get(
+ "cfg_img"
+ )
+
+ cfg_parallel_contract = (
+ active_branch is not None or bool(branch_roles) or cfg_text_kv is not None or cfg_img_kv is not None
+ )
+ if cfg_parallel_contract:
+ logger.info(
+ "CFG enabled with injected branch KV context roles=%s active=%s",
+ branch_roles,
+ active_branch,
+ )
+
if cfg_text_kv is not None:
- logger.info("CFG enabled with multi-KV: using injected cfg_text KV Cache")
cfg_text_seq_len = cfg_text_kv.key_cache[0].shape[0]
cfg_text_context["past_key_values"] = cfg_text_kv
cfg_text_context["kv_lens"] = [cfg_text_seq_len]
- cfg_text_metadata = getattr(req.sampling_params, "cfg_text_kv_metadata", None)
if cfg_text_metadata and "ropes" in cfg_text_metadata:
cfg_text_context["ropes"] = cfg_text_metadata["ropes"]
else:
cfg_text_context["ropes"] = [cfg_text_seq_len]
-
- cfg_img_kv = getattr(req.sampling_params, "cfg_img_past_key_values", None) or injected_kv
+ else:
+ # No cfg_text companion received. For text2img this is the
+ # expected path: original BAGEL uses an empty KV cache (0
+ # tokens) as the text-unconditional branch. Keep the default
+ # empty NaiveCache in cfg_text_context and preserve the
+ # original cfg_text_scale so CFG still applies.
+ pass
+
+ if cfg_img_kv is None:
+ # text2img multi-stage: cfg_img reuses gen KV (positive prompt,
+ # no image), mirroring forward_cache_update_text on cfg_img_context
+ # in the single-stage path.
+ cfg_img_seq_len = injected_kv.key_cache[0].shape[0]
+ cfg_img_context["past_key_values"] = injected_kv
+ cfg_img_context["kv_lens"] = [cfg_img_seq_len]
+ if req.sampling_params.kv_metadata and "ropes" in req.sampling_params.kv_metadata:
+ cfg_img_context["ropes"] = req.sampling_params.kv_metadata["ropes"]
+ else:
+ cfg_img_context["ropes"] = [cfg_img_seq_len]
+ else:
cfg_img_seq_len = cfg_img_kv.key_cache[0].shape[0]
cfg_img_context["past_key_values"] = cfg_img_kv
cfg_img_context["kv_lens"] = [cfg_img_seq_len]
- cfg_img_metadata = getattr(req.sampling_params, "cfg_img_kv_metadata", None)
if cfg_img_metadata and "ropes" in cfg_img_metadata:
cfg_img_context["ropes"] = cfg_img_metadata["ropes"]
else:
cfg_img_context["ropes"] = [cfg_img_seq_len]
- else:
- logger.warning("CFG is disabled: only single KV cache available")
- gen_params = BagelGenParams(
- num_timesteps=gen_params.num_timesteps,
- timestep_shift=gen_params.timestep_shift,
- cfg_text_scale=1.0,
- cfg_img_scale=1.0,
- )
else:
image_input = (
@@ -495,11 +525,15 @@ def vae_transforms(img):
cfg_text_context = deepcopy(gen_context)
+ # Strip <|im_start|>/<|im_end|> wrappers that end2end.py may have
+ # already added, so prepare_prompts doesn't double-add bos/eos.
+ clean_prompt = prompt.removeprefix("<|im_start|>").removesuffix("<|im_end|>")
+
# Update gen_context with text prompt
generation_input, newlens, new_rope = self.bagel.prepare_prompts(
curr_kvlens=gen_context["kv_lens"],
curr_rope=gen_context["ropes"],
- prompts=[prompt],
+ prompts=[clean_prompt],
tokenizer=self.tokenizer,
new_token_ids=self.new_token_ids,
)
@@ -527,34 +561,37 @@ def vae_transforms(img):
gen_context["kv_lens"] = newlens
gen_context["ropes"] = new_rope
- # cfg_text_context: update with negative prompt (no text condition)
+ # cfg_text_context: update with negative prompt (no text condition).
+ # When empty, keep cfg_text_context as-is (kv_lens=0) to match
+ # original BAGEL; _merge_naive_caches handles None KV entries.
neg_prompt = extra_args.get("negative_prompt", "")
- neg_input, neg_newlens, neg_rope = self.bagel.prepare_prompts(
- curr_kvlens=cfg_text_context["kv_lens"],
- curr_rope=cfg_text_context["ropes"],
- prompts=[neg_prompt],
- tokenizer=self.tokenizer,
- new_token_ids=self.new_token_ids,
- )
- for k, v in neg_input.items():
- if torch.is_tensor(v):
- neg_input[k] = v.to(self.device)
- with torch.autocast(
- device_type=self.device.type,
- enabled=self.device.type != "cpu",
- dtype=self.od_config.dtype,
- ):
- cfg_text_context["past_key_values"] = self.bagel.forward_cache_update_text(
- cfg_text_context["past_key_values"], **neg_input
+ if neg_prompt:
+ neg_input, neg_newlens, neg_rope = self.bagel.prepare_prompts(
+ curr_kvlens=cfg_text_context["kv_lens"],
+ curr_rope=cfg_text_context["ropes"],
+ prompts=[neg_prompt],
+ tokenizer=self.tokenizer,
+ new_token_ids=self.new_token_ids,
)
- cfg_text_context["kv_lens"] = neg_newlens
- cfg_text_context["ropes"] = neg_rope
+ for k, v in neg_input.items():
+ if torch.is_tensor(v):
+ neg_input[k] = v.to(self.device)
+ with torch.autocast(
+ device_type=self.device.type,
+ enabled=self.device.type != "cpu",
+ dtype=self.od_config.dtype,
+ ):
+ cfg_text_context["past_key_values"] = self.bagel.forward_cache_update_text(
+ cfg_text_context["past_key_values"], **neg_input
+ )
+ cfg_text_context["kv_lens"] = neg_newlens
+ cfg_text_context["ropes"] = neg_rope
# cfg_img_context: update with text prompt (no image condition)
cfg_img_generation_input, cfg_img_newlens, cfg_img_new_rope = self.bagel.prepare_prompts(
curr_kvlens=cfg_img_context["kv_lens"],
curr_rope=cfg_img_context["ropes"],
- prompts=[prompt],
+ prompts=[clean_prompt],
tokenizer=self.tokenizer,
new_token_ids=self.new_token_ids,
)
@@ -572,6 +609,96 @@ def vae_transforms(img):
cfg_img_context["kv_lens"] = cfg_img_newlens
cfg_img_context["ropes"] = cfg_img_new_rope
+ # ---- Detect output modality and think mode ----
+ modalities = first_prompt.get("modalities", []) if isinstance(first_prompt, dict) else []
+ is_text_output = "text" in modalities
+ think_enabled = extra_args.get("think", False)
+ think_text = None
+
+ if think_enabled and injected_kv is None:
+ max_think_tokens = int(extra_args.get("max_think_tokens", 1000))
+ do_sample = bool(extra_args.get("do_sample", False))
+ text_temperature = float(extra_args.get("text_temperature", 0.3))
+
+ with torch.autocast(
+ device_type=self.device.type,
+ enabled=self.device.type != "cpu",
+ dtype=self.od_config.dtype,
+ ):
+ start_input = self.bagel.prepare_start_tokens(
+ gen_context["kv_lens"], gen_context["ropes"], self.new_token_ids
+ )
+ for k, v in start_input.items():
+ if torch.is_tensor(v):
+ start_input[k] = v.to(self.device)
+
+ gen_ctx_copy = deepcopy(gen_context)
+ token_ids = self.bagel.generate_text(
+ past_key_values=gen_ctx_copy["past_key_values"],
+ max_length=max_think_tokens,
+ do_sample=do_sample,
+ temperature=text_temperature,
+ end_token_id=self.new_token_ids["eos_token_id"],
+ **start_input,
+ )
+ # token_ids shape: (seq_len, batch=1)
+ decoded = self.tokenizer.decode(token_ids[:, 0].tolist())
+ # Strip chat markers to get clean text
+ think_text = decoded.split("<|im_end|>")[0]
+ if "<|im_start|>" in think_text:
+ think_text = think_text.split("<|im_start|>")[-1]
+ logger.info("Think mode generated %d tokens", token_ids.shape[0])
+
+ if not is_text_output:
+ # Use the autoregressive KV cache from think generation
+ # directly, instead of decode→re-encode which adds extra
+ # bos/eos and may alter tokenization.
+ num_think_tokens = token_ids.shape[0]
+ gen_context["past_key_values"] = gen_ctx_copy["past_key_values"]
+ gen_context["kv_lens"] = [kl + num_think_tokens for kl in gen_context["kv_lens"]]
+ gen_context["ropes"] = [r + num_think_tokens for r in gen_context["ropes"]]
+
+ # ---- Text-only output (text2text / img2text) ----
+ if is_text_output and injected_kv is None:
+ if think_text is not None:
+ # Think mode already generated the text (including reasoning)
+ text_output = think_text
+ else:
+ max_text_tokens = int(extra_args.get("max_think_tokens", 500))
+ do_sample = bool(extra_args.get("do_sample", False))
+ text_temperature = float(extra_args.get("text_temperature", 0.3))
+
+ with torch.autocast(
+ device_type=self.device.type,
+ enabled=self.device.type != "cpu",
+ dtype=self.od_config.dtype,
+ ):
+ start_input = self.bagel.prepare_start_tokens(
+ gen_context["kv_lens"], gen_context["ropes"], self.new_token_ids
+ )
+ for k, v in start_input.items():
+ if torch.is_tensor(v):
+ start_input[k] = v.to(self.device)
+ token_ids = self.bagel.generate_text(
+ past_key_values=gen_context["past_key_values"],
+ max_length=max_text_tokens,
+ do_sample=do_sample,
+ temperature=text_temperature,
+ end_token_id=self.new_token_ids["eos_token_id"],
+ **start_input,
+ )
+ decoded = self.tokenizer.decode(token_ids[:, 0].tolist())
+ text_output = decoded.split("<|im_end|>")[0]
+ if "<|im_start|>" in text_output:
+ text_output = text_output.split("<|im_start|>")[-1]
+
+ return DiffusionOutput(
+ output=text_output,
+ custom_output={"text_output": text_output},
+ stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None,
+ )
+
+ # ---- Image generation (text2img / img2img) ----
if req.sampling_params.seed is not None:
torch.manual_seed(req.sampling_params.seed)
if self.device.type == "cuda":
@@ -676,12 +803,17 @@ def vae_transforms(img):
if trajectory_log_probs:
trajectory_log_probs_stacked = torch.stack(trajectory_log_probs)
+ custom = {}
+ if think_text is not None:
+ custom["think_text"] = think_text
+
return DiffusionOutput(
output=img,
trajectory_latents=trajectory_latents_stacked,
trajectory_timesteps=trajectory_timesteps_stacked,
trajectory_log_probs=trajectory_log_probs_stacked,
trajectory_decoded=trajectory_decoded,
+ custom_output=custom,
stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None,
)
diff --git a/vllm_omni/diffusion/models/dreamid_omni/fusion.py b/vllm_omni/diffusion/models/dreamid_omni/fusion.py
index a534f5a76f..abca4c9474 100644
--- a/vllm_omni/diffusion/models/dreamid_omni/fusion.py
+++ b/vllm_omni/diffusion/models/dreamid_omni/fusion.py
@@ -1,3 +1,5 @@
+import re
+
import torch
import torch.nn as nn
from vllm.logger import init_logger
@@ -15,78 +17,26 @@
logger = init_logger(__name__)
-class FusionModel(nn.Module):
- def __init__(self, video_config=None, audio_config=None):
- super().__init__()
- has_video = True
- has_audio = True
- if video_config is not None:
- self.video_model = WanModel(**video_config)
- else:
- has_video = False
- self.video_model = None
- logger.warning("No video model is provided!")
-
- if audio_config is not None:
- self.audio_model = WanModel(**audio_config)
- else:
- has_audio = False
- self.audio_model = None
- logger.warning("No audio model is provided!")
-
- if has_video and has_audio:
- assert len(self.video_model.blocks) == len(self.audio_model.blocks)
- self.num_blocks = len(self.video_model.blocks)
-
- self.inject_cross_attention_kv_projections()
- self.device = get_local_device()
-
- self.num_heads = self.video_model.num_heads
- self.head_dim = self.video_model.dim // self.video_model.num_heads
- self.attn = Attention(
- num_heads=self.num_heads,
- head_size=self.head_dim,
- num_kv_heads=self.num_heads,
- softmax_scale=1.0 / (self.head_dim**0.5),
- causal=False,
- )
-
- def inject_cross_attention_kv_projections(self):
- for vid_block in self.video_model.blocks:
- vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim)
- vid_block.cross_attn.v_fusion = nn.Linear(vid_block.dim, vid_block.dim)
- vid_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(vid_block.dim, elementwise_affine=True)
- vid_block.cross_attn.norm_k_fusion = (
- WanRMSNorm(vid_block.dim, eps=1e-6) if vid_block.qk_norm else nn.Identity()
- )
+class FusedBlock(nn.Module):
+ """Wrapper pairing a video block and audio block for layerwise offloading.
- for audio_block in self.audio_model.blocks:
- audio_block.cross_attn.k_fusion = nn.Linear(audio_block.dim, audio_block.dim)
- audio_block.cross_attn.v_fusion = nn.Linear(audio_block.dim, audio_block.dim)
- audio_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(audio_block.dim, elementwise_affine=True)
- audio_block.cross_attn.norm_k_fusion = (
- WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity()
- )
+ Registers both blocks as submodules so their parameters are visible to the offload hooks.
+ """
- def merge_kwargs(self, vid_kwargs, audio_kwargs):
- """
- keys in each kwarg:
- e
- seq_lens
- grid_sizes
- freqs
- context
- context_lens
- """
- merged_kwargs = {}
- for key in vid_kwargs:
- merged_kwargs[f"vid_{key}"] = vid_kwargs[key]
- for key in audio_kwargs:
- merged_kwargs[f"audio_{key}"] = audio_kwargs[key]
- return merged_kwargs
+ def __init__(
+ self,
+ vid_block: nn.Module,
+ audio_block: nn.Module,
+ device: torch.device,
+ ):
+ super().__init__()
+ self.vid_block = vid_block
+ self.audio_block = audio_block
+ self.device = device
- def single_fusion_cross_attention_forward(
+ def _cross_attention_forward(
self,
+ attn: Attention,
cross_attn_block,
src_seq,
src_grid_sizes,
@@ -104,21 +54,17 @@ def single_fusion_cross_attention_forward(
):
b, n, d = src_seq.size(0), cross_attn_block.num_heads, cross_attn_block.head_dim
if hasattr(cross_attn_block, "k_img"):
- ## means is i2v block
q, k, v, k_img, v_img = cross_attn_block.qkv_fn(src_seq, context)
else:
- ## means is t2v block
q, k, v = cross_attn_block.qkv_fn(src_seq, context)
k_img = v_img = None
- x = self.attn(q, k, v)
+ x = attn(q, k, v)
if k_img is not None:
- img_x = self.attn(q, k_img, v_img)
+ img_x = attn(q, k_img, v_img)
x = x + img_x
- # is_vid = src_grid_sizes.shape[1] > 1
- # compute target attention
target_seq = cross_attn_block.pre_attn_norm_fusion(target_seq)
k_target = cross_attn_block.norm_k_fusion(cross_attn_block.k_fusion(target_seq)).view(b, -1, n, d)
v_target = cross_attn_block.v_fusion(target_seq).view(b, -1, n, d)
@@ -132,17 +78,16 @@ def single_fusion_cross_attention_forward(
freqs_scaling=target_freqs_scaling,
)
- target_x = self.attn(q, k_target, v_target)
+ target_x = attn(q, k_target, v_target)
x = x + target_x
-
- x = x.flatten(2) # [B, L/P, C]
-
+ x = x.flatten(2)
x = cross_attn_block.o(x)
return x
- def single_fusion_cross_attention_ffn_forward(
+ def _cross_attention_ffn_forward(
self,
+ attn: Attention,
attn_block,
src_seq,
src_grid_sizes,
@@ -159,7 +104,8 @@ def single_fusion_cross_attention_ffn_forward(
target_ref_lengths=None,
target_freqs_scaling=None,
):
- src_seq = src_seq + self.single_fusion_cross_attention_forward(
+ src_seq = src_seq + self._cross_attention_forward(
+ attn,
attn_block.cross_attn,
attn_block.norm3(src_seq),
src_grid_sizes=src_grid_sizes,
@@ -180,12 +126,11 @@ def single_fusion_cross_attention_ffn_forward(
src_seq = src_seq + y * src_e[5].squeeze(2)
return src_seq
- def single_fusion_block_forward(
+ def forward(
self,
- vid_block,
- audio_block,
vid,
audio,
+ attn: Attention,
vid_e,
vid_seq_lens,
vid_grid_sizes,
@@ -203,6 +148,9 @@ def single_fusion_block_forward(
audio_ref_lengths,
audio_freqs_scaling,
):
+ vid_block = self.vid_block
+ audio_block = self.audio_block
+
## audio modulation
assert audio_e.dtype == torch.bfloat16
assert len(audio_e.shape) == 4 and audio_e.size(2) == 6 and audio_e.shape[1] == audio.shape[1], (
@@ -246,7 +194,8 @@ def single_fusion_block_forward(
og_audio = audio
# audio cross-attention
- audio = self.single_fusion_cross_attention_ffn_forward(
+ audio = self._cross_attention_ffn_forward(
+ attn,
audio_block,
audio,
audio_grid_sizes,
@@ -267,7 +216,8 @@ def single_fusion_block_forward(
assert not torch.equal(og_audio, audio), "Audio should be changed after cross-attention!"
# video cross-attention
- vid = self.single_fusion_cross_attention_ffn_forward(
+ vid = self._cross_attention_ffn_forward(
+ attn,
vid_block,
vid,
vid_grid_sizes,
@@ -287,6 +237,128 @@ def single_fusion_block_forward(
return vid, audio
+
+class FusionModel(nn.Module):
+ _layerwise_offload_blocks_attrs = ["fused_blocks"]
+
+ def __init__(self, video_config=None, audio_config=None):
+ super().__init__()
+ has_video = True
+ has_audio = True
+ self.device = get_local_device()
+ if video_config is not None:
+ self.video_model = WanModel(**video_config)
+ else:
+ has_video = False
+ self.video_model = None
+ logger.warning("No video model is provided!")
+
+ if audio_config is not None:
+ self.audio_model = WanModel(**audio_config)
+ else:
+ has_audio = False
+ self.audio_model = None
+ logger.warning("No audio model is provided!")
+
+ if has_video and has_audio:
+ assert len(self.video_model.blocks) == len(self.audio_model.blocks)
+ self.num_blocks = len(self.video_model.blocks)
+
+ self.inject_cross_attention_kv_projections()
+
+ self.num_heads = self.video_model.num_heads
+ self.head_dim = self.video_model.dim // self.video_model.num_heads
+ # Make a single shared instance to pass in at forward time
+ self.attn = Attention(
+ num_heads=self.num_heads,
+ head_size=self.head_dim,
+ num_kv_heads=self.num_heads,
+ softmax_scale=1.0 / (self.head_dim**0.5),
+ causal=False,
+ )
+
+ if has_video and has_audio:
+ self.fused_blocks = nn.ModuleList(
+ [
+ FusedBlock(
+ self.video_model.blocks[i],
+ self.audio_model.blocks[i],
+ self.device,
+ )
+ for i in range(self.num_blocks)
+ ]
+ )
+
+ def load_state_dict(self, state_dict, strict=True, assign=False):
+ """Remap checkpoints where blocks are stored under
+ `video_model.blocks.N.*` / `audio_model.blocks.N.*` to the current
+ `fused_blocks.N.vid_block.*` / `fused_blocks.N.audio_block.*`.
+ """
+ needs_remap = any(re.match(r"^(video_model|audio_model)\.blocks\.\d+\.", k) for k in state_dict)
+ if needs_remap:
+ remapped = {}
+ for k, v in state_dict.items():
+ new_k = re.sub(r"^video_model\.blocks\.(\d+)\.", r"fused_blocks.\1.vid_block.", k)
+ new_k = re.sub(r"^audio_model\.blocks\.(\d+)\.", r"fused_blocks.\1.audio_block.", new_k)
+ remapped[new_k] = v
+ state_dict = remapped
+
+ self._detach_blocks_from_backbones()
+
+ return super().load_state_dict(state_dict, strict=strict, assign=assign)
+
+ def inject_cross_attention_kv_projections(self):
+ for vid_block in self.video_model.blocks:
+ vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim)
+ vid_block.cross_attn.v_fusion = nn.Linear(vid_block.dim, vid_block.dim)
+ vid_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(vid_block.dim, elementwise_affine=True)
+ vid_block.cross_attn.norm_k_fusion = (
+ WanRMSNorm(vid_block.dim, eps=1e-6) if vid_block.qk_norm else nn.Identity()
+ )
+
+ for audio_block in self.audio_model.blocks:
+ audio_block.cross_attn.k_fusion = nn.Linear(audio_block.dim, audio_block.dim)
+ audio_block.cross_attn.v_fusion = nn.Linear(audio_block.dim, audio_block.dim)
+ audio_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(audio_block.dim, elementwise_affine=True)
+ audio_block.cross_attn.norm_k_fusion = (
+ WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity()
+ )
+
+ def _detach_blocks_from_backbones(self) -> None:
+ """Keep offloadable blocks owned only by a single place.
+
+ NOTE: This is a special workaround to support layerwise offloading.
+ The model registers the same Wan blocks under both the video/audio
+ backbones and `fused_blocks` which is a wrapper for unified blocks
+ walking through. However, layerwise offloading will only consider
+ `fused_blocks` as offloadable components and will materialize all
+ other modules onto device, including the same blocks owned by both
+ `fused_blocks` and `video_model` and `audio_model`.
+ """
+ video_blocks = list(self.video_model.blocks)
+ audio_blocks = list(self.audio_model.blocks)
+ self.video_model._modules.pop("blocks", None)
+ self.audio_model._modules.pop("blocks", None)
+ self.video_model.blocks = tuple(video_blocks)
+ self.audio_model.blocks = tuple(audio_blocks)
+
+ def merge_kwargs(self, vid_kwargs, audio_kwargs):
+ """
+ keys in each kwarg:
+ e
+ seq_lens
+ grid_sizes
+ freqs
+ context
+ context_lens
+ """
+ merged_kwargs = {}
+ for key in vid_kwargs:
+ merged_kwargs[f"vid_{key}"] = vid_kwargs[key]
+ for key in audio_kwargs:
+ merged_kwargs[f"audio_{key}"] = audio_kwargs[key]
+ return merged_kwargs
+
def forward(
self,
vid,
@@ -316,17 +388,8 @@ def forward(
kwargs = self.merge_kwargs(vid_kwargs, audio_kwargs)
- for i in range(self.num_blocks):
- """
- 1 fusion block refers to 1 audio block with 1 video block.
- """
-
- vid_block = self.video_model.blocks[i]
- audio_block = self.audio_model.blocks[i]
-
- vid, audio = self.single_fusion_block_forward(
- vid_block=vid_block, audio_block=audio_block, vid=vid, audio=audio, **kwargs
- )
+ for fused_block in self.fused_blocks:
+ vid, audio = fused_block(vid, audio, self.attn, **kwargs)
vid = self.video_model.post_transformer_block_out(vid, vid_kwargs["grid_sizes"], vid_e)
audio = self.audio_model.post_transformer_block_out(audio, audio_kwargs["grid_sizes"], audio_e)
diff --git a/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py b/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py
index e22765f80e..cc932f8c1f 100644
--- a/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py
+++ b/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py
@@ -4,6 +4,7 @@
import logging
import math
import os
+from collections.abc import Iterable
import torch
import torch.distributed
@@ -15,12 +16,8 @@
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
-from vllm_omni.diffusion.distributed.parallel_state import (
- get_cfg_group,
- get_classifier_free_guidance_rank,
- get_classifier_free_guidance_world_size,
-)
from vllm_omni.diffusion.distributed.utils import get_local_device
+from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.interface import SupportAudioInput, SupportImageInput
from vllm_omni.diffusion.request import OmniDiffusionRequest
@@ -32,7 +29,6 @@
init_mmaudio_vae,
init_text_model,
init_wan_vae_2_2,
- load_fusion_checkpoint,
)
from dreamid_omni.utils.rearrange import Rearrange
from dreamid_omni.utils.resize import NaResize
@@ -43,6 +39,21 @@
logger = logging.getLogger(__name__)
+def get_dreamid_omni_post_process_func(*args, **kwargs):
+ def post_process(output):
+ if isinstance(output, tuple) and len(output) == 2:
+ video, audio = output
+ return {
+ "video": video,
+ "audio": audio,
+ "audio_sample_rate": 16000,
+ "fps": 24,
+ }
+ return output
+
+ return post_process
+
+
AUDIO_CONFIG = {
"patch_size": [1],
"model_type": "t2a",
@@ -112,16 +123,24 @@ def __init__(
self.text_model = init_text_model(model, rank=self.device)
self.text_encoder = self.text_model.model
- # Fusion model
- ## load audio/video model config
- Fusion_model = FusionModel(VIDEO_CONFIG, AUDIO_CONFIG)
-
- checkpoint_path = self.od_config.tf_model_config.get("fusion", None)
- assert checkpoint_path is not None, "fusion checkpoint path is None"
- load_fusion_checkpoint(Fusion_model, checkpoint_path=os.path.join(model, checkpoint_path))
- self.model = Fusion_model
+ # Fusion model — weights are loaded later via load_weights()
+ self.model = FusionModel(VIDEO_CONFIG, AUDIO_CONFIG)
self.transformer = self.model
+ fusion_path = self.od_config.tf_model_config.get("fusion", None)
+ assert fusion_path is not None, "fusion checkpoint path is None in transformer config"
+ fusion_subfolder = os.path.dirname(fusion_path) or None
+ fusion_filename = os.path.basename(fusion_path)
+ self.weights_sources = [
+ DiffusersPipelineLoader.ComponentSource(
+ model_or_path=model,
+ subfolder=fusion_subfolder,
+ revision=None,
+ prefix="model.",
+ allow_patterns_overrides=[fusion_filename],
+ )
+ ]
+
# Fixed attributes, non-configurable
self.audio_latent_channel = AUDIO_CONFIG.get("in_dim")
self.video_latent_channel = VIDEO_CONFIG.get("in_dim")
@@ -216,8 +235,11 @@ def load_image_latent_ref_ip_video(
return ref_vae_latents, ref_audio_lengths
- def load_weights(self, weights):
- pass
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ prefix = "model."
+ state_dict = {name[len(prefix) :]: tensor for name, tensor in weights if name.startswith(prefix)}
+ self.model.load_state_dict(state_dict, strict=True)
+ return {prefix + k for k in state_dict}
def get_scheduler_time_steps(self, sampling_steps, solver_name="unipc", device=0, shift=5.0):
torch.manual_seed(4)
@@ -249,6 +271,28 @@ def get_scheduler_time_steps(self, sampling_steps, solver_name="unipc", device=0
return sample_scheduler, timesteps
+ def predict_noise(self, **kwargs):
+ pred_vid, pred_audio = self.model(**kwargs)
+ return (pred_vid[0], pred_audio[0])
+
+ def combine_multi_branch_cfg_noise(self, predictions, true_cfg_scale, cfg_normalize=False):
+ vid_pos, audio_pos = predictions[0]
+ vid_neg, audio_neg = predictions[1]
+ vid_ip_neg, _ = predictions[2]
+ _, refaudio_neg = predictions[3]
+
+ pred_video = (
+ vid_neg
+ + true_cfg_scale["video_cfg_scale"] * (vid_pos - vid_neg)
+ + true_cfg_scale["video_ref_cfg_scale"] * (vid_pos - vid_ip_neg)
+ )
+ pred_audio = (
+ audio_neg
+ + true_cfg_scale["audio_cfg_scale"] * (audio_pos - audio_neg)
+ + true_cfg_scale["audio_ref_cfg_scale"] * (audio_pos - refaudio_neg)
+ )
+ return (pred_video, pred_audio)
+
def diffuse(
self,
video_noise: torch.Tensor,
@@ -306,72 +350,22 @@ def diffuse(
"vid_context": [text_embeddings_video_neg],
}
- if get_classifier_free_guidance_world_size() > 1:
- # Enable CFG-parallel: rank0 computes positive, rank1 computes negative.
- cfg_group = get_cfg_group()
- cfg_rank = get_classifier_free_guidance_rank()
-
- if cfg_rank == 0:
- pred_vid, pred_audio = self.model(
- vid=[model_input_video], audio=[model_input_audio], t=timestep_input, **pos_args
- )
- pre_vid_ip_neg, _ = self.model(
- vid=[model_input_video_neg], audio=[model_input_audio], t=timestep_input, **pos_args
- )
- pred_vid_0 = pred_vid[0]
- pred_audio_0 = pred_audio[0]
- pre_vid_ip_0 = pre_vid_ip_neg[0]
- pred_refaudio_0 = torch.zeros_like(pred_audio_0) # dummy tensor
- else:
- pred_vid, pred_audio = self.model(
- vid=[model_input_video], audio=[model_input_audio], t=timestep_input, **neg_args
- )
- _, pred_refaudio_neg = self.model(
- vid=[model_input_video], audio=[model_input_audio_neg], t=timestep_input, **pos_args
- )
- pred_vid_0 = pred_vid[0]
- pred_audio_0 = pred_audio[0]
- pre_vid_ip_0 = torch.zeros_like(pred_vid_0) # dummy tensor
- pred_refaudio_0 = pred_refaudio_neg[0]
-
- pred_vid_gathered = cfg_group.all_gather(pred_vid_0, separate_tensors=True)
- pred_audio_gathered = cfg_group.all_gather(pred_audio_0, separate_tensors=True)
- pre_vid_ip_gathered = cfg_group.all_gather(pre_vid_ip_0, separate_tensors=True)
- pred_refaudio_gathered = cfg_group.all_gather(pred_refaudio_0, separate_tensors=True)
-
- pred_vid_pos = [pred_vid_gathered[0]]
- pred_vid_neg = [pred_vid_gathered[1]]
- pred_audio_pos = [pred_audio_gathered[0]]
- pred_audio_neg = [pred_audio_gathered[1]]
- pre_vid_ip_neg = [pre_vid_ip_gathered[0]]
- pred_refaudio_neg = [pred_refaudio_gathered[1]]
- else:
- pred_vid_pos, pred_audio_pos = self.model(
- vid=[model_input_video], audio=[model_input_audio], t=timestep_input, **pos_args
- )
-
- pred_vid_neg, pred_audio_neg = self.model(
- vid=[model_input_video], audio=[model_input_audio], t=timestep_input, **neg_args
- )
-
- pre_vid_ip_neg, _ = self.model(
- vid=[model_input_video_neg], audio=[model_input_audio], t=timestep_input, **pos_args
- )
-
- _, pred_refaudio_neg = self.model(
- vid=[model_input_video], audio=[model_input_audio_neg], t=timestep_input, **pos_args
- )
-
- pred_video_guided = (
- pred_vid_neg[0]
- + self.video_cfg_scale * (pred_vid_pos[0] - pred_vid_neg[0])
- + self.video_ref_cfg_scale * (pred_vid_pos[0] - pre_vid_ip_neg[0])
- )
-
- pred_audio_guided = (
- pred_audio_neg[0]
- + self.audio_cfg_scale * (pred_audio_pos[0] - pred_audio_neg[0])
- + self.audio_ref_cfg_scale * (pred_audio_pos[0] - pred_refaudio_neg[0])
+ branches_kwargs = [
+ {"vid": [model_input_video], "audio": [model_input_audio], "t": timestep_input, **pos_args},
+ {"vid": [model_input_video], "audio": [model_input_audio], "t": timestep_input, **neg_args},
+ {"vid": [model_input_video_neg], "audio": [model_input_audio], "t": timestep_input, **pos_args},
+ {"vid": [model_input_video], "audio": [model_input_audio_neg], "t": timestep_input, **pos_args},
+ ]
+
+ pred_video_guided, pred_audio_guided = self.predict_noise_with_multi_branch_cfg(
+ do_true_cfg=True,
+ true_cfg_scale={
+ "video_cfg_scale": self.video_cfg_scale,
+ "video_ref_cfg_scale": self.video_ref_cfg_scale,
+ "audio_cfg_scale": self.audio_cfg_scale,
+ "audio_ref_cfg_scale": self.audio_ref_cfg_scale,
+ },
+ branches_kwargs=branches_kwargs,
)
video_noise = scheduler_video.step(
pred_video_guided.unsqueeze(0), t_v, video_noise.unsqueeze(0), return_dict=False
diff --git a/vllm_omni/diffusion/models/flux/flux_transformer.py b/vllm_omni/diffusion/models/flux/flux_transformer.py
index 680b8bfbbe..dd88ee76c1 100644
--- a/vllm_omni/diffusion/models/flux/flux_transformer.py
+++ b/vllm_omni/diffusion/models/flux/flux_transformer.py
@@ -381,7 +381,9 @@ def __init__(
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
- self.norm = AdaLayerNormZeroSingle(dim, quant_config=quant_config, prefix=f"{prefix}.norm")
+ # Modulation linear kept full precision; shift/scale/gate outputs
+ # are multiplied into the residual stream every block (see #2728).
+ self.norm = AdaLayerNormZeroSingle(dim, quant_config=None, prefix=f"{prefix}.norm")
self.proj_mlp = ReplicatedLinear(
dim,
self.mlp_hidden_dim,
@@ -524,10 +526,10 @@ def _is_transformer_block(name: str, module) -> bool:
def __init__(
self,
- od_config: OmniDiffusionConfig = None,
+ od_config: OmniDiffusionConfig | None = None,
patch_size: int = 1,
in_channels: int = 64,
- out_channels: int = None,
+ out_channels: int | None = None,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
@@ -563,13 +565,16 @@ def __init__(
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
+ # Dual-stream blocks kept full precision — FP8 on their joint
+ # attention path causes noise on FLUX (#2728). Single-stream
+ # blocks (38 vs 19) still get FP8 for memory savings.
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
- quant_config=quant_config,
+ quant_config=None,
prefix=f"transformer_blocks.{i}",
)
for i in range(num_layers)
@@ -589,12 +594,13 @@ def __init__(
]
)
+ # Final modulation feeds proj_out; keep full precision (see #2728).
self.norm_out = AdaLayerNormContinuous(
self.inner_dim,
self.inner_dim,
elementwise_affine=False,
eps=1e-6,
- quant_config=quant_config,
+ quant_config=None,
prefix="norm_out",
)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
diff --git a/vllm_omni/diffusion/models/flux/pipeline_flux.py b/vllm_omni/diffusion/models/flux/pipeline_flux.py
index 6f43e8dbb5..70d572d9a6 100644
--- a/vllm_omni/diffusion/models/flux/pipeline_flux.py
+++ b/vllm_omni/diffusion/models/flux/pipeline_flux.py
@@ -30,6 +30,7 @@
from vllm_omni.diffusion.models.t5_encoder import T5EncoderModel
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
+from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific
logger = logging.getLogger(__name__)
@@ -106,7 +107,11 @@ def __init__(
self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to(
self.device
)
- self.transformer = FluxTransformer2DModel(od_config=od_config, quant_config=od_config.quantization_config)
+
+ transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, FluxTransformer2DModel)
+ self.transformer = FluxTransformer2DModel(
+ **transformer_kwargs, od_config=od_config, quant_config=od_config.quantization_config
+ )
self.tokenizer = CLIPTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
self.tokenizer_2 = T5TokenizerFast.from_pretrained(
diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py
index 437dd58d0c..021d28c2ac 100644
--- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py
+++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py
@@ -234,7 +234,15 @@ def __init__(
self.transformer = Flux2Transformer2DModel(quant_config=od_config.quantization_config, **transformer_kwargs)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.latent_channels = self.vae.config.latent_channels if hasattr(self.vae, "config") else 16
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ vae_latent_channels=self.latent_channels,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
self.tokenizer_max_length = 512
self.default_sample_size = 128
@@ -247,6 +255,14 @@ def __init__(
enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler
)
+ def get_timesteps(self, num_inference_steps, strength, device):
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+ return timesteps, num_inference_steps - t_start
+
@staticmethod
def _get_qwen3_prompt_embeds(
text_encoder: Qwen3ForCausalLM,
@@ -582,6 +598,54 @@ def prepare_image_latents(
return image_latents, image_latent_ids
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ ):
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
+ mask = mask.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if masked_image is not None:
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ if masked_image.shape[1] != num_channels_latents:
+ masked_image_latents = self._encode_vae_image(image=masked_image, generator=generator)
+ else:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = None
+
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError("The passed mask and the required batch size don't match.")
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents is not None and masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError("The passed mask and the required batch size don't match.")
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ if masked_image_latents is not None:
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ masked_image_latents = self._pack_latents(masked_image_latents)
+
+ mask = mask.repeat(1, self.latent_channels, 1, 1)
+ mask = self._patchify_latents(mask)
+ mask = self._pack_latents(mask)
+
+ return mask, masked_image_latents
+
def check_inputs(
self,
prompt,
@@ -590,7 +654,15 @@ def check_inputs(
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
guidance_scale=None,
+ strength=None,
+ num_inference_steps=None,
):
+ if strength is not None:
+ if strength < 0 or strength > 1:
+ raise ValueError(f"strength must be between 0 and 1, got {strength}")
+ if num_inference_steps is not None and num_inference_steps <= 0:
+ raise ValueError(f"num_inference_steps must be positive, got {num_inference_steps}")
+
if (
height is not None
and height % (self.vae_scale_factor * 2) != 0
@@ -622,7 +694,7 @@ def check_inputs(
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
- if guidance_scale > 1.0 and self.is_distilled:
+ if guidance_scale is not None and guidance_scale > 1.0 and self.is_distilled:
logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.")
@property
@@ -653,11 +725,14 @@ def forward(
self,
req: OmniDiffusionRequest,
image: PIL.Image.Image | list[PIL.Image.Image] | None = None,
+ reference_image: PIL.Image.Image | list[PIL.Image.Image] | None = None,
+ mask_image: PIL.Image.Image | list[PIL.Image.Image] | None = None,
prompt: str | list[str] | None = None,
height: int | None = None,
width: int | None = None,
num_inference_steps: int = 50,
sigmas: list[float] | None = None,
+ strength: float = 1.0,
guidance_scale: float | None = 4.0,
num_images_per_prompt: int = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
@@ -671,6 +746,7 @@ def forward(
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
max_sequence_length: int = 512,
text_encoder_out_layers: tuple[int, ...] = (9, 18, 27),
+ padding_mask_crop: int | None = None,
) -> DiffusionOutput:
r"""
Function invoked when calling the pipeline for generation.
@@ -804,6 +880,8 @@ def forward(
prompt_embeds=prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
guidance_scale=guidance_scale,
+ strength=strength,
+ num_inference_steps=num_inference_steps,
)
self._guidance_scale = guidance_scale
@@ -848,6 +926,9 @@ def forward(
if image is not None and not isinstance(image, list):
image = [image]
+ multiple_of = self.vae_scale_factor * 2
+ crops_coords = None
+ resize_mode = "crop"
condition_images = None
if image is not None:
for img in image:
@@ -860,10 +941,14 @@ def forward(
img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
image_width, image_height = img.size
- multiple_of = self.vae_scale_factor * 2
image_width = (image_width // multiple_of) * multiple_of
image_height = (image_height // multiple_of) * multiple_of
- img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ img = self.image_processor.preprocess(
+ img, height=image_height, width=image_width, crops_coords=crops_coords, resize_mode=resize_mode
+ )
condition_images.append(img)
height = height or image_height
width = width or image_width
@@ -871,6 +956,16 @@ def forward(
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
+ # Get mask_image and reference_image
+ multi_modal_data = req.prompts[0].get("multi_modal_data", {}) if req.prompts else {}
+ mask_image = multi_modal_data.get("mask_image")
+ reference_image = multi_modal_data.get("reference_image")
+
+ if mask_image is not None and (image is None or (isinstance(image, list) and len(image) == 0)):
+ raise ValueError("image must be provided when using mask_image for inpainting")
+
+ init_image = condition_images[0] if condition_images else None
+
# 5. prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_ids = self.prepare_latents(
@@ -884,6 +979,8 @@ def forward(
latents=latents,
)
+ original_latent_ids = latent_ids
+
image_latents = None
image_latent_ids = None
if condition_images is not None:
@@ -895,6 +992,71 @@ def forward(
dtype=self.vae.dtype,
)
+ # Preprocess reference_image
+ if reference_image is not None and not (
+ isinstance(reference_image, torch.Tensor) and reference_image.size(1) == self.latent_channels
+ ):
+ if (
+ isinstance(reference_image, list)
+ and isinstance(reference_image[0], torch.Tensor)
+ and reference_image[0].ndim == 4
+ ):
+ reference_image = torch.cat(reference_image, dim=0)
+ img_reference = reference_image[0] if isinstance(reference_image, list) else reference_image
+ reference_image_height, reference_image_width = self.image_processor.get_default_height_width(img_reference)
+
+ reference_image_width = reference_image_width // multiple_of * multiple_of
+ reference_image_height = reference_image_height // multiple_of * multiple_of
+ reference_image = self.image_processor.resize(
+ reference_image, reference_image_height, reference_image_width
+ )
+ reference_image = self.image_processor.preprocess(
+ reference_image,
+ reference_image_height,
+ reference_image_width,
+ crops_coords=crops_coords,
+ resize_mode=resize_mode,
+ )
+ else:
+ pass
+
+ reference_image_latents = None
+ reference_image_latent_ids = None
+ if reference_image is not None:
+ reference_image_latents, reference_image_latent_ids = self.prepare_image_latents(
+ images=[reference_image],
+ batch_size=batch_size * num_images_per_prompt,
+ generator=generator,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+
+ if reference_image_latent_ids is not None:
+ latent_ids = torch.cat([latent_ids, reference_image_latent_ids], dim=1)
+ elif image_latent_ids is not None:
+ latent_ids = torch.cat([latent_ids, image_latent_ids], dim=1)
+
+ mask = None
+ if mask_image is not None:
+ mask_condition = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+ if init_image is not None:
+ masked_image = init_image * (mask_condition < 0.5).to(init_image.dtype)
+ else:
+ masked_image = None
+ mask, _ = self.prepare_mask_latents(
+ mask_condition,
+ masked_image,
+ batch_size,
+ self.latent_channels,
+ num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
# 6. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
@@ -908,6 +1070,13 @@ def forward(
sigmas=sigmas,
mu=mu,
)
+ if reference_image is not None or mask_image is not None:
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
self._num_timesteps = len(timesteps)
# 7. Denoising loop
@@ -924,7 +1093,10 @@ def forward(
latent_model_input = latents.to(self.transformer.dtype)
latent_image_ids = latent_ids
- if image_latents is not None:
+ if reference_image_latents is not None:
+ latent_model_input = torch.cat([latents, reference_image_latents], dim=1)
+ latent_image_ids = latent_ids
+ elif image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)
@@ -953,7 +1125,9 @@ def forward(
negative_kwargs = None
# For editing pipelines, we need to slice the output to remove condition latents
- output_slice = latents.size(1) if image_latents is not None else None
+ output_slice = (
+ latents.size(1) if (image_latents is not None or reference_image_latents is not None) else None
+ )
noise_pred = self.predict_noise_maybe_with_cfg(
do_true_cfg=self.do_classifier_free_guidance,
@@ -964,9 +1138,22 @@ def forward(
output_slice=output_slice,
)
+ latents_dtype = latents.dtype
# Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync
latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, self.do_classifier_free_guidance)
+ if mask is not None and image_latents is not None:
+ init_latents_proper = image_latents
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.scale_noise(
+ init_latents_proper, torch.tensor([noise_timestep], device=device), latents
+ )
+ latents = (1 - mask) * init_latents_proper + mask * latents
+
+ if latents.dtype != latents_dtype and torch.backends.mps.is_available():
+ latents = latents.to(latents_dtype)
+
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
@@ -978,7 +1165,7 @@ def forward(
self._current_timestep = None
- latents = self._unpack_latents_with_ids(latents, latent_ids)
+ latents = self._unpack_latents_with_ids(latents, original_latent_ids)
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
diff --git a/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py
index 490e0198b9..7ff42a5f00 100644
--- a/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py
+++ b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py
@@ -19,10 +19,16 @@
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
from vllm_omni.diffusion.attention.layer import Attention
from vllm_omni.diffusion.cache.base import CachedTransformer
-from vllm_omni.diffusion.data import OmniDiffusionConfig
+from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.hsdp_utils import is_transformer_block_module
+from vllm_omni.diffusion.distributed.sp_plan import (
+ SequenceParallelInput,
+ SequenceParallelOutput,
+)
+from vllm_omni.diffusion.forward_context import get_forward_context
logger = init_logger(__name__)
@@ -108,8 +114,8 @@ def __init__(
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, channel, height, width = hidden_states.shape
- post_patch_height = height // self.patch_size
- post_patch_width = width // self.patch_size
+ post_patch_height = torch.tensor(height // self.patch_size, device=hidden_states.device, dtype=torch.int64)
+ post_patch_width = torch.tensor(width // self.patch_size, device=hidden_states.device, dtype=torch.int64)
# Reshape: [B, C, H, W] -> [B, H', W', C*p*p] -> [B, H'*W', C*p*p]
hidden_states = hidden_states.reshape(
@@ -159,6 +165,65 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
return (freqs.cos(), freqs.sin())
+class GlmImagePrepare(nn.Module):
+ """Prepare module for GLM-Image that handles patch embedding and RoPE computation.
+
+ This module encapsulates the input processing pipeline to create a module boundary
+ where _sp_plan can shard outputs via split_output=True.
+
+ Similar to Qwen-Image's ImageRopePrepare, this ensures hidden_states and RoPE
+ embeddings are sharded together to maintain dimension alignment.
+ """
+
+ def __init__(
+ self,
+ image_projector: nn.Module,
+ rope: GlmImageRotaryPosEmbed,
+ patch_size: int,
+ ):
+ super().__init__()
+ self.image_projector = image_projector
+ self.rope = rope
+ self.patch_size = patch_size
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ prior_hidden_states: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Process hidden_states and compute RoPE embeddings.
+
+ Args:
+ hidden_states: Input latent tensor [B, C, H, W]
+ prior_hidden_states: Optional prior embedding to add
+
+ Returns:
+ hidden_states: Patched hidden states [B, seq_len, D]
+ rope_cos: RoPE cos embeddings [seq_len, dim]
+ rope_sin: RoPE sin embeddings [seq_len, dim]
+ post_patch_height: Scalar tensor for height after patching
+ post_patch_width: Scalar tensor for width after patching
+ """
+ batch_size, num_channels, height, width = hidden_states.shape
+
+ post_patch_height = torch.tensor(height // self.patch_size, device=hidden_states.device, dtype=torch.int64)
+ post_patch_width = torch.tensor(width // self.patch_size, device=hidden_states.device, dtype=torch.int64)
+
+ # Compute RoPE (uses original 4D hidden_states shape)
+ image_rotary_emb = self.rope(hidden_states)
+ rope_cos = image_rotary_emb[0].to(hidden_states.device)
+ rope_sin = image_rotary_emb[1].to(hidden_states.device)
+
+ # Patch embedding: [B, C, H, W] -> [B, seq_len, D]
+ hidden_states = self.image_projector(hidden_states)
+
+ # Add prior embedding if provided
+ if prior_hidden_states is not None:
+ hidden_states = hidden_states + prior_hidden_states
+
+ return hidden_states, rope_cos, rope_sin, post_patch_height, post_patch_width
+
+
class GlmImageAdaLayerNormZero(nn.Module):
"""Adaptive LayerNorm with zero initialization for both image and text streams."""
@@ -397,6 +462,7 @@ def __init__(
dim: int,
num_heads: int,
head_dim: int,
+ parallel_config: DiffusionParallelConfig | None = None,
out_bias: bool = True,
eps: float = 1e-5,
):
@@ -404,6 +470,7 @@ def __init__(
self.dim = dim
self.total_num_heads = num_heads
self.head_dim = head_dim
+ self.parallel_config = parallel_config
# QKV projection (fused for efficiency)
self.to_qkv = QKVParallelLinear(
@@ -450,16 +517,19 @@ def forward(
attention_mask: torch.Tensor | None = None,
kv_cache: GlmImageLayerKVCache | None = None,
kv_cache_mode: KVCacheMode | None = None,
+ hidden_states_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for joint attention.
Args:
- hidden_states: Image hidden states [B, img_seq_len, D]
- encoder_hidden_states: Text hidden states [B, text_seq_len, D]
- image_rotary_emb: Tuple of (cos, sin) for RoPE
+ hidden_states: Image hidden states [B, img_seq_len, D] (sharded in SP mode)
+ encoder_hidden_states: Text hidden states [B, text_seq_len, D] (full in SP mode)
+ image_rotary_emb: Tuple of (cos, sin) for RoPE (sharded in SP mode)
+ attention_mask: Optional attention mask
kv_cache: Optional layer KV cache for image editing
kv_cache_mode: Cache mode (WRITE, READ, SKIP)
+ hidden_states_mask: Mask for SP padding (True=valid, False=padding)
Returns:
Tuple of (image_hidden_states, text_hidden_states)
@@ -467,6 +537,13 @@ def forward(
dtype = encoder_hidden_states.dtype
batch_size, text_seq_length, _ = encoder_hidden_states.shape
+ # Check if SP is enabled
+ sp_size = self.parallel_config.sequence_parallel_size if self.parallel_config else None
+ use_sp = sp_size is not None and sp_size > 1
+ if use_sp:
+ forward_ctx = get_forward_context()
+ use_sp = not forward_ctx.split_text_embed_in_sp
+
# Concatenate text and image: [text, image]
hidden_states_combined = torch.cat([encoder_hidden_states, hidden_states], dim=1)
@@ -485,41 +562,88 @@ def forward(
query = self.norm_q(query).to(dtype=dtype)
key = self.norm_k(key).to(dtype=dtype)
- # Apply RoPE only to image tokens (not text tokens)
- if image_rotary_emb is not None:
- # Only apply RoPE to image part (after text_seq_length)
- query_img = query[:, text_seq_length:, :, :]
- key_img = key[:, text_seq_length:, :, :]
- from diffusers.models.embeddings import apply_rotary_emb
-
- query_img = apply_rotary_emb(query_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
- key_img = apply_rotary_emb(key_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
- query = torch.cat([query[:, :text_seq_length, :, :], query_img], dim=1)
- key = torch.cat([key[:, :text_seq_length, :, :], key_img], dim=1)
-
- # Handle KV cache for image editing
- if kv_cache is not None and kv_cache_mode is not None:
- if kv_cache_mode == KVCacheMode.WRITE:
- kv_cache.store(key, value)
- elif kv_cache_mode == KVCacheMode.READ:
- k_cached, v_cached = kv_cache.get()
- if k_cached is not None:
- key = torch.cat([k_cached, key], dim=1)
- value = torch.cat([v_cached, value], dim=1)
- # KVCacheMode.SKIP: do nothing
-
- # Attention computation
- hidden_states_out = self.attn(query, key, value)
- hidden_states_out = hidden_states_out.flatten(2, 3)
- hidden_states_out = hidden_states_out.to(dtype)
+ if use_sp:
+ # SP mode: use joint attention mechanism
+ # Split Q/K/V into text and image parts
+ text_query = query[:, :text_seq_length, :, :]
+ text_key = key[:, :text_seq_length, :, :]
+ text_value = value[:, :text_seq_length, :, :]
+ img_query = query[:, text_seq_length:, :, :]
+ img_key = key[:, text_seq_length:, :, :]
+ img_value = value[:, text_seq_length:, :, :]
+
+ # Apply RoPE only to image part
+ if image_rotary_emb is not None:
+ from diffusers.models.embeddings import apply_rotary_emb
+
+ img_query = apply_rotary_emb(img_query, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
+ img_key = apply_rotary_emb(img_key, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
+
+ # Create attention metadata for joint attention
+ attn_metadata = AttentionMetadata(
+ joint_query=text_query,
+ joint_key=text_key,
+ joint_value=text_value,
+ joint_strategy="front",
+ )
- # Output projection
- for module in self.to_out:
- hidden_states_out = module(hidden_states_out)
+ # Add padding mask for SP if available
+ if hidden_states_mask is not None:
+ attn_metadata.attn_mask = hidden_states_mask
+
+ # Attention computation with joint text/image
+ # Note: Ulysses post_attention returns [text, image] concatenated
+ joint_hidden_states_out = self.attn(img_query, img_key, img_value, attn_metadata)
+
+ # Project combined [text, image] outputs, then split.
+ # This keeps SP numerically aligned with the non-SP path.
+ joint_hidden_states_out = joint_hidden_states_out.flatten(2, 3).to(dtype)
+ for module in self.to_out:
+ joint_hidden_states_out = module(joint_hidden_states_out)
- # Split back to text and image
- encoder_hidden_states_out = hidden_states_out[:, :text_seq_length, :]
- hidden_states_out = hidden_states_out[:, text_seq_length:, :]
+ encoder_hidden_states_out = joint_hidden_states_out[:, :text_seq_length, :]
+ hidden_states_out = joint_hidden_states_out[:, text_seq_length:, :]
+ else:
+ # Non-SP mode: original logic
+ # Apply RoPE only to image tokens (not text tokens)
+ if image_rotary_emb is not None:
+ query_img = query[:, text_seq_length:, :, :]
+ key_img = key[:, text_seq_length:, :, :]
+ from diffusers.models.embeddings import apply_rotary_emb
+
+ query_img = apply_rotary_emb(query_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
+ key_img = apply_rotary_emb(key_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
+ query = torch.cat([query[:, :text_seq_length, :, :], query_img], dim=1)
+ key = torch.cat([key[:, :text_seq_length, :, :], key_img], dim=1)
+
+ # Handle KV cache for image editing
+ if kv_cache is not None and kv_cache_mode is not None:
+ if kv_cache_mode == KVCacheMode.WRITE:
+ kv_cache.store(key, value)
+ elif kv_cache_mode == KVCacheMode.READ:
+ k_cached, v_cached = kv_cache.get()
+ if k_cached is not None:
+ key = torch.cat([k_cached, key], dim=1)
+ value = torch.cat([v_cached, value], dim=1)
+
+ # Attention computation
+ attn_metadata = None
+ if attention_mask is not None:
+ if attention_mask.dim() == 3:
+ attention_mask = attention_mask.unsqueeze(1)
+ attn_metadata = AttentionMetadata(attn_mask=attention_mask)
+
+ hidden_states_out = self.attn(query, key, value, attn_metadata)
+ hidden_states_out = hidden_states_out.flatten(2, 3)
+ hidden_states_out = hidden_states_out.to(dtype)
+
+ # Output projection
+ for module in self.to_out:
+ hidden_states_out = module(hidden_states_out)
+
+ # Split back to text and image
+ encoder_hidden_states_out = hidden_states_out[:, :text_seq_length, :]
+ hidden_states_out = hidden_states_out[:, text_seq_length:, :]
return hidden_states_out, encoder_hidden_states_out
@@ -628,6 +752,7 @@ def __init__(
attention_head_dim: int = 40,
time_embed_dim: int = 512,
ffn_hidden_dim: int | None = None,
+ parallel_config: DiffusionParallelConfig | None = None,
) -> None:
super().__init__()
@@ -637,6 +762,7 @@ def __init__(
dim=dim,
num_heads=num_attention_heads,
head_dim=attention_head_dim,
+ parallel_config=parallel_config,
)
# 2. Feedforward
@@ -654,6 +780,7 @@ def forward(
attention_kwargs: dict[str, Any] | None = None,
kv_cache: GlmImageLayerKVCache | None = None,
kv_cache_mode: KVCacheMode | None = None,
+ hidden_states_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for transformer block.
@@ -667,6 +794,7 @@ def forward(
attention_kwargs: Additional attention arguments
kv_cache: Layer-specific KV cache for image editing
kv_cache_mode: Cache mode (WRITE, READ, SKIP)
+ hidden_states_mask: Mask for SP padding (True=valid, False=padding)
Returns:
Tuple of (image_hidden_states, text_hidden_states)
@@ -693,6 +821,7 @@ def forward(
attention_mask=attention_mask,
kv_cache=kv_cache,
kv_cache_mode=kv_cache_mode,
+ hidden_states_mask=hidden_states_mask,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
@@ -724,6 +853,26 @@ class GlmImageTransformer2DModel(CachedTransformer):
"""
_repeated_blocks = ["GlmImageTransformerBlock"]
+ # SP plan using GlmImagePrepare module for sharding hidden_states and RoPE together.
+ # Similar to Qwen-Image's ImageRopePrepare, this creates a module boundary where
+ # _sp_plan can shard outputs via split_output=True.
+ #
+ # Key insight: hidden_states and RoPE embeddings MUST be sharded together
+ # to maintain dimension alignment for RoPE computation in attention layers.
+ _sp_plan = {
+ # Shard GlmImagePrepare outputs (hidden_states and RoPE must be sharded together)
+ "prepare": {
+ # hidden_states: [B, seq_len, D] - shard along sequence dimension
+ 0: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True, auto_pad=True),
+ # RoPE cos: [seq_len, dim] - shard along sequence dimension
+ 1: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True),
+ # RoPE sin: [seq_len, dim] - shard along sequence dimension
+ 2: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True),
+ # post_patch_height and post_patch_width are scalars, not sharded
+ },
+ # Gather output at proj_out
+ "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
+ }
_hsdp_shard_conditions = [is_transformer_block_module]
@@ -790,6 +939,9 @@ def __init__(
dim=inner_dim, dim_out=inner_dim, inner_dim=inner_dim, activation_fn="linear-silu"
)
+ # Prepare module for SP (encapsulates patch embedding and RoPE for _sp_plan)
+ self.prepare = GlmImagePrepare(self.image_projector, self.rope, patch_size)
+
self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings(
embedding_dim=time_embed_dim,
condition_dim=condition_dim,
@@ -806,6 +958,7 @@ def __init__(
attention_head_dim,
time_embed_dim,
ffn_hidden_dim=ffn_hidden_dim,
+ parallel_config=self.parallel_config,
)
for _ in range(num_layers)
]
@@ -859,33 +1012,51 @@ def forward(
# Get KV cache mode
kv_cache_mode = kv_cache.mode if kv_cache is not None else None
- # 1. RoPE
- if image_rotary_emb is None:
- image_rotary_emb = self.rope(hidden_states)
- # Move to correct device
- image_rotary_emb = (
- image_rotary_emb[0].to(hidden_states.device),
- image_rotary_emb[1].to(hidden_states.device),
- )
-
- # 2. Patch & Timestep embeddings
- p = self.patch_size
- post_patch_height = height // p
- post_patch_width = width // p
+ # Set SP context if enabled
+ sp_size = self.parallel_config.sequence_parallel_size
+ if sp_size is not None and sp_size > 1:
+ get_forward_context().split_text_embed_in_sp = False
- hidden_states = self.image_projector(hidden_states)
+ # Text embedding projection
encoder_hidden_states = self.glyph_projector(encoder_hidden_states)
# Prior embedding with dropout
prior_embedding = self.prior_token_embedding(prior_token_id)
prior_embedding[prior_token_drop] *= 0.0
prior_hidden_states = self.prior_projector(prior_embedding)
- hidden_states = hidden_states + prior_hidden_states
+
+ # 1. Prepare hidden_states and RoPE via GlmImagePrepare module
+ # _sp_plan will shard hidden_states and RoPE together via split_output=True
+ hidden_states, rope_cos, rope_sin, post_patch_height_t, post_patch_width_t = self.prepare(
+ hidden_states, prior_hidden_states
+ )
+ image_rotary_emb = (rope_cos, rope_sin)
+ post_patch_height = int(post_patch_height_t.item())
+ post_patch_width = int(post_patch_width_t.item())
# Timestep conditioning
temb = self.time_condition_embed(timestep, target_size, crop_coords, hidden_states.dtype)
- # 3. Transformer blocks
+ # Create padding mask for SP if needed (after _sp_plan hooks have run)
+ hidden_states_mask = None
+ if sp_size is not None and sp_size > 1:
+ from vllm_omni.diffusion.forward_context import is_forward_context_available
+
+ if is_forward_context_available():
+ ctx = get_forward_context()
+ if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0:
+ img_padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size
+ hidden_states_mask = torch.ones(
+ batch_size,
+ img_padded_seq_len,
+ dtype=torch.bool,
+ device=hidden_states.device,
+ )
+ hidden_states_mask[:, ctx.sp_original_seq_len :] = False
+ if hidden_states_mask.all():
+ hidden_states_mask = None
+
+ # 2. Transformer blocks
for layer_idx, block in enumerate(self.transformer_blocks):
# Get layer-specific KV cache if available
layer_kv_cache = kv_cache[layer_idx] if kv_cache is not None else None
@@ -899,13 +1070,16 @@ def forward(
attention_kwargs,
kv_cache=layer_kv_cache,
kv_cache_mode=kv_cache_mode,
+ hidden_states_mask=hidden_states_mask,
)
- # 4. Output norm & projection
+ # 3. Output norm & projection
+ # _sp_plan will gather hidden_states via proj_out hook
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
- # 5. Unpatchify: [B, H'*W', C*p*p] -> [B, C, H, W]
+ # 4. Unpatchify: [B, H'*W', C*p*p] -> [B, C, H, W]
+ p = self.patch_size
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
diff --git a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py
index 375f7e7b80..0386364998 100644
--- a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py
+++ b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py
@@ -712,6 +712,14 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
if img is not None:
preprocessed_images = [img]
+ # Priority: prompt dict (from ar2diffusion) > sampling_params
+ # ar2diffusion returns adjusted height/width that matches prior_token_ids
+ if not isinstance(first_prompt, str):
+ ar_height = first_prompt.get("height")
+ ar_width = first_prompt.get("width")
+ else:
+ ar_height = ar_width = None
+
img_height = req.sampling_params.height
img_width = req.sampling_params.width
@@ -719,12 +727,19 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
# Treat that as t2i warmup to avoid requiring i2i-only KV-cache inputs.
is_image_edit = (preprocessed_images is not None) and (not is_dummy_warmup)
- # Use image dimensions as default if available
- height = req.sampling_params.height or img_height or self.default_sample_size * self.vae_scale_factor
- width = req.sampling_params.width or img_width or self.default_sample_size * self.vae_scale_factor
+ # Use prompt dict dimensions (from ar2diffusion) as priority, then sampling_params
+ height = (
+ ar_height or req.sampling_params.height or img_height or self.default_sample_size * self.vae_scale_factor
+ )
+ width = ar_width or req.sampling_params.width or img_width or self.default_sample_size * self.vae_scale_factor
num_inference_steps = req.sampling_params.num_inference_steps or 50
guidance_scale = req.sampling_params.guidance_scale or 1.5
+ # Ensure dimensions are multiples of vae_scale_factor * patch_size
+ multiple_of = self.vae_scale_factor * self._patch_size
+ height = height // multiple_of * multiple_of
+ width = width // multiple_of * multiple_of
+
self.check_inputs(prompt=prompt, height=height, width=width, prompt_embeds=prompt_embeds)
batch_size = 1
@@ -753,6 +768,20 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
prior_token_id = prior_token_id.to(device=self.device, dtype=torch.long)
if prior_token_id.dim() == 1:
prior_token_id = prior_token_id.unsqueeze(0)
+
+ # Validate that prior_token_id seq_len matches dimensions
+ prior_seq_len = prior_token_id.shape[1]
+ expected_seq_len = (height // self.vae_scale_factor // self._patch_size) * (
+ width // self.vae_scale_factor // self._patch_size
+ )
+ if prior_seq_len != expected_seq_len:
+ raise ValueError(
+ f"prior_token_ids seq_len ({prior_seq_len}) doesn't match dimensions "
+ f"({height}x{width}, expected seq_len={expected_seq_len}). "
+ f"This indicates a mismatch between AR output and Diffusion input. "
+ f"Please ensure ar2diffusion returns correct height/width."
+ )
+
prior_token_image_ids = None
if external_prior_image_ids is not None:
if isinstance(external_prior_image_ids, torch.Tensor):
diff --git a/vllm_omni/diffusion/models/helios/helios_transformer.py b/vllm_omni/diffusion/models/helios/helios_transformer.py
index b3d2621ad8..5e7934c3ba 100644
--- a/vllm_omni/diffusion/models/helios/helios_transformer.py
+++ b/vllm_omni/diffusion/models/helios/helios_transformer.py
@@ -62,10 +62,16 @@ def apply_rotary_emb_helios(
"""
x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
- out = torch.empty_like(hidden_states)
- out[..., 0::2] = x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2]
- out[..., 1::2] = x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2]
- return out.type_as(hidden_states)
+ # Use stack+flatten instead of strided slice assignment for contiguous
+ # memory layout and better performance on GPU/NPU (#2436, cf. PR #2393).
+ rotated = torch.stack(
+ (
+ x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2],
+ x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2],
+ ),
+ dim=-1,
+ )
+ return rotated.flatten(-2, -1).type_as(hidden_states)
class DistributedRMSNorm(nn.Module):
diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/__init__.py b/vllm_omni/diffusion/models/hunyuan_image3/__init__.py
similarity index 58%
rename from vllm_omni/diffusion/models/hunyuan_image_3/__init__.py
rename to vllm_omni/diffusion/models/hunyuan_image3/__init__.py
index cbc6a8ad1f..6612bd855b 100644
--- a/vllm_omni/diffusion/models/hunyuan_image_3/__init__.py
+++ b/vllm_omni/diffusion/models/hunyuan_image3/__init__.py
@@ -2,12 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Hunyuan Image 3 diffusion model components."""
-from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import HunyuanFusedMoE
-from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_image_3_transformer import (
+from vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe import HunyuanFusedMoE
+from vllm_omni.diffusion.models.hunyuan_image3.hunyuan_image3_transformer import (
HunyuanImage3Model,
HunyuanImage3Text2ImagePipeline,
)
-from vllm_omni.diffusion.models.hunyuan_image_3.pipeline_hunyuan_image_3 import (
+from vllm_omni.diffusion.models.hunyuan_image3.pipeline_hunyuan_image3 import (
HunyuanImage3Pipeline,
)
diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/autoencoder.py b/vllm_omni/diffusion/models/hunyuan_image3/autoencoder.py
similarity index 100%
rename from vllm_omni/diffusion/models/hunyuan_image_3/autoencoder.py
rename to vllm_omni/diffusion/models/hunyuan_image3/autoencoder.py
diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_fused_moe.py
similarity index 100%
rename from vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py
rename to vllm_omni/diffusion/models/hunyuan_image3/hunyuan_fused_moe.py
diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_tokenizer.py b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_tokenizer.py
similarity index 99%
rename from vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_tokenizer.py
rename to vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_tokenizer.py
index ce563f7115..4a29e9df93 100644
--- a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_tokenizer.py
+++ b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_tokenizer.py
@@ -13,7 +13,7 @@
from transformers import AutoTokenizer
from vllm.logger import init_logger
-from .hunyuan_image_3_transformer import ImageInfo, JointImageInfo, default
+from .hunyuan_image3_transformer import ImageInfo, JointImageInfo, default
logger = init_logger(__name__)
diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py
similarity index 99%
rename from vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py
rename to vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py
index bc81ca9c3e..fbdacddaf3 100644
--- a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py
+++ b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py
@@ -74,7 +74,7 @@
)
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.layers.rope import RotaryEmbedding
-from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import HunyuanFusedMoE
+from vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe import HunyuanFusedMoE
logger = logging.getLogger(__name__)
@@ -1684,7 +1684,8 @@ def forward(
else:
attn_output = self.attn(q, k, v)
# For o_proj
- attn_output = attn_output.view(q.shape[0], -1)
+ # image_attn may return a non-contiguous tensor; reshape is safe here.
+ attn_output = attn_output.reshape(q.shape[0], -1)
output, _ = self.o_proj(attn_output)
output = output.reshape(bsz, q_len, -1)
return output, None, past_key_value
diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py
similarity index 99%
rename from vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py
rename to vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py
index 7e9e2d2787..3de0ab3101 100644
--- a/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py
+++ b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py
@@ -6,7 +6,6 @@
from collections.abc import Iterable
from typing import Any
-import numpy as np
import torch
import torch.nn as nn
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
@@ -25,8 +24,8 @@
from vllm_omni.diffusion.request import OmniDiffusionRequest
from .autoencoder import AutoencoderKLConv3D
-from .hunyuan_image_3_tokenizer import TokenizerWrapper
-from .hunyuan_image_3_transformer import (
+from .hunyuan_image3_tokenizer import TokenizerWrapper
+from .hunyuan_image3_transformer import (
CausalMMOutputWithPast,
HunyuanImage3ImageProcessor,
HunyuanImage3Model,
@@ -544,7 +543,7 @@ def prepare_model_inputs(
generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
# 3. apply chat template
- cfg_factor = {"gen_text": 1, "gen_image": 2}
+ cfg_factor = {"gen_text": 1, "gen_image": 1 + int(guidance_scale > 1.0)}
bot_task = kwargs.pop("bot_task", "auto")
# If `drop_think` enabled, always drop parts in the context.
drop_think = kwargs.get("drop_think", self.generation_config.drop_think)
@@ -1009,8 +1008,7 @@ def forward(
if req.sampling_params.guidance_scale_provided:
guidance_scale = req.sampling_params.guidance_scale
if guidance_scale <= 1.0:
- logger.warning("HunyuanImage3.0 does not support guidance_scale <= 1.0, will set it to 1.0 + epsilon.")
- guidance_scale = 1.0 + np.finfo(float).eps
+ logger.info("HunyuanImage3.0 runs without classifier-free guidance when guidance_scale <= 1.0.")
image_size = (height, width)
model_inputs = self.prepare_model_inputs(
prompt=prompt,
diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/system_prompt.py b/vllm_omni/diffusion/models/hunyuan_image3/system_prompt.py
similarity index 100%
rename from vllm_omni/diffusion/models/hunyuan_image_3/system_prompt.py
rename to vllm_omni/diffusion/models/hunyuan_image3/system_prompt.py
diff --git a/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py b/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py
index a1bf7f7809..95ef919c24 100644
--- a/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py
+++ b/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py
@@ -1264,6 +1264,7 @@ class LTX2VideoTransformer3DModel(nn.Module):
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
_repeated_blocks = ["LTX2VideoTransformerBlock"]
+ _layerwise_offload_blocks_attrs = ["transformer_blocks"]
_sp_plan: dict[str, Any] | None = None
@staticmethod
diff --git a/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py b/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py
index 9ff681a3c0..3f03563a1c 100644
--- a/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py
+++ b/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py
@@ -5,6 +5,8 @@
import torch
import torch.nn as nn
+import torch.nn.functional as F
+import vllm._custom_ops as ops
from diffusers.models.activations import get_activation
from diffusers.models.embeddings import Timesteps, get_1d_rotary_pos_embed
from diffusers.models.modeling_outputs import Transformer2DModelOutput
@@ -16,6 +18,7 @@
QKVParallelLinear,
RowParallelLinear,
)
+from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm_omni.diffusion.attention.layer import Attention
@@ -24,6 +27,105 @@
logger = logging.getLogger(__name__)
+def _patch_cutlass_padded_fp8():
+ """Monkey-patch vllm._custom_ops.cutlass_scaled_mm to pad tensors whose
+ dimensions are not multiples of 16, so the CUTLASS FP8 kernel is used.
+
+ OmniGen2 has hidden_size=2520 (2520 % 16 == 8). Without this patch,
+ vLLM's cutlass_scaled_mm falls back to a Triton scaled_mm kernel for
+ every FP8 linear layer (QKV, attn output, gate_up_proj, down_proj),
+ which is dramatically slower than the native CUTLASS FP8 tensor-core
+ path on H100/H200 GPUs.
+
+ Weight tensors (b) are constant across forward passes, so padded
+ versions are computed once and cached by data_ptr to avoid repeated
+ allocation and column-major conversion overhead.
+ """
+ _orig_cutlass_scaled_mm = ops.cutlass_scaled_mm
+ # Cache: data_ptr → (padded_b, padded_bias, padded_scale_b, pad_k, pad_n, orig_n)
+ _weight_cache: dict[int, tuple] = {}
+
+ def _padded_cutlass_scaled_mm(
+ a: torch.Tensor,
+ b: torch.Tensor,
+ scale_a: torch.Tensor,
+ scale_b: torch.Tensor,
+ out_dtype: torch.dtype,
+ bias: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ if b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0:
+ return _orig_cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
+
+ # Reshape to 2D (mirrors the original function)
+ target_shape = (*a.shape[:-1], b.shape[1])
+ a = a.view(-1, a.shape[-1])
+ orig_n = b.shape[1]
+
+ # Cache the padded weight — it's a model parameter that never changes.
+ key = b.data_ptr()
+ if key not in _weight_cache:
+ pad_k = (16 - b.shape[0] % 16) % 16
+ pad_n = (16 - orig_n % 16) % 16
+ b_pad = b
+ if pad_k > 0:
+ b_pad = F.pad(b_pad, (0, 0, 0, pad_k))
+ if pad_n > 0:
+ b_pad = F.pad(b_pad, (0, pad_n))
+ # CUTLASS requires b column-major (stride(0)==1).
+ b_pad = b_pad.t().contiguous().t()
+
+ bias_pad = None
+ if bias is not None and pad_n > 0:
+ bias_pad = F.pad(bias, (0, pad_n))
+
+ scale_b_pad = scale_b
+ if scale_b.numel() > 1 and pad_n > 0:
+ scale_b_pad = F.pad(
+ scale_b.view(-1, scale_b.shape[-1]),
+ (0, pad_n),
+ value=1.0,
+ )
+
+ _weight_cache[key] = (
+ b_pad,
+ bias_pad,
+ scale_b_pad,
+ pad_k,
+ pad_n,
+ orig_n,
+ )
+
+ b_pad, bias_pad, scale_b_pad, pad_k, pad_n, orig_n = _weight_cache[key]
+
+ # Pad activations on K dimension (cheap — activations are small).
+ if pad_k > 0:
+ a = F.pad(a, (0, pad_k)).contiguous()
+
+ out = torch.empty((a.shape[0], b_pad.shape[1]), dtype=out_dtype, device=a.device)
+ torch.ops._C.cutlass_scaled_mm(
+ out,
+ a,
+ b_pad,
+ scale_a,
+ scale_b_pad,
+ bias_pad if bias is not None else None,
+ )
+
+ if pad_n > 0:
+ out = out[:, :orig_n]
+
+ return out.view(*target_shape)
+
+ ops.cutlass_scaled_mm = _padded_cutlass_scaled_mm
+ logger.info(
+ "Patched vllm._custom_ops.cutlass_scaled_mm with CUTLASS-padded FP8 "
+ "variant (avoids slow Triton fallback for non-%%16 dimensions)"
+ )
+
+
+_patch_cutlass_padded_fp8()
+
+
class OmniGen2Attention(nn.Module):
def __init__(
self,
@@ -31,6 +133,8 @@ def __init__(
num_heads: int,
num_kv_heads: int,
eps: float = 1e-5,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
):
super().__init__()
self.dim = dim
@@ -46,12 +150,26 @@ def __init__(
total_num_kv_heads=num_kv_heads,
disable_tp=True,
bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.to_qkv",
)
self.norm_q = RMSNorm(self.head_dim, eps=eps)
self.norm_k = RMSNorm(self.head_dim, eps=eps)
- self.to_out = nn.ModuleList([nn.Linear(dim, dim, bias=False)])
+ self.to_out = nn.ModuleList(
+ [
+ RowParallelLinear(
+ dim,
+ dim,
+ bias=False,
+ input_is_parallel=False,
+ quant_config=quant_config,
+ return_bias=False,
+ prefix=f"{prefix}.to_out.0",
+ )
+ ]
+ )
self.attn = Attention(
num_heads=num_heads,
head_size=self.head_dim,
@@ -78,6 +196,9 @@ def forward(
"""
batch_size = hidden_states.shape[0]
+ # Contiguous layout for FP8 quantized linear GEMMs (matches FLUX DiT).
+ hidden_states = hidden_states.contiguous()
+
# Get Query-Key-Value Pair
qkv, _ = self.to_qkv(hidden_states)
@@ -121,7 +242,7 @@ def forward(
hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(dtype)
- hidden_states = self.to_out[0](hidden_states)
+ hidden_states = self.to_out[0](hidden_states.contiguous())
return hidden_states
@@ -233,6 +354,7 @@ def __init__(
embedding_dim: int,
norm_eps: float,
norm_elementwise_affine: bool,
+ **kwargs,
):
super().__init__()
self.silu = nn.SiLU()
@@ -325,6 +447,8 @@ def __init__(
inner_dim: int,
multiple_of: int | None = 256,
ffn_dim_multiplier: float | None = None,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
):
super().__init__()
@@ -338,6 +462,8 @@ def __init__(
[inner_dim, inner_dim],
bias=False,
return_bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_up_proj",
)
self.act_fn = get_act_and_mul_fn("silu")
self.down_proj = RowParallelLinear(
@@ -346,6 +472,8 @@ def __init__(
bias=False,
input_is_parallel=True,
return_bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.down_proj",
)
def forward(self, x):
@@ -591,6 +719,8 @@ def __init__(
ffn_dim_multiplier: float,
norm_eps: float,
modulation: bool = True,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
) -> None:
"""Initialize the transformer block."""
super().__init__()
@@ -602,6 +732,8 @@ def __init__(
num_heads=num_attention_heads,
num_kv_heads=num_kv_heads,
eps=1e-5,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn",
)
# Initialize feed-forward network
@@ -610,11 +742,19 @@ def __init__(
inner_dim=4 * dim,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
+ quant_config=quant_config,
+ prefix=f"{prefix}.feed_forward",
)
# Initialize normalization layers
if modulation:
- self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True)
+ self.norm1 = LuminaRMSNormZero(
+ embedding_dim=dim,
+ norm_eps=norm_eps,
+ norm_elementwise_affine=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.norm1",
+ )
else:
self.norm1 = RMSNorm(dim, eps=norm_eps)
@@ -713,6 +853,7 @@ def __init__(
axes_lens: tuple[int, int, int] = (1024, 1664, 1664),
text_feat_dim: int = 2048,
timestep_scale: float = 1000.0,
+ quant_config: QuantizationConfig | None = None,
) -> None:
"""Initialize the OmniGen2 transformer model."""
super().__init__()
@@ -770,8 +911,10 @@ def __init__(
ffn_dim_multiplier,
norm_eps,
modulation=True,
+ quant_config=quant_config,
+ prefix=f"noise_refiner.{i}",
)
- for _ in range(num_refiner_layers)
+ for i in range(num_refiner_layers)
]
)
@@ -785,8 +928,10 @@ def __init__(
ffn_dim_multiplier,
norm_eps,
modulation=True,
+ quant_config=quant_config,
+ prefix=f"ref_image_refiner.{i}",
)
- for _ in range(num_refiner_layers)
+ for i in range(num_refiner_layers)
]
)
@@ -800,8 +945,10 @@ def __init__(
ffn_dim_multiplier,
norm_eps,
modulation=False,
+ quant_config=quant_config,
+ prefix=f"context_refiner.{i}",
)
- for _ in range(num_refiner_layers)
+ for i in range(num_refiner_layers)
]
)
@@ -816,8 +963,10 @@ def __init__(
ffn_dim_multiplier,
norm_eps,
modulation=True,
+ quant_config=quant_config,
+ prefix=f"layers.{i}",
)
- for _ in range(num_layers)
+ for i in range(num_layers)
]
)
@@ -847,11 +996,25 @@ def img_patch_embed_and_refine(
temb,
):
batch_size = len(hidden_states)
+ has_ref_tokens = any(ref_img_len > 0 for ref_lens in l_effective_ref_img_len for ref_img_len in ref_lens)
max_combined_img_len = max(
[img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)]
)
hidden_states = self.x_embedder(hidden_states)
+ if not has_ref_tokens:
+ # FP8 kernels do not support zero-token GEMM on ref_image_patch_embedder; skip that path only.
+ # Still run noise_refiner and return the same combined layout as the no-ref case below
+ # (batch, max_combined_img_len, hidden) — not raw noise tokens alone.
+ for layer in self.noise_refiner:
+ hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
+ combined_img_hidden_states = hidden_states.new_zeros(
+ batch_size, max_combined_img_len, self.config.hidden_size
+ )
+ for i, img_len in enumerate(l_effective_img_len):
+ combined_img_hidden_states[i, :img_len] = hidden_states[i, :img_len]
+ return combined_img_hidden_states
+
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
for i in range(batch_size):
diff --git a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py
index 2d370aea19..04720c932f 100644
--- a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py
+++ b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py
@@ -29,6 +29,7 @@
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
+from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.omnigen2.omnigen2_transformer import (
@@ -619,7 +620,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class OmniGen2Pipeline(nn.Module):
+class OmniGen2Pipeline(CFGParallelMixin, nn.Module):
"""
Pipeline for text-to-image generation using OmniGen2.
@@ -675,7 +676,10 @@ def __init__(
)
transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, OmniGen2Transformer2DModel)
- self.transformer = OmniGen2Transformer2DModel(**transformer_kwargs)
+ self.transformer = OmniGen2Transformer2DModel(
+ **transformer_kwargs,
+ quant_config=od_config.quantization_config,
+ )
self.mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model, subfolder="mllm", local_files_only=local_files_only
).to(self.device)
@@ -1171,7 +1175,14 @@ def processing(
self._num_timesteps = len(timesteps)
for i, t in enumerate(timesteps):
- model_pred = self.predict(
+ text_guidance_scale = (
+ self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
+ )
+ image_guidance_scale = (
+ self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
+ )
+
+ positive_kwargs = dict(
t=t,
latents=latents,
prompt_embeds=prompt_embeds,
@@ -1179,15 +1190,18 @@ def processing(
prompt_attention_mask=prompt_attention_mask,
ref_image_hidden_states=ref_latents,
)
- text_guidance_scale = (
- self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
- )
- image_guidance_scale = (
- self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
+ uncond_kwargs = dict(
+ t=t,
+ latents=latents,
+ prompt_embeds=negative_prompt_embeds,
+ freqs_cis=freqs_cis,
+ prompt_attention_mask=negative_prompt_attention_mask,
+ ref_image_hidden_states=None,
)
if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
- model_pred_ref = self.predict(
+ # 3-branch CFG: pos + ref_neg + uncond
+ ref_neg_kwargs = dict(
t=t,
latents=latents,
prompt_embeds=negative_prompt_embeds,
@@ -1195,31 +1209,24 @@ def processing(
prompt_attention_mask=negative_prompt_attention_mask,
ref_image_hidden_states=ref_latents,
)
-
- model_pred_uncond = self.predict(
- t=t,
- latents=latents,
- prompt_embeds=negative_prompt_embeds,
- freqs_cis=freqs_cis,
- prompt_attention_mask=negative_prompt_attention_mask,
- ref_image_hidden_states=None,
- )
-
- model_pred = (
- model_pred_uncond
- + image_guidance_scale * (model_pred_ref - model_pred_uncond)
- + text_guidance_scale * (model_pred - model_pred_ref)
+ model_pred = self.predict_noise_with_multi_branch_cfg(
+ do_true_cfg=True,
+ true_cfg_scale={
+ "text": text_guidance_scale,
+ "image": image_guidance_scale,
+ },
+ branches_kwargs=[positive_kwargs, ref_neg_kwargs, uncond_kwargs],
)
elif text_guidance_scale > 1.0:
- model_pred_uncond = self.predict(
- t=t,
- latents=latents,
- prompt_embeds=negative_prompt_embeds,
- freqs_cis=freqs_cis,
- prompt_attention_mask=negative_prompt_attention_mask,
- ref_image_hidden_states=None,
+ # 2-branch CFG: pos + uncond
+ model_pred = self.predict_noise_with_multi_branch_cfg(
+ do_true_cfg=True,
+ true_cfg_scale=text_guidance_scale,
+ branches_kwargs=[positive_kwargs, uncond_kwargs],
)
- model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)
+ else:
+ # No CFG
+ model_pred = self.predict_noise(**positive_kwargs)
latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
@@ -1249,8 +1256,6 @@ def predict(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
- batch_size, num_channels_latents, height, width = latents.shape
-
optional_kwargs = {}
if "ref_image_hidden_states" in set(inspect.signature(self.transformer.forward).parameters.keys()):
optional_kwargs["ref_image_hidden_states"] = ref_image_hidden_states
@@ -1265,6 +1270,21 @@ def predict(
)
return model_pred
+ def predict_noise(self, **kwargs):
+ """Override CFGParallelMixin.predict_noise to use self.predict."""
+ return self.predict(**kwargs)
+
+ def combine_multi_branch_cfg_noise(self, predictions, true_cfg_scale, cfg_normalize=False):
+ """Override: 3-branch dual scale or 2-branch standard CFG."""
+ if len(predictions) == 3:
+ text_scale = true_cfg_scale["text"]
+ image_scale = true_cfg_scale["image"]
+ pos, ref, uncond = predictions[0], predictions[1], predictions[2]
+ return uncond + image_scale * (ref - uncond) + text_scale * (pos - ref)
+ # 2-branch: standard CFG
+ pos, neg = predictions[0], predictions[1]
+ return neg + true_cfg_scale * (pos - neg)
+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
diff --git a/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py b/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py
index 568e2f5164..c330e91de8 100644
--- a/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py
+++ b/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py
@@ -16,6 +16,7 @@
from collections.abc import Iterable
from typing import ClassVar
+import numpy as np
import torch
from tokenizers import Tokenizer as HFTokenizer
from torch import nn
@@ -30,6 +31,13 @@
from vllm_omni.model_executor.models.omnivoice.omnivoice_decoder import OmniVoiceDecoder
from vllm_omni.model_executor.models.omnivoice.omnivoice_generator import OmniVoiceGenerator
+try:
+ from transformers import HiggsAudioV2TokenizerModel
+except ImportError:
+ HiggsAudioV2TokenizerModel = None
+
+import torchaudio
+
logger = init_logger(__name__)
@@ -79,6 +87,17 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
tokenizer_path = os.path.join(self.model_path, "tokenizer.json")
self.tokenizer = HFTokenizer.from_file(tokenizer_path)
+ # Audio tokenizer for voice cloning (requires transformers>=5.3)
+ if HiggsAudioV2TokenizerModel is not None:
+ audio_tokenizer_path = os.path.join(self.model_path, "audio_tokenizer")
+ self.audio_tokenizer = HiggsAudioV2TokenizerModel.from_pretrained(
+ audio_tokenizer_path, device_map=self.device
+ ).eval()
+ logger.info("HiggsAudioV2 tokenizer loaded for voice cloning on %s", self.device)
+ else:
+ self.audio_tokenizer = None
+ logger.warning("Voice cloning disabled (requires transformers>=5.3.0).")
+
# Duration estimator
self.duration_estimator = RuleDurationEstimator()
@@ -91,20 +110,46 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
self.class_temperature = self.config.class_temperature
self.sample_rate = self.config.sample_rate
+ def _encode_ref_audio(self, audio_signal: torch.Tensor, sr: int) -> torch.Tensor:
+ """Encode reference audio to 8-codebook tokens for voice cloning."""
+ if self.audio_tokenizer is None:
+ raise RuntimeError("Audio tokenizer not available for voice cloning")
+ if audio_signal.dim() == 1:
+ audio_signal = audio_signal.unsqueeze(0)
+ # Resample to tokenizer's expected sample rate
+ target_sr = self.audio_tokenizer.config.sample_rate
+ if sr != target_sr:
+ audio_signal = torchaudio.functional.resample(audio_signal, sr, target_sr)
+ # Ensure mono [B, 1, samples]
+ if audio_signal.dim() == 2:
+ audio_signal = audio_signal.unsqueeze(1)
+ with torch.inference_mode():
+ tokens = self.audio_tokenizer.encode(
+ audio_signal.to(self.audio_tokenizer.device), return_dict=False
+ ) # [B, 8, T_ref]
+ tokens = tokens.squeeze(0) # [8, T_ref]
+ return tokens
+
@torch.inference_mode()
def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
- """Generate speech audio from text.
-
- Args:
- req: Diffusion request containing text prompt(s).
+ """Generate speech audio from text, optionally with voice cloning.
- Returns:
- DiffusionOutput with audio tensor in .output
+ Accepts either a plain text prompt or a structured dict:
+ {"text": "...", "ref_audio": (samples, sr), "ref_text": "...",
+ "lang": "...", "instruct": "..."}
"""
- # Extract text from request
prompt = req.prompts[0] if req.prompts else ""
+ ref_audio = None
+ ref_text = None
+ lang = "None"
+ instruct = "None"
+
if isinstance(prompt, dict):
text = prompt.get("input", prompt.get("text", str(prompt)))
+ ref_audio = prompt.get("ref_audio")
+ ref_text = prompt.get("ref_text")
+ lang = prompt.get("lang") or "None"
+ instruct = prompt.get("instruct") or "None"
else:
text = str(prompt)
@@ -119,17 +164,37 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
target_len = self.duration_estimator.estimate_duration(text, "Nice to meet you.", 25)
target_len = max(1, int(target_len))
- # Tokenize with control tokens
- style = "<|denoise|><|lang_start|>None<|lang_end|><|instruct_start|>None<|instruct_end|>"
- full_prompt = f"{style}<|text_start|>{text}<|text_end|>"
+ # Build text prompt with control tokens
+ style = f"<|denoise|><|lang_start|>{lang}<|lang_end|><|instruct_start|>{instruct}<|instruct_end|>"
+ if ref_text:
+ full_text = f"{ref_text} {text}"
+ else:
+ full_text = text
+ full_prompt = f"{style}<|text_start|>{full_text}<|text_end|>"
encoding = self.tokenizer.encode(full_prompt)
text_tokens = torch.tensor(encoding.ids, dtype=torch.long, device=device)
text_len = text_tokens.shape[0]
+ # Encode reference audio tokens if provided
+ ref_audio_tokens = None
+ if ref_audio is not None:
+ if self.audio_tokenizer is None:
+ raise RuntimeError(
+ "Voice cloning requires transformers>=5.3.0. Try: uv pip install 'transformers>=5.3.0'"
+ )
+ audio_signal, sr = ref_audio
+ if isinstance(audio_signal, np.ndarray):
+ audio_signal = torch.from_numpy(audio_signal).float()
+ ref_audio_tokens = self._encode_ref_audio(audio_signal, int(sr)).to(device)
+
# Build conditional + unconditional batches [2, 8, max_len]
text_ids = text_tokens.unsqueeze(0).repeat(num_cb, 1)
target_ids = torch.full((num_cb, target_len), mask_id, dtype=torch.long, device=device)
- cond_ids = torch.cat([text_ids, target_ids], dim=1)
+
+ if ref_audio_tokens is not None:
+ cond_ids = torch.cat([text_ids, ref_audio_tokens, target_ids], dim=1)
+ else:
+ cond_ids = torch.cat([text_ids, target_ids], dim=1)
cond_len = cond_ids.shape[1]
uncond_ids = target_ids.clone()
diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py
index 9f75c84538..9ef0cacd5a 100644
--- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py
+++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py
@@ -34,6 +34,9 @@
)
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
+from vllm_omni.diffusion.utils.prompt_utils import (
+ validate_prompt_sequence_lengths,
+)
from vllm_omni.diffusion.utils.size_utils import (
normalize_min_aligned_size,
)
@@ -363,8 +366,10 @@ def check_inputs(
"that was used to generate `negative_prompt_embeds`."
)
- if max_sequence_length is not None and max_sequence_length > 1024:
- raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+ if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length:
+ raise ValueError(
+ f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}"
+ )
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
@@ -378,6 +383,8 @@ def _get_qwen_prompt_embeds(
self,
prompt: str | list[str] = None,
dtype: torch.dtype | None = None,
+ max_sequence_length: int | None = None,
+ prompt_name: str = "prompt",
):
dtype = dtype or self.text_encoder.dtype
@@ -388,12 +395,27 @@ def _get_qwen_prompt_embeds(
txt = [template.format(e) for e in prompt]
txt_tokens = self.tokenizer(
txt,
- max_length=self.tokenizer_max_length + drop_idx,
padding=True,
- truncation=True,
+ truncation=False,
+ return_tensors="pt",
+ ).to(self.device)
+ # Validate only the user prompt contribution. The Qwen template also
+ # adds a fixed suffix after the user text, so subtracting only
+ # prompt_template_encode_start_idx would overcount near-limit prompts.
+ template_tokens = self.tokenizer(
+ [template.format("")],
+ padding=True,
+ truncation=False,
return_tensors="pt",
).to(self.device)
- # print(f"attention mask: {txt_tokens.attention_mask}")
+ validate_prompt_sequence_lengths(
+ txt_tokens.attention_mask,
+ max_sequence_length=max_sequence_length or self.tokenizer_max_length,
+ supported_max_sequence_length=self.tokenizer_max_length,
+ prompt_name=prompt_name,
+ baseline_attention_mask=template_tokens.attention_mask,
+ error_context="after applying the Qwen prompt template",
+ )
encoder_hidden_states = self.text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
@@ -422,6 +444,7 @@ def encode_prompt(
prompt_embeds: torch.Tensor | None = None,
prompt_embeds_mask: torch.Tensor | None = None,
max_sequence_length: int = 1024,
+ prompt_name: str = "prompt",
):
r"""
@@ -439,7 +462,11 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
if prompt_embeds is None:
- prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt)
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
+ prompt,
+ max_sequence_length=max_sequence_length,
+ prompt_name=prompt_name,
+ )
prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
@@ -632,6 +659,7 @@ def _prepare_generation_context(
prompt_embeds_mask=negative_prompt_embeds_mask,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
+ prompt_name="negative_prompt",
)
else:
negative_prompt_embeds = None
@@ -703,7 +731,7 @@ def prepare_encode(
num_images_per_prompt=sampling.num_outputs_per_prompt if sampling.num_outputs_per_prompt > 0 else 1,
generator=sampling.generator,
true_cfg_scale=sampling.true_cfg_scale or 4.0,
- max_sequence_length=sampling.max_sequence_length or 512,
+ max_sequence_length=sampling.max_sequence_length or self.tokenizer_max_length,
attention_kwargs=kwargs.get("attention_kwargs"),
)
@@ -934,7 +962,7 @@ def forward(
output_type: str | None = "pil",
attention_kwargs: dict[str, Any] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
- max_sequence_length: int = 512,
+ max_sequence_length: int = 1024,
) -> DiffusionOutput:
extracted_prompt, negative_prompt = self._extract_prompts(req.prompts)
prompt = extracted_prompt or prompt
diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py
index dd77d71b1e..cef7fe473a 100644
--- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py
+++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py
@@ -37,6 +37,9 @@
)
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
+from vllm_omni.diffusion.utils.prompt_utils import (
+ validate_prompt_sequence_lengths,
+)
from vllm_omni.diffusion.utils.size_utils import (
normalize_min_aligned_size,
)
@@ -323,8 +326,10 @@ def check_inputs(
"that was used to generate `negative_prompt_embeds`."
)
- if max_sequence_length is not None and max_sequence_length > 1024:
- raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+ if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length:
+ raise ValueError(
+ f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}"
+ )
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
@@ -384,6 +389,8 @@ def _get_qwen_prompt_embeds(
prompt: str | list[str] = None,
image: PIL.Image.Image | torch.Tensor | None = None,
dtype: torch.dtype | None = None,
+ max_sequence_length: int | None = None,
+ prompt_name: str = "prompt",
):
"""Get prompt embeddings with image support for editing."""
dtype = dtype or self.text_encoder.dtype
@@ -393,6 +400,33 @@ def _get_qwen_prompt_embeds(
template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
+ txt_tokens = self.tokenizer(
+ txt,
+ padding=True,
+ truncation=False,
+ return_tensors="pt",
+ ).to(self.device)
+ # The edit template contains fixed multimodal scaffolding around the
+ # instruction. Validate against the empty-template baseline so image
+ # placeholder text does not consume the user's text budget.
+ template_tokens = self.tokenizer(
+ [template.format("")],
+ padding=True,
+ truncation=False,
+ return_tensors="pt",
+ ).to(self.device)
+ # Qwen-Image-Edit expands image placeholders into many vision tokens
+ # inside the processor. `max_sequence_length` is meant to constrain the
+ # prompt text length, so validate on the text template before image
+ # token expansion.
+ validate_prompt_sequence_lengths(
+ txt_tokens.attention_mask,
+ max_sequence_length=max_sequence_length or self.tokenizer_max_length,
+ supported_max_sequence_length=self.tokenizer_max_length,
+ prompt_name=prompt_name,
+ baseline_attention_mask=template_tokens.attention_mask,
+ error_context="after applying the Qwen prompt template",
+ )
# Use processor to handle both text and image inputs
model_inputs = self.processor(
@@ -434,6 +468,7 @@ def encode_prompt(
prompt_embeds: torch.Tensor | None = None,
prompt_embeds_mask: torch.Tensor | None = None,
max_sequence_length: int = 1024,
+ prompt_name: str = "prompt",
):
r"""
@@ -453,7 +488,12 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
if prompt_embeds is None:
- prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image)
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
+ prompt,
+ image,
+ max_sequence_length=max_sequence_length,
+ prompt_name=prompt_name,
+ )
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -624,7 +664,7 @@ def forward(
output_type: str | None = "pil",
attention_kwargs: dict[str, Any] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
- max_sequence_length: int = 512,
+ max_sequence_length: int = 1024,
) -> DiffusionOutput:
"""Forward pass for image editing."""
# TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "")
@@ -739,6 +779,7 @@ def forward(
prompt_embeds_mask=negative_prompt_embeds_mask,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
+ prompt_name="negative_prompt",
)
num_channels_latents = self.transformer.in_channels // 4
diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py
index 6f6c9d2ba3..2e25d0fe6b 100644
--- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py
+++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py
@@ -25,6 +25,7 @@
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
+from vllm_omni.diffusion.model_metadata import QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.qwen_image.cfg_parallel import (
QwenImageCFGParallelMixin,
@@ -40,6 +41,9 @@
)
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
+from vllm_omni.diffusion.utils.prompt_utils import (
+ validate_prompt_sequence_lengths,
+)
from vllm_omni.diffusion.utils.size_utils import (
normalize_min_aligned_size,
)
@@ -53,6 +57,12 @@
CONDITION_IMAGE_SIZE = 384 * 384
VAE_IMAGE_SIZE = 1024 * 1024
+# Keep this in sync with the practical conditioning-token budget for
+# Qwen-Image-Edit-2511. Empirically, 4 images stays within the supported range
+# while 5 images overflows the prompt/conditioning path and fails downstream.
+# Re-export the shared metadata value locally so this pipeline keeps a nearby,
+# descriptive constant for validation and tests without becoming the source of truth.
+MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES = QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES
def get_qwen_image_edit_plus_pre_process_func(
@@ -90,6 +100,11 @@ def pre_process_func(
if not isinstance(raw_image, list):
raw_image = [raw_image]
+ if len(raw_image) > MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES:
+ raise ValueError(
+ f"Received {len(raw_image)} input images. "
+ f"At most {MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES} images are supported by this model."
+ )
image = [
PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image | np.ndarray | torch.Tensor, im)
for im in raw_image
@@ -283,8 +298,10 @@ def check_inputs(
"that was used to generate `negative_prompt_embeds`."
)
- if max_sequence_length is not None and max_sequence_length > 1024:
- raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+ if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length:
+ raise ValueError(
+ f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}"
+ )
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
@@ -299,6 +316,8 @@ def _get_qwen_prompt_embeds(
prompt: str | list[str],
image: list[torch.Tensor] | torch.Tensor | None = None,
dtype: torch.dtype | None = None,
+ max_sequence_length: int | None = None,
+ prompt_name: str = "prompt",
):
"""Get prompt embeddings with support for multiple images."""
dtype = dtype or self.text_encoder.dtype
@@ -319,6 +338,32 @@ def _get_qwen_prompt_embeds(
template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(base_img_prompt + e) for e in prompt]
+ txt_tokens = self.tokenizer(
+ txt,
+ padding=True,
+ truncation=False,
+ return_tensors="pt",
+ ).to(self.device)
+ # Multi-image edit prepends "Picture N" placeholders before the user
+ # instruction. Subtract the placeholder-aware baseline so attached
+ # images do not reduce the remaining prompt budget.
+ template_tokens = self.tokenizer(
+ [template.format(base_img_prompt)],
+ padding=True,
+ truncation=False,
+ return_tensors="pt",
+ ).to(self.device)
+ # The processor expands image placeholders into many vision tokens.
+ # `max_sequence_length` should guard the prompt text length before that
+ # multimodal expansion happens.
+ validate_prompt_sequence_lengths(
+ txt_tokens.attention_mask,
+ max_sequence_length=max_sequence_length or self.tokenizer_max_length,
+ supported_max_sequence_length=self.tokenizer_max_length,
+ prompt_name=prompt_name,
+ baseline_attention_mask=template_tokens.attention_mask,
+ error_context="after applying the Qwen prompt template",
+ )
# Use processor to handle both text and image inputs
model_inputs = self.processor(
@@ -360,6 +405,7 @@ def encode_prompt(
prompt_embeds: torch.Tensor | None = None,
prompt_embeds_mask: torch.Tensor | None = None,
max_sequence_length: int = 1024,
+ prompt_name: str = "prompt",
):
r"""
@@ -379,7 +425,12 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
if prompt_embeds is None:
- prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image)
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
+ prompt,
+ image,
+ max_sequence_length=max_sequence_length,
+ prompt_name=prompt_name,
+ )
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -557,7 +608,7 @@ def forward(
output_type: str | None = "pil",
attention_kwargs: dict[str, Any] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
- max_sequence_length: int = 512,
+ max_sequence_length: int = 1024,
) -> DiffusionOutput:
"""Forward pass for image editing with support for multiple images."""
# TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "")
@@ -692,6 +743,7 @@ def forward(
prompt_embeds_mask=negative_prompt_embeds_mask,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
+ prompt_name="negative_prompt",
)
num_channels_latents = self.transformer.in_channels // 4
diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py
index 38866d89c5..905ef5b424 100644
--- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py
+++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py
@@ -36,6 +36,9 @@
)
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
+from vllm_omni.diffusion.utils.prompt_utils import (
+ validate_prompt_sequence_lengths,
+)
from vllm_omni.diffusion.utils.size_utils import (
normalize_min_aligned_size,
)
@@ -340,8 +343,10 @@ def check_inputs(
"generate `negative_prompt_embeds`."
)
- if max_sequence_length is not None and max_sequence_length > 1024:
- raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+ if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length:
+ raise ValueError(
+ f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}"
+ )
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
@@ -356,6 +361,8 @@ def _get_qwen_prompt_embeds(
prompt: str | list[str] | None = None,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
+ max_sequence_length: int | None = None,
+ prompt_name: str = "prompt",
):
device = device or self.device
dtype = dtype or self.text_encoder.dtype
@@ -368,8 +375,26 @@ def _get_qwen_prompt_embeds(
txt_tokens = self.tokenizer(
txt,
padding=True,
+ truncation=False,
+ return_tensors="pt",
+ ).to(device)
+ # The layered template also appends fixed non-user tokens after the
+ # editable text, so use the empty-template tokenized baseline instead of
+ # counting everything after prompt_template_encode_start_idx.
+ template_tokens = self.tokenizer(
+ [template.format("")],
+ padding=True,
+ truncation=False,
return_tensors="pt",
).to(device)
+ validate_prompt_sequence_lengths(
+ txt_tokens.attention_mask,
+ max_sequence_length=max_sequence_length or self.tokenizer_max_length,
+ supported_max_sequence_length=self.tokenizer_max_length,
+ prompt_name=prompt_name,
+ baseline_attention_mask=template_tokens.attention_mask,
+ error_context="after applying the Qwen prompt template",
+ )
encoder_hidden_states = self.text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
@@ -399,6 +424,7 @@ def encode_prompt(
prompt_embeds: torch.Tensor | None = None,
prompt_embeds_mask: torch.Tensor | None = None,
max_sequence_length: int = 1024,
+ prompt_name: str = "prompt",
):
r"""
@@ -419,7 +445,11 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
if prompt_embeds is None:
- prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt)
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
+ prompt,
+ max_sequence_length=max_sequence_length,
+ prompt_name=prompt_name,
+ )
prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
@@ -603,7 +633,7 @@ def forward(
negative_prompt_embeds_mask: torch.Tensor | None = None,
output_type: str | None = "pil",
attention_kwargs: dict[str, Any] | None = None,
- max_sequence_length: int = 512,
+ max_sequence_length: int = 1024,
resolution: int = 640,
cfg_normalize: bool = False,
use_en_prompt: bool = False,
@@ -736,6 +766,7 @@ def forward(
device=self.device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
+ prompt_name="negative_prompt",
)
# 4. Prepare latent variables
diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py
index b34f19e954..88a66d7f6b 100644
--- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py
+++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py
@@ -169,12 +169,15 @@ def __init__(
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ # Time embedding MLP is kept full precision (quant_config=None) —
+ # small layers that feed per-block modulation; precision-sensitive
+ # (see #2728).
self.timestep_embedder.linear_1 = ReplicatedLinear(
256,
embedding_dim,
bias=True,
return_bias=False,
- quant_config=quant_config,
+ quant_config=None,
prefix="timestep_embedder.linear_1",
)
self.timestep_embedder.linear_2 = ReplicatedLinear(
@@ -182,7 +185,7 @@ def __init__(
embedding_dim,
bias=True,
return_bias=False,
- quant_config=quant_config,
+ quant_config=None,
prefix="timestep_embedder.linear_2",
)
self.use_additional_t_cond = use_additional_t_cond
@@ -701,7 +704,10 @@ def __init__(
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
- # Image processing modules
+ # Image processing modules.
+ # Modulation linear is kept full precision (quant_config=None) — it
+ # produces shift/scale/gate values that are precision-sensitive
+ # (see #2728).
self.img_mod = nn.Sequential(
nn.SiLU(),
ReplicatedLinear(
@@ -709,7 +715,7 @@ def __init__(
6 * dim,
bias=True,
return_bias=False,
- quant_config=quant_config,
+ quant_config=None,
prefix="img_mod.1",
),
)
@@ -725,7 +731,7 @@ def __init__(
self.img_norm2 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps)
self.img_mlp = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config, prefix="img_mlp")
- # Text processing modules
+ # Text processing modules.
self.txt_mod = nn.Sequential(
nn.SiLU(),
ReplicatedLinear(
@@ -733,7 +739,7 @@ def __init__(
6 * dim,
bias=True,
return_bias=False,
- quant_config=quant_config,
+ quant_config=None,
prefix="txt_mod.1",
),
)
@@ -744,9 +750,9 @@ def __init__(
self.zero_cond_t = zero_cond_t
- def _modulate(self, x, mod_params, index=None):
+ def _modulate(self, mod_params, index=None):
"""Apply modulation to input tensor"""
- # x: b l d, shift: b d, scale: b d, gate: b d
+ # shift: b d, scale: b d, gate: b d
shift, scale, gate = mod_params.chunk(3, dim=-1)
if index is not None:
@@ -778,7 +784,7 @@ def _modulate(self, x, mod_params, index=None):
scale_result = scale.unsqueeze(1)
gate_result = gate.unsqueeze(1)
- return x * (1 + scale_result) + shift_result, gate_result
+ return scale_result, shift_result, gate_result
def forward(
self,
@@ -804,10 +810,12 @@ def forward(
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
# Process image stream - norm1 + modulation
- img_modulated, img_gate1 = self.img_norm1(hidden_states, img_mod1, modulate_index)
+ img_scale1, img_shift1, img_gate1 = self._modulate(img_mod1, modulate_index)
+ img_modulated = self.img_norm1(hidden_states, img_scale1, img_shift1)
# Process text stream - norm1 + modulation
- txt_modulated, txt_gate1 = self.txt_norm1(encoder_hidden_states, txt_mod1)
+ txt_scale1, txt_shift1, txt_gate1 = self._modulate(txt_mod1)
+ txt_modulated = self.txt_norm1(encoder_hidden_states, txt_scale1, txt_shift1)
# Use QwenAttnProcessor2_0 for joint attention computation
# This directly implements the DoubleStreamLayerMegatron logic:
@@ -832,13 +840,16 @@ def forward(
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
# Process image stream - norm2 + MLP
- img_modulated2, img_gate2 = self.img_norm2(hidden_states, img_mod2, modulate_index)
+ img_scale2, img_shift2, img_gate2 = self._modulate(img_mod2, modulate_index)
+ img_modulated2 = self.img_norm2(hidden_states, img_scale2, img_shift2)
img_mlp_output = self.img_mlp(img_modulated2)
hidden_states = hidden_states + img_gate2 * img_mlp_output
# Process text stream - norm2 + MLP
- txt_modulated2, txt_gate2 = self.txt_norm2(encoder_hidden_states, txt_mod2)
+ txt_scale2, txt_shift2, txt_gate2 = self._modulate(txt_mod2)
+ txt_modulated2 = self.txt_norm2(encoder_hidden_states, txt_scale2, txt_shift2)
+
txt_mlp_output = self.txt_mlp(txt_modulated2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
@@ -958,12 +969,14 @@ def __init__(
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
+ # Entry projections (image/text) are kept full precision —
+ # small sensitive layers at the network boundary (see #2728).
self.img_in = ReplicatedLinear(
in_channels,
self.inner_dim,
bias=True,
return_bias=False,
- quant_config=quant_config,
+ quant_config=None,
prefix="img_in",
)
self.txt_in = ReplicatedLinear(
@@ -971,7 +984,7 @@ def __init__(
self.inner_dim,
bias=True,
return_bias=False,
- quant_config=quant_config,
+ quant_config=None,
prefix="txt_in",
)
@@ -988,13 +1001,16 @@ def __init__(
]
)
+ # Final modulation and output projection are kept full precision —
+ # they produce the output latent and are precision-sensitive
+ # (see #2728).
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.norm_out.linear = ReplicatedLinear(
self.inner_dim,
2 * self.inner_dim,
bias=True,
return_bias=False,
- quant_config=quant_config,
+ quant_config=None,
prefix="norm_out.linear",
)
self.proj_out = ReplicatedLinear(
@@ -1002,7 +1018,7 @@ def __init__(
patch_size * patch_size * self.out_channels,
bias=True,
return_bias=False,
- quant_config=quant_config,
+ quant_config=None,
prefix="proj_out",
)
diff --git a/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py b/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py
index 22d56ac1fd..4a4892673f 100644
--- a/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py
+++ b/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py
@@ -375,6 +375,8 @@ class StableAudioDiTModel(nn.Module):
- Output: [B, out_channels, L]
"""
+ _repeated_blocks = ["StableAudioDiTBlock"]
+
def __init__(
self,
od_config: OmniDiffusionConfig | None = None,
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
index a550e576f0..652425d509 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
@@ -24,14 +24,60 @@
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero
from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler
+from vllm_omni.diffusion.models.wan2_2.scheduling_wan_euler import WanEulerScheduler
from vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import WanTransformer3DModel
+from vllm_omni.diffusion.postprocess import interpolate_video_tensor
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
+from vllm_omni.diffusion.utils.prompt_utils import (
+ validate_prompt_sequence_lengths,
+)
from vllm_omni.inputs.data import OmniTextPrompt
from vllm_omni.platforms import current_omni_platform
logger = logging.getLogger(__name__)
DEBUG_PERF = False
+WAN_SAMPLE_SOLVER_CHOICES = {"unipc", "euler"}
+WAN22_MAX_SEQUENCE_LENGTH = 512
+
+
+def build_wan_scheduler(sample_solver: str, flow_shift: float) -> Any:
+ if sample_solver == "unipc":
+ return FlowUniPCMultistepScheduler(
+ num_train_timesteps=1000,
+ shift=flow_shift,
+ prediction_type="flow_prediction",
+ )
+ if sample_solver == "euler":
+ return WanEulerScheduler(
+ num_train_timesteps=1000,
+ shift=flow_shift,
+ )
+
+ raise ValueError(
+ f"Unsupported Wan sample_solver: {sample_solver}. Expected one of: {sorted(WAN_SAMPLE_SOLVER_CHOICES)}"
+ )
+
+
+def resolve_wan_sample_solver(req: OmniDiffusionRequest, default: str = "unipc") -> str:
+ extra_args = getattr(req.sampling_params, "extra_args", {}) or {}
+ raw = extra_args.get("sample_solver", default)
+ sample_solver = str(raw).strip().lower()
+ if sample_solver not in WAN_SAMPLE_SOLVER_CHOICES:
+ raise ValueError(f"Invalid sample_solver={raw!r}. Expected one of: {sorted(WAN_SAMPLE_SOLVER_CHOICES)}")
+ return sample_solver
+
+
+def resolve_wan_flow_shift(req: OmniDiffusionRequest, od_config: OmniDiffusionConfig) -> float:
+ extra_args = getattr(req.sampling_params, "extra_args", {}) or {}
+ raw_flow_shift = extra_args.get("flow_shift")
+ if raw_flow_shift is None:
+ raw_flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0
+
+ try:
+ return float(raw_flow_shift)
+ except (TypeError, ValueError) as exc:
+ raise ValueError(f"Invalid flow_shift={raw_flow_shift!r}. flow_shift must be a float.") from exc
def retrieve_latents(
@@ -121,10 +167,23 @@ def get_wan22_post_process_func(
def post_process_func(
video: torch.Tensor,
output_type: str = "np",
+ sampling_params=None,
):
if output_type == "latent":
return video
- return video_processor.postprocess_video(video, output_type=output_type)
+ custom_output = {}
+ if sampling_params is not None and getattr(sampling_params, "enable_frame_interpolation", False):
+ video, multiplier = interpolate_video_tensor(
+ video,
+ exp=sampling_params.frame_interpolation_exp,
+ scale=sampling_params.frame_interpolation_scale,
+ model_path=sampling_params.frame_interpolation_model_path,
+ )
+ custom_output["video_fps_multiplier"] = multiplier
+ return {
+ "video": video_processor.postprocess_video(video, output_type=output_type),
+ "custom_output": custom_output,
+ }
return post_process_func
@@ -234,6 +293,7 @@ def __init__(
pass
self.boundary_ratio = od_config.boundary_ratio
+ self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH
# Determine which transformers to load based on boundary_ratio
# boundary_ratio=1.0: only load transformer_2 (low-noise stage only)
@@ -296,13 +356,9 @@ def __init__(
else:
raise RuntimeError("No transformer loaded")
- # Initialize UniPC scheduler
- flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p
- self.scheduler = FlowUniPCMultistepScheduler(
- num_train_timesteps=1000,
- shift=flow_shift,
- prediction_type="flow_prediction",
- )
+ self._sample_solver = "unipc"
+ self._flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0
+ self.scheduler = build_wan_scheduler(self._sample_solver, self._flow_shift)
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
@@ -336,6 +392,102 @@ def num_timesteps(self):
def current_timestep(self):
return self._current_timestep
+ def diffuse(
+ self,
+ latents: torch.Tensor,
+ timesteps: torch.Tensor,
+ prompt_embeds: torch.Tensor,
+ negative_prompt_embeds: torch.Tensor | None,
+ guidance_low: float,
+ guidance_high: float,
+ boundary_timestep: float | None,
+ dtype: torch.dtype,
+ attention_kwargs: dict[str, Any],
+ latent_condition: torch.Tensor | None = None,
+ first_frame_mask: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ with self.progress_bar(total=len(timesteps)) as pbar:
+ for t in timesteps:
+ self._current_timestep = t
+
+ # Select model based on timestep and boundary_ratio
+ # High noise stage (t >= boundary_timestep): use transformer
+ # Low noise stage (t < boundary_timestep): use transformer_2
+ if boundary_timestep is not None and t < boundary_timestep:
+ # Low noise stage - always use guidance_high for this stage
+ current_guidance_scale = guidance_high
+ if self.transformer_2 is not None:
+ current_model = self.transformer_2
+ elif self.transformer is not None:
+ # Fallback to transformer if transformer_2 not loaded
+ current_model = self.transformer
+ else:
+ raise RuntimeError("No transformer available for low-noise stage")
+ else:
+ # High noise stage - always use guidance_low for this stage
+ current_guidance_scale = guidance_low
+ if self.transformer is not None:
+ current_model = self.transformer
+ elif self.transformer_2 is not None:
+ # Fallback to transformer_2 if transformer not loaded
+ current_model = self.transformer_2
+ else:
+ raise RuntimeError("No transformer available for high-noise stage")
+
+ if self.expand_timesteps and latent_condition is not None:
+ # I2V mode: blend condition with latents using mask
+ latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents
+ latent_model_input = latent_model_input.to(dtype)
+
+ # Expand timesteps per patch - use floor division to match patch embedding
+ patch_size = self.transformer_config.patch_size
+ patch_height = latents.shape[3] // patch_size[1]
+ patch_width = latents.shape[4] // patch_size[2]
+
+ # Create mask at patch resolution (same as hidden states sequence length)
+ patch_mask = first_frame_mask[:, :, :, :: patch_size[1], :: patch_size[2]]
+ patch_mask = patch_mask[:, :, :, :patch_height, :patch_width] # Ensure correct dimensions
+ temp_ts = (patch_mask[0][0] * t).flatten()
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+ else:
+ # T2V mode: standard forward
+ latent_model_input = latents.to(dtype)
+ timestep = t.expand(latents.shape[0])
+
+ do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None
+ positive_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": prompt_embeds,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ "current_model": current_model,
+ }
+ if do_true_cfg:
+ negative_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": negative_prompt_embeds,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ "current_model": current_model,
+ }
+ else:
+ negative_kwargs = None
+
+ noise_pred = self.predict_noise_maybe_with_cfg(
+ do_true_cfg=do_true_cfg,
+ true_cfg_scale=current_guidance_scale,
+ positive_kwargs=positive_kwargs,
+ negative_kwargs=negative_kwargs,
+ cfg_normalize=False,
+ )
+
+ latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
+ pbar.update()
+
+ return latents
+
def forward(
self,
req: OmniDiffusionRequest,
@@ -413,6 +565,7 @@ def forward(
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale_2=guidance_high if boundary_ratio is not None else None,
boundary_ratio=boundary_ratio,
+ max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
)
if num_frames % self.vae_scale_factor_temporal != 1:
@@ -446,7 +599,7 @@ def forward(
negative_prompt=negative_prompt,
do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0,
num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1,
- max_sequence_length=req.sampling_params.max_sequence_length or 512,
+ max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
device=device,
dtype=dtype,
)
@@ -462,6 +615,13 @@ def forward(
current_omni_platform.synchronize()
_t_text_enc_ms = (time.perf_counter() - _t_text_enc_start) * 1000
+ sample_solver = resolve_wan_sample_solver(req, default=self._sample_solver)
+ flow_shift = resolve_wan_flow_shift(req, self.od_config)
+ if sample_solver != self._sample_solver or abs(flow_shift - self._flow_shift) > 1e-6:
+ self.scheduler = build_wan_scheduler(sample_solver, flow_shift)
+ self._sample_solver = sample_solver
+ self._flow_shift = flow_shift
+
# Timesteps
self.scheduler.set_timesteps(num_steps, device=device)
timesteps = self.scheduler.timesteps
@@ -571,90 +731,19 @@ def forward(
if DEBUG_PERF:
_t_denoise_start = time.perf_counter()
- with self.progress_bar(total=len(timesteps)) as pbar:
- for t in timesteps:
- self._current_timestep = t
-
- # Select model based on timestep and boundary_ratio
- # High noise stage (t >= boundary_timestep): use transformer
- # Low noise stage (t < boundary_timestep): use transformer_2
- if boundary_timestep is not None and t < boundary_timestep:
- # Low noise stage - always use guidance_high for this stage
- current_guidance_scale = guidance_high
- if self.transformer_2 is not None:
- current_model = self.transformer_2
- elif self.transformer is not None:
- # Fallback to transformer if transformer_2 not loaded
- current_model = self.transformer
- else:
- raise RuntimeError("No transformer available for low-noise stage")
- else:
- # High noise stage - always use guidance_low for this stage
- current_guidance_scale = guidance_low
- if self.transformer is not None:
- current_model = self.transformer
- elif self.transformer_2 is not None:
- # Fallback to transformer_2 if transformer not loaded
- current_model = self.transformer_2
- else:
- raise RuntimeError("No transformer available for high-noise stage")
-
- if self.expand_timesteps and latent_condition is not None:
- # I2V mode: blend condition with latents using mask
- latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents
- latent_model_input = latent_model_input.to(dtype)
-
- # Expand timesteps per patch - use floor division to match patch embedding
- patch_size = self.transformer_config.patch_size
- num_latent_frames = latents.shape[2]
- patch_height = latents.shape[3] // patch_size[1]
- patch_width = latents.shape[4] // patch_size[2]
-
- # Create mask at patch resolution (same as hidden states sequence length)
- patch_mask = first_frame_mask[:, :, :, :: patch_size[1], :: patch_size[2]]
- patch_mask = patch_mask[:, :, :, :patch_height, :patch_width] # Ensure correct dimensions
- temp_ts = (patch_mask[0][0] * t).flatten()
- timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
- else:
- # T2V mode: standard forward
- latent_model_input = latents.to(dtype)
- timestep = t.expand(latents.shape[0])
-
- do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None
- # Prepare kwargs for positive and negative predictions
- positive_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": prompt_embeds,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- "current_model": current_model,
- }
- if do_true_cfg:
- negative_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": negative_prompt_embeds,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- "current_model": current_model,
- }
- else:
- negative_kwargs = None
-
- # Predict noise with automatic CFG parallel handling
- noise_pred = self.predict_noise_maybe_with_cfg(
- do_true_cfg=do_true_cfg,
- true_cfg_scale=current_guidance_scale,
- positive_kwargs=positive_kwargs,
- negative_kwargs=negative_kwargs,
- cfg_normalize=False,
- )
-
- # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync
- latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
-
- pbar.update()
+ latents = self.diffuse(
+ latents=latents,
+ timesteps=timesteps,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ guidance_low=guidance_low,
+ guidance_high=guidance_high,
+ boundary_timestep=boundary_timestep,
+ dtype=dtype,
+ attention_kwargs=attention_kwargs,
+ latent_condition=latent_condition,
+ first_frame_mask=first_frame_mask,
+ )
# Wan2.2 is prone to out of memory errors when predicting large videos
# so we empty the cache here to avoid OOM before vae decoding.
@@ -743,6 +832,20 @@ def encode_prompt(
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_clean = [self._prompt_clean(p) for p in prompt]
batch_size = len(prompt_clean)
+ text_inputs_untruncated = self.tokenizer(
+ prompt_clean,
+ padding=True,
+ truncation=False,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ validate_prompt_sequence_lengths(
+ text_inputs_untruncated.attention_mask,
+ max_sequence_length=max_sequence_length,
+ supported_max_sequence_length=self.tokenizer_max_length,
+ error_context="for Wan2.2 text encoding",
+ )
text_inputs = self.tokenizer(
prompt_clean,
@@ -771,8 +874,24 @@ def encode_prompt(
if do_classifier_free_guidance:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt]
+ neg_text_inputs_untruncated = self.tokenizer(
+ negative_prompt_clean,
+ padding=True,
+ truncation=False,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ validate_prompt_sequence_lengths(
+ neg_text_inputs_untruncated.attention_mask,
+ max_sequence_length=max_sequence_length,
+ supported_max_sequence_length=self.tokenizer_max_length,
+ prompt_name="negative_prompt",
+ error_context="for Wan2.2 text encoding",
+ )
neg_text_inputs = self.tokenizer(
- [self._prompt_clean(p) for p in negative_prompt],
+ negative_prompt_clean,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
@@ -844,6 +963,7 @@ def check_inputs(
negative_prompt_embeds=None,
guidance_scale_2=None,
boundary_ratio=None,
+ max_sequence_length=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -870,5 +990,10 @@ def check_inputs(
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+ if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length:
+ raise ValueError(
+ f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}"
+ )
+
if boundary_ratio is None and guidance_scale_2 is not None:
raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.")
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
index c05ecc9c9a..95d1e08bbc 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
@@ -12,6 +12,7 @@
import numpy as np
import PIL.Image
import torch
+import torchvision.transforms.functional as TF
from diffusers.utils.torch_utils import randn_tensor
from torch import nn
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
@@ -24,14 +25,21 @@
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero
-from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import (
+ WAN22_MAX_SEQUENCE_LENGTH,
+ build_wan_scheduler,
create_transformer_from_config,
load_transformer_config,
+ resolve_wan_flow_shift,
+ resolve_wan_sample_solver,
retrieve_latents,
)
+from vllm_omni.diffusion.postprocess import interpolate_video_tensor
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
+from vllm_omni.diffusion.utils.prompt_utils import (
+ validate_prompt_sequence_lengths,
+)
from vllm_omni.inputs.data import OmniTextPrompt
from vllm_omni.platforms import current_omni_platform
@@ -72,10 +80,23 @@ def get_wan22_i2v_post_process_func(
def post_process_func(
video: torch.Tensor,
output_type: str = "np",
+ sampling_params=None,
):
if output_type == "latent":
return video
- return video_processor.postprocess_video(video, output_type=output_type)
+ custom_output = {}
+ if sampling_params is not None and getattr(sampling_params, "enable_frame_interpolation", False):
+ video, multiplier = interpolate_video_tensor(
+ video,
+ exp=sampling_params.frame_interpolation_exp,
+ scale=sampling_params.frame_interpolation_scale,
+ model_path=sampling_params.frame_interpolation_model_path,
+ )
+ custom_output["video_fps_multiplier"] = multiplier
+ return {
+ "video": video_processor.postprocess_video(video, output_type=output_type),
+ "custom_output": custom_output,
+ }
return post_process_func
@@ -197,6 +218,7 @@ def __init__(
# Text encoder
self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
+ self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH
self.text_encoder = UMT5EncoderModel.from_pretrained(
model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only
).to(self.device)
@@ -230,13 +252,9 @@ def __init__(
else:
self.transformer_2 = None
- # Initialize UniPC scheduler
- flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p
- self.scheduler = FlowUniPCMultistepScheduler(
- num_train_timesteps=1000,
- shift=flow_shift,
- prediction_type="flow_prediction",
- )
+ self._sample_solver = "unipc"
+ self._flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0
+ self.scheduler = build_wan_scheduler(self._sample_solver, self._flow_shift)
# VAE scale factors
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if hasattr(self.vae, "config") else 4
@@ -272,6 +290,82 @@ def num_timesteps(self):
def current_timestep(self):
return self._current_timestep
+ def diffuse(
+ self,
+ latents: torch.Tensor,
+ timesteps: torch.Tensor,
+ prompt_embeds: torch.Tensor,
+ negative_prompt_embeds: torch.Tensor | None,
+ image_embeds: torch.Tensor | None,
+ guidance_low: float,
+ guidance_high: float,
+ boundary_timestep: float | None,
+ dtype: torch.dtype,
+ attention_kwargs: dict[str, Any],
+ condition: torch.Tensor,
+ first_frame_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ with self.progress_bar(total=len(timesteps)) as pbar:
+ for t in timesteps:
+ self._current_timestep = t
+
+ # Select model and guidance scale based on timestep
+ current_model = self.transformer
+ current_guidance_scale = guidance_low
+ if boundary_timestep is not None and t < boundary_timestep and self.transformer_2 is not None:
+ current_model = self.transformer_2
+ current_guidance_scale = guidance_high
+
+ # Prepare latent input
+ if self.expand_timesteps:
+ # TI2V-5B style: blend condition with latents using mask
+ latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents
+ latent_model_input = latent_model_input.to(dtype)
+
+ # Expand timesteps for each patch
+ temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+ else:
+ # Wan2.1 style: concatenate condition with latents
+ latent_model_input = torch.cat([latents, condition], dim=1).to(dtype)
+ timestep = t.expand(latents.shape[0])
+
+ do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None
+ positive_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": prompt_embeds,
+ "encoder_hidden_states_image": image_embeds,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ "current_model": current_model,
+ }
+ if do_true_cfg:
+ negative_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": negative_prompt_embeds,
+ "encoder_hidden_states_image": image_embeds,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ "current_model": current_model,
+ }
+ else:
+ negative_kwargs = None
+
+ noise_pred = self.predict_noise_maybe_with_cfg(
+ do_true_cfg=do_true_cfg,
+ true_cfg_scale=current_guidance_scale,
+ positive_kwargs=positive_kwargs,
+ negative_kwargs=negative_kwargs,
+ cfg_normalize=False,
+ )
+
+ latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
+ pbar.update()
+
+ return latents
+
def encode_image(
self,
image: PIL.Image.Image | list[PIL.Image.Image],
@@ -380,6 +474,7 @@ def forward(
image_embeds=image_embeds,
guidance_scale_2=guidance_high if boundary_ratio is not None else None,
boundary_ratio=boundary_ratio,
+ max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
)
# Adjust num_frames to be compatible with VAE temporal scaling
@@ -408,7 +503,7 @@ def forward(
negative_prompt=negative_prompt,
do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0,
num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1,
- max_sequence_length=req.sampling_params.max_sequence_length or 512,
+ max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
device=device,
dtype=dtype,
)
@@ -440,6 +535,13 @@ def forward(
current_omni_platform.synchronize()
_t_img_enc_ms = (time.perf_counter() - _t_img_enc_start) * 1000
+ sample_solver = resolve_wan_sample_solver(req, default=self._sample_solver)
+ flow_shift = resolve_wan_flow_shift(req, self.od_config)
+ if sample_solver != self._sample_solver or abs(flow_shift - self._flow_shift) > 1e-6:
+ self.scheduler = build_wan_scheduler(sample_solver, flow_shift)
+ self._sample_solver = sample_solver
+ self._flow_shift = flow_shift
+
# Timesteps
self.scheduler.set_timesteps(num_steps, device=device)
timesteps = self.scheduler.timesteps
@@ -459,6 +561,7 @@ def forward(
video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
if isinstance(image, PIL.Image.Image):
+ image = TF.to_tensor(image).to(device)
image_tensor = video_processor.preprocess(image, height=height, width=width)
else:
image_tensor = image
@@ -467,6 +570,7 @@ def forward(
# Handle last_image if provided
if last_image is not None:
if isinstance(last_image, PIL.Image.Image):
+ image = TF.to_tensor(last_image).to(device)
last_image_tensor = video_processor.preprocess(last_image, height=height, width=width)
else:
last_image_tensor = last_image
@@ -497,68 +601,20 @@ def forward(
if DEBUG_PERF:
_t_denoise_start = time.perf_counter()
- with self.progress_bar(total=len(timesteps)) as pbar:
- for t in timesteps:
- self._current_timestep = t
-
- # Select model and guidance scale based on timestep
- current_model = self.transformer
- current_guidance_scale = guidance_low
- if boundary_timestep is not None and t < boundary_timestep and self.transformer_2 is not None:
- current_model = self.transformer_2
- current_guidance_scale = guidance_high
-
- # Prepare latent input
- if self.expand_timesteps:
- # TI2V-5B style: blend condition with latents using mask
- latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents
- latent_model_input = latent_model_input.to(dtype)
-
- # Expand timesteps for each patch
- temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
- timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
- else:
- # Wan2.1 style: concatenate condition with latents
- latent_model_input = torch.cat([latents, condition], dim=1).to(dtype)
- timestep = t.expand(latents.shape[0])
-
- do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None
- # Prepare kwargs for positive and negative predictions
- positive_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": prompt_embeds,
- "encoder_hidden_states_image": image_embeds,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- "current_model": current_model,
- }
- if do_true_cfg:
- negative_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": negative_prompt_embeds,
- "encoder_hidden_states_image": image_embeds,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- "current_model": current_model,
- }
- else:
- negative_kwargs = None
-
- # Predict noise with automatic CFG parallel handling
- noise_pred = self.predict_noise_maybe_with_cfg(
- do_true_cfg=do_true_cfg,
- true_cfg_scale=current_guidance_scale,
- positive_kwargs=positive_kwargs,
- negative_kwargs=negative_kwargs,
- cfg_normalize=False,
- )
-
- # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync
- latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
-
- pbar.update()
+ latents = self.diffuse(
+ latents=latents,
+ timesteps=timesteps,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ image_embeds=image_embeds,
+ guidance_low=guidance_low,
+ guidance_high=guidance_high,
+ boundary_timestep=boundary_timestep,
+ dtype=dtype,
+ attention_kwargs=attention_kwargs,
+ condition=condition,
+ first_frame_mask=first_frame_mask,
+ )
# Wan2.2 is prone to out of memory errors when predicting large videos
# so we empty the cache here to avoid OOM before vae decoding.
@@ -652,6 +708,20 @@ def encode_prompt(
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_clean = [self._prompt_clean(p) for p in prompt]
batch_size = len(prompt_clean)
+ text_inputs_untruncated = self.tokenizer(
+ prompt_clean,
+ padding=True,
+ truncation=False,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ validate_prompt_sequence_lengths(
+ text_inputs_untruncated.attention_mask,
+ max_sequence_length=max_sequence_length,
+ supported_max_sequence_length=self.tokenizer_max_length,
+ error_context="for Wan2.2 text encoding",
+ )
text_inputs = self.tokenizer(
prompt_clean,
@@ -680,8 +750,24 @@ def encode_prompt(
if do_classifier_free_guidance:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt]
+ neg_text_inputs_untruncated = self.tokenizer(
+ negative_prompt_clean,
+ padding=True,
+ truncation=False,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ validate_prompt_sequence_lengths(
+ neg_text_inputs_untruncated.attention_mask,
+ max_sequence_length=max_sequence_length,
+ supported_max_sequence_length=self.tokenizer_max_length,
+ prompt_name="negative_prompt",
+ error_context="for Wan2.2 text encoding",
+ )
neg_text_inputs = self.tokenizer(
- [self._prompt_clean(p) for p in negative_prompt],
+ negative_prompt_clean,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
@@ -789,12 +875,14 @@ def prepare_latents(
return latents, latent_condition, first_frame_mask
# Wan2.1 style: create mask and concatenate with condition
- mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
+ mask_lat_size = torch.ones(
+ batch_size, 1, num_frames, latent_height, latent_width, device=latent_condition.device
+ )
if last_image is None:
- mask_lat_size[:, :, list(range(1, num_frames))] = 0
+ mask_lat_size[:, :, 1:] = 0
else:
- mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
+ mask_lat_size[:, :, 1 : num_frames - 1] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
@@ -823,6 +911,7 @@ def check_inputs(
image_embeds=None,
guidance_scale_2=None,
boundary_ratio=None,
+ max_sequence_length=None,
):
if image is None and image_embeds is None:
raise ValueError("Provide either `image` or `image_embeds`. Cannot leave both undefined.")
@@ -844,6 +933,11 @@ def check_inputs(
if prompt is None and prompt_embeds is None:
raise ValueError("Provide either `prompt` or `prompt_embeds`.")
+ if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length:
+ raise ValueError(
+ f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}"
+ )
+
if boundary_ratio is None and guidance_scale_2 is not None:
raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.")
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py
index 261f62fb79..dba76ba8af 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py
@@ -36,13 +36,20 @@
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin
-from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import (
+ WAN22_MAX_SEQUENCE_LENGTH,
+ build_wan_scheduler,
create_transformer_from_config,
load_transformer_config,
+ resolve_wan_flow_shift,
+ resolve_wan_sample_solver,
retrieve_latents,
)
+from vllm_omni.diffusion.postprocess import interpolate_video_tensor
from vllm_omni.diffusion.request import OmniDiffusionRequest
+from vllm_omni.diffusion.utils.prompt_utils import (
+ validate_prompt_sequence_lengths,
+)
from vllm_omni.inputs.data import OmniTextPrompt
from vllm_omni.platforms import current_omni_platform
@@ -59,10 +66,23 @@ def get_wan22_ti2v_post_process_func(
def post_process_func(
video: torch.Tensor,
output_type: str = "np",
+ sampling_params=None,
):
if output_type == "latent":
return video
- return video_processor.postprocess_video(video, output_type=output_type)
+ custom_output = {}
+ if sampling_params is not None and getattr(sampling_params, "enable_frame_interpolation", False):
+ video, multiplier = interpolate_video_tensor(
+ video,
+ exp=sampling_params.frame_interpolation_exp,
+ scale=sampling_params.frame_interpolation_scale,
+ model_path=sampling_params.frame_interpolation_model_path,
+ )
+ custom_output["video_fps_multiplier"] = multiplier
+ return {
+ "video": video_processor.postprocess_video(video, output_type=output_type),
+ "custom_output": custom_output,
+ }
return post_process_func
@@ -169,6 +189,7 @@ def __init__(
# Text encoder
self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
+ self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH
self.text_encoder = UMT5EncoderModel.from_pretrained(
model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only
).to(self.device)
@@ -183,13 +204,9 @@ def __init__(
transformer_config = load_transformer_config(model, "transformer", local_files_only)
self.transformer = create_transformer_from_config(transformer_config)
- # Initialize UniPC scheduler
- flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p
- self.scheduler = FlowUniPCMultistepScheduler(
- num_train_timesteps=1000,
- shift=flow_shift,
- prediction_type="flow_prediction",
- )
+ self._sample_solver = "unipc"
+ self._flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0
+ self.scheduler = build_wan_scheduler(self._sample_solver, self._flow_shift)
# VAE scale factors
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if hasattr(self.vae, "config") else 4
@@ -218,6 +235,77 @@ def num_timesteps(self):
def current_timestep(self):
return self._current_timestep
+ def diffuse(
+ self,
+ latents: torch.Tensor,
+ timesteps: torch.Tensor,
+ prompt_embeds: torch.Tensor,
+ negative_prompt_embeds: torch.Tensor | None,
+ guidance_scale: float,
+ dtype: torch.dtype,
+ attention_kwargs: dict[str, Any],
+ num_latent_frames: int,
+ latent_height: int,
+ latent_width: int,
+ latent_condition: torch.Tensor | None = None,
+ first_frame_mask: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ with self.progress_bar(total=len(timesteps)) as pbar:
+ for t in timesteps:
+ self._current_timestep = t
+
+ # Prepare latent input
+ if latent_condition is not None:
+ # I2V mode: blend condition with latents using mask
+ latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents
+ latent_model_input = latent_model_input.to(dtype)
+
+ # Expand timesteps for each patch (TI2V style)
+ temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+ else:
+ # T2V mode: use latents directly
+ latent_model_input = latents.to(dtype)
+
+ # Expand timesteps for TI2V model architecture
+ mask = torch.ones(1, 1, num_latent_frames, latent_height, latent_width, device=latents.device)
+ temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+
+ do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None
+ positive_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": prompt_embeds,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ "current_model": self.transformer,
+ }
+ if do_true_cfg:
+ negative_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": negative_prompt_embeds,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ "current_model": self.transformer,
+ }
+ else:
+ negative_kwargs = None
+
+ noise_pred = self.predict_noise_maybe_with_cfg(
+ do_true_cfg=do_true_cfg,
+ true_cfg_scale=guidance_scale,
+ positive_kwargs=positive_kwargs,
+ negative_kwargs=negative_kwargs,
+ cfg_normalize=False,
+ )
+
+ latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
+ pbar.update()
+
+ return latents
+
def forward(
self,
req: OmniDiffusionRequest,
@@ -289,6 +377,7 @@ def forward(
width=width,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
)
# Adjust num_frames to be compatible with VAE temporal scaling
@@ -312,7 +401,7 @@ def forward(
negative_prompt=negative_prompt,
do_classifier_free_guidance=guidance_scale > 1.0,
num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1,
- max_sequence_length=req.sampling_params.max_sequence_length or 512,
+ max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
device=device,
dtype=dtype,
)
@@ -323,6 +412,13 @@ def forward(
batch_size = prompt_embeds.shape[0]
+ sample_solver = resolve_wan_sample_solver(req, default=self._sample_solver)
+ flow_shift = resolve_wan_flow_shift(req, self.od_config)
+ if sample_solver != self._sample_solver or abs(flow_shift - self._flow_shift) > 1e-6:
+ self.scheduler = build_wan_scheduler(sample_solver, flow_shift)
+ self._sample_solver = sample_solver
+ self._flow_shift = flow_shift
+
# Timesteps
self.scheduler.set_timesteps(num_steps, device=device)
timesteps = self.scheduler.timesteps
@@ -380,64 +476,20 @@ def forward(
if attention_kwargs is None:
attention_kwargs = {}
- # Denoising loop
- with self.progress_bar(total=len(timesteps)) as pbar:
- for t in timesteps:
- self._current_timestep = t
-
- # Prepare latent input
- if latent_condition is not None:
- # I2V mode: blend condition with latents using mask
- latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents
- latent_model_input = latent_model_input.to(dtype)
-
- # Expand timesteps for each patch (TI2V style)
- temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
- timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
- else:
- # T2V mode: use latents directly
- latent_model_input = latents.to(dtype)
-
- # Expand timesteps for TI2V model architecture
- mask = torch.ones(1, 1, num_latent_frames, latent_height, latent_width, device=device)
- temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
- timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
-
- do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None
- # Prepare kwargs for positive and negative predictions
- positive_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": prompt_embeds,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- "current_model": self.transformer,
- }
- if do_true_cfg:
- negative_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": negative_prompt_embeds,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- "current_model": self.transformer,
- }
- else:
- negative_kwargs = None
-
- # Predict noise with automatic CFG parallel handling
- noise_pred = self.predict_noise_maybe_with_cfg(
- do_true_cfg=do_true_cfg,
- true_cfg_scale=guidance_scale,
- positive_kwargs=positive_kwargs,
- negative_kwargs=negative_kwargs,
- cfg_normalize=False,
- )
-
- # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync
- latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
-
- pbar.update()
+ latents = self.diffuse(
+ latents=latents,
+ timesteps=timesteps,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ guidance_scale=guidance_scale,
+ dtype=dtype,
+ attention_kwargs=attention_kwargs,
+ num_latent_frames=num_latent_frames,
+ latent_height=latent_height,
+ latent_width=latent_width,
+ latent_condition=latent_condition,
+ first_frame_mask=first_frame_mask,
+ )
# Wan2.2 is prone to out of memory errors when predicting large videos
# so we empty the cache here to avoid OOM before vae decoding.
@@ -499,6 +551,20 @@ def encode_prompt(
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_clean = [self._prompt_clean(p) for p in prompt]
batch_size = len(prompt_clean)
+ text_inputs_untruncated = self.tokenizer(
+ prompt_clean,
+ padding=True,
+ truncation=False,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ validate_prompt_sequence_lengths(
+ text_inputs_untruncated.attention_mask,
+ max_sequence_length=max_sequence_length,
+ supported_max_sequence_length=self.tokenizer_max_length,
+ error_context="for Wan2.2 text encoding",
+ )
text_inputs = self.tokenizer(
prompt_clean,
@@ -527,8 +593,24 @@ def encode_prompt(
if do_classifier_free_guidance:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt]
+ neg_text_inputs_untruncated = self.tokenizer(
+ negative_prompt_clean,
+ padding=True,
+ truncation=False,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ validate_prompt_sequence_lengths(
+ neg_text_inputs_untruncated.attention_mask,
+ max_sequence_length=max_sequence_length,
+ supported_max_sequence_length=self.tokenizer_max_length,
+ prompt_name="negative_prompt",
+ error_context="for Wan2.2 text encoding",
+ )
neg_text_inputs = self.tokenizer(
- [self._prompt_clean(p) for p in negative_prompt],
+ negative_prompt_clean,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
@@ -653,6 +735,7 @@ def check_inputs(
width,
prompt_embeds=None,
negative_prompt_embeds=None,
+ max_sequence_length=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -668,6 +751,11 @@ def check_inputs(
if prompt is None and prompt_embeds is None:
raise ValueError("Provide either `prompt` or `prompt_embeds`.")
+ if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length:
+ raise ValueError(
+ f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}"
+ )
+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights using AutoWeightsLoader for vLLM integration."""
loader = AutoWeightsLoader(self)
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py
index ea52336311..11408e2d24 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py
@@ -176,6 +176,62 @@ def _create_transformer(self, config: dict) -> WanVACETransformer3DModel:
"""Build VACE transformer directly from config dict."""
return create_vace_transformer_from_config(config)
+ def diffuse(
+ self,
+ latents: torch.Tensor,
+ timesteps: torch.Tensor,
+ prompt_embeds: torch.Tensor,
+ negative_prompt_embeds: torch.Tensor | None,
+ guidance_scale: float,
+ dtype: torch.dtype,
+ attention_kwargs: dict[str, object],
+ vace_context: torch.Tensor | None,
+ vace_context_scale: float,
+ ) -> torch.Tensor:
+ with self.progress_bar(total=len(timesteps)) as pbar:
+ for t in timesteps:
+ self._current_timestep = t
+ latent_model_input = latents.to(dtype)
+ timestep = t.expand(latents.shape[0])
+
+ do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None
+
+ positive_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": prompt_embeds,
+ "attention_kwargs": attention_kwargs,
+ "vace_context": vace_context,
+ "vace_context_scale": vace_context_scale,
+ "return_dict": False,
+ }
+ negative_kwargs = (
+ {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": negative_prompt_embeds,
+ "attention_kwargs": attention_kwargs,
+ "vace_context": vace_context,
+ "vace_context_scale": vace_context_scale,
+ "return_dict": False,
+ }
+ if do_true_cfg
+ else None
+ )
+
+ noise_pred = self.predict_noise_maybe_with_cfg(
+ do_true_cfg=do_true_cfg,
+ true_cfg_scale=guidance_scale,
+ positive_kwargs=positive_kwargs,
+ negative_kwargs=negative_kwargs,
+ cfg_normalize=False,
+ )
+
+ latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
+ pbar.update()
+
+ return latents
+
def check_inputs(
self,
prompt,
@@ -187,6 +243,7 @@ def check_inputs(
video=None,
mask=None,
reference_images=None,
+ max_sequence_length=None,
):
super().check_inputs(
prompt=prompt,
@@ -195,6 +252,7 @@ def check_inputs(
width=width,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
)
# VACE-specific: validate video/mask/reference_images consistency
@@ -491,6 +549,7 @@ def forward(
video=source_video,
mask=source_mask,
reference_images=reference_images,
+ max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
)
device = self.device
@@ -509,7 +568,7 @@ def forward(
negative_prompt=negative_prompt,
do_classifier_free_guidance=guidance_scale > 1.0,
num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1,
- max_sequence_length=req.sampling_params.max_sequence_length or 512,
+ max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
device=device,
dtype=dtype,
)
@@ -569,48 +628,17 @@ def forward(
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
- # Denoising loop
- with self.progress_bar(total=len(timesteps)) as pbar:
- for t in timesteps:
- self._current_timestep = t
- latent_model_input = latents.to(dtype)
- timestep = t.expand(latents.shape[0])
-
- do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None
-
- positive_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": prompt_embeds,
- "attention_kwargs": attention_kwargs,
- "vace_context": vace_context,
- "vace_context_scale": vace_context_scale,
- "return_dict": False,
- }
- negative_kwargs = (
- {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": negative_prompt_embeds,
- "attention_kwargs": attention_kwargs,
- "vace_context": vace_context,
- "vace_context_scale": vace_context_scale,
- "return_dict": False,
- }
- if do_true_cfg
- else None
- )
-
- noise_pred = self.predict_noise_maybe_with_cfg(
- do_true_cfg=do_true_cfg,
- true_cfg_scale=guidance_scale,
- positive_kwargs=positive_kwargs,
- negative_kwargs=negative_kwargs,
- cfg_normalize=False,
- )
-
- latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
- pbar.update()
+ latents = self.diffuse(
+ latents=latents,
+ timesteps=timesteps,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ guidance_scale=guidance_scale,
+ dtype=dtype,
+ attention_kwargs=attention_kwargs,
+ vace_context=vace_context,
+ vace_context_scale=vace_context_scale,
+ )
self._current_timestep = None
diff --git a/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py b/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py
new file mode 100644
index 0000000000..25444044c2
--- /dev/null
+++ b/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py
@@ -0,0 +1,147 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from types import SimpleNamespace
+
+import numpy as np
+import torch
+
+
+@dataclass
+class WanEulerSchedulerOutput:
+ prev_sample: torch.FloatTensor
+
+
+def _unsqueeze_to_ndim(in_tensor: torch.Tensor, target_ndim: int) -> torch.Tensor:
+ if in_tensor.ndim >= target_ndim:
+ return in_tensor
+ return in_tensor[(...,) + (None,) * (target_ndim - in_tensor.ndim)]
+
+
+def _get_timesteps(num_steps: int, max_steps: int = 1000) -> np.ndarray:
+ # Keep num_steps + 1 points so Euler update can always access sigma_next.
+ return np.linspace(max_steps, 0, num_steps + 1, dtype=np.float32)
+
+
+def _timestep_shift(timesteps: torch.Tensor, shift: float = 1.0) -> torch.Tensor:
+ return shift * timesteps / (1 + (shift - 1) * timesteps)
+
+
+class WanEulerScheduler:
+ order = 1
+
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ shift: float = 1.0,
+ device: torch.device | str = "cpu",
+ ) -> None:
+ self.num_train_timesteps = int(num_train_timesteps)
+ self._shift = float(shift)
+ self.device = device
+ self.config = SimpleNamespace(num_train_timesteps=self.num_train_timesteps)
+ self.init_noise_sigma = 1.0
+
+ self._step_index: int | None = None
+ self._begin_index: int | None = None
+
+ self.timesteps = torch.empty(0, dtype=torch.float32)
+ self.sigmas = torch.empty(0, dtype=torch.float32)
+ self.timesteps_ori = torch.empty(0, dtype=torch.float32)
+
+ self.set_timesteps(num_inference_steps=self.num_train_timesteps, device=self.device)
+
+ @property
+ def step_index(self) -> int | None:
+ return self._step_index
+
+ @property
+ def begin_index(self) -> int | None:
+ return self._begin_index
+
+ def set_begin_index(self, begin_index: int = 0) -> None:
+ self._begin_index = int(begin_index)
+
+ def index_for_timestep(self, timestep: torch.Tensor) -> int:
+ indices = (self.timesteps == timestep).nonzero()
+ if len(indices) > 0:
+ pos = 1 if len(indices) > 1 else 0
+ return int(indices[pos].item())
+ # Fallback for tiny float drift
+ return int(torch.argmin(torch.abs(self.timesteps - timestep)).item())
+
+ def _init_step_index(self, timestep: float | torch.Tensor) -> None:
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep_t = timestep.to(self.timesteps.device, dtype=self.timesteps.dtype)
+ else:
+ timestep_t = torch.tensor(timestep, device=self.timesteps.device, dtype=self.timesteps.dtype)
+ self._step_index = self.index_for_timestep(timestep_t)
+ else:
+ self._step_index = self._begin_index
+
+ def set_shift(self, shift: float = 1.0) -> None:
+ # Compute shifted sigma schedule on [0, 1].
+ sigmas_full = self.timesteps_ori / float(self.num_train_timesteps)
+ sigmas_full = _timestep_shift(sigmas_full, shift=float(shift))
+ self.sigmas = sigmas_full
+ # Public timesteps are the first N points; next point is consumed as sigma_next.
+ self.timesteps = self.sigmas[:-1] * self.num_train_timesteps
+ self._shift = float(shift)
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int,
+ device: torch.device | str | int | None = None,
+ **kwargs, # noqa: ARG002 - kept for scheduler API compatibility
+ ) -> None:
+ timesteps = _get_timesteps(
+ num_steps=int(num_inference_steps),
+ max_steps=self.num_train_timesteps,
+ )
+ self.timesteps_ori = torch.from_numpy(timesteps).to(
+ dtype=torch.float32,
+ device=device or self.device,
+ )
+ self.set_shift(self._shift)
+ self._step_index = None
+ self._begin_index = None
+
+ def scale_model_input(self, sample: torch.Tensor, timestep: int | None = None) -> torch.Tensor: # noqa: ARG002
+ return sample
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: float | torch.FloatTensor,
+ sample: torch.FloatTensor,
+ return_dict: bool = True,
+ **kwargs, # noqa: ARG002 - kept for scheduler API compatibility
+ ) -> WanEulerSchedulerOutput | tuple[torch.FloatTensor]:
+ if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
+ raise ValueError(
+ "Passing integer indices as timesteps is not supported. Use one value from scheduler.timesteps instead."
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+ assert self._step_index is not None
+
+ sample_fp32 = sample.to(torch.float32)
+ sigma = _unsqueeze_to_ndim(self.sigmas[self._step_index], sample_fp32.ndim).to(sample_fp32.device)
+ sigma_next = _unsqueeze_to_ndim(self.sigmas[self._step_index + 1], sample_fp32.ndim).to(sample_fp32.device)
+
+ prev_sample = sample_fp32 + (sigma_next - sigma) * model_output
+ prev_sample = prev_sample.to(model_output.dtype)
+
+ self._step_index += 1
+
+ if not return_dict:
+ return (prev_sample,)
+ return WanEulerSchedulerOutput(prev_sample=prev_sample)
+
+ def __len__(self) -> int:
+ return self.num_train_timesteps
diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py
index 65a2d4390a..d4d81b78eb 100644
--- a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py
+++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py
@@ -11,7 +11,6 @@
from diffusers.models.attention import FeedForward
from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
from diffusers.models.modeling_outputs import Transformer2DModelOutput
-from diffusers.models.normalization import FP32LayerNorm
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@@ -29,6 +28,8 @@
SequenceParallelOutput,
)
from vllm_omni.diffusion.forward_context import get_forward_context
+from vllm_omni.diffusion.layers.adalayernorm import AdaLayerNorm
+from vllm_omni.diffusion.layers.norm import LayerNorm, RMSNorm
from vllm_omni.platforms import current_omni_platform
logger = init_logger(__name__)
@@ -235,9 +236,9 @@ class WanImageEmbedding(nn.Module):
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len: int | None = None):
super().__init__()
- self.norm1 = FP32LayerNorm(in_features)
+ self.norm1 = LayerNorm(in_features)
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
- self.norm2 = FP32LayerNorm(out_features)
+ self.norm2 = LayerNorm(out_features)
if pos_embed_seq_len is not None:
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
else:
@@ -377,8 +378,12 @@ def __init__(
self.tp_inner_dim = self.num_heads * head_dim
# QK normalization using vLLM's RMSNorm
- self.norm_q = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
- self.norm_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
+ if get_tensor_model_parallel_world_size() > 1:
+ self.norm_q = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
+ self.norm_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
+ else:
+ self.norm_q = RMSNorm(self.tp_inner_dim, eps=eps)
+ self.norm_k = RMSNorm(self.tp_inner_dim, eps=eps)
self.to_out = RowParallelLinear(
self.inner_dim,
@@ -497,8 +502,12 @@ def __init__(
self.tp_inner_dim = self.num_heads * head_dim
# QK normalization
- self.norm_q = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
- self.norm_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
+ if get_tensor_model_parallel_world_size() > 1:
+ self.norm_q = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
+ self.norm_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
+ else:
+ self.norm_q = RMSNorm(self.tp_inner_dim, eps=eps)
+ self.norm_k = RMSNorm(self.tp_inner_dim, eps=eps)
# Optional added KV projections for I2V (image embeddings)
self.added_kv_proj_dim = added_kv_proj_dim
@@ -517,7 +526,10 @@ def __init__(
gather_output=False,
return_bias=False,
)
- self.norm_added_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
+ if get_tensor_model_parallel_world_size() > 1:
+ self.norm_added_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
+ else:
+ self.norm_added_k = RMSNorm(self.tp_inner_dim, eps=eps)
else:
self.add_k_proj = None
self.add_v_proj = None
@@ -620,7 +632,7 @@ def __init__(
head_dim = dim // num_heads
# 1. Self-attention
- self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+ self.norm1 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps)
self.attn1 = WanSelfAttention(
dim=dim,
num_heads=num_heads,
@@ -636,11 +648,11 @@ def __init__(
eps=eps,
added_kv_proj_dim=added_kv_proj_dim,
)
- self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
+ self.norm2 = LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
# 3. Feed-forward
self.ffn = WanFeedForward(dim=dim, inner_dim=ffn_dim, dim_out=dim)
- self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+ self.norm3 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps)
# Scale-shift table for modulation
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
@@ -656,7 +668,7 @@ def forward(
if temb.ndim == 4:
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
- self.scale_shift_table.unsqueeze(0) + temb.float()
+ self.scale_shift_table.unsqueeze(0) + temb
).chunk(6, dim=2)
shift_msa = shift_msa.squeeze(2)
scale_msa = scale_msa.squeeze(2)
@@ -667,25 +679,23 @@ def forward(
else:
# temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
- self.scale_shift_table + temb.float()
+ self.scale_shift_table + temb
).chunk(6, dim=1)
# 1. Self-attention
- norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
+ norm_hidden_states = self.norm1(hidden_states, scale_msa, shift_msa).type_as(hidden_states)
attn_output = self.attn1(norm_hidden_states, rotary_emb, hidden_states_mask)
- hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
+ hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states)
# 2. Cross-attention
- norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
+ norm_hidden_states = self.norm2(hidden_states).type_as(hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states)
hidden_states = hidden_states + attn_output
# 3. Feed-forward
- norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
- hidden_states
- )
+ norm_hidden_states = self.norm3(hidden_states, c_scale_msa, c_shift_msa).type_as(hidden_states)
ff_output = self.ffn(norm_hidden_states)
- hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
+ hidden_states = (hidden_states + ff_output * c_gate_msa).type_as(hidden_states)
return hidden_states
@@ -854,7 +864,7 @@ def __init__(
)
# 4. Output norm & projection
- self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
+ self.norm_out = AdaLayerNorm(inner_dim, elementwise_affine=False, eps=eps)
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
# SP helper modules
@@ -942,7 +952,7 @@ def forward(
shift = shift.unsqueeze(1)
scale = scale.unsqueeze(1)
- hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
+ hidden_states = self.norm_out(hidden_states, scale, shift).type_as(hidden_states)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
@@ -1015,6 +1025,14 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if ".to_out.0." in lookup_name:
lookup_name = lookup_name.replace(".to_out.0.", ".to_out.")
+ # Compatibility: some Wan conversion pipelines still keep
+ # block modulation keys as `blocks.N.modulation` instead of
+ # `blocks.N.scale_shift_table`.
+ if lookup_name.endswith(".modulation"):
+ modulation_alias = lookup_name[: -len(".modulation")] + ".scale_shift_table"
+ if modulation_alias in params_dict:
+ lookup_name = modulation_alias
+
if lookup_name not in params_dict:
logger.warning(f"Skipping weight {original_name} -> {lookup_name}")
continue
diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py
index 4f4217dabf..c48938e1ba 100644
--- a/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py
+++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py
@@ -239,7 +239,7 @@ def forward(
shift = shift.unsqueeze(1)
scale = scale.unsqueeze(1)
- hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
+ hidden_states = self.norm_out(hidden_states, scale, shift).type_as(hidden_states)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
diff --git a/vllm_omni/diffusion/models/z_image/z_image_transformer.py b/vllm_omni/diffusion/models/z_image/z_image_transformer.py
index 3ffad221ba..c36ea74665 100644
--- a/vllm_omni/diffusion/models/z_image/z_image_transformer.py
+++ b/vllm_omni/diffusion/models/z_image/z_image_transformer.py
@@ -214,12 +214,14 @@ def __init__(
super().__init__()
if mid_size is None:
mid_size = out_size
+ # Time embedding MLP is kept full precision (quant_config=None) —
+ # small layers that feed adaLN; precision-sensitive (see #2728).
self.mlp = nn.Sequential(
ReplicatedLinear(
frequency_embedding_size,
mid_size,
bias=True,
- quant_config=quant_config,
+ quant_config=None,
return_bias=False,
),
nn.SiLU(),
@@ -227,7 +229,7 @@ def __init__(
mid_size,
out_size,
bias=True,
- quant_config=quant_config,
+ quant_config=None,
return_bias=False,
),
)
@@ -426,9 +428,16 @@ def __init__(
self.modulation = modulation
if modulation:
+ # Modulation linear is kept at full precision (quant_config=None)
+ # — it produces scale/gate values that are precision-sensitive
+ # (see #2728, mirrors OmniGen2 fix).
self.adaLN_modulation = nn.Sequential(
ReplicatedLinear(
- min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True, return_bias=False, quant_config=quant_config
+ min(dim, ADALN_EMBED_DIM),
+ 4 * dim,
+ bias=True,
+ quant_config=None,
+ return_bias=False,
),
)
@@ -485,14 +494,24 @@ class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, quant_config: "QuantizationConfig | None" = None):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ # Final output projection and its modulation are precision-sensitive
+ # (produce the output latent); keep at full precision (see #2728).
self.linear = ReplicatedLinear(
- hidden_size, out_channels, bias=True, quant_config=quant_config, return_bias=False
+ hidden_size,
+ out_channels,
+ bias=True,
+ quant_config=None,
+ return_bias=False,
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
ReplicatedLinear(
- min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True, quant_config=quant_config, return_bias=False
+ min(hidden_size, ADALN_EMBED_DIM),
+ hidden_size,
+ bias=True,
+ quant_config=None,
+ return_bias=False,
),
)
@@ -673,11 +692,13 @@ def __init__(
all_x_embedder = {}
all_final_layer = {}
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
+ # x_embedder (patch embed) is a small precision-sensitive entry
+ # layer; keep full precision (see #2728).
x_embedder = ReplicatedLinear(
f_patch_size * patch_size * patch_size * in_channels,
dim,
bias=True,
- quant_config=quant_config,
+ quant_config=None,
return_bias=False,
)
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
@@ -720,9 +741,17 @@ def __init__(
]
)
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024, quant_config=quant_config)
+ # Caption embedder maps text features -> hidden; keep full precision
+ # (see #2728).
self.cap_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps),
- ReplicatedLinear(cap_feat_dim, dim, bias=True, return_bias=False, quant_config=quant_config),
+ ReplicatedLinear(
+ cap_feat_dim,
+ dim,
+ bias=True,
+ quant_config=None,
+ return_bias=False,
+ ),
)
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
diff --git a/vllm_omni/diffusion/postprocess/__init__.py b/vllm_omni/diffusion/postprocess/__init__.py
new file mode 100644
index 0000000000..e6fe5b2d22
--- /dev/null
+++ b/vllm_omni/diffusion/postprocess/__init__.py
@@ -0,0 +1,10 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Diffusion post-processing helpers."""
+
+from vllm_omni.diffusion.postprocess.rife_interpolator import (
+ FrameInterpolator,
+ interpolate_video_tensor,
+)
+
+__all__ = ["FrameInterpolator", "interpolate_video_tensor"]
diff --git a/vllm_omni/diffusion/postprocess/rife_interpolator.py b/vllm_omni/diffusion/postprocess/rife_interpolator.py
new file mode 100644
index 0000000000..89297d0a44
--- /dev/null
+++ b/vllm_omni/diffusion/postprocess/rife_interpolator.py
@@ -0,0 +1,443 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+RIFE 4.22.lite frame interpolation for vLLM-Omni video generation.
+
+RIFE model code is vendored and adapted from:
+ - https://github.com/hzwer/ECCV2022-RIFE (MIT License)
+ - https://github.com/hzwer/Practical-RIFE (MIT License)
+ Copyright (c) 2021 Zhewei Huang
+
+The FrameInterpolator wrapper and vLLM-Omni integration code are original work.
+"""
+
+from __future__ import annotations
+
+import os
+import threading
+from typing import Any
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+_DEFAULT_RIFE_HF_REPO = "elfgum/RIFE-4.22.lite"
+_MODEL_CACHE: dict[tuple[str, str], Model] = {}
+_MODEL_CACHE_LOCK = threading.Lock()
+
+
+def warp(ten_input: torch.Tensor, ten_flow: torch.Tensor) -> torch.Tensor:
+ """Warp input tensor by optical flow using grid_sample."""
+ ten_horizontal = (
+ torch.linspace(-1.0, 1.0, ten_flow.shape[3], device=ten_flow.device)
+ .view(1, 1, 1, ten_flow.shape[3])
+ .expand(ten_flow.shape[0], -1, ten_flow.shape[2], -1)
+ )
+ ten_vertical = (
+ torch.linspace(-1.0, 1.0, ten_flow.shape[2], device=ten_flow.device)
+ .view(1, 1, ten_flow.shape[2], 1)
+ .expand(ten_flow.shape[0], -1, -1, ten_flow.shape[3])
+ )
+ ten_grid = torch.cat([ten_horizontal, ten_vertical], dim=1)
+
+ ten_flow = torch.cat(
+ [
+ ten_flow[:, 0:1, :, :] / ((ten_input.shape[3] - 1.0) / 2.0),
+ ten_flow[:, 1:2, :, :] / ((ten_input.shape[2] - 1.0) / 2.0),
+ ],
+ dim=1,
+ )
+ grid = (ten_grid + ten_flow).permute(0, 2, 3, 1)
+ return F.grid_sample(
+ input=ten_input,
+ grid=grid,
+ mode="bilinear",
+ padding_mode="border",
+ align_corners=True,
+ )
+
+
+def _conv(
+ in_planes: int,
+ out_planes: int,
+ kernel_size: int = 3,
+ stride: int = 1,
+ padding: int = 1,
+ dilation: int = 1,
+) -> nn.Sequential:
+ return nn.Sequential(
+ nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=True,
+ ),
+ nn.LeakyReLU(0.2, True),
+ )
+
+
+class ResConv(nn.Module):
+ """Residual convolution block with learnable beta scaling."""
+
+ def __init__(self, c: int, dilation: int = 1):
+ super().__init__()
+ self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1)
+ self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
+ self.relu = nn.LeakyReLU(0.2, True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.relu(self.conv(x) * self.beta + x)
+
+
+class IFBlock(nn.Module):
+ """Single-scale optical flow, mask, and feature block."""
+
+ def __init__(self, in_planes: int, c: int = 64):
+ super().__init__()
+ self.conv0 = nn.Sequential(
+ _conv(in_planes, c // 2, 3, 2, 1),
+ _conv(c // 2, c, 3, 2, 1),
+ )
+ self.convblock = nn.Sequential(
+ ResConv(c),
+ ResConv(c),
+ ResConv(c),
+ ResConv(c),
+ ResConv(c),
+ ResConv(c),
+ ResConv(c),
+ ResConv(c),
+ )
+ self.lastconv = nn.Sequential(
+ nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1),
+ nn.PixelShuffle(2),
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ flow: torch.Tensor | None = None,
+ scale: float = 1.0,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
+ if flow is not None:
+ flow = (
+ F.interpolate(
+ flow,
+ scale_factor=1.0 / scale,
+ mode="bilinear",
+ align_corners=False,
+ )
+ * 1.0
+ / scale
+ )
+ x = torch.cat((x, flow), 1)
+ feat = self.conv0(x)
+ feat = self.convblock(feat)
+ tmp = self.lastconv(feat)
+ tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False)
+ flow = tmp[:, :4] * scale
+ mask = tmp[:, 4:5]
+ feat = tmp[:, 5:]
+ return flow, mask, feat
+
+
+class Head(nn.Module):
+ """Feature encoder producing four-channel features at full resolution."""
+
+ def __init__(self):
+ super().__init__()
+ self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1)
+ self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1)
+ self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1)
+ self.cnn3 = nn.ConvTranspose2d(16, 4, 4, 2, 1)
+ self.relu = nn.LeakyReLU(0.2, True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x0 = self.cnn0(x)
+ x = self.relu(x0)
+ x1 = self.cnn1(x)
+ x = self.relu(x1)
+ x2 = self.cnn2(x)
+ x = self.relu(x2)
+ x3 = self.cnn3(x)
+ return x3
+
+
+class IFNet(nn.Module):
+ """Four-scale IFNet optical flow network."""
+
+ def __init__(self):
+ super().__init__()
+ self.block0 = IFBlock(7 + 8, c=192)
+ self.block1 = IFBlock(8 + 4 + 8 + 8, c=128)
+ self.block2 = IFBlock(8 + 4 + 8 + 8, c=64)
+ self.block3 = IFBlock(8 + 4 + 8 + 8, c=32)
+ self.encode = Head()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ timestep: float = 0.5,
+ scale_list: list[float] | None = None,
+ ) -> tuple[list[torch.Tensor], torch.Tensor, list[tuple[torch.Tensor, torch.Tensor] | torch.Tensor]]:
+ if scale_list is None:
+ scale_list = [8, 4, 2, 1]
+
+ channel = x.shape[1] // 2
+ img0 = x[:, :channel]
+ img1 = x[:, channel:]
+
+ if not torch.is_tensor(timestep):
+ timestep = (x[:, :1].clone() * 0 + 1) * timestep
+ else:
+ timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3])
+
+ f0 = self.encode(img0[:, :3])
+ f1 = self.encode(img1[:, :3])
+
+ flow_list: list[torch.Tensor] = []
+ merged: list[tuple[torch.Tensor, torch.Tensor] | torch.Tensor] = []
+ mask_list: list[torch.Tensor] = []
+ warped_img0 = img0
+ warped_img1 = img1
+ flow = None
+ mask = None
+
+ for i, block in enumerate([self.block0, self.block1, self.block2, self.block3]):
+ if flow is None:
+ flow, mask, feat = block(
+ torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1),
+ None,
+ scale=scale_list[i],
+ )
+ else:
+ wf0 = warp(f0, flow[:, :2])
+ wf1 = warp(f1, flow[:, 2:4])
+ fd, m0, feat = block(
+ torch.cat(
+ (
+ warped_img0[:, :3],
+ warped_img1[:, :3],
+ wf0,
+ wf1,
+ timestep,
+ mask,
+ feat,
+ ),
+ 1,
+ ),
+ flow,
+ scale=scale_list[i],
+ )
+ mask = m0
+ flow = flow + fd
+
+ mask_list.append(mask)
+ flow_list.append(flow)
+ warped_img0 = warp(img0, flow[:, :2])
+ warped_img1 = warp(img1, flow[:, 2:4])
+ merged.append((warped_img0, warped_img1))
+
+ mask = torch.sigmoid(mask)
+ merged[3] = warped_img0 * mask + warped_img1 * (1 - mask)
+ return flow_list, mask_list[3], merged
+
+
+class Model:
+ """Wraps IFNet and exposes RIFE-compatible load/inference helpers."""
+
+ def __init__(self):
+ self.flownet = IFNet()
+
+ def eval(self) -> Model:
+ self.flownet.eval()
+ return self
+
+ def device(self) -> torch.device:
+ return next(self.flownet.parameters()).device
+
+ def load_model(self, path: str) -> None:
+ flownet_path = os.path.join(path, "flownet.pkl")
+ if not os.path.isfile(flownet_path):
+ raise FileNotFoundError(
+ f"RIFE weight file not found: {flownet_path}. Expected layout: /flownet.pkl"
+ )
+
+ state = torch.load(flownet_path, map_location="cpu", weights_only=False)
+ state = {k.removeprefix("module."): v for k, v in state.items()}
+ self.flownet.load_state_dict(state, strict=False)
+ logger.info("Loaded RIFE weights from %s", flownet_path)
+
+ def inference(
+ self,
+ img0: torch.Tensor,
+ img1: torch.Tensor,
+ scale: float = 1.0,
+ timestep: float = 0.5,
+ ) -> torch.Tensor:
+ _n, _c, h, w = img0.shape
+ ph = ((h - 1) // 32 + 1) * 32
+ pw = ((w - 1) // 32 + 1) * 32
+ pad = (0, pw - w, 0, ph - h)
+ img0 = F.pad(img0, pad)
+ img1 = F.pad(img1, pad)
+
+ imgs = torch.cat((img0, img1), 1)
+ scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale]
+ with torch.no_grad():
+ _flow_list, _mask, merged = self.flownet(
+ imgs,
+ timestep=timestep,
+ scale_list=scale_list,
+ )
+ return merged[3][:, :, :h, :w]
+
+
+def _resolve_rife_model_path(model_path: str | None) -> str:
+ model_path = model_path or _DEFAULT_RIFE_HF_REPO
+ if os.path.isdir(model_path):
+ return model_path
+ from vllm_omni.model_executor.model_loader.weight_utils import (
+ download_weights_from_hf_specific,
+ )
+
+ return download_weights_from_hf_specific(
+ model_path,
+ cache_dir=None,
+ allow_patterns=["flownet.pkl"],
+ require_all=True,
+ )
+
+
+def _select_torch_device() -> torch.device:
+ try:
+ from vllm_omni.platforms import current_omni_platform
+
+ return current_omni_platform.get_torch_device()
+ except Exception as exc:
+ logger.warning("Failed to resolve current vLLM-Omni torch device: %s", exc)
+
+ if torch.cuda.is_available():
+ return torch.device("cuda")
+ return torch.device("cpu")
+
+
+def _normalize_video_tensor_layout(video: torch.Tensor) -> tuple[torch.Tensor, Any]:
+ if video.ndim == 5:
+ if video.shape[1] in (3, 4):
+ return video, lambda out: out
+ if video.shape[2] in (3, 4):
+ return video.permute(0, 2, 1, 3, 4), lambda out: out.permute(0, 2, 1, 3, 4)
+ elif video.ndim == 4:
+ if video.shape[0] in (3, 4):
+ return video.unsqueeze(0), lambda out: out.squeeze(0)
+ if video.shape[1] in (3, 4):
+ return video.permute(1, 0, 2, 3).unsqueeze(0), lambda out: out.squeeze(0).permute(1, 0, 2, 3)
+ raise ValueError(f"Unsupported video tensor shape for interpolation: {tuple(video.shape)}")
+
+
+def _normalize_video_tensor_range(video: torch.Tensor) -> tuple[torch.Tensor, Any]:
+ original_dtype = video.dtype
+ video = video.detach()
+ if video.is_floating_point():
+ video = video.to(torch.float32)
+ if torch.amin(video) < 0.0 or torch.amax(video) > 1.0:
+ return video.clamp(-1.0, 1.0) * 0.5 + 0.5, lambda out: (out * 2.0 - 1.0).to(original_dtype)
+ return video.clamp(0.0, 1.0), lambda out: out.to(original_dtype)
+ return video.to(torch.float32) / 255.0, lambda out: (out * 255.0).round().clamp(0, 255).to(original_dtype)
+
+
+class FrameInterpolator:
+ """Lazy-loaded RIFE 4.22.lite frame interpolator."""
+
+ def __init__(self, model_path: str | None = None):
+ self._model_path = model_path
+ self._resolved_path: str | None = None
+
+ def _ensure_model_loaded(self, preferred_device: torch.device | None = None) -> Model:
+ resolved_path = _resolve_rife_model_path(self._model_path)
+ self._resolved_path = resolved_path
+ device = preferred_device or _select_torch_device()
+ cache_key = (resolved_path, str(device))
+
+ with _MODEL_CACHE_LOCK:
+ if cache_key in _MODEL_CACHE:
+ return _MODEL_CACHE[cache_key]
+
+ model = Model()
+ model.load_model(resolved_path)
+ model.eval()
+ model.flownet = model.flownet.to(device)
+ _MODEL_CACHE[cache_key] = model
+ logger.info("RIFE model loaded on device: %s", device)
+ return model
+
+ def _make_inference(
+ self,
+ model: Model,
+ img0: torch.Tensor,
+ img1: torch.Tensor,
+ n: int,
+ scale: float,
+ ) -> list[torch.Tensor]:
+ if n == 1:
+ return [model.inference(img0, img1, scale=scale)]
+ mid = model.inference(img0, img1, scale=scale)
+ return (
+ self._make_inference(model, img0, mid, n // 2, scale)
+ + [mid]
+ + self._make_inference(model, mid, img1, n // 2, scale)
+ )
+
+ def interpolate_tensor(
+ self,
+ video: torch.Tensor,
+ exp: int = 1,
+ scale: float = 1.0,
+ ) -> tuple[torch.Tensor, int]:
+ if exp < 1:
+ raise ValueError(f"frame interpolation exp must be >= 1, got {exp}")
+ if scale <= 0:
+ raise ValueError(f"frame interpolation scale must be > 0, got {scale}")
+
+ video, restore_layout = _normalize_video_tensor_layout(video)
+ if video.shape[2] < 2:
+ return restore_layout(video), 1
+
+ video, restore_range = _normalize_video_tensor_range(video)
+ # A CPU tensor may be transport/offload state rather than an execution
+ # choice, so only trust it when it is already on an accelerator.
+ preferred_device = video.device
+ if preferred_device.type == "cpu":
+ preferred_device = _select_torch_device()
+ model = self._ensure_model_loaded(preferred_device=preferred_device)
+ video = video.to(model.device())
+ intermediates_per_pair = 2**exp // 2
+
+ result_frames: list[torch.Tensor] = []
+ for idx in range(video.shape[2] - 1):
+ img0 = video[:, :, idx, :, :]
+ img1 = video[:, :, idx + 1, :, :]
+ result_frames.append(img0)
+ result_frames.extend(self._make_inference(model, img0, img1, intermediates_per_pair, scale))
+ result_frames.append(video[:, :, -1, :, :])
+ result = torch.stack(result_frames, dim=2)
+ return restore_layout(restore_range(result)), 2**exp
+
+
+def interpolate_video_tensor(
+ video: torch.Tensor,
+ exp: int = 1,
+ scale: float = 1.0,
+ model_path: str | None = None,
+) -> tuple[torch.Tensor, int]:
+ """Interpolate a video tensor and return the FPS multiplier."""
+ interpolator = FrameInterpolator(model_path=model_path)
+ return interpolator.interpolate_tensor(video, exp=exp, scale=scale)
diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py
index 97bc7fa292..0bf8c04517 100644
--- a/vllm_omni/diffusion/registry.py
+++ b/vllm_omni/diffusion/registry.py
@@ -119,8 +119,8 @@
"FluxKontextPipeline",
),
"HunyuanImage3ForCausalMM": (
- "hunyuan_image_3",
- "pipeline_hunyuan_image_3",
+ "hunyuan_image3",
+ "pipeline_hunyuan_image3",
"HunyuanImage3Pipeline",
),
"Flux2KleinPipeline": (
@@ -375,6 +375,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) -
"HunyuanVideo15ImageToVideoPipeline": "get_hunyuan_video_15_i2v_post_process_func",
"MagiHumanPipeline": "get_magi_human_post_process_func",
"OmniVoicePipeline": "get_omnivoice_post_process_func",
+ "DreamIDOmniPipeline": "get_dreamid_omni_post_process_func",
}
_DIFFUSION_PRE_PROCESS_FUNCS = {
diff --git a/vllm_omni/diffusion/stage_diffusion_client.py b/vllm_omni/diffusion/stage_diffusion_client.py
index cd7159b683..480d113d19 100644
--- a/vllm_omni/diffusion/stage_diffusion_client.py
+++ b/vllm_omni/diffusion/stage_diffusion_client.py
@@ -34,6 +34,24 @@
logger = init_logger(__name__)
+def create_diffusion_client(
+ model: str,
+ od_config: OmniDiffusionConfig,
+ metadata: StageMetadata,
+ stage_init_timeout: int,
+ batch_size: int = 1,
+ use_inline: bool = False,
+) -> Any:
+ """Factory to create either an inline or out-of-process diffusion client."""
+ if use_inline:
+ from vllm_omni.diffusion.inline_stage_diffusion_client import InlineStageDiffusionClient
+
+ return InlineStageDiffusionClient(model, od_config, metadata, batch_size=batch_size)
+ return StageDiffusionClient(
+ model, od_config, metadata, stage_init_timeout=stage_init_timeout, batch_size=batch_size
+ )
+
+
class StageDiffusionClient:
"""Communicates with StageDiffusionProc via ZMQ for use inside the Orchestrator.
@@ -50,11 +68,12 @@ def __init__(
model: str,
od_config: OmniDiffusionConfig,
metadata: StageMetadata,
+ stage_init_timeout: int,
batch_size: int = 1,
) -> None:
# Spawn StageDiffusionProc subprocess and wait for READY.
proc, handshake_address, request_address, response_address = spawn_diffusion_proc(model, od_config)
- complete_diffusion_handshake(proc, handshake_address)
+ complete_diffusion_handshake(proc, handshake_address, stage_init_timeout)
self._initialize_client(metadata, request_address, response_address, proc=proc, batch_size=batch_size)
@classmethod
@@ -153,6 +172,13 @@ def _drain_responses(self) -> None:
"error": True,
"reason": error_msg,
}
+ elif req_id is not None:
+ error_output = OmniRequestOutput.from_diffusion(
+ request_id=req_id,
+ images=[],
+ )
+ error_output.error = error_msg
+ self._output_queue.put_nowait(error_output)
# Fields that are subprocess-local and cannot be serialized across
# process boundaries. They are recreated in the subprocess with
diff --git a/vllm_omni/diffusion/stage_diffusion_proc.py b/vllm_omni/diffusion/stage_diffusion_proc.py
index 2bba419250..eced444fd3 100644
--- a/vllm_omni/diffusion/stage_diffusion_proc.py
+++ b/vllm_omni/diffusion/stage_diffusion_proc.py
@@ -19,12 +19,11 @@
import zmq.asyncio
from PIL import Image
from vllm.logger import init_logger
-from vllm.transformers_utils.config import get_hf_file_to_dict
from vllm.utils.network_utils import get_open_zmq_ipc_path, zmq_socket_ctx
from vllm.utils.system_utils import get_mp_context
from vllm.v1.utils import shutdown
-from vllm_omni.diffusion.data import DiffusionRequestAbortedError, TransformerConfig
+from vllm_omni.diffusion.data import DiffusionRequestAbortedError
from vllm_omni.diffusion.diffusion_engine import DiffusionEngine
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.distributed.omni_connectors.utils.serialization import (
@@ -39,8 +38,6 @@
logger = init_logger(__name__)
-_HANDSHAKE_POLL_TIMEOUT_S = 600
-
class StageDiffusionProc:
"""Subprocess entry point for diffusion inference.
@@ -68,47 +65,8 @@ def initialize(self) -> None:
logger.info("StageDiffusionProc initialized with model: %s", self._model)
def _enrich_config(self) -> None:
- """Load model metadata from HuggingFace and populate od_config fields.
-
- Diffusers-style models expose ``model_index.json`` with ``_class_name``.
- Non-diffusers models (e.g. Bagel, NextStep) only have ``config.json``,
- so we fall back to reading that and mapping model_type manually.
- """
- od_config = self._od_config
-
- try:
- config_dict = get_hf_file_to_dict("model_index.json", od_config.model)
- if config_dict is not None:
- if od_config.model_class_name is None:
- od_config.model_class_name = config_dict.get("_class_name", None)
- od_config.update_multimodal_support()
-
- tf_config_dict = get_hf_file_to_dict("transformer/config.json", od_config.model)
- od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict)
- else:
- raise FileNotFoundError("model_index.json not found")
- except (AttributeError, OSError, ValueError, FileNotFoundError):
- cfg = get_hf_file_to_dict("config.json", od_config.model)
- if cfg is None:
- raise ValueError(f"Could not find config.json or model_index.json for model {od_config.model}")
-
- od_config.tf_model_config = TransformerConfig.from_dict(cfg)
- model_type = cfg.get("model_type")
- architectures = cfg.get("architectures") or []
-
- if model_type == "bagel" or "BagelForConditionalGeneration" in architectures:
- od_config.model_class_name = "BagelPipeline"
- od_config.tf_model_config = TransformerConfig()
- od_config.update_multimodal_support()
- elif model_type == "nextstep":
- if od_config.model_class_name is None:
- od_config.model_class_name = "NextStep11Pipeline"
- od_config.tf_model_config = TransformerConfig()
- od_config.update_multimodal_support()
- elif architectures and len(architectures) == 1:
- od_config.model_class_name = architectures[0]
- else:
- raise
+ """Load model metadata from HuggingFace and populate od_config fields."""
+ self._od_config.enrich_config()
# ------------------------------------------------------------------
# Request processing
@@ -619,13 +577,14 @@ def spawn_diffusion_proc(
def complete_diffusion_handshake(
proc: BaseProcess,
handshake_address: str,
+ handshake_timeout: int,
) -> None:
"""Wait for the diffusion subprocess to signal READY.
On failure the process is terminated before re-raising.
"""
try:
- _perform_diffusion_handshake(proc, handshake_address)
+ _perform_diffusion_handshake(proc, handshake_address, handshake_timeout)
except Exception:
shutdown([proc])
raise
@@ -634,6 +593,7 @@ def complete_diffusion_handshake(
def _perform_diffusion_handshake(
proc: BaseProcess,
handshake_address: str,
+ handshake_timeout: int,
) -> None:
"""Run the handshake with the diffusion subprocess."""
with zmq_socket_ctx(handshake_address, zmq.ROUTER, bind=True) as handshake_socket:
@@ -641,11 +601,15 @@ def _perform_diffusion_handshake(
poller.register(handshake_socket, zmq.POLLIN)
poller.register(proc.sentinel, zmq.POLLIN)
- timeout_ms = _HANDSHAKE_POLL_TIMEOUT_S * 1000
+ timeout_ms = handshake_timeout * 1000
while True:
events = dict(poller.poll(timeout=timeout_ms))
if not events:
- raise TimeoutError("Timed out waiting for READY from StageDiffusionProc")
+ raise TimeoutError(
+ f"Timed out waiting for READY from StageDiffusionProc after {handshake_timeout}s. "
+ f"This typically indicates model loading or warmup is taking too long. "
+ f"Consider increasing `stage_init_timeout` for large models."
+ )
if handshake_socket in events:
identity, raw = handshake_socket.recv_multipart()
msg = msgspec.msgpack.decode(raw)
diff --git a/vllm_omni/diffusion/utils/media_utils.py b/vllm_omni/diffusion/utils/media_utils.py
index f96a28fbd7..a09cd45953 100644
--- a/vllm_omni/diffusion/utils/media_utils.py
+++ b/vllm_omni/diffusion/utils/media_utils.py
@@ -20,6 +20,7 @@ def mux_video_audio_bytes(
video_codec: str = "h264",
audio_codec: str = "aac",
crf: str = "18",
+ video_codec_options: dict[str, str] | None = None,
) -> bytes:
"""Mux video frames and optional audio waveform into MP4 bytes.
@@ -42,7 +43,11 @@ def mux_video_audio_bytes(
v_stream.width = video_frames.shape[2]
v_stream.height = video_frames.shape[1]
v_stream.pix_fmt = "yuv420p"
- v_stream.options = {"crf": crf}
+
+ options = {"crf": str(crf)}
+ if video_codec_options:
+ options.update(video_codec_options)
+ v_stream.options = options
a_stream = None
if audio_waveform is not None:
diff --git a/vllm_omni/diffusion/utils/prompt_utils.py b/vllm_omni/diffusion/utils/prompt_utils.py
new file mode 100644
index 0000000000..fc1769f4d5
--- /dev/null
+++ b/vllm_omni/diffusion/utils/prompt_utils.py
@@ -0,0 +1,38 @@
+import torch
+
+
+def validate_prompt_sequence_lengths(
+ attention_mask: torch.Tensor,
+ *,
+ max_sequence_length: int,
+ supported_max_sequence_length: int,
+ prompt_name: str = "prompt",
+ length_offset: int = 0,
+ baseline_attention_mask: torch.Tensor | None = None,
+ error_context: str,
+) -> None:
+ sequence_lengths = attention_mask.sum(dim=1)
+ if baseline_attention_mask is not None:
+ # Some callers need to validate only the user-controlled portion of a
+ # templated prompt. In those cases we subtract the fully-tokenized
+ # template baseline instead of only removing a fixed prefix length,
+ # because the template may also contribute a suffix or image markers.
+ baseline_lengths = baseline_attention_mask.sum(dim=1)
+ if baseline_lengths.shape[0] == 1 and sequence_lengths.shape[0] > 1:
+ baseline_lengths = baseline_lengths.expand(sequence_lengths.shape[0])
+ sequence_lengths = sequence_lengths - baseline_lengths
+ if length_offset:
+ sequence_lengths = sequence_lengths - length_offset
+ sequence_lengths = torch.clamp(sequence_lengths, min=0)
+ too_long = torch.nonzero(sequence_lengths > max_sequence_length, as_tuple=False)
+ if too_long.numel() == 0:
+ return
+
+ batch_idx = int(too_long[0].item())
+ actual_length = int(sequence_lengths[batch_idx].item())
+ prompt_ref = f"`{prompt_name}` at batch index {batch_idx}" if attention_mask.shape[0] > 1 else f"`{prompt_name}`"
+ raise ValueError(
+ f"{prompt_ref} is too long {error_context}: got {actual_length} tokens, but "
+ f"`max_sequence_length` is {max_sequence_length}. Shorten the prompt or increase "
+ f"`max_sequence_length` up to {supported_max_sequence_length}."
+ )
diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py
index 32ea5bf64d..535f053c38 100644
--- a/vllm_omni/diffusion/worker/diffusion_model_runner.py
+++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py
@@ -35,11 +35,12 @@
from vllm_omni.diffusion.worker.utils import DiffusionRequestState, RunnerOutput
from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager
from vllm_omni.platforms import current_omni_platform
+from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin
logger = init_logger(__name__)
-class DiffusionModelRunner:
+class DiffusionModelRunner(OmniConnectorModelRunnerMixin):
"""
Model runner that handles model loading and execution for diffusion models.
diff --git a/vllm_omni/diffusion/worker/diffusion_worker.py b/vllm_omni/diffusion/worker/diffusion_worker.py
index ea4b9d96f7..160309e0d8 100644
--- a/vllm_omni/diffusion/worker/diffusion_worker.py
+++ b/vllm_omni/diffusion/worker/diffusion_worker.py
@@ -20,6 +20,7 @@
from vllm.config import CompilationConfig, DeviceConfig, VllmConfig, set_current_vllm_config
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.logger import init_logger
+from vllm.profiler.wrapper import CudaProfilerWrapper, WorkerProfiler
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.mem_utils import GiB_bytes
from vllm.v1.worker.workspace import init_workspace_manager
@@ -83,15 +84,7 @@ def __init__(
od_config=self.od_config,
device=self.device,
)
- # Initialize profiler if configured
- self.profiler: OmniTorchProfilerWrapper | None = None
- profiler_config = self.od_config.profiler_config
- if profiler_config and profiler_config.profiler == "torch":
- self.profiler = create_omni_profiler(
- profiler_config=profiler_config,
- worker_name=f"diffusion_worker_{self.rank}",
- local_rank=self.local_rank,
- )
+ self.profiler: WorkerProfiler | None = self._create_profiler()
if not skip_load_model:
self.load_model(load_format=self.od_config.diffusion_load_format)
self.init_lora_manager()
@@ -122,6 +115,7 @@ def init_device(self) -> None:
vllm_config.parallel_config.tensor_parallel_size = self.od_config.parallel_config.tensor_parallel_size
vllm_config.parallel_config.data_parallel_size = self.od_config.parallel_config.data_parallel_size
vllm_config.parallel_config.enable_expert_parallel = self.od_config.parallel_config.enable_expert_parallel
+ vllm_config.profiler_config = self.od_config.profiler_config
self.vllm_config = vllm_config
# Initialize distributed environment
@@ -147,6 +141,24 @@ def init_device(self) -> None:
)
init_workspace_manager(self.device)
+ def _create_profiler(self) -> WorkerProfiler | None:
+ profiler_config = self.od_config.profiler_config
+ profiler_type = getattr(profiler_config, "profiler", None)
+ if profiler_type == "torch":
+ return create_omni_profiler(
+ profiler_config=profiler_config,
+ worker_name=f"diffusion_rank{self.rank}",
+ local_rank=self.local_rank,
+ )
+ if profiler_type == "cuda":
+ return CudaProfilerWrapper(profiler_config)
+ if profiler_type is not None:
+ logger.warning("Unknown profiler backend %r on diffusion worker %s", profiler_type, self.rank)
+ return None
+
+ def _get_profiler(self) -> WorkerProfiler | None:
+ return getattr(self, "profiler", None)
+
def load_model(self, load_format: str = "default", custom_pipeline_name: str | None = None) -> None:
"""Load the diffusion model using DiffusionModelRunner."""
with (
@@ -192,27 +204,21 @@ def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> N
Args:
is_start: True to start profiling, False to stop.
- profile_prefix: Optional prefix for trace filename (vLLM compat).
-
- Note:
- Matches vLLM's worker.profile() signature for consistency.
- Traces are saved automatically via on_trace_ready callback.
+ profile_prefix: Optional prefix for trace filename.
"""
- if self.profiler is None:
- logger.warning("Profiler not initialized, skipping profile(%s)", is_start)
+ profiler = self._get_profiler()
+ if profiler is None:
return
if is_start:
- from vllm_omni.profiler import OmniTorchProfilerWrapper
-
- if isinstance(self.profiler, OmniTorchProfilerWrapper):
+ if isinstance(profiler, OmniTorchProfilerWrapper):
import time
- filename = profile_prefix or f"diffusion_{int(time.time())}"
- self.profiler.set_trace_filename(filename)
- self.profiler.start()
+ filename = profile_prefix or f"diffusion_rank{self.rank}_{int(time.time())}"
+ profiler.set_trace_filename(filename)
+ profiler.start()
else:
- self.profiler.stop()
+ profiler.stop()
def execute_model(self, req: OmniDiffusionRequest, od_config: OmniDiffusionConfig) -> DiffusionOutput:
"""Execute a forward pass by delegating to the model runner."""
@@ -224,7 +230,13 @@ def execute_model(self, req: OmniDiffusionRequest, od_config: OmniDiffusionConfi
if req.sampling_params.lora_request is not None:
raise
logger.warning("LoRA activation skipped: %s", exc)
- return self.model_runner.execute_model(req)
+ profiler = self._get_profiler()
+ ctx = profiler.annotate_context_manager("diffusion_forward") if profiler else nullcontext()
+ with ctx:
+ output = self.model_runner.execute_model(req)
+ if profiler:
+ profiler.step()
+ return output
def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput:
"""Execute one diffusion step by delegating to the model runner."""
@@ -236,8 +248,13 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> Runner
if any(new_req.req.sampling_params.lora_request is not None for new_req in scheduler_output.scheduled_new_reqs):
raise ValueError("Step mode does not support LoRA yet.")
-
- return self.model_runner.execute_stepwise(scheduler_output)
+ profiler = self._get_profiler()
+ ctx = profiler.annotate_context_manager("diffusion_step") if profiler else nullcontext()
+ with ctx:
+ output = self.model_runner.execute_stepwise(scheduler_output)
+ if profiler:
+ profiler.step()
+ return output
def load_weights(self, weights) -> set[str]:
"""Load weights by delegating to the model runner."""
diff --git a/vllm_omni/distributed/omni_connectors/connectors/base.py b/vllm_omni/distributed/omni_connectors/connectors/base.py
index 83edb2ab0a..0df428f2ff 100644
--- a/vllm_omni/distributed/omni_connectors/connectors/base.py
+++ b/vllm_omni/distributed/omni_connectors/connectors/base.py
@@ -34,13 +34,21 @@ def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[
pass
@abstractmethod
- def get(self, from_stage: str, to_stage: str, get_key: str, metadata=None) -> tuple[Any, int] | None:
+ def get(
+ self, from_stage: str, to_stage: str, get_key: str, metadata: dict[str, Any] | None = None
+ ) -> tuple[Any, int] | None:
"""Retrieve Python object and payload size (bytes).
Args:
from_stage: Source stage identifier
to_stage: Destination stage identifier
get_key: Unique request identifier
+ metadata: Optional transport-specific metadata. When provided,
+ the connector uses it directly (e.g. source_host, source_port,
+ data_size) instead of querying the sender. For heterogeneous
+ TP the manager may supply partial metadata (host/port only);
+ the connector will query the sender at that address to fill
+ in data_size.
Returns:
Tuple of (Python object, serialized byte size) if found, None otherwise
diff --git a/vllm_omni/distributed/omni_connectors/connectors/mooncake_store_connector.py b/vllm_omni/distributed/omni_connectors/connectors/mooncake_store_connector.py
index c672e35f79..fa1fc3286d 100644
--- a/vllm_omni/distributed/omni_connectors/connectors/mooncake_store_connector.py
+++ b/vllm_omni/distributed/omni_connectors/connectors/mooncake_store_connector.py
@@ -78,7 +78,24 @@ def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[
try:
serialized_data = self.serialize_obj(data)
key = self._make_key(put_key, from_stage, to_stage)
- self.store.put(key, serialized_data, self.pin)
+ put_rc = self.store.put(key, serialized_data, self.pin)
+
+ if isinstance(put_rc, bool):
+ put_ok = put_rc
+ else:
+ put_ok = put_rc is None or put_rc == 0
+
+ if not put_ok:
+ self._metrics["errors"] += 1
+ logger.error(
+ "MooncakeStoreConnector put failed for %s (%s -> %s), rc=%r, %d bytes",
+ key,
+ from_stage,
+ to_stage,
+ put_rc,
+ len(serialized_data),
+ )
+ return False, 0, None
self._metrics["puts"] += 1
self._metrics["bytes_transferred"] += len(serialized_data)
diff --git a/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py b/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py
index 96a528963f..bd4160f3e6 100644
--- a/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py
+++ b/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py
@@ -230,16 +230,19 @@ class MooncakeTransferEngineConnector(OmniConnectorBase):
sender immediately cleans up the buffer (``cleanup()``), so only the
first receiver to pull a given key will succeed. Broadcast / multicast
(1 sender → N receivers sharing the same data) is not yet supported.
- - **1 receiver → 1 sender**: ``update_sender_info()`` stores a single
- ``(sender_host, sender_zmq_port)`` pair, so a receiver can only query
- metadata from one sender at a time.
+ - **1 receiver → N senders**: Supported via partial metadata. The
+ manager constructs metadata with the target sender's
+ ``source_host`` / ``source_port`` (computed from ``from_rank``)
+ and passes it to ``get(metadata=...)``. The connector detects
+ that ``data_size`` is missing, queries the specified sender at
+ the given address to fill it in, then performs the RDMA pull.
+ This enables heterogeneous TP (sender TP > receiver TP) where a
+ single receiver must pull KV shards from multiple sender ranks.
Future work:
- Support 1 sender → N receivers (e.g. reference-counted buffers, or
explicit ``retain()`` / ``release()`` semantics so the buffer survives
multiple pulls).
- - Support 1 receiver → N senders (e.g. a sender registry mapping
- ``get_key`` prefixes to different sender endpoints).
"""
# RDMA connector copies raw bytes/tensor directly to the memory pool
@@ -267,6 +270,7 @@ def __init__(self, config: dict[str, Any]):
self._req_local = threading.local()
self._worker_local = threading.local()
self._last_ttl_check: float = _time_mod.monotonic()
+ self._sender_endpoints: dict[int, tuple[str, int]] = {}
self._metrics = {
"puts": 0,
@@ -408,16 +412,38 @@ def get_connection_info(self) -> dict[str, Any]:
"can_put": self.can_put,
}
- def update_sender_info(self, sender_host: str, sender_zmq_port: int) -> None:
- """
- Inject the sender's ZMQ endpoint into the receiver connector.
- Used for NO METADATA GET calls.(E.g: KV-cache transfer path)
- Must be called before using get() without metadata!
- Otherwise, get() will raise an error.
+ def update_sender_info(
+ self,
+ sender_host: str,
+ sender_zmq_port: int,
+ sender_rank: int | None = None,
+ ) -> None:
+ """Inject a sender's ZMQ endpoint into the receiver connector.
+
+ When ``sender_rank`` is ``None`` (default), sets the single default
+ sender used by ``get()`` when no rank is specified — this preserves
+ backward-compatible 1:1 semantics.
+
+ When ``sender_rank`` is an integer, the endpoint is stored in a
+ per-rank registry for internal use (e.g. by
+ ``_query_metadata_from_sender(sender_rank=R)``).
"""
- self.sender_host = sender_host
- self.sender_zmq_port = sender_zmq_port
- logger.info(f"Sender info updated: host={sender_host!r}, zmq_port={sender_zmq_port}")
+ if sender_rank is not None:
+ self._sender_endpoints[sender_rank] = (sender_host, sender_zmq_port)
+ logger.info(
+ "Sender info updated for rank %s: host=%r, zmq_port=%s",
+ sender_rank,
+ sender_host,
+ sender_zmq_port,
+ )
+ else:
+ self.sender_host = sender_host
+ self.sender_zmq_port = sender_zmq_port
+ logger.info(
+ "Sender info updated (default): host=%r, zmq_port=%s",
+ sender_host,
+ sender_zmq_port,
+ )
def _get_local_ip(self) -> str:
"""
@@ -657,56 +683,75 @@ def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[
logger.error(f"RDMA Put failed for {put_key}: {e}", exc_info=True)
return False, 0, None
- def _query_metadata_from_sender(self, get_key: str) -> dict[str, Any] | None:
- """Query metadata from sender via ZMQ (fallback when ``metadata=None``).
-
- ``get()`` supports two metadata resolution paths::
-
- get(metadata=?)
- ├── metadata provided (adapter path)
- │ → use metadata directly (source_host/port/data_size)
- │ → RDMA pull
- └── metadata=None (KV-transfer polling path)
- → _query_metadata_from_sender(get_key) ← this method
- │
- ├── sender_host resolved (via update_sender_info)
- │ → ZMQ query → get data_size/is_fast_path
- │ → construct metadata → RDMA pull
- └── sender_host unresolved ("auto" / None)
- → return None → caller retries or times out
+ def _resolve_sender_endpoint(self, sender_rank: int | None = None) -> tuple[str, int] | None:
+ """Return ``(host, zmq_port)`` for *sender_rank*.
- For the second path, the caller must call
- :meth:`update_sender_info` before ``get()`` to resolve the sender's ZMQ endpoint.
- Support the two paths in case that the orchestrator pushes the request info
- to different stages at the same time knowing metadata or not.
+ Resolution order:
+ 1. Per-rank registry (``_sender_endpoints[sender_rank]``)
+ 2. Default sender (``sender_host`` / ``sender_zmq_port``)
+ 3. ``None`` if nothing is configured.
+ """
+ if sender_rank is not None and sender_rank in self._sender_endpoints:
+ return self._sender_endpoints[sender_rank]
+ host = getattr(self, "sender_host", None)
+ port = getattr(self, "sender_zmq_port", None)
+ if host and port and str(host).lower() != "auto":
+ return (host, int(port))
+ return None
+
+ def _query_metadata_at(self, get_key: str, host: str, port: int) -> dict[str, Any] | None:
+ """Query metadata from a sender endpoint via ZMQ.
+
+ Returns ``{source_host, source_port, data_size, is_fast_path}``
+ or ``None`` when the key is not found / the query fails.
"""
- zmq_addr = f"tcp://{self.sender_host}:{self.sender_zmq_port}"
+ zmq_addr = f"tcp://{host}:{port}"
req_socket = self._get_req_socket(zmq_addr, timeout_ms=5000)
-
try:
- # Send query request
- query = QueryRequest(request_id=get_key)
- req_socket.send(QUERY_INFO + msgspec.msgpack.encode(query))
+ req_socket.send(QUERY_INFO + msgspec.msgpack.encode(QueryRequest(request_id=get_key)))
resp = req_socket.recv()
-
if resp == INFO_NOT_FOUND:
return None
-
- # Parse response
query_resp = msgspec.msgpack.decode(resp, type=QueryResponse)
return {
- # source_host/source_port are used for verification
- "source_host": self.sender_host,
- "source_port": self.sender_zmq_port,
+ "source_host": host,
+ "source_port": port,
"data_size": query_resp.data_size,
"is_fast_path": query_resp.is_fast_path,
}
except Exception as e:
- # Socket may be stuck in bad state after timeout; discard it
self._invalidate_req_socket(zmq_addr)
- logger.debug(f"Failed to query metadata for {get_key}: {e}")
+ logger.debug("Failed to query metadata at %s for %s: %s", zmq_addr, get_key, e)
return None
+ def _query_metadata_from_sender(self, get_key: str, sender_rank: int | None = None) -> dict[str, Any] | None:
+ """Query metadata from sender via ZMQ (fallback when ``metadata=None``).
+
+ ``get()`` supports three metadata resolution paths::
+
+ get(metadata=?)
+ ├── Path 1: metadata has data_size (adapter path)
+ │ → use metadata directly → RDMA pull
+ ├── Path 2: metadata has source_host/port but no data_size
+ │ → _query_metadata_at(host, port) → get data_size → RDMA pull
+ └── Path 3: metadata=None (KV-transfer polling path)
+ → _query_metadata_from_sender(get_key) ← this method
+ │
+ ├── sender endpoint resolved (via update_sender_info)
+ │ → ZMQ query → get data_size/is_fast_path
+ │ → construct metadata → RDMA pull
+ └── sender endpoint unresolved
+ → return None → caller retries or times out
+
+ When *sender_rank* is provided, the query is routed to that
+ rank's endpoint (registered via ``update_sender_info(rank=...)``).
+ Otherwise the default sender is used.
+ """
+ endpoint = self._resolve_sender_endpoint(sender_rank)
+ if endpoint is None:
+ return None
+ return self._query_metadata_at(get_key, *endpoint)
+
def get(
self,
from_stage: str,
@@ -714,12 +759,18 @@ def get(
get_key: str,
metadata: dict[str, Any] | None = None,
) -> tuple[Any, int] | None:
- """
- Consumer Side.
- Allocates from local pool and pulls data via RDMA.
+ """Consumer Side. Allocates from local pool and pulls data via RDMA.
+
+ Metadata resolution:
- If metadata is not provided, will attempt to query it from sender
- using configured sender_host/sender_zmq_port.
+ 1. ``metadata`` provided **with** ``data_size`` → use directly (RDMA pull).
+ 2. ``metadata`` provided with ``source_host``/``source_port`` but
+ **without** ``data_size`` → query that specific sender for
+ ``data_size`` / ``is_fast_path``, then RDMA pull. This is the
+ heterogeneous-TP path where the manager knows the target sender
+ endpoint but not the payload size.
+ 3. ``metadata=None`` → query the default sender (set via
+ ``update_sender_info()``) for the full metadata.
Returns:
``(data, size)`` on success, ``None`` on failure.
@@ -727,9 +778,6 @@ def get(
- **is_fast_path=True** (tensor *or* bytes payload):
Returns ``(ManagedBuffer, size)``.
**CALLER MUST call ``ManagedBuffer.release()`` after consuming.**
- Note: even if the producer ``put()`` raw ``bytes``, the consumer
- receives a ``ManagedBuffer`` — use ``buf.to_bytes()`` to obtain
- a ``bytes`` copy, or ``buf.tensor`` for zero-copy access.
- **is_fast_path=False** (serialized Python object):
Returns ``(DeserializedObject, size)``.
Buffer is auto-released internally after deserialization.
@@ -741,9 +789,8 @@ def get(
_t0 = _time_mod.perf_counter()
- # If no metadata provided, try to query from sender
if not metadata:
- # Must insert sender info before using get() without metadata.
+ # Path 3: no metadata at all — query default sender
if not self.sender_host or not self.sender_zmq_port or str(self.sender_host).lower() == "auto":
raise RuntimeError(
f"get(metadata=None) requires sender info to be resolved, "
@@ -753,6 +800,21 @@ def get(
metadata = self._query_metadata_from_sender(get_key)
if not metadata:
return None
+ elif "data_size" not in metadata:
+ # Path 2: partial metadata (host/port only) — query that sender
+ partial_host = metadata.get("source_host")
+ partial_port = metadata.get("source_port")
+ if not partial_host or not partial_port:
+ logger.warning(
+ "get(%s): partial metadata missing source_host/source_port, cannot resolve data_size. metadata=%s",
+ get_key,
+ metadata,
+ )
+ return None
+ queried = self._query_metadata_at(get_key, str(partial_host), int(partial_port))
+ if not queried:
+ return None
+ metadata = queried
_t1 = _time_mod.perf_counter()
_query_ms = (_t1 - _t0) * 1000
diff --git a/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py b/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py
index 5c7384c1f8..6cf5c2f15b 100644
--- a/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py
+++ b/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py
@@ -15,9 +15,13 @@
class SharedMemoryConnector(OmniConnectorBase):
- """
- Connector that uses SharedMemory for large objects and inline data for small objects.
- Acts as a unified replacement for the legacy IPC fallback logic.
+ """Key-addressed local shared-memory connector.
+
+ SHM is a local-only transport: it reads/writes POSIX shared memory
+ segments identified purely by *key*. It does **not** understand
+ remote-transport metadata such as ``source_host`` / ``source_port``
+ (that is the RDMA connector's job). When such metadata is passed in,
+ the connector silently falls back to key-based lookup.
"""
def __init__(self, config: dict[str, Any]):
@@ -25,6 +29,7 @@ def __init__(self, config: dict[str, Any]):
self.stage_id = config.get("stage_id", -1)
self.device = config.get("device", "cuda:0")
self.threshold = int(config.get("shm_threshold_bytes", 65536))
+ self._pending_keys: set[str] = set()
self._metrics = {
"puts": 0,
"gets": 0,
@@ -59,6 +64,7 @@ def put(
# meta contains {'name': ..., 'size': ...}
metadata = {"shm": meta, "size": size}
+ self._pending_keys.add(put_key)
self._metrics["shm_writes"] += 1
else:
# Inline - pass bytes directly to avoid double serialization of the object
@@ -93,6 +99,28 @@ def _get_data_with_lock(self, lock_file: str, shm_handle: dict):
if obj and os.path.exists(lock_file):
os.remove(lock_file)
+ def _get_by_key(self, get_key: str) -> tuple[Any, int] | None:
+ """Read a SHM segment addressed purely by *get_key*."""
+ shm = None
+ try:
+ shm = shm_pkg.SharedMemory(name=get_key)
+ if shm is None or shm.size == 0:
+ return None
+ lock_file = f"/dev/shm/shm_{get_key}_lockfile.lock"
+ shm_handle = {"name": get_key, "size": shm.size}
+ result = self._get_data_with_lock(lock_file, shm_handle)
+ if result is not None:
+ self._pending_keys.discard(get_key)
+ return result
+ except FileNotFoundError:
+ return None
+ except Exception:
+ logger.debug("_get_by_key: unexpected error reading SHM segment %s", get_key, exc_info=True)
+ return None
+ finally:
+ if shm:
+ shm.close()
+
def get(
self,
from_stage: str,
@@ -101,16 +129,16 @@ def get(
metadata=None,
) -> tuple[Any, int] | None:
if metadata is not None:
- # Some callers may wrap metadata by request id.
if isinstance(metadata, dict) and get_key in metadata:
metadata = metadata.get(get_key)
if not isinstance(metadata, dict):
- return None
+ return self._get_by_key(get_key)
if "inline_bytes" in metadata:
try:
obj = self.deserialize_obj(metadata["inline_bytes"])
+ self._pending_keys.discard(get_key)
return obj, int(metadata.get("size", 0))
except Exception as e:
logger.error(f"SharedMemoryConnector inline get failed for req {get_key}: {e}")
@@ -119,33 +147,64 @@ def get(
if "shm" in metadata:
shm_handle = metadata["shm"]
lock_file = f"/dev/shm/shm_{shm_handle['name']}_lockfile.lock"
- return self._get_data_with_lock(lock_file, shm_handle)
+ result = self._get_data_with_lock(lock_file, shm_handle)
+ if result is not None:
+ self._pending_keys.discard(get_key)
+ return result
- return None
- shm = None
- try:
- shm = shm_pkg.SharedMemory(name=get_key)
- if shm is None or shm.size == 0:
- return None
- lock_file = f"/dev/shm/shm_{get_key}_lockfile.lock"
- shm_handle = {"name": get_key, "size": shm.size}
- return self._get_data_with_lock(lock_file, shm_handle)
- except Exception:
- return None
- finally:
- if shm:
- shm.close()
+ # Metadata is a dict but has no SHM-specific handle (e.g. RDMA-
+ # style source_host/source_port). Fall back to key-based read.
+ return self._get_by_key(get_key)
+
+ return self._get_by_key(get_key)
def cleanup(self, request_id: str) -> None:
- # SHM segments are automatically unlinked during 'get' (shm_read_bytes).
- # If 'get' is never called (e.g. error flow), the SHM segment might leak.
- # A robust implementation might track created segments and unlink them here
- # if they haven't been consumed.
- # For now, we rely on the consumer to read and unlink.
- pass
+ """Best-effort cleanup of unconsumed SHM segments for *request_id*.
+
+ Matches pending keys where *request_id* appears as the full key,
+ as a ``_``-delimited prefix, or as a ``_``-delimited suffix.
+ If ``get()`` was never called, we unlink it here so /dev/shm
+ doesn't leak.
+ """
+ stale = [
+ k
+ for k in self._pending_keys
+ if k == request_id or k.startswith(request_id + "_") or k.endswith("_" + request_id)
+ ]
+ for key in stale:
+ self._pending_keys.discard(key)
+ try:
+ seg = shm_pkg.SharedMemory(name=key)
+ seg.close()
+ seg.unlink()
+ logger.debug("cleanup: unlinked unconsumed SHM segment %s", key)
+ except FileNotFoundError:
+ pass
+ except Exception as e:
+ logger.debug("cleanup: failed to unlink SHM segment %s: %s", key, e)
+ lock_file = f"/dev/shm/shm_{key}_lockfile.lock"
+ if os.path.exists(lock_file):
+ try:
+ os.remove(lock_file)
+ except OSError:
+ pass
def close(self) -> None:
- pass
+ """Unlink all remaining tracked SHM segments."""
+ for key in list(self._pending_keys):
+ try:
+ seg = shm_pkg.SharedMemory(name=key)
+ seg.close()
+ seg.unlink()
+ except Exception:
+ pass
+ lock_file = f"/dev/shm/shm_{key}_lockfile.lock"
+ if os.path.exists(lock_file):
+ try:
+ os.remove(lock_file)
+ except OSError:
+ pass
+ self._pending_keys.clear()
def health(self) -> dict[str, Any]:
return {"status": "healthy", "threshold": self.threshold, **self._metrics}
diff --git a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py
index 1958c9d40a..ad008c3971 100644
--- a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py
+++ b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py
@@ -14,8 +14,20 @@
from .factory import OmniConnectorFactory
from .utils.config import ConnectorSpec
-from .utils.initialization import KV_TRANSFER_PORT_OFFSET
-from .utils.kv_utils import normalize_layer_kv
+from .utils.initialization import KV_RANK_PORT_STRIDE
+from .utils.kv_utils import (
+ KVTPTopology,
+ build_rank_aware_recv_keys,
+ build_rank_aware_send_keys,
+ get_kv_target_ranks,
+ get_local_tp_rank,
+ get_tp_world_size,
+ kv_zmq_port,
+ merge_received_rank_shards,
+ normalize_layer_kv,
+ slice_layer_blocks,
+ slice_received_rank_shard,
+)
logger = init_logger(__name__)
@@ -57,6 +69,8 @@ class OmniKVCacheConfig:
need_recv_cache: bool = False
need_send_cache: bool = False
recv_timeout: float = 30.0
+ from_tp: int = 1
+ to_tp: int = 1
@dataclass
@@ -72,82 +86,44 @@ def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for serialization."""
return asdict(self)
- def to_bytes(self) -> bytes:
- """Convert to compact binary format for fast transfer."""
- tensors_desc: list[dict[str, Any]] = []
- tensor_bufs: list[bytes] = []
- data_offset = 0
-
- for cache_name in ("key_cache", "value_cache"):
- cache_list = self.layer_blocks.get(cache_name, [])
- for layer_idx, tensor in enumerate(cache_list):
- if tensor is None:
- tensors_desc.append({"n": f"{cache_name}_{layer_idx}", "x": True})
- continue
-
- t = tensor.detach().cpu().contiguous()
- dtype_str = str(t.dtype).removeprefix("torch.")
- raw = t.view(torch.uint8).numpy().tobytes()
- tensors_desc.append(
- {
- "n": f"{cache_name}_{layer_idx}",
- "i": layer_idx,
- "d": dtype_str,
- "s": list(t.shape),
- "o": data_offset,
- "b": len(raw),
- }
- )
- tensor_bufs.append(raw)
- data_offset += len(raw)
-
- header = json.dumps(
- {
- "rid": self.request_id,
- "bids": self.block_ids,
- "meta": self.metadata,
- "td": tensors_desc,
- "nl": len(self.layer_blocks.get("key_cache", [])),
- },
- separators=(",", ":"),
- ).encode("utf-8")
- return b"".join([struct.pack(">I", len(header)), header] + tensor_bufs)
+ def _build_tensors_desc(self, *, cpu: bool) -> tuple[list[dict[str, Any]], list, int, torch.device | None]:
+ """Iterate layer blocks and build tensor descriptors + data chunks.
- def to_gpu_tensor(self) -> torch.Tensor:
- """Convert to a packed GPU tensor for raw-data connectors."""
+ Returns ``(tensors_desc, chunks, total_bytes, device)``.
+ *chunks* contains ``bytes`` when *cpu* is True, flat uint8 GPU tensors otherwise.
+ """
tensors_desc: list[dict[str, Any]] = []
- gpu_tensors: list[torch.Tensor] = []
+ chunks: list = []
data_offset = 0
device = None
for cache_name in ("key_cache", "value_cache"):
- cache_list = self.layer_blocks.get(cache_name, [])
- for layer_idx, tensor in enumerate(cache_list):
+ for layer_idx, tensor in enumerate(self.layer_blocks.get(cache_name, [])):
if tensor is None:
tensors_desc.append({"n": f"{cache_name}_{layer_idx}", "x": True})
continue
-
t = tensor.detach().contiguous()
- if device is None and t.is_cuda:
+ if cpu:
+ t = t.cpu()
+ elif device is None and t.is_cuda:
device = t.device
- dtype_str = str(t.dtype).removeprefix("torch.")
nbytes = t.numel() * t.element_size()
tensors_desc.append(
{
"n": f"{cache_name}_{layer_idx}",
"i": layer_idx,
- "d": dtype_str,
+ "d": str(t.dtype).removeprefix("torch."),
"s": list(t.shape),
"o": data_offset,
"b": nbytes,
}
)
- gpu_tensors.append(t.view(torch.uint8).flatten())
+ chunks.append(t.view(torch.uint8).numpy().tobytes() if cpu else t.view(torch.uint8).flatten())
data_offset += nbytes
- if device is None:
- raise RuntimeError("No CUDA tensors found, use to_bytes() instead")
+ return tensors_desc, chunks, data_offset, device
+ def _build_header_bytes(self, tensors_desc: list[dict[str, Any]]) -> bytes:
header = json.dumps(
{
"rid": self.request_id,
@@ -158,19 +134,26 @@ def to_gpu_tensor(self) -> torch.Tensor:
},
separators=(",", ":"),
).encode("utf-8")
+ return struct.pack(">I", len(header)) + header
- header_prefix = struct.pack(">I", len(header)) + header
- total_size = len(header_prefix) + data_offset
- output = torch.empty(total_size, dtype=torch.uint8, device=device)
- header_tensor = torch.frombuffer(bytearray(header_prefix), dtype=torch.uint8)
- output[: len(header_prefix)].copy_(header_tensor)
+ def to_bytes(self) -> bytes:
+ """Convert to compact binary format for fast transfer."""
+ tensors_desc, chunks, _, _ = self._build_tensors_desc(cpu=True)
+ return b"".join([self._build_header_bytes(tensors_desc)] + chunks)
+ def to_gpu_tensor(self) -> torch.Tensor:
+ """Convert to a packed GPU tensor for raw-data connectors."""
+ tensors_desc, chunks, data_offset, device = self._build_tensors_desc(cpu=False)
+ if device is None:
+ raise RuntimeError("No CUDA tensors found, use to_bytes() instead")
+ header_prefix = self._build_header_bytes(tensors_desc)
+ output = torch.empty(len(header_prefix) + data_offset, dtype=torch.uint8, device=device)
+ output[: len(header_prefix)].copy_(torch.frombuffer(bytearray(header_prefix), dtype=torch.uint8))
pos = len(header_prefix)
- for t_flat in gpu_tensors:
+ for t_flat in chunks:
n = t_flat.numel()
output[pos : pos + n].copy_(t_flat)
pos += n
-
return output
@staticmethod
@@ -237,11 +220,8 @@ def _resolve_layer_idx(info: dict[str, Any], num_layers: int) -> int:
return layer_idx
@staticmethod
- def from_bytes(raw: "bytes | bytearray | memoryview") -> dict[str, Any]:
- """Reconstruct KV cache data from the packed bytes format."""
- raw_mv = memoryview(raw) if not isinstance(raw, memoryview) else raw
- header, tensor_data_mv = KVCacheTransferData._load_header_from_memoryview(raw_mv)
-
+ def _populate_caches(header: dict[str, Any], get_tensor: callable) -> dict[str, Any]:
+ """Shared deserialization loop for both CPU and GPU paths."""
num_layers = header["nl"]
key_cache: list[torch.Tensor | None] = [None] * num_layers
value_cache: list[torch.Tensor | None] = [None] * num_layers
@@ -249,20 +229,9 @@ def from_bytes(raw: "bytes | bytearray | memoryview") -> dict[str, Any]:
for info in header["td"]:
if info.get("x"):
continue
-
name: str = info["n"]
torch_dtype = KVCacheTransferData._resolve_torch_dtype(info["d"])
- offset, nbytes = KVCacheTransferData._validate_tensor_span(name, info, len(tensor_data_mv))
- t = (
- torch.frombuffer(
- tensor_data_mv,
- dtype=torch.uint8,
- offset=offset,
- count=nbytes,
- )
- .view(torch_dtype)
- .reshape(info["s"])
- )
+ t = get_tensor(info).view(torch_dtype).reshape(info["s"])
layer_idx = KVCacheTransferData._resolve_layer_idx(info, num_layers)
if name.startswith("key_cache_"):
key_cache[layer_idx] = t
@@ -276,37 +245,30 @@ def from_bytes(raw: "bytes | bytearray | memoryview") -> dict[str, Any]:
"metadata": header["meta"],
}
+ @staticmethod
+ def from_bytes(raw: "bytes | bytearray | memoryview") -> dict[str, Any]:
+ """Reconstruct KV cache data from the packed bytes format."""
+ raw_mv = memoryview(raw) if not isinstance(raw, memoryview) else raw
+ header, tensor_data_mv = KVCacheTransferData._load_header_from_memoryview(raw_mv)
+ data_len = len(tensor_data_mv)
+
+ def _get(info: dict) -> torch.Tensor:
+ offset, nbytes = KVCacheTransferData._validate_tensor_span(info["n"], info, data_len)
+ return torch.frombuffer(tensor_data_mv, dtype=torch.uint8, offset=offset, count=nbytes)
+
+ return KVCacheTransferData._populate_caches(header, _get)
+
@staticmethod
def from_bytes_gpu(gpu_tensor: torch.Tensor) -> dict[str, Any]:
"""Reconstruct KV cache data from a packed GPU tensor."""
header, data_start = KVCacheTransferData._load_header_from_tensor(gpu_tensor)
+ data_len = int(gpu_tensor.numel()) - data_start
- num_layers = header["nl"]
- key_cache: list[torch.Tensor | None] = [None] * num_layers
- value_cache: list[torch.Tensor | None] = [None] * num_layers
- tensor_data_bytes = int(gpu_tensor.numel()) - data_start
+ def _get(info: dict) -> torch.Tensor:
+ offset, nbytes = KVCacheTransferData._validate_tensor_span(info["n"], info, data_len)
+ return gpu_tensor[data_start + offset : data_start + offset + nbytes].clone()
- for info in header["td"]:
- if info.get("x"):
- continue
-
- name: str = info["n"]
- torch_dtype = KVCacheTransferData._resolve_torch_dtype(info["d"])
- offset, nbytes = KVCacheTransferData._validate_tensor_span(name, info, tensor_data_bytes)
- t = gpu_tensor[data_start + offset : data_start + offset + nbytes].clone()
- t = t.view(torch_dtype).reshape(info["s"])
- layer_idx = KVCacheTransferData._resolve_layer_idx(info, num_layers)
- if name.startswith("key_cache_"):
- key_cache[layer_idx] = t
- elif name.startswith("value_cache_"):
- value_cache[layer_idx] = t
-
- return {
- "request_id": header["rid"],
- "layer_blocks": {"key_cache": key_cache, "value_cache": value_cache},
- "block_ids": header["bids"],
- "metadata": header["meta"],
- }
+ return KVCacheTransferData._populate_caches(header, _get)
class OmniKVTransferManager:
@@ -341,6 +303,30 @@ def __init__(self, config: OmniKVCacheConfig):
else (None, None)
)
+ local_rank = get_local_tp_rank()
+
+ if config.from_tp <= 1 and config.to_tp <= 1:
+ detected_tp = get_tp_world_size()
+ from_tp = detected_tp
+ to_tp = detected_tp
+ else:
+ from_tp = config.from_tp
+ to_tp = config.to_tp
+
+ self._tp_topo = KVTPTopology(source_tp_size=from_tp, target_tp_size=to_tp, local_rank=local_rank)
+
+ # Injectable hooks (compatible with PR #2677 OmniConnectorModelRunnerMixin).
+ self.kv_send_key_builder: Callable | None = None
+ self.kv_recv_key_builder: Callable | None = None
+ self.kv_payload_merger: Callable | None = None
+ self.kv_payload_slicer: Callable | None = None
+
+ # Base sender endpoint (rank-0 host/port) stored during
+ # update_sender_info(). Used by the receive path to construct
+ # per-rank metadata for heterogeneous TP without querying a registry.
+ self._sender_base_host: str | None = None
+ self._sender_base_zmq_port: int | None = None
+
if config.need_send_cache and config.connector_config:
try:
_ = self.connector
@@ -348,11 +334,20 @@ def __init__(self, config: OmniKVCacheConfig):
except Exception as e:
logger.warning("Failed to eagerly initialize sender connector: %s", e)
+ # ------------------------------------------------------------------ #
+ # Factory helpers
+ # ------------------------------------------------------------------ #
+
@classmethod
def _create(cls, cfg: dict | None) -> "OmniKVTransferManager":
"""Create manager from raw config dict."""
if not cfg or not isinstance(cfg, dict):
return cls(OmniKVCacheConfig())
+
+ rank_mapping = cfg.get("rank_mapping", {})
+ if not isinstance(rank_mapping, dict):
+ rank_mapping = {}
+
return cls(
OmniKVCacheConfig(
connector_config=cfg.get("connector_config"),
@@ -363,19 +358,18 @@ def _create(cls, cfg: dict | None) -> "OmniKVTransferManager":
need_recv_cache=cfg.get("need_recv_cache", False),
need_send_cache=cfg.get("need_send_cache", False),
recv_timeout=cfg.get("recv_timeout", 30.0),
+ from_tp=int(rank_mapping.get("from_tp", 1)),
+ to_tp=int(rank_mapping.get("to_tp", 1)),
)
)
- @classmethod
- def from_model_config(cls, config: Any) -> "OmniKVTransferManager":
- """Create from model config (for AR model runner)."""
- return cls._create(getattr(config, "omni_kv_config", None))
-
@classmethod
def from_od_config(cls, config: Any) -> "OmniKVTransferManager":
- """Create from OmniDiffusion config (for diffusion runner)."""
+ """Create from model or OmniDiffusion config."""
return cls._create(getattr(config, "omni_kv_config", None))
+ from_model_config = from_od_config
+
@classmethod
def from_vllm_config(cls, vllm_config: Any, model_config: Any) -> "OmniKVTransferManager":
"""Create from vllm config with fallback to kv_transfer_config."""
@@ -417,45 +411,33 @@ def connector(self):
)
c_extra["to_stage"] = str(self.config.to_stage) if self.config.to_stage is not None else "1"
+ try:
+ stage_int = int(self.config.from_stage) if self.config.from_stage is not None else 0
+ except (TypeError, ValueError):
+ stage_int = 0
+ zmq_port = kv_zmq_port(base_port, stage_int, self._tp_topo.local_rank)
+
if self.config.need_send_cache:
c_extra["role"] = "sender"
- from_stage = self.config.from_stage
- if from_stage is not None:
- try:
- c_extra["zmq_port"] = base_port + KV_TRANSFER_PORT_OFFSET + int(from_stage)
- except (TypeError, ValueError):
- c_extra["zmq_port"] = base_port + KV_TRANSFER_PORT_OFFSET
+ c_extra["zmq_port"] = zmq_port
elif self.config.need_recv_cache:
c_extra["role"] = "receiver"
- from_stage = self.config.from_stage
- sender_port = base_port + KV_TRANSFER_PORT_OFFSET
- if from_stage is not None:
- try:
- sender_port = base_port + KV_TRANSFER_PORT_OFFSET + int(from_stage)
- except (TypeError, ValueError):
- pass
c_extra.setdefault("sender_host", c_extra.get("host", "127.0.0.1"))
- c_extra.setdefault("sender_zmq_port", sender_port)
+ c_extra.setdefault("sender_zmq_port", zmq_port)
logger.info(
- "Initializing OmniConnector (purpose=kv_transfer) with config: %s, role: %s",
- cfg,
+ "Initializing OmniConnector type=%s role=%s",
+ c_type,
c_extra.get("role", "N/A"),
)
self._connector = OmniConnectorFactory.create_connector(ConnectorSpec(name=c_type, extra=c_extra))
- except Exception as e:
- logger.error(f"Failed to initialize OmniConnector: {e}")
- import traceback
-
- traceback.print_exc()
- # Cache failure sentinel to avoid repeated initialization attempts in hot paths.
+ except Exception:
+ logger.exception("Failed to initialize OmniConnector")
self._connector = False
return self._connector if self._connector else None
- def get_connector(self):
- """Get connector (compatibility wrapper for existing code)."""
- return self.connector
+ get_connector = property(lambda self: self.connector)
def _resolve_sender_info(
self, sender_info: dict[str, Any], sender_stage_id: str | int | None = None
@@ -513,8 +495,187 @@ def _clone_received_payload_tensors(data: dict[str, Any]) -> dict[str, Any]:
cache_list[idx] = tensor.clone()
return data
+ def _slice_transfer_data_for_target(self, kv_data: KVCacheTransferData, target_rank: int) -> KVCacheTransferData:
+ """Pre-slice sender payload for one target rank when sender TP < receiver TP."""
+ topo = self._tp_topo
+ ratio = topo.target_tp_size // topo.source_tp_size
+ offset_in_sender = target_rank % ratio
+ metadata = dict(kv_data.metadata) if isinstance(kv_data.metadata, dict) else {}
+ metadata["tp_head_slice"] = {
+ "applied": True,
+ "side": "sender",
+ "target_rank": target_rank,
+ "source_rank": topo.local_rank,
+ "from_tp": topo.source_tp_size,
+ "to_tp": topo.target_tp_size,
+ "offset_in_shard": offset_in_sender,
+ "num_slices": ratio,
+ }
+ return KVCacheTransferData(
+ request_id=kv_data.request_id,
+ layer_blocks=slice_layer_blocks(kv_data.layer_blocks, offset_in_sender, ratio),
+ block_ids=list(kv_data.block_ids),
+ metadata=metadata,
+ )
+
+ def _serialize_transfer_payload(self, kv_data: KVCacheTransferData) -> torch.Tensor | bytes | dict[str, Any]:
+ """Serialize KV transfer data using the connector's fastest supported path."""
+ if getattr(self.connector, "supports_raw_data", False):
+ try:
+ return kv_data.to_gpu_tensor()
+ except Exception:
+ pass
+ try:
+ return kv_data.to_bytes()
+ except Exception:
+ return kv_data.to_dict()
+
+ @staticmethod
+ def _collect_request_kv_payload(req: Any) -> dict[str, object]:
+ """Collect request-side KV objects for object broadcast."""
+ kv_payload: dict[str, object] = {}
+ for attr in ("past_key_values", "kv_metadata"):
+ val = getattr(req, attr, None)
+ if val is not None:
+ kv_payload[attr] = val
+
+ if hasattr(req, "sampling_params") and req.sampling_params is not None:
+ for key in list(vars(req.sampling_params).keys()):
+ if key in ("past_key_values", "kv_metadata") or (
+ key.startswith("cfg_")
+ and (
+ key.endswith("_past_key_values")
+ or key.endswith("_kv_metadata")
+ or key
+ in (
+ "cfg_kv_request_ids",
+ "cfg_active_branch",
+ "cfg_branch_roles",
+ "cfg_branch_past_key_values",
+ "cfg_branch_kv_metadata",
+ )
+ )
+ ):
+ val = getattr(req.sampling_params, key, None)
+ if val is not None:
+ kv_payload[f"sp.{key}"] = val
+
+ return kv_payload
+
+ @staticmethod
+ def _apply_request_kv_payload(
+ req: Any,
+ kv_payload: dict[str, object],
+ target_device: torch.device | None = None,
+ ) -> None:
+ """Apply a broadcast KV payload back onto a request object."""
+ for attr in ("past_key_values", "kv_metadata"):
+ val = kv_payload.get(attr)
+ if val is not None:
+ if target_device is not None:
+ val = _move_to_device(val, target_device)
+ setattr(req, attr, val)
+
+ if hasattr(req, "sampling_params") and req.sampling_params is not None:
+ for key, val in kv_payload.items():
+ if key.startswith("sp."):
+ if target_device is not None:
+ val = _move_to_device(val, target_device)
+ setattr(req.sampling_params, key[3:], val)
+
+ @staticmethod
+ def _discover_cfg_branch_roles(req: Any) -> list[str]:
+ """Discover CFG branch roles in a stable order."""
+ sampling_params = getattr(req, "sampling_params", None)
+ if sampling_params is None:
+ return []
+
+ roles: list[str] = []
+ branch_map = getattr(sampling_params, "cfg_branch_past_key_values", None) or {}
+ for preferred_role in ("cfg_text", "cfg_img"):
+ if (
+ preferred_role in branch_map
+ or getattr(sampling_params, f"{preferred_role}_past_key_values", None) is not None
+ ):
+ roles.append(preferred_role)
+
+ for role in branch_map.keys():
+ if role not in roles and branch_map.get(role) is not None:
+ roles.append(role)
+
+ for key in vars(sampling_params).keys():
+ if not (key.startswith("cfg_") and key.endswith("_past_key_values")):
+ continue
+ role = key.removesuffix("_past_key_values")
+ if role in ("cfg_branch",) or role in roles:
+ continue
+ if getattr(sampling_params, key, None) is not None:
+ roles.append(role)
+
+ return roles
+
+ @classmethod
+ def _build_cfg_rank_local_payloads(cls, req: Any, cfg_size: int) -> list[dict[str, object] | None]:
+ """Build per-cfg-rank payloads so each rank receives only its branch KV."""
+ full_payload = cls._collect_request_kv_payload(req)
+ payloads: list[dict[str, object] | None] = []
+
+ main_payload = {
+ key: value
+ for key, value in full_payload.items()
+ if key in ("past_key_values", "kv_metadata", "sp.past_key_values", "sp.kv_metadata")
+ }
+ branch_roles = cls._discover_cfg_branch_roles(req)
+ if branch_roles:
+ main_payload["sp.cfg_branch_roles"] = list(branch_roles)
+ main_payload["sp.cfg_active_branch"] = None
+ payloads.append(main_payload or None)
+
+ sampling_params = getattr(req, "sampling_params", None)
+ branch_map = getattr(sampling_params, "cfg_branch_past_key_values", None) or {}
+ branch_metadata_map = getattr(sampling_params, "cfg_branch_kv_metadata", None) or {}
+
+ for role in branch_roles:
+ if sampling_params is None:
+ payloads.append(None)
+ continue
+
+ branch_kv = branch_map.get(role)
+ if branch_kv is None:
+ branch_kv = getattr(sampling_params, f"{role}_past_key_values", None)
+ branch_metadata = branch_metadata_map.get(role)
+ if branch_metadata is None:
+ branch_metadata = getattr(sampling_params, f"{role}_kv_metadata", None)
+ if branch_kv is None:
+ payloads.append(None)
+ continue
+
+ local_payload = dict(main_payload)
+ local_payload["sp.cfg_active_branch"] = role
+ local_payload["sp.cfg_branch_roles"] = list(branch_roles)
+ local_payload["sp.cfg_branch_past_key_values"] = {role: branch_kv}
+ local_payload[f"sp.{role}_past_key_values"] = branch_kv
+ if branch_metadata is not None:
+ local_payload["sp.cfg_branch_kv_metadata"] = {role: branch_metadata}
+ local_payload[f"sp.{role}_kv_metadata"] = branch_metadata
+
+ payloads.append(local_payload)
+
+ while len(payloads) < cfg_size:
+ payloads.append(None)
+
+ return payloads[:cfg_size]
+
def update_sender_info(self, sender_info: dict[str, Any], sender_stage_id: str | int | None = None) -> None:
- """Update receiver-side sender info before loading remote KV cache."""
+ """Update receiver-side sender info before loading remote KV cache.
+
+ The orchestrator always reports rank-0's ZMQ port. When TP > 1 the
+ receiver must offset the port so that each TP rank connects to the
+ corresponding sender rank's port.
+
+ The base host/port are also stored so that the receive path can
+ construct per-rank metadata for heterogeneous TP scenarios.
+ """
if not self.config.need_recv_cache:
return
@@ -523,18 +684,39 @@ def update_sender_info(self, sender_info: dict[str, Any], sender_stage_id: str |
logger.warning("Invalid sender_info format: %s", sender_info)
return
+ sender_host = actual_info.get("host")
+ base_zmq_port = actual_info.get("zmq_port")
+
+ # Store base sender info for per-rank metadata construction.
+ self._sender_base_host = sender_host
+ if base_zmq_port is not None:
+ self._sender_base_zmq_port = int(base_zmq_port)
+
+ # --- Default sender: offset to match this receiver's corresponding sender rank ---
+ zmq_port = base_zmq_port
+ if zmq_port is not None and self._tp_topo.local_rank > 0:
+ zmq_port = int(zmq_port) + self._tp_topo.local_rank * KV_RANK_PORT_STRIDE
+
if self.config.connector_config:
- self.config.connector_config["sender_host"] = actual_info.get("host")
- self.config.connector_config["sender_zmq_port"] = actual_info.get("zmq_port")
+ self.config.connector_config["sender_host"] = sender_host
+ self.config.connector_config["sender_zmq_port"] = zmq_port
if self._connector and hasattr(self._connector, "update_sender_info"):
try:
- self._connector.update_sender_info(actual_info.get("host"), actual_info.get("zmq_port"))
+ self._connector.update_sender_info(sender_host, zmq_port)
except Exception:
if hasattr(self._connector, "sender_host"):
- self._connector.sender_host = actual_info.get("host")
+ self._connector.sender_host = sender_host
if hasattr(self._connector, "sender_zmq_port"):
- self._connector.sender_zmq_port = actual_info.get("zmq_port")
+ self._connector.sender_zmq_port = zmq_port
+
+ logger.info(
+ "Sender info updated: host=%s, base_port=%s, adjusted_port=%s (local_rank=%s)",
+ sender_host,
+ base_zmq_port,
+ zmq_port,
+ self._tp_topo.local_rank,
+ )
def handle_finished_requests_kv_transfer(
self,
@@ -692,35 +874,54 @@ def _transfer_kv_cache(self, kv_data: KVCacheTransferData, transfer_req_id: str)
kv_data.request_id = transfer_req_id
serialization_start = time.perf_counter()
- transfer_data: torch.Tensor | bytes | dict[str, Any]
- supports_raw = getattr(self.connector, "supports_raw_data", False)
+ topo = self._tp_topo
+ send_keys = build_rank_aware_send_keys(
+ transfer_req_id, from_stage, to_stage, topo, hook=self.kv_send_key_builder
+ )
+ sender_slice_active = (
+ topo.source_tp_size < topo.target_tp_size and len(send_keys) > 1 and not callable(self.kv_send_key_builder)
+ )
+ per_key_payloads: list[tuple[str, torch.Tensor | bytes | dict[str, Any]]] = []
- try:
- if supports_raw:
- transfer_data = kv_data.to_gpu_tensor()
+ if sender_slice_active:
+ target_ranks = get_kv_target_ranks(topo)
+ if len(target_ranks) != len(send_keys):
+ logger.warning(
+ "Skip sender-side KV slicing because target rank count does not match send key count: "
+ "target_ranks=%s send_keys=%s",
+ len(target_ranks),
+ len(send_keys),
+ )
+ sender_slice_active = False
else:
- raise RuntimeError("Connector does not support raw tensor")
- except Exception:
- try:
- transfer_data = kv_data.to_bytes()
- except Exception:
- data_dict = kv_data.to_dict()
- data_dict["request_id"] = transfer_req_id
- transfer_data = data_dict
+ for put_key, target_rank in zip(send_keys, target_ranks, strict=False):
+ sliced_kv_data = self._slice_transfer_data_for_target(kv_data, target_rank)
+ per_key_payloads.append((put_key, self._serialize_transfer_payload(sliced_kv_data)))
+
+ if not per_key_payloads:
+ transfer_data = self._serialize_transfer_payload(kv_data)
+ per_key_payloads = [(put_key, transfer_data) for put_key in send_keys]
serialization_ms = (time.perf_counter() - serialization_start) * 1000
logger.info("KV cache serialized for %s in %.1f ms", transfer_req_id, serialization_ms)
transfer_start = time.perf_counter()
- success, size, _ = self._transfer_with_retry(from_stage, to_stage, f"kv_cache_{transfer_req_id}", transfer_data)
+ total_size = 0
+ all_succeeded = True
+ for put_key, transfer_data in per_key_payloads:
+ success, size, _ = self._transfer_with_retry(from_stage, to_stage, put_key, transfer_data)
+ total_size += size
+ all_succeeded = all_succeeded and success
+
elapsed = time.perf_counter() - transfer_start
- if success:
- mbps = (size / 1024 / 1024) / elapsed if elapsed > 0 else 0
+ if all_succeeded:
+ mbps = (total_size / 1024 / 1024) / elapsed if elapsed > 0 else 0
logger.info(
- "KV transfer OK: %s, %s bytes, %.3fs, %.1f MB/s",
+ "KV transfer OK: %s, %s bytes across %s key(s), %.3fs, %.1f MB/s",
transfer_req_id,
- size,
+ total_size,
+ len(send_keys),
elapsed,
mbps,
)
@@ -731,7 +932,7 @@ def _transfer_with_retry(
self,
from_stage: str,
to_stage: str,
- request_id: str,
+ put_key: str,
data: "dict[str, Any] | bytes | torch.Tensor",
max_retries: int = 3,
) -> tuple[bool, int, dict[str, Any] | None]:
@@ -740,7 +941,7 @@ def _transfer_with_retry(
Args:
from_stage: Source stage identifier
to_stage: Target stage identifier
- request_id: Request identifier for the key
+ put_key: Pre-built connector key (rank-aware when TP > 1)
data: Data to transfer
max_retries: Maximum number of retry attempts
@@ -749,14 +950,12 @@ def _transfer_with_retry(
"""
for attempt in range(max_retries):
try:
- # Build the full key for connector
- full_request_id = f"omni_{from_stage}_to_{to_stage}_{request_id}"
success, size, metadata = self.connector.put(
- from_stage=from_stage, to_stage=to_stage, put_key=full_request_id, data=data
+ from_stage=from_stage, to_stage=to_stage, put_key=put_key, data=data
)
if success:
return success, size, metadata
- logger.warning(f"Transfer attempt {attempt + 1} failed for {request_id}")
+ logger.warning(f"Transfer attempt {attempt + 1} failed for {put_key}")
except Exception as e:
logger.warning(f"Transfer attempt {attempt + 1} exception: {e}")
@@ -801,22 +1000,46 @@ def receive_kv_cache_for_request(
poll_interval = 0.01
max_poll_interval = 0.5
- logger.info(f"Wait for KV cache for request {request_id} from stage {from_stage} to {to_stage}...")
+ topo = self._tp_topo
+ recv_key_pairs = build_rank_aware_recv_keys(
+ request_id, from_stage, to_stage, topo, hook=self.kv_recv_key_builder
+ )
+ pending_pairs = list(recv_key_pairs)
+ received_payloads: dict[str, tuple[dict[str, Any], int]] = {}
+
+ logger.info(
+ "Wait for KV cache for request %s from stage %s to %s via %s key(s)...",
+ request_id,
+ from_stage,
+ to_stage,
+ len(recv_key_pairs),
+ )
try:
while True:
- # Build the full key for connector
- full_request_id = f"omni_{from_stage}_to_{to_stage}_kv_cache_{request_id}"
link_start = time.perf_counter()
- result = self.connector.get(
- from_stage=from_stage,
- to_stage=to_stage,
- get_key=full_request_id,
- )
- if result:
+ for get_key, from_rank in list(pending_pairs):
+ # Construct per-rank metadata so the connector queries
+ # the correct sender endpoint (heterogeneous TP path).
+ # When from_rank is None (TP<=1), metadata stays None
+ # and the connector falls back to its default sender.
+ rank_metadata: dict[str, Any] | None = None
+ if from_rank is not None and self._sender_base_host and self._sender_base_zmq_port is not None:
+ rank_metadata = {
+ "source_host": self._sender_base_host,
+ "source_port": self._sender_base_zmq_port + from_rank * KV_RANK_PORT_STRIDE,
+ }
+
+ result = self.connector.get(
+ from_stage=from_stage,
+ to_stage=to_stage,
+ get_key=get_key,
+ metadata=rank_metadata,
+ )
+ if not result:
+ continue
+
raw_data, size = result
- elapsed = time.time() - start_time
- link_ms = (time.perf_counter() - link_start) * 1000
managed_buffer = None
if hasattr(raw_data, "tensor") and hasattr(raw_data, "release"):
@@ -844,6 +1067,21 @@ def receive_kv_cache_for_request(
else:
data = raw_data
+ received_payloads[get_key] = (data, size)
+ pending_pairs.remove((get_key, from_rank))
+
+ if not pending_pairs and received_payloads:
+ elapsed = time.time() - start_time
+ link_ms = (time.perf_counter() - link_start) * 1000
+ ordered_payloads = [received_payloads[key][0] for key, _ in recv_key_pairs]
+ total_size = sum(received_payloads[key][1] for key, _ in recv_key_pairs)
+
+ if len(ordered_payloads) == 1:
+ data = ordered_payloads[0]
+ else:
+ data = merge_received_rank_shards(ordered_payloads, merger=self.kv_payload_merger)
+ data = slice_received_rank_shard(data, topo, slicer=self.kv_payload_slicer)
+
try:
if isinstance(data, dict) and "layer_blocks" in data:
layer_blocks = data["layer_blocks"]
@@ -856,18 +1094,18 @@ def receive_kv_cache_for_request(
continue
if target_device is not None and tensor.device != target_device:
cache_list[i] = tensor.to(target_device).contiguous()
- finally:
- if managed_buffer is not None:
- managed_buffer.release()
+ except Exception:
+ logger.exception("Failed to move KV cache tensors to target device")
logger.info(
- "Successfully received KV cache for %s, %s bytes, wait=%.3fs, link=%.1fms",
+ "Successfully received KV cache for %s, %s bytes across %s key(s), wait=%.3fs, link=%.1fms",
request_id,
- size,
+ total_size,
+ len(recv_key_pairs),
elapsed,
link_ms,
)
- return data, size
+ return data, total_size
if time.time() - start_time > timeout:
logger.error(f"Timeout waiting for KV cache for request {request_id} after {timeout}s")
@@ -876,11 +1114,8 @@ def receive_kv_cache_for_request(
time.sleep(poll_interval)
poll_interval = min(poll_interval * 2, max_poll_interval)
- except Exception as e:
- logger.error(f"Error receiving KV cache for {request_id}: {e}")
- import traceback
-
- traceback.print_exc()
+ except Exception:
+ logger.exception("Error receiving KV cache for %s", request_id)
return None, 0
def apply_kv_cache_to_request(self, req: Any, data: dict[str, Any]) -> None:
@@ -994,73 +1229,79 @@ def receive_multi_kv_cache_distributed(
cfg_kv_collect_func: Callable | None = None,
target_device: torch.device | None = None,
) -> bool:
- """Broadcast-aware wrapper around :meth:`receive_multi_kv_cache`.
-
- SharedMemory connector is single-reader: once rank 0 consumes the
- segment it is deleted. For multi-GPU stages (e.g. sequence-parallel)
- only rank 0 receives; the result is then broadcast to every other
- rank via the world process-group.
-
- For single-worker stages this is equivalent to calling
- :meth:`receive_multi_kv_cache` directly.
+ """Distributed wrapper around :meth:`receive_multi_kv_cache`.
+
+ TP-aware path selection:
+ - world size 1: direct receive
+ - TP active, cfg size 1: each rank independently receives
+ - TP active, cfg size > 1: cfg-rank 0 receives, then broadcasts to
+ peers that share the same TP rank
+ - TP inactive: legacy rank-0 receive then world broadcast
"""
- from vllm_omni.diffusion.distributed.parallel_state import get_world_group
+ from vllm_omni.diffusion.distributed.parallel_state import (
+ get_cfg_group,
+ get_classifier_free_guidance_rank,
+ get_classifier_free_guidance_world_size,
+ get_world_group,
+ )
world = get_world_group()
if world.world_size <= 1:
return self.receive_multi_kv_cache(req, cfg_kv_collect_func, target_device)
- # --- rank 0: receive to CPU (needed for pickle-based broadcast) ---
- if world.rank_in_group == 0:
- self.receive_multi_kv_cache(req, cfg_kv_collect_func, torch.device("cpu"))
+ topo = self._tp_topo
+ tp_active = topo.source_tp_size > 1 or topo.target_tp_size > 1
+ cfg_size = 1
+ cfg_rank = 0
+ cfg_group = None
+ try:
+ cfg_size = get_classifier_free_guidance_world_size()
+ cfg_rank = get_classifier_free_guidance_rank()
+ cfg_group = get_cfg_group()
+ except Exception:
+ cfg_size = 1
+ cfg_rank = 0
+ cfg_group = None
- kv_payload: dict[str, object] = {}
- for attr in ("past_key_values", "kv_metadata"):
- val = getattr(req, attr, None)
- if val is not None:
- kv_payload[attr] = val
+ if tp_active and cfg_size <= 1:
+ logger.info(
+ "Rank-aware KV receive: rank %s independently receiving (from_tp=%s, to_tp=%s)",
+ topo.local_rank,
+ topo.source_tp_size,
+ topo.target_tp_size,
+ )
+ return self.receive_multi_kv_cache(req, cfg_kv_collect_func, target_device)
- if hasattr(req, "sampling_params") and req.sampling_params is not None:
- for key in list(vars(req.sampling_params).keys()):
- if (key.startswith("cfg_") and key.endswith("_past_key_values")) or key in (
- "past_key_values",
- "kv_metadata",
- ):
- val = getattr(req.sampling_params, key, None)
- if val is not None:
- kv_payload[f"sp.{key}"] = val
-
- payload_list = [kv_payload]
- # Use broadcast_object_list (pickle-based) instead of broadcast_tensor_dict
- # because the KV cache is a heterogeneous nested structure (NaiveCache objects
- # with metadata + tensors), not a flat tensor dict. This runs once before
- # the denoising loop so the serialization cost is negligible.
- torch.distributed.broadcast_object_list(payload_list, src=world.ranks[0], group=world.cpu_group)
- kv_payload = payload_list[0]
- else:
- payload_list: list[dict[str, object] | None] = [None]
- torch.distributed.broadcast_object_list(payload_list, src=world.ranks[0], group=world.cpu_group)
- kv_payload = payload_list[0]
+ if tp_active and cfg_size > 1 and cfg_group is not None:
+ kv_payload: dict[str, object] | None = None
+ if cfg_rank == 0:
+ received = self.receive_multi_kv_cache(req, cfg_kv_collect_func, torch.device("cpu"))
+ rank_payloads = self._build_cfg_rank_local_payloads(req, cfg_size) if received else [None] * cfg_size
+ kv_payload = rank_payloads[0]
+ for dst_rank in range(1, cfg_size):
+ cfg_group.send_object(rank_payloads[dst_rank], dst_rank)
+ else:
+ kv_payload = cfg_group.recv_object(0)
- # --- apply on ALL ranks (rank 0 also needs CPU→GPU move) ---
- if not kv_payload:
- return False
+ if not kv_payload:
+ return False
- for attr in ("past_key_values", "kv_metadata"):
- val = kv_payload.get(attr)
- if val is not None:
- if target_device is not None:
- val = _move_to_device(val, target_device)
- setattr(req, attr, val)
+ self._apply_request_kv_payload(req, kv_payload, target_device)
+ return True
- if hasattr(req, "sampling_params") and req.sampling_params is not None:
- for key, val in kv_payload.items():
- if key.startswith("sp."):
- if target_device is not None:
- val = _move_to_device(val, target_device)
- setattr(req.sampling_params, key[3:], val)
+ kv_payload: dict[str, object] | None = None
+ if world.rank_in_group == 0:
+ received = self.receive_multi_kv_cache(req, cfg_kv_collect_func, torch.device("cpu"))
+ if received:
+ kv_payload = self._collect_request_kv_payload(req)
+
+ kv_payload = world.broadcast_object(kv_payload, src=0)
+
+ if not kv_payload:
+ return False
+ self._apply_request_kv_payload(req, kv_payload, target_device)
return True
diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py
index 393d0e8013..1d5fc398a4 100644
--- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py
+++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py
@@ -58,6 +58,7 @@ def __init__(self, vllm_config: Any):
self.waiting_for_chunk_waiting_requests: deque[Any] = deque()
self.waiting_for_chunk_running_requests: deque[Any] = deque()
self.requests_with_ready_chunks = set()
+ self.requests_origin_status = {}
@classmethod
def create_connector(cls, model_config: Any):
@@ -279,6 +280,7 @@ def cleanup_receiver(self, request_id: str) -> None:
self.get_req_chunk.pop(request_id, None)
self.requests_with_ready_chunks.discard(request_id)
self.request_ids_mapping.pop(request_id, None)
+ self.requests_origin_status.pop(request_id, None)
self._cancelled_load_reqs.add(request_id)
self._finished_load_reqs.discard(request_id)
@@ -408,6 +410,7 @@ def _process_chunk_queue(
self.requests_with_ready_chunks.add(request.request_id)
continue
queue.remove(request)
+ self.requests_origin_status[request.request_id] = target_status
waiting_for_chunk_list.append(request)
def _clear_chunk_ready(self, scheduler_output: Any) -> None:
@@ -420,3 +423,23 @@ def _clear_chunk_ready(self, scheduler_output: Any) -> None:
for req_id in scheduler_output.scheduled_cached_reqs.req_ids:
if req_id in self.requests_with_ready_chunks:
self.requests_with_ready_chunks.remove(req_id)
+
+ def finish_requests(
+ self, request_ids: Any, finished_status: RequestStatus, requests: dict[str, Request] | None = None
+ ) -> list[tuple[str, int]]:
+ assert RequestStatus.is_finished(finished_status)
+ if isinstance(request_ids, str):
+ request_ids = (request_ids,)
+ elif request_ids is not None:
+ request_ids = set(request_ids)
+ else:
+ request_ids = requests.keys()
+
+ # First pass: collect requests to remove from queues
+ for req_id in request_ids:
+ request = requests.get(req_id) if requests else None
+ if request is None or request.is_finished():
+ # Invalid request ID.
+ continue
+ if req_id in self.requests_origin_status:
+ request.status = self.requests_origin_status.pop(req_id)
diff --git a/vllm_omni/distributed/omni_connectors/utils/initialization.py b/vllm_omni/distributed/omni_connectors/utils/initialization.py
index 37b7d0d7f8..f012af3c9c 100644
--- a/vllm_omni/distributed/omni_connectors/utils/initialization.py
+++ b/vllm_omni/distributed/omni_connectors/utils/initialization.py
@@ -23,6 +23,11 @@
# collide with request-forwarding endpoints that share the same base port.
KV_TRANSFER_PORT_OFFSET = 100
+# Port stride between TP ranks so each worker binds a unique ZMQ port
+# when TP > 1. Must be larger than the maximum number of pipeline stages.
+# Formula: zmq_port = base + KV_TRANSFER_PORT_OFFSET + rank * STRIDE + stage
+KV_RANK_PORT_STRIDE = 16
+
def initialize_connectors_from_config(
config_path: str | Path | None = None,
@@ -201,6 +206,19 @@ def load_omni_transfer_config(
if config_dict is None:
return None
+ # Normalize new-schema (top-level ``connectors`` + ``stages``) into the
+ # legacy ``runtime.connectors`` + ``stage_args`` shape the parser reads.
+ if "stages" in config_dict and "stage_args" not in config_dict:
+ normalized: dict[str, Any] = dict(config_dict)
+ runtime = dict(normalized.get("runtime") or {})
+ if "connectors" in normalized and "connectors" not in runtime:
+ runtime["connectors"] = normalized["connectors"]
+ if "edges" in normalized and "edges" not in runtime:
+ runtime["edges"] = normalized["edges"]
+ normalized["runtime"] = runtime
+ normalized["stage_args"] = normalized["stages"]
+ config_dict = normalized
+
# Parse connectors
connectors = {}
runtime_config = config_dict.get("runtime", {})
diff --git a/vllm_omni/distributed/omni_connectors/utils/kv_utils.py b/vllm_omni/distributed/omni_connectors/utils/kv_utils.py
index 2cb48a8b34..12b9b3d4f7 100644
--- a/vllm_omni/distributed/omni_connectors/utils/kv_utils.py
+++ b/vllm_omni/distributed/omni_connectors/utils/kv_utils.py
@@ -1,15 +1,380 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Utility helpers for KV cache manipulation."""
+"""Utility helpers for KV cache manipulation, TP routing, and merge/slice."""
+
+from __future__ import annotations
+
+import os
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Any
import torch
+from vllm.distributed.parallel_state import (
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+)
from vllm.logger import init_logger
+from .initialization import KV_RANK_PORT_STRIDE, KV_TRANSFER_PORT_OFFSET
+
logger = init_logger(__name__)
LayerKV = torch.Tensor | tuple[torch.Tensor, torch.Tensor]
+# ------------------------------------------------------------------ #
+# TP Topology
+# ------------------------------------------------------------------ #
+
+
+@dataclass(frozen=True)
+class KVTPTopology:
+ """Immutable descriptor for a KV-transfer parallel mapping.
+
+ Captures sender/receiver parallel sizes and the local rank within
+ that parallel dimension. Works for any divisible parallel dimension
+ (TP, SP, Ring Attention).
+ """
+
+ source_tp_size: int
+ target_tp_size: int
+ local_rank: int
+
+ def __post_init__(self) -> None:
+ if self.source_tp_size <= 0 or self.target_tp_size <= 0:
+ raise ValueError(
+ f"Parallel sizes must be positive: "
+ f"source_tp_size={self.source_tp_size}, target_tp_size={self.target_tp_size}"
+ )
+ if self.local_rank < 0:
+ raise ValueError(f"local_rank must be non-negative, got {self.local_rank}")
+
+ @property
+ def is_heterogeneous(self) -> bool:
+ return self.source_tp_size != self.target_tp_size
+
+ @property
+ def ratio(self) -> int:
+ """Larger parallel size divided by smaller. Always >= 1."""
+ return max(self.source_tp_size, self.target_tp_size) // min(self.source_tp_size, self.target_tp_size)
+
+
+# ------------------------------------------------------------------ #
+# Runtime TP detection
+# ------------------------------------------------------------------ #
+
+
+def get_local_tp_rank() -> int:
+ """Return the TP-local rank of this worker process.
+
+ Uses ``get_tensor_model_parallel_rank()`` which returns the rank
+ within the TP group only, not the stage-global rank.
+ """
+ try:
+ return get_tensor_model_parallel_rank()
+ except Exception:
+ logger.debug("TP parallel state not initialized, falling back to LOCAL_RANK env", exc_info=True)
+ try:
+ return int(os.environ.get("LOCAL_RANK", "0"))
+ except (ValueError, TypeError):
+ return 0
+
+
+def get_tp_world_size() -> int:
+ """Return the TP world size (tensor-parallel dimension only).
+
+ Uses ``get_tensor_model_parallel_world_size()`` so that
+ cfg_parallel, SP, PP etc. are not included in the count.
+ """
+ try:
+ return get_tensor_model_parallel_world_size()
+ except Exception:
+ logger.debug("TP parallel state not initialized, defaulting world_size=1", exc_info=True)
+ return 1
+
+
+# ------------------------------------------------------------------ #
+# ZMQ port computation
+# ------------------------------------------------------------------ #
+
+
+def kv_zmq_port(base_port: int, from_stage: int, local_rank: int = 0) -> int:
+ """Compute the ZMQ port for a KV-transfer connector.
+
+ Each TP rank gets its own port so that TP > 1 deployments do not
+ cause ``EADDRINUSE`` when multiple sender workers bind on the same
+ host. The formula is backward-compatible: rank 0 produces the same
+ port as the previous ``base + OFFSET + stage`` formula.
+ """
+ return base_port + KV_TRANSFER_PORT_OFFSET + local_rank * KV_RANK_PORT_STRIDE + from_stage
+
+
+# ------------------------------------------------------------------ #
+# TP topology validation and rank routing
+# ------------------------------------------------------------------ #
+
+
+def validate_kv_tp_topology(topo: KVTPTopology) -> None:
+ """Reject heterogeneous TP mappings that cannot be routed losslessly."""
+ larger = max(topo.source_tp_size, topo.target_tp_size)
+ smaller = min(topo.source_tp_size, topo.target_tp_size)
+ if larger % smaller != 0:
+ raise ValueError(
+ f"KV TP mapping must be divisible: "
+ f"source_tp_size={topo.source_tp_size}, "
+ f"target_tp_size={topo.target_tp_size}"
+ )
+
+
+def get_kv_target_ranks(topo: KVTPTopology) -> list[int]:
+ """Which remote ranks this local rank sends KV shards to (send side)."""
+ validate_kv_tp_topology(topo)
+ if topo.source_tp_size == topo.target_tp_size:
+ return [topo.local_rank]
+ if topo.source_tp_size > topo.target_tp_size:
+ return [topo.local_rank // (topo.source_tp_size // topo.target_tp_size)]
+ ratio = topo.target_tp_size // topo.source_tp_size
+ return [topo.local_rank * ratio + i for i in range(ratio)]
+
+
+def get_kv_source_ranks(topo: KVTPTopology) -> list[int]:
+ """Which remote ranks this local rank receives KV shards from (recv side)."""
+ validate_kv_tp_topology(topo)
+ if topo.source_tp_size == topo.target_tp_size:
+ return [topo.local_rank]
+ if topo.source_tp_size > topo.target_tp_size:
+ ratio = topo.source_tp_size // topo.target_tp_size
+ return [topo.local_rank * ratio + i for i in range(ratio)]
+ return [topo.local_rank // (topo.target_tp_size // topo.source_tp_size)]
+
+
+# ------------------------------------------------------------------ #
+# Rank-aware connector key building
+# ------------------------------------------------------------------ #
+
+
+def get_kv_connector_key(
+ req_id: str,
+ from_stage: int | str,
+ chunk_id: int,
+ from_rank: int,
+ to_rank: int,
+) -> str:
+ """Build connector key that includes rank info for KV transfers.
+
+ Format matches PR #2677: ``{req_id}_{from_stage}_{chunk_id}_{from_rank}_{to_rank}``
+ """
+ return f"{req_id}_{from_stage}_{chunk_id}_{from_rank}_{to_rank}"
+
+
+def build_rank_aware_send_keys(
+ request_id: str,
+ from_stage: str,
+ to_stage: str,
+ topo: KVTPTopology,
+ hook: Callable | None = None,
+) -> list[str]:
+ """Build send-side connector keys, checking injectable hook first."""
+ if callable(hook):
+ keys = list(hook(request_id, from_stage, to_stage))
+ if keys:
+ return keys
+ if topo.source_tp_size <= 1 and topo.target_tp_size <= 1:
+ return [f"omni_{from_stage}_to_{to_stage}_kv_cache_{request_id}"]
+ target_ranks = get_kv_target_ranks(topo)
+ return [get_kv_connector_key(request_id, from_stage, 0, topo.local_rank, r) for r in target_ranks]
+
+
+def build_rank_aware_recv_keys(
+ request_id: str,
+ from_stage: str,
+ to_stage: str,
+ topo: KVTPTopology,
+ hook: Callable | None = None,
+) -> list[tuple[str, int | None]]:
+ """Build recv-side connector keys with sender rank info.
+
+ Returns a list of ``(key, from_rank)`` tuples. ``from_rank`` is
+ ``None`` when TP <= 1 (single sender, no per-rank routing needed).
+ For TP > 1, ``from_rank`` identifies which sender rank owns the
+ key so that the connector can route metadata queries to the
+ correct endpoint.
+ """
+ if callable(hook):
+ raw = list(hook(request_id, from_stage, to_stage))
+ if raw:
+ if isinstance(raw[0], tuple):
+ return raw
+ # Hook returned plain strings (e.g. OmniConnectorModelRunnerMixin.
+ # get_rank_aware_kv_keys). Reconstruct from_rank from topology so
+ # Mooncake connector can route metadata queries to the correct
+ # sender endpoint in heterogeneous TP.
+ # TODO: have the mixin return (key, from_rank) tuples directly
+ # to avoid this indirect reconstruction.
+ source_ranks = get_kv_source_ranks(topo)
+ if len(raw) == len(source_ranks):
+ return list(zip(raw, source_ranks))
+ return [(k, None) for k in raw]
+ if topo.source_tp_size <= 1 and topo.target_tp_size <= 1:
+ return [(f"omni_{from_stage}_to_{to_stage}_kv_cache_{request_id}", None)]
+ source_ranks = get_kv_source_ranks(topo)
+ return [(get_kv_connector_key(request_id, from_stage, 0, r, topo.local_rank), r) for r in source_ranks]
+
+
+# ------------------------------------------------------------------ #
+# KV tensor head slicing (heterogeneous TP)
+# ------------------------------------------------------------------ #
+
+
+def slice_kv_tensor_heads(
+ tensor: torch.Tensor | None,
+ offset_in_shard: int,
+ num_slices: int,
+) -> torch.Tensor | None:
+ """Slice one KV tensor along its head dimension (dim 1)."""
+ if tensor is None:
+ return None
+ if not isinstance(tensor, torch.Tensor):
+ return tensor
+ if tensor.dim() < 2:
+ raise ValueError(f"Expected KV tensor with a head dimension, got shape={tuple(tensor.shape)}")
+ if num_slices <= 0:
+ raise ValueError(f"num_slices must be > 0, got {num_slices}")
+ if not (0 <= offset_in_shard < num_slices):
+ raise ValueError(f"offset_in_shard must be in [0, {num_slices}), got {offset_in_shard}")
+
+ heads_in_shard = tensor.shape[1]
+ if heads_in_shard % num_slices != 0:
+ raise ValueError(
+ "KV head count must be divisible for heterogeneous TP slicing: "
+ f"heads_in_shard={heads_in_shard}, num_slices={num_slices}"
+ )
+
+ heads_per_slice = heads_in_shard // num_slices
+ start = offset_in_shard * heads_per_slice
+ end = start + heads_per_slice
+ return tensor[:, start:end, ...].contiguous()
+
+
+def slice_layer_blocks(
+ layer_blocks: dict[str, Any],
+ offset_in_shard: int,
+ num_slices: int,
+) -> dict[str, list[torch.Tensor | None]]:
+ """Slice all KV layers for one logical receiver rank."""
+ sliced_blocks: dict[str, list[torch.Tensor | None]] = {}
+ for cache_name in ("key_cache", "value_cache"):
+ cache_list = layer_blocks.get(cache_name, [])
+ sliced_blocks[cache_name] = [
+ slice_kv_tensor_heads(tensor, offset_in_shard, num_slices) for tensor in cache_list
+ ]
+ return sliced_blocks
+
+
+# ------------------------------------------------------------------ #
+# Multi-rank merge and receiver-side slice
+# ------------------------------------------------------------------ #
+
+
+def merge_received_rank_shards(
+ payloads: list[dict[str, Any]],
+ merger: Callable | None = None,
+) -> dict[str, Any] | None:
+ """Merge multiple source-rank KV shards for one target rank.
+
+ When *merger* is provided (injectable hook), it is called directly.
+ Otherwise the default merges along the head dimension (dim 1).
+ """
+ if callable(merger):
+ return merger(payloads)
+ if not payloads:
+ return None
+ if len(payloads) == 1:
+ return payloads[0]
+
+ base_payload = payloads[0]
+ if not isinstance(base_payload, dict) or "layer_blocks" not in base_payload:
+ return base_payload
+
+ merged: dict[str, Any] = {
+ "request_id": base_payload.get("request_id"),
+ "block_ids": list(base_payload.get("block_ids", [])),
+ "metadata": dict(base_payload.get("metadata", {})),
+ }
+ merged_layer_blocks: dict[str, list[torch.Tensor | None]] = {}
+
+ for cache_name in ("key_cache", "value_cache"):
+ cache_lists = [payload.get("layer_blocks", {}).get(cache_name, []) for payload in payloads]
+ num_layers = max((len(cache_list) for cache_list in cache_lists), default=0)
+ merged_cache: list[torch.Tensor | None] = []
+
+ for layer_idx in range(num_layers):
+ layer_tensors = [
+ cache_list[layer_idx]
+ for cache_list in cache_lists
+ if layer_idx < len(cache_list) and cache_list[layer_idx] is not None
+ ]
+ if not layer_tensors:
+ merged_cache.append(None)
+ elif len(layer_tensors) == 1 or not isinstance(layer_tensors[0], torch.Tensor):
+ merged_cache.append(layer_tensors[0])
+ else:
+ merged_cache.append(torch.cat(layer_tensors, dim=1).contiguous())
+
+ merged_layer_blocks[cache_name] = merged_cache
+
+ merged["layer_blocks"] = merged_layer_blocks
+ return merged
+
+
+def slice_received_rank_shard(
+ payload: dict[str, Any] | None,
+ topo: KVTPTopology,
+ slicer: Callable | None = None,
+) -> dict[str, Any] | None:
+ """Optionally slice a received payload to extract this rank's portion.
+
+ Used when ``to_tp > from_tp``: the sender sent full heads and each
+ receiver rank slices out its own subset.
+ """
+ if callable(slicer):
+ return slicer(payload)
+ if not payload or topo.target_tp_size <= topo.source_tp_size or "layer_blocks" not in payload:
+ return payload
+
+ metadata = payload.get("metadata", {})
+ slice_metadata = metadata.get("tp_head_slice") if isinstance(metadata, dict) else None
+ if isinstance(slice_metadata, dict) and slice_metadata.get("applied"):
+ tagged_rank = slice_metadata.get("target_rank")
+ if tagged_rank is not None and tagged_rank != topo.local_rank:
+ logger.warning(
+ "Received pre-sliced KV payload for unexpected target rank: expected=%s got=%s",
+ topo.local_rank,
+ tagged_rank,
+ )
+ return payload
+
+ ratio = topo.target_tp_size // topo.source_tp_size
+ offset_in_sender = topo.local_rank % ratio
+ updated_metadata = dict(metadata) if isinstance(metadata, dict) else {}
+ updated_metadata["tp_head_slice"] = {
+ "applied": True,
+ "side": "receiver",
+ "target_rank": topo.local_rank,
+ "from_tp": topo.source_tp_size,
+ "to_tp": topo.target_tp_size,
+ "offset_in_shard": offset_in_sender,
+ "num_slices": ratio,
+ }
+ return {
+ "request_id": payload.get("request_id"),
+ "layer_blocks": slice_layer_blocks(payload["layer_blocks"], offset_in_sender, ratio),
+ "block_ids": list(payload.get("block_ids", [])),
+ "metadata": updated_metadata,
+ }
+
+
def normalize_layer_kv(
layer_kv: LayerKV,
*,
diff --git a/vllm_omni/engine/__init__.py b/vllm_omni/engine/__init__.py
index c8a96e6d25..6c92d7952d 100644
--- a/vllm_omni/engine/__init__.py
+++ b/vllm_omni/engine/__init__.py
@@ -79,6 +79,10 @@ class OmniEngineCoreRequest(EngineCoreRequest):
class OmniEngineCoreOutput(EngineCoreOutput):
pooling_output: dict[str, torch.Tensor] | None = None
+ # Finished flag for streaming input segment
+ is_segment_finished: bool | None = False
+ # Streaming update prompt length
+ new_prompt_len_snapshot: int | None = None
class OmniEngineCoreOutputs(EngineCoreOutputs):
diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py
index d43f1b8fdc..d98ce7d419 100644
--- a/vllm_omni/engine/arg_utils.py
+++ b/vllm_omni/engine/arg_utils.py
@@ -3,7 +3,7 @@
import json
import os
import tempfile
-from dataclasses import dataclass, field
+from dataclasses import dataclass, field, fields
from typing import Any
from vllm.engine.arg_utils import EngineArgs
@@ -20,6 +20,8 @@
_ARCH_TO_MODEL_TYPE: dict[str, str] = {
"CosyVoice3Model": "cosyvoice3",
"OmniVoiceModel": "omnivoice",
+ "VoxCPM2TalkerForConditionalGeneration": "voxcpm2",
+ "VoxCPMForConditionalGeneration": "voxcpm",
}
# Maps model architecture names to tokenizer subfolder paths within HF repos.
@@ -40,6 +42,8 @@ def _register_omni_hf_configs() -> None:
from vllm_omni.model_executor.models.voxtral_tts.configuration_voxtral_tts import (
VoxtralTTSConfig,
)
+ from vllm_omni.transformers_utils.configs.voxcpm import VoxCPMConfig
+ from vllm_omni.transformers_utils.configs.voxcpm2 import VoxCPM2Config
except Exception as exc: # pragma: no cover - best-effort optional registration
logger.warning("Skipping omni HF config registration due to import error: %s", exc)
return
@@ -57,6 +61,8 @@ def _register_omni_hf_configs() -> None:
("cosyvoice3", CosyVoice3Config),
("omnivoice", OmniVoiceConfig),
("voxtral_tts", VoxtralTTSConfig),
+ ("voxcpm", VoxCPMConfig),
+ ("voxcpm2", VoxCPM2Config),
]:
try:
AutoConfig.register(model_type, config_cls)
@@ -121,6 +127,9 @@ class OmniEngineArgs(EngineArgs):
(e.g. ["text", "audio"]). If None, all modalities supported by
the model are used.
log_stats: Whether to log engine statistics. Defaults to False.
+ custom_pipeline_args: Dictionary of arguments for custom pipeline
+ initialization (e.g., ``{"pipeline_class": "my.Module"}``).
+ Passed through to the diffusion stage engine.
"""
stage_id: int = 0
@@ -140,6 +149,7 @@ class OmniEngineArgs(EngineArgs):
stage_configs_path: str | None = None
output_modalities: list[str] | None = None
log_stats: bool = False
+ custom_pipeline_args: dict[str, Any] | None = None
def __post_init__(self) -> None:
load_omni_general_plugins()
@@ -290,3 +300,254 @@ def create_model_config(self) -> OmniModelConfig:
def output_modality(self) -> OutputModality:
"""Parse engine_output_type into a type-safe OutputModality flag."""
return OutputModality.from_string(self.engine_output_type)
+
+
+# ============================================================================
+# CLI argument routing
+# ============================================================================
+#
+# vLLM-Omni's CLI flags live in three buckets:
+#
+# ┌──────────────────┐ ┌──────────────────┐ ┌──────────────────┐
+# │ OrchestratorArgs │ │ OmniEngineArgs │ │ (upstream vllm) │
+# │ │ │ │ │ server/api │
+# │ stage_timeout │ │ max_num_seqs │ │ host, port │
+# │ worker_backend │ │ gpu_mem_util │ │ ssl_keyfile │
+# │ deploy_config │ │ dtype, quant │ │ api_key │
+# │ ... │ │ ... │ │ ... │
+# └──────────────────┘ └──────────────────┘ └──────────────────┘
+# │ │ │
+# ▼ ▼ ▼
+# orchestrator each stage uvicorn /
+# consumes engine FastAPI
+#
+# Fields in ``SHARED_FIELDS`` (e.g. ``model``, ``log_stats``) flow to BOTH
+# orchestrator and engine by design.
+#
+# Invariants enforced by ``tests/test_arg_utils.py``:
+#
+# 1. ``OrchestratorArgs`` ∩ ``OmniEngineArgs`` ⊆ ``SHARED_FIELDS``
+# 2. Every CLI flag is classifiable into one of the three buckets
+# 3. User-typed flags that match none of the above are logged as dropped
+#
+# Adding a new orchestrator-only flag → add a field to ``OrchestratorArgs``.
+# Everything else is automatic.
+
+
+@dataclass(frozen=True)
+class OrchestratorArgs:
+ """CLI flags consumed by the orchestrator.
+
+ Contract: every field here is either
+ (a) orchestrator-only (never needed by a stage engine), OR
+ (b) orchestrator-read-then-redistributed (e.g. ``async_chunk`` is read
+ from CLI, written to ``DeployConfig``, then propagated to every
+ stage via ``merge_pipeline_deploy`` — not via direct kwargs
+ forwarding).
+
+ Fields that BOTH orchestrator and engine genuinely need (e.g. ``model``,
+ ``log_stats``) should be listed in ``SHARED_FIELDS`` below; ``split_kwargs``
+ will copy them to both buckets.
+ """
+
+ # === Lifecycle ===
+ stage_init_timeout: int = 300
+ init_timeout: int = 600
+
+ # === Cross-stage Communication ===
+ shm_threshold_bytes: int = 65536
+ batch_timeout: int = 10
+
+ # === Cluster / Backend ===
+ worker_backend: str = "multi_process"
+ ray_address: str | None = None
+
+ # === Config Files ===
+ stage_configs_path: str | None = None
+ deploy_config: str | None = None
+ stage_overrides: str | None = None # raw JSON string; parsed downstream
+
+ # === Mode Switches (orchestrator reads, DeployConfig redistributes) ===
+ async_chunk: bool | None = None
+
+ # === Observability ===
+ log_stats: bool = False
+
+ # === Headless Mode (also forwarded to engine — see SHARED_FIELDS) ===
+ stage_id: int | None = None
+
+ # === Pre-built Objects ===
+ parallel_config: Any = None
+
+ # === Multi-stage guards ===
+ # --tokenizer is captured here so it does not propagate to every stage
+ # uniformly (different stages often need different tokenizers, e.g.
+ # qwen3_omni thinker vs talker). Users wanting a per-stage tokenizer
+ # should set it in the deploy YAML.
+ tokenizer: str | None = None
+
+
+# Fields that live in BOTH OrchestratorArgs and OmniEngineArgs by design.
+# Changes to this set are a review red flag — revisit the contract.
+SHARED_FIELDS: frozenset[str] = frozenset(
+ {
+ "model", # orch: detect model_type; engine: load weights
+ "stage_id", # orch: route (headless); engine: identity
+ "log_stats", # both want the flag
+ "stage_configs_path", # orch: load legacy YAML; engine: may reference for validation
+ }
+)
+
+
+def orchestrator_field_names() -> frozenset[str]:
+ """Return the names of every field on OrchestratorArgs."""
+ return frozenset(f.name for f in fields(OrchestratorArgs))
+
+
+def internal_blacklist_keys() -> frozenset[str]:
+ """Return the set of CLI keys that must never be forwarded as per-stage
+ engine overrides.
+
+ Derived from ``OrchestratorArgs`` fields minus ``SHARED_FIELDS``, so
+ adding a new orchestrator-owned flag is a one-line change to the
+ dataclass — this function updates automatically.
+ """
+ return orchestrator_field_names() - SHARED_FIELDS
+
+
+def split_kwargs(
+ kwargs: dict[str, Any],
+ *,
+ engine_cls: type | None = None,
+ user_typed: set[str] | None = None,
+ strict: bool = False,
+) -> tuple[OrchestratorArgs, dict[str, Any]]:
+ """Partition CLI kwargs into (orchestrator, engine) buckets.
+
+ Args:
+ kwargs: Raw dict, typically ``vars(args)``.
+ engine_cls: Engine dataclass used to whitelist-filter the engine
+ bucket. Defaults to ``OmniEngineArgs``. Pass a custom class
+ for testing.
+ user_typed: Keys the user actually typed on the command line. Used
+ to warn when a user-typed flag is unclassifiable.
+ strict: If True, raise ``ValueError`` on ambiguous (double-classified
+ but not in ``SHARED_FIELDS``) fields. Default False to keep the
+ rollout non-breaking; flip to True in tests and CI.
+
+ Returns:
+ ``(orchestrator_args, engine_kwargs)``. ``engine_kwargs`` has already
+ been whitelist-filtered against ``engine_cls`` — safe to pass directly
+ to ``engine_cls(**engine_kwargs)``.
+ """
+ if engine_cls is None:
+ engine_cls = OmniEngineArgs
+
+ orch_fields = orchestrator_field_names()
+ engine_fields = {f.name for f in fields(engine_cls)}
+
+ orch_kwargs: dict[str, Any] = {}
+ engine_candidate: dict[str, Any] = {}
+ shared_values: dict[str, Any] = {}
+ unclassified: dict[str, Any] = {}
+
+ for key, value in kwargs.items():
+ in_orch = key in orch_fields
+ in_engine = key in engine_fields
+ is_shared = key in SHARED_FIELDS
+
+ if is_shared:
+ shared_values[key] = value
+ elif in_orch and in_engine:
+ # Declared in both but not marked shared → ambiguous.
+ msg = (
+ f"Field {key!r} is defined on both OrchestratorArgs and "
+ f"{engine_cls.__name__} but is not in SHARED_FIELDS. "
+ f"This causes double-routing. Either remove the duplicate or "
+ f"add {key!r} to SHARED_FIELDS if the sharing is intentional."
+ )
+ if strict:
+ raise ValueError(msg)
+ logger.error(msg)
+ # Default: treat as orchestrator-only to preserve existing behavior.
+ orch_kwargs[key] = value
+ elif in_orch:
+ orch_kwargs[key] = value
+ elif in_engine:
+ engine_candidate[key] = value
+ else:
+ unclassified[key] = value
+
+ # Warn on user-typed but unclassifiable flags so we don't silently drop
+ # something the user cared about (fixes the class of bug that spawned #873).
+ if unclassified and user_typed:
+ user_typed_unknown = sorted(k for k in unclassified if k in user_typed)
+ if user_typed_unknown:
+ logger.warning(
+ "CLI flags not consumed by vllm-omni and dropped before "
+ "per-stage engine construction: %s. If these are vllm "
+ "frontend/uvicorn flags (host, port, ssl_*, api_key, …) this "
+ "is expected; otherwise check your spelling.",
+ user_typed_unknown,
+ )
+
+ # Engine bucket: shared + engine-only. We do NOT pass through unclassified
+ # fields — that's exactly the server/uvicorn noise we want to shed.
+ engine_kwargs = {**shared_values, **engine_candidate}
+
+ # Construct the orchestrator dataclass. Shared fields that OrchestratorArgs
+ # also declares get copied into its constructor.
+ orch_init: dict[str, Any] = dict(orch_kwargs)
+ for key, value in shared_values.items():
+ if key in orch_fields:
+ orch_init[key] = value
+ orch_args = OrchestratorArgs(**orch_init)
+
+ return orch_args, engine_kwargs
+
+
+def derive_server_dests_from_vllm_parser() -> frozenset[str]:
+ """Derive the set of argparse dests that belong to vllm's frontend/server.
+
+ Returns every dest registered by ``make_arg_parser`` that is NOT a field
+ of ``OmniEngineArgs`` and NOT a field of ``OrchestratorArgs``. Useful for
+ CI tests to assert all CLI flags are classifiable without maintaining
+ a hardcoded server list.
+
+ Returns empty frozenset if vllm's parser cannot be built (e.g. in a
+ minimal test environment).
+ """
+ try:
+ from vllm.entrypoints.openai.cli_args import make_arg_parser
+ from vllm.utils.argparse_utils import FlexibleArgumentParser
+ except ImportError:
+ logger.debug("Cannot import vllm parser — server-dest derivation skipped")
+ return frozenset()
+
+ try:
+ parser = make_arg_parser(FlexibleArgumentParser())
+ all_dests = {a.dest for a in parser._actions if a.dest and a.dest != "help"}
+ except Exception as exc:
+ logger.debug("Failed to build vllm parser: %s", exc)
+ return frozenset()
+
+ engine_fields = {f.name for f in fields(OmniEngineArgs)}
+ orch_fields = orchestrator_field_names()
+
+ return frozenset(all_dests - engine_fields - orch_fields - SHARED_FIELDS)
+
+
+def orchestrator_args_from_argparse(args: Any) -> OrchestratorArgs:
+ """Build an ``OrchestratorArgs`` from an ``argparse.Namespace``.
+
+ Only copies attributes that exist on the namespace — missing fields fall
+ back to the dataclass default. Useful when the full parser is already
+ built and ``vars(args)`` would include noise.
+ """
+ kwargs: dict[str, Any] = {}
+ for f in fields(OrchestratorArgs):
+ if hasattr(args, f.name):
+ value = getattr(args, f.name)
+ if value is not None or f.default is None:
+ kwargs[f.name] = value
+ return OrchestratorArgs(**kwargs)
diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py
index 77b386124c..35b6be70f3 100644
--- a/vllm_omni/engine/async_omni_engine.py
+++ b/vllm_omni/engine/async_omni_engine.py
@@ -25,12 +25,15 @@
import janus
import torch
from omegaconf import OmegaConf
+from vllm import envs as vllm_envs
+from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.input_processor import InputProcessor
+from vllm_omni.config.stage_config import strip_parent_engine_args
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient
from vllm_omni.diffusion.stage_diffusion_proc import (
@@ -63,6 +66,7 @@
)
from vllm_omni.engine.stage_init_utils import (
StartedLlmStage,
+ _inject_inferred_kv_tp_topology,
acquire_device_locks,
build_diffusion_config,
build_engine_args_dict,
@@ -80,7 +84,11 @@
setup_stage_devices,
terminate_alive_proc,
)
-from vllm_omni.entrypoints.utils import load_and_resolve_stage_configs
+from vllm_omni.entrypoints.pd_utils import PDDisaggregationMixin
+from vllm_omni.entrypoints.utils import (
+ inject_omni_kv_config,
+ load_and_resolve_stage_configs,
+)
from vllm_omni.inputs.preprocess import OmniInputPreprocessor
from vllm_omni.platforms import current_omni_platform
@@ -90,6 +98,27 @@
logger = init_logger(__name__)
+# ============================================================================
+# Parent-EngineArgs field-routing contracts (consumed by
+# AsyncOmniEngine._strip_parent_engine_args when ``stage_configs_path`` is set).
+# ============================================================================
+
+# Fields that must survive the "equal to default → strip" filter because
+# diffusion stages need them even when equal to vllm's default value
+# (e.g. colocate worker setup relies on worker_extension_cls being forwarded).
+_PARENT_ARGS_KEEP: frozenset[str] = frozenset({"worker_extension_cls"})
+
+# Omni orchestrator-level fields consumed by ``_resolve_stage_configs`` that
+# must never leak into per-stage EngineArgs (``stage_configs_path`` would
+# trigger the ``create_model_config`` guard).
+_PARENT_ARGS_STRIP: frozenset[str] = frozenset({"stage_configs_path"})
+
+# Fields always populated by callers (via ``from_cli_args`` / ``asdict``) so
+# their presence as an override is never a surprise — suppress the
+# "override ignored" warning for these.
+_PARENT_ARGS_NO_WARN: frozenset[str] = frozenset({"model"})
+
+
def _patch_generation_config_if_needed(model_config: Any) -> None:
"""Ensure try_get_generation_config won't crash for models whose HF
config.json lacks model_type (e.g. CosyVoice3). We probe it once;
@@ -380,6 +409,12 @@ def _launch_llm_stage(
omni_kv["omni_to_stage"] = omni_to
omni_kv.setdefault("stage_id", metadata.stage_id)
engine_args_dict["omni_kv_config"] = omni_kv
+ if self.stage_configs:
+ _inject_inferred_kv_tp_topology(
+ engine_args_dict.get("omni_kv_config"),
+ metadata.stage_id,
+ self.stage_configs,
+ )
vllm_config, executor_class = build_vllm_config(
stage_cfg,
self.model,
@@ -426,21 +461,24 @@ def _launch_llm_stage(
proc=proc,
)
logger.info("[AsyncOmniEngine] Stage %s engine launch started", metadata.stage_id)
- # Keep the stage-specific device visibility until vLLM
- # finishes starting all child processes.
- if self.single_stage_mode and self._omni_master_server is not None:
- launch_stack.close()
- else:
- assert proc is not None
- assert handshake_address is not None
- complete_stage_handshake(proc, handshake_address, addresses, vllm_config)
- logger.info("[AsyncOmniEngine] Stage %s engine startup completed", metadata.stage_id)
finally:
if previous_visible_devices is None:
current_omni_platform.unset_device_control_env_var()
else:
current_omni_platform.set_device_control_env_var(previous_visible_devices)
+ # After StageEngineCoreProc has been spawned it carries its
+ # stage-specific device visibility into descendants, so the
+ # slow HELLO/READY handshake can run without holding the
+ # process-wide launch lock.
+ if self.single_stage_mode and self._omni_master_server is not None:
+ launch_stack.close()
+ else:
+ assert proc is not None
+ assert handshake_address is not None
+ complete_stage_handshake(proc, handshake_address, addresses, vllm_config, stage_init_timeout)
+ logger.info("[AsyncOmniEngine] Stage %s engine startup completed", metadata.stage_id)
+
assert started_stage is not None
return started_stage
except Exception:
@@ -752,10 +790,8 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
setup_stage_devices(configured_stage_id, metadata.runtime_cfg)
omni_conn_cfg, omni_from, omni_to = omni_kv_connector
if omni_conn_cfg:
- from vllm_omni.entrypoints.utils import inject_omni_kv_config
-
inject_omni_kv_config(stage_cfg, omni_conn_cfg, omni_from, omni_to)
- inject_kv_stage_info(stage_cfg, configured_stage_id)
+ inject_kv_stage_info(stage_cfg, configured_stage_id, self.stage_configs)
if self.single_stage_mode:
assert self._omni_master_server is not None
stage_clients[stage_idx] = self._launch_diffusion_stage(
@@ -764,11 +800,14 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
self._omni_master_server,
)
else:
+ use_inline = True if self.num_stages == 1 else False
stage_clients[stage_idx] = initialize_diffusion_stage(
self.model,
stage_cfg,
metadata,
+ stage_init_timeout=stage_init_timeout,
batch_size=self.diffusion_batch_size,
+ use_inline=use_inline,
)
logger.info(
"[AsyncOmniEngine] Stage %s initialized (diffusion, batch_size=%d)",
@@ -927,6 +966,7 @@ async def _run_orchestrator() -> None:
if isinstance(_sc, StageEngineCoreClientBase):
_sc.receiver_connectors = self.stage_receiver_connectors[_sid]
+ pd_config = self._detect_pd_config()
orchestrator = Orchestrator(
request_async_queue=self.request_queue.async_q,
output_async_queue=self.output_queue.async_q,
@@ -936,6 +976,7 @@ async def _run_orchestrator() -> None:
output_processors=self.output_processors,
stage_vllm_configs=self.stage_vllm_configs,
connectors=self.omni_connectors,
+ pd_config=pd_config,
)
if not startup_future.done():
startup_future.set_result(asyncio.get_running_loop())
@@ -1098,7 +1139,6 @@ def _enqueue_cfg_companions(
params=companion_params,
supported_tasks=self.supported_tasks,
)
- request = _upgrade_to_omni_request(request, companion_prompt)
request.external_req_id = cid
self.output_processors[0].add_request(
@@ -1158,12 +1198,64 @@ def _normalize_cache_config(cache_backend: str | None, cache_config: Any | None)
cache_config = AsyncOmniEngine._get_default_cache_config(cache_backend)
return cache_config
+ def _detect_pd_config(self) -> dict[str, Any] | None:
+ """Detect PD (Prefill-Decode) disaggregation config from stage_configs.
+ Returns a dict with 'pd_pair' and 'bootstrap_addr', or None.
+ """
+ pd_pair = PDDisaggregationMixin.detect_pd_separation_from_stage_configs(self.stage_configs)
+ if pd_pair is None:
+ return None
+ prefill_idx, decode_idx = pd_pair
+
+ # Extract bootstrap address from prefill stage engine_args
+ bootstrap_addr: str | None = None
+ try:
+ prefill_cfg = self.stage_configs[prefill_idx]
+ ea = getattr(prefill_cfg, "engine_args", None)
+ kv_cfg = getattr(ea, "kv_transfer_config", None) if ea is not None else None
+ if kv_cfg is not None:
+ port = vllm_envs.VLLM_MOONCAKE_BOOTSTRAP_PORT
+ kv_ip = getattr(kv_cfg, "kv_ip", None) or "127.0.0.1"
+ bootstrap_addr = f"http://{kv_ip}:{port}"
+ except Exception as exc:
+ logger.warning("[AsyncOmniEngine] Could not extract PD bootstrap address: %s", exc)
+
+ logger.info(
+ "[AsyncOmniEngine] PD disaggregation detected: prefill=stage-%d, decode=stage-%d, bootstrap=%s",
+ prefill_idx,
+ decode_idx,
+ bootstrap_addr,
+ )
+ prefill_engine_id: str | None = None
+ try:
+ prefill_client = self.stage_clients[prefill_idx]
+ kv_cfg = getattr(getattr(prefill_client, "vllm_config", None), "kv_transfer_config", None)
+ prefill_engine_id = getattr(kv_cfg, "engine_id", None)
+ except Exception as exc:
+ logger.warning("[AsyncOmniEngine] Could not extract prefill engine_id: %s", exc)
+
+ return {
+ "pd_pair": (prefill_idx, decode_idx),
+ "bootstrap_addr": bootstrap_addr,
+ "prefill_engine_id": prefill_engine_id,
+ }
+
@staticmethod
def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list:
"""Create a default single-stage diffusion config from kwargs."""
# We temporally create a default config for diffusion stage.
# In the future, we should merge the default config with the user-provided config.
normalized_kwargs = dict(kwargs)
+ default_sampling_params = normalized_kwargs.get("default_sampling_params")
+ if isinstance(default_sampling_params, str):
+ try:
+ default_sampling_params = json.loads(default_sampling_params)
+ except json.JSONDecodeError:
+ logger.warning("Invalid default_sampling_params JSON, ignoring stage defaults.")
+ default_sampling_params = None
+ if not isinstance(default_sampling_params, dict):
+ default_sampling_params = None
+ stage_default_sampling_params = default_sampling_params.get("0", {}) if default_sampling_params else {}
# TODO: hack, convert dtype to string to avoid non-premitive omegaconf create error.
if "dtype" in normalized_kwargs and not isinstance(normalized_kwargs["dtype"], str):
@@ -1227,6 +1319,8 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list:
"enable_cpu_offload": kwargs.get("enable_cpu_offload", False),
"enable_layerwise_offload": kwargs.get("enable_layerwise_offload", False),
"enforce_eager": kwargs.get("enforce_eager", False),
+ "boundary_ratio": kwargs.get("boundary_ratio", None),
+ "flow_shift": kwargs.get("flow_shift", None),
"diffusion_load_format": kwargs.get("diffusion_load_format", "default"),
"custom_pipeline_args": kwargs.get("custom_pipeline_args", None),
"worker_extension_cls": kwargs.get("worker_extension_cls", None),
@@ -1258,6 +1352,7 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list:
"devices": devices,
},
"engine_args": stage_engine_args,
+ "default_sampling_params": stage_default_sampling_params,
"final_output": True,
"final_output_type": "image",
}
@@ -1265,10 +1360,50 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list:
default_stage_cfg[0]["engine_args"]["model_stage"] = "diffusion"
return default_stage_cfg
+ @staticmethod
+ def _strip_single_engine_args(kwargs: dict[str, Any]) -> dict[str, Any]:
+ """Remove parent ``EngineArgs`` fields from *kwargs*.
+
+ When ``stage_configs_path`` is set, per-stage engine args are defined
+ in the YAML. Top-level single-engine fields (``compilation_config``,
+ ``tensor_parallel_size``, …) must not leak into per-stage configs via
+ the ``base_engine_args`` merge in ``load_stage_configs_from_yaml`` —
+ they can cause type errors (e.g. ``compilation_config`` as a JSON
+ string rejected by ``VllmConfig``) or silently override YAML values.
+
+ Logs a warning for any parent field whose value differs from the
+ dataclass default, so users know their explicit overrides are ignored.
+ See the module-level ``_PARENT_ARGS_*`` constants for the routing
+ contracts this method enforces.
+ """
+ parent_fields: dict[str, dataclasses.Field] = {f.name: f for f in dataclasses.fields(EngineArgs)}
+ result, overridden = strip_parent_engine_args(
+ kwargs,
+ parent_fields=parent_fields,
+ keep_keys=_PARENT_ARGS_KEEP,
+ strip_keys=_PARENT_ARGS_STRIP,
+ no_warn_keys=_PARENT_ARGS_NO_WARN,
+ )
+
+ if overridden:
+ logger.warning(
+ "stage_configs_path is set — the following top-level engine "
+ "args are ignored (per-stage YAML takes precedence): %s",
+ ", ".join(sorted(overridden)),
+ )
+
+ return result
+
def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[str, list[Any]]:
"""Resolve stage configs and inject defaults shared by orchestrator/headless."""
stage_configs_path = kwargs.get("stage_configs_path", None)
+ deploy_config_path = kwargs.pop("deploy_config", None)
+ stage_overrides_json = kwargs.pop("stage_overrides", None)
+ # Set of CLI keys the user actually typed; ``None`` means we have no
+ # parser-level info (e.g. programmatic Omni() call) and the lower
+ # layers should treat all kwargs as explicit.
+ cli_explicit_keys = kwargs.pop("_cli_explicit_keys", None)
explicit_stage_configs = kwargs.pop("stage_configs", None)
if explicit_stage_configs is not None:
logger.warning(
@@ -1276,13 +1411,32 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st
"Ignoring it and resolving stages from stage_configs_path/model factory."
)
- # Use the legacy config loading path (load_and_resolve_stage_configs).
- # StageConfigFactory wiring will be done in config refactor [2/N].
+ if stage_configs_path is not None:
+ base_kwargs = self._strip_single_engine_args(kwargs)
+ else:
+ base_kwargs = kwargs
+
+ # Parse --stage-overrides JSON string if provided
+ stage_overrides = None
+ if stage_overrides_json:
+ if isinstance(stage_overrides_json, str):
+ try:
+ stage_overrides = json.loads(stage_overrides_json)
+ except json.JSONDecodeError as exc:
+ raise ValueError(
+ f"--stage-overrides is not valid JSON: {exc}. Got: {stage_overrides_json!r}"
+ ) from exc
+ else:
+ stage_overrides = stage_overrides_json
+
config_path, stage_configs = load_and_resolve_stage_configs(
model,
stage_configs_path,
- kwargs,
+ base_kwargs,
default_stage_cfg_factory=lambda: self._create_default_diffusion_stage_cfg(kwargs),
+ deploy_config_path=deploy_config_path,
+ stage_overrides=stage_overrides,
+ cli_explicit_keys=cli_explicit_keys,
)
# Inject diffusion LoRA-related knobs from kwargs if not present in the stage config.
diff --git a/vllm_omni/engine/cfg_companion_tracker.py b/vllm_omni/engine/cfg_companion_tracker.py
new file mode 100644
index 0000000000..b9dfae833e
--- /dev/null
+++ b/vllm_omni/engine/cfg_companion_tracker.py
@@ -0,0 +1,125 @@
+"""CFG companion request tracker for the Omni orchestrator.
+
+Encapsulates all bookkeeping for Classifier-Free Guidance companion
+requests (parent/companion ID mapping, completion tracking,
+deferred forwarding, and cleanup).
+"""
+
+from __future__ import annotations
+
+import logging
+from typing import Any
+
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+
+logger = logging.getLogger(__name__)
+
+
+class CfgCompanionTracker:
+ """Manages CFG companion request lifecycle in the orchestrator scheduling loop."""
+
+ def __init__(self) -> None:
+ self._companion_map: dict[str, dict[str, str]] = {} # parent -> {role: companion_id}
+ self._companion_ids: set[str] = set()
+ self._companion_to_parent: dict[str, str] = {} # companion -> parent
+ self._done: dict[str, set[str]] = {} # parent -> completed companion ids
+ self._pending_parents: dict[str, dict[str, Any]] = {} # parent -> deferred result
+
+ def is_companion(self, req_id: str) -> bool:
+ return req_id in self._companion_ids
+
+ def has_companions(self, parent_id: str) -> bool:
+ return parent_id in self._companion_map
+
+ def all_companions_done(self, parent_id: str) -> bool:
+ role_map = self._companion_map.get(parent_id, {})
+ done_set = self._done.get(parent_id, set())
+ return all(cid in done_set for cid in role_map.values())
+
+ def get_companion_request_ids(self, parent_id: str) -> dict[str, str]:
+ """Return ``{role: companion_request_id}`` for a parent."""
+ return self._companion_map.get(parent_id, {})
+
+ def register_parent(self, parent_id: str) -> None:
+ self._companion_map.setdefault(parent_id, {})
+ self._done.setdefault(parent_id, set())
+
+ def register_companion(self, parent_id: str, role: str, companion_id: str) -> None:
+ self.register_parent(parent_id)
+ self._companion_map[parent_id][role] = companion_id
+ self._companion_ids.add(companion_id)
+ self._companion_to_parent[companion_id] = parent_id
+
+ def attach_cfg_request_ids(self, parent_id: str, sampling_params: Any) -> Any:
+ cfg_ids = self.get_companion_request_ids(parent_id)
+ if not cfg_ids:
+ return sampling_params
+
+ if isinstance(sampling_params, OmniDiffusionSamplingParams):
+ sampling_params = sampling_params.clone()
+ sampling_params.cfg_kv_request_ids = cfg_ids
+ logger.info(
+ "Attaching cfg_kv_request_ids=%s to request %s",
+ cfg_ids,
+ parent_id,
+ )
+ return sampling_params
+
+ def on_companion_completed(self, companion_id: str) -> str | None:
+ """Mark done. Returns parent_id only if parent is pending and all companions finished."""
+ parent_id = self._companion_to_parent.get(companion_id)
+ if parent_id is None:
+ return None
+ done_set = self._done.get(parent_id)
+ assert done_set is not None, f"Companion {companion_id} completed before parent {parent_id} was registered"
+ if companion_id in done_set:
+ return None
+ done_set.add(companion_id)
+ logger.debug("CFG companion %s completed (parent=%s)", companion_id, parent_id)
+ if parent_id in self._pending_parents and self.all_companions_done(parent_id):
+ return parent_id
+ return None
+
+ def defer_parent(self, parent_id: str, engine_outputs: Any, stage_id: int) -> None:
+ """Hold parent result while waiting for companions to finish."""
+ # TODO: Add timeout/error recovery when the orchestrator grows a
+ # companion-failure path. Today deferred parents are released only when
+ # companions finish or the external layer aborts the request.
+ self._pending_parents[parent_id] = {
+ "engine_outputs": engine_outputs,
+ "stage_id": stage_id,
+ }
+ logger.debug("Parent %s deferred, waiting for CFG companions", parent_id)
+
+ def pop_pending_parent(self, parent_id: str) -> dict[str, Any] | None:
+ return self._pending_parents.pop(parent_id, None)
+
+ def cleanup_parent(self, parent_id: str) -> list[str]:
+ companion_ids = list(self._companion_map.pop(parent_id, {}).values())
+ for companion_id in companion_ids:
+ self._companion_ids.discard(companion_id)
+ self._companion_to_parent.pop(companion_id, None)
+ self._done.pop(parent_id, None)
+ self._pending_parents.pop(parent_id, None)
+ return companion_ids
+
+ def abort_parents(self, request_ids: list[str]) -> list[str]:
+ all_request_ids = list(request_ids)
+ seen = set(all_request_ids)
+ parents_to_cleanup: set[str] = set()
+
+ for req_id in request_ids:
+ # The orchestrator calls this with parent request IDs. If a raw
+ # companion ID is passed here, keep it as a direct abort target and
+ # avoid tearing down parent tracking state implicitly.
+ if req_id not in self._companion_ids:
+ parents_to_cleanup.add(req_id)
+
+ for parent_id in parents_to_cleanup:
+ companion_ids = self.cleanup_parent(parent_id)
+ for companion_id in companion_ids:
+ if companion_id not in seen:
+ seen.add(companion_id)
+ all_request_ids.append(companion_id)
+
+ return all_request_ids
diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py
index d955b30251..aea7b9a8b7 100644
--- a/vllm_omni/engine/orchestrator.py
+++ b/vllm_omni/engine/orchestrator.py
@@ -27,6 +27,7 @@
from vllm_omni.engine import (
OmniEngineCoreRequest,
)
+from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker
from vllm_omni.engine.serialization import serialize_additional_information
from vllm_omni.metrics.stats import StageRequestStats as StageRequestMetrics
from vllm_omni.metrics.stats import StageStats
@@ -41,6 +42,8 @@ def build_engine_core_request_from_tokens(
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
model_config: ModelConfig | None = None,
+ resumable: bool = False,
+ mm_features: list | None = None,
) -> OmniEngineCoreRequest:
"""Build an OmniEngineCoreRequest directly from an OmniTokensPrompt.
@@ -75,7 +78,7 @@ def build_engine_core_request_from_tokens(
return OmniEngineCoreRequest(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
- mm_features=None,
+ mm_features=mm_features,
sampling_params=sampling_params,
pooling_params=pooling_params,
arrival_time=arrival_time,
@@ -83,6 +86,7 @@ def build_engine_core_request_from_tokens(
cache_salt=None,
data_parallel_rank=None,
prompt_embeds=prompt_embeds,
+ resumable=resumable,
additional_information=additional_info_payload,
)
@@ -103,6 +107,22 @@ class OrchestratorRequestState:
# Metrics: timestamp when request was submitted to each stage
stage_submit_ts: dict[int, float] = field(default_factory=dict)
+ mm_processor_kwargs: dict | None = None
+ mm_features: list | None = None
+
+ streaming: StreamingInputState = field(default_factory=lambda: StreamingInputState())
+
+
+@dataclass
+class StreamingInputState:
+ # Flag of streaming input request
+ enabled: bool = False
+ # Flag of segment of streaming input finished
+ segment_finished: bool = False
+ # Streaming update prompt length
+ new_prompt_len_snapshot: int | None = None
+ # Model/bridge-specific runtime states (e.g., thinker->talker)
+ bridge_states: dict[str, Any] = field(default_factory=dict)
class Orchestrator:
@@ -123,6 +143,7 @@ def __init__(
*,
async_chunk: bool = False,
connectors: dict[tuple[str, str], Any] | None = None,
+ pd_config: dict[str, Any] | None = None,
) -> None:
self.request_async_queue = request_async_queue
self.output_async_queue = output_async_queue
@@ -138,15 +159,21 @@ def __init__(
# Sender-side omni connectors keyed by (from_stage, to_stage)
self.connectors: dict[tuple[str, str], Any] = connectors or {}
+ # PD disaggregation state
+ self._pd_pair: tuple[int, int] | None = None
+ self._pd_bootstrap_addr: str | None = None
+ self._pd_prefill_engine_id: str | None = None
+ self._pd_kv_params: dict[str, Any] = {}
+ if pd_config is not None:
+ self._pd_pair = pd_config.get("pd_pair")
+ self._pd_bootstrap_addr = pd_config.get("bootstrap_addr")
+ self._pd_prefill_engine_id = pd_config.get("prefill_engine_id")
+
# Per-request state
self.request_states: dict[str, OrchestratorRequestState] = {}
# CFG companion tracking
- self._companion_map: dict[str, dict[str, str]] = {}
- self._companion_to_parent: dict[str, str] = {}
- self._companion_ids: set[str] = set()
- self._companion_done: dict[str, set[str]] = {}
- self._deferred_parents: dict[str, dict[str, Any]] = {}
+ self._cfg_tracker = CfgCompanionTracker()
# Per-stage metrics accumulators.
self._batch_seq: list[int] = [0] * self.num_stages
@@ -250,6 +277,23 @@ async def _orchestration_loop(self) -> None:
idle = False
req_state = self.request_states.get(output.request_id)
if req_state is not None:
+ if getattr(output, "error", None) is not None:
+ parent_id = self._companion_to_parent.get(output.request_id, output.request_id)
+ await self.output_async_queue.put(
+ {
+ "type": "error",
+ "request_id": parent_id,
+ "stage_id": stage_id,
+ "error": output.error,
+ }
+ )
+ role_map = self._companion_map.get(parent_id, {})
+ for cid in role_map.values():
+ self.request_states.pop(cid, None)
+ self._cleanup_companion_state(parent_id)
+ self.request_states.pop(parent_id, None)
+ continue
+
stage_metrics = self._build_stage_metrics(stage_id, output.request_id, [output], req_state)
await self._route_output(stage_id, output, req_state, stage_metrics)
continue
@@ -321,7 +365,7 @@ async def _route_output(
# CFG companion handling: companions don't produce user-visible output
# and don't forward to the next stage directly.
- if finished and req_id in self._companion_ids:
+ if finished and self._cfg_tracker.is_companion(req_id):
await self._handle_cfg_companion_ready(req_id)
self.request_states.pop(req_id, None)
return
@@ -349,63 +393,80 @@ async def _route_output(
}
)
+ # PD disaggregation: extract KV transfer params from prefill stage output
+ if self._pd_pair is not None and finished and stage_id == self._pd_pair[0]:
+ kv_params = getattr(output, "kv_transfer_params", None)
+ if kv_params is not None:
+ self._pd_kv_params[req_id] = kv_params if isinstance(kv_params, dict) else dict(kv_params)
+ logger.debug(
+ "[Orchestrator][PD] stored kv_transfer_params for req=%s (keys=%s)",
+ req_id,
+ list(self._pd_kv_params[req_id].keys()),
+ )
+ else:
+ logger.warning(
+ "[Orchestrator][PD] prefill stage output for req=%s has no kv_transfer_params; "
+ "KV transfer may fail. Ensure apply_mooncake_connector_patch() was called.",
+ req_id,
+ )
+
if (
- finished
+ (finished or (req_state.streaming.enabled and req_state.streaming.segment_finished))
and stage_id < req_state.final_stage_id
and not self.async_chunk
- and not self._next_stage_already_submitted(stage_id, req_state)
+ and (not self._next_stage_already_submitted(stage_id, req_state) or req_state.streaming.enabled)
):
- if req_id in self._companion_map and not self._all_companions_done(req_id):
- self._deferred_parents[req_id] = {
- "stage_id": stage_id,
- "output": output,
- }
+ if (
+ finished
+ and self._cfg_tracker.has_companions(req_id)
+ and not self._cfg_tracker.all_companions_done(req_id)
+ ):
+ self._cfg_tracker.defer_parent(req_id, output, stage_id)
else:
- await self._forward_to_next_stage(req_id, stage_id, output, req_state)
+ await self._forward_to_next_stage(
+ req_id,
+ stage_id,
+ output,
+ req_state,
+ is_streaming_session=req_state.streaming.enabled,
+ is_final_update=False,
+ )
+ if req_state.streaming.enabled and finished:
+ # For streaming sessions, send the terminal (resumable=False) update only on a finish
+ await self._forward_to_next_stage(
+ req_id,
+ stage_id,
+ output,
+ req_state,
+ is_streaming_session=True,
+ is_final_update=True,
+ )
if finished and stage_id == req_state.final_stage_id:
- self._cleanup_companion_state(req_id)
+ # PD: clean up any lingering KV params for this request
+ self._pd_kv_params.pop(req_id, None)
+ self._cfg_tracker.cleanup_parent(req_id)
self.request_states.pop(req_id, None)
- def _cleanup_companion_state(self, parent_id: str) -> None:
- """Remove all companion tracking state for a completed parent."""
- role_map = self._companion_map.pop(parent_id, {})
- for cid in role_map.values():
- self._companion_ids.discard(cid)
- self._companion_to_parent.pop(cid, None)
- self._companion_done.pop(parent_id, None)
- self._deferred_parents.pop(parent_id, None)
-
- def _all_companions_done(self, parent_id: str) -> bool:
- """Check whether all CFG companions for a parent request have finished."""
- role_map = self._companion_map.get(parent_id, {})
- if not role_map:
- return True
- done_set = self._companion_done.get(parent_id, set())
- return all(cid in done_set for cid in role_map.values())
-
def _next_stage_already_submitted(self, stage_id: int, req_state: OrchestratorRequestState) -> bool:
return (stage_id + 1) in req_state.stage_submit_ts
async def _handle_cfg_companion_ready(self, req_id: str) -> None:
"""Mark a CFG companion as done; if all companions are done, flush deferred parent."""
- parent_id = self._companion_to_parent.get(req_id)
+ parent_id = self._cfg_tracker.on_companion_completed(req_id)
if parent_id is None:
return
- done_set = self._companion_done.setdefault(parent_id, set())
- if req_id in done_set:
+ deferred = self._cfg_tracker.pop_pending_parent(parent_id)
+ if deferred is None:
return
- done_set.add(req_id)
- if parent_id in self._deferred_parents and self._all_companions_done(parent_id):
- deferred = self._deferred_parents.pop(parent_id)
- parent_state = self.request_states.get(parent_id)
- if parent_state is not None and not self._next_stage_already_submitted(deferred["stage_id"], parent_state):
- await self._forward_to_next_stage(
- parent_id,
- deferred["stage_id"],
- deferred["output"],
- parent_state,
- )
+ parent_state = self.request_states.get(parent_id)
+ if parent_state is not None and not self._next_stage_already_submitted(deferred["stage_id"], parent_state):
+ await self._forward_to_next_stage(
+ parent_id,
+ deferred["stage_id"],
+ deferred["engine_outputs"],
+ parent_state,
+ )
async def _handle_kv_ready_raw_outputs(self, stage_id: int, raw_outputs: EngineCoreOutputs) -> None:
"""Forward split requests once stage-0 KV is ready, not only when decode fully finishes."""
@@ -419,21 +480,63 @@ async def _handle_kv_ready_raw_outputs(self, stage_id: int, raw_outputs: EngineC
req_state = self.request_states.get(req_id)
if req_state is None:
continue
- if req_id in self._companion_ids:
+ if self._cfg_tracker.is_companion(req_id):
await self._handle_cfg_companion_ready(req_id)
continue
if stage_id >= req_state.final_stage_id:
continue
if self._next_stage_already_submitted(stage_id, req_state):
continue
- if req_id in self._companion_map and not self._all_companions_done(req_id):
- self._deferred_parents[req_id] = {
- "stage_id": stage_id,
- "output": raw_output,
- }
+ if self._cfg_tracker.has_companions(req_id) and not self._cfg_tracker.all_companions_done(req_id):
+ self._cfg_tracker.defer_parent(req_id, raw_output, stage_id)
else:
await self._forward_to_next_stage(req_id, stage_id, raw_output, req_state)
+ def _build_pd_decode_params(self, req_id: str, sp: Any) -> Any:
+ """Build decode-side sampling params with KV transfer params for PD routing.
+
+ Clones the sampling params and injects kv_transfer_params that tell the
+ decode engine where to pull the KV cache from (prefill engine's bootstrap addr).
+ """
+ sp = sp.clone()
+ if sp.extra_args is None:
+ sp.extra_args = {}
+
+ # Get KV params captured from the prefill output (must include remote_request_id).
+ kv_prefill_params = self._pd_kv_params.pop(req_id, None)
+ if not kv_prefill_params or "remote_request_id" not in kv_prefill_params:
+ raise RuntimeError(
+ f"[Orchestrator][PD] Missing prefill kv_transfer_params.remote_request_id for req={req_id}"
+ )
+
+ decode_kv_params: dict[str, Any] = {
+ "transfer_id": f"xfer-{req_id}",
+ }
+
+ if self._pd_bootstrap_addr:
+ decode_kv_params["remote_bootstrap_addr"] = self._pd_bootstrap_addr
+
+ if self._pd_prefill_engine_id:
+ decode_kv_params["remote_engine_id"] = self._pd_prefill_engine_id
+
+ # Overlay params from prefill side (includes remote_request_id set by monkey patch).
+ decode_kv_params.update(kv_prefill_params)
+
+ # Ensure these flags are set correctly after any overlay.
+ decode_kv_params["do_remote_prefill"] = True
+ decode_kv_params["do_remote_decode"] = False
+ if not decode_kv_params.get("transfer_id"):
+ decode_kv_params["transfer_id"] = f"xfer-{req_id}"
+
+ sp.extra_args["kv_transfer_params"] = decode_kv_params
+
+ logger.debug(
+ "[Orchestrator][PD] decode kv_transfer_params for req=%s: %s",
+ req_id,
+ decode_kv_params,
+ )
+ return sp
+
def _build_stage_metrics(
self,
stage_id: int,
@@ -511,6 +614,9 @@ async def _forward_to_next_stage(
stage_id: int,
output: Any,
req_state: OrchestratorRequestState,
+ *,
+ is_streaming_session: bool = False,
+ is_final_update: bool = False,
) -> None:
"""Forward output from current stage to the next stage.
@@ -520,6 +626,7 @@ async def _forward_to_next_stage(
next_stage_id = stage_id + 1
next_client = self.stage_clients[next_stage_id]
params = req_state.sampling_params_list[next_stage_id]
+ next_stage_resumable = is_streaming_session and not is_final_update
if next_client.stage_type == "diffusion":
self.stage_clients[stage_id].set_engine_outputs([output])
@@ -535,20 +642,7 @@ async def _forward_to_next_stage(
else:
diffusion_prompt = req_state.prompt
- # Attach CFG companion KV request IDs so the diffusion model
- # runner can fetch companion KV caches alongside the primary one.
- cfg_ids = self._companion_map.get(req_id)
- if cfg_ids:
- from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
- if isinstance(params, OmniDiffusionSamplingParams):
- params = copy.deepcopy(params)
- params.cfg_kv_request_ids = cfg_ids
- logger.info(
- "[Orchestrator] Attaching cfg_kv_request_ids=%s to req %s",
- cfg_ids,
- req_id,
- )
+ params = self._cfg_tracker.attach_cfg_request_ids(req_id, params)
source_stage_ids = list(getattr(next_client, "engine_input_source", None) or [stage_id])
kv_sender_info = self._build_kv_sender_info(sender_stage_ids=source_stage_ids)
@@ -569,6 +663,52 @@ async def _forward_to_next_stage(
req_state.stage_submit_ts[next_stage_id] = _time.time()
return
+ # PD disaggregation: prefill → decode routing uses original prompt + KV transfer params
+ if self._pd_pair is not None and (stage_id, next_stage_id) == self._pd_pair:
+ # Save prefill stage outputs so thinker2talker can merge embeddings later
+ self.stage_clients[stage_id].set_engine_outputs([output])
+
+ params = self._build_pd_decode_params(req_id, params)
+
+ # Use the original user prompt for the decode stage (not processed embeddings)
+ original_prompt = req_state.prompt
+ raw_decode_inputs = [original_prompt] if not isinstance(original_prompt, list) else original_prompt
+
+ decode_inputs: list[dict[str, Any]] = []
+ for decode_input in raw_decode_inputs:
+ if isinstance(decode_input, dict):
+ decode_inputs.append(decode_input)
+ continue
+ prompt_token_ids = getattr(decode_input, "prompt_token_ids", None)
+ if prompt_token_ids is None:
+ raise TypeError(
+ "[Orchestrator][PD] decode input must be dict or have prompt_token_ids, "
+ f"got {type(decode_input).__name__} for req={req_id}"
+ )
+ decode_inputs.append({"prompt_token_ids": list(prompt_token_ids)})
+
+ for decode_input in decode_inputs:
+ request = build_engine_core_request_from_tokens(
+ request_id=req_id,
+ prompt=decode_input,
+ params=params,
+ model_config=self.stage_vllm_configs[next_stage_id].model_config,
+ mm_features=req_state.mm_features, # Pass mm_features for M-RoPE
+ )
+ request.external_req_id = request.request_id
+
+ self.output_processors[next_stage_id].add_request(
+ request=request,
+ prompt=None,
+ parent_req=None,
+ request_index=0,
+ queue=None,
+ )
+ await next_client.add_request_async(request)
+
+ req_state.stage_submit_ts[next_stage_id] = _time.time()
+ return
+
self.stage_clients[stage_id].set_engine_outputs([output])
# Process inputs for next stage
@@ -576,6 +716,7 @@ async def _forward_to_next_stage(
next_inputs = next_client.process_engine_inputs(
stage_list=self.stage_clients,
prompt=req_state.prompt,
+ streaming_context=(req_state.streaming if req_state.streaming.enabled else None),
)
except Exception:
logger.exception(
@@ -630,11 +771,17 @@ async def _forward_to_next_stage(
next_stage_id,
)
+ # Only AR thinker stages consume encoder mm_features; downstream
+ # (talker/code2wav/…) must not see them (avoids encoder-cache misses).
+ _ms = getattr(next_client, "model_stage", None)
+ _mm_features = req_state.mm_features if _ms == "thinker" else None
request = build_engine_core_request_from_tokens(
request_id=req_id,
prompt=next_input,
params=params,
model_config=self.stage_vllm_configs[next_stage_id].model_config,
+ mm_features=_mm_features,
+ resumable=next_stage_resumable,
)
# TODO: Here we directly use the req id to assign.
@@ -676,6 +823,13 @@ async def _process_stage_outputs(self, stage_id: int, raw_outputs: EngineCoreOut
raw_outputs.timestamp,
None,
)
+ for eco in raw_outputs.outputs:
+ if not hasattr(eco, "request_id"):
+ continue
+ req_state = self.request_states.get(eco.request_id)
+ if req_state:
+ req_state.streaming.segment_finished = eco.is_segment_finished
+ req_state.streaming.new_prompt_len_snapshot = eco.new_prompt_len_snapshot
if processed.reqs_to_abort:
await self.stage_clients[stage_id].abort_requests_async(processed.reqs_to_abort)
@@ -711,19 +865,22 @@ async def _handle_add_request(self, msg: dict[str, Any]) -> None:
# Track request state - use original_prompt so downstream stages
# (e.g. thinker2talker) can access the raw dict with multi_modal_data.
+ request = prompt
+ is_streaming = bool(getattr(request, "resumable", False))
req_state = OrchestratorRequestState(
request_id=request_id,
prompt=original_prompt,
sampling_params_list=sampling_params_list,
final_stage_id=final_stage_id,
+ mm_features=getattr(prompt, "mm_features", None), # Save mm_features for PD
)
+ req_state.streaming.enabled = is_streaming
req_state.stage_submit_ts[stage_id] = _time.time()
self.request_states[request_id] = req_state
# Stage-0 prompt is already a fully-formed OmniEngineCoreRequest
# (pre-processed by AsyncOmniEngine.add_request, output processor
# already registered there) - submit directly.
- request = prompt
stage_client = self.stage_clients[stage_id]
if stage_client.stage_type == "diffusion":
if isinstance(prompt, list):
@@ -759,6 +916,7 @@ async def _handle_streaming_update(self, msg: dict[str, Any]) -> None:
if "sampling_params_list" in msg and msg["sampling_params_list"]:
req_state.sampling_params_list = msg["sampling_params_list"]
+ req_state.streaming.enabled = True
req_state.stage_submit_ts[stage_id] = _time.time()
stage_client = self.stage_clients[stage_id]
@@ -847,13 +1005,7 @@ async def _handle_add_companion(self, msg: dict[str, Any]) -> None:
companion_prompt = msg["prompt"]
sampling_params_list = msg["sampling_params_list"]
- # Register companion mapping
- if parent_id not in self._companion_map:
- self._companion_map[parent_id] = {}
- self._companion_map[parent_id][role] = companion_id
- self._companion_ids.add(companion_id)
- self._companion_to_parent[companion_id] = parent_id
- self._companion_done.setdefault(parent_id, set())
+ self._cfg_tracker.register_companion(parent_id, role, companion_id)
companion_state = OrchestratorRequestState(
request_id=companion_id,
@@ -878,22 +1030,10 @@ async def _handle_add_companion(self, msg: dict[str, Any]) -> None:
async def _handle_abort(self, msg: dict[str, Any]) -> None:
"""Handle an abort message from the main thread."""
request_ids = msg["request_ids"]
- # Also abort any CFG companions for aborted parents
- companion_ids_to_abort: list[str] = []
- for req_id in request_ids:
- role_map = self._companion_map.pop(req_id, {})
- for cid in role_map.values():
- companion_ids_to_abort.append(cid)
- self._companion_ids.discard(cid)
- self._companion_to_parent.pop(cid, None)
- self.request_states.pop(cid, None)
- self._companion_done.pop(req_id, None)
- self._deferred_parents.pop(req_id, None)
-
- all_ids_to_abort = list(request_ids) + companion_ids_to_abort
+ all_ids_to_abort = self._cfg_tracker.abort_parents(request_ids)
for stage_id in range(self.num_stages):
await self.stage_clients[stage_id].abort_requests_async(all_ids_to_abort)
- for req_id in request_ids:
+ for req_id in all_ids_to_abort:
self.request_states.pop(req_id, None)
logger.info("[Orchestrator] Aborted request(s) %s", request_ids)
diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py
index 43d02e85b8..67b4dd1650 100644
--- a/vllm_omni/engine/output_processor.py
+++ b/vllm_omni/engine/output_processor.py
@@ -118,9 +118,10 @@ def _consolidate_multimodal_tensors(self) -> None:
if isinstance(v, list) and v and isinstance(v[0], torch.Tensor):
try:
if k == "audio":
- # When the audio tensor shape is inconsistent, torch.cat will fail.
- # We need to use torch.cat in -1 dimension.
- continue
+ # Concatenate delta audio chunks (1-D) into the full waveform.
+ # Each entry is a per-step slice; flatten to -1 so chunks with
+ # inconsistent leading dims can still be joined on the sample axis.
+ self.mm_accumulated[k] = torch.cat([t.reshape(-1) for t in v], dim=0)
elif k == "sr":
# Sample rate is a constant scalar, keep last value.
self.mm_accumulated[k] = v[-1]
@@ -232,10 +233,9 @@ def _new_completion_output(
# Reuse base text/logprobs logic, then annotate with pooling_result.
base_output = super()._new_completion_output(token_ids, finish_reason, stop_reason, routed_experts)
try:
+ if not hasattr(base_output, "multimodal_output"):
+ setattr(base_output, "multimodal_output", {})
if self.mm_accumulated is not None:
- # Attach accumulated multimodal dict on the completion output
- if not hasattr(base_output, "multimodal_output"):
- setattr(base_output, "multimodal_output", {})
mm_out = getattr(base_output, "multimodal_output")
if isinstance(mm_out, dict):
for k, v in self.mm_accumulated.items():
diff --git a/vllm_omni/engine/stage_engine_core_client.py b/vllm_omni/engine/stage_engine_core_client.py
index 193ab6a656..073c2340a4 100644
--- a/vllm_omni/engine/stage_engine_core_client.py
+++ b/vllm_omni/engine/stage_engine_core_client.py
@@ -14,7 +14,9 @@
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import AsyncMPClient, DPLBAsyncMPClient
-from vllm_omni.distributed.omni_connectors.utils.initialization import KV_TRANSFER_PORT_OFFSET
+from vllm_omni.distributed.omni_connectors.utils.initialization import (
+ KV_TRANSFER_PORT_OFFSET,
+)
from vllm_omni.engine.stage_init_utils import StageMetadata
if TYPE_CHECKING:
@@ -294,6 +296,8 @@ def _initialize_kv_sender_endpoint(self) -> None:
from_stage = omni_kv_config.get("omni_from_stage", from_stage)
try:
+ # Orchestrator always reports rank-0's port; receiver
+ # workers add their own local_rank * KV_RANK_PORT_STRIDE.
sender_port = int(base_port) + KV_TRANSFER_PORT_OFFSET + int(from_stage)
except (TypeError, ValueError):
logger.warning(
@@ -332,6 +336,7 @@ def get_kv_sender_info(
self._kv_sender_host = self._resolve_contact_host()
if self._kv_sender_host is None:
return None
+ # rank-0 base port; receiver workers adjust per KV_RANK_PORT_STRIDE.
return {
"host": self._kv_sender_host,
"zmq_port": base_port + kv_transfer_port_offset + int(self.stage_id),
@@ -345,11 +350,21 @@ def process_engine_inputs(
self,
stage_list: list[Any],
prompt: OmniTokensPrompt | list[OmniTokensPrompt] | None = None,
+ streaming_context: Any | None = None,
) -> list[OmniTokensPrompt]:
"""Process inputs from upstream stages."""
from vllm_omni.inputs.data import OmniTokensPrompt
if self.custom_process_input_func is not None:
+ # Keep legacy arg call for non-streaming processors.
+ if bool(getattr(streaming_context, "enabled", False)):
+ return self.custom_process_input_func(
+ stage_list,
+ self.engine_input_source,
+ prompt,
+ self.requires_multimodal_data,
+ streaming_context,
+ )
return self.custom_process_input_func(
stage_list,
self.engine_input_source,
diff --git a/vllm_omni/engine/stage_engine_core_proc.py b/vllm_omni/engine/stage_engine_core_proc.py
index 05d8f107c2..689378a798 100644
--- a/vllm_omni/engine/stage_engine_core_proc.py
+++ b/vllm_omni/engine/stage_engine_core_proc.py
@@ -37,8 +37,6 @@
logger = init_logger(__name__)
-_HANDSHAKE_POLL_TIMEOUT_S = 600
-
class StageEngineCoreProc(EngineCoreProc):
"""Stage-specific engine core process for vLLM-Omni.
@@ -145,13 +143,14 @@ def complete_stage_handshake(
handshake_address: str,
addresses: EngineZmqAddresses,
vllm_config: VllmConfig,
+ handshake_timeout: int,
) -> None:
"""Perform the HELLO/INIT/READY handshake with an already-spawned proc.
On failure the process is terminated before re-raising.
"""
try:
- _perform_handshake(proc, handshake_address, addresses, vllm_config)
+ _perform_handshake(proc, handshake_address, addresses, vllm_config, handshake_timeout)
except Exception:
shutdown([proc])
raise
@@ -162,6 +161,7 @@ def _perform_handshake(
handshake_address: str,
addresses: EngineZmqAddresses,
vllm_config: VllmConfig,
+ handshake_timeout: int,
) -> None:
"""Run the HELLO / INIT / READY handshake with the subprocess."""
with zmq_socket_ctx(handshake_address, zmq.ROUTER, bind=True) as handshake_socket:
@@ -169,7 +169,7 @@ def _perform_handshake(
poller.register(handshake_socket, zmq.POLLIN)
poller.register(proc.sentinel, zmq.POLLIN)
- identity, msg = _recv(poller, handshake_socket, proc, "HELLO")
+ identity, msg = _recv(poller, handshake_socket, proc, "HELLO", handshake_timeout)
if msg.get("status") != "HELLO":
raise RuntimeError(f"Expected HELLO, got: {msg}")
@@ -179,7 +179,7 @@ def _perform_handshake(
)
handshake_socket.send_multipart([identity, msgspec.msgpack.encode(init_payload)])
- identity, msg = _recv(poller, handshake_socket, proc, "READY")
+ identity, msg = _recv(poller, handshake_socket, proc, "READY", handshake_timeout)
if msg.get("status") != "READY":
raise RuntimeError(f"Expected READY, got: {msg}")
num_gpu_blocks = msg.get("num_gpu_blocks")
@@ -192,13 +192,18 @@ def _recv(
handshake_socket: zmq.Socket,
proc: BaseProcess,
expected: str,
+ timeout_s: int = 600,
) -> tuple[bytes, dict]:
"""Wait for one handshake message; raise if the process dies first."""
- timeout_ms = _HANDSHAKE_POLL_TIMEOUT_S * 1000
+ timeout_ms = timeout_s * 1000
while True:
events = dict(poller.poll(timeout=timeout_ms))
if not events:
- raise TimeoutError(f"Timed out waiting for {expected} from StageEngineCoreProc")
+ raise TimeoutError(
+ f"Timed out waiting for {expected} from StageEngineCoreProc after {timeout_s}s. "
+ f"This typically indicates model loading or initialization is taking too long. "
+ f"Consider increasing `stage_init_timeout` for large models."
+ )
if handshake_socket in events:
identity, raw = handshake_socket.recv_multipart()
return identity, msgspec.msgpack.decode(raw)
diff --git a/vllm_omni/engine/stage_init_utils.py b/vllm_omni/engine/stage_init_utils.py
index 09195faeca..89dfdc163c 100644
--- a/vllm_omni/engine/stage_init_utils.py
+++ b/vllm_omni/engine/stage_init_utils.py
@@ -13,7 +13,7 @@
import multiprocessing as mp
import os
import time
-from collections.abc import Callable
+from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Any, Literal
@@ -101,8 +101,110 @@ def resolve_worker_cls(engine_args: dict[str, Any]) -> None:
raise ValueError(f"Unknown worker_type: {worker_type}")
-def inject_kv_stage_info(stage_cfg: Any, stage_id: int) -> None:
- """Inject stage metadata into omni_kv_config when present."""
+def _get_attr_or_item(obj: Any, key: str, default: Any = None) -> Any:
+ """Read *key* from *obj* regardless of whether it's a dict or object."""
+ if hasattr(obj, "get"):
+ return obj.get(key, default)
+ return getattr(obj, key, default)
+
+
+def _tp_size_for_stage(stage_configs: Sequence[Any], stage_id: Any) -> int | None:
+ """Resolve tensor_parallel_size for *stage_id* from the loaded stage configs."""
+ id_strs = {str(stage_id)}
+ try:
+ id_strs.add(str(int(stage_id)))
+ except (TypeError, ValueError):
+ pass
+
+ for stage_cfg in stage_configs:
+ if str(getattr(stage_cfg, "stage_id", None)) not in id_strs:
+ continue
+ engine_args = getattr(stage_cfg, "engine_args", None)
+ if engine_args is None:
+ return 1
+ parallel_config = _get_attr_or_item(engine_args, "parallel_config")
+ if parallel_config is not None:
+ tp = _get_attr_or_item(parallel_config, "tensor_parallel_size", 1)
+ else:
+ tp = _get_attr_or_item(engine_args, "tensor_parallel_size", 1)
+ try:
+ return max(1, int(tp))
+ except (TypeError, ValueError):
+ return 1
+ return None
+
+
+def _inject_inferred_kv_tp_topology(
+ omni_kv: Any,
+ stage_id: int,
+ stage_configs: Sequence[Any],
+ engine_input_source: Sequence[int] | None = None,
+) -> None:
+ """Infer adjacent-stage TP topology and inject it into omni_kv_config.
+
+ This keeps heterogeneous TP working without requiring user-authored
+ rank_mapping blocks in config files.
+ """
+ if omni_kv is None:
+ return
+
+ if hasattr(omni_kv, "get"):
+ need_send = bool(omni_kv.get("need_send_cache", False))
+ need_recv = bool(omni_kv.get("need_recv_cache", False))
+ omni_from_stage = omni_kv.get("omni_from_stage")
+ omni_to_stage = omni_kv.get("omni_to_stage")
+ rank_mapping = omni_kv.get("rank_mapping")
+ else:
+ need_send = bool(getattr(omni_kv, "need_send_cache", False))
+ need_recv = bool(getattr(omni_kv, "need_recv_cache", False))
+ omni_from_stage = getattr(omni_kv, "omni_from_stage", None)
+ omni_to_stage = getattr(omni_kv, "omni_to_stage", None)
+ rank_mapping = getattr(omni_kv, "rank_mapping", None)
+
+ if not need_send and not need_recv:
+ return
+
+ current_tp = _tp_size_for_stage(stage_configs, stage_id)
+ if current_tp is None:
+ return
+
+ peer_stage_id = None
+ from_tp = None
+ to_tp = None
+ if str(omni_from_stage) == str(stage_id):
+ peer_stage_id = omni_to_stage
+ from_tp = current_tp
+ to_tp = _tp_size_for_stage(stage_configs, peer_stage_id)
+ elif str(omni_to_stage) == str(stage_id):
+ peer_stage_id = omni_from_stage
+ from_tp = _tp_size_for_stage(stage_configs, peer_stage_id)
+ to_tp = current_tp
+ elif need_recv and engine_input_source:
+ peer_stage_id = engine_input_source[0]
+ from_tp = _tp_size_for_stage(stage_configs, peer_stage_id)
+ to_tp = current_tp
+
+ if from_tp is None or to_tp is None:
+ return
+
+ if not isinstance(rank_mapping, dict):
+ rank_mapping = {}
+ rank_mapping.setdefault("from_tp", int(from_tp))
+ rank_mapping.setdefault("to_tp", int(to_tp))
+
+ if hasattr(omni_kv, "__setitem__"):
+ omni_kv["rank_mapping"] = rank_mapping
+ else:
+ setattr(omni_kv, "rank_mapping", rank_mapping)
+
+
+def inject_kv_stage_info(stage_cfg: Any, stage_id: int, stage_configs: Sequence[Any] | None = None) -> None:
+ """Inject stage_id, engine_input_source, and inferred TP topology into omni_kv_config.
+
+ When *stage_configs* is provided, also infers from_tp/to_tp for
+ heterogeneous TP topologies so the KV transfer manager can compute
+ rank mappings automatically.
+ """
try:
engine_args = stage_cfg.engine_args
if hasattr(engine_args, "get"):
@@ -125,6 +227,14 @@ def inject_kv_stage_info(stage_cfg: Any, stage_id: int) -> None:
omni_kv.setdefault("engine_input_source", list(engine_input_source))
elif hasattr(omni_kv, "__setitem__") and "engine_input_source" not in omni_kv:
omni_kv["engine_input_source"] = list(engine_input_source)
+
+ if stage_configs:
+ _inject_inferred_kv_tp_topology(
+ omni_kv,
+ stage_id=stage_id,
+ stage_configs=stage_configs,
+ engine_input_source=engine_input_source,
+ )
except Exception as e:
logger.debug("Failed to inject stage info into omni_kv_config: %s", e)
@@ -168,6 +278,20 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata:
stage_id: int = stage_config.stage_id
stage_type: Literal["llm", "diffusion"] = getattr(stage_config, "stage_type", "llm")
engine_args = stage_config.engine_args
+
+ if current_omni_platform.is_rocm():
+ if engine_args.get("attention_backend") is None:
+ from vllm._aiter_ops import rocm_aiter_ops
+
+ if rocm_aiter_ops.is_enabled():
+ engine_args["attention_backend"] = "ROCM_AITER_FA"
+ # Before vLLM v0.19.0, the default attention backend is TRITON_ATTN for ROCm.
+ # Since vLLM v0.19.0, the default attention backend is ROCM_ATTN for ROCm.
+ # However, the compatibility of ROCM_ATTN with Omni is not guaranteed.
+ # Therefore, we still use TRITON_ATTN as the default attention backend,
+ # when the selected_backend is not specified.
+ engine_args["attention_backend"] = "TRITON_ATTN"
+
runtime_cfg = getattr(stage_config, "runtime", {})
engine_input_source: list[int] = getattr(stage_config, "engine_input_source", [])
final_output: bool = getattr(stage_config, "final_output", False)
@@ -178,8 +302,9 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata:
default_sampling_params: OmniSamplingParams = SPClass(**default_sp)
custom_process_input_func: Callable | None = None
- if hasattr(stage_config, "custom_process_input_func"):
- mod_path, fn_name = stage_config.custom_process_input_func.rsplit(".", 1)
+ _cpif_path = getattr(stage_config, "custom_process_input_func", None)
+ if _cpif_path:
+ mod_path, fn_name = _cpif_path.rsplit(".", 1)
custom_process_input_func = getattr(importlib.import_module(mod_path), fn_name)
prompt_expand_func: Callable | None = None
@@ -309,6 +434,20 @@ def build_vllm_config(
filtered_engine_args_dict = filter_dataclass_kwargs(OmniEngineArgs, engine_args_dict)
omni_engine_args = OmniEngineArgs(**filtered_engine_args_dict)
+
+ # Multi-stage pipelines (qwen3_tts code2wav, etc.) set max_model_len
+ # larger than HF max_position_embeddings by design. vLLM's validator
+ # rejects that without the env flag.
+ if filtered_engine_args_dict.get("max_model_len") is not None and not os.environ.get(
+ "VLLM_ALLOW_LONG_MAX_MODEL_LEN"
+ ):
+ os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
+ logger.debug(
+ "Auto-set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 for stage %s (max_model_len=%s).",
+ stage_config.stage_id,
+ filtered_engine_args_dict["max_model_len"],
+ )
+
vllm_config = omni_engine_args.create_engine_config(
usage_context=UsageContext.LLM_CLASS,
headless=headless,
@@ -321,7 +460,7 @@ def build_vllm_config(
def acquire_device_locks(
stage_id: int,
engine_args_dict: dict[str, Any],
- stage_init_timeout: int = 300,
+ stage_init_timeout: int,
) -> list[int]:
"""Acquire exclusive file locks on devices needed by this stage.
@@ -514,7 +653,9 @@ def initialize_diffusion_stage(
model: str,
stage_cfg: Any,
metadata: StageMetadata,
+ stage_init_timeout: int,
batch_size: int = 1,
+ use_inline: bool = False,
) -> Any:
"""Build a diffusion stage client.
@@ -522,14 +663,16 @@ def initialize_diffusion_stage(
model: Model name or path.
stage_cfg: Stage configuration.
metadata: Extracted stage metadata.
+ stage_init_timeout: Timeout in seconds for stage initialization handshake
batch_size: Maximum number of requests to batch together in the
diffusion engine. Passed through to ``StageDiffusionClient``
and ultimately to ``AsyncOmni``.
+ use_inline: If True, uses the inline diffusion client instead of subprocess.
"""
- from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient
+ from vllm_omni.diffusion.stage_diffusion_client import create_diffusion_client
od_config = build_diffusion_config(model, stage_cfg, metadata)
- return StageDiffusionClient(model, od_config, metadata, batch_size=batch_size)
+ return create_diffusion_client(model, od_config, metadata, stage_init_timeout, batch_size, use_inline)
def _shutdown_or_close_resource(resource: Any, resource_name: str, stage_id: int) -> None:
diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py
index 129ef3c99d..9606cc80d0 100644
--- a/vllm_omni/entrypoints/async_omni.py
+++ b/vllm_omni/entrypoints/async_omni.py
@@ -78,7 +78,6 @@ def __init__(self, *args: Any, model: str = "", **kwargs: Any) -> None:
self.final_output_task: asyncio.Task | None = None
self.config_path = self.engine.config_path
- self.stage_configs = self.engine.stage_configs
self.tts_max_instructions_length = kwargs.get("tts_max_instructions_length", None)
self.input_processor = self.engine.input_processor
@@ -209,6 +208,13 @@ async def generate(
# Start final output dispatcher on the first call to generate()
self._final_output_handler()
+ # Expand sampling params for PD disaggregation (user may provide N-1 params)
+ if (
+ sampling_params_list is not None
+ and isinstance(sampling_params_list, Sequence)
+ and not isinstance(sampling_params_list, (str, bytes))
+ ):
+ sampling_params_list = self._maybe_expand_sampling_params(list(sampling_params_list))
sampling_params_list = self.resolve_sampling_params_list(sampling_params_list)
# Track per-request metrics
@@ -228,20 +234,27 @@ async def generate(
req_state.metrics = metrics
self.request_states[request_id] = req_state
+ # PD disaggregation: modify prefill-stage sampling params per request
+ req_sp_list = list(sampling_params_list)
+ pd_pair = self._get_pd_separation_pair()
+ if pd_pair is not None:
+ p_id = pd_pair[0]
+ req_sp_list[p_id] = self._prepare_prefill_sampling_params(request_id, req_sp_list[p_id])
+
# Add request(s) to stage 0. For streaming inputs, submit
# chunks incrementally through streaming_update.
if isinstance(prompt, AsyncGenerator):
input_stream_task = await self._add_streaming_input_request(
request_id=request_id,
input_stream=prompt,
- sampling_params_list=sampling_params_list,
+ sampling_params_list=req_sp_list,
final_stage_id=final_stage_id_for_e2e,
)
else:
await self.engine.add_request_async(
request_id=request_id,
prompt=prompt,
- sampling_params_list=sampling_params_list,
+ sampling_params_list=req_sp_list,
final_stage_id=final_stage_id_for_e2e,
)
submit_ts = time.time()
@@ -296,7 +309,6 @@ async def _add_streaming_input_request(
if not stage0_params.skip_clone:
stage0_params = stage0_params.clone()
stage0_params.skip_clone = True
- stage0_params.output_kind = RequestOutputKind.DELTA
has_submitted_first_chunk = False
diff --git a/vllm_omni/entrypoints/cfg_companion_tracker.py b/vllm_omni/entrypoints/cfg_companion_tracker.py
deleted file mode 100644
index 9c2e835f07..0000000000
--- a/vllm_omni/entrypoints/cfg_companion_tracker.py
+++ /dev/null
@@ -1,233 +0,0 @@
-"""CFG companion request tracker for the Omni orchestrator.
-
-Encapsulates all bookkeeping for Classifier-Free Guidance companion
-requests (prompt expansion, parent/companion ID mapping, completion
-tracking, deferred forwarding, failure propagation, and timeouts)
-so that ``Omni._run_generation`` stays clean.
-"""
-
-from __future__ import annotations
-
-import copy
-import logging
-import os
-import time
-from collections.abc import Callable, Sequence
-from typing import Any
-
-from vllm_omni.distributed.omni_connectors.adapter import try_send_via_connector
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams
-
-logger = logging.getLogger(__name__)
-
-
-class CfgCompanionTracker:
- """Manages CFG companion request lifecycle in the orchestrator scheduling loop."""
-
- def __init__(
- self,
- prompt_expand_func: Callable[..., Any] | None,
- stage0_sampling_params: Any,
- timeout_s: float | None = None,
- ) -> None:
- self._expand_func = prompt_expand_func
- self._sp0 = stage0_sampling_params
- self._timeout_s = (
- timeout_s if timeout_s is not None else float(os.environ.get("VLLM_CFG_PENDING_TIMEOUT_S", "120"))
- )
-
- self._companion_map: dict[str, dict[str, str]] = {} # parent -> {role: companion_id}
- self._companion_ids: set[str] = set()
- self._companion_to_parent: dict[str, str] = {} # companion -> parent
- self._done: dict[str, set[str]] = {} # parent -> completed companion ids
- self._pending_parents: dict[str, dict[str, Any]] = {} # parent -> deferred result
- self._failed_parents: set[str] = set()
-
- @property
- def is_active(self) -> bool:
- return bool(self._companion_ids)
-
- @property
- def num_companions(self) -> int:
- return len(self._companion_ids)
-
- @property
- def stage0_sampling_params(self) -> Any:
- return self._sp0
-
- def expand_prompts(
- self,
- request_id_to_prompt: dict[str, Any],
- ) -> list[tuple[str, Any]]:
- """Expand user prompts into ``(companion_id, prompt)`` pairs via model-specific func."""
- if not self._expand_func:
- return []
-
- pairs: list[tuple[str, Any]] = []
- for rid, prompt in request_id_to_prompt.items():
- expanded = self._expand_func(prompt, self._sp0)
- if not expanded:
- continue
- role_map: dict[str, str] = {}
- for ep in expanded:
- cid = f"{rid}{ep.request_id_suffix}"
- role_map[ep.role] = cid
- self._companion_ids.add(cid)
- self._companion_to_parent[cid] = rid
- pairs.append((cid, ep.prompt))
- self._companion_map[rid] = role_map
- self._done[rid] = set()
-
- logger.debug(
- "CFG expansion: %d parent(s) -> %d companion(s)",
- len(self._companion_map),
- len(self._companion_ids),
- )
- return pairs
-
- def is_companion(self, req_id: str) -> bool:
- return req_id in self._companion_ids
-
- def has_companions(self, parent_id: str) -> bool:
- return parent_id in self._companion_map
-
- def all_companions_done(self, parent_id: str) -> bool:
- role_map = self._companion_map.get(parent_id, {})
- done_set = self._done.get(parent_id, set())
- return all(cid in done_set for cid in role_map.values())
-
- def get_companion_request_ids(self, parent_id: str) -> dict[str, str]:
- """Return ``{role: companion_request_id}`` for a parent."""
- return self._companion_map.get(parent_id, {})
-
- def is_parent_failed(self, parent_id: str) -> bool:
- return parent_id in self._failed_parents
-
- # -- Lifecycle events --
-
- def on_companion_error(self, companion_id: str) -> tuple[str | None, bool]:
- """Record failure. Returns ``(parent_id, parent_was_aborted)``."""
- parent_id = self._companion_to_parent.get(companion_id)
- if parent_id is None:
- return None, False
- self._failed_parents.add(parent_id)
- logger.error("CFG companion %s failed; marking parent %s as failed", companion_id, parent_id)
- aborted = parent_id in self._pending_parents
- if aborted:
- self._pending_parents.pop(parent_id, None)
- return parent_id, aborted
-
- def on_companion_completed(self, companion_id: str) -> str | None:
- """Mark done. Returns parent_id only if parent is pending and all companions finished."""
- parent_id = self._companion_to_parent.get(companion_id)
- if parent_id is None:
- return None
- self._done[parent_id].add(companion_id)
- logger.debug("CFG companion %s completed (parent=%s)", companion_id, parent_id)
- if parent_id in self._pending_parents and self.all_companions_done(parent_id):
- return parent_id
- return None
-
- def consume_parent_failure(self, parent_id: str) -> None:
- self._failed_parents.discard(parent_id)
-
- # -- Deferred parent management --
-
- def defer_parent(self, parent_id: str, engine_outputs: Any, stage_id: int) -> None:
- """Hold parent result while waiting for companions to finish."""
- self._pending_parents[parent_id] = {
- "engine_outputs": engine_outputs,
- "stage_id": stage_id,
- "pending_since": time.time(),
- }
- logger.debug("Parent %s deferred, waiting for CFG companions", parent_id)
-
- def pop_pending_parent(self, parent_id: str) -> dict[str, Any] | None:
- return self._pending_parents.pop(parent_id, None)
-
- def check_timeouts(self) -> list[str]:
- """Return and remove parent IDs that exceeded the pending timeout."""
- if not self._pending_parents:
- return []
- now = time.time()
- timed_out: list[str] = []
- for pid in list(self._pending_parents):
- pending_since = self._pending_parents[pid].get("pending_since", now)
- if now - pending_since > self._timeout_s:
- self._pending_parents.pop(pid)
- self._failed_parents.discard(pid)
- timed_out.append(pid)
- logger.error("Parent %s timed out waiting for CFG companions (>%.0fs)", pid, self._timeout_s)
- return timed_out
-
- # -- Forward parent with CFG KV --
-
- def forward_parent_with_cfg(
- self,
- req_id: str,
- parent_result: dict[str, Any],
- stage_list: Sequence[Any],
- connectors: dict[tuple[str, str], Any],
- sampling_params_list: Sequence[OmniSamplingParams],
- request_id_to_prompt: dict[str, Any],
- final_stage_id_to_prompt: dict[str, int],
- metrics: Any,
- remaining_by_stage: list[int],
- ) -> bool:
- """Forward a parent request to the next stage with CFG KV request IDs attached."""
- stage_id = parent_result["stage_id"]
- next_stage_id = stage_id + 1
- if next_stage_id > final_stage_id_to_prompt.get(req_id, 0):
- return True
-
- next_stage = stage_list[next_stage_id]
- try:
- with metrics.stage_postprocess_timer(stage_id, req_id):
- next_inputs = next_stage.process_engine_inputs(
- stage_list,
- [request_id_to_prompt[req_id]],
- source_outputs_override=parent_result["engine_outputs"],
- )
- except Exception as e:
- logger.exception(
- "Process engine inputs error for req %s at stage %d: %s",
- req_id,
- next_stage_id,
- e,
- )
- return False
-
- sp_next = copy.deepcopy(sampling_params_list[next_stage_id])
- if isinstance(sp_next, OmniDiffusionSamplingParams):
- sp_next.cfg_kv_request_ids = self.get_companion_request_ids(req_id)
- logger.info(
- "Attaching cfg_kv_request_ids=%s to request %s",
- sp_next.cfg_kv_request_ids,
- req_id,
- )
-
- connector_key = (str(stage_id), str(next_stage_id))
- connector = connectors.get(connector_key)
- sent_via_connector = False
- if connector:
- sent_via_connector = try_send_via_connector(
- connector=connector,
- stage_id=stage_id,
- next_stage_id=next_stage_id,
- req_id=req_id,
- next_inputs=next_inputs,
- sampling_params=sp_next,
- original_prompt=request_id_to_prompt[req_id],
- next_stage_queue_submit_fn=stage_list[next_stage_id].submit,
- metrics=metrics,
- )
-
- if not sent_via_connector:
- raise RuntimeError(
- f"Failed to send CFG request {req_id} to stage-{next_stage_id} via connector. "
- "Configure a connector for this edge or inspect connector logs for details."
- )
-
- logger.debug("Forwarded CFG-enabled request %s to stage-%d", req_id, next_stage_id)
- remaining_by_stage[next_stage_id] += 1
- return True
diff --git a/vllm_omni/entrypoints/chat_utils.py b/vllm_omni/entrypoints/chat_utils.py
index 8970e58984..4c3d311ec5 100644
--- a/vllm_omni/entrypoints/chat_utils.py
+++ b/vllm_omni/entrypoints/chat_utils.py
@@ -2,7 +2,7 @@
async def extract_audio_from_video_async(video_url: str) -> tuple[np.ndarray, int | float]:
- """Extract audio from a video URL using librosa.
+ """Extract audio from a video URL using vllm's load_audio.
Returns a (audio_array, sample_rate) tuple compatible with audio format.
All blocking I/O operations are run in a thread pool.
@@ -26,9 +26,9 @@ def _write_temp_file_sync(data: bytes, suffix: str) -> str:
return temp_file.name
def _load_audio_sync(file_path: str) -> tuple[np.ndarray, int | float]:
- import librosa
+ from vllm.multimodal.media.audio import load_audio
- return librosa.load(file_path, sr=16000)
+ return load_audio(file_path, sr=16000)
def _cleanup_file_sync(file_path: str) -> None:
try:
diff --git a/vllm_omni/entrypoints/cli/benchmark/serve.py b/vllm_omni/entrypoints/cli/benchmark/serve.py
index 906e8851a4..d281432e59 100644
--- a/vllm_omni/entrypoints/cli/benchmark/serve.py
+++ b/vllm_omni/entrypoints/cli/benchmark/serve.py
@@ -1,4 +1,5 @@
import argparse
+import os
from vllm.benchmarks.serve import add_cli_args
@@ -6,15 +7,149 @@
from vllm_omni.entrypoints.cli.benchmark.base import OmniBenchmarkSubcommandBase
+def add_daily_omni_cli_args(parser: argparse.ArgumentParser) -> None:
+ """Add CLI arguments specific to Daily-Omni dataset.
+
+ This function should be called by the CLI entrypoint to add additional
+ arguments for daily-omni benchmark support.
+
+ Args:
+ parser: The ArgumentParser instance to extend
+ """
+ # Daily-Omni specific arguments
+ daily_omni_group = parser.add_argument_group("Daily-Omni Dataset Options")
+
+ daily_omni_group.add_argument(
+ "--daily-omni-qa-json",
+ type=str,
+ default=None,
+ help="Path to local upstream qa.json. When set, QA rows are read from this file and "
+ "the HuggingFace dataset is not loaded (no network). Use with --daily-omni-video-dir "
+ "for fully offline runs. --dataset-path / Hub split flags are then ignored for QA loading.",
+ )
+ daily_omni_group.add_argument(
+ "--daily-omni-video-dir",
+ type=str,
+ default=None,
+ help="Root directory of extracted Daily-Omni videos (contents of Videos.tar: "
+ "each video_id in its own subdir with {video_id}_video.mp4). "
+ "When using file URLs, you MUST start the vLLM server with "
+ "--allowed-local-media-path set to this same directory (or a parent), "
+ "otherwise requests fail with 'Cannot load local files without "
+ "--allowed-local-media-path'.",
+ )
+ daily_omni_group.add_argument(
+ "--daily-omni-inline-local-video",
+ action="store_true",
+ default=False,
+ help="For local videos only: embed MP4 as base64 data URLs in benchmark "
+ "requests so the server does not need --allowed-local-media-path. "
+ "Increases request size and client memory; use for small --num-prompts. "
+ "When using --daily-omni-input-mode audio or all, local WAV files are "
+ "embedded the same way.",
+ )
+ daily_omni_group.add_argument(
+ "--daily-omni-input-mode",
+ type=str,
+ choices=["all", "visual", "audio"],
+ default="all",
+ help="Daily-Omni input protocol (mirrors upstream Lliar-liar/Daily-Omni "
+ "--input_mode). 'visual': video only (default). 'audio': WAV only, "
+ "requires {video_id}/{video_id}_audio.wav under --daily-omni-video-dir. "
+ "'all': video + WAV together. Sets mm_processor_kwargs.use_audio_in_video=false "
+ "and matches official separate video/audio streams.",
+ )
+ daily_omni_group.add_argument(
+ "--daily-omni-save-eval-items",
+ action="store_true",
+ default=False,
+ help="Include per-request Daily-Omni accuracy rows (gold/predicted/correct) "
+ "in the saved JSON under key daily_omni_eval_items. "
+ "Alternatively set env DAILY_OMNI_SAVE_EVAL_ITEMS=1.",
+ )
+
+ # Note: --dataset-name daily-omni via get_samples patch; use either Hub (--dataset-path
+ # liarliar/Daily-Omni) or local --daily-omni-qa-json (offline).
+
+
+def add_seed_tts_cli_args(parser: argparse.ArgumentParser) -> None:
+ """CLI for Seed-TTS zero-shot TTS benchmark (``--dataset-name seed-tts``)."""
+ g = parser.add_argument_group("Seed-TTS Dataset Options")
+ g.add_argument(
+ "--seed-tts-locale",
+ type=str,
+ choices=["en", "zh"],
+ default="en",
+ help="Which Seed-TTS split to load: en/meta.lst or zh/meta.lst under the dataset root.",
+ )
+ g.add_argument(
+ "--seed-tts-root",
+ type=str,
+ default=None,
+ help="Override root directory that contains en/ and zh/ (meta.lst + prompt-wavs). "
+ "If set, --dataset-path can still name the HF repo for logging; this path is used for files.",
+ )
+ g.add_argument(
+ "--seed-tts-file-ref-audio",
+ action="store_true",
+ default=False,
+ help="Send ref_audio as file:// URIs (smaller HTTP bodies). Requires the API server "
+ "to be started with --allowed-local-media-path covering the Seed-TTS dataset root. "
+ "Default is inline data:audio/wav;base64 so Qwen3-TTS works without that flag.",
+ )
+ g.add_argument(
+ "--seed-tts-inline-ref-audio",
+ action="store_true",
+ default=False,
+ help=argparse.SUPPRESS,
+ )
+ g.add_argument(
+ "--seed-tts-system-prompt",
+ type=str,
+ default=None,
+ help="Override chat system message for --backend openai-chat-omni (Qwen3-Omni TTS). "
+ "Default follows official Qwen3-Omni identity + zero-shot voice-clone instructions.",
+ )
+ g.add_argument(
+ "--seed-tts-wer-eval",
+ action="store_true",
+ default=False,
+ help="Keep synthesized audio as 24 kHz mono PCM for WER (works with "
+ "--backend openai-audio-speech or openai-chat-omni). Scoring follows "
+ "BytedanceSpeech/seed-tts-eval (Whisper-large-v3 / Paraformer-zh + jiwer). "
+ "Sets SEED_TTS_WER_EVAL=1. Install: pip install 'vllm-omni[seed-tts-eval]'. "
+ "Optional: SEED_TTS_EVAL_DEVICE, SEED_TTS_HF_WHISPER_MODEL.",
+ )
+ g.add_argument(
+ "--seed-tts-wer-save-items",
+ action="store_true",
+ default=False,
+ help="Include per-utterance ASR rows in the saved JSON under key seed_tts_wer_eval_items. "
+ "Or set SEED_TTS_WER_SAVE_ITEMS=1.",
+ )
+
+
class OmniBenchmarkServingSubcommand(OmniBenchmarkSubcommandBase):
"""The `serve` subcommand for vllm bench."""
name = "serve"
- help = "Benchmark the online serving throughput."
+ help = "Benchmark the online serving throughput. Supports Daily-Omni and Seed-TTS datasets."
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
+
+ # Add Daily-Omni specific arguments
+ add_daily_omni_cli_args(parser)
+ add_seed_tts_cli_args(parser)
+
+ for action in parser._actions:
+ if action.dest == "dataset_name" and action.choices is not None:
+ extra = [c for c in ("daily-omni", "seed-tts") if c not in action.choices]
+ if extra:
+ action.choices = list(action.choices) + extra
+
+ # Update help messages for omni-specific features
for action in parser._actions:
if action.dest == "percentile_metrics":
action.help = (
@@ -48,4 +183,10 @@ def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
@staticmethod
def cmd(args: argparse.Namespace) -> None:
+ if getattr(args, "daily_omni_save_eval_items", False):
+ os.environ["DAILY_OMNI_SAVE_EVAL_ITEMS"] = "1"
+ if getattr(args, "seed_tts_wer_eval", False):
+ os.environ["SEED_TTS_WER_EVAL"] = "1"
+ if getattr(args, "seed_tts_wer_save_items", False):
+ os.environ["SEED_TTS_WER_SAVE_ITEMS"] = "1"
main(args)
diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py
index 6e9adc2461..8bccfbb591 100644
--- a/vllm_omni/entrypoints/cli/serve.py
+++ b/vllm_omni/entrypoints/cli/serve.py
@@ -9,6 +9,7 @@
import json
import os
import signal
+import sys
from types import FrameType
from typing import Any
@@ -21,6 +22,7 @@
from vllm_omni.entrypoints.cli.logo import log_logo
from vllm_omni.entrypoints.openai.api_server import omni_run_server
+from vllm_omni.entrypoints.utils import detect_explicit_cli_keys
logger = init_logger(__name__)
@@ -79,6 +81,9 @@ class OmniServeCommand(CLISubcommand):
"""The `serve` subcommand for the vLLM CLI."""
name = "serve"
+ # Parser stashed at subparser_init so ``cmd`` can resolve each user-typed
+ # flag to its real ``dest`` via the parser's action table.
+ _parser: FlexibleArgumentParser | None = None
@staticmethod
def cmd(args: argparse.Namespace) -> None:
@@ -90,6 +95,10 @@ def cmd(args: argparse.Namespace) -> None:
if hasattr(args, "model_tag") and args.model_tag is not None:
args.model = args.model_tag
+ # Stash the set of long-option keys the user actually typed so the
+ # stage-config factory can give YAML precedence over argparse defaults.
+ args._cli_explicit_keys = detect_explicit_cli_keys(sys.argv[1:], OmniServeCommand._parser)
+
if args.headless:
run_headless(args)
else:
@@ -138,11 +147,33 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu
help="Default task type for TTS models (CustomVoice, VoiceDesign, or Base). "
"If not specified, will be inferred from model path.",
)
+ # TODO(@lishunyang12): deprecate once all models migrate to --deploy-config
omni_config_group.add_argument(
"--stage-configs-path",
type=str,
default=None,
- help="Path to the stage configs file. If not specified, the stage configs will be loaded from the model.",
+ help="[Deprecated — will be removed in a future release] Path to a legacy "
+ "stage configs YAML (stage_args format). Prefer --deploy-config for new-format deploy YAMLs.",
+ )
+ omni_config_group.add_argument(
+ "--deploy-config",
+ type=str,
+ default=None,
+ help="Path to a deploy config YAML (new format with stages/engine_args). "
+ "Mutually exclusive with --stage-configs-path.",
+ )
+ omni_config_group.add_argument(
+ "--stage-overrides",
+ type=str,
+ default=None,
+ help="Per-stage JSON overrides. Example: "
+ '\'{"0": {"gpu_memory_utilization": 0.8}, "2": {"enforce_eager": true}}\'',
+ )
+ omni_config_group.add_argument(
+ "--async-chunk",
+ action=argparse.BooleanOptionalAction,
+ default=None,
+ help="Override the deploy YAML's ``async_chunk:`` bool. Unset leaves the YAML value in force.",
)
omni_config_group.add_argument(
"--stage-id",
@@ -406,6 +437,9 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu
action="store_true",
help="Enable diffusion pipeline profiler to display stage durations.",
)
+ # Stash via type(self) so the docs hook (which execs this function in a
+ # sandboxed globals dict via ``DummySelf``) doesn't fail on a NameError.
+ type(self)._parser = serve_parser
return serve_parser
@@ -461,10 +495,15 @@ def run_headless(args: argparse.Namespace) -> None:
raise ValueError("headless mode requires worker_backend=multi_process")
args_dict = vars(args).copy()
+ # Preserve the explicit-keys set captured at parse time so per-stage yaml
+ # values (e.g. stage 1's ``gpu_memory_utilization: 0.5``) are not
+ # overwritten by argparse defaults for flags the user didn't type.
+ cli_explicit_keys = args_dict.pop("_cli_explicit_keys", None)
config_path, stage_configs = load_and_resolve_stage_configs(
model,
args_dict.get("stage_configs_path"),
args_dict,
+ cli_explicit_keys=cli_explicit_keys,
)
# Locate the stage config that matches stage_id.
diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py
index a3bfe98ce2..8ef7e2ee5b 100644
--- a/vllm_omni/entrypoints/omni.py
+++ b/vllm_omni/entrypoints/omni.py
@@ -66,6 +66,13 @@ def generate(
py_generator: bool = False,
use_tqdm: bool | Callable[..., tqdm] = True,
) -> Generator[OmniRequestOutput, None, None] | list[OmniRequestOutput]:
+ # Expand sampling params for PD disaggregation (user may provide N-1 params)
+ if (
+ sampling_params_list is not None
+ and isinstance(sampling_params_list, Sequence)
+ and not isinstance(sampling_params_list, (str, bytes))
+ ):
+ sampling_params_list = self._maybe_expand_sampling_params(list(sampling_params_list))
sampling_params_list = self.resolve_sampling_params_list(sampling_params_list)
try:
if py_generator:
@@ -125,10 +132,17 @@ def _run_generation(
req_state.metrics = metrics
self.request_states[req_id] = req_state
+ # PD disaggregation: modify stage-0 (prefill) sampling params per request
+ req_sp_list = list(sampling_params_list)
+ pd_pair = self._get_pd_separation_pair()
+ if pd_pair is not None:
+ p_id = pd_pair[0]
+ req_sp_list[p_id] = self._prepare_prefill_sampling_params(req_id, req_sp_list[p_id])
+
self.engine.add_request(
request_id=req_id,
prompt=prompt,
- sampling_params_list=sampling_params_list,
+ sampling_params_list=req_sp_list,
final_stage_id=final_stage_id,
)
submit_ts = time.time()
diff --git a/vllm_omni/entrypoints/omni_base.py b/vllm_omni/entrypoints/omni_base.py
index 1a7ffc4a50..dca494efe7 100644
--- a/vllm_omni/entrypoints/omni_base.py
+++ b/vllm_omni/entrypoints/omni_base.py
@@ -1,6 +1,8 @@
from __future__ import annotations
+import argparse
import os
+import sys
import time
import types
import weakref
@@ -14,7 +16,8 @@
from vllm_omni.engine.async_omni_engine import AsyncOmniEngine
from vllm_omni.entrypoints.client_request_state import ClientRequestState
-from vllm_omni.entrypoints.utils import get_final_stage_id_for_e2e
+from vllm_omni.entrypoints.pd_utils import PDDisaggregationMixin
+from vllm_omni.entrypoints.utils import detect_explicit_cli_keys, get_final_stage_id_for_e2e
from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics
from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific
from vllm_omni.outputs import OmniRequestOutput
@@ -65,9 +68,51 @@ def omni_snapshot_download(model_id: str) -> str:
OutputMessageHandleResult = tuple[Literal[True], None, None, None] | tuple[Literal[False], str, int, ClientRequestState]
-class OmniBase:
+class OmniBase(PDDisaggregationMixin):
"""Shared runtime foundation for AsyncOmni and Omni."""
+ @classmethod
+ def from_cli_args(
+ cls,
+ args: argparse.Namespace,
+ *,
+ parser: argparse.ArgumentParser | None = None,
+ **overrides: Any,
+ ) -> OmniBase:
+ """Construct an ``Omni`` / ``AsyncOmni`` from an ``argparse.Namespace``.
+
+ Mirrors the ``EngineArgs.from_cli_args`` pattern used upstream and in
+ ``OmniEngineArgs.from_cli_args``. This is the recommended entry point
+ for any argparse-based caller (offline scripts, tests, CI): it
+ expands ``vars(args)`` into kwargs and automatically captures which
+ flags the user typed on the command line so that argparse defaults
+ do not silently override deploy YAML values.
+
+ Passing ``parser`` is strongly recommended: without it, flag-to-dest
+ resolution falls back to a name-based heuristic that misidentifies
+ flags with ``dest=`` overrides, alias flags, and ``--disable-X`` /
+ ``store_false`` pairs. See :func:`detect_explicit_cli_keys`.
+
+ Args:
+ args: Parsed argparse namespace from ``parser.parse_args()``.
+ parser: The argparse parser used to produce ``args``. When
+ provided, each user-typed flag is resolved to its real
+ ``dest`` via the parser's action table.
+ **overrides: Extra keyword arguments that take precedence over
+ attributes on ``args``.
+
+ Example::
+
+ parser = FlexibleArgumentParser()
+ parser.add_argument("--model", required=True)
+ args = parser.parse_args()
+ omni = Omni.from_cli_args(args, parser=parser) # preferred
+ omni = Omni.from_cli_args(args, parser=parser, model="other")
+ """
+ kwargs: dict[str, Any] = {**vars(args), **overrides}
+ kwargs["_cli_explicit_keys"] = detect_explicit_cli_keys(sys.argv[1:], parser)
+ return cls(**kwargs)
+
def __init__(
self,
model: str,
@@ -77,16 +122,24 @@ def __init__(
stage_init_timeout = kwargs.pop("stage_init_timeout", 300)
init_timeout = kwargs.pop("init_timeout", 600)
log_stats = kwargs.pop("log_stats", False)
- async_chunk = kwargs.pop("async_chunk", False)
+ # NOTE: read-only lookup — must NOT pop. Popping here drops the key
+ # before it reaches ``StageConfigFactory._create_from_registry``, so
+ # ``--no-async-chunk`` (``async_chunk=False``) silently fails to
+ # override the deploy YAML's ``async_chunk: true`` default.
+ async_chunk = kwargs.get("async_chunk")
output_modalities = kwargs.pop("output_modalities", None)
diffusion_batch_size: int = kwargs.pop("diffusion_batch_size", 1)
if "log_requests" in kwargs:
raise TypeError("`log_requests` has been removed in Omni/AsyncOmni. Use `log_stats`.")
model = omni_snapshot_download(model)
+ self._name = self.__class__.__name__
self.model = model
self.log_stats = log_stats
- self.async_chunk = async_chunk
+ # Provisional value (mirrors the CLI/caller kwarg); the engine resolves
+ # pipeline + deploy YAML + CLI precedence below and the final value is
+ # re-assigned from ``self.engine.async_chunk`` after init.
+ self.async_chunk = bool(async_chunk) if async_chunk is not None else False
self.output_modalities = output_modalities or []
self.tts_batch_max_items: int = kwargs.pop("tts_batch_max_items", 32)
@@ -104,7 +157,11 @@ def __init__(
self._weak_finalizer = weakref.finalize(self, _weak_shutdown_engine, self.engine)
et = time.time()
logger.info("[%s] AsyncOmniEngine initialized in %.2f seconds", self.__class__.__name__, et - st)
- self.async_chunk = bool(self.async_chunk or getattr(self.engine, "async_chunk", False))
+ # Authoritative: ``AsyncOmniEngine`` resolves (pipeline + deploy YAML +
+ # CLI overrides) through ``StageConfigFactory`` and stores the final
+ # value on ``engine.async_chunk``; mirror it here so ``--no-async-chunk``
+ # (explicit ``False``) is not fallen-back-through by ``or``.
+ self.async_chunk = bool(getattr(self.engine, "async_chunk", False))
self.request_states: dict[str, ClientRequestState] = {}
@@ -125,10 +182,18 @@ def __init__(
model,
)
+ # PD disaggregation state (detects if a prefill/decode stage pair is configured)
+ self._init_pd_state()
+
@property
def num_stages(self) -> int:
return self.engine.num_stages
+ @property
+ def stage_configs(self) -> list:
+ """Expose engine stage configs for PD disaggregation detection and validation."""
+ return self.engine.stage_configs
+
@property
def is_running(self) -> bool:
return self.engine.is_alive()
diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py
index 4519ae8c0c..745b719d5b 100644
--- a/vllm_omni/entrypoints/openai/api_server.py
+++ b/vllm_omni/entrypoints/openai/api_server.py
@@ -18,6 +18,7 @@
from typing import Annotated, Any, Literal, cast
import httpx
+import numpy as np
import vllm.envs as envs
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, Request, UploadFile, WebSocket
from fastapi.responses import FileResponse, JSONResponse, Response, StreamingResponse
@@ -52,7 +53,6 @@
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.orca_metrics import metrics_header
-from vllm.entrypoints.openai.realtime.connection import RealtimeConnection
from vllm.entrypoints.openai.realtime.serving import OpenAIServingRealtime
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
from vllm.entrypoints.openai.server_utils import get_uvicorn_log_config
@@ -107,6 +107,7 @@
VideoListResponse,
VideoResponse,
)
+from vllm_omni.entrypoints.openai.realtime_connection import RealtimeConnection
from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech
from vllm_omni.entrypoints.openai.serving_speech_stream import OmniStreamingSpeechHandler
@@ -120,6 +121,7 @@
logger = init_logger(__name__)
router = APIRouter()
+MAX_UINT32_SEED = 2**32 - 1
profiler_router = APIRouter()
@@ -1202,6 +1204,22 @@ async def streaming_speech(websocket: WebSocket):
@router.websocket("/v1/realtime")
async def realtime_websocket(websocket: WebSocket):
"""WebSocket endpoint for OpenAI-style realtime interactions."""
+ engine_client = getattr(websocket.app.state, "engine_client", None)
+ if engine_client is not None and getattr(engine_client, "async_chunk", False):
+ await websocket.accept()
+ await websocket.send_json(
+ {
+ "type": "error",
+ "error": (
+ "The /v1/realtime API is not supported when async_chunk is enabled on the server. "
+ "Use a stage configuration with async_chunk disabled and restart the server before using "
+ "this endpoint."
+ ),
+ "code": "unsupported",
+ }
+ )
+ await websocket.close()
+ return
serving = getattr(websocket.app.state, "openai_serving_realtime", None)
if serving is None:
await websocket.accept()
@@ -1303,14 +1321,64 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request)
# Get engine client (AsyncOmni) from app state
engine_client, model_name, stage_configs = _get_engine_and_model(raw_request)
- # Validate model field (warn if mismatch, don't error)
if request.model is not None and request.model != model_name:
- logger.warning(
- f"Model mismatch: request specifies '{request.model}' but "
- f"server is running '{model_name}'. Using server model."
+ raise HTTPException(
+ status_code=HTTPStatus.BAD_REQUEST.value,
+ detail=(f"Model mismatch: request specifies '{request.model}' but server is running '{model_name}'."),
)
try:
+ # Unify request construction for any multi-stage pipeline to avoid
+ # divergence between /v1/images and /v1/chat/completions.
+ if len(stage_configs) > 1:
+ chat_handler = getattr(raw_request.app.state, "openai_serving_chat", None)
+ if chat_handler is None:
+ logger.warning("openai_serving_chat is not initialized for multi-stage /v1/images/generations")
+ raise HTTPException(
+ status_code=HTTPStatus.SERVICE_UNAVAILABLE.value,
+ detail="openai_serving_chat is not initialized for multi-stage image generation.",
+ )
+
+ effective_seed = request.seed if request.seed is not None else random.randint(0, MAX_UINT32_SEED)
+ extra_body: dict[str, Any] = {
+ "seed": effective_seed,
+ "num_outputs_per_prompt": request.n,
+ }
+ if request.size is not None:
+ parse_size(request.size)
+ width, height = parse_size(request.size)
+ app_state_args = getattr(raw_request.app.state, "args", None)
+ _check_max_generated_image_size(app_state_args, width, height)
+ extra_body["size"] = request.size
+ if request.negative_prompt is not None:
+ extra_body["negative_prompt"] = request.negative_prompt
+ if request.num_inference_steps is not None:
+ extra_body["num_inference_steps"] = request.num_inference_steps
+ if request.guidance_scale is not None:
+ extra_body["guidance_scale"] = request.guidance_scale
+ if request.true_cfg_scale is not None:
+ extra_body["true_cfg_scale"] = request.true_cfg_scale
+ if request.generator_device is not None:
+ extra_body["generator_device"] = request.generator_device
+ if request.lora is not None:
+ # Keep /images validation semantics: invalid LoRA should fail with 400.
+ _parse_lora_request(request.lora)
+ extra_body["lora"] = request.lora
+
+ generation_result = await chat_handler.generate_diffusion_images(
+ prompt=request.prompt,
+ extra_body=extra_body,
+ request_id=f"img_gen-{random_uuid()}",
+ )
+ if isinstance(generation_result, ErrorResponse):
+ return JSONResponse(
+ status_code=generation_result.error.code if generation_result.error else 400,
+ content=generation_result.model_dump(),
+ )
+ flat_images, _, _ = generation_result
+ image_data = [ImageData(b64_json=encode_image_base64(img), revised_prompt=None) for img in flat_images]
+ return ImageGenerationResponse(created=int(time.time()), data=image_data)
+
# Build params - pass through user values directly
prompt: OmniTextPrompt = {"prompt": request.prompt}
if request.negative_prompt is not None:
@@ -1351,7 +1419,7 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request)
# This fixes issues where using the default global generator
# might produce blurry images in some environments.
_update_if_not_none(
- gen_params, "seed", request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
+ gen_params, "seed", request.seed if request.seed is not None else random.randint(0, MAX_UINT32_SEED)
)
_update_if_not_none(gen_params, "generator_device", request.generator_device)
_update_if_not_none(gen_params, "layers", request.layers)
@@ -1425,6 +1493,9 @@ async def edit_images(
background: str | None = Form("auto"),
output_compression: Annotated[int, Form(ge=0, le=100)] = 100,
user: str | None = Form(None), # unused now
+ # vllm-omni extensions for image editing
+ mask_image: str | UploadFile | None = None,
+ reference_image: str | UploadFile | None = None,
# vllm-omni extensions for diffusion control
negative_prompt: str | None = Form(None),
num_inference_steps: int | None = Form(None),
@@ -1445,8 +1516,9 @@ async def edit_images(
# 1. get engine and model
engine_client, model_name, stage_configs = _get_engine_and_model(raw_request)
if model is not None and model != model_name:
- logger.warning(
- f"Model mismatch: request specifies '{model}' but server is running '{model_name}'. Using server model."
+ raise HTTPException(
+ status_code=HTTPStatus.BAD_REQUEST.value,
+ detail=(f"Model mismatch: request specifies '{model}' but server is running '{model_name}'."),
)
# 2. get output format & compression
output_format = _choose_output_format(output_format, background)
@@ -1469,15 +1541,35 @@ async def edit_images(
input_images_list.extend(urls)
if not input_images_list:
raise HTTPException(status_code=422, detail="Field 'image' or 'url' is required")
- pil_images = await _load_input_images(input_images_list)
- if len(pil_images) > 1 and not _supports_multimodal_image_inputs(raw_request, engine_client):
+ # Reject oversized multi-image edit requests before fetching or decoding
+ # any inputs. This keeps over-limit URL requests from burning network,
+ # CPU, and memory on work that will be rejected anyway.
+ max_input_images = _get_max_edit_input_images(raw_request, engine_client)
+ if max_input_images is not None and len(input_images_list) > max_input_images:
+ detail = (
+ "Received multiple input images. Only a single image is supported by this model."
+ if max_input_images == 1
+ else (
+ f"Received {len(input_images_list)} input images. "
+ f"At most {max_input_images} images are supported by this model."
+ )
+ )
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
- detail="Received multiple input images. Only a single image is supported by this model.",
+ detail=detail,
)
+ pil_images = await _load_input_images(input_images_list)
prompt["multi_modal_data"] = {}
prompt["multi_modal_data"]["image"] = pil_images
+ if mask_image is not None:
+ loaded = await _load_input_images([mask_image])
+ prompt["multi_modal_data"]["mask_image"] = loaded[0]
+
+ if reference_image is not None:
+ loaded = await _load_input_images([reference_image])
+ prompt["multi_modal_data"]["reference_image"] = loaded[0]
+
# 3 Build sample params
gen_params = OmniDiffusionSamplingParams()
# 3.0 Init with system default values
@@ -1549,7 +1641,7 @@ async def edit_images(
# a proper generator is initialized in the backend.
# This fixes issues where using the default global generator
# might produce blurry images in some environments.
- _update_if_not_none(gen_params, "seed", seed if seed is not None else random.randint(0, 2**32 - 1))
+ _update_if_not_none(gen_params, "seed", seed if seed is not None else random.randint(0, MAX_UINT32_SEED))
_update_if_not_none(gen_params, "generator_device", generator_device)
_update_if_not_none(gen_params, "layers", layers)
_update_if_not_none(gen_params, "resolution", resolution)
@@ -1639,18 +1731,25 @@ def _get_engine_and_model(raw_request: Request):
return engine_client, model_name, normalized_stage_configs
-def _supports_multimodal_image_inputs(raw_request: Request, engine_client: Any) -> bool:
+def _get_diffusion_od_config(raw_request: Request, engine_client: Any) -> Any:
diffusion_engine = getattr(raw_request.app.state, "diffusion_engine", None) or engine_client
get_diffusion_od_config = getattr(diffusion_engine, "get_diffusion_od_config", None)
- od_config = (
+ return (
get_diffusion_od_config() if callable(get_diffusion_od_config) else getattr(diffusion_engine, "od_config", None)
)
+
+def _get_max_edit_input_images(raw_request: Request, engine_client: Any) -> int | None:
+ od_config = _get_diffusion_od_config(raw_request, engine_client)
if od_config is None:
# Preserve the existing compatibility behavior when the diffusion
# config is not exposed on the serving surface.
- return True
- return bool(getattr(od_config, "supports_multimodal_inputs", False))
+ return None
+
+ if not bool(getattr(od_config, "supports_multimodal_inputs", False)):
+ return 1
+
+ return getattr(od_config, "max_multimodal_image_inputs", None)
def _get_lora_from_json_str(lora_body):
@@ -1767,6 +1866,34 @@ def _update_if_not_none(object: Any, key: str, val: Any) -> None:
setattr(object, key, val)
+def _normalize_image(image: Any) -> Any:
+ """Normalize a single image output to a PIL-compatible format."""
+ if isinstance(image, Image.Image):
+ return image
+ if not isinstance(image, np.ndarray):
+ raise ValueError(f"Unsupported image type: {type(image)}")
+ if not np.issubdtype(image.dtype, np.integer) and not np.issubdtype(image.dtype, np.floating):
+ raise ValueError(f"Unsupported dtype: {image.dtype}")
+ if isinstance(image, np.ndarray):
+ while image.ndim > 3:
+ image = image[0]
+ if image.min() < 0:
+ if image.min() < -1.01 or image.max() > 1.01:
+ logger.warning(
+ f"Image float range [{image.min():.2f}, {image.max():.2f}] outside expected [-1, 1]. "
+ f"Clipping to [-1, 1] before normalization."
+ )
+ image = np.clip(image, -1.0, 1.0) * 0.5 + 0.5
+ elif image.max() > 1.01:
+ logger.warning(
+ f"Image float range [{image.min():.2f}, {image.max():.2f}] outside expected [0, 1]. "
+ f"Clipping to [0, 1] before normalization."
+ )
+ image = (np.clip(image, 0.0, 1.0) * 255).astype(np.uint8)
+ image = Image.fromarray(image)
+ return image
+
+
def _extract_images_from_result(result: Any) -> list[Any]:
images = []
if hasattr(result, "images") and result.images:
@@ -1777,6 +1904,10 @@ def _extract_images_from_result(result: Any) -> list[Any]:
images = request_output["images"]
elif hasattr(request_output, "images") and request_output.images:
images = request_output.images
+ # Handle when generate more than one image
+ if images and isinstance(images[0], np.ndarray) and images[0].shape[0] > 1 and images[0].ndim == 5:
+ # Unwrap batch: (N, T, H, W, C) -> [img1, img2, ...]
+ images = list(images[0])
# Flatten nested lists (e.g., from layered models like Qwen-Image-Layered).
# Note: This only flattens one level deep. Deeper nesting is not supported.
flattened = []
@@ -1785,7 +1916,7 @@ def _extract_images_from_result(result: Any) -> list[Any]:
flattened.extend(img)
else:
flattened.append(img)
- return flattened
+ return [_normalize_image(img) for img in flattened]
async def _load_input_images(
@@ -1955,18 +2086,6 @@ def video_response_from_request(model_name: str, req: VideoGenerationRequest) ->
return resp
-async def decode_and_save_video_output(output: Any, file_name: str) -> str:
- if not output.b64_json:
- raise RuntimeError(f"Video output for {file_name} did not include b64_json content.")
-
- try:
- video_bytes = base64.b64decode(output.b64_json)
- except Exception as decode_exc:
- raise RuntimeError(f"Failed to decode generated video payload for {file_name}") from decode_exc
-
- return await STORAGE_MANAGER.save(video_bytes, file_name)
-
-
def _cleanup_video(video_id: str, output_path: str | None):
try:
if output_path is not None:
@@ -1990,15 +2109,12 @@ async def _run_video_generation_job(
started_at = time.perf_counter()
output_path = None
try:
- response = await handler.generate_videos(request, video_id, reference_image=reference_image)
- if not response.data:
- raise RuntimeError("Video generation completed but returned no outputs.")
-
- if (video_count := len(response.data)) > 1:
- logger.warning("Video request %s generated %s outputs but we only expected one.", video_id, video_count)
+ video_bytes, stage_durations, peak_memory_mb = await handler.generate_video_bytes(
+ request, video_id, reference_image=reference_image
+ )
file_name = f"{video_id}.{job.file_extension}"
- output_path = await decode_and_save_video_output(response.data[0], file_name)
+ output_path = await STORAGE_MANAGER.save(video_bytes, file_name)
logger.info("Video request %s persisted %s output file.", video_id, output_path)
await VIDEO_STORE.update_fields(
@@ -2009,6 +2125,8 @@ async def _run_video_generation_job(
"file_name": file_name,
"completed_at": int(time.time()),
"inference_time_s": time.perf_counter() - started_at,
+ "stage_durations": stage_durations,
+ "peak_memory_mb": peak_memory_mb,
},
)
except Exception as exc:
@@ -2055,6 +2173,10 @@ async def _parse_video_form(
true_cfg_scale: float | None = Form(default=None),
seed: int | None = Form(default=None),
negative_prompt: str | None = Form(default=None),
+ enable_frame_interpolation: bool = Form(default=False),
+ frame_interpolation_exp: int = Form(default=1, ge=1),
+ frame_interpolation_scale: float = Form(default=1.0, gt=0.0),
+ frame_interpolation_model_path: str | None = Form(default=None),
lora: str | None = Form(default=None),
extra_params: str | None = Form(default=None),
) -> tuple[VideoGenerationRequest, "OmniOpenAIServingVideo", str, ReferenceImage | None]:
@@ -2091,6 +2213,10 @@ async def _parse_video_form(
"true_cfg_scale": true_cfg_scale,
"seed": seed,
"negative_prompt": negative_prompt,
+ "enable_frame_interpolation": enable_frame_interpolation,
+ "frame_interpolation_exp": frame_interpolation_exp,
+ "frame_interpolation_scale": frame_interpolation_scale,
+ "frame_interpolation_model_path": frame_interpolation_model_path,
"lora": _parse_form_json(lora, expected_type=dict),
"extra_params": _parse_form_json(extra_params, expected_type=dict),
}
@@ -2108,10 +2234,12 @@ async def _parse_video_form(
app_model_name, app_stage_configs = _resolve_video_runtime_context(raw_request)
effective_model_name = handler.model_name or app_model_name or request.model or "unknown"
if request.model is not None and effective_model_name is not None and request.model != effective_model_name:
- logger.warning(
- "Model mismatch: request specifies '%s' but server is running '%s'. Using server model.",
- request.model,
- effective_model_name,
+ raise HTTPException(
+ status_code=HTTPStatus.BAD_REQUEST.value,
+ detail=(
+ f"Model mismatch: request specifies '{request.model}' but server is running "
+ f"'{effective_model_name}'."
+ ),
)
handler.set_stage_configs_if_missing(app_stage_configs)
except HTTPException:
@@ -2182,7 +2310,7 @@ async def create_video_sync(
request_id = f"video_sync-{random_uuid()}"
started_at = time.perf_counter()
try:
- video_bytes = await asyncio.wait_for(
+ video_bytes, stage_durations, peak_memory_mb = await asyncio.wait_for(
handler.generate_video_bytes(request, request_id, reference_image=reference_image),
timeout=VIDEO_SYNC_TIMEOUT_S,
)
@@ -2208,6 +2336,8 @@ async def create_video_sync(
"X-Request-Id": request_id,
"X-Model": effective_model_name,
"X-Inference-Time-S": f"{inference_time_s:.3f}",
+ "X-Stage-Durations": json.dumps(stage_durations, separators=(",", ":")),
+ "X-Peak-Memory-MB": f"{peak_memory_mb:.3f}",
},
)
diff --git a/vllm_omni/entrypoints/openai/audio_utils_mixin.py b/vllm_omni/entrypoints/openai/audio_utils_mixin.py
index 13df32ebe0..b626f7eeb2 100644
--- a/vllm_omni/entrypoints/openai/audio_utils_mixin.py
+++ b/vllm_omni/entrypoints/openai/audio_utils_mixin.py
@@ -1,6 +1,8 @@
from io import BytesIO
import numpy as np
+import torch
+import torchaudio
from vllm.logger import init_logger
from vllm_omni.entrypoints.openai.protocol.audio import AudioResponse, CreateAudio
@@ -10,11 +12,6 @@
except ImportError:
soundfile = None
-try:
- import librosa
-except ImportError:
- librosa = None
-
logger = init_logger(__name__)
@@ -74,20 +71,53 @@ def create_audio(self, audio_obj: CreateAudio) -> AudioResponse:
return AudioResponse(audio_data=audio_data, media_type=media_type)
def _apply_speed_adjustment(self, audio_tensor: np.ndarray, speed: float, sample_rate: int):
- """Apply speed adjustment to the audio tensor while preserving pitch."""
+ """Apply speed adjustment to the audio tensor while preserving pitch.
+
+ Uses torchaudio's phase vocoder (Spectrogram → TimeStretch →
+ InverseSpectrogram) to stretch/compress audio in time without
+ changing pitch.
+ """
if speed == 1.0:
return audio_tensor, sample_rate
- if librosa is None:
- raise ImportError("librosa is required for speed adjustment. Please install it with: pip install librosa")
-
try:
- # librosa.effects.time_stretch requires a float audio tensor.
if not np.issubdtype(audio_tensor.dtype, np.floating):
audio_tensor = audio_tensor.astype(np.float32)
- stretched_audio = librosa.effects.time_stretch(y=audio_tensor, rate=speed)
- return stretched_audio, sample_rate
+ # Stereo numpy arrays use channels-last (T, C);
+ # torch expects channels-first (C, T).
+ channels_last = audio_tensor.ndim == 2
+ if channels_last:
+ waveform = torch.from_numpy(audio_tensor.T)
+ else:
+ waveform = torch.from_numpy(audio_tensor).unsqueeze(0)
+
+ # Match librosa.stft defaults: n_fft=2048, hop_length=n_fft//4
+ n_fft = 2048
+ hop_length = n_fft // 4
+ to_spec = torchaudio.transforms.Spectrogram(
+ n_fft=n_fft,
+ hop_length=hop_length,
+ power=None,
+ )
+ stretch = torchaudio.transforms.TimeStretch(
+ n_freq=n_fft // 2 + 1,
+ hop_length=hop_length,
+ )
+ to_wave = torchaudio.transforms.InverseSpectrogram(
+ n_fft=n_fft,
+ hop_length=hop_length,
+ )
+
+ spec = to_spec(waveform)
+ stretched = stretch(spec, speed)
+ expected_length = int(audio_tensor.shape[0] / speed)
+ result = to_wave(stretched, length=expected_length)
+
+ result = result.squeeze(0).numpy()
+ if channels_last:
+ result = result.T
+ return result, sample_rate
except Exception as e:
logger.error(f"An error occurred during speed adjustment: {e}")
raise ValueError("Failed to apply speed adjustment.") from e
diff --git a/vllm_omni/entrypoints/openai/protocol/videos.py b/vllm_omni/entrypoints/openai/protocol/videos.py
index e180bef229..7c2c3164d9 100644
--- a/vllm_omni/entrypoints/openai/protocol/videos.py
+++ b/vllm_omni/entrypoints/openai/protocol/videos.py
@@ -150,6 +150,29 @@ class VideoGenerationRequest(BaseModel):
)
seed: int | None = Field(default=None, description="Random seed for reproducibility")
+ # vllm-omni extensions for post-generation frame interpolation.
+ enable_frame_interpolation: bool = Field(
+ default=False,
+ description="Enable post-generation RIFE frame interpolation before MP4 encoding.",
+ )
+ frame_interpolation_exp: int = Field(
+ default=1,
+ ge=1,
+ description="Interpolation exponent: 1=2x temporal resolution, 2=4x, etc.",
+ )
+ frame_interpolation_scale: float = Field(
+ default=1.0,
+ gt=0.0,
+ description="RIFE inference scale. Use 0.5 for high-resolution inputs to save memory.",
+ )
+ frame_interpolation_model_path: str | None = Field(
+ default=None,
+ description=(
+ "Local directory or Hugging Face repo ID containing RIFE flownet.pkl weights. "
+ "Defaults to elfgum/RIFE-4.22.lite."
+ ),
+ )
+
# vllm-omni extension for per-request LoRA.
lora: dict[str, Any] | None = Field(
default=None,
@@ -201,6 +224,14 @@ class VideoGenerationResponse(BaseModel):
created: int = Field(..., description="Unix timestamp of when the generation completed")
data: list[VideoData] = Field(..., description="Array of generated videos")
+ stage_durations: dict[str, float] = Field(
+ default_factory=dict,
+ description="Profiler stage durations reported by the diffusion pipeline.",
+ )
+ peak_memory_mb: float = Field(
+ default=0.0,
+ description="Peak device memory usage in MB reported by the diffusion pipeline.",
+ )
class VideoError(BaseModel):
@@ -250,6 +281,14 @@ class VideoResponse(BaseModel):
description="Filename of the saved output video files for this job.",
)
inference_time_s: float | None = Field(default=None, description="End-to-end inference time in seconds.")
+ stage_durations: dict[str, float] = Field(
+ default_factory=dict,
+ description="Profiler stage durations reported by the diffusion pipeline.",
+ )
+ peak_memory_mb: float = Field(
+ default=0.0,
+ description="Peak device memory usage in MB reported by the diffusion pipeline.",
+ )
@property
def file_extension(self) -> str:
diff --git a/vllm_omni/entrypoints/openai/realtime_connection.py b/vllm_omni/entrypoints/openai/realtime_connection.py
new file mode 100644
index 0000000000..1d5470f569
--- /dev/null
+++ b/vllm_omni/entrypoints/openai/realtime_connection.py
@@ -0,0 +1,193 @@
+from __future__ import annotations
+
+import asyncio
+import base64
+import json
+from collections.abc import AsyncGenerator
+from uuid import uuid4
+
+import numpy as np
+from vllm.entrypoints.openai.engine.protocol import UsageInfo
+from vllm.entrypoints.openai.realtime.connection import RealtimeConnection as VllmRealtimeConnection
+from vllm.entrypoints.openai.realtime.protocol import TranscriptionDelta, TranscriptionDone
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+class RealtimeConnection(VllmRealtimeConnection):
+ """Omni realtime connection with audio-only server events.
+
+ Reuses upstream vLLM websocket/session lifecycle and only customizes
+ generation output handling to emit audio deltas.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # Last audio buffer seen for this realtime generation (cumulative or concatenation
+ # of increments); used to turn server cumulative PCM into true deltas.
+ self._realtime_audio_ref: np.ndarray | None = None
+
+ async def start_generation(self):
+ await super().start_generation()
+
+ @staticmethod
+ def _tensor_to_numpy(value) -> np.ndarray | None:
+ if value is None:
+ return None
+ if isinstance(value, np.ndarray):
+ arr = value
+ elif hasattr(value, "detach"):
+ arr = value.detach().float().cpu().numpy()
+ else:
+ try:
+ arr = np.asarray(value)
+ except Exception:
+ return None
+ if arr.ndim > 1:
+ arr = arr.reshape(-1)
+ return arr.astype(np.float32, copy=False)
+
+ @staticmethod
+ def _numpy_audio_prefix_match(prev: np.ndarray, curr: np.ndarray) -> bool:
+ n = prev.shape[0]
+ if n == 0:
+ return True
+ if curr.shape[0] < n:
+ return False
+ return bool(np.allclose(curr[:n], prev, rtol=1e-3, atol=2e-4))
+
+ def _raw_waveform_to_deltas(self, arr: np.ndarray) -> list[np.ndarray]:
+ """Convert one streaming PCM f32 chunk into incremental piece(s) for the client.
+
+ Some engine paths emit a growing cumulative waveform each step; others emit
+ true per-step deltas. We support both without duplicating audio on the client.
+ """
+ if arr.size == 0:
+ return []
+ ref = self._realtime_audio_ref
+ if ref is None:
+ self._realtime_audio_ref = arr.copy()
+ return [arr]
+ if self._numpy_audio_prefix_match(ref, arr):
+ delta = arr[ref.shape[0] :]
+ self._realtime_audio_ref = arr.copy()
+ return [delta] if delta.size > 0 else []
+ # True per-step delta (not a prefix extension of what we have seen).
+ self._realtime_audio_ref = np.concatenate([ref, arr])
+ return [arr]
+
+ def _extract_audio_chunks(self, output) -> tuple[list[np.ndarray], int]:
+ mm = getattr(output, "multimodal_output", None)
+ if not isinstance(mm, dict):
+ return [], 24000
+
+ sr = mm.get("sr") or mm.get("sample_rate") or mm.get("audio_sample_rate") or 24000
+ key = "audio" if "audio" in mm else ("model_outputs" if "model_outputs" in mm else None)
+ if key is None:
+ return [], int(sr)
+
+ raw_audio = mm.get(key)
+ chunks: list[np.ndarray] = []
+ if isinstance(raw_audio, (list, tuple)):
+ if len(raw_audio) > 0:
+ arr = self._tensor_to_numpy(raw_audio[-1])
+ if arr is not None and arr.size > 0:
+ chunks.extend(self._raw_waveform_to_deltas(arr))
+ else:
+ arr = self._tensor_to_numpy(raw_audio)
+ if arr is not None and arr.size > 0:
+ chunks.extend(self._raw_waveform_to_deltas(arr))
+ return chunks, int(sr)
+
+ @staticmethod
+ def _pcm16_b64(audio_f32: np.ndarray) -> str:
+ clipped = np.clip(audio_f32, -1.0, 1.0)
+ pcm16 = (clipped * 32767.0).astype(np.int16)
+ return base64.b64encode(pcm16.tobytes()).decode("utf-8")
+
+ async def _run_generation(
+ self,
+ streaming_input_gen: AsyncGenerator,
+ input_stream: asyncio.Queue[list[int]],
+ ):
+ request_id = f"rt-{self.connection_id}-{uuid4()}"
+ sent_audio = False
+ audio_done_sent = False
+ full_text = ""
+ sent_text_len = 0
+ prompt_token_ids_len = 0
+ completion_tokens_len = 0
+ self._realtime_audio_ref = None
+
+ try:
+ result_gen = self.serving.engine_client.generate(
+ prompt=streaming_input_gen,
+ request_id=request_id,
+ )
+
+ async for output in result_gen:
+ if output.outputs and len(output.outputs) > 0:
+ output0 = output.outputs[0]
+ token_ids = list(output0.token_ids)
+ if token_ids:
+ input_stream.put_nowait(token_ids)
+ # token_ids are cumulative per request
+ completion_tokens_len = len(token_ids)
+ if not prompt_token_ids_len and output.prompt_token_ids:
+ prompt_token_ids_len = len(output.prompt_token_ids)
+ cumulative_text = output0.text or ""
+ if cumulative_text:
+ if len(cumulative_text) >= sent_text_len:
+ delta_text = cumulative_text[sent_text_len:]
+ else:
+ delta_text = cumulative_text
+ sent_text_len = len(cumulative_text)
+ full_text = cumulative_text
+ else:
+ delta_text = ""
+
+ if delta_text:
+ await self.send(TranscriptionDelta(delta=delta_text))
+
+ audio_chunks, sample_rate = self._extract_audio_chunks(output)
+
+ for chunk in audio_chunks:
+ sent_audio = True
+ await self.send_json(
+ {
+ "type": "response.audio.delta",
+ "audio": self._pcm16_b64(chunk),
+ "format": "pcm16",
+ "sample_rate_hz": sample_rate,
+ }
+ )
+
+ if not self._is_connected:
+ break
+
+ usage = UsageInfo(
+ prompt_tokens=prompt_token_ids_len,
+ completion_tokens=completion_tokens_len,
+ total_tokens=prompt_token_ids_len + completion_tokens_len,
+ )
+ await self.send(TranscriptionDone(text=full_text, usage=usage))
+
+ if sent_audio:
+ await self.send_json({"type": "response.audio.done", "has_audio": True})
+ audio_done_sent = True
+ except Exception as e:
+ logger.exception("Error in generation: %s", e)
+ await self.send_error(str(e), "processing_error")
+ finally:
+ # Always send terminal event so clients don't hang forever.
+ if self._is_connected and not audio_done_sent:
+ try:
+ await self.send_json({"type": "response.audio.done", "has_audio": sent_audio})
+ except Exception:
+ logger.exception("Failed to send response.audio.done")
+ while not self.audio_queue.empty():
+ self.audio_queue.get_nowait()
+
+ async def send_json(self, payload: dict):
+ await self.websocket.send_text(json.dumps(payload))
diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py
index e84a49aac2..8cddac6a6c 100644
--- a/vllm_omni/entrypoints/openai/serving_chat.py
+++ b/vllm_omni/entrypoints/openai/serving_chat.py
@@ -85,7 +85,12 @@
from vllm_omni.entrypoints.openai.image_api_utils import validate_layered_layers
from vllm_omni.entrypoints.openai.protocol import OmniChatCompletionStreamResponse
from vllm_omni.entrypoints.openai.protocol.audio import AudioResponse, CreateAudio
-from vllm_omni.entrypoints.openai.utils import parse_lora_request
+from vllm_omni.entrypoints.openai.utils import (
+ get_stage_type,
+ get_supported_speakers_from_hf_config,
+ parse_lora_request,
+ validate_requested_speaker,
+)
from vllm_omni.lora.request import LoRARequest
from vllm_omni.outputs import OmniRequestOutput
@@ -106,6 +111,7 @@ class OmniOpenAIServingChat(OpenAIServingChat, AudioMixin):
_diffusion_mode: bool = False
_diffusion_engine: AsyncOmni | None = None
_diffusion_model_name: str = ""
+ _supported_speakers: set[str] | None = None
@classmethod
def for_diffusion(
@@ -132,6 +138,18 @@ def for_diffusion(
instance._diffusion_model_name = model_name
return instance
+ def _get_supported_speakers(self) -> set[str]:
+ """Load supported speakers from model config (cached)."""
+ if self._supported_speakers is not None:
+ return self._supported_speakers
+ try:
+ self._supported_speakers = get_supported_speakers_from_hf_config(self.model_config.hf_config)
+ return self._supported_speakers
+ except Exception as e:
+ logger.warning("Could not load speakers from model config: %s", e)
+ self._supported_speakers = set()
+ return self._supported_speakers
+
async def create_chat_completion(
self,
request: ChatCompletionRequest,
@@ -260,7 +278,10 @@ async def create_chat_completion(
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
- return self.create_error_response(f"{e} {e.__cause__}")
+ message = str(e)
+ if e.__cause__ is not None:
+ message = f"{message} {e.__cause__}"
+ return self.create_error_response(message)
request_id = f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}"
@@ -274,6 +295,8 @@ async def create_chat_completion(
)
num_inference_steps = None
+ cfg_text_scale = None
+ cfg_img_scale = None
# Omni multistage image generation: Stage-0 (AR) should receive a clean
# text prompt (and optional conditioning image/size) so the model's own
# processor can construct the correct inputs.
@@ -322,6 +345,8 @@ async def create_chat_completion(
except Exception:
pass
negative_prompt = extra_body.get("negative_prompt")
+ cfg_text_scale = extra_body.get("cfg_text_scale")
+ cfg_img_scale = extra_body.get("cfg_img_scale")
engine_prompt_image: dict[str, Any] | None = None
is_img2img = False
@@ -377,14 +402,18 @@ async def create_chat_completion(
sampling_params_list = self._build_sampling_params_list_from_request(request)
# Apply user-specified overrides to diffusion stage(s) for image generation
- if _image_gen_height is not None or _image_gen_width is not None or num_inference_steps is not None:
- for idx, sp in enumerate(sampling_params_list):
- if hasattr(sp, "height") and _image_gen_height is not None:
- sp.height = _image_gen_height
- if hasattr(sp, "width") and _image_gen_width is not None:
- sp.width = _image_gen_width
- if hasattr(sp, "num_inference_steps") and num_inference_steps is not None:
- sp.num_inference_steps = num_inference_steps
+ for idx, sp in enumerate(sampling_params_list):
+ if hasattr(sp, "height") and _image_gen_height is not None:
+ sp.height = _image_gen_height
+ if hasattr(sp, "width") and _image_gen_width is not None:
+ sp.width = _image_gen_width
+ if hasattr(sp, "num_inference_steps") and num_inference_steps is not None:
+ sp.num_inference_steps = num_inference_steps
+ if hasattr(sp, "extra_args") and sp.extra_args is not None:
+ if cfg_text_scale is not None:
+ sp.extra_args["cfg_text_scale"] = cfg_text_scale
+ if cfg_img_scale is not None:
+ sp.extra_args["cfg_img_scale"] = cfg_img_scale
self._log_inputs(
request_id,
@@ -540,10 +569,11 @@ async def _preprocess_chat(
engine_prompt["cache_salt"] = request.cache_salt
speaker = getattr(request, "speaker", None)
- if speaker is not None and isinstance(speaker, str) and speaker.strip():
+ normalized = validate_requested_speaker(speaker, self._get_supported_speakers())
+ if normalized is not None:
if "additional_information" not in engine_prompt or engine_prompt["additional_information"] is None:
engine_prompt["additional_information"] = {}
- engine_prompt["additional_information"]["speaker"] = [speaker.lower().strip()]
+ engine_prompt["additional_information"]["speaker"] = [normalized]
language = getattr(request, "language", None)
if language is not None and isinstance(language, str) and language.strip():
@@ -698,11 +728,17 @@ def _apply_request_overrides(
for field_name in self._OPENAI_SAMPLING_FIELDS:
value = getattr(request, field_name, None)
- if value is not None:
+ if (value is not None and not isinstance(value, list)) or (isinstance(value, list) and len(value) > 0):
setattr(params, field_name, value)
return params
+ @staticmethod
+ def _set_if_supported(obj: Any, **kwargs: Any) -> None:
+ for key, value in kwargs.items():
+ if value is not None and hasattr(obj, key):
+ setattr(obj, key, value)
+
def _build_sampling_params_list_from_request(
self,
request: ChatCompletionRequest,
@@ -1549,6 +1585,7 @@ async def chat_completion_full_generator(
role,
reasoning_parser,
)
+ final_res = omni_outputs.request_output
elif omni_outputs.final_output_type == "audio":
choices_data = self._create_audio_choice(omni_outputs, role, request, stream=False)
elif omni_outputs.final_output_type == "image":
@@ -2027,6 +2064,254 @@ def _create_image_choice(
return choices
# ==================== Diffusion Mode Methods ====================
+ def _build_multistage_generation_inputs(
+ self,
+ *,
+ engine: AsyncOmni,
+ prompt: str,
+ extra_body: dict[str, Any],
+ reference_images: list[Image.Image],
+ gen_params: OmniDiffusionSamplingParams,
+ ) -> tuple[OmniTextPrompt, list[Any]]:
+ """Build the shared multistage generation prompt and stage params."""
+ stage_configs = getattr(engine, "stage_configs", None) or []
+ default_params_list = list(getattr(engine, "default_sampling_params_list", []) or [])
+
+ height = gen_params.height
+ width = gen_params.width
+ seed = gen_params.seed
+ generator_device = gen_params.generator_device
+ num_outputs_per_prompt = gen_params.num_outputs_per_prompt
+ num_inference_steps = extra_body.get("num_inference_steps")
+ guidance_scale = extra_body.get("guidance_scale")
+ true_cfg_scale = extra_body.get("true_cfg_scale") or extra_body.get("cfg_scale")
+ negative_prompt = extra_body.get("negative_prompt")
+ num_frames = extra_body.get("num_frames")
+ guidance_scale_2 = extra_body.get("guidance_scale_2")
+ lora_body = extra_body.get("lora")
+ layers = extra_body.get("layers")
+ resolution = extra_body.get("resolution")
+
+ engine_prompt_data: dict[str, Any] | None = None
+ modalities = ["image"]
+ if reference_images:
+ if len(reference_images) == 1:
+ engine_prompt_data = {"img2img": reference_images[0]}
+ modalities = ["img2img"]
+ else:
+ engine_prompt_data = {"image": reference_images}
+
+ engine_prompt: OmniTextPrompt = {"prompt": prompt}
+ engine_prompt["modalities"] = modalities
+ if negative_prompt is not None:
+ engine_prompt["negative_prompt"] = negative_prompt
+
+ mm_processor_kwargs: dict[str, Any] = {}
+ if height is not None:
+ mm_processor_kwargs["target_h"] = height
+ if width is not None:
+ mm_processor_kwargs["target_w"] = width
+ if mm_processor_kwargs:
+ engine_prompt["mm_processor_kwargs"] = mm_processor_kwargs
+ if engine_prompt_data is not None:
+ engine_prompt["multi_modal_data"] = engine_prompt_data
+
+ comprehension_idx = None
+ for idx, stage in enumerate(stage_configs):
+ if getattr(stage, "is_comprehension", False):
+ comprehension_idx = idx
+ break
+
+ sampling_params_list: list[Any] = []
+ for idx, stage_cfg in enumerate(stage_configs):
+ stage_type = get_stage_type(stage_cfg)
+ if idx < len(default_params_list):
+ default_stage_params = default_params_list[idx]
+ if hasattr(default_stage_params, "clone"):
+ try:
+ default_stage_params = default_stage_params.clone()
+ except Exception:
+ pass
+ elif stage_type == "diffusion":
+ default_stage_params = gen_params.clone()
+ else:
+ default_stage_params = SamplingParams()
+
+ if (
+ comprehension_idx is not None
+ and idx == comprehension_idx
+ and seed is not None
+ and hasattr(default_stage_params, "seed")
+ ):
+ default_stage_params.seed = seed
+
+ if stage_type == "diffusion":
+ self._set_if_supported(
+ default_stage_params,
+ height=height,
+ width=width,
+ seed=seed,
+ generator_device=generator_device,
+ num_outputs_per_prompt=num_outputs_per_prompt,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ true_cfg_scale=true_cfg_scale,
+ num_frames=num_frames,
+ guidance_scale_2=guidance_scale_2,
+ layers=layers,
+ resolution=resolution,
+ )
+ if lora_body and isinstance(lora_body, dict):
+ try:
+ lora_req, lora_scale = parse_lora_request(lora_body)
+ if lora_req is not None:
+ default_stage_params.lora_request = lora_req
+ if lora_scale is not None:
+ default_stage_params.lora_scale = lora_scale
+ except Exception as e: # pragma: no cover - safeguard
+ logger.warning("Failed to parse LoRA request: %s", e)
+
+ sampling_params_list.append(default_stage_params)
+
+ return engine_prompt, sampling_params_list
+
+ async def generate_diffusion_images(
+ self,
+ *,
+ prompt: str,
+ extra_body: dict[str, Any] | None = None,
+ reference_images: list[str] | None = None,
+ request_id: str | None = None,
+ ) -> tuple[list[Image.Image], dict[str, Any], float] | ErrorResponse:
+ """Generate diffusion images and return raw images plus generation stats."""
+ if request_id is None:
+ request_id = f"chatcmpl-{uuid.uuid4().hex[:16]}"
+ if extra_body is None:
+ extra_body = {}
+ if reference_images is None:
+ reference_images = []
+
+ engine = self._diffusion_engine if self._diffusion_engine is not None else self.engine_client
+
+ height = extra_body.get("height")
+ width = extra_body.get("width")
+ if "size" in extra_body:
+ try:
+ size_str = extra_body["size"]
+ if isinstance(size_str, str) and "x" in size_str.lower():
+ w, h = size_str.lower().split("x")
+ width, height = int(w), int(h)
+ except ValueError:
+ logger.warning("Invalid size format: %s", extra_body.get("size"))
+
+ seed = extra_body.get("seed")
+ generator_device = extra_body.get("generator_device")
+ negative_prompt = extra_body.get("negative_prompt")
+ num_outputs_per_prompt = extra_body.get("num_outputs_per_prompt", 1)
+ lora_body = extra_body.get("lora")
+
+ pil_images: list[Image.Image] = []
+ for img_b64 in reference_images:
+ try:
+ img_bytes = base64.b64decode(img_b64)
+ pil_images.append(Image.open(BytesIO(img_bytes)))
+ except Exception as e:
+ logger.warning("Failed to decode reference image: %s", e)
+
+ gen_params = OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_outputs_per_prompt=num_outputs_per_prompt,
+ seed=seed,
+ )
+ self._set_if_supported(
+ gen_params,
+ generator_device=generator_device,
+ num_inference_steps=extra_body.get("num_inference_steps"),
+ guidance_scale=extra_body.get("guidance_scale"),
+ true_cfg_scale=extra_body.get("true_cfg_scale") or extra_body.get("cfg_scale"),
+ num_frames=extra_body.get("num_frames"),
+ guidance_scale_2=extra_body.get("guidance_scale_2"),
+ layers=extra_body.get("layers"),
+ resolution=extra_body.get("resolution"),
+ )
+
+ if lora_body and isinstance(lora_body, dict):
+ try:
+ lora_req, lora_scale = parse_lora_request(lora_body)
+ if lora_req is not None:
+ gen_params.lora_request = lora_req
+ if lora_scale is not None:
+ gen_params.lora_scale = lora_scale
+ except Exception as e: # pragma: no cover - safeguard
+ logger.warning("Failed to parse LoRA request: %s", e)
+
+ gen_prompt: OmniTextPrompt = {
+ "prompt": prompt,
+ "negative_prompt": negative_prompt,
+ }
+ if pil_images:
+ if len(pil_images) == 1:
+ gen_prompt["multi_modal_data"] = {"image": pil_images[0]}
+ else:
+ od_config = getattr(engine, "od_config", None)
+ supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", False)
+ if od_config is None:
+ supports_multimodal_inputs = True
+ if supports_multimodal_inputs:
+ gen_prompt["multi_modal_data"] = {"image": pil_images}
+ else:
+ return self._create_error_response(
+ "Multiple input images are not supported by the current diffusion model. "
+ "For multi-image editing, start the server with Qwen-Image-Edit-2509 "
+ "and send multiple images in the user message content.",
+ status_code=400,
+ )
+
+ if isinstance(engine, AsyncOmni):
+ diffusion_engine = cast(AsyncOmni, engine)
+ stage_configs = getattr(diffusion_engine, "stage_configs", None) or []
+ if len(stage_configs) > 1:
+ engine_prompt, sampling_params_list = self._build_multistage_generation_inputs(
+ engine=diffusion_engine,
+ prompt=prompt,
+ extra_body=extra_body,
+ reference_images=pil_images,
+ gen_params=gen_params,
+ )
+ else:
+ engine_prompt = gen_prompt
+ sampling_params_list = [gen_params]
+
+ result = None
+ async for output in diffusion_engine.generate(
+ prompt=engine_prompt,
+ sampling_params_list=sampling_params_list,
+ request_id=request_id,
+ ):
+ result = output
+ if result is None:
+ return self._create_error_response("No output generated from AsyncOmni", status_code=500)
+ else:
+ result = await engine.generate(
+ prompt=gen_prompt,
+ sampling_params=gen_params,
+ request_id=request_id,
+ )
+
+ images = getattr(result.request_output, "images", [])
+ stage_durations = result.stage_durations
+ peak_memory_mb = result.peak_memory_mb
+
+ flat_images: list[Image.Image] = []
+ for item in images:
+ if isinstance(item, list):
+ flat_images.extend(item)
+ else:
+ flat_images.append(item)
+
+ return flat_images, stage_durations, peak_memory_mb
+
async def _create_diffusion_chat_completion(
self,
request: ChatCompletionRequest,
@@ -2087,6 +2372,8 @@ async def _create_diffusion_chat_completion(
num_inference_steps = extra_body.get("num_inference_steps")
guidance_scale = extra_body.get("guidance_scale")
true_cfg_scale = extra_body.get("true_cfg_scale") or extra_body.get("cfg_scale")
+ cfg_text_scale = extra_body.get("cfg_text_scale")
+ cfg_img_scale = extra_body.get("cfg_img_scale")
seed = extra_body.get("seed")
negative_prompt = extra_body.get("negative_prompt")
num_outputs_per_prompt = extra_body.get("num_outputs_per_prompt", 1)
@@ -2141,6 +2428,10 @@ async def _create_diffusion_chat_completion(
gen_params.guidance_scale = guidance_scale
if true_cfg_scale is not None:
gen_params.true_cfg_scale = true_cfg_scale
+ if cfg_text_scale is not None:
+ gen_params.extra_args["cfg_text_scale"] = cfg_text_scale
+ if cfg_img_scale is not None:
+ gen_params.extra_args["cfg_img_scale"] = cfg_img_scale
if num_frames is not None:
gen_params.num_frames = num_frames
if guidance_scale_2 is not None:
@@ -2161,7 +2452,7 @@ async def _create_diffusion_chat_completion(
except Exception as e: # pragma: no cover - safeguard
logger.warning("Failed to parse LoRA request: %s", e)
- # Add reference image if provided
+ # Add reference image if provided (from messages content)
if pil_images:
if len(pil_images) == 1:
gen_prompt["multi_modal_data"] = {}
@@ -2185,10 +2476,30 @@ async def _create_diffusion_chat_completion(
# Generate image
diffusion_engine = cast(AsyncOmni, self._diffusion_engine)
+ stage_configs = list(getattr(diffusion_engine, "stage_configs", []) or [])
+ default_params_list = list(getattr(diffusion_engine, "default_sampling_params_list", []) or [])
+
+ sampling_params_list: list[Any] = []
+ for idx, stage_cfg in enumerate(stage_configs):
+ if get_stage_type(stage_cfg) == "diffusion":
+ sampling_params_list.append(gen_params)
+ continue
+
+ default_stage_params = default_params_list[idx] if idx < len(default_params_list) else SamplingParams()
+ if hasattr(default_stage_params, "clone"):
+ try:
+ default_stage_params = default_stage_params.clone()
+ except Exception as e:
+ logger.warning("Failed to clone default params for stage %d: %s", idx, e)
+ sampling_params_list.append(default_stage_params)
+
+ if not sampling_params_list:
+ sampling_params_list = [gen_params]
+
result = None
async for output in diffusion_engine.generate(
prompt=gen_prompt,
- sampling_params_list=[gen_params], # Pass as single-stage params
+ sampling_params_list=sampling_params_list,
request_id=request_id,
):
result = output
diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py
index 87ef6a4e9b..ba8292f0c2 100644
--- a/vllm_omni/entrypoints/openai/serving_speech.py
+++ b/vllm_omni/entrypoints/openai/serving_speech.py
@@ -49,12 +49,16 @@
_FISH_TTS_MODEL_STAGES = {"fish_speech_slow_ar"}
_COSYVOICE3_TTS_MODEL_STAGES = {"cosyvoice3_talker"}
_OMNIVOICE_TTS_MODEL_STAGES = {"omnivoice_generator"}
+_VOXCPM_TTS_MODEL_STAGES = {"latent_generator", "vae"}
+_VOXCPM2_TTS_MODEL_STAGES = {"latent_generator"}
_TTS_MODEL_STAGES: set[str] = (
_VOXTRAL_TTS_MODEL_STAGES
| _QWEN3_TTS_MODEL_STAGES
| _FISH_TTS_MODEL_STAGES
| _COSYVOICE3_TTS_MODEL_STAGES
| _OMNIVOICE_TTS_MODEL_STAGES
+ | _VOXCPM_TTS_MODEL_STAGES
+ | _VOXCPM2_TTS_MODEL_STAGES
)
_TTS_LANGUAGES: set[str] = {
"Auto",
@@ -212,6 +216,8 @@ def __init__(self, *args, **kwargs):
"Re-upload voices after each restart if needed."
)
self._tts_tokenizer = None
+ self._voxcpm2_tokenizer = None
+ self._voxcpm2_split_map: dict[int, list[int]] = {}
logger.info(f"Loaded {len(self.supported_speakers)} supported speakers: {sorted(self.supported_speakers)}")
@@ -280,6 +286,11 @@ def _detect_tts_model_type(self) -> str | None:
if self._tts_stage is None:
return None
model_stage = getattr(self._tts_stage.engine_args, "model_stage", None)
+ model_arch = getattr(self._tts_stage.engine_args, "model_arch", None)
+ if model_arch == "VoxCPM2TalkerForConditionalGeneration":
+ return "voxcpm2"
+ if model_arch == "VoxCPMForConditionalGeneration":
+ return "voxcpm"
if model_stage in _QWEN3_TTS_MODEL_STAGES:
return "qwen3_tts"
if model_stage in _VOXTRAL_TTS_MODEL_STAGES:
@@ -290,6 +301,12 @@ def _detect_tts_model_type(self) -> str | None:
return "cosyvoice3"
if model_stage in _OMNIVOICE_TTS_MODEL_STAGES:
return "omnivoice"
+ if model_stage in (_VOXCPM_TTS_MODEL_STAGES | _VOXCPM2_TTS_MODEL_STAGES):
+ has_vae_stage = any(
+ getattr(getattr(stage, "engine_args", None), "model_stage", None) == "vae"
+ for stage in self.engine_client.stage_configs
+ )
+ return "voxcpm" if has_vae_stage or model_stage == "vae" else "voxcpm2"
return None
def _compute_max_instructions_length(self) -> int:
@@ -314,6 +331,8 @@ def _compute_max_instructions_length(self) -> int:
def _load_supported_speakers(self) -> set[str]:
"""Load supported speakers (case-insensitive) from the model configuration."""
try:
+ if self._tts_model_type == "voxcpm":
+ return set()
if self._tts_model_type == "voxtral_tts":
config = self.engine_client.model_config.hf_config.audio_config
else:
@@ -373,6 +392,8 @@ def _estimate_ref_code_len(self, ref_audio: object) -> int | None:
def _estimate_prompt_len(self, tts_params: dict[str, Any]) -> int:
"""Estimate prompt length so the placeholder matches model-side embeddings."""
try:
+ if self._tts_model_type == "voxcpm":
+ return 1
from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import (
Qwen3TTSTalkerForConditionalGeneration,
)
@@ -436,6 +457,25 @@ def _estimate_fish_prompt_len(self, text: str, ref_text: str, ref_audio: object)
logger.warning("Failed to estimate Fish Speech prompt length, using fallback 2048: %s", e)
return 2048
+ async def _build_voxcpm2_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]:
+ """Build prefill prompt for VoxCPM2 TTS (`prompt_token_ids` padded to full prefill length)."""
+ from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import build_voxcpm2_prompt
+
+ self._voxcpm2_encode("") # lazy-init tokenizer + split_map
+ ref_audio = None
+ ref_sr = None
+ if request.ref_audio is not None:
+ ref_audio, ref_sr = await self._resolve_ref_audio(request.ref_audio)
+ return build_voxcpm2_prompt(
+ hf_config=self.engine_client.model_config.hf_config,
+ tokenizer=self._voxcpm2_tokenizer,
+ split_map=self._voxcpm2_split_map,
+ text=request.input,
+ ref_audio=ref_audio,
+ ref_sr=ref_sr,
+ ref_text=request.ref_text,
+ )
+
def _get_uploaded_audio_data(self, voice_name: str) -> str | None:
"""Get base64 encoded audio data for uploaded voice."""
voice_name_lower = voice_name.lower()
@@ -787,8 +827,31 @@ def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | Non
return self._validate_fish_tts_request(request)
if self._tts_model_type == "cosyvoice3":
return self._validate_cosyvoice3_request(request)
+ if self._tts_model_type == "voxcpm":
+ return self._validate_voxcpm_request(request)
+ if self._tts_model_type == "voxcpm2":
+ return None # VoxCPM2 accepts any text input
return self._validate_qwen_tts_request(request)
+ def _voxcpm2_encode(self, text: str) -> list[int]:
+ """Tokenize text for VoxCPM2, splitting multichar Chinese tokens."""
+ from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import (
+ build_cjk_split_map,
+ split_multichar_chinese,
+ )
+
+ if self._voxcpm2_tokenizer is None:
+ from transformers import AutoTokenizer
+
+ model_name = self.engine_client.model_config.model
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+ self._voxcpm2_split_map = build_cjk_split_map(tokenizer)
+ self._voxcpm2_tokenizer = tokenizer
+ logger.info("VoxCPM2 serving: built multichar split map (%d entries)", len(self._voxcpm2_split_map))
+
+ ids = self._voxcpm2_tokenizer.encode(text, add_special_tokens=True)
+ return split_multichar_chinese(ids, self._voxcpm2_split_map)
+
def _validate_ref_audio_format(self, ref_audio: str) -> str | None:
"""Validate ref_audio is a supported URI format. Returns error or None."""
if not (
@@ -826,6 +889,43 @@ def _validate_voxtral_tts_request(self, request: OpenAICreateSpeechRequest) -> s
return None
+ def _validate_voxcpm_request(self, request: OpenAICreateSpeechRequest) -> str | None:
+ """Validate VoxCPM request parameters. Returns error message or None."""
+ if not request.input or not request.input.strip():
+ return "Input text cannot be empty"
+
+ if request.voice is not None:
+ return "'voice' is not supported for VoxCPM"
+ if request.instructions is not None:
+ return "'instructions' is not supported for VoxCPM"
+ if request.language is not None:
+ return "'language' is not supported for VoxCPM"
+ if request.task_type not in (None, "Base"):
+ return "VoxCPM only supports plain TTS or voice cloning with ref_audio/ref_text"
+ if request.x_vector_only_mode is not None:
+ return "'x_vector_only_mode' is not supported for VoxCPM"
+ if request.speaker_embedding is not None:
+ return "'speaker_embedding' is not supported for VoxCPM"
+ if request.initial_codec_chunk_frames is not None:
+ return "'initial_codec_chunk_frames' is not supported for VoxCPM"
+
+ if request.ref_audio is not None:
+ fmt_err = self._validate_ref_audio_format(request.ref_audio)
+ if fmt_err:
+ return fmt_err
+ if not request.ref_text or not request.ref_text.strip():
+ return "Voice cloning requires 'ref_text' (transcript of the reference audio)"
+ elif request.ref_text is not None:
+ return "'ref_text' requires 'ref_audio' for VoxCPM voice cloning"
+
+ if request.max_new_tokens is not None:
+ if request.max_new_tokens < _TTS_MAX_NEW_TOKENS_MIN:
+ return f"max_new_tokens must be at least {_TTS_MAX_NEW_TOKENS_MIN}"
+ if request.max_new_tokens > _TTS_MAX_NEW_TOKENS_MAX:
+ return f"max_new_tokens cannot exceed {_TTS_MAX_NEW_TOKENS_MAX}"
+
+ return None
+
def _validate_qwen_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None:
"""Validate Qwen TTS request parameters. Returns error message or None."""
# Infer Base task when ref_audio or ref_text is provided without explicit task_type.
@@ -919,6 +1019,13 @@ def _validate_qwen_tts_request(self, request: OpenAICreateSpeechRequest) -> str
fmt_err = self._validate_ref_audio_format(request.ref_audio)
if fmt_err:
return fmt_err
+ if not getattr(request, "x_vector_only_mode", False) and (
+ not request.ref_text or not request.ref_text.strip()
+ ):
+ return (
+ "Base task requires non-empty 'ref_text' (transcript of "
+ "the reference audio) unless 'x_vector_only_mode' is enabled"
+ )
# Validate cross-parameter dependencies
if task_type != "Base":
@@ -1017,11 +1124,15 @@ async def _resolve_ref_audio(self, ref_audio_str: str) -> tuple[list[float], int
URLs, ``data:`` base64 URIs, and ``file:`` local paths (the latter
gated by ``--allowed-local-media-path``).
"""
- model_config = self.model_config
- connector = MediaConnector(
- allowed_local_media_path=model_config.allowed_local_media_path,
- allowed_media_domains=model_config.allowed_media_domains,
- )
+ # In diffusion mode, model_config may not be available
+ if self._diffusion_mode:
+ connector = MediaConnector()
+ else:
+ model_config = self.model_config
+ connector = MediaConnector(
+ allowed_local_media_path=model_config.allowed_local_media_path,
+ allowed_media_domains=model_config.allowed_media_domains,
+ )
wav_np, sr = await connector.fetch_audio_async(ref_audio_str)
wav_np = np.asarray(wav_np, dtype=np.float32)
if wav_np.ndim > 1:
@@ -1152,6 +1263,18 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any
Processes each parameter if present, skips if not.
Values are wrapped in lists as required by the model.
"""
+ if self._tts_model_type == "voxcpm":
+ params: dict[str, Any] = {
+ "text": [request.input],
+ "cfg_value": [2.0],
+ "inference_timesteps": [10],
+ "min_len": [2],
+ "max_new_tokens": [request.max_new_tokens or 4096],
+ }
+ if request.ref_text is not None:
+ params["ref_text"] = [request.ref_text]
+ return params
+
params: dict[str, Any] = {}
# Text content (always required)
@@ -1392,8 +1515,36 @@ async def _prepare_speech_generation(
prompt = await self._build_fish_speech_prompt_async(request, ref_audio_data=ref_audio_data)
tts_params = {}
elif self._tts_model_type == "omnivoice":
+ if not request.input or not request.input.strip():
+ raise ValueError("Input text cannot be empty")
+ tts_params = {}
+ prompt: dict[str, Any] = {"input": request.input}
+ # Resolve ref_audio: explicit request param or uploaded voice
+ ref_src = request.ref_audio
+ if not ref_src and request.voice:
+ vl = request.voice.lower()
+ if vl in self.uploaded_speakers:
+ sp = self.uploaded_speakers[vl]
+ if sp.get("embedding_source") == "audio":
+ ref_src = self._get_uploaded_audio_data(request.voice)
+ if not ref_src:
+ raise ValueError(f"Audio for voice '{request.voice}' missing")
+ prompt["ref_text"] = sp.get("ref_text")
+ if ref_src:
+ fmt_err = self._validate_ref_audio_format(ref_src)
+ if fmt_err:
+ raise ValueError(fmt_err)
+ wav, sr = await self._resolve_ref_audio(ref_src)
+ prompt["ref_audio"] = (np.asarray(wav, dtype=np.float32), sr)
+ if request.ref_text:
+ prompt["ref_text"] = request.ref_text
+ if request.language:
+ prompt["lang"] = request.language
+ if request.instructions:
+ prompt["instruct"] = request.instructions
+ elif self._tts_model_type == "voxcpm2":
+ prompt = await self._build_voxcpm2_prompt(request)
tts_params = {}
- prompt = request.input # Diffusion engine takes raw text
elif self._is_tts:
validation_error = self._validate_tts_request(request)
if validation_error:
@@ -1420,6 +1571,24 @@ async def _prepare_speech_generation(
ph_len = await self._estimate_prompt_len_async(tts_params)
prompt = {"prompt_token_ids": [1] * ph_len, "additional_information": tts_params}
else:
+ # Qwen omni models (Qwen3-Omni, Qwen2.5-Omni) use a "talker"
+ # stage whose preprocess requires chat-templated tokens. The
+ # async-chunk orchestrator prewarms the talker via
+ # compute_talker_prompt_ids_length(), which scans for Qwen
+ # chat-template markers (im_start_token_id 151644). A raw-text
+ # prompt produces a 1-token placeholder that crashes the talker's
+ # prefill/decode handoff. Reject early with an actionable message.
+ stage_names = {
+ getattr(getattr(s, "engine_args", None), "model_stage", None) for s in self.engine_client.stage_configs
+ }
+ if "talker" in stage_names:
+ raise ValueError(
+ "The /v1/audio/speech endpoint is only supported for "
+ "dedicated TTS models (e.g., Qwen3-TTS, Voxtral, Fish "
+ "Speech, CosyVoice3, OmniVoice, VoxCPM2). For omni "
+ "models like Qwen3-Omni, use /v1/chat/completions with "
+ '\'"modalities": ["audio"]\' instead.'
+ )
tts_params = {}
prompt = {"prompt": request.input}
@@ -1430,6 +1599,10 @@ async def _prepare_speech_generation(
model_type = "voxtral_tts"
elif self._tts_model_type == "cosyvoice3":
model_type = "cosyvoice3"
+ elif self._tts_model_type == "voxcpm":
+ model_type = "voxcpm"
+ elif self._tts_model_type == "voxcpm2":
+ model_type = "voxcpm2"
elif self._is_tts:
model_type = tts_params.get("task_type", ["unknown"])[0]
else:
@@ -1560,13 +1733,26 @@ async def _create_diffusion_speech(
from vllm_omni.outputs import OmniRequestOutput
try:
+ if not request.input or not request.input.strip():
+ raise ValueError("Input text cannot be empty")
+
request_id = f"speech-{random_uuid()}"
- prompt = request.input
+ prompt: dict[str, Any] = {"input": request.input}
+ if request.ref_audio:
+ wav, sr = await self._resolve_ref_audio(request.ref_audio)
+ prompt["ref_audio"] = (np.asarray(wav, dtype=np.float32), sr)
+ if request.ref_text:
+ prompt["ref_text"] = request.ref_text
+ if request.language:
+ prompt["lang"] = request.language
+ if request.instructions:
+ prompt["instruct"] = request.instructions
logger.info(
- "Diffusion TTS speech request %s: text=%r",
+ "Diffusion TTS speech request %s: text=%r, voice_clone=%s",
request_id,
- prompt[:50] + "..." if len(prompt) > 50 else prompt,
+ request.input[:50] + "..." if len(request.input) > 50 else request.input,
+ "ref_audio" in prompt,
)
generator = self._diffusion_engine.generate(
diff --git a/vllm_omni/entrypoints/openai/serving_video.py b/vllm_omni/entrypoints/openai/serving_video.py
index bddfd48003..741295c7c2 100644
--- a/vllm_omni/entrypoints/openai/serving_video.py
+++ b/vllm_omni/entrypoints/openai/serving_video.py
@@ -33,6 +33,18 @@ class ReferenceImage:
data: Image.Image
+@dataclass
+class VideoGenerationArtifacts:
+ """Normalized outputs and profiler metadata extracted from one request."""
+
+ videos: list[Any]
+ audios: list[Any | None]
+ audio_sample_rate: int
+ output_fps: int
+ stage_durations: dict[str, float]
+ peak_memory_mb: float
+
+
class OmniOpenAIServingVideo:
"""OpenAI-style video generation handler for omni diffusion models."""
@@ -77,12 +89,8 @@ async def _run_and_extract(
reference_id: str,
*,
reference_image: ReferenceImage | None = None,
- ) -> tuple[list[Any], list[Any | None], int, int]:
- """Run the generation pipeline and extract video/audio outputs.
-
- Returns:
- Tuple of (videos, audios, audio_sample_rate, output_fps).
- """
+ ) -> VideoGenerationArtifacts:
+ """Run the generation pipeline and extract video/audio/profiler outputs."""
prompt: OmniTextPrompt = OmniTextPrompt(prompt=request.prompt)
if request.negative_prompt is not None:
prompt["negative_prompt"] = request.negative_prompt
@@ -105,6 +113,10 @@ async def _run_and_extract(
if vp.fps is not None:
gen_params.fps = vp.fps
gen_params.frame_rate = float(vp.fps)
+ gen_params.enable_frame_interpolation = request.enable_frame_interpolation
+ gen_params.frame_interpolation_exp = request.frame_interpolation_exp
+ gen_params.frame_interpolation_scale = request.frame_interpolation_scale
+ gen_params.frame_interpolation_model_path = request.frame_interpolation_model_path
if request.num_inference_steps is not None:
gen_params.num_inference_steps = request.num_inference_steps
@@ -152,8 +164,15 @@ async def _run_and_extract(
videos = self._extract_video_outputs(result)
audios = self._extract_audio_outputs(result, expected_count=len(videos))
audio_sample_rate = self._resolve_audio_sample_rate(result)
- output_fps = vp.fps or self._resolve_fps(result) or 24
- return videos, audios, audio_sample_rate, output_fps
+ output_fps = (vp.fps or self._resolve_fps(result) or 24) * self._resolve_video_fps_multiplier(result)
+ return VideoGenerationArtifacts(
+ videos=videos,
+ audios=audios,
+ audio_sample_rate=audio_sample_rate,
+ output_fps=output_fps,
+ stage_durations=self._extract_stage_durations(result),
+ peak_memory_mb=self._extract_peak_memory_mb(result),
+ )
async def generate_videos(
self,
@@ -162,28 +181,38 @@ async def generate_videos(
*,
reference_image: ReferenceImage | None = None,
) -> VideoGenerationResponse:
- videos, audios, audio_sample_rate, output_fps = await self._run_and_extract(
- request, reference_id, reference_image=reference_image
- )
+ artifacts = await self._run_and_extract(request, reference_id, reference_image=reference_image)
+
+ video_codec_options = {"preset": "ultrafast", "threads": "0"}
+ if request.extra_params is not None and isinstance(request.extra_params, dict):
+ if "video_codec_options" in request.extra_params:
+ video_codec_options = request.extra_params["video_codec_options"]
+
_t_encode_start = time.perf_counter()
video_data = [
VideoData(
b64_json=(
- encode_video_base64(video, fps=output_fps)
- if audios[idx] is None
+ encode_video_base64(video, fps=artifacts.output_fps, video_codec_options=video_codec_options)
+ if artifacts.audios[idx] is None
else encode_video_base64(
video,
- fps=output_fps,
- audio=audios[idx],
- audio_sample_rate=audio_sample_rate,
+ fps=artifacts.output_fps,
+ audio=artifacts.audios[idx],
+ audio_sample_rate=artifacts.audio_sample_rate,
+ video_codec_options=video_codec_options,
)
)
)
- for idx, video in enumerate(videos)
+ for idx, video in enumerate(artifacts.videos)
]
_t_encode_ms = (time.perf_counter() - _t_encode_start) * 1000
logger.info("Video response encoding (MP4+base64): %.2f ms", _t_encode_ms)
- return VideoGenerationResponse(created=int(time.time()), data=video_data)
+ return VideoGenerationResponse(
+ created=int(time.time()),
+ data=video_data,
+ stage_durations=artifacts.stage_durations,
+ peak_memory_mb=artifacts.peak_memory_mb,
+ )
async def generate_video_bytes(
self,
@@ -191,25 +220,48 @@ async def generate_video_bytes(
reference_id: str,
*,
reference_image: ReferenceImage | None = None,
- ) -> bytes:
+ ) -> tuple[bytes, dict[str, float], float]:
"""Generate a video and return raw MP4 bytes, bypassing base64 encoding."""
- videos, audios, audio_sample_rate, output_fps = await self._run_and_extract(
- request, reference_id, reference_image=reference_image
- )
- if len(videos) > 1:
+ artifacts = await self._run_and_extract(request, reference_id, reference_image=reference_image)
+ if len(artifacts.videos) > 1:
logger.warning(
- "Video request %s generated %d outputs; returning only the first.", reference_id, len(videos)
+ "Video request %s generated %d outputs; returning only the first.",
+ reference_id,
+ len(artifacts.videos),
)
- audio = audios[0]
+ audio = artifacts.audios[0]
+
+ video_codec_options = {"preset": "ultrafast", "threads": "0"}
+ if request.extra_params is not None and isinstance(request.extra_params, dict):
+ if "video_codec_options" in request.extra_params:
+ video_codec_options = request.extra_params["video_codec_options"]
+
_t_encode_start = time.perf_counter()
video_bytes = _encode_video_bytes(
- videos[0],
- fps=output_fps,
- **({"audio": audio, "audio_sample_rate": audio_sample_rate} if audio is not None else {}),
+ artifacts.videos[0],
+ fps=artifacts.output_fps,
+ **({"audio": audio, "audio_sample_rate": artifacts.audio_sample_rate} if audio is not None else {}),
+ video_codec_options=video_codec_options,
)
_t_encode_ms = (time.perf_counter() - _t_encode_start) * 1000
logger.info("Video response encoding (MP4 bytes): %.2f ms", _t_encode_ms)
- return video_bytes
+ return video_bytes, artifacts.stage_durations, artifacts.peak_memory_mb
+
+ @staticmethod
+ def _resolve_video_fps_multiplier(result: Any) -> int:
+ custom_output = getattr(result, "custom_output", None)
+ if isinstance(custom_output, dict):
+ multiplier = custom_output.get("video_fps_multiplier")
+ if multiplier is not None:
+ return int(multiplier)
+ request_output = getattr(result, "request_output", None)
+ if request_output is not None:
+ custom_output = getattr(request_output, "custom_output", None)
+ if isinstance(custom_output, dict):
+ multiplier = custom_output.get("video_fps_multiplier")
+ if multiplier is not None:
+ return int(multiplier)
+ return 1
@staticmethod
def _apply_lora(lora_body: Any, gen_params: OmniDiffusionSamplingParams) -> None:
@@ -483,3 +535,16 @@ def _coerce_audio_sample_rate(value: Any) -> int | None:
return None
return sample_rate if sample_rate > 0 else None
+
+ @staticmethod
+ def _extract_stage_durations(result: Any) -> dict[str, float]:
+ stage_durations = getattr(result, "stage_durations", None)
+ return stage_durations if isinstance(stage_durations, dict) else {}
+
+ @staticmethod
+ def _extract_peak_memory_mb(result: Any) -> float:
+ peak_memory_mb = getattr(result, "peak_memory_mb", 0.0)
+ try:
+ return float(peak_memory_mb or 0.0)
+ except (TypeError, ValueError):
+ return 0.0
diff --git a/vllm_omni/entrypoints/openai/utils.py b/vllm_omni/entrypoints/openai/utils.py
index 84b28ef5b1..f411526fdb 100644
--- a/vllm_omni/entrypoints/openai/utils.py
+++ b/vllm_omni/entrypoints/openai/utils.py
@@ -53,3 +53,33 @@ def parse_lora_request(lora_body: Any) -> tuple[LoRARequest | None, float | None
scale = float(lora_scale) if lora_scale is not None else None
return LoRARequest(str(lora_name), int(lora_int_id), str(lora_path)), scale
+
+
+def get_supported_speakers_from_hf_config(hf_config: Any) -> set[str]:
+ """Extract supported speaker names from a model hf_config."""
+ config = (
+ hf_config.get("talker_config") if isinstance(hf_config, dict) else getattr(hf_config, "talker_config", None)
+ )
+ if config is None:
+ return set()
+
+ for spk_attr in ("speaker_id", "spk_id"):
+ speakers_dict = config.get(spk_attr) if isinstance(config, dict) else getattr(config, spk_attr, None)
+ if speakers_dict and isinstance(speakers_dict, dict):
+ return {speaker.lower() for speaker in speakers_dict}
+ return set()
+
+
+def validate_requested_speaker(speaker: str | None, supported_speakers: set[str]) -> str | None:
+ """Normalize and validate an optional speaker value.
+
+ Returns the normalized speaker string when provided, otherwise ``None``.
+ Raises ``ValueError`` when the speaker is not in the supported list.
+ """
+ if not isinstance(speaker, str) or not speaker.strip():
+ return None
+
+ normalized = speaker.lower().strip()
+ if supported_speakers and normalized not in supported_speakers:
+ raise ValueError(f"Invalid speaker '{speaker}'. Supported: {', '.join(sorted(supported_speakers))}")
+ return normalized
diff --git a/vllm_omni/entrypoints/openai/video_api_utils.py b/vllm_omni/entrypoints/openai/video_api_utils.py
index 69178fb3d3..3fb991225c 100644
--- a/vllm_omni/entrypoints/openai/video_api_utils.py
+++ b/vllm_omni/entrypoints/openai/video_api_utils.py
@@ -202,7 +202,13 @@ def _coerce_audio_to_numpy(audio: Any) -> np.ndarray:
return arr.astype(np.float32)
-def _encode_video_bytes(video: Any, fps: int, audio: Any | None = None, audio_sample_rate: int | None = None) -> bytes:
+def _encode_video_bytes(
+ video: Any,
+ fps: int,
+ audio: Any | None = None,
+ audio_sample_rate: int | None = None,
+ video_codec_options: dict[str, str] | None = None,
+) -> bytes:
"""Encode a video payload into MP4 bytes, optionally muxing audio."""
from vllm_omni.diffusion.utils.media_utils import mux_video_audio_bytes
@@ -213,7 +219,16 @@ def _encode_video_bytes(video: Any, fps: int, audio: Any | None = None, audio_sa
frames_np = np.stack(frames, axis=0)
if frames_np.ndim == 4 and frames_np.shape[-1] == 4:
frames_np = frames_np[..., :3]
- frames_u8 = (np.clip(frames_np, 0.0, 1.0) * 255).round().clip(0, 255).astype(np.uint8)
+
+ if frames_np.dtype == np.uint8:
+ frames_u8 = frames_np
+ else:
+ frames_np = np.clip(frames_np, 0.0, 1.0)
+ frames_np *= 255.0
+ frames_u8 = np.round(frames_np).astype(np.uint8)
+
+ # Ensure contiguous memory layout for faster PyAV muxing
+ frames_u8 = np.ascontiguousarray(frames_u8)
audio_np = _coerce_audio_to_numpy(audio) if audio is not None else None
@@ -222,10 +237,19 @@ def _encode_video_bytes(video: Any, fps: int, audio: Any | None = None, audio_sa
audio_np,
fps=float(fps),
audio_sample_rate=audio_sample_rate or 24000,
+ video_codec_options=video_codec_options,
)
-def encode_video_base64(video: Any, fps: int, audio: Any | None = None, audio_sample_rate: int | None = None) -> str:
+def encode_video_base64(
+ video: Any,
+ fps: int,
+ audio: Any | None = None,
+ audio_sample_rate: int | None = None,
+ video_codec_options: dict[str, str] | None = None,
+) -> str:
"""Encode a video (frames/array/tensor) to base64 MP4."""
- video_bytes = _encode_video_bytes(video, fps=fps, audio=audio, audio_sample_rate=audio_sample_rate)
+ video_bytes = _encode_video_bytes(
+ video, fps=fps, audio=audio, audio_sample_rate=audio_sample_rate, video_codec_options=video_codec_options
+ )
return base64.b64encode(video_bytes).decode("utf-8")
diff --git a/vllm_omni/entrypoints/pd_utils.py b/vllm_omni/entrypoints/pd_utils.py
index 0e3d65f553..413d5d6b44 100644
--- a/vllm_omni/entrypoints/pd_utils.py
+++ b/vllm_omni/entrypoints/pd_utils.py
@@ -23,9 +23,19 @@
class PDDisaggregationMixin:
"""Mixin supplying PD disaggregation helpers to OmniBase."""
+ def _get_pd_separation_pair(self) -> tuple[int, int] | None:
+ """PD prefill/decode indices when ``_init_pd_state`` ran; else ``None``.
+
+ Partial test doubles may skip ``OmniBase.__init__``; treat missing state as
+ no PD disaggregation instead of raising ``AttributeError``.
+ """
+ return getattr(self, "_pd_separation_pair", None)
+
def _init_pd_state(self) -> None:
"""Initialise PD disaggregation state."""
- self._pd_separation_pair: tuple[int, int] | None = self._detect_pd_separation()
+ self._pd_separation_pair: tuple[int, int] | None = self.detect_pd_separation_from_stage_configs(
+ self.stage_configs
+ )
self._pd_connector_info: dict[str, Any] | None = None
self._pd_kv_params_by_req: dict[str, dict[str, Any]] = {}
self._pd_kv_params_lock = threading.Lock()
@@ -40,11 +50,19 @@ def _init_pd_state(self) -> None:
d_id,
)
- def _detect_pd_separation(self) -> tuple[int, int] | None:
- """Scan stage_list for a prefill/decode pair. Returns (p_id, d_id) or None."""
+ @staticmethod
+ def detect_pd_separation_from_stage_configs(stage_configs: list[Any]) -> tuple[int, int] | None:
+ """Scan stage configs for a prefill/decode pair.
+
+ Returns:
+ (prefill_idx, decode_idx) if one pair exists, None if not found.
+
+ Raises:
+ ValueError: if multiple candidate PD pairs are found.
+ """
prefill_by_id: dict[int, int] = {}
decode_indices: list[int] = []
- for i, stage in enumerate(self.stage_list):
+ for i, stage in enumerate(stage_configs):
if getattr(stage, "is_prefill_only", False):
prefill_by_id[i] = i
sid = getattr(stage, "stage_id", i)
@@ -55,7 +73,7 @@ def _detect_pd_separation(self) -> tuple[int, int] | None:
pd_pairs: list[tuple[int, int]] = []
for j in decode_indices:
- source_ids = getattr(self.stage_list[j], "engine_input_source", [])
+ source_ids = getattr(stage_configs[j], "engine_input_source", [])
for src in source_ids:
if src in prefill_by_id:
pd_pairs.append((prefill_by_id[src], j))
@@ -107,10 +125,11 @@ def _normalize_kv_transfer_params(self, kv_params: Any) -> dict[str, Any] | None
def _validate_pd_separation_config(self) -> None:
"""Validate PD stage configurations are consistent."""
- assert self._pd_separation_pair is not None
- p_id, d_id = self._pd_separation_pair
- p_stage = self.stage_list[p_id]
- d_stage = self.stage_list[d_id]
+ pair = self._get_pd_separation_pair()
+ assert pair is not None
+ p_id, d_id = pair
+ p_stage = self.stage_configs[p_id]
+ d_stage = self.stage_configs[d_id]
def _get_kv_cfg(stage: "OmniStage") -> dict[str, Any]:
ea = stage.engine_args
@@ -158,11 +177,12 @@ def _get_kv_cfg(stage: "OmniStage") -> dict[str, Any]:
def _get_pd_connector_info(self) -> dict[str, Any] | None:
"""Extract prefill engine KV connector info."""
- if self._pd_separation_pair is None:
+ pair = self._get_pd_separation_pair()
+ if pair is None:
return None
- p_id, _ = self._pd_separation_pair
- p_stage = self.stage_list[p_id]
+ p_id, _ = pair
+ p_stage = self.stage_configs[p_id]
ea = p_stage.engine_args
kv_cfg = getattr(ea, "kv_transfer_config", None)
@@ -241,18 +261,17 @@ def _extract_kv_transfer_params(self, engine_outputs: Any) -> dict[str, Any] | N
def _is_pd_routing(self, stage_id: int, next_stage_id: int) -> bool:
"""True when edge stage_id → next_stage_id is the prefill→decode boundary."""
- return self._pd_separation_pair is not None and self._pd_separation_pair == (
- stage_id,
- next_stage_id,
- )
+ pair = self._get_pd_separation_pair()
+ return pair is not None and pair == (stage_id, next_stage_id)
def _maybe_expand_sampling_params(self, sampling_params_list: list) -> list:
"""Auto-duplicate thinker SP for decode stage when user provides N-1 params."""
- if self._pd_separation_pair is None:
+ pair = self._get_pd_separation_pair()
+ if pair is None:
return sampling_params_list
- if len(sampling_params_list) != len(self.stage_list) - 1:
+ if len(sampling_params_list) != len(self.stage_configs) - 1:
return sampling_params_list
- p_id, d_id = self._pd_separation_pair
+ p_id, d_id = pair
sp_list = list(sampling_params_list)
sp_list.insert(d_id, sp_list[p_id])
return sp_list
diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py
index 84391c2ea8..5757d38990 100644
--- a/vllm_omni/entrypoints/utils.py
+++ b/vllm_omni/entrypoints/utils.py
@@ -1,3 +1,4 @@
+import argparse
import os
import types
from collections import Counter
@@ -5,10 +6,12 @@
from pathlib import Path
from typing import Any, get_args, get_origin
+import yaml
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config, get_hf_file_to_dict
from vllm.transformers_utils.repo_utils import file_or_path_exists
+from vllm_omni.config.stage_config import StageConfigFactory
from vllm_omni.config.yaml_util import create_config, load_yaml_config, merge_configs
from vllm_omni.entrypoints.stage_utils import _to_dict
from vllm_omni.platforms import current_omni_platform
@@ -23,6 +26,65 @@
}
+def detect_explicit_cli_keys(
+ argv: list[str],
+ parser: argparse.ArgumentParser | None = None,
+) -> set[str]:
+ """Walk ``argv`` and return the set of ``dest`` attribute names the user
+ explicitly provided (e.g. ``--max-num-seqs 64`` → ``max_num_seqs``).
+
+ Used to distinguish user-typed CLI args from argparse default values so
+ deploy YAMLs are not silently overridden by parser defaults. Shared
+ across online (``vllm serve``) and offline (scripts, examples, tests,
+ CI) entry points — offline callers that parse CLI args via argparse
+ should invoke this on ``sys.argv[1:]`` and pass the result through to
+ ``AsyncOmni`` / ``Omni`` via the ``_cli_explicit_keys`` kwarg.
+
+ When ``parser`` is provided, each token is looked up in the parser's
+ action table to find its real ``dest``. This correctly handles flags
+ with ``dest=`` overrides, alias flags (e.g. ``--usp`` /
+ ``--ulysses-degree`` both mapping to ``ulysses_degree``), and
+ ``--disable-foo`` / ``store_false`` patterns that map to a differently
+ named dest. Callers with access to an ``argparse.ArgumentParser`` should
+ always pass it.
+
+ When ``parser`` is ``None``, a name-based heuristic is used as a
+ fallback (hyphens → underscores, plus a ``no_`` prefix strip for
+ ``argparse.BooleanOptionalAction``). This is correct for simple flags
+ but silently misidentifies ``--disable-X``-style flags and explicit
+ ``dest=`` overrides, so prefer the parser-aware form.
+ """
+ if parser is not None:
+ dest_map: dict[str, str] = {}
+ for action in parser._actions:
+ for opt in action.option_strings:
+ dest_map[opt] = action.dest
+ explicit: set[str] = set()
+ for tok in argv:
+ if not tok.startswith("--"):
+ continue
+ flag = tok.split("=", 1)[0]
+ dest = dest_map.get(flag)
+ if dest is not None:
+ explicit.add(dest)
+ return explicit
+
+ # Fallback: name-based heuristic (legacy path for callers without a parser).
+ explicit = set()
+ for tok in argv:
+ if not tok.startswith("--"):
+ continue
+ name = tok[2:].split("=", 1)[0]
+ if not name:
+ continue
+ attr = name.replace("-", "_")
+ explicit.add(attr)
+ # BooleanOptionalAction: --no-foo records as dest `foo`, not `no_foo`.
+ if attr.startswith("no_"):
+ explicit.add(attr[3:])
+ return explicit
+
+
def inject_omni_kv_config(stage: Any, omni_conn_cfg: dict[str, Any], omni_from: str, omni_to: str) -> None:
"""Inject connector configuration into stage engine arguments."""
# Prepare omni_kv_config dict
@@ -273,29 +335,59 @@ def resolve_model_config_path(model: str) -> str:
return str(stage_config_path)
-def load_stage_configs_from_model(model: str, base_engine_args: dict | None = None) -> list:
+def load_stage_configs_from_model(
+ model: str,
+ base_engine_args: dict | None = None,
+ deploy_config_path: str | None = None,
+ stage_overrides: dict[str, dict[str, Any]] | None = None,
+ cli_explicit_keys: set[str] | None = None,
+) -> list:
"""Load stage configurations from model's default config file.
- .. deprecated::
- This is the legacy OmegaConf-based loading path. New code should use
- ``StageConfigFactory.create_from_model()`` instead. This function will
- be removed once all callers are migrated (see PR series [2/N]).
+ For models registered in the pipeline registry (new path), uses
+ ``StageConfigFactory.create_from_model()`` which merges
+ PipelineConfig + DeployConfig + CLI overrides.
- Loads stage configurations based on the model type and device type.
- First tries to load a device-specific YAML file from stage_configs/{device_type}/
- directory. If not found, falls back to the default config file.
+ For other models (legacy path), loads stage configs from YAML.
Args:
model: Model name or path (used to determine model_type)
+ base_engine_args: Base engine args to merge as CLI overrides.
+ deploy_config_path: Optional explicit deploy config path.
+ stage_overrides: Per-stage overrides from --stage-overrides.
+ cli_explicit_keys: Set of CLI keys the user actually typed. When
+ provided, only these keys override deploy YAML; argparse defaults
+ stay subordinate to YAML. ``None`` means treat every kwarg as
+ explicit (programmatic ``Omni()`` calls).
Returns:
List of stage configuration dictionaries
-
- Raises:
- FileNotFoundError: If no stage config file exists for the model type
"""
if base_engine_args is None:
base_engine_args = {}
+
+ cli_overrides = _convert_dataclasses_to_dict(dict(base_engine_args))
+ # Per-stage JSON overrides are always explicit (the user typed --stage-overrides).
+ explicit = set(cli_explicit_keys) if cli_explicit_keys is not None else None
+ if stage_overrides:
+ for stage_id_str, overrides in stage_overrides.items():
+ for key, val in overrides.items():
+ stage_key = f"stage_{stage_id_str}_{key}"
+ cli_overrides[stage_key] = val
+ if explicit is not None:
+ explicit.add(stage_key)
+
+ stages = StageConfigFactory.create_from_model(
+ model,
+ cli_overrides=cli_overrides,
+ deploy_config_path=deploy_config_path,
+ cli_explicit_keys=explicit,
+ )
+ if stages is not None:
+ # Convert StageConfig objects to OmegaConf for backward compat
+ return [stage.to_omegaconf() for stage in stages]
+
+ # Legacy fallback: load from YAML
stage_config_path = resolve_model_config_path(model)
if stage_config_path is None:
return []
@@ -312,10 +404,9 @@ def load_stage_configs_from_yaml(
base_engine_args: dict | None = None,
prefer_stage_engine_args: bool = True,
) -> list:
- """Load stage configurations from a YAML file.
+ """Load stage configurations from a YAML file (legacy OmegaConf path).
- .. deprecated::
- Legacy OmegaConf-based loader. Will be removed in PR series [2/N].
+ TODO(@lishunyang12): remove once all models use PipelineConfig + DeployConfig.
Args:
config_path: Path to the YAML configuration file
@@ -449,22 +540,75 @@ def load_and_resolve_stage_configs(
stage_configs_path: str | None,
kwargs: dict | None,
default_stage_cfg_factory: Any = None,
+ deploy_config_path: str | None = None,
+ stage_overrides: dict[str, dict[str, Any]] | None = None,
+ cli_explicit_keys: set[str] | None = None,
) -> tuple[str, list]:
"""Load stage configurations from model or YAML file with fallback to defaults.
Args:
model: Model name or path
- stage_configs_path: Optional path to YAML file containing stage configurations
+ stage_configs_path: Optional path to legacy YAML (stage_args format)
kwargs: Engine arguments to merge with stage configs
default_stage_cfg_factory: Optional callable that takes no args and returns
default stage config list when no configs are found
+ deploy_config_path: Optional path to deploy YAML (new format).
+ Mutually exclusive with ``stage_configs_path``.
+ stage_overrides: Per-stage overrides from ``--stage-overrides`` JSON.
+ Keys are stage_id strings, values are dicts of overrides.
Returns:
Tuple of (config_path, stage_configs)
"""
- if stage_configs_path is None:
+ if stage_configs_path is not None and deploy_config_path is not None:
+ raise ValueError(
+ "--stage-configs-path and --deploy-config are mutually exclusive: "
+ "they use different path resolution rules and loading paths. "
+ "Use --deploy-config for new-format YAMLs (preferred); "
+ "--stage-configs-path is kept only for the legacy `stage_args` format "
+ "and will be removed in a future release."
+ )
+ if stage_configs_path is not None and deploy_config_path is None:
+ if not os.path.exists(stage_configs_path):
+ raise FileNotFoundError(
+ f"--stage-configs-path {stage_configs_path!r} does not exist. "
+ "Legacy `stage_configs/` yamls were replaced by `vllm_omni/deploy/.yaml`; "
+ "use --deploy-config. See docs/configuration/stage_configs.md."
+ )
+ with open(stage_configs_path, encoding="utf-8") as f:
+ _peek = yaml.safe_load(f) or {}
+ if "stages" in _peek and "stage_args" not in _peek:
+ deploy_config_path = stage_configs_path
+ stage_configs_path = None
+ else:
+ logger.warning(
+ "--stage-configs-path is deprecated; migrate %r and use --deploy-config.",
+ stage_configs_path,
+ )
+
+ if deploy_config_path is not None:
+ config_path = deploy_config_path
+ stage_configs = load_stage_configs_from_model(
+ model,
+ base_engine_args=kwargs,
+ deploy_config_path=deploy_config_path,
+ stage_overrides=stage_overrides,
+ cli_explicit_keys=cli_explicit_keys,
+ )
+ if not stage_configs:
+ if default_stage_cfg_factory is not None:
+ default_stage_cfg = default_stage_cfg_factory()
+ stage_configs = create_config(default_stage_cfg)
+ else:
+ stage_configs = []
+ elif stage_configs_path is None:
config_path = resolve_model_config_path(model)
- stage_configs = load_stage_configs_from_model(model, base_engine_args=kwargs)
+ stage_configs = load_stage_configs_from_model(
+ model,
+ base_engine_args=kwargs,
+ stage_overrides=stage_overrides,
+ cli_explicit_keys=cli_explicit_keys,
+ )
if not stage_configs:
if default_stage_cfg_factory is not None:
default_stage_cfg = default_stage_cfg_factory()
diff --git a/vllm_omni/inputs/data.py b/vllm_omni/inputs/data.py
index 9cb6c44335..e4c33a58c2 100644
--- a/vllm_omni/inputs/data.py
+++ b/vllm_omni/inputs/data.py
@@ -227,6 +227,10 @@ class OmniDiffusionSamplingParams:
frame_rate: float | None = None # Floating-point rate used by the diffusion model when it differs from `fps`.
height_not_provided: bool = False
width_not_provided: bool = False
+ enable_frame_interpolation: bool = False
+ frame_interpolation_exp: int = 1
+ frame_interpolation_scale: float = 1.0
+ frame_interpolation_model_path: str | None = None
# Timesteps
timesteps: torch.Tensor | None = None
@@ -263,6 +267,10 @@ class OmniDiffusionSamplingParams:
cfg_text_kv_metadata: dict[str, Any] | None = None
cfg_img_kv_metadata: dict[str, Any] | None = None
cfg_kv_request_ids: dict[str, str] | None = None
+ cfg_active_branch: str | None = None
+ cfg_branch_roles: list[str] | None = None
+ cfg_branch_past_key_values: dict[str, Any] | None = None
+ cfg_branch_kv_metadata: dict[str, dict[str, Any]] | None = None
# Component modules
modules: dict[str, Any] = field(default_factory=dict)
diff --git a/vllm_omni/model_executor/models/bagel/bagel.py b/vllm_omni/model_executor/models/bagel/bagel.py
index acbbc28b4c..cbb775680c 100644
--- a/vllm_omni/model_executor/models/bagel/bagel.py
+++ b/vllm_omni/model_executor/models/bagel/bagel.py
@@ -1,4 +1,3 @@
-from collections import deque
from collections.abc import Iterable, Mapping, Sequence
from math import isqrt
from typing import Any
@@ -442,14 +441,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._pending_img2img_info: list[tuple[int, int, int, int]] = []
self._ropes_pending: list[dict[str, Any]] = []
self._ropes_metadata: dict[str, dict[str, Any]] = {}
- self._cfg_companion_queue: deque[tuple[tuple[int, int, int, int], int]] = deque()
-
- # Per-request position offset for decode after img2img prefill.
- # Prefill rewrites positions (VAE→0, ViT→1, text→2..N) but the model
- # runner assigns decode positions starting from prefill_len, not N+1.
- # offset = rope - prefill_len (a negative number).
- self._pending_decode_offsets: list[int] = []
- self._decode_position_offsets: dict[str, int] = {}
+ self._last_img2img_info: tuple[int, int, int, int] | None = None
from transformers import AutoTokenizer
@@ -461,7 +453,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._start_of_image_id = int(_tok.convert_tokens_to_ids("<|vision_start|>"))
self._end_of_image_id = int(_tok.convert_tokens_to_ids("<|vision_end|>"))
self._img2img_token_id = int(_tok.convert_tokens_to_ids("<|fim_middle|>"))
-
self._vae_token_mask: torch.Tensor | None = None
self.device = get_local_device()
self._install_mot_modules(config)
@@ -540,9 +531,7 @@ def _clear_warmup_state(self):
self._ropes_pending.clear()
self._ropes_metadata.clear()
self._pending_img2img_info.clear()
- self._cfg_companion_queue.clear()
- self._pending_decode_offsets.clear()
- self._decode_position_offsets.clear()
+ self._last_img2img_info = None
self._vae_token_mask = None
def get_kv_transfer_metadata(
@@ -554,12 +543,10 @@ def get_kv_transfer_metadata(
meta = self._ropes_metadata.pop(req_id, None)
if meta is None:
return None
- # In think-mode img2img the prefill rope doesn't account for decoded
- # thinking tokens; correct it to num_computed_tokens + offset.
- # Skip correction when num_computed_tokens is unavailable (None).
- offset = self._decode_position_offsets.pop(req_id, 0)
- if offset != 0 and "ropes" in meta and num_computed_tokens is not None:
- meta["ropes"] = [num_computed_tokens + offset]
+ if num_computed_tokens is not None and "image_shape" in meta:
+ prefill_rope = meta["ropes"][0] if meta.get("ropes") else 0
+ if num_computed_tokens > prefill_rope:
+ meta["ropes"] = [num_computed_tokens]
return meta
def prepare_runner_inputs(
@@ -572,48 +559,29 @@ def prepare_runner_inputs(
num_scheduled_tokens: list[int],
input_ids_buffer: torch.Tensor | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
- """Model-runner hook: adjust inputs before ``forward()``.
-
- Returns ``(input_ids, positions)`` — possibly modified.
-
- Two adjustments for BAGEL img2img:
-
- 1. **Restore input_ids** when ``inputs_embeds`` is present so that
- ``_adjust_positions_for_img2img`` can locate the
- ``<|fim_middle|>`` placeholder.
- 2. **Decode position offset**: prefill rewrites positions to a
- compact scheme (rope ≪ prefill_len). The runner assigns decode
- positions from ``num_computed_tokens``, which is far too large;
- apply the stored per-request offset.
- """
+ """Restore input_ids so _adjust_positions_for_img2img can locate
+ the <|fim_middle|> placeholder for thinking-mode pre_text_len
+ detection."""
if inputs_embeds is not None and input_ids is None and input_ids_buffer is not None:
input_ids = input_ids_buffer
-
- if self._decode_position_offsets and positions is not None:
- token_start = 0
- for i, rid in enumerate(req_ids):
- sched = num_scheduled_tokens[i]
- offset = self._decode_position_offsets.get(rid, 0)
- if offset != 0 and num_computed_tokens[i] > 0:
- positions[token_start : token_start + sched] += offset
- token_start += sched
-
return input_ids, positions
def flush_pending_metadata(self, req_ids: list[str]) -> None:
- """Map pending metadata (batch order) to req_ids after forward()."""
+ """Map pending metadata (batch order) to req_ids after forward().
+
+ Guard: if a request already has metadata with ``image_shape``
+ (written during img2img prefill), don't overwrite it with
+ decode-step metadata that lacks ``image_shape``.
+ """
pending = self._ropes_pending
self._ropes_pending = []
for i, meta in enumerate(pending):
if i < len(req_ids):
- if req_ids[i] not in self._ropes_metadata:
- self._ropes_metadata[req_ids[i]] = meta
-
- pending_offsets = self._pending_decode_offsets
- self._pending_decode_offsets = []
- for i, offset in enumerate(pending_offsets):
- if i < len(req_ids) and offset != 0:
- self._decode_position_offsets[req_ids[i]] = offset
+ rid = req_ids[i]
+ existing = self._ropes_metadata.get(rid)
+ if existing and "image_shape" in existing and "image_shape" not in meta:
+ continue
+ self._ropes_metadata[rid] = meta
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {}
@@ -727,16 +695,7 @@ def _process_img2img_input(self, multimodal_input):
num_vit = vit_emb.shape[0] + 2
info = (num_vae, num_vit, int(H), int(W))
self._pending_img2img_info.append(info)
- # Only the gen (main) request should add a companion queue entry.
- # Companion requests (cfg_text, cfg_img) also call this method with
- # the same image, so guard by checking whether this exact info
- # tuple is already enqueued. For batched img2img with multiple
- # concurrent gen requests this correctly adds one entry per unique
- # image; images with identical (num_vae, num_vit, H, W) that arrive
- # in the same batch are indistinguishable here and will share one
- # entry, but that is an uncommon edge case.
- if not any(entry[0] == info for entry in self._cfg_companion_queue):
- self._cfg_companion_queue.append((info, 2)) # cfg_text + cfg_img
+ self._last_img2img_info = info
return tuple(results)
@@ -755,31 +714,18 @@ def forward(
positions = self._adjust_positions_for_img2img(positions, input_ids)
use_mot = True
- elif self._cfg_companion_queue:
- # Guard: if this looks like a pure decode step (small token count,
- # no multimodal embeddings), the queue has stale entries from a
- # previous prefill cycle — clear them instead of consuming.
- if inputs_embeds is None and seq_len <= 2:
- self._cfg_companion_queue.clear()
- else:
- cached, remaining = self._cfg_companion_queue[0]
- remaining -= 1
- num_vae, num_vit, img_H, img_W = cached
- num_img2img = num_vae + 1 + num_vit # +1 separator
- seq_len = inputs_embeds.shape[0] if inputs_embeds is not None else positions.shape[0]
-
- if inputs_embeds is not None and seq_len >= num_img2img:
- self._pending_img2img_info = [cached]
- positions = self._adjust_positions_for_img2img(positions, input_ids)
- use_mot = True
- else:
- rope = int(positions[seq_len - 1].item()) + 1
- self._ropes_pending.append({"ropes": [rope]})
+ elif self._last_img2img_info is not None:
+ info = self._last_img2img_info
+ num_vae, num_vit, _, _ = info
+ num_img2img = num_vae + 1 + num_vit
- if remaining == 0:
- self._cfg_companion_queue.popleft()
- else:
- self._cfg_companion_queue[0] = (cached, remaining)
+ if seq_len >= num_img2img:
+ self._pending_img2img_info = [info]
+ positions = self._adjust_positions_for_img2img(positions, input_ids)
+ use_mot = True
+ else:
+ rope = int(positions[seq_len - 1].item()) + 1
+ self._ropes_pending.append({"ropes": [rope]})
if use_mot:
return self._mot_forward(input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs)
@@ -790,27 +736,18 @@ def _adjust_positions_for_img2img(
positions: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
- """Rewrite position IDs to match the original BAGEL position scheme:
-
- If there are ``pre_text_len`` text tokens before the img2img block::
-
- pre_text → 0, 1, ..., M-1
- VAE → M (all share)
- separator→ M
- ViT → M+1 (all share)
- post_text→ M+2, M+3, ...
+ """Rewrite position IDs for img2img.
- When no text precedes the img2img block (M=0), this reduces to the
- simpler scheme: VAE→0, ViT→1, text→2, 3, ...
+ Supports an optional ``pre_text_len`` prefix (thinking-mode) detected
+ via the ``<|fim_middle|>`` token in *input_ids*:
- Also computes ``self._vae_token_mask`` (bool tensor, True for actual
- VAE latent patches that should use gen-mode weights) and pushes
- per-request ropes + image_shape to the FIFO consumed by
- ``get_kv_transfer_metadata``.
+ pre_text -> 0 .. M-1
+ VAE -> M (all share)
+ separator-> M
+ ViT -> M+1 (all share)
+ post_text-> M+2, M+3, ...
- For img2img requests, also stores a decode position offset so that
- subsequent autoregressive decode steps use positions that continue
- from the rewritten scheme rather than from the original prefill length.
+ When M=0 (standard img2img) this reduces to VAE->0, ViT->1, text->2..
"""
info_list = self._pending_img2img_info
self._pending_img2img_info = []
@@ -836,70 +773,64 @@ def _adjust_positions_for_img2img(
req_len = end - start
if img2img_idx < len(info_list):
- num_vae, num_vit, img_H, img_W = info_list[img2img_idx]
+ cur_info = info_list[img2img_idx]
+ elif self._last_img2img_info is not None:
+ cur_info = self._last_img2img_info
+ else:
+ cur_info = None
+
+ if cur_info is not None:
+ num_vae, num_vit, img_H, img_W = cur_info
num_img2img = num_vae + 1 + num_vit # +1 separator
if req_len >= num_img2img:
- # Detect offset of img2img tokens within this request
- # by searching for the img2img placeholder token ID.
pre_text_len = 0
if input_ids is not None:
- req_ids = input_ids[start:end]
- mask = req_ids == self._img2img_token_id
- indices = mask.nonzero(as_tuple=True)[0]
+ req_ids_slice = input_ids[start:end]
+ indices = (req_ids_slice == self._img2img_token_id).nonzero(as_tuple=True)[0]
if indices.numel() > 0:
pre_text_len = int(indices[0].item())
- img_start = start + pre_text_len
+ M = pre_text_len
+ img_start = start + M
post_text_start = img_start + num_img2img
- # pre_text_pos: position base for image tokens
- pre_text_pos = pre_text_len
- # Pre-image text: sequential positions 0..pre_text_pos-1
- if pre_text_len > 0:
+ if M > 0:
new_positions[start:img_start] = torch.arange(
- 0, pre_text_pos, device=positions.device, dtype=positions.dtype
+ 0, M, device=positions.device, dtype=positions.dtype
)
- # VAE tokens: all share position pre_text_pos
- new_positions[img_start : img_start + num_vae] = pre_text_pos
- # Separator: position pre_text_pos
- new_positions[img_start + num_vae] = pre_text_pos
- # ViT tokens: all share position pre_text_pos+1
+ new_positions[img_start : img_start + num_vae] = M
+ new_positions[img_start + num_vae] = M # separator
vit_start = img_start + num_vae + 1
- new_positions[vit_start : vit_start + num_vit] = pre_text_pos + 1
+ new_positions[vit_start : vit_start + num_vit] = M + 1
- # Post-image text: sequential positions pre_text_pos+2, pre_text_pos+3, ...
num_post_text = end - post_text_start
if num_post_text > 0:
new_positions[post_text_start:end] = torch.arange(
- pre_text_pos + 2,
- pre_text_pos + 2 + num_post_text,
+ M + 2,
+ M + 2 + num_post_text,
device=positions.device,
dtype=positions.dtype,
)
- # VAE gen-mode mask: only actual VAE latent patches (not markers)
- vae_patches_start = img_start + 1 # skip start_marker
- vae_patches_end = img_start + num_vae - 1 # before end_marker
+ vae_patches_start = img_start + 1
+ vae_patches_end = img_start + num_vae - 1
if vae_patches_end > vae_patches_start:
vae_mask[vae_patches_start:vae_patches_end] = True
- rope = pre_text_pos + 2 + num_post_text
+ rope = M + 2 + num_post_text
self._ropes_pending.append(
{
"ropes": [rope],
"image_shape": [img_H, img_W],
}
)
- decode_offset = rope - req_len
- self._pending_decode_offsets.append(decode_offset)
img2img_idx += 1
continue
rope = int(new_positions[end - 1].item()) + 1
self._ropes_pending.append({"ropes": [rope]})
- self._pending_decode_offsets.append(0)
self._vae_token_mask = vae_mask if vae_mask.any() else None
return new_positions
diff --git a/vllm_omni/model_executor/models/common/__init__.py b/vllm_omni/model_executor/models/common/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/vllm_omni/model_executor/models/common/qwen3_code_predictor.py b/vllm_omni/model_executor/models/common/qwen3_code_predictor.py
new file mode 100644
index 0000000000..3a904442fa
--- /dev/null
+++ b/vllm_omni/model_executor/models/common/qwen3_code_predictor.py
@@ -0,0 +1,654 @@
+"""Qwen3 Code Predictor -- optimized re-prefill, no KV cache.
+
+Shared by Qwen3-Omni and Qwen3-TTS talker models.
+
+* SDPA attention (F.scaled_dot_product_attention) with native GQA support
+* HF-compatible numerics (float32 RMSNorm, float32 RoPE, separate linear layers)
+* Per-call embedding buffer to avoid cross-request aliasing
+* Pre-allocated position_ids (read-only, safe to persist)
+* torch.compile (epilogue_fusion=False) on inner transformer by default
+* Optional manual CUDA graph capture per batch-size bucket
+* Inline sampling (top-k + top-p) -- no custom op overhead
+"""
+
+from __future__ import annotations
+
+import dataclasses
+from collections.abc import Iterable
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from vllm.config import VllmConfig
+from vllm.logger import init_logger
+from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+
+from vllm_omni.platforms import current_omni_platform
+
+logger = init_logger(__name__)
+
+
+# ===================================================================
+# HF-numerics-compatible layers for code predictor
+# ===================================================================
+#
+# These use plain PyTorch ops (nn.Linear, manual RMSNorm in float32,
+# rotate_half RoPE) to produce outputs numerically identical to the
+# HuggingFace reference. vLLM's fused kernels (RMSNorm, QKVParallel,
+# get_rope) introduce small precision differences that compound across
+# the autoregressive steps of the code predictor, causing severe
+# audio quality degradation.
+#
+# See: https://github.com/vllm-project/vllm-omni/issues/2274
+
+
+class _RMSNorm(nn.Module):
+ """RMSNorm matching HuggingFace's implementation exactly.
+
+ Computes variance in float32 to avoid bfloat16 precision loss.
+ """
+
+ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+def _rotate_half(x: torch.Tensor) -> torch.Tensor:
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+class _RotaryEmbedding(nn.Module):
+ """RoPE matching HuggingFace's implementation exactly.
+
+ Forces float32 computation for cos/sin, matching HF's torch.autocast(enabled=False).
+ """
+
+ def __init__(self, config) -> None:
+ super().__init__()
+ head_dim = getattr(
+ config,
+ "head_dim",
+ config.hidden_size // config.num_attention_heads,
+ )
+ rope_theta = getattr(config, "rope_theta", 10000.0)
+ inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ # position_ids: [batch, seq_len]
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ # Force float32 (matching HF)
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+# ===================================================================
+# Attention
+# ===================================================================
+
+
+class CodePredictorAttention(nn.Module):
+ """Multi-head self-attention for code predictor.
+
+ Uses ``F.scaled_dot_product_attention`` with HF-compatible RoPE and RMSNorm.
+ No KV cache -- the code predictor always re-prefills the full (short)
+ sequence each AR step.
+
+ Input : [B, seq_len, hidden_size]
+ Output: [B, seq_len, hidden_size]
+ """
+
+ def __init__(self, config, *, prefix: str = "") -> None:
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.num_kv_heads = config.num_key_value_heads
+ self.head_dim = getattr(
+ config,
+ "head_dim",
+ config.hidden_size // config.num_attention_heads,
+ )
+ self.hidden_size = config.hidden_size
+ self.scaling = self.head_dim**-0.5
+ self._use_gqa = self.num_kv_heads != self.num_heads
+
+ # Separate q/k/v projections matching HF (no fused packing)
+ bias = getattr(config, "attention_bias", False)
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=bias)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+ self.q_norm = _RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.k_norm = _RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ ) -> torch.Tensor:
+ bsz, seq_len, _ = hidden_states.shape
+ hidden_shape_q = (bsz, seq_len, self.num_heads, self.head_dim)
+ hidden_shape_kv = (bsz, seq_len, self.num_kv_heads, self.head_dim)
+
+ q = self.q_norm(self.q_proj(hidden_states).view(hidden_shape_q)).transpose(1, 2)
+ k = self.k_norm(self.k_proj(hidden_states).view(hidden_shape_kv)).transpose(1, 2)
+ v = self.v_proj(hidden_states).view(hidden_shape_kv).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ # cos/sin are [batch, seq_len, head_dim], need unsqueeze at dim=1 for heads
+ cos = cos.unsqueeze(1) # [batch, 1, seq_len, head_dim]
+ sin = sin.unsqueeze(1)
+ q = (q * cos) + (_rotate_half(q) * sin)
+ k = (k * cos) + (_rotate_half(k) * sin)
+
+ attn_out = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ scale=self.scaling,
+ is_causal=True,
+ enable_gqa=self._use_gqa,
+ )
+
+ attn_out = attn_out.transpose(1, 2).reshape(bsz, seq_len, -1)
+ return self.o_proj(attn_out)
+
+
+# ===================================================================
+# MLP
+# ===================================================================
+
+
+class CodePredictorMLP(nn.Module):
+ """SiLU-gated MLP for code predictor, matching HF's implementation."""
+
+ def __init__(self, config, *, prefix: str = "") -> None:
+ super().__init__()
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(F.silu(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
+
+
+# ===================================================================
+# Decoder Layer
+# ===================================================================
+
+
+class CodePredictorDecoderLayer(nn.Module):
+ """Transformer decoder layer (SDPA, no KV cache)."""
+
+ def __init__(self, config, *, prefix: str = "") -> None:
+ super().__init__()
+ self.self_attn = CodePredictorAttention(config, prefix=f"{prefix}.self_attn")
+ self.mlp = CodePredictorMLP(config, prefix=f"{prefix}.mlp")
+ self.input_layernorm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states = self.self_attn(hidden_states, position_embeddings)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+# ===================================================================
+# Base Transformer Model (re-prefill, no KV cache)
+# ===================================================================
+
+
+class CodePredictorBaseModel(nn.Module):
+ """Inner transformer for code predictor.
+
+ Signature: ``forward(inputs_embeds, position_ids) -> hidden_states``
+ """
+
+ def __init__(
+ self,
+ config,
+ *,
+ embedding_dim: int | None = None,
+ use_parallel_embedding: bool = False,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.config = config
+
+ emb_dim = int(embedding_dim) if embedding_dim is not None else int(config.hidden_size)
+ if use_parallel_embedding:
+ self.codec_embedding = nn.ModuleList(
+ [VocabParallelEmbedding(config.vocab_size, emb_dim) for _ in range(config.num_code_groups - 1)]
+ )
+ else:
+ self.codec_embedding = nn.ModuleList(
+ [nn.Embedding(config.vocab_size, emb_dim) for _ in range(config.num_code_groups - 1)]
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ CodePredictorDecoderLayer(config, prefix=f"{prefix}.layers.{idx}")
+ for idx in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = _RotaryEmbedding(config)
+
+ def get_input_embeddings(self) -> nn.ModuleList:
+ return self.codec_embedding
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ ) -> torch.Tensor:
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+ for layer in self.layers:
+ hidden_states = layer(hidden_states, position_embeddings)
+ hidden_states = self.norm(hidden_states)
+ return hidden_states
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
+ loaded_params: set[str] = set()
+ for name, loaded_weight in weights:
+ if "rotary_emb.inv_freq" in name:
+ continue
+ param = params_dict.get(name)
+ if param is None:
+ continue
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+# ===================================================================
+# Wrapper Configuration
+# ===================================================================
+
+
+@dataclasses.dataclass
+class CodePredictorWrapperConfig:
+ """Controls behavioral differences between model-specific code predictors."""
+
+ use_cuda_graphs: bool = False
+ use_parallel_embedding: bool = False
+ use_projection: bool = False
+ return_proj_buf: bool = False
+ sampling_mode: str = "stored"
+
+
+# ===================================================================
+# Code Predictor Wrapper (optimized re-prefill, persistent buffers)
+# ===================================================================
+
+
+class CodePredictorWrapper(nn.Module):
+ """Optimized code predictor -- re-prefill approach, no KV cache.
+
+ Each AR step forwards the full growing sequence (len 2 -> num_code_groups+1)
+ through the transformer. The extra O(T^2) FLOPs are negligible for
+ short sequences, and this avoids all KV-cache management overhead.
+
+ Optimizations:
+ 1. Per-call embedding buffer -- avoids cross-request aliasing.
+ 2. Pre-allocated position_ids -- no torch.arange per step.
+ 3. Cached module references -- bypass ModuleList indexing.
+ 4. torch.compile on inner transformer.
+ 5. Inline sampling (top-k + top-p) -- no custom op overhead.
+ 6. Optional manual CUDA graph capture per batch-size bucket.
+ """
+
+ def __init__(
+ self,
+ *,
+ vllm_config: VllmConfig,
+ cp_config,
+ wrapper_config: CodePredictorWrapperConfig,
+ talker_hidden_size: int | None = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self._vllm_config = vllm_config
+ self.config = cp_config
+ self._wrapper_config = wrapper_config
+ self.prefix = prefix
+
+ self._num_groups = int(cp_config.num_code_groups)
+ self._cp_hidden = int(cp_config.hidden_size)
+
+ # For Omni backward compat (accessed by the talker)
+ self.num_code_groups = self._num_groups
+
+ # Determine embedding dimension
+ _talker_hidden = int(talker_hidden_size) if talker_hidden_size is not None else self._cp_hidden
+
+ self.model = CodePredictorBaseModel(
+ cp_config,
+ embedding_dim=_talker_hidden,
+ use_parallel_embedding=wrapper_config.use_parallel_embedding,
+ prefix=f"{prefix}.model" if prefix else "model",
+ )
+
+ self.lm_head = nn.ModuleList(
+ [nn.Linear(cp_config.hidden_size, cp_config.vocab_size, bias=False) for _ in range(self._num_groups - 1)]
+ )
+
+ # Projection: Identity when hidden sizes match or not needed
+ if wrapper_config.use_projection and _talker_hidden != self._cp_hidden:
+ self.small_to_mtp_projection = nn.Linear(_talker_hidden, self._cp_hidden, bias=True)
+ else:
+ self.small_to_mtp_projection = nn.Identity()
+
+ # Sampling defaults for "stored" mode
+ self._top_k: int = 50
+ self._top_p: float = 0.8
+
+ # Lazily initialised state
+ self._proj_buf: torch.Tensor | None = None
+ self._model_dtype: torch.dtype | None = None
+ self._compiled_model_fwd = None
+ self._bucket_sizes: list[int] = []
+ self._bucket_pos_ids: dict[int, torch.Tensor] = {}
+ self._lm_heads_list: list[nn.Module] | None = None
+ self._codec_embeds_list: list[nn.Module] | None = None
+ self._cuda_graphs: dict[int, tuple[torch.cuda.CUDAGraph, torch.Tensor]] = {}
+
+ def get_input_embeddings(self) -> nn.ModuleList:
+ return self.model.get_input_embeddings()
+
+ def set_sampling_params(self, top_k: int = 50, top_p: float = 0.8) -> None:
+ """Configure sampling parameters to maintain consistency with previous implementation."""
+ self._top_k = top_k
+ self._top_p = top_p
+ logger.debug("Sampling parameters updated: top_k=%d, top_p=%.2f", top_k, top_p)
+
+ # ------------------------------------------------------------------
+ # Lazy-init helpers
+ # ------------------------------------------------------------------
+
+ def _ensure_buffers(self, device: torch.device, dtype: torch.dtype, bsz: int) -> None:
+ """Ensure the projection buffer can hold at least *bsz* rows."""
+ max_seq = self._num_groups + 1
+ if (
+ self._proj_buf is not None
+ and self._proj_buf.device == device
+ and self._proj_buf.dtype == dtype
+ and self._proj_buf.shape[0] >= bsz
+ ):
+ return
+ self._proj_buf = torch.zeros(bsz, max_seq, self._cp_hidden, dtype=dtype, device=device)
+
+ def _setup_compile(self) -> None:
+ """Lazily set up torch.compile with optional CUDA graph capture."""
+ if self._compiled_model_fwd is not None:
+ return
+
+ # Cache model parameter dtype so forward() doesn't need to query it
+ # on every call. Also ensures warmup buffers match model precision
+ # even when upstream modules produce a different dtype (#2385).
+ self._model_dtype = next(self.model.parameters()).dtype
+ self._lm_heads_list = list(self.lm_head)
+ self._codec_embeds_list = list(self.model.codec_embedding)
+
+ if not current_omni_platform.supports_torch_inductor():
+ logger.warning_once("code_predictor: torch.compile disabled")
+ self._compiled_model_fwd = self.model.forward
+ return
+
+ # torch.compile fuses RMSNorm/RoPE in ways that lose float32
+ # precision, compounding across AR steps. Use epilogue_fusion=False
+ # to disable the problematic fusions while still getting kernel
+ # fusion benefits for the linear layers and SDPA.
+ self._compiled_model_fwd = torch.compile(
+ self.model.forward,
+ dynamic=False,
+ options={"epilogue_fusion": False},
+ )
+ self._warmup_buckets()
+
+ if self._wrapper_config.use_cuda_graphs:
+ self._capture_cuda_graphs()
+ logger.info("code_predictor: torch.compile (no epilogue fusion) + CUDA graphs")
+ else:
+ logger.info("code_predictor: torch.compile (dynamic=False, no epilogue fusion)")
+
+ def _padded_bsz(self, bsz: int) -> int:
+ """Round batch size up to nearest power-of-2 bucket."""
+ for bucket in self._bucket_sizes:
+ if bsz <= bucket:
+ return bucket
+ return bsz
+
+ def _warmup_buckets(self) -> None:
+ """Warmup power-of-2 batch-size buckets to front-load Inductor compilation."""
+ max_bsz = self._vllm_config.scheduler_config.max_num_seqs
+ bucket_sizes = [1 << i for i in range(max_bsz.bit_length()) if (1 << i) <= max_bsz]
+ if max_bsz not in bucket_sizes:
+ bucket_sizes.append(max_bsz)
+ self._bucket_sizes = sorted(bucket_sizes)
+
+ max_seq = self._num_groups + 1
+ device = next(self.model.parameters()).device
+
+ # Ensure proj_buf matches model parameter dtype to avoid dtype
+ # mismatch during warmup compilation (see #2385).
+ self._ensure_buffers(device, self._model_dtype, max(self._bucket_sizes))
+ proj_buf = self._proj_buf
+
+ for bsz in self._bucket_sizes:
+ pos_ids = torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(bsz, -1).contiguous()
+ self._bucket_pos_ids[bsz] = pos_ids
+ for _ in range(3):
+ self._compiled_model_fwd(proj_buf[:bsz, :max_seq, :], pos_ids)
+ logger.info("code_predictor: warmup done for buckets %s", self._bucket_sizes)
+
+ def _capture_cuda_graphs(self) -> None:
+ """Capture a CUDA graph per bucket using vLLM's global graph pool."""
+ from vllm.platforms import current_platform
+
+ pool = current_platform.get_global_graph_pool()
+ max_seq = self._num_groups + 1
+ proj_buf = self._proj_buf
+
+ for bsz in self._bucket_sizes:
+ static_input = proj_buf[:bsz, :max_seq, :]
+ pos_ids = self._bucket_pos_ids[bsz]
+
+ g = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(g, pool=pool):
+ static_output = self._compiled_model_fwd(static_input, pos_ids)
+
+ self._cuda_graphs[bsz] = (g, static_output)
+
+ logger.info("code_predictor: captured CUDA graphs for buckets %s", self._bucket_sizes)
+
+ # ------------------------------------------------------------------
+ # Forward -- re-prefill + inline sampling
+ # ------------------------------------------------------------------
+
+ @torch.inference_mode()
+ def forward(
+ self,
+ layer0_code: torch.Tensor,
+ layer0_embed: torch.Tensor,
+ last_talker_hidden: torch.Tensor,
+ do_sample: bool = True,
+ temperature: float = 0.9,
+ top_k: int = 50,
+ top_p: float = 1.0,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ """Predict residual codebooks 1..G-1 autoregressively via re-prefill."""
+ bsz = int(layer0_code.shape[0])
+ num_groups = self._num_groups
+ device = layer0_code.device
+
+ # _setup_compile caches _model_dtype on first call; use it for buffers
+ # so they always match model weight precision (#2385).
+ self._setup_compile()
+ dtype = self._model_dtype
+
+ padded_bsz = self._padded_bsz(bsz)
+ self._ensure_buffers(device, dtype, padded_bsz)
+
+ proj_buf = self._proj_buf
+ max_seq = num_groups + 1
+ projection = self.small_to_mtp_projection
+ model_fwd = self._compiled_model_fwd
+ lm_heads = self._lm_heads_list
+ codec_embeds = self._codec_embeds_list
+
+ # Zero the padded region of the buffer
+ proj_buf[:padded_bsz].zero_()
+
+ # Fill buffer positions 0 (talker hidden) & 1 (layer0 embed)
+ proj_buf[:bsz, 0, :] = projection(last_talker_hidden.reshape(bsz, 1, -1).to(dtype)).reshape(bsz, -1)
+ proj_buf[:bsz, 1, :] = projection(layer0_embed.reshape(bsz, 1, -1).to(dtype)).reshape(bsz, -1)
+
+ # Get pre-computed pos_ids for this bucket
+ full_pos_ids = self._bucket_pos_ids.get(padded_bsz)
+ if full_pos_ids is None:
+ full_pos_ids = (
+ torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(padded_bsz, -1).contiguous()
+ )
+
+ # Use captured CUDA graph if available, otherwise call compiled fn.
+ cuda_graph_entry = self._cuda_graphs.get(padded_bsz)
+
+ # Prepare sampling parameters
+ stored_mode = self._wrapper_config.sampling_mode == "stored"
+ if stored_mode:
+ s_top_k = self._top_k
+ s_top_p = self._top_p
+ else:
+ use_sampling = do_sample and temperature > 0
+ inv_temperature = 1.0 / max(temperature, 1e-6) if use_sampling else 0.0
+ if use_sampling and top_p != 1.0:
+ raise NotImplementedError(
+ "top_p sampling is not implemented for the vLLM-native code predictor; please set top_p=1.0."
+ )
+
+ # Output codes -- shape depends on return mode
+ if self._wrapper_config.return_proj_buf:
+ all_codes = torch.empty(bsz, num_groups, 1, dtype=torch.int64, device=device)
+ all_codes[:, 0] = layer0_code.reshape(bsz, -1)[:, :1]
+ else:
+ all_codes = torch.empty(bsz, num_groups, dtype=torch.long, device=device)
+ all_codes[:, 0] = layer0_code.reshape(bsz)
+
+ # Autoregressive loop: predict layers 1..G-1
+ for step in range(1, num_groups):
+ # Run transformer (CUDA graph replay or compiled forward)
+ if cuda_graph_entry is not None:
+ cuda_graph_entry[0].replay()
+ hidden_out = cuda_graph_entry[1]
+ else:
+ hidden_out = model_fwd(proj_buf[:padded_bsz, :max_seq, :], full_pos_ids)
+
+ logits = lm_heads[step - 1](hidden_out[:bsz, step, :])
+
+ # Sample next code
+ if stored_mode:
+ # "stored" mode: top-k -> top-p -> softmax -> multinomial
+ if s_top_k > 0:
+ topk_vals, _ = logits.topk(s_top_k, dim=-1)
+ logits = logits.masked_fill(logits < topk_vals[:, -1:], float("-inf"))
+ if s_top_p < 1.0:
+ sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True)
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
+ cumulative_probs = sorted_probs.cumsum(dim=-1)
+ remove_mask = (cumulative_probs - sorted_probs) >= s_top_p
+ sorted_logits[remove_mask] = float("-inf")
+ logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
+ probs = F.softmax(logits, dim=-1)
+ code = torch.multinomial(probs, num_samples=1)
+ else:
+ # "per_call" mode: temperature-scaled + top-k
+ if use_sampling:
+ scaled = logits * inv_temperature
+ if top_k > 0:
+ topk_vals, _ = scaled.topk(top_k, dim=-1)
+ scaled = scaled.masked_fill(scaled < topk_vals[:, -1:], float("-inf"))
+ probs = F.softmax(scaled, dim=-1)
+ code = torch.multinomial(probs, num_samples=1)
+ else:
+ code = logits.argmax(dim=-1, keepdim=True)
+
+ # Store code
+ if self._wrapper_config.return_proj_buf:
+ all_codes[:, step] = code
+ else:
+ all_codes[:, step] = code.reshape(bsz)
+
+ # Embed predicted code -> project -> next buffer position
+ if step < num_groups - 1 or self._wrapper_config.return_proj_buf:
+ new_embed = codec_embeds[step - 1](code)
+ proj_buf[:bsz, step + 1, :] = projection(new_embed.reshape(bsz, 1, -1)).reshape(bsz, -1)
+
+ if self._wrapper_config.return_proj_buf:
+ return all_codes, proj_buf[:bsz].clone()
+ return all_codes
+
+ # ------------------------------------------------------------------
+ # Weight loading
+ # ------------------------------------------------------------------
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ """Load weights directly (no fused projection remapping needed)."""
+ loaded: set[str] = set()
+ model_weights: list[tuple[str, torch.Tensor]] = []
+ other_weights: list[tuple[str, torch.Tensor]] = []
+
+ for name, w in weights:
+ if "rotary_emb.inv_freq" in name:
+ continue
+ if name.startswith("model."):
+ model_weights.append((name[len("model.") :], w))
+ else:
+ other_weights.append((name, w))
+
+ loaded_model = self.model.load_weights(model_weights)
+ loaded |= {f"model.{n}" for n in loaded_model}
+
+ params = dict(self.named_parameters(remove_duplicate=False))
+ for name, w in other_weights:
+ param = params.get(name)
+ if param is None:
+ continue
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, w)
+ loaded.add(name)
+
+ return loaded
diff --git a/vllm_omni/model_executor/models/cosyvoice3/assets/mel_filters.npz b/vllm_omni/model_executor/models/cosyvoice3/assets/mel_filters.npz
deleted file mode 100644
index 28ea26909d..0000000000
Binary files a/vllm_omni/model_executor/models/cosyvoice3/assets/mel_filters.npz and /dev/null differ
diff --git a/vllm_omni/model_executor/models/cosyvoice3/utils.py b/vllm_omni/model_executor/models/cosyvoice3/utils.py
index 52c52655e8..0bf0cccb16 100644
--- a/vllm_omni/model_executor/models/cosyvoice3/utils.py
+++ b/vllm_omni/model_executor/models/cosyvoice3/utils.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
-import os
from functools import cache, lru_cache
import numpy as np
@@ -9,7 +8,8 @@
import torch.nn.functional as F
import torchaudio
import torchaudio.compliance.kaldi as kaldi
-from librosa.filters import mel as librosa_mel_fn
+
+from vllm_omni.utils.audio import mel_filter_bank
logger = logging.getLogger(__name__)
@@ -34,8 +34,13 @@ def _get_mel_basis(
fmax: float | None,
device_str: str,
) -> torch.Tensor:
- mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
- return torch.from_numpy(mel).float().to(torch.device(device_str))
+ return mel_filter_bank(
+ sr=sampling_rate,
+ n_fft=n_fft,
+ n_mels=num_mels,
+ fmin=fmin,
+ fmax=fmax,
+ ).to(torch.device(device_str))
@lru_cache
@@ -122,42 +127,8 @@ def exact_div(x, y):
@cache
def mel_filters(device, n_mels: int) -> torch.Tensor:
- """
- load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
- Allows decoupling librosa dependency; saved using:
-
- np.savez_compressed(
- "mel_filters.npz",
- mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
- mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
- )
- """
- assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
-
- filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
- if not os.path.exists(filters_path):
- source_url = "https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/mel_filters.npz"
- os.makedirs(os.path.dirname(filters_path), exist_ok=True)
- try:
- import urllib.request
-
- with urllib.request.urlopen(source_url, timeout=30) as resp:
- with open(filters_path, "wb") as f_out:
- f_out.write(resp.read())
- logger.info("Downloaded mel_filters.npz from %s", source_url)
- except Exception as e:
- raise FileNotFoundError(
- "Missing CosyVoice3 mel filter asset:\n"
- f" {filters_path}\n"
- "Auto-download failed. Download it manually from:\n"
- f" {source_url}\n"
- "Example:\n"
- f" mkdir -p {os.path.dirname(filters_path)} && "
- f"curl -L {source_url} -o {filters_path}"
- ) from e
-
- with np.load(filters_path, allow_pickle=False) as f:
- return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
+ """Compute mel filterbank matrix for projecting STFT into a Mel spectrogram."""
+ return mel_filter_bank(sr=16000, n_fft=400, n_mels=n_mels).to(device)
def log_mel_spectrogram(
diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py
index 8bbb643ebe..22a2744ff5 100644
--- a/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py
+++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py
@@ -310,6 +310,7 @@ def __init__(
self._compiled_model_fwd: object | None = None
self._compile_attempted = False
self._compile_failed = False
+ self._disable_compile_for_graph = False
def _ensure_buffers(self, bsz: int, device: torch.device, dtype: torch.dtype) -> None:
max_seq = self._num_codebooks + 1 # hidden_state + num_codebooks codes
@@ -327,11 +328,20 @@ def _setup_compile(self) -> None:
if self._compile_attempted:
return
self._compile_attempted = True
+ if self._disable_compile_for_graph:
+ try:
+ self._compiled_model_fwd = torch.compile(
+ self.model.forward,
+ dynamic=True,
+ options={"epilogue_fusion": False},
+ )
+ except Exception as exc:
+ logger.warning("Fast AR torch.compile (graph mode) failed: %s", exc)
+ self._compiled_model_fwd = self.model.forward
+ return
try:
self._compiled_model_fwd = torch.compile(
self.model.forward,
- # Keep the helper compiler separate from vLLM's outer
- # cudagraph-managed Stage-0 execution.
mode="default",
dynamic=True,
fullgraph=False,
@@ -366,10 +376,10 @@ def warmup_compile(
@torch.inference_mode()
def _run_model(self, step_input: torch.Tensor, step_pos_ids: torch.Tensor, bsz: int) -> torch.Tensor:
- # Default-on compile only pays off for single-request decode. For
- # batched decode, eager preserves loaded throughput and avoids the
- # regression seen with batch>1 compiled execution.
- model_fwd = self._compiled_model_fwd if bsz == 1 else self.model.forward
+ if self._disable_compile_for_graph:
+ model_fwd = self._compiled_model_fwd or self.model.forward
+ else:
+ model_fwd = self._compiled_model_fwd if bsz == 1 else self.model.forward
try:
return model_fwd(step_input, step_pos_ids)
except Exception as exc:
diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py
index 3813597caa..62776cbb31 100644
--- a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py
+++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py
@@ -194,6 +194,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.has_postprocess = True
self.mtp_hidden_size = int(self.text_config.hidden_size)
self.talker_mtp_output_key = "audio_codes"
+ self.talker_mtp_graph_safe = True
self.gpu_resident_buffer_keys: set[str] = {"last_slow_ar_hidden"}
# Qwen3 transformer backbone.
@@ -236,6 +237,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
slow_ar_config=self.text_config,
prefix="fast_ar",
)
+ if self.talker_mtp_graph_safe:
+ self.fast_ar._disable_compile_for_graph = True
# Constant logit mask: allow only semantic tokens + im_end.
vocab = int(self.text_config.vocab_size)
@@ -680,18 +683,13 @@ def talker_mtp(
inputs_embeds_out = input_embeds.reshape(bsz, -1).clone()
semantic_mask = (input_ids[:, 0] >= self._semantic_begin_id) & (input_ids[:, 0] <= self._semantic_end_id)
- if semantic_mask.any():
- semantic_codes = audio_codes[semantic_mask].clamp(min=0)
- offsets = (
- torch.arange(self._num_codebooks, device=dev, dtype=semantic_codes.dtype) * self._codebook_size
- ).unsqueeze(0)
- codebook_sum = self.codebook_embeddings(semantic_codes + offsets).sum(dim=1).to(dtype=torch.bfloat16)
-
- # Normalize by sqrt(num_codebooks + 1) as in the reference model
- # (scale_codebook_embeddings=True for fish_qwen3_omni).
- inputs_embeds_out[semantic_mask] = (inputs_embeds_out[semantic_mask] + codebook_sum) / math.sqrt(
- self._num_codebooks + 1
- )
+ semantic_codes = audio_codes.clamp(min=0, max=self._codebook_size - 1)
+ offsets = (
+ torch.arange(self._num_codebooks, device=dev, dtype=semantic_codes.dtype) * self._codebook_size
+ ).unsqueeze(0)
+ codebook_sum = self.codebook_embeddings(semantic_codes + offsets).sum(dim=1).to(dtype=torch.bfloat16)
+ norm_embeds = (inputs_embeds_out + codebook_sum) / math.sqrt(self._num_codebooks + 1)
+ inputs_embeds_out = torch.where(semantic_mask.unsqueeze(-1), norm_embeds, inputs_embeds_out)
return inputs_embeds_out, audio_codes.to(dtype=torch.long)
@@ -802,14 +800,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if truncated:
logger.info("Truncated %d RoPE cos_sin_cache buffers to bf16 precision", truncated)
- try:
- self.fast_ar.warmup_compile(
- device=self.codebook_embeddings.weight.device,
- dtype=torch.bfloat16,
- batch_sizes=(1,),
- )
- except Exception as exc:
- logger.warning("Fish Speech Fast AR compile warmup failed: %s", exc)
+ if not getattr(self, "talker_mtp_graph_safe", False):
+ try:
+ self.fast_ar.warmup_compile(
+ device=self.codebook_embeddings.weight.device,
+ dtype=torch.bfloat16,
+ batch_sizes=(1,),
+ )
+ except Exception as exc:
+ logger.warning("Fish Speech Fast AR compile warmup failed: %s", exc)
codec_device = self.codebook_embeddings.weight.device
_load_dac_codec(
diff --git a/vllm_omni/model_executor/models/fish_speech/prompt_utils.py b/vllm_omni/model_executor/models/fish_speech/prompt_utils.py
index 923e97b63a..8b8d8559ea 100644
--- a/vllm_omni/model_executor/models/fish_speech/prompt_utils.py
+++ b/vllm_omni/model_executor/models/fish_speech/prompt_utils.py
@@ -38,10 +38,7 @@ def _encode_plain_text(tokenizer: Any, text: str) -> list[int]:
def _encode_control_token(tokenizer: Any, token: str) -> list[int]:
- vocab = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else {}
- token_id = vocab.get(token)
- if token_id is None:
- token_id = tokenizer.convert_tokens_to_ids(token)
+ token_id = tokenizer.convert_tokens_to_ids(token)
if token_id is None or token_id == getattr(tokenizer, "unk_token_id", None):
raise ValueError(f"Fish Speech tokenizer is missing required control token: {token}")
return [int(token_id)]
diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py
index 6d25274f90..6304eeab29 100644
--- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py
+++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py
@@ -77,7 +77,9 @@
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils.tensor_schema import TensorSchema
+from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
+from vllm.v1.sample.sampler import Sampler
from vllm_omni.model_executor.models.hunyuan_image3.autoencoder_kl_3d import AutoencoderKLConv3D
from vllm_omni.model_executor.models.hunyuan_image3.siglip2 import LightProjector, Siglip2VisionTransformer
@@ -175,8 +177,11 @@ def contains_unexpected_keyword(name, keywords):
return True
return False
+ skipped_unexpected: set[str] = set()
+
for name, loaded_weight in weights:
if contains_unexpected_keyword(name, unexpected_keywords):
+ skipped_unexpected.add(name)
continue
if "rotary_emb.inv_freq" in name:
@@ -362,6 +367,17 @@ def contains_unexpected_keyword(name, keywords):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
+
+ if skipped_unexpected:
+ logger.warning_once(
+ "Skipped %d weights matching unexpected_keywords "
+ "(e.g. vae, vision_model, patch_embed, timestep_emb). "
+ "If upstream renamed components, these may be silently "
+ "lost. Skipped names: %s",
+ len(skipped_unexpected),
+ sorted(skipped_unexpected)[:10],
+ )
+
return loaded_params
@@ -1149,6 +1165,8 @@ class HunyuanImage3ForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo
HunyuanImage3Inputs: TypeAlias = HunyuanImage3PixelInputs
+ prefer_model_sampler = True
+
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -1199,6 +1217,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
else:
self.lm_head = PPMissingLayer()
+ # --- AR-stage components ---
+ # These are needed for image encoding in the AR stage.
+ # If a future text-only stage is added, gate on vllm_config.model_config.model_stage.
+
# vae
self.vae = AutoencoderKLConv3D.from_config(config.vae)
self.patch_embed = UNetDown(
@@ -1226,6 +1248,63 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._mrope_joint_img_sep_token_id = tokenizer.convert_tokens_to_ids("")
self._mrope_max_num_patches = config.vit_processor.get("max_num_patches", 729)
+ # Special token IDs for logits processors (stage transitions).
+ # These mirror the official tokenization_hunyuan_image_3.py setup.
+ self._end_of_think_id = tokenizer.convert_tokens_to_ids(" ")
+ self._recaption_id = tokenizer.convert_tokens_to_ids("")
+ self._end_of_recaption_id = tokenizer.convert_tokens_to_ids(" ")
+ self._answer_id = tokenizer.convert_tokens_to_ids("")
+ self._end_of_answer_id = tokenizer.convert_tokens_to_ids(" ")
+ image_base_size = getattr(config, "image_base_size", 1024)
+ self._size_token_id = tokenizer.convert_tokens_to_ids(f"")
+ self._start_ratio_id = tokenizer.convert_tokens_to_ids("")
+ self._end_ratio_id = tokenizer.convert_tokens_to_ids("")
+ ratio_33 = tokenizer.convert_tokens_to_ids("")
+ ratio_36 = tokenizer.convert_tokens_to_ids("")
+ self._ratio_other_slices = [(ratio_33, ratio_36 + 1)]
+ # Build the full set of ratio token IDs for use as stop tokens.
+ self._all_ratio_ids = set(range(self._start_ratio_id, self._end_ratio_id + 1))
+ for s, e in self._ratio_other_slices:
+ self._all_ratio_ids.update(range(s, e))
+
+ # Determine mode: comprehension (I2T/T2T) vs generation (IT2I/T2I).
+ engine_output_type = getattr(vllm_config.model_config, "engine_output_type", None)
+ self._is_comprehension = engine_output_type in (None, "text")
+
+ # For comprehension mode, block image generation tokens but allow
+ # text structure tokens (, , etc.) so the model can
+ # follow its natural generation pattern. Stop tokens in YAML will
+ # terminate at or EOS.
+ self._blocked_token_ids: set[int] = set()
+ if self._is_comprehension:
+ self._blocked_token_ids.update(
+ [
+ self._mrope_boi_token_id, #
+ self._mrope_eoi_token_id, #
+ self._size_token_id, #
+ ]
+ )
+ self._blocked_token_ids.update(self._all_ratio_ids)
+
+ # For generation mode, build stage transition map.
+ # Official logic: → [],
+ # → [, , ]
+ # After , restrict vocab to ratio tokens only.
+ # Stage-transition forced sequences, keyed by trigger token.
+ self._stage_transitions: dict[int, list[int]] = {}
+ if not self._is_comprehension:
+ self._stage_transitions[self._end_of_think_id] = [
+ self._recaption_id,
+ ]
+ self._stage_transitions[self._end_of_recaption_id] = [
+ self._answer_id,
+ self._mrope_boi_token_id,
+ self._size_token_id,
+ ]
+
+ self._sampler: Sampler | None = None
+ self._eos_token_id: int = tokenizer.eos_token_id
+
self._replace_rotary_embeddings()
def _replace_rotary_embeddings(self):
@@ -1257,6 +1336,12 @@ def _replace_rotary_embeddings(self):
head_dim,
rope_theta,
)
+ if replaced == 0:
+ raise RuntimeError(
+ "HunyuanImage3: _replace_rotary_embeddings replaced 0 layers. "
+ "The custom interleaved 2D mRoPE is not active — model outputs "
+ "will be incorrect. Check that model.layers[*].self_attn.rotary_emb exists."
+ )
def _parse_and_validate_image_input(
self,
@@ -1274,6 +1359,10 @@ def _parse_and_validate_image_input(
if vit_pixel_values is None or vae_pixel_values is None:
return None
+ # Handle empty batch (e.g., during profiling with 0 images / T2T mode)
+ if vit_pixel_values.numel() == 0 or vae_pixel_values.numel() == 0:
+ return None
+
return HunyuanImage3PixelInputs(
type="pixel_values",
pixel_values={
@@ -1472,6 +1561,112 @@ def compute_logits(
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
+ # ------------------------------------------------------------------
+ # Custom sampler — applies HunyuanImage3-specific logits processors
+ # before the standard sampling step.
+ #
+ # Comprehension (I2T / T2T):
+ # Block generation-specific special tokens so sampling can't
+ # accidentally produce , , ratio tokens, etc.
+ #
+ # Generation (IT2I / T2I think):
+ # 1. _StageTransitionLogitsProcessor — force token sequences at
+ # transition boundaries ( → , etc.)
+ # 2. _ConditionalSliceVocabLogitsProcessor — after ,
+ # restrict vocab to ratio tokens only (greedy).
+ # ------------------------------------------------------------------
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> SamplerOutput | None:
+ if logits is None or logits.numel() == 0:
+ return None
+
+ if self._sampler is None:
+ self._sampler = Sampler()
+
+ min_score = torch.finfo(logits.dtype).min
+
+ assert logits.shape[0] == 1, f"HunyuanImage3 sampler requires max_num_seqs=1, got batch size {logits.shape[0]}"
+
+ for req_idx in range(logits.shape[0]):
+ decoded_tokens: list[int] = (
+ sampling_metadata.output_token_ids[req_idx] if req_idx < len(sampling_metadata.output_token_ids) else []
+ )
+ last_token = decoded_tokens[-1] if decoded_tokens else -1
+
+ if self._is_comprehension:
+ for tid in self._blocked_token_ids:
+ logits[req_idx, tid] = min_score
+ else:
+ forced = self._get_forced_token(decoded_tokens)
+ if forced is not None:
+ logits[req_idx].fill_(min_score)
+ logits[req_idx, forced] = 0
+ elif last_token == self._size_token_id:
+ self._apply_ratio_restriction(logits, req_idx, min_score)
+ elif last_token in self._all_ratio_ids:
+ logits[req_idx].fill_(min_score)
+ logits[req_idx, self._eos_token_id] = 0
+
+ return self._sampler(logits=logits, sampling_metadata=sampling_metadata)
+
+ def _get_forced_token(self, decoded_tokens: list[int]) -> int | None:
+ """Derive the next forced token from output history (stateless).
+
+ Scans decoded_tokens backwards for the most recent trigger token,
+ then prefix-matches the forced sequence against what followed.
+ Returns the next token to force, or None if the sequence is complete
+ or history has diverged from the expected forced sequence.
+ """
+ for i in range(len(decoded_tokens) - 1, -1, -1):
+ trigger = decoded_tokens[i]
+ if trigger not in self._stage_transitions:
+ continue
+
+ forced_seq = self._stage_transitions[trigger]
+ emitted = decoded_tokens[i + 1 :]
+
+ matched = 0
+ for expected, actual in zip(forced_seq, emitted):
+ if actual != expected:
+ # History diverged from the expected forced sequence.
+ # Stop applying transition forcing for safety.
+ return None
+ matched += 1
+
+ if matched < len(forced_seq):
+ return forced_seq[matched]
+ return None
+
+ return None
+
+ def _apply_ratio_restriction(
+ self,
+ logits: torch.Tensor,
+ req_idx: int,
+ min_score: float,
+ ) -> None:
+ """Port of official _ConditionalSliceVocabLogitsProcessor.__call__.
+
+ After the size token, only allow ratio tokens and pick greedily.
+ """
+ original = logits[req_idx].clone()
+ logits[req_idx].fill_(min_score)
+ # Allow primary ratio range.
+ logits[req_idx, self._start_ratio_id : self._end_ratio_id + 1] = original[
+ self._start_ratio_id : self._end_ratio_id + 1
+ ]
+ # Allow extra ratio slices.
+ for s, e in self._ratio_other_slices:
+ logits[req_idx, s:e] = original[s:e]
+ # Force greedy: keep only the argmax.
+ max_id = logits[req_idx].argmax().item()
+ logits[req_idx].fill_(min_score)
+ logits[req_idx, max_id] = 0
+
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype, device: torch.device
) -> IntermediateTensors:
@@ -1507,9 +1702,9 @@ def get_mrope_input_positions(
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec] | None = None,
*,
- hf_config: PretrainedConfig,
- image_grid_thw: list[list[int]] | torch.Tensor,
- video_grid_thw: list[list[int]] | torch.Tensor,
+ hf_config: PretrainedConfig | None = None,
+ image_grid_thw: list[list[int]] | torch.Tensor | None = None,
+ video_grid_thw: list[list[int]] | torch.Tensor | None = None,
second_per_grid_ts: list[float] | None = None,
context_len: int = 0,
seq_len: int | None = None,
diff --git a/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py b/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py
index 56cb8788ee..85fe4b0051 100644
--- a/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py
+++ b/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py
@@ -50,6 +50,7 @@
PromptUpdate,
PromptUpdateDetails,
)
+from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema
@@ -150,7 +151,6 @@ def __init__(self, model: "MiMoAudioLLMForConditionalGeneration", max_batch_size
dtype = next(model.hidden_states_downcast.parameters()).dtype
hidden_size = model.local_config.hidden_size
- self.pool = torch.cuda.graph_pool_handle()
self.input_tensor = torch.zeros((max_batch_size, 1, hidden_size), dtype=dtype, device=device)
self.sampler = MiMoLocalSamplerTensor(
temperature=torch.ones(max_batch_size, dtype=torch.float32, device=device),
@@ -231,7 +231,7 @@ def capture(
cuda_graph = torch.cuda.CUDAGraph()
if eager_run_first:
model.base_local_forward(input_tensor, local_sampler=sampler)
- with torch.cuda.graph(cuda_graph, buffer.pool):
+ with torch.cuda.graph(cuda_graph, pool=current_platform.get_global_graph_pool()):
output_tensor = model.base_local_forward(input_tensor, local_sampler=sampler)
return cls(
@@ -263,7 +263,6 @@ def __init__(self, model: "MiMoAudioLLMForConditionalGeneration", max_batch_size
hidden_size = model.input_local_config.hidden_size
group_size = model.group_size
- self.pool = torch.cuda.graph_pool_handle()
self.input_tensor = torch.zeros((max_batch_size, group_size, hidden_size), dtype=dtype, device=device)
self.lock = threading.Lock()
@@ -311,7 +310,7 @@ def capture(
out = model.input_local_transformer(inputs_embeds=input_tensor, return_dict=True, is_causal=False)
_ = out.last_hidden_state
- with torch.cuda.graph(cuda_graph, buffer.pool):
+ with torch.cuda.graph(cuda_graph, pool=current_platform.get_global_graph_pool()):
out = model.input_local_transformer(inputs_embeds=input_tensor, return_dict=True, is_causal=False)
output_tensor = out.last_hidden_state
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/__init__.py b/vllm_omni/model_executor/models/ming_flash_omni/__init__.py
new file mode 100644
index 0000000000..d7fa44fd7e
--- /dev/null
+++ b/vllm_omni/model_executor/models/ming_flash_omni/__init__.py
@@ -0,0 +1,18 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+
+from .ming_flash_omni import MingFlashOmniForConditionalGeneration
+from .ming_flash_omni_thinker import (
+ MingFlashOmniThinkerDummyInputsBuilder,
+ MingFlashOmniThinkerForConditionalGeneration,
+ MingFlashOmniThinkerMultiModalProcessor,
+ MingFlashOmniThinkerProcessingInfo,
+)
+
+__all__ = [
+ "MingFlashOmniForConditionalGeneration",
+ "MingFlashOmniThinkerForConditionalGeneration",
+ "MingFlashOmniThinkerProcessingInfo",
+ "MingFlashOmniThinkerMultiModalProcessor",
+ "MingFlashOmniThinkerDummyInputsBuilder",
+]
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/audio_encoder.py b/vllm_omni/model_executor/models/ming_flash_omni/audio_encoder.py
new file mode 100644
index 0000000000..6ca1990114
--- /dev/null
+++ b/vllm_omni/model_executor/models/ming_flash_omni/audio_encoder.py
@@ -0,0 +1,246 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+# Copyright 2024 ANT Group and the HuggingFace Inc. team.
+# Copyright (c) 2022 OpenAI
+# Adapted from Ming repository modeling_whisper_encoder.py
+# https://github.com/inclusionAI/Ming
+
+import operator
+from collections.abc import Iterable
+from itertools import accumulate
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from vllm.logger import init_logger
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+
+from vllm_omni.diffusion.attention.backends.utils.fa import HAS_FLASH_ATTN, flash_attn_varlen_func
+from vllm_omni.model_executor.models.whisper_utils import Conv1d, Linear, sinusoids
+
+logger = init_logger(__name__)
+
+
+class MultiHeadAttention(nn.Module):
+ """Multi-head attention with packed sequence support.
+ Adapted from Qwen3-TTS WhisperEncoder.
+ """
+
+ def __init__(self, n_state: int, n_head: int, use_flash_attn: bool = True):
+ super().__init__()
+ self.n_head = n_head
+ self.query = Linear(n_state, n_state)
+ self.key = Linear(n_state, n_state, bias=False)
+ self.value = Linear(n_state, n_state)
+ self.out = Linear(n_state, n_state)
+
+ if use_flash_attn and not HAS_FLASH_ATTN:
+ logger.warning("flash-attn is not available. Fallback to manual PyTorch version")
+ self.use_flash_attn = use_flash_attn and HAS_FLASH_ATTN
+
+ def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
+ """Forward pass with packed sequence support.
+
+ Args:
+ x: [total_tokens, n_state] packed sequence
+ cu_seqlens: [num_seqs + 1] cumulative sequence lengths, e.g. [0, len1, len1+len2, ...]
+
+ Returns:
+ [total_tokens, n_state] attention output
+ """
+ q = self.query(x)
+ k = self.key(x)
+ v = self.value(x)
+
+ n_ctx, n_state = q.shape
+ head_dim = n_state // self.n_head
+
+ q = q.view(n_ctx, self.n_head, head_dim)
+ k = k.view(n_ctx, self.n_head, head_dim)
+ v = v.view(n_ctx, self.n_head, head_dim)
+
+ # Try flash attention varlen
+ if self.use_flash_attn and cu_seqlens is not None and q.dtype in [torch.float16, torch.bfloat16]:
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+ attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen)
+ else:
+ attn_output = self._manual_attention(q, k, v, cu_seqlens)
+
+ # Reshape back: [T, H, D] -> [T, H*D]
+ attn_output = attn_output.contiguous().view(n_ctx, n_state)
+ return self.out(attn_output)
+
+ def _manual_attention(
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens: torch.Tensor
+ ) -> torch.Tensor:
+ """Manual attention for variable-length sequences (fallback)."""
+ _, n_head, head_dim = q.shape
+ scale = head_dim**-0.5
+
+ # Unpack sequences and pad to max length
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ batch_size = len(seqlens)
+ max_seqlen = max(seqlens)
+
+ # Create padded tensors
+ q_padded = torch.zeros(batch_size, max_seqlen, n_head, head_dim, dtype=q.dtype, device=q.device)
+ k_padded = torch.zeros_like(q_padded)
+ v_padded = torch.zeros_like(q_padded)
+
+ # Fill with actual sequences
+ for i in range(batch_size):
+ start_idx = cu_seqlens[i]
+ end_idx = cu_seqlens[i + 1]
+ seq_len = seqlens[i]
+ q_padded[i, :seq_len] = q[start_idx:end_idx]
+ k_padded[i, :seq_len] = k[start_idx:end_idx]
+ v_padded[i, :seq_len] = v[start_idx:end_idx]
+
+ # Transpose for attention: [B, H, T, D]
+ q_padded = q_padded.transpose(1, 2)
+ k_padded = k_padded.transpose(1, 2)
+ v_padded = v_padded.transpose(1, 2)
+
+ # Create attention mask for variable lengths: 0 for valid positions, -inf for padding
+ padding_mask = (
+ torch.arange(max_seqlen, device=q.device)[None, :] >= torch.tensor(seqlens, device=q.device)[:, None]
+ )
+ attn_mask = torch.zeros(batch_size, 1, 1, max_seqlen, dtype=q.dtype, device=q.device)
+ attn_mask = attn_mask.masked_fill(padding_mask.unsqueeze(1).unsqueeze(2), -torch.finfo(q.dtype).max)
+
+ # Compute attention
+ attn_scores = torch.matmul(q_padded, k_padded.transpose(-2, -1)) * scale
+ attn_scores = attn_scores + attn_mask
+ attn_weights = F.softmax(attn_scores, dim=-1)
+ context = torch.matmul(attn_weights, v_padded)
+
+ # Transpose back: [B, H, T, D] -> [B, T, H, D]
+ context = context.transpose(1, 2).contiguous()
+ output_packed = torch.cat([context[i, : seqlens[i]] for i in range(batch_size)], dim=0)
+
+ return output_packed
+
+
+class ResidualAttentionBlock(nn.Module):
+ """Whisper-style residual attention block with packed sequence support.
+
+ Adapted from
+ https://github.com/openai/whisper/blob/v20250625/whisper/model.py
+ vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py
+ """
+
+ def __init__(self, n_state: int, n_head: int, use_flash_attn: bool = True):
+ super().__init__()
+ self.attn = MultiHeadAttention(n_state, n_head, use_flash_attn=use_flash_attn)
+ self.attn_ln = nn.LayerNorm(n_state)
+
+ n_mlp = n_state * 4
+ self.mlp = nn.Sequential(
+ Linear(n_state, n_mlp),
+ nn.GELU(),
+ Linear(n_mlp, n_state),
+ )
+ self.mlp_ln = nn.LayerNorm(n_state)
+
+ def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
+ x = x + self.attn(self.attn_ln(x), cu_seqlens=cu_seqlens)
+ x = x + self.mlp(self.mlp_ln(x))
+ return x
+
+
+class WhisperAudioEncoder(nn.Module):
+ """Whisper audio encoder for Ming with packed sequence support.
+
+ Adapted from
+ https://github.com/openai/whisper/blob/v20250625/whisper/model.py
+ vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py
+ """
+
+ def __init__(
+ self,
+ n_mels: int = 128,
+ n_ctx: int = 15000,
+ n_state: int = 1280,
+ n_head: int = 20,
+ n_layer: int = 32,
+ use_flash_attn: bool = True,
+ ):
+ super().__init__()
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
+ # self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
+ self.blocks = nn.ModuleList(
+ [ResidualAttentionBlock(n_state, n_head, use_flash_attn=use_flash_attn) for _ in range(n_layer)]
+ )
+ self.ln_post = nn.LayerNorm(n_state)
+ self.audio_emb_dim = n_state
+
+ self.n_layer = n_layer
+ self.n_mels = n_mels
+ self.use_flash_attn = use_flash_attn
+
+ def forward(
+ self,
+ x_list: list[torch.Tensor],
+ audio_lens: list[int],
+ ) -> torch.Tensor:
+ """Forward pass with packed sequence format for variable-length inputs.
+
+ Args:
+ x_list: List of [n_mels, T_i] mel spectrogram features for each audio
+ audio_lens: List of original audio lengths in frames
+
+ Returns:
+ [total_T', n_state] packed encoded audio features, where
+ total_T' is the sum of all encoded sequence lengths
+ """
+ # Cast inputs to model dtype
+ target_dtype = self.conv1.weight.dtype
+ x_list = [x.to(target_dtype) for x in x_list]
+
+ encoded_list = []
+ encoded_lens = []
+ for mel_spec in x_list:
+ # mel_spec: [n_mels, T] - process through conv layers
+ x = mel_spec.unsqueeze(0) # [1, n_mels, T]
+ x = F.gelu(self.conv1(x))
+ x = F.gelu(self.conv2(x))
+ x = x.squeeze(0).transpose(0, 1) # [T', n_state]
+
+ # Add positional embedding
+ seq_len = x.shape[0]
+ positional_embedding = self.positional_embedding[:seq_len, :]
+ x = (x + positional_embedding).to(x.dtype)
+
+ encoded_list.append(x)
+ encoded_lens.append(seq_len)
+
+ x_packed = torch.cat(encoded_list, dim=0) # [total_T', n_state]
+
+ cu_seqlens = list(accumulate(encoded_lens, func=operator.add, initial=0))
+ cu_seqlens = torch.tensor(cu_seqlens, device=x_packed.device, dtype=torch.int32)
+
+ for block in self.blocks:
+ x_packed = block(x_packed, cu_seqlens=cu_seqlens)
+
+ x_packed = self.ln_post(x_packed)
+ return x_packed
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ params_dict: dict[str, torch.Tensor] = {
+ **dict(self.named_parameters(remove_duplicate=False)),
+ **dict(self.named_buffers()),
+ }
+ loaded_params: set[str] = set()
+
+ for name, loaded_weight in weights:
+ if name not in params_dict:
+ logger.warning("Skipping unknown audio encoder weight: %s", name)
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+
+ return loaded_params
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py
new file mode 100644
index 0000000000..87728890b6
--- /dev/null
+++ b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py
@@ -0,0 +1,223 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+# Copyright 2024 ANT Group and the HuggingFace Inc. team. All rights reserved.
+# Adapted from Ming repository modeling_bailingmm2.py
+# https://github.com/inclusionAI/Ming
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Ming-flash-omni-2.0 unified model (thinker + imagegen + talker)."""
+
+from collections.abc import Iterable
+
+import torch
+import torch.nn as nn
+from vllm.config import VllmConfig
+from vllm.logger import init_logger
+from vllm.model_executor.models.interfaces import (
+ SupportsMRoPE,
+ SupportsMultiModal,
+ SupportsPP,
+)
+from vllm.model_executor.models.module_mapping import MultiModelKeys
+from vllm.model_executor.models.utils import (
+ init_vllm_registered_model,
+ maybe_prefix,
+)
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.sequence import IntermediateTensors
+
+from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin
+from vllm_omni.model_executor.models.output_templates import OmniOutput
+from vllm_omni.model_executor.models.utils import add_prefix_to_loaded_weights
+from vllm_omni.transformers_utils.configs.ming_flash_omni import BailingMM2Config, MingFlashOmniConfig
+
+from .ming_flash_omni_thinker import (
+ MingFlashOmniThinkerDummyInputsBuilder,
+ MingFlashOmniThinkerMultiModalProcessor,
+ MingFlashOmniThinkerProcessingInfo,
+)
+
+logger = init_logger(__name__)
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ MingFlashOmniThinkerMultiModalProcessor,
+ info=MingFlashOmniThinkerProcessingInfo,
+ dummy_inputs=MingFlashOmniThinkerDummyInputsBuilder,
+)
+class MingFlashOmniForConditionalGeneration(
+ nn.Module,
+ SupportsMultiModal,
+ SupportsPP,
+ SupportsMRoPE,
+ CustomProcessMixin,
+):
+ """Unified Ming-flash-omni-2.0 model combining thinker, imagegen, and talker."""
+
+ supports_multimodal = True
+ requires_raw_input_tokens: bool = True
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ self.have_multimodal_outputs = True
+ self.has_preprocess = False
+ self.has_postprocess = False
+
+ config = vllm_config.model_config.hf_config
+
+ self.vllm_config = vllm_config
+ self.config = config
+
+ if isinstance(config, MingFlashOmniConfig):
+ thinker_config = config.thinker_config
+ else:
+ thinker_config = config
+
+ self.thinker_config: BailingMM2Config = thinker_config
+ self.model_stage = vllm_config.model_config.model_stage
+
+ if self.model_stage == "thinker":
+ thinker_vllm_config = vllm_config.with_hf_config(
+ thinker_config, architectures=["MingFlashOmniThinkerForConditionalGeneration"]
+ )
+ self.thinker = init_vllm_registered_model(
+ vllm_config=thinker_vllm_config,
+ prefix=maybe_prefix(prefix, "thinker"),
+ architectures=["MingFlashOmniThinkerForConditionalGeneration"],
+ )
+ self.model = self.thinker
+ self.imagegen = None
+ self.talker = None
+
+ elif self.model_stage == "imagegen":
+ # TODO: Implement image generator stage
+ raise NotImplementedError(
+ "Image generation stage is not yet implemented. Please use model_stage='thinker' for now."
+ )
+
+ elif self.model_stage == "talker":
+ # TODO: Implement talker (TTS) stage
+ raise NotImplementedError(
+ "Talker (TTS) stage is not yet implemented. Please use model_stage='thinker' for now."
+ )
+
+ else:
+ raise ValueError(
+ f"Invalid model_stage: {self.model_stage}. Must be one of: 'thinker', 'imagegen', 'talker'"
+ )
+
+ # Set up intermediate tensors
+ self.make_empty_intermediate_tensors = (
+ self.thinker.make_empty_intermediate_tensors if self.model_stage == "thinker" else lambda: None
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ **kwargs,
+ ) -> OmniOutput:
+ return self.model.forward(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ **kwargs,
+ )
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata=None,
+ ) -> torch.Tensor | None:
+ if hasattr(self.model, "compute_logits"):
+ return self.model.compute_logits(hidden_states, sampling_metadata)
+ return None
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata,
+ ):
+ if hasattr(self.model, "sample"):
+ return self.model.sample(logits, sampling_metadata)
+ raise NotImplementedError("sample method not available on current stage")
+
+ def get_mrope_input_positions(self, *args, **kwargs):
+ if hasattr(self.model, "get_mrope_input_positions"):
+ return self.model.get_mrope_input_positions(*args, **kwargs)
+ raise NotImplementedError("get_mrope_input_positions not available on current stage")
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ loaded_weights = set()
+ thinker_weights = []
+ imagegen_weights = []
+ talker_weights = []
+
+ for name, value in weights:
+ if name.startswith("thinker."):
+ thinker_weights.append((name, value))
+ elif name.startswith("imagegen."):
+ imagegen_weights.append((name, value))
+ elif name.startswith("talker."):
+ talker_weights.append((name, value))
+ else:
+ # Weights without prefix go to thinker by default
+ thinker_weights.append((name, value))
+
+ if self.model_stage == "thinker" and thinker_weights:
+ # Remove "thinker." prefix before loading
+ thinker_weights_stripped = [
+ (name.replace("thinker.", "", 1) if name.startswith("thinker.") else name, value)
+ for name, value in thinker_weights
+ ]
+ thinker_loaded = self.thinker.load_weights(thinker_weights_stripped)
+ thinker_loaded = add_prefix_to_loaded_weights(thinker_loaded, "thinker")
+ loaded_weights.update(thinker_loaded)
+
+ # TODO: Load imagegen weights when implemented
+ # TODO: Load talker weights when implemented
+
+ return loaded_weights
+
+ def get_mm_mapping(self) -> MultiModelKeys:
+ return MultiModelKeys.from_string_field(
+ language_model="thinker.language_model",
+ connector=["thinker.linear_proj.", "thinker.linear_proj_audio."],
+ tower_model=["thinker.vision.", "thinker.audio."],
+ )
+
+ @property
+ def sampler(self):
+ if hasattr(self.model, "sampler"):
+ return self.model.sampler
+ return None
+
+ def embed_input_ids(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings=None,
+ *,
+ is_multimodal=None,
+ ) -> torch.Tensor:
+ return self.model.embed_input_ids(
+ input_ids,
+ multimodal_embeddings,
+ is_multimodal=is_multimodal,
+ )
+
+ def embed_multimodal(self, **kwargs):
+ return self.model.embed_multimodal(**kwargs)
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_thinker.py b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_thinker.py
new file mode 100644
index 0000000000..bde7477b94
--- /dev/null
+++ b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_thinker.py
@@ -0,0 +1,893 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+# Copyright 2024 ANT Group and the HuggingFace Inc. team.
+# Adapted from Ming repository modeling_bailingmm2.py and processing_bailingmm2.py
+# https://github.com/inclusionAI/Ming
+
+"""Ming-flash-omni-2.0 Thinker stage implementation (multimodal understanding)."""
+
+from collections.abc import Iterable, Iterator, Mapping, Sequence
+from typing import Annotated, Any
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers.feature_extraction_utils import BatchFeature
+from vllm.config import VllmConfig
+from vllm.config.multimodal import BaseDummyOptions
+from vllm.inputs import MultiModalDataDict
+from vllm.logger import init_logger
+from vllm.model_executor.models.interfaces import (
+ MultiModalEmbeddings,
+ SupportsMRoPE,
+ SupportsMultiModal,
+ SupportsPP,
+)
+from vllm.model_executor.models.qwen2_5_vl import (
+ Qwen2_5_VLImageInputs,
+ Qwen2_5_VLImagePixelInputs,
+ Qwen2_5_VLVideoInputs,
+ Qwen2_5_VLVideoPixelInputs,
+)
+from vllm.model_executor.models.qwen2_vl import (
+ Qwen2VLProcessingInfo,
+)
+from vllm.model_executor.models.utils import (
+ AutoWeightsLoader,
+ WeightsMapper,
+ _merge_multimodal_embeddings,
+ maybe_prefix,
+)
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (
+ MultiModalFeatureSpec,
+ MultiModalFieldConfig,
+ MultiModalKwargsItems,
+)
+from vllm.multimodal.parse import (
+ AudioProcessorItems,
+ ImageProcessorItems,
+ MultiModalDataItems,
+ MultiModalDataParser,
+ VideoProcessorItems,
+)
+from vllm.multimodal.processing import (
+ BaseDummyInputsBuilder,
+ BaseMultiModalProcessor,
+ PromptReplacement,
+ PromptUpdate,
+ PromptUpdateDetails,
+)
+from vllm.sequence import IntermediateTensors
+from vllm.utils.tensor_schema import TensorSchema, TensorShape
+
+from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin
+from vllm_omni.model_executor.models.output_templates import OmniOutput
+from vllm_omni.transformers_utils.configs.ming_flash_omni import BailingMM2Config
+from vllm_omni.transformers_utils.processors.ming import (
+ PLACEHOLDER_AUDIO_TOKEN_IN_TEXT,
+ PLACEHOLDER_IMAGE_TOKEN_IN_TEXT,
+ PLACEHOLDER_VIDEO_TOKEN_IN_TEXT,
+ MingFlashOmniProcessor,
+ MingWhisperFeatureExtractor,
+)
+
+from .audio_encoder import WhisperAudioEncoder
+from .modeling_bailing_moe_v2 import BailingMoeV2ForCausalLM
+from .projectors import AudioProjector, VisionProjector
+from .vision_encoder import MingVisionEncoder
+
+logger = init_logger(__name__)
+
+
+class MingAudioInput(TensorSchema):
+ """
+ Dimensions:
+ - b: Batch size
+ - l: Total audio frames (clips concatenated along the time axis)
+ - nm: Number of mel bins
+ - N: Max number of audio clips per batch item
+ """
+
+ audio_feats: Annotated[
+ torch.Tensor,
+ TensorShape("b", "l", "nm"),
+ ]
+
+ audio_feats_lengths: Annotated[
+ torch.Tensor,
+ TensorShape("b", "N"),
+ ]
+
+
+class MingFlashOmniThinkerProcessingInfo(Qwen2VLProcessingInfo):
+ def get_hf_config(self) -> BailingMM2Config:
+ return self.ctx.get_hf_config(BailingMM2Config)
+
+ def get_hf_processor(self, **kwargs: object):
+ return self.ctx.get_hf_processor(MingFlashOmniProcessor, **kwargs)
+
+ def get_target_channels(self) -> int:
+ # See `_normalize_audio_tensor` in vllm_omni/transformers_utils/processors/ming.py
+ return 1
+
+ def get_supported_mm_limits(self) -> Mapping[str, int | None]:
+ return {"image": None, "video": None, "audio": None}
+
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ mm_counts = mm_counts or {}
+ requested_modalities = {m for m, c in mm_counts.items() if c > 0}
+ mm_max_tokens: dict[str, int] = {}
+
+ if requested_modalities & {"image", "video"}:
+ vl_tokens = super().get_mm_max_tokens_per_item(
+ seq_len=seq_len,
+ mm_counts=mm_counts,
+ )
+ mm_max_tokens.update({m: vl_tokens[m] for m in ["image", "video"] if m in requested_modalities})
+
+ if "audio" in requested_modalities:
+ # TODO: consider computing from audio config
+ mm_max_tokens["audio"] = 3000
+
+ return mm_max_tokens
+
+ def get_feature_extractor(self, **kwargs: object) -> MingWhisperFeatureExtractor:
+ hf_processor = self.get_hf_processor(**kwargs)
+ feature_extractor = hf_processor.audio_processor
+ assert isinstance(feature_extractor, MingWhisperFeatureExtractor)
+ return feature_extractor
+
+ def get_data_parser(self):
+ feature_extractor = self.get_feature_extractor()
+ return MultiModalDataParser(
+ target_sr=feature_extractor.sampling_rate,
+ target_channels=self.get_target_channels(),
+ expected_hidden_size=self._get_expected_hidden_size(),
+ )
+
+
+class MingFlashOmniThinkerDummyInputsBuilder(BaseDummyInputsBuilder[MingFlashOmniThinkerProcessingInfo]):
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_images = mm_counts.get("image", 0)
+ num_videos = mm_counts.get("video", 0)
+ num_audios = mm_counts.get("audio", 0)
+
+ hf_processor = self.info.get_hf_processor()
+
+ audio_token: str = hf_processor.audio_token
+ image_token: str = hf_processor.image_token
+ video_token: str = hf_processor.video_token
+
+ return image_token * num_images + video_token * num_videos + audio_token * num_audios
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ mm_options: Mapping[str, BaseDummyOptions] | None = None,
+ ) -> MultiModalDataDict:
+ num_images = mm_counts.get("image", 0)
+ num_videos = mm_counts.get("video", 0)
+ num_audios = mm_counts.get("audio", 0)
+
+ # Default dimensions for dummy data
+ image_width, image_height = 448, 448
+ video_width, video_height = 448, 448
+ num_frames = 8
+ audio_duration = 3.0 # seconds
+ sample_rate = 16000
+
+ audio_length = int(audio_duration * sample_rate)
+
+ mm_data: MultiModalDataDict = {
+ "image": self._get_dummy_images(
+ width=image_width,
+ height=image_height,
+ num_images=num_images,
+ ),
+ "video": self._get_dummy_videos(
+ width=video_width,
+ height=video_height,
+ num_frames=num_frames,
+ num_videos=num_videos,
+ ),
+ "audio": [(np.random.randn(audio_length).astype(np.float32), sample_rate) for _ in range(num_audios)],
+ }
+
+ return mm_data
+
+
+class MingFlashOmniThinkerMultiModalProcessor(BaseMultiModalProcessor[MingFlashOmniThinkerProcessingInfo]):
+ """Multimodal processor for Ming-flash-omni Thinker stage.
+
+ Handles preprocessing of 1) image, 2) video, and 3) audio inputs,
+ and expands placeholder tokens to the correct number of patch tokens.
+ """
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargsItems,
+ ) -> Sequence[PromptUpdate]:
+ tokenizer = self.info.get_tokenizer()
+ # might want to add a fallback to resolve token ids
+ # vocab = tokenizer.get_vocab()
+ thinker_config = self.info.get_hf_config()
+
+ # patch/delimiter token IDs (used in replacement sequences)
+ image_start_token_id = thinker_config.llm_config.image_start_token
+ image_patch_token_id = thinker_config.llm_config.image_patch_token
+ image_end_token_id = thinker_config.llm_config.image_end_token
+
+ video_start_token_id = thinker_config.llm_config.video_start_token
+ frame_patch_token_id = thinker_config.llm_config.video_patch_token
+ video_end_token_id = thinker_config.llm_config.video_end_token
+
+ audio_start_token_id = thinker_config.llm_config.audio_start_token
+ audio_patch_token_id = thinker_config.llm_config.audio_patch_token
+ audio_end_token_id = thinker_config.llm_config.audio_end_token
+
+ vision_config = thinker_config.vision_config
+ spatial_merge_size = vision_config.spatial_merge_size if vision_config else 2
+
+ newline_token_ids: list[int] = tokenizer.encode("\n", add_special_tokens=False)
+
+ out_mm_data = out_mm_kwargs.get_data()
+
+ def get_replacement_image(item_idx: int) -> PromptUpdateDetails:
+ """Generate token sequence for an image."""
+ grid_thw = out_mm_data.get("image_grid_thw")
+ if grid_thw is None:
+ raise ValueError(
+ "image_grid_thw missing from processor output; "
+ "cannot determine image patch count for prompt replacement."
+ )
+ if isinstance(grid_thw, torch.Tensor):
+ thw = grid_thw[item_idx]
+ num_patches = int(thw.prod().item()) // (spatial_merge_size**2)
+ else:
+ thw = grid_thw[item_idx]
+ num_patches = (thw[0] * thw[1] * thw[2]) // (spatial_merge_size**2)
+
+ # Build token sequence: *N \n
+ # the newline token is added in purpose from original model processing
+ tokens: list[int] = []
+ tokens.append(image_start_token_id)
+ tokens.extend([image_patch_token_id] * num_patches)
+ tokens.append(image_end_token_id)
+ # Refer to Ming's BailingMM2Processor._expand_image_tokens
+ # https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/processing_bailingmm2.py
+ tokens.extend(newline_token_ids)
+
+ # Only tokens receive multimodal embeddings
+ return PromptUpdateDetails.select_token_id(tokens, image_patch_token_id)
+
+ def get_replacement_video(item_idx: int) -> PromptUpdateDetails:
+ """Generate token sequence for a video."""
+ grid_thw = out_mm_data.get("video_grid_thw", None)
+ if grid_thw is None:
+ raise ValueError(
+ "video_grid_thw missing from processor output; "
+ "cannot determine video patch count for prompt replacement."
+ )
+ if isinstance(grid_thw, torch.Tensor):
+ thw = grid_thw[item_idx]
+ num_patches = int(thw.prod().item()) // (spatial_merge_size**2)
+ else:
+ thw = grid_thw[item_idx]
+ num_patches = (thw[0] * thw[1] * thw[2]) // (spatial_merge_size**2)
+
+ # Build token sequence: *N \n
+ # the newline token is added in purpose from original model processing
+ tokens: list[int] = []
+ tokens.append(video_start_token_id)
+ tokens.extend([frame_patch_token_id] * num_patches)
+ tokens.append(video_end_token_id)
+ tokens.extend(newline_token_ids)
+
+ # Only tokens receive multimodal embeddings
+ return PromptUpdateDetails.select_token_id(tokens, frame_patch_token_id)
+
+ def get_replacement_audio(item_idx: int) -> PromptUpdateDetails:
+ """Generate token sequence for an audio."""
+ encoder_feats_lengths = out_mm_data.get("encoder_feats_lengths", None)
+ if encoder_feats_lengths is None:
+ raise ValueError(
+ "encoder_feats_lengths missing from processor output; "
+ "cannot determine audio patch count for prompt replacement."
+ )
+ if isinstance(encoder_feats_lengths, torch.Tensor):
+ num_patches = int(encoder_feats_lengths[item_idx].item())
+ else:
+ num_patches = encoder_feats_lengths[item_idx]
+
+ # Build token sequence: *N
+ tokens: list[int] = []
+ tokens.append(audio_start_token_id)
+ tokens.extend([audio_patch_token_id] * num_patches)
+ tokens.append(audio_end_token_id)
+
+ # Only tokens receive multimodal embeddings
+ return PromptUpdateDetails.select_token_id(tokens, audio_patch_token_id)
+
+ # Build prompt updates and process replacement
+ updates: list[PromptUpdate] = []
+
+ if "image" in mm_items and mm_items.get_items("image", ImageProcessorItems):
+ updates.append(
+ PromptReplacement(
+ modality="image",
+ target=PLACEHOLDER_IMAGE_TOKEN_IN_TEXT,
+ replacement=get_replacement_image,
+ )
+ )
+ if "video" in mm_items and mm_items.get_items("video", VideoProcessorItems):
+ updates.append(
+ PromptReplacement(
+ modality="video",
+ target=PLACEHOLDER_VIDEO_TOKEN_IN_TEXT,
+ replacement=get_replacement_video,
+ )
+ )
+ if "audio" in mm_items and mm_items.get_items("audio", AudioProcessorItems):
+ updates.append(
+ PromptReplacement(
+ modality="audio",
+ target=PLACEHOLDER_AUDIO_TOKEN_IN_TEXT,
+ replacement=get_replacement_audio,
+ )
+ )
+ return updates
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ config: dict[str, MultiModalFieldConfig] = {}
+
+ # Image fields, pixel_values is flat (concatenated patches from all images)
+ image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
+ if "pixel_values" in hf_inputs:
+ image_sizes = image_grid_thw.prod(-1)
+ config["pixel_values"] = MultiModalFieldConfig.flat_from_sizes(
+ "image",
+ image_sizes,
+ )
+ if "image_grid_thw" in hf_inputs:
+ config["image_grid_thw"] = MultiModalFieldConfig.batched("image")
+
+ # Video fields, same flat layout as images
+ video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
+ if "pixel_values_videos" in hf_inputs:
+ video_sizes = video_grid_thw.prod(-1)
+ config["pixel_values_videos"] = MultiModalFieldConfig.flat_from_sizes(
+ "video",
+ video_sizes,
+ )
+ if "video_grid_thw" in hf_inputs:
+ config["video_grid_thw"] = MultiModalFieldConfig.batched("video")
+
+ # Audio fields
+ if "audio_feats" in hf_inputs:
+ config["audio_feats"] = MultiModalFieldConfig.batched("audio")
+ if "audio_feats_lengths" in hf_inputs:
+ config["audio_feats_lengths"] = MultiModalFieldConfig.batched("audio")
+ if "encoder_feats_lengths" in hf_inputs:
+ config["encoder_feats_lengths"] = MultiModalFieldConfig.batched("audio")
+ if "placeholder_audio_loc_lens" in hf_inputs:
+ config["placeholder_audio_loc_lens"] = MultiModalFieldConfig.batched("audio")
+
+ return config
+
+ def _hf_processor_applies_updates(
+ self,
+ prompt_text: str,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ tokenization_kwargs: Mapping[str, object],
+ ) -> bool:
+ return False
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ tok_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ """Call sub-processors for multimodal inputs and tokenize.
+
+ We call the image/audio sub-processors directly (instead of going
+ through `MingFlashOmniProcessor.__call__`) so that the high-level
+ placeholder tokens remain **unexpanded** in the tokenized output.
+ """
+ hf_processor = self.info.get_hf_processor()
+ tokenizer = self.info.get_tokenizer()
+
+ data: dict[str, object] = {}
+
+ images = mm_data.get("images", None)
+ if images is not None:
+ image_outputs = hf_processor.image_processor(
+ images=images,
+ videos=None,
+ return_tensors="pt",
+ )
+ data.update(image_outputs)
+
+ videos = mm_data.get("videos", None)
+ if videos is not None:
+ video_outputs = hf_processor.image_processor(
+ images=None,
+ videos=videos,
+ return_tensors="pt",
+ )
+ # Rename keys to distinguish from images
+ if "pixel_values" in video_outputs:
+ video_outputs["pixel_values_videos"] = video_outputs.pop("pixel_values")
+ if "image_grid_thw" in video_outputs:
+ video_outputs["video_grid_thw"] = video_outputs.pop("image_grid_thw")
+ data.update(video_outputs)
+
+ audios = mm_data.get("audios", None)
+ if audios is not None:
+ # vLLM's AudioProcessorItems provides raw numpy arrays (already resampled).
+ # MingWhisperAudioProcessor expects (waveform, sr) tuples,
+ # so wrap them with the target sample rate.
+ target_sr = hf_processor.audio_processor.sampling_rate
+ audio_tuples = [(a, target_sr) if not isinstance(a, tuple) else a for a in audios]
+
+ audio_outputs = hf_processor.audio_processor(
+ audio_tuples,
+ return_tensors="pt",
+ )
+ data.update(audio_outputs)
+
+ # Tokenize text with placeholders still intact
+ text_outputs = tokenizer(prompt, return_tensors="pt", **tok_kwargs)
+ data.update(text_outputs)
+
+ return BatchFeature(data=data)
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ MingFlashOmniThinkerMultiModalProcessor,
+ info=MingFlashOmniThinkerProcessingInfo,
+ dummy_inputs=MingFlashOmniThinkerDummyInputsBuilder,
+)
+class MingFlashOmniThinkerForConditionalGeneration(
+ nn.Module,
+ SupportsMultiModal,
+ SupportsPP,
+ SupportsMRoPE,
+ CustomProcessMixin,
+):
+ """Ming Thinker stage: multimodal understanding
+ (text + image + video + audio) -> text generation.
+ """
+
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_prefix={"model.": "language_model."},
+ )
+
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> str | None:
+ # vllm_omni/transformers_utils/processors/ming.py
+ if modality.startswith("image"):
+ return ""
+ elif modality.startswith("video"):
+ return ""
+ elif modality.startswith("audio"):
+ return ""
+
+ raise ValueError("Only image, video, or audio modality is supported")
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+
+ config = vllm_config.model_config.hf_config
+
+ thinker_config: BailingMM2Config = config
+ if (
+ thinker_config.llm_config is None
+ or thinker_config.vision_config is None
+ or thinker_config.audio_config is None
+ ):
+ raise ValueError(
+ "MingFlashOmniThinker requires `llm_config`, `vision_config`, and `audio_config` in `thinker_config`."
+ )
+
+ llm_config = thinker_config.llm_config
+
+ self.config = llm_config
+ self.thinker_config = thinker_config
+ self.have_multimodal_outputs = True
+
+ # Initialize LLM as a component
+ with self._mark_language_model(vllm_config):
+ llm_vllm_config = vllm_config.with_hf_config(llm_config)
+ self.language_model = BailingMoeV2ForCausalLM(
+ vllm_config=llm_vllm_config, prefix=maybe_prefix(prefix, "llm")
+ )
+
+ # Ming thinker is inherently multimodal; initialize both towers eagerly.
+ with self._mark_tower_model(vllm_config, {"image", "video"}):
+ self.vision = MingVisionEncoder(
+ vision_config=thinker_config.vision_config,
+ quant_config=vllm_config.quant_config,
+ prefix=maybe_prefix(prefix, "vision"),
+ )
+ self.linear_proj = VisionProjector(
+ vision_dim=self.vision.image_emb_dim,
+ llm_dim=llm_config.hidden_size,
+ mlp_depth=getattr(thinker_config, "mlp_depth", 2),
+ )
+ logger.info("Initialized MingVisionEncoder and VisionProjector")
+
+ audio_cfg = thinker_config.audio_config
+ whisper_cfg = getattr(audio_cfg, "whisper_encoder_config", {}) or {}
+ with self._mark_tower_model(vllm_config, "audio"):
+ self.audio = WhisperAudioEncoder(
+ **whisper_cfg,
+ use_flash_attn=True,
+ )
+ self.linear_proj_audio = AudioProjector(
+ audio_dim=self.audio.audio_emb_dim,
+ llm_dim=llm_config.hidden_size,
+ ds_kernel_size=getattr(audio_cfg, "ds_kernel_size", 3),
+ ds_stride=getattr(audio_cfg, "ds_stride", 2),
+ mlp_depth=getattr(thinker_config, "mlp_depth", 1),
+ )
+ logger.info("Initialized WhisperAudioEncoder and AudioProjector")
+
+ # Expose interfaces
+ self.make_empty_intermediate_tensors = self.language_model.make_empty_intermediate_tensors
+
+ logger.info("MingFlashOmniThinker initialized with vision and audio towers")
+
+ def extract_image_feature(self, pixel_values: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
+ """Extract and project image features.
+
+ Args:
+ pixel_values: Flattened pixel values from vision processor.
+ grid_thw: [num_images, 3] tensor of (t, h, w) grid dimensions.
+
+ Returns:
+ [seq_len, hidden_size] L2-normalized image embeddings.
+ """
+ if self.vision is None:
+ raise ValueError("Vision encoder not initialized")
+
+ with torch.amp.autocast(pixel_values.device.type, dtype=torch.bfloat16):
+ image_embeds = self.vision(pixel_values, grid_thw=grid_thw)
+
+ if self.vision.use_deepstack:
+ image_embeds = image_embeds[:, : self.vision.image_emb_dim]
+
+ image_embeds = self.linear_proj(image_embeds)
+ image_embeds = F.normalize(image_embeds, dim=-1)
+ return image_embeds
+
+ def extract_audio_feature(
+ self, audio_feats: torch.Tensor, audio_feats_lengths: torch.Tensor
+ ) -> tuple[torch.Tensor, ...]:
+ """Extract and project audio features.
+
+ Args:
+ audio_feats: [B, L_total, n_mels] wrapped mel features — multiple audio
+ clips per batch item are concatenated along the time dimension
+ (time-first, as produced by MingWhisperFeatureExtractor).
+ audio_feats_lengths: [B, N] lengths of each audio clip per batch item.
+ N is the max number of clips per item; zero-padded entries are skipped.
+
+ Returns:
+ Tuple of per-clip [T'_i, hidden_size] projected audio embeddings.
+ """
+ if self.audio is None:
+ raise ValueError("Audio encoder not initialized")
+
+ # Unwrap packed [B, L_total, n_mels] into a list of [n_mels, T'_i] tensors,
+ # one per audio clip, as expected by WhisperAudioEncoder.
+ x_list: list[torch.Tensor] = []
+ audio_lens: list[int] = []
+ for i in range(audio_feats_lengths.shape[0]):
+ feat_index = 0
+ for j in range(audio_feats_lengths.shape[1]):
+ feat_len = int(audio_feats_lengths[i, j].item())
+ if feat_len == 0:
+ break
+ mel_seg = audio_feats[i, feat_index : feat_index + feat_len].transpose(0, 1)
+ x_list.append(mel_seg)
+ audio_lens.append(feat_len)
+ feat_index += feat_len
+
+ audio_packed = self.audio(x_list, audio_lens)
+
+ # Compute per-clip lengths after Whisper Conv1d (kernel=3, stride=2, pad=1)
+ encoded_lens = [(audio_len - 3 + 2) // 2 + 1 for audio_len in audio_lens]
+
+ # Project packed
+ proj_packed, proj_lens = self.linear_proj_audio.forward_packed(audio_packed, encoded_lens)
+
+ normalize = getattr(self.thinker_config.audio_config, "norm_query_embeds", False)
+ if normalize:
+ proj_packed = F.normalize(proj_packed, dim=-1)
+
+ proj_packed = proj_packed.to(audio_feats.dtype)
+
+ # Split into per-clip tensors
+ return proj_packed.split(proj_lens)
+
+ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
+ """Parse and validate multimodal kwargs into per-modality dicts."""
+ mm_input_by_modality: dict[str, Qwen2_5_VLImageInputs | Qwen2_5_VLVideoInputs | MingAudioInput] = {}
+
+ for key in kwargs:
+ if key == "pixel_values" and "image" not in mm_input_by_modality:
+ pixel_values = kwargs.get("pixel_values")
+ image_grid_thw = kwargs.get("image_grid_thw")
+ if pixel_values is not None and image_grid_thw is not None:
+ mm_input_by_modality["image"] = Qwen2_5_VLImagePixelInputs(
+ type="pixel_values",
+ pixel_values=pixel_values, # type: ignore[arg-type]
+ image_grid_thw=image_grid_thw, # type: ignore[arg-type]
+ )
+ elif key == "pixel_values_videos" and "video" not in mm_input_by_modality:
+ pixel_values_videos = kwargs.get("pixel_values_videos")
+ video_grid_thw = kwargs.get("video_grid_thw")
+ second_per_grid_ts = kwargs.get("second_per_grid_ts")
+ if pixel_values_videos is not None and video_grid_thw is not None:
+ mm_input_by_modality["video"] = Qwen2_5_VLVideoPixelInputs(
+ type="pixel_values_videos",
+ pixel_values_videos=pixel_values_videos, # type: ignore[arg-type]
+ video_grid_thw=video_grid_thw, # type: ignore[arg-type]
+ second_per_grid_ts=second_per_grid_ts, # type: ignore[arg-type]
+ )
+ elif key == "audio_feats" and "audio" not in mm_input_by_modality:
+ audio_feats = kwargs.get("audio_feats")
+ audio_feats_lengths = kwargs.get("audio_feats_lengths")
+ if audio_feats is not None and audio_feats_lengths is not None:
+ mm_input_by_modality["audio"] = MingAudioInput(
+ audio_feats=audio_feats, # type: ignore[arg-type]
+ audio_feats_lengths=audio_feats_lengths, # type: ignore[arg-type]
+ )
+
+ return mm_input_by_modality
+
+ def _process_image_input(self, image_input: Qwen2_5_VLImageInputs) -> list[torch.Tensor]:
+ # Splits the flat [total_tokens, D] output of extract_image_feature
+ # into one tensor per image.
+ pixel_values = image_input["pixel_values"]
+ image_grid_thw = image_input["image_grid_thw"]
+ image_embeds = self.extract_image_feature(pixel_values, image_grid_thw)
+ merge_unit = self.thinker_config.vision_config.spatial_merge_size**2
+ sizes = (image_grid_thw.prod(dim=-1) // merge_unit).tolist()
+ return list(image_embeds.split([int(s) for s in sizes], dim=0))
+
+ def _process_video_input(self, video_input: Qwen2_5_VLVideoInputs) -> list[torch.Tensor]:
+ pixel_values_videos = video_input["pixel_values_videos"]
+ video_grid_thw = video_input["video_grid_thw"]
+ video_embeds = self.extract_image_feature(pixel_values_videos, video_grid_thw)
+ merge_unit = self.thinker_config.vision_config.spatial_merge_size**2
+ sizes = (video_grid_thw.prod(dim=-1) // merge_unit).tolist()
+ return list(video_embeds.split([int(s) for s in sizes], dim=0))
+
+ def _process_audio_input(self, audio_input: MingAudioInput) -> list[torch.Tensor]:
+ return list(self.extract_audio_feature(audio_input["audio_feats"], audio_input["audio_feats_lengths"]))
+
+ def _compute_modality_masks(self, input_ids: torch.Tensor) -> tuple[torch.Tensor | None, torch.Tensor | None]:
+ """Compute vision and audio MoE-routing masks from input_ids.
+
+ Returns:
+ Tuple of (vision_mask, audio_mask), each shape [seq_len] bool.
+ """
+ llm_config = self.config
+
+ # vision mask
+ vision_mask = torch.zeros_like(input_ids, dtype=torch.bool)
+ image_token = llm_config.image_patch_token
+ video_token = llm_config.video_patch_token
+ vision_mask = vision_mask | (input_ids == image_token)
+ vision_mask = vision_mask | (input_ids == video_token)
+
+ # audio mask
+ audio_token = llm_config.audio_patch_token
+ audio_mask = input_ids == audio_token
+
+ return vision_mask, audio_mask
+
+ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
+ mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
+ if not mm_input_by_modality:
+ return []
+
+ # preserve the order of modalities
+ multimodal_embeddings: tuple[torch.Tensor, ...] = ()
+
+ for modality, mm_input in mm_input_by_modality.items():
+ if modality == "image":
+ multimodal_embeddings += tuple(self._process_image_input(mm_input)) # type: ignore[arg-type]
+ elif modality == "video":
+ multimodal_embeddings += tuple(self._process_video_input(mm_input)) # type: ignore[arg-type]
+ elif modality == "audio":
+ multimodal_embeddings += tuple(self._process_audio_input(mm_input)) # type: ignore[arg-type]
+
+ return multimodal_embeddings
+
+ def embed_input_ids(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: MultiModalEmbeddings | None = None,
+ *,
+ is_multimodal: torch.Tensor | None = None,
+ handle_oov_mm_token: bool = False,
+ ) -> torch.Tensor:
+ inputs_embeds = self.language_model.model.word_embeddings(input_ids)
+
+ if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
+ return inputs_embeds
+
+ assert is_multimodal is not None, "`is_multimodal` mask required when `multimodal_embeddings` provided"
+ return _merge_multimodal_embeddings(
+ inputs_embeds=inputs_embeds,
+ multimodal_embeddings=multimodal_embeddings,
+ is_multimodal=is_multimodal,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ **kwargs,
+ ) -> OmniOutput:
+ # Compute MoE modality masks on every device
+ image_mask, audio_mask = self._compute_modality_masks(input_ids)
+ hidden_states = self.language_model.forward(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ image_mask=image_mask,
+ audio_mask=audio_mask,
+ )
+
+ # Capture embeddings for downstream stages
+ multimodal_outputs = {
+ "final_hidden_states": hidden_states,
+ }
+
+ return OmniOutput(
+ text_hidden_states=hidden_states,
+ multimodal_outputs=multimodal_outputs,
+ )
+
+ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) -> torch.Tensor | None:
+ return self.language_model.compute_logits(hidden_states, sampling_metadata)
+
+ def sample(self, logits: torch.Tensor, sampling_metadata):
+ return self.language_model.sample(logits, sampling_metadata)
+
+ @property
+ def sampler(self):
+ return self.language_model.sampler
+
+ def iter_mm_features(
+ self,
+ mm_features: list[MultiModalFeatureSpec],
+ ) -> Iterator[tuple[int, str, dict[str, Any]]]:
+ """Iterate over image/video features sorted by token position.
+
+ Yields: (offset, modality, feature_data) where feature_data contains:
+ - image: {"grid_t", "grid_h", "grid_w", "second_per_grid_t"}
+ - video: {"grid_t", "grid_h", "grid_w", "second_per_grid_t"}
+
+ Audio features are not yielded: Ming assigns them sequential
+ text positions (same T/H/W value) rather than 3D grid positions.
+ """
+ spatial_merge_size = self.config.spatial_merge_size
+
+ for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
+ if mm_feature.data is None:
+ continue
+
+ offset = mm_feature.mm_position.offset
+ modality = mm_feature.modality
+
+ if modality == "image":
+ t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
+ yield (
+ offset,
+ "image",
+ {
+ "grid_t": int(t),
+ "grid_h": int(h) // spatial_merge_size,
+ "grid_w": int(w) // spatial_merge_size,
+ "second_per_grid_t": 0.0,
+ },
+ )
+ elif modality == "video":
+ t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
+ second_per_grid_t = 1.0
+ spgt_field = mm_feature.data.get("second_per_grid_ts")
+ if spgt_field is not None:
+ second_per_grid_t = float(spgt_field.data.item())
+ yield (
+ offset,
+ "video",
+ {
+ "grid_t": int(t),
+ "grid_h": int(h) // spatial_merge_size,
+ "grid_w": int(w) // spatial_merge_size,
+ "second_per_grid_t": second_per_grid_t,
+ },
+ )
+
+ def get_mrope_input_positions(
+ self,
+ input_tokens: list[int],
+ mm_features: list[MultiModalFeatureSpec] | None = None,
+ **kwargs: object,
+ ) -> tuple[torch.Tensor, int]:
+ """Compute M-RoPE input positions using mm_features directly."""
+ llm_config = self.config
+ tokens_per_second: int = getattr(llm_config, "tokens_per_second", 2)
+ seq_len = len(input_tokens)
+
+ llm_pos_ids_list: list[np.ndarray] = []
+ st = 0 # index of next unprocessed token
+
+ for patch_offset, _modality, data in self.iter_mm_features(mm_features or []):
+ text_len = patch_offset - st
+ st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
+ if text_len > 0:
+ llm_pos_ids_list.append(np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx)
+ st_idx += text_len
+
+ # 3-D grid positions for patch tokens
+ grid_t: int = data["grid_t"]
+ grid_h: int = data["grid_h"]
+ grid_w: int = data["grid_w"]
+ second_per_grid_t: float = data["second_per_grid_t"]
+
+ t_raw = np.arange(grid_t)
+ if second_per_grid_t > 0:
+ t_index = (t_raw * second_per_grid_t * tokens_per_second).astype(np.int64)
+ else:
+ t_index = t_raw.astype(np.int64)
+ t_index = np.repeat(t_index, grid_h * grid_w)
+
+ h_index = np.tile(np.arange(grid_h).repeat(grid_w), grid_t)
+ w_index = np.tile(np.arange(grid_w), grid_t * grid_h)
+
+ llm_pos_ids_list.append(np.stack([t_index, h_index, w_index]) + st_idx)
+
+ num_patches = grid_t * grid_h * grid_w
+ st = patch_offset + num_patches
+
+ if st < seq_len:
+ st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
+ tail_len = seq_len - st
+ llm_pos_ids_list.append(np.broadcast_to(np.arange(tail_len), (3, tail_len)) + st_idx)
+
+ if llm_pos_ids_list:
+ position_ids = torch.from_numpy(np.concatenate(llm_pos_ids_list, axis=1).astype(np.int64)) # (3, seq_len)
+ else:
+ # text-only, simple sequential positions
+ position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).expand(3, -1)
+
+ mrope_position_delta = int(position_ids.max().item()) + 1 - seq_len
+ return position_ids, mrope_position_delta
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/modeling_bailing_moe_v2.py b/vllm_omni/model_executor/models/ming_flash_omni/modeling_bailing_moe_v2.py
new file mode 100644
index 0000000000..1ff362c5b9
--- /dev/null
+++ b/vllm_omni/model_executor/models/ming_flash_omni/modeling_bailing_moe_v2.py
@@ -0,0 +1,896 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved.
+# Adapted from Ming
+# https://github.com/inclusionAI/Ming/blob/2a0c02ae3130190160c215f89fce7de3005db483/modeling_bailing_moe_v2.py
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Iterable
+
+import torch
+from torch import nn
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import VllmConfig
+from vllm.config.cache import CacheConfig
+from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
+from vllm.logger import init_logger
+from vllm.model_executor.layers.activation import SiluAndMul
+from vllm.model_executor.layers.attention import Attention
+from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear,
+)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
+from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.models.utils import (
+ PPMissingLayer,
+ WeightsMapper,
+ make_empty_intermediate_tensors_factory,
+ make_layers,
+ maybe_prefix,
+)
+from vllm.sequence import IntermediateTensors
+from vllm.v1.outputs import SamplerOutput
+from vllm.v1.sample.sampler import Sampler
+
+from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin
+from vllm_omni.transformers_utils.configs.ming_flash_omni import BailingMoeV2Config
+
+logger = init_logger(__name__)
+
+
+class MingVideoRopeMRotaryEmbedding(MRotaryEmbedding):
+ """MRotaryEmbedding with Ming's video_rope cos/sin interleaving.
+
+ Unlike standard mrope which maps contiguous frequency sections to T/H/W,
+ video_rope alternates H/W frequencies element-wise in the spatial section
+ and places temporal frequencies at the end:
+ Standard mrope: [T T T ... H H H ... W W W ...]
+ Video rope: [H W H W ... H W ... T T T ...]
+
+ Refer to Ming's BailingMoeV2RotaryEmbedding3D
+ https://github.com/inclusionAI/Ming/blob/2a0c02ae3130190160c215f89fce7de3005db483/modeling_bailing_moe_v2.py#L174
+ """
+
+ def _remap_video_rope(
+ self,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Remap 3D cos/sin to video_rope interleaved layout.
+
+ Args:
+ cos, sin: [3, num_tokens, rotary_dim // 2]
+ Returns:
+ cos, sin: [num_tokens, rotary_dim // 2]
+
+ Refer to Ming's apply_3d_rotary_pos_emb
+ https://github.com/inclusionAI/Ming/blob/2a0c02ae3130190160c215f89fce7de3005db483/modeling_bailing_moe_v2.py#L226
+ """
+ assert self.mrope_section is not None
+ hw_size = self.mrope_section[1] + self.mrope_section[2]
+
+ result_cos = torch.empty_like(cos[0])
+ result_sin = torch.empty_like(sin[0])
+
+ # Spatial frequencies: even indices from H (dim 1), odd from W (dim 2)
+ result_cos[:, 0:hw_size:2] = cos[1, :, 0:hw_size:2]
+ result_cos[:, 1:hw_size:2] = cos[2, :, 1:hw_size:2]
+ result_sin[:, 0:hw_size:2] = sin[1, :, 0:hw_size:2]
+ result_sin[:, 1:hw_size:2] = sin[2, :, 1:hw_size:2]
+
+ # Temporal frequencies at the end
+ result_cos[:, hw_size:] = cos[0, :, hw_size:]
+ result_sin[:, hw_size:] = sin[0, :, hw_size:]
+
+ return result_cos, result_sin
+
+ def forward_native(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor | None = None,
+ offsets: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ assert positions.ndim == 1 or positions.ndim == 2
+ assert key is not None
+
+ cos_sin_cache = self._match_cos_sin_cache_dtype(query)
+ num_tokens = positions.shape[-1]
+ cos_sin = cos_sin_cache[positions]
+ cos, sin = cos_sin.chunk(2, dim=-1)
+
+ if positions.ndim == 2:
+ cos, sin = self._remap_video_rope(cos, sin)
+
+ query_shape = query.shape
+ query = query.view(num_tokens, -1, self.head_size)
+ query_rot = query[..., : self.rotary_dim]
+ query_pass = query[..., self.rotary_dim :]
+ query_rot = self.apply_rotary_emb.forward_native(query_rot, cos, sin)
+ query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
+
+ key_shape = key.shape
+ key = key.view(num_tokens, -1, self.head_size)
+ key_rot = key[..., : self.rotary_dim]
+ key_pass = key[..., self.rotary_dim :]
+ key_rot = self.apply_rotary_emb.forward_native(key_rot, cos, sin)
+ key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
+
+ return query, key
+
+ def forward_cuda(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor | None = None,
+ offsets: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ # No custom Triton kernel for video_rope; fall back to native for 3D
+ # TODO: Consider custom optimization
+ if positions.ndim == 2:
+ return self.forward_native(positions, query, key, offsets)
+ return super().forward_cuda(positions, query, key, offsets)
+
+ def forward_cpu(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor | None = None,
+ offsets: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ return self.forward_native(positions, query, key, offsets)
+
+
+class BailingMoeV2MLP(nn.Module):
+ def __init__(
+ self,
+ config: BailingMoeV2Config,
+ intermediate_size: int,
+ hidden_act: str = "silu",
+ quant_config: QuantizationConfig | None = None,
+ reduce_results: bool = True,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = intermediate_size
+
+ self.gate_up_proj = MergedColumnParallelLinear(
+ self.hidden_size,
+ [self.intermediate_size] * 2,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_up_proj",
+ )
+ self.down_proj = RowParallelLinear(
+ self.intermediate_size,
+ self.hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ reduce_results=reduce_results,
+ prefix=f"{prefix}.down_proj",
+ )
+
+ if hidden_act != "silu":
+ raise ValueError(f"Unsupported activation: {hidden_act}")
+ self.act_fn = SiluAndMul()
+
+ def forward(self, x):
+ gate_up, _ = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x, _ = self.down_proj(x)
+ return x
+
+
+class BailingMoeV2Gate(nn.Module):
+ """MoE routing gate with grouped expert selection."""
+
+ def __init__(
+ self,
+ config: BailingMoeV2Config,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.config = config
+ self.top_k = config.num_experts_per_tok
+ self.num_experts = config.num_experts
+
+ self.n_group = config.n_group
+ self.topk_group = config.topk_group
+
+ self.gating_dim = config.hidden_size
+
+ self.gate = ReplicatedLinear(
+ self.gating_dim,
+ self.num_experts,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate",
+ )
+
+ self.routed_scaling_factor = config.routed_scaling_factor
+
+ self.expert_bias = nn.Parameter(torch.zeros(self.num_experts), requires_grad=False)
+
+ def group_limited_topk(self, scores: torch.Tensor):
+ """Group-limited top-k selection for expert routing."""
+ num_tokens, _ = scores.size()
+ # Organize experts into groups
+ group_scores = scores.view(num_tokens, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
+ group_mask = torch.zeros_like(group_scores)
+ group_mask.scatter_(1, group_idx, 1)
+
+ # Mask experts based on selected groups
+ score_mask = (
+ group_mask.unsqueeze(-1)
+ .expand(num_tokens, self.n_group, self.num_experts // self.n_group)
+ .reshape(num_tokens, -1)
+ )
+
+ masked_scores = scores.masked_fill(~score_mask.bool(), float("-inf"))
+ probs, top_indices = torch.topk(masked_scores, k=self.top_k, dim=-1, sorted=False)
+
+ return probs, top_indices
+
+ def forward(self, hidden_states):
+ # compute gating score
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+ logits, _ = self.gate(hidden_states)
+
+ logits = logits.float()
+ scores = torch.sigmoid(logits)
+
+ scores_for_routing = scores + self.expert_bias
+ _, topk_idx = self.group_limited_topk(scores_for_routing)
+
+ scores = torch.gather(scores, dim=1, index=topk_idx).type_as(logits)
+
+ topk_weight = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.top_k > 1 else scores
+ topk_weight = topk_weight * self.routed_scaling_factor
+
+ return topk_idx, topk_weight, logits
+
+
+def _unpack_multi_routing(
+ hidden_states: torch.Tensor,
+ gating_output: torch.Tensor,
+ topk: int,
+ renormalize: bool,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Stateless routing function that unpacks pre-computed routing results.
+
+ Used as `custom_routing_function` for `FusedMoE`. The caller is expected
+ to pack (topk_weight, topk_idx) into `gating_output` before
+ calling FusedMoE.forward(), and this function unpacks them.
+
+ Args:
+ gating_output: [num_tokens, top_k * 2]
+ - [:, :top_k], topk_weight (float)
+ - [:, top_k:], topk_idx (float, cast back to int)
+ """
+ topk_weight = gating_output[:, :topk].contiguous()
+ topk_idx = gating_output[:, topk:]
+ return topk_weight.to(torch.float32), topk_idx.to(torch.int32)
+
+
+class BailingMoeV2SparseMoeBlock(nn.Module):
+ """Sparse MoE block with MultiRouter support for multimodal routing.
+
+ Keep the custom multi-router gating logic external.
+ """
+
+ def __init__(
+ self,
+ config: BailingMoeV2Config,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.config = config
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.num_experts_per_tok = config.num_experts_per_tok
+
+ if isinstance(self.config.num_shared_experts, int) and self.config.num_shared_experts > 0:
+ self.shared_experts = BailingMoeV2MLP(
+ config=self.config,
+ intermediate_size=self.config.moe_intermediate_size * self.config.num_shared_experts,
+ quant_config=quant_config,
+ reduce_results=False,
+ prefix=f"{prefix}.shared_experts",
+ )
+ else:
+ self.shared_experts = None
+
+ self.experts = SharedFusedMoE(
+ shared_experts=self.shared_experts,
+ num_experts=config.num_experts,
+ top_k=config.num_experts_per_tok,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size,
+ custom_routing_function=_unpack_multi_routing,
+ renormalize=False, # we handle normalization in the gate
+ reduce_results=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.experts",
+ )
+
+ self.experts.expert_mapping = FusedMoE.make_expert_params_mapping(
+ self.experts,
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=config.num_experts,
+ )
+
+ self.router_type = self.config.router_type
+ if self.router_type == "topN":
+ self.gate = BailingMoeV2Gate(self.config, quant_config, prefix=f"{prefix}.gate")
+ elif self.router_type == "MultiRouter":
+ self.gate = BailingMoeV2Gate(self.config, quant_config, prefix=f"{prefix}.gate")
+ self.image_gate = BailingMoeV2Gate(self.config, quant_config, prefix=f"{prefix}.image_gate")
+ self.audio_gate = BailingMoeV2Gate(self.config, quant_config, prefix=f"{prefix}.audio_gate")
+ else:
+ raise ValueError(f"Unsupported router_type: {self.router_type}")
+
+ @staticmethod
+ def _normalize_mask(
+ mask: torch.Tensor,
+ bsz: int,
+ seq_len: int,
+ name: str,
+ ) -> torch.Tensor:
+ """Validate and reshape a modality mask to [bsz*seq_len, 1] bool."""
+ N = bsz * seq_len
+ if mask.ndim == 1:
+ # vLLM path: flat tokens [N]
+ assert mask.shape[0] == N, f"{name} length {mask.shape[0]} != N={N}"
+ elif mask.ndim == 2:
+ assert mask.shape == (bsz, seq_len), f"{name} shape {mask.shape} != ({bsz}, {seq_len})"
+ elif mask.ndim == 3:
+ assert mask.shape == (bsz, seq_len, 1), f"{name} shape {mask.shape} != ({bsz}, {seq_len}, 1)"
+ else:
+ raise ValueError(f"Unsupported {name} shape: {mask.shape}")
+
+ return mask.reshape(N, 1).bool()
+
+ def forward(self, hidden_states, image_mask: torch.Tensor, audio_mask: torch.Tensor):
+ # TODO(yuanheng-zhao): revise the shapes in the flow
+ assert 2 <= hidden_states.dim() <= 3, f"{self.__class__.__name__} only supports 2D or 3D inputs"
+ input_is_2d = hidden_states.ndim == 2
+ if input_is_2d:
+ hidden_states = hidden_states.unsqueeze(0)
+
+ bsz, seq_len, h = hidden_states.shape
+
+ if self.router_type == "MultiRouter":
+ image_mask = self._normalize_mask(image_mask, bsz, seq_len, "image_mask").to(hidden_states.device)
+ audio_mask = self._normalize_mask(audio_mask, bsz, seq_len, "audio_mask").to(hidden_states.device)
+
+ # if image_mask is not None and audio_mask is not None:
+ # assert torch.logical_and(image_mask, audio_mask).sum() == 0
+
+ image_topk_idx, image_topk_weight, _ = self.image_gate(hidden_states)
+ audio_topk_idx, audio_topk_weight, _ = self.audio_gate(hidden_states)
+ topk_idx, topk_weight, _ = self.gate(hidden_states)
+
+ topk_idx = torch.where(image_mask, image_topk_idx, topk_idx)
+ topk_weight = torch.where(image_mask, image_topk_weight, topk_weight)
+ topk_idx = torch.where(audio_mask, audio_topk_idx, topk_idx)
+ topk_weight = torch.where(audio_mask, audio_topk_weight, topk_weight)
+ else:
+ topk_idx, topk_weight, _ = self.gate(hidden_states)
+
+ # Pack pre-computed routing into a single tensor
+ packed_routing = torch.cat(
+ [
+ topk_weight.to(hidden_states.dtype),
+ topk_idx.to(hidden_states.dtype),
+ ],
+ dim=-1,
+ )
+
+ # SharedFusedMoE expects 2D hidden_states
+ hidden_states_2d = hidden_states.view(-1, h)
+ result = self.experts(hidden_states_2d, packed_routing)
+
+ if self.shared_experts is not None:
+ shared_output, fused_out = result
+ else:
+ shared_output, fused_out = None, result
+
+ final_hidden_states = fused_out + shared_output if shared_output is not None else fused_out
+
+ final_hidden_states = final_hidden_states.view(bsz, seq_len, h)
+
+ return final_hidden_states.squeeze(0) if input_is_2d else final_hidden_states
+
+
+class BailingMoeV2Attention(nn.Module):
+ """Multi-headed attention using vLLM's Attention layer with 3D RoPE support."""
+
+ def __init__(
+ self,
+ config: BailingMoeV2Config,
+ layer_idx: int,
+ cache_config: CacheConfig | None = None,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.num_kv_heads = config.num_key_value_heads
+ self.head_dim = config.head_dim
+
+ tp_size = get_tensor_model_parallel_world_size()
+ assert self.num_heads % tp_size == 0
+ self.num_heads = self.num_heads // tp_size
+ self.num_kv_heads = max(1, self.num_kv_heads // tp_size)
+
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+
+ partial_rotary_factor = config.partial_rotary_factor
+ self.rope_dim = int(self.head_dim * partial_rotary_factor)
+
+ total_num_heads = config.num_attention_heads
+ total_num_kv_heads = config.num_key_value_heads
+ self.qkv_proj = QKVParallelLinear(
+ self.hidden_size,
+ self.head_dim,
+ total_num_heads,
+ total_num_kv_heads,
+ bias=config.use_qkv_bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
+ )
+
+ self.dense = RowParallelLinear(
+ total_num_heads * self.head_dim,
+ self.hidden_size,
+ bias=config.use_bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.dense",
+ )
+
+ # apply vLLM RMSNorm here rather than BailingMoeV2RMSNorm, diff might exist
+ self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+
+ # 3D Rotary embeddings for multimodal
+ if config.rope_scaling is None:
+ raise ValueError("rope_scaling must not be None")
+
+ rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ mrope_section = config.rope_scaling.get("mrope_section", [8, 12, 12])
+
+ if rope_type == "video_rope":
+ # Ming-specific video_rope with custom H/W interleaving
+ self.rotary_emb = MingVideoRopeMRotaryEmbedding(
+ head_size=self.head_dim,
+ rotary_dim=self.rope_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ is_neox_style=True,
+ dtype=torch.get_default_dtype(),
+ mrope_section=mrope_section,
+ )
+ else:
+ # Standard m_rope (rope_type "3D", "default", or None)
+ rope_scaling = dict(config.rope_scaling)
+ rope_scaling["rope_type"] = "default" # normalize for get_rope dispatch
+ rope_scaling["mrope_section"] = mrope_section
+ self.rotary_emb = get_rope(
+ head_size=self.head_dim,
+ max_position=config.max_position_embeddings,
+ is_neox_style=True,
+ rope_parameters={
+ "rope_theta": config.rope_theta,
+ "partial_rotary_factor": config.partial_rotary_factor,
+ **rope_scaling,
+ },
+ )
+
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn",
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ """Forward pass for attention with 3D RoPE.
+
+ Args:
+ positions: Position IDs, shape (3, num_tokens) for 3D rope
+ or (num_tokens,) for text-only
+ hidden_states: Input hidden states, shape (num_tokens, hidden_size)
+
+ Returns:
+ Attention output tensor, shape (num_tokens, hidden_size)
+ """
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+
+ num_tokens = q.shape[0]
+ q = self.q_norm(q.view(num_tokens, self.num_heads, self.head_dim)).view(num_tokens, self.q_size)
+ k = self.k_norm(k.view(num_tokens, self.num_kv_heads, self.head_dim)).view(num_tokens, self.kv_size)
+
+ q, k = self.rotary_emb(positions, q, k)
+
+ attn_output = self.attn(q, k, v)
+
+ output, _ = self.dense(attn_output)
+ return output
+
+
+class BailingMoeV2DecoderLayer(nn.Module):
+ """Decoder layer with attention and MoE MLP."""
+
+ def __init__(
+ self,
+ config: BailingMoeV2Config,
+ layer_idx: int,
+ cache_config: CacheConfig | None = None,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.layer_idx = layer_idx
+
+ self.attention = BailingMoeV2Attention(
+ config=config,
+ layer_idx=layer_idx,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attention",
+ )
+
+ # MLP or MoE based on layer index
+ if config.num_experts is not None and layer_idx >= config.first_k_dense_replace:
+ self.mlp = BailingMoeV2SparseMoeBlock(
+ config=config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ )
+ self.is_moe = True
+ else:
+ self.mlp = BailingMoeV2MLP(
+ config=config,
+ intermediate_size=config.intermediate_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ )
+ self.is_moe = False
+
+ # apply vLLM RMSNorm to replace BailingMoeV2RMSNorm, diff might exist
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ image_mask: torch.Tensor | None = None,
+ audio_mask: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Forward pass for decoder layer.
+
+ Args:
+ positions: Position IDs
+ hidden_states: Input hidden states
+ residual: Residual connection from previous layer
+ image_mask: Mask for image tokens (for MultiRouter MoE)
+ audio_mask: Mask for audio tokens (for MultiRouter MoE)
+
+ Returns:
+ Tuple of (hidden_states, residual)
+ """
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ hidden_states = self.attention(
+ positions=positions,
+ hidden_states=hidden_states,
+ )
+
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+
+ if self.is_moe:
+ hidden_states = self.mlp(hidden_states, image_mask, audio_mask)
+ else:
+ # Dense MLP only takes hidden_states (no routing masks)
+ hidden_states = self.mlp(hidden_states)
+
+ return hidden_states, residual
+
+
+@support_torch_compile(
+ dynamic_arg_dims={
+ "input_ids": 0,
+ "positions": -1,
+ "intermediate_tensors": 0,
+ "inputs_embeds": 0,
+ "image_mask": 0,
+ "audio_mask": 0,
+ }
+)
+class BailingMoeV2Model(nn.Module):
+ """BailingMoeV2 Model adapted from:
+
+ Ming repo BailingMoeV2Model
+ https://github.com/inclusionAI/Ming/blob/2a0c02ae3130190160c215f89fce7de3005db483/modeling_bailing_moe_v2.py
+ vLLM repo BailingMoeModel
+ https://github.com/vllm-project/vllm/blob/7291d1b288558d48508e1a17c37b0aa170332264/vllm/model_executor/models/bailing_moe.py
+ """
+
+ def __init__(
+ self,
+ *,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ ):
+ super().__init__()
+
+ # BailingMoeV2Config
+ config = vllm_config.model_config.hf_text_config
+
+ cache_config = vllm_config.cache_config
+ quant_config = vllm_config.quant_config
+
+ self.config = config
+ self.quant_config = quant_config
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
+
+ if get_pp_group().is_first_rank or (self.tie_word_embeddings and get_pp_group().is_last_rank):
+ self.word_embeddings = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.word_embeddings",
+ )
+ else:
+ self.word_embeddings = PPMissingLayer()
+
+ # Decoder layers with later pipeline parallelism support
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ config.num_hidden_layers,
+ lambda prefix: BailingMoeV2DecoderLayer(
+ config=config,
+ layer_idx=int(prefix.split(".")[-1]),
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=prefix,
+ ),
+ prefix=f"{prefix}.layers",
+ )
+
+ if get_pp_group().is_last_rank:
+ # apply vLLM RMSNorm to replace BailingMoeV2RMSNorm, diff might exist
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ else:
+ self.norm = PPMissingLayer()
+
+ self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], config.hidden_size
+ )
+
+ def get_input_embeddings(self):
+ return self.word_embeddings
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ image_mask: torch.Tensor | None = None,
+ audio_mask: torch.Tensor | None = None,
+ ) -> torch.Tensor | IntermediateTensors:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.word_embeddings(input_ids)
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+
+ for layer in self.layers[self.start_layer : self.end_layer]:
+ hidden_states, residual = layer(
+ positions,
+ hidden_states,
+ residual,
+ image_mask=image_mask,
+ audio_mask=audio_mask,
+ )
+
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors({"hidden_states": hidden_states, "residual": residual})
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+class BailingMoeV2ForCausalLM(nn.Module, CustomProcessMixin):
+ """BailingMoeV2 model for causal language modeling, adapted for vLLM.
+
+ Inherits from CustomProcessMixin to support custom preprocessing and postprocessing
+ for integration with omni model pipelines.
+ """
+
+ packed_modules_mapping = {
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
+ "gate_up_proj": ["gate_proj", "up_proj"],
+ }
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+
+ # BailingMoeV2Config
+ config = vllm_config.model_config.hf_text_config
+ quant_config = vllm_config.quant_config
+
+ self.config = config
+ self.quant_config = quant_config
+
+ self.model = BailingMoeV2Model(
+ vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "model"),
+ )
+
+ self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
+ if get_pp_group().is_last_rank:
+ self.lm_head = ParallelLMHead(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=maybe_prefix(prefix, "lm_head"),
+ )
+ if self.tie_word_embeddings:
+ self.lm_head.weight = self.model.word_embeddings.weight
+ else:
+ self.lm_head = PPMissingLayer()
+
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+ self.sampler = Sampler()
+ self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ image_mask: torch.Tensor | None = None,
+ audio_mask: torch.Tensor | None = None,
+ ):
+ hidden_states = self.model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ image_mask=image_mask,
+ audio_mask=audio_mask,
+ )
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata,
+ ) -> torch.Tensor | None:
+ logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
+ return logits
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata,
+ ) -> SamplerOutput | None:
+ next_tokens = self.sampler(logits, sampling_metadata)
+ return next_tokens
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, weight_name, shard_id)
+ # BailingMoE stores fused QKV in checkpoint as query_key_value
+ ("qkv_proj", "query_key_value", None),
+ # Dense MLP and shared_experts gate/up are stored separately
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+
+ # Gate router linear layers: checkpoint `{r}.weight` -> model `{r}.gate.weight`
+ gate_name_mapper = WeightsMapper(
+ orig_to_new_substr={f".{r}.weight": f".{r}.gate.weight" for r in ("gate", "image_gate", "audio_gate")}
+ )
+
+ # FusedMoE expert params mapping is identical across all MoE layers
+ expert_params_mapping: list[tuple[str, str, int, str]] = []
+ for layer in self.model.layers:
+ if hasattr(layer, "mlp") and hasattr(layer.mlp, "experts"):
+ expert_params_mapping = layer.mlp.experts.expert_mapping
+ break
+
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
+ loaded_params: set[str] = set()
+
+ for name, loaded_weight in gate_name_mapper.apply(weights):
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in name or "mlp.experts" in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict.get(name)
+ if param is not None:
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ loaded_params.add(name)
+ break
+ else:
+ for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict.get(name)
+ if param is not None:
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id)
+ loaded_params.add(name)
+ break
+ else:
+ param = params_dict.get(name)
+ if param is not None:
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+
+ return loaded_params
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/projectors.py b/vllm_omni/model_executor/models/ming_flash_omni/projectors.py
new file mode 100644
index 0000000000..42e53d1c63
--- /dev/null
+++ b/vllm_omni/model_executor/models/ming_flash_omni/projectors.py
@@ -0,0 +1,184 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+# Copyright (c) Ant Group. All rights reserved.
+# Adapted from Ming repository modeling_bailingmm2.py
+# https://github.com/inclusionAI/Ming
+
+from collections.abc import Iterable
+
+import torch
+import torch.nn as nn
+from vllm.logger import init_logger
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+
+logger = init_logger(__name__)
+
+
+class Transpose(nn.Module):
+ """Used in nn.Sequential pipelines."""
+
+ def __init__(self, dim0: int, dim1: int):
+ super().__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return x.transpose(self.dim0, self.dim1)
+
+
+class VisionProjector(nn.Module):
+ """MLP projector from vision encoder output to LLM hidden space.
+
+ Args:
+ vision_dim: Vision encoder output dimension (out_hidden_size).
+ llm_dim: LLM hidden dimension.
+ mlp_depth: Number of linear layers (>= 1).
+ """
+
+ def __init__(self, vision_dim: int, llm_dim: int, mlp_depth: int = 1):
+ super().__init__()
+ layers: list[nn.Module] = [nn.Linear(vision_dim, llm_dim)]
+ for _ in range(1, mlp_depth):
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(llm_dim, llm_dim))
+ self.proj = nn.Sequential(*layers)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Project vision features.
+
+ Args:
+ x: [seq_len, vision_dim] or [B, seq_len, vision_dim]
+
+ Returns:
+ Projected features with last dim = llm_dim.
+ """
+ return self.proj(x)
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ for name, loaded_weight in weights:
+ if not name.startswith("proj."):
+ name = f"proj.{name}"
+ if name not in params_dict:
+ logger.warning("Skipping unknown vision projector weight: %s", name)
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+class AudioProjector(nn.Module):
+ """Projector for audio features.
+
+ Args:
+ audio_dim: Audio encoder output dimension (n_state).
+ llm_dim: LLM hidden dimension.
+ ds_kernel_size: Conv1d kernel size for downsampling.
+ ds_stride: Conv1d stride for downsampling.
+ mlp_depth: Total number of projection layers (>= 1).
+ """
+
+ def __init__(
+ self,
+ audio_dim: int,
+ llm_dim: int,
+ ds_kernel_size: int = 3,
+ ds_stride: int = 2,
+ mlp_depth: int = 1,
+ ):
+ super().__init__()
+ self.ds_kernel_size = ds_kernel_size
+ self.ds_stride = ds_stride
+
+ layers: list[nn.Module] = [
+ nn.Conv1d(
+ audio_dim,
+ llm_dim,
+ kernel_size=ds_kernel_size,
+ stride=ds_stride,
+ padding=ds_kernel_size // 2,
+ ),
+ Transpose(-1, -2), # [B, llm_dim, T'] -> [B, T', llm_dim]
+ ]
+ for _ in range(1, mlp_depth):
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(llm_dim, llm_dim))
+ self.proj = nn.Sequential(*layers)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Project audio features with temporal downsampling.
+
+ Args:
+ x: [B, T, audio_dim] audio encoder output (channel-last).
+
+ Returns:
+ [B, T', llm_dim] projected features (channel-last),
+ where T' = (T - ds_kernel_size + 2*(ds_kernel_size//2)) // ds_stride + 1.
+ """
+ # Conv1d expects [B, C, T], so transpose input
+ x = x.transpose(-1, -2) # [B, audio_dim, T]
+ return self.proj(x)
+
+ def forward_packed(
+ self,
+ packed: torch.Tensor,
+ encoded_lens: list[int],
+ ) -> tuple[torch.Tensor, list[int]]:
+ """Project packed audio features from the Whisper encoder.
+
+ Args:
+ packed: [total_T', audio_dim] packed encoder output.
+ encoded_lens: Length of each clip after Whisper encoding.
+
+ Returns:
+ Tuple of:
+ - [total_T'', llm_dim] packed projected features.
+ - List of projected lengths per clip.
+ """
+ conv1d = self.proj[0]
+ mlp = self.proj[2:]
+
+ # Split packed tensor per clip for Conv1d
+ segments = packed.split(encoded_lens)
+ conv_segments = []
+ proj_lens: list[int] = []
+ for seg in segments:
+ out = conv1d(seg.transpose(0, 1).unsqueeze(0)) # [1, llm_dim, T'_i]
+ out = out.squeeze(0).transpose(0, 1) # [T'_i, llm_dim]
+ conv_segments.append(out)
+ proj_lens.append(out.shape[0])
+
+ packed_proj = torch.cat(conv_segments, dim=0) # [total_T'', llm_dim]
+ packed_proj = mlp(packed_proj)
+ return packed_proj, proj_lens
+
+ def compute_output_length(self, input_length: torch.Tensor) -> torch.Tensor:
+ """Compute output sequence length after Conv1d downsampling.
+
+ Args:
+ input_length: Original mel spectrogram lengths.
+
+ Returns:
+ Output lengths after both convolutions.
+ """
+ length = (input_length - 3 + 2 * 1) // 2 + 1
+ length = (length - self.ds_kernel_size + 2 * (self.ds_kernel_size // 2)) // self.ds_stride + 1
+ return length
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ for name, loaded_weight in weights:
+ if not name.startswith("proj."):
+ name = f"proj.{name}"
+ if name not in params_dict:
+ logger.warning("Skipping unknown audio projector weight: %s", name)
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/vision_encoder.py b/vllm_omni/model_executor/models/ming_flash_omni/vision_encoder.py
new file mode 100644
index 0000000000..7976d76ce8
--- /dev/null
+++ b/vllm_omni/model_executor/models/ming_flash_omni/vision_encoder.py
@@ -0,0 +1,125 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Adapted from Ming repository qwen3_moe_vit.py
+# https://github.com/inclusionAI/Ming
+
+from collections.abc import Iterable
+
+import torch
+import torch.nn as nn
+from vllm.logger import init_logger
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.models.qwen3_omni_moe_thinker import (
+ Qwen3Omni_VisionTransformer,
+)
+from vllm.model_executor.models.utils import WeightsMapper
+
+logger = init_logger(__name__)
+
+
+def _adapt_vision_config(vision_config):
+ # Adapt Ming's Qwen3VLMoeVisionConfig to be compatible with vLLM's
+ # Qwen3Omni_VisionTransformer expectations.
+ if not hasattr(vision_config, "image_size") or vision_config.image_size is None:
+ if hasattr(vision_config, "num_position_embeddings") and vision_config.num_position_embeddings:
+ import math
+
+ num_grid = int(math.sqrt(vision_config.num_position_embeddings))
+ vision_config.image_size = num_grid * vision_config.patch_size
+ else:
+ vision_config.image_size = vision_config.patch_size * 14 # fallback
+
+ if not hasattr(vision_config, "apply_vit_abs_pos_embed"):
+ vision_config.apply_vit_abs_pos_embed = True
+
+ return vision_config
+
+
+class MingVisionEncoder(nn.Module):
+ """**Wrapper** around vLLM's Qwen3Omni_VisionTransformer for Ming."""
+
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_substr={
+ "deepstack_merger_list.": "merger_list.",
+ "merger.norm.": "merger.ln_q.",
+ "merger.linear_fc1.": "merger.mlp.0.",
+ "merger.linear_fc2.": "merger.mlp.2.",
+ }
+ )
+
+ def __init__(
+ self,
+ vision_config,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ adapted_config = _adapt_vision_config(vision_config)
+ norm_eps = 1e-6
+ self.encoder = Qwen3Omni_VisionTransformer(
+ vision_config=adapted_config,
+ norm_eps=norm_eps,
+ quant_config=quant_config,
+ prefix=f"{prefix}.encoder",
+ )
+ self.image_emb_dim = vision_config.out_hidden_size
+ self.use_deepstack = (
+ hasattr(vision_config, "deepstack_visual_indexes") and vision_config.deepstack_visual_indexes is not None
+ )
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return self.encoder.dtype
+
+ @property
+ def device(self) -> torch.device:
+ return self.encoder.device
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ grid_thw: torch.Tensor,
+ ) -> torch.Tensor:
+ """forward method of the vision encoder.
+
+ Args:
+ pixel_values: Flattened pixel values.
+ grid_thw: [num_images, 3] tensor of (t, h, w) grid sizes.
+
+ Returns:
+ If deepstack is enabled, returns concatenated multi-scale features
+ along the feature dim: [seq_len, hidden_size * (1 + num_deepstack)].
+ Otherwise returns [seq_len, hidden_size].
+ """
+ return self.encoder(pixel_values, grid_thw=grid_thw)
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ import re
+
+ def _remap_merger_list_inner(name: str) -> str:
+ name = re.sub(r"(merger_list\.\d+)\.norm\.", r"\1.ln_q.", name)
+ name = re.sub(r"(merger_list\.\d+)\.linear_fc1\.", r"\1.mlp.0.", name)
+ name = re.sub(r"(merger_list\.\d+)\.linear_fc2\.", r"\1.mlp.2.", name)
+
+ return name
+
+ remapped_weights = self.hf_to_vllm_mapper.apply(weights)
+ remapped_weights = ((_remap_merger_list_inner(name), tensor) for name, tensor in remapped_weights)
+ loaded_params = self.encoder.load_weights(remapped_weights)
+
+ loaded_params = {f"encoder.{loaded_param}" for loaded_param in loaded_params}
+
+ return loaded_params
diff --git a/vllm_omni/model_executor/models/omnivoice/omnivoice.py b/vllm_omni/model_executor/models/omnivoice/omnivoice.py
index a3603a3c39..7fde8f16fa 100644
--- a/vllm_omni/model_executor/models/omnivoice/omnivoice.py
+++ b/vllm_omni/model_executor/models/omnivoice/omnivoice.py
@@ -15,6 +15,7 @@
import numpy as np
import torch
import torch.nn as nn
+import torchaudio
from transformers.feature_extraction_utils import BatchFeature
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
@@ -77,31 +78,21 @@ def _ensure_cached_runtime_components(self, model_dir: str, config: OmniVoiceCon
self.text_tokenizer = AutoTokenizer.from_pretrained(model_dir)
- # Audio tokenizer for encoding reference audio
+ # Audio tokenizer for encoding reference audio (requires transformers>=5.3)
audio_tokenizer_path = os.path.join(model_dir, "audio_tokenizer")
- if os.path.isdir(audio_tokenizer_path):
- try:
- from transformers import (
- AutoFeatureExtractor,
- HiggsAudioV2TokenizerModel,
- )
- except ImportError as e:
- raise ImportError(
- "OmniVoice voice cloning requires transformers with "
- "HiggsAudioV2TokenizerModel. Upgrade transformers or "
- "use text-only mode (no reference audio)."
- ) from e
+ try:
+ from transformers import (
+ AutoFeatureExtractor,
+ HiggsAudioV2TokenizerModel,
+ )
self.audio_tokenizer = HiggsAudioV2TokenizerModel.from_pretrained(audio_tokenizer_path, device_map="cpu")
self.feature_extractor = AutoFeatureExtractor.from_pretrained(audio_tokenizer_path)
self.audio_tokenizer.eval()
- else:
+ except ImportError:
self.audio_tokenizer = None
self.feature_extractor = None
- logger.warning(
- "audio_tokenizer not found at %s, voice cloning disabled",
- audio_tokenizer_path,
- )
+ logger.warning("Voice cloning disabled (requires transformers>=5.3.0).")
self._cached_model_dir = model_dir
@@ -166,20 +157,16 @@ def _call_hf_processor(
if self.feature_extractor is not None:
target_sr = self.feature_extractor.sampling_rate
if sr != target_sr:
- import torchaudio
-
audio_signal = torchaudio.functional.resample(audio_signal, sr, target_sr)
# Encode reference audio to 8-codebook tokens
- if self.audio_tokenizer is not None:
- with torch.inference_mode():
- ref_audio_tokens = self.audio_tokenizer.encode(audio_signal) # [8, T_ref]
- if ref_audio_tokens.dim() == 3:
- ref_audio_tokens = ref_audio_tokens.squeeze(0) # [8, T_ref]
- else:
- raise RuntimeError(
- "Audio tokenizer not available for voice cloning. Ensure audio_tokenizer/ exists in model directory."
- )
+ if self.audio_tokenizer is None:
+ raise RuntimeError("Voice cloning requires transformers>=5.3.0. Try: uv pip install 'transformers>=5.3.0'")
+
+ with torch.inference_mode():
+ ref_audio_tokens = self.audio_tokenizer.encode(audio_signal) # [8, T_ref]
+ if ref_audio_tokens.dim() == 3:
+ ref_audio_tokens = ref_audio_tokens.squeeze(0) # [8, T_ref]
ft = BatchFeature(
{
diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py b/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py
new file mode 100644
index 0000000000..b44d08eb32
--- /dev/null
+++ b/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py
@@ -0,0 +1,78 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Qwen2.5-Omni pipeline topology (frozen).
+
+Stage 0: Thinker — multimodal understanding + text generation
+Stage 1: Talker — text embeddings → speech tokens
+Stage 2: Code2Wav — speech tokens → audio waveform
+"""
+
+from vllm_omni.config.stage_config import (
+ PipelineConfig,
+ StageExecutionType,
+ StagePipelineConfig,
+)
+
+_PROC = "vllm_omni.model_executor.stage_input_processors.qwen2_5_omni"
+
+QWEN2_5_OMNI_PIPELINE = PipelineConfig(
+ model_type="qwen2_5_omni",
+ model_arch="Qwen2_5OmniForConditionalGeneration",
+ stages=(
+ StagePipelineConfig(
+ stage_id=0,
+ model_stage="thinker",
+ execution_type=StageExecutionType.LLM_AR,
+ input_sources=(),
+ final_output=True,
+ final_output_type="text",
+ owns_tokenizer=True,
+ requires_multimodal_data=True,
+ engine_output_type="latent",
+ sampling_constraints={"detokenize": True},
+ ),
+ StagePipelineConfig(
+ stage_id=1,
+ model_stage="talker",
+ execution_type=StageExecutionType.LLM_AR,
+ input_sources=(0,),
+ engine_output_type="latent",
+ custom_process_input_func=f"{_PROC}.thinker2talker",
+ sampling_constraints={
+ "detokenize": True,
+ "stop_token_ids": [8294],
+ },
+ ),
+ StagePipelineConfig(
+ stage_id=2,
+ model_stage="code2wav",
+ execution_type=StageExecutionType.LLM_GENERATION,
+ input_sources=(1,),
+ final_output=True,
+ final_output_type="audio",
+ engine_output_type="audio",
+ sampling_constraints={"detokenize": True},
+ ),
+ ),
+)
+
+
+# Single-stage thinker-only variant for the abort test.
+QWEN2_5_OMNI_THINKER_ONLY_PIPELINE = PipelineConfig(
+ model_type="qwen2_5_omni_thinker_only",
+ model_arch="Qwen2_5OmniForConditionalGeneration",
+ stages=(
+ StagePipelineConfig(
+ stage_id=0,
+ model_stage="thinker",
+ execution_type=StageExecutionType.LLM_AR,
+ input_sources=(),
+ final_output=True,
+ final_output_type="text",
+ owns_tokenizer=True,
+ requires_multimodal_data=True,
+ engine_output_type="latent",
+ sampling_constraints={"detokenize": True},
+ ),
+ ),
+)
diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py
index 0307034089..617f0f9e32 100644
--- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py
+++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py
@@ -64,6 +64,10 @@
)
from vllm.sequence import IntermediateTensors
+from vllm_omni.quantization.component_config import (
+ resolve_encoder_quant_config,
+)
+
try:
import flash_attn
except (ImportError, ModuleNotFoundError):
@@ -359,6 +363,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.quant_config = quant_config
+ # Pre-quantized checkpoints (modelopt NVFP4/FP8/MXFP8) only quantize
+ # the Thinker LM. Vision encoder weights remain in BF16 with no FP8
+ # scale tensors; passing quant_config causes FP8 kernels to run on
+ # BF16 weights, producing garbage embeddings. Keep None for encoders.
+ visual_quant_config = resolve_encoder_quant_config(quant_config)
+
with self._mark_tower_model(vllm_config, "audio"):
if multimodal_config.get_limit_per_prompt("audio"):
self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config)
@@ -370,7 +380,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.visual = Qwen2_5_VisionTransformer(
vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
- quant_config=quant_config,
+ quant_config=visual_quant_config,
prefix=maybe_prefix(prefix, "visual"),
)
else:
diff --git a/vllm_omni/model_executor/models/qwen3_omni/pipeline.py b/vllm_omni/model_executor/models/qwen3_omni/pipeline.py
new file mode 100644
index 0000000000..1c69ec7957
--- /dev/null
+++ b/vllm_omni/model_executor/models/qwen3_omni/pipeline.py
@@ -0,0 +1,63 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Qwen3-Omni-MoE pipeline topology (frozen).
+
+Stage 0: Thinker — multimodal understanding + text generation
+Stage 1: Talker — text embeddings → RVQ codec codes
+Stage 2: Code2Wav — RVQ codes → audio waveform
+"""
+
+from vllm_omni.config.stage_config import (
+ PipelineConfig,
+ StageExecutionType,
+ StagePipelineConfig,
+)
+
+_PROC = "vllm_omni.model_executor.stage_input_processors.qwen3_omni"
+
+QWEN3_OMNI_PIPELINE = PipelineConfig(
+ model_type="qwen3_omni_moe",
+ model_arch="Qwen3OmniMoeForConditionalGeneration",
+ stages=(
+ StagePipelineConfig(
+ stage_id=0,
+ model_stage="thinker",
+ execution_type=StageExecutionType.LLM_AR,
+ input_sources=(),
+ final_output=True,
+ final_output_type="text",
+ owns_tokenizer=True,
+ requires_multimodal_data=True,
+ hf_config_name="thinker_config",
+ engine_output_type="latent",
+ custom_process_next_stage_input_func=(f"{_PROC}.thinker2talker_async_chunk"),
+ sampling_constraints={"detokenize": True},
+ ),
+ StagePipelineConfig(
+ stage_id=1,
+ model_stage="talker",
+ execution_type=StageExecutionType.LLM_AR,
+ input_sources=(0,),
+ hf_config_name="talker_config",
+ engine_output_type="latent",
+ custom_process_input_func=f"{_PROC}.thinker2talker",
+ custom_process_next_stage_input_func=(f"{_PROC}.talker2code2wav_async_chunk"),
+ sampling_constraints={
+ "detokenize": False,
+ "stop_token_ids": [2150],
+ },
+ ),
+ StagePipelineConfig(
+ stage_id=2,
+ model_stage="code2wav",
+ execution_type=StageExecutionType.LLM_GENERATION,
+ input_sources=(1,),
+ final_output=True,
+ final_output_type="audio",
+ hf_config_name="thinker_config",
+ engine_output_type="audio",
+ custom_process_input_func=f"{_PROC}.talker2code2wav",
+ sampling_constraints={"detokenize": True},
+ ),
+ ),
+)
diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
index ed6df6af36..f06ecf41d2 100644
--- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
+++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
@@ -180,6 +180,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
"trailing_text_hidden",
"tts_pad_embed_projected",
}
+ # Keys that need to be accumulated across streaming inputs
+ self.streaming_accumulated_keys: set[str] = {
+ "thinker_prefill_embeddings",
+ "thinker_hidden_states",
+ }
elif self.model_stage == "code2wav":
self.enable_update_additional_information = True
@@ -610,13 +615,14 @@ def _init_special_tokens_embeddings(self) -> set[str]:
# Speaker token IDs (for voice selection)
# In Qwen3, speaker_id mapping is in talker_config.speaker_id
+ # Keys are lowercased for case-insensitive matching with serving layer.
if hasattr(talker_hf_config, "speaker_id") and talker_hf_config.speaker_id:
- self.tts_text_spk_token_ids = talker_hf_config.speaker_id
+ self.tts_text_spk_token_ids = {k.lower(): v for k, v in talker_hf_config.speaker_id.items()}
else:
# Default to audio_start_token_id if no speaker mapping
self.tts_text_spk_token_ids = {
"default": talker_hf_config.audio_start_token_id,
- "Ethan": talker_hf_config.audio_start_token_id,
+ "ethan": talker_hf_config.audio_start_token_id,
"prefix_caching": talker_hf_config.audio_start_token_id,
}
@@ -890,10 +896,11 @@ def _thinker_to_talker_prefill(
Returns:
(input_ids, input_embeds) for talker
"""
+ target_len = thinker_result_ids.shape[-1]
im_start_indexes = torch.cat(
(
torch.nonzero(input_ids[0] == self.config.im_start_token_id).squeeze(),
- torch.tensor([thinker_result_ids.shape[-1]], device=input_ids.device, dtype=input_ids.dtype),
+ torch.tensor([target_len], device=input_ids.device, dtype=input_ids.dtype),
),
dim=-1,
) # Shape [n_starts + 1]; Take batch 0 since batched inference is not supported here.
@@ -1028,8 +1035,35 @@ def talker_preprocess_decode(
return last_talker_hidden, text_step, update_dict
def _get_talker_user_parts(self, im_start_index, segment_end_index, multimodal_mask, thinker_hidden, thinker_embed):
+ clamped = min(
+ segment_end_index,
+ multimodal_mask.shape[0],
+ thinker_hidden.shape[0],
+ thinker_embed.shape[0],
+ )
+ if clamped < segment_end_index:
+ logger.warning(
+ "_get_talker_user_parts: segment_end_index %d clamped to %d "
+ "(embed=%d, hidden=%d, mask=%d). "
+ "This usually means _merge_pd_embeddings failed to merge "
+ "prefill embeddings – check PD prefill_mm keys.",
+ segment_end_index,
+ clamped,
+ thinker_embed.shape[0],
+ thinker_hidden.shape[0],
+ multimodal_mask.shape[0],
+ )
+ segment_end_index = clamped
+ seg_len = segment_end_index - im_start_index
+ if seg_len <= 0:
+ return torch.empty(
+ (0, self.config.talker_config.text_config.hidden_size),
+ device=thinker_hidden.device,
+ dtype=torch.bfloat16,
+ )
+
user_talker_part = torch.empty(
- (segment_end_index - im_start_index, self.config.talker_config.text_config.hidden_size),
+ (seg_len, self.config.talker_config.text_config.hidden_size),
device=thinker_hidden.device,
dtype=torch.bfloat16,
)
diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py
index 2ceaafdb67..819e22e181 100644
--- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py
+++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py
@@ -1,510 +1,28 @@
-"""Qwen3-Omni Code Predictor -- optimized re-prefill, no KV cache.
+"""Qwen3-Omni Code Predictor -- thin wrapper over CodePredictorWrapper."""
-* SDPA attention (F.scaled_dot_product_attention) with native GQA support
-* HF-compatible numerics (float32 RMSNorm, float32 RoPE, separate linear layers)
-* Per-call embedding buffer to avoid cross-request aliasing
-* Pre-allocated position_ids (read-only, safe to persist)
-* torch.compile (epilogue_fusion=False) on inner transformer by default
-* Inline sampling (top-k + top-p) -- no custom op overhead
-"""
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
from vllm.config import VllmConfig
-from vllm.logger import init_logger
-from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
-from vllm.model_executor.model_loader.weight_utils import default_weight_loader
-
-from vllm_omni.platforms import current_omni_platform
-
-logger = init_logger(__name__)
-
-
-# ===================================================================
-# HF-numerics-compatible layers for code predictor
-# ===================================================================
-#
-# These use plain PyTorch ops (nn.Linear, manual RMSNorm in float32,
-# rotate_half RoPE) to produce outputs numerically identical to the
-# HuggingFace reference. vLLM's fused kernels (RMSNorm, QKVParallel,
-# get_rope) introduce small precision differences that compound across
-# the autoregressive steps of the code predictor, causing severe
-# audio quality degradation.
-#
-# See: https://github.com/vllm-project/vllm-omni/issues/2274
-
-
-class _RMSNorm(nn.Module):
- """RMSNorm matching HuggingFace's implementation exactly.
-
- Computes variance in float32 to avoid bfloat16 precision loss.
- """
-
- def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-
-def _rotate_half(x: torch.Tensor) -> torch.Tensor:
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
-
-
-class _RotaryEmbedding(nn.Module):
- """RoPE matching HuggingFace's implementation exactly.
-
- Forces float32 computation for cos/sin, matching HF's torch.autocast(enabled=False).
- """
-
- def __init__(self, config) -> None:
- super().__init__()
- head_dim = getattr(
- config,
- "head_dim",
- config.hidden_size // config.num_attention_heads,
- )
- rope_theta = getattr(config, "rope_theta", 10000.0)
- inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
- self.register_buffer("inv_freq", inv_freq, persistent=False)
-
- def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
- # position_ids: [batch, seq_len]
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
-
- # Force float32 (matching HF)
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
-
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
-
-
-class Qwen3OmniCodePredictorAttention(nn.Module):
- """Multi-head self-attention for code predictor.
-
- Uses ``F.scaled_dot_product_attention`` with HF-compatible RoPE and RMSNorm.
- No KV cache -- the code predictor always re-prefills the full (short)
- sequence each AR step.
-
- Input : [B, seq_len, hidden_size]
- Output: [B, seq_len, hidden_size]
- """
-
- def __init__(
- self,
- config,
- prefix: str = "",
- ):
- super().__init__()
- cp_cfg = config.code_predictor_config
- self.num_heads = cp_cfg.num_attention_heads
- self.num_kv_heads = cp_cfg.num_key_value_heads
- self.head_dim = getattr(
- cp_cfg,
- "head_dim",
- cp_cfg.hidden_size // cp_cfg.num_attention_heads,
- )
- self.hidden_size = cp_cfg.hidden_size
- self.scaling = self.head_dim**-0.5
- self._use_gqa = self.num_kv_heads != self.num_heads
-
- # Separate q/k/v projections matching HF (no fused packing)
- self.q_proj = nn.Linear(
- self.hidden_size,
- self.num_heads * self.head_dim,
- bias=False,
- )
- self.k_proj = nn.Linear(
- self.hidden_size,
- self.num_kv_heads * self.head_dim,
- bias=False,
- )
- self.v_proj = nn.Linear(
- self.hidden_size,
- self.num_kv_heads * self.head_dim,
- bias=False,
- )
- self.o_proj = nn.Linear(
- self.num_heads * self.head_dim,
- self.hidden_size,
- bias=False,
- )
- self.q_norm = _RMSNorm(self.head_dim, eps=cp_cfg.rms_norm_eps)
- self.k_norm = _RMSNorm(self.head_dim, eps=cp_cfg.rms_norm_eps)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- ) -> torch.Tensor:
- bsz, seq_len, _ = hidden_states.shape
- hidden_shape_q = (bsz, seq_len, self.num_heads, self.head_dim)
- hidden_shape_kv = (bsz, seq_len, self.num_kv_heads, self.head_dim)
-
- q = self.q_norm(self.q_proj(hidden_states).view(hidden_shape_q)).transpose(1, 2)
- k = self.k_norm(self.k_proj(hidden_states).view(hidden_shape_kv)).transpose(1, 2)
- v = self.v_proj(hidden_states).view(hidden_shape_kv).transpose(1, 2)
-
- cos, sin = position_embeddings
- # cos/sin are [batch, seq_len, head_dim], need unsqueeze at dim=1 for heads
- cos = cos.unsqueeze(1) # [batch, 1, seq_len, head_dim]
- sin = sin.unsqueeze(1)
- q = (q * cos) + (_rotate_half(q) * sin)
- k = (k * cos) + (_rotate_half(k) * sin)
-
- attn_out = F.scaled_dot_product_attention(
- q,
- k,
- v,
- scale=self.scaling,
- is_causal=True,
- enable_gqa=self._use_gqa,
- )
-
- attn_out = attn_out.transpose(1, 2).reshape(bsz, seq_len, -1)
- output = self.o_proj(attn_out)
- return output
-
-
-# ===================================================================
-# MLP
-# ===================================================================
-
-
-class Qwen3OmniCodePredictorMLP(nn.Module):
- """SiLU-gated MLP for code predictor, matching HF's implementation."""
-
- def __init__(
- self,
- config,
- prefix: str = "",
- ):
- super().__init__()
- hidden_size = config.code_predictor_config.hidden_size
- intermediate_size = config.code_predictor_config.intermediate_size
-
- self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
- self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
- self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- return self.down_proj(F.silu(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
-
-
-# ===================================================================
-# Decoder Layer
-# ===================================================================
-
-
-class Qwen3OmniCodePredictorDecoderLayer(nn.Module):
- """Transformer decoder layer (SDPA, no KV cache)."""
-
- def __init__(
- self,
- config,
- prefix: str = "",
- ) -> None:
- super().__init__()
- self.self_attn = Qwen3OmniCodePredictorAttention(
- config,
- prefix=f"{prefix}.self_attn",
- )
- self.mlp = Qwen3OmniCodePredictorMLP(
- config,
- prefix=f"{prefix}.mlp",
- )
- cp_cfg = config.code_predictor_config
- self.input_layernorm = _RMSNorm(cp_cfg.hidden_size, eps=cp_cfg.rms_norm_eps)
- self.post_attention_layernorm = _RMSNorm(cp_cfg.hidden_size, eps=cp_cfg.rms_norm_eps)
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- ) -> torch.Tensor:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- hidden_states = self.self_attn(hidden_states, position_embeddings)
- hidden_states = residual + hidden_states
+from vllm_omni.model_executor.models.common.qwen3_code_predictor import (
+ CodePredictorWrapper,
+ CodePredictorWrapperConfig,
+)
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states
+class Qwen3OmniMoeTalkerCodePredictor(CodePredictorWrapper):
+ """Qwen3-Omni code predictor (no CUDA graphs, VocabParallelEmbedding)."""
-# ===================================================================
-# Base Transformer Model (re-prefill, no KV cache)
-# ===================================================================
-
-
-class Qwen3OmniCodePredictorBaseModel(nn.Module):
- """Inner transformer for code predictor.
-
- Signature: ``forward(inputs_embeds, position_ids) -> hidden_states``
- -- plain Tensor in, plain Tensor out (no namedtuple).
- """
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- super().__init__()
- config = vllm_config.model_config.hf_config.code_predictor_config
- self.config = config
-
- self.codec_embedding = nn.ModuleList(
- [VocabParallelEmbedding(config.vocab_size, config.hidden_size) for _ in range(config.num_code_groups - 1)]
- )
-
- self.layers = nn.ModuleList(
- [
- Qwen3OmniCodePredictorDecoderLayer(
- vllm_config.model_config.hf_config,
- prefix=f"{prefix}.layers.{idx}",
- )
- for idx in range(config.num_hidden_layers)
- ]
- )
- self.norm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = _RotaryEmbedding(config)
-
- def forward(
- self,
- inputs_embeds: torch.Tensor,
- position_ids: torch.Tensor,
- ) -> torch.Tensor:
- hidden_states = inputs_embeds
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
- for layer in self.layers:
- hidden_states = layer(hidden_states, position_embeddings)
- hidden_states = self.norm(hidden_states)
- return hidden_states
-
-
-# ===================================================================
-# Code Predictor Wrapper (optimized re-prefill, persistent buffers)
-# ===================================================================
-
-
-class Qwen3OmniMoeTalkerCodePredictor(nn.Module):
- """Optimized code predictor -- re-prefill approach, no KV cache.
-
- Each AR step forwards the full growing sequence (len 2 -> num_code_groups+1)
- through the transformer. The extra O(T^2) FLOPs are negligible for
- short sequences, and this avoids all KV-cache management overhead.
-
- Optimizations:
- 1. Per-call embedding buffer -- avoids cross-request aliasing.
- 2. Pre-allocated position_ids -- no torch.arange per step.
- 3. Cached module references -- bypass ModuleList indexing.
- 4. torch.compile on inner transformer.
- 5. Inline sampling (top-k + top-p) -- no custom op overhead.
- """
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- super().__init__()
-
- config = vllm_config.model_config.hf_config
- self.config = config
- self.quant_config = vllm_config.quant_config
- self.prefix = prefix
-
- self.num_code_groups = config.code_predictor_config.num_code_groups
- self._hidden_size = config.code_predictor_config.hidden_size
-
- self.model = Qwen3OmniCodePredictorBaseModel(
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
+ cp_config = vllm_config.model_config.hf_config.code_predictor_config
+ super().__init__(
vllm_config=vllm_config,
+ cp_config=cp_config,
+ wrapper_config=CodePredictorWrapperConfig(
+ use_cuda_graphs=False,
+ use_parallel_embedding=True,
+ use_projection=False,
+ return_proj_buf=True,
+ sampling_mode="stored",
+ ),
+ talker_hidden_size=cp_config.hidden_size,
prefix=prefix,
)
-
- # One lm_head per residual layer (layers 1 .. G-1)
- self.lm_head = nn.ModuleList(
- [
- nn.Linear(
- config.code_predictor_config.hidden_size,
- config.code_predictor_config.vocab_size,
- bias=False,
- )
- for _ in range(self.num_code_groups - 1)
- ]
- )
-
- self.set_sampling_params()
-
- # Lazily initialised position ids (read-only, safe to persist)
- self._pos_ids: torch.Tensor | None = None
-
- # Cached plain-list refs (set once)
- self._lm_heads: list | None = None
- self._codec_embeds: list | None = None
-
- # Model forward (optionally compiled)
- self._model_fwd: object | None = None
-
- def set_sampling_params(self, top_k: int = 50, top_p: float = 0.8):
- """Configure sampling parameters to maintain consistency with previous implementation."""
- self._top_k = top_k
- self._top_p = top_p
- logger.debug(f"Sampling parameters updated: top_k={top_k}, top_p={top_p}s")
-
- # ------------------------------------------------------------------
- # Lazy-init helpers
- # ------------------------------------------------------------------
-
- def _ensure_pos_ids(self, device: torch.device) -> None:
- if self._pos_ids is not None and self._pos_ids.device == device:
- return
- max_seq = self.num_code_groups + 1
- # [1, max_seq] for HF-style RoPE (will be expanded to [bsz, seq_len] at use)
- self._pos_ids = torch.arange(max_seq, dtype=torch.long, device=device).unsqueeze(0)
-
- def _ensure_cached_refs(self) -> None:
- if self._lm_heads is not None:
- return
- self._lm_heads = list(self.lm_head)
- self._codec_embeds = list(self.model.codec_embedding)
-
- def _ensure_model_fwd(self) -> None:
- if self._model_fwd is not None:
- return
- if current_omni_platform.supports_torch_inductor():
- # torch.compile fuses RMSNorm/RoPE in ways that lose float32
- # precision, compounding across AR steps. Use epilogue_fusion=False
- # to disable the problematic fusions while still getting kernel
- # fusion benefits for the linear layers and SDPA.
- self._model_fwd = torch.compile(
- self.model.forward,
- dynamic=True,
- options={
- "epilogue_fusion": False,
- },
- )
- logger.info("code_predictor: torch.compile enabled (no epilogue fusion)")
- else:
- self._model_fwd = self.model.forward
- logger.info("code_predictor: using eager mode (no torch.compile)")
-
- # ------------------------------------------------------------------
- # Forward -- re-prefill + inline sampling
- # ------------------------------------------------------------------
-
- @torch.inference_mode()
- def forward(
- self,
- layer0_code: torch.Tensor,
- layer0_embed: torch.Tensor,
- last_talker_hidden: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """Predict residual codebooks 1..G-1 autoregressively via re-prefill.
-
- Args:
- layer0_code: [bsz, 1] int64
- layer0_embed: [bsz, 1, hidden_size]
- last_talker_hidden: [bsz, 1, hidden_size]
-
- Returns:
- all_codes: [bsz, num_code_groups, 1]
- proj_buf: [bsz, num_code_groups + 1, hidden_size]
- pos 0 = last_talker_hidden (NOT a codec embed)
- pos 1 = layer0_embed
- pos 2.. = `codec_embedding[i](predicted_code_i)`
- """
- bsz = int(layer0_code.shape[0])
- device = layer0_code.device
- dtype = last_talker_hidden.dtype
- num_groups = self.num_code_groups
-
- # Lazy init (read-only caches only)
- self._ensure_pos_ids(device)
- self._ensure_model_fwd()
- self._ensure_cached_refs()
-
- # Allocate proj_buf locally each call to avoid cross-call aliasing
- max_seq = num_groups + 1
- proj_buf = torch.zeros(bsz, max_seq, self._hidden_size, dtype=dtype, device=device)
- pos_ids = self._pos_ids
- model_fwd = self._model_fwd
- lm_heads = self._lm_heads
- codec_embeds = self._codec_embeds
-
- # Output codes
- all_codes = torch.empty(bsz, num_groups, 1, dtype=torch.int64, device=device)
- all_codes[:, 0] = layer0_code
-
- # Fill buffer positions 0 & 1
- proj_buf[:bsz, 0:1, :] = last_talker_hidden
- proj_buf[:bsz, 1:2, :] = layer0_embed
-
- # Autoregressive loop: predict layers 1..G-1
- for step in range(1, num_groups):
- seq_len = step + 1
- projected = proj_buf[:bsz, :seq_len, :]
- # position_ids: [batch, seq_len] for HF-style RoPE
- step_pos_ids = pos_ids[:, :seq_len].expand(bsz, -1)
-
- hidden_out = model_fwd(projected, step_pos_ids)
-
- # Inline sampling: top-k -> top-p -> softmax -> multinomial
- logits = lm_heads[step - 1](hidden_out[:, -1, :]) # [bsz, vocab]
- if self._top_k > 0:
- topk_vals, _ = logits.topk(self._top_k, dim=-1)
- logits = logits.masked_fill(logits < topk_vals[:, -1:], float("-inf"))
- if self._top_p < 1.0:
- sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True)
- cumulative_probs = F.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
- # Remove tokens with cumulative probability above top_p
- remove_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= self._top_p
- sorted_logits[remove_mask] = float("-inf")
- logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
- probs = F.softmax(logits, dim=-1)
- code = torch.multinomial(probs, num_samples=1) # [bsz, 1]
-
- all_codes[:, step] = code
-
- # Embed predicted code -> next buffer position
- new_embed = codec_embeds[step - 1](code) # [batch, 1, hidden_size]
- proj_buf[:bsz, step + 1 : step + 2, :] = new_embed
-
- return all_codes, proj_buf[:bsz]
-
- # ------------------------------------------------------------------
- # Weight loading
- # ------------------------------------------------------------------
-
- def load_weights(self, weights: list[tuple[str, torch.Tensor]]) -> set[str]:
- """Load weights directly (no fused projection remapping needed).
-
- Since we use separate nn.Linear for q/k/v/o and gate/up/down,
- weight names match the HF checkpoint directly.
- """
- params_dict = dict(self.named_parameters())
- loaded_params: set[str] = set()
-
- for name, loaded_weight in weights:
- # Skip rotary embeddings
- if "rotary_emb.inv_freq" in name:
- continue
-
- param = params_dict.get(name)
- if param is None:
- continue
-
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- weight_loader(param, loaded_weight)
- loaded_params.add(name)
-
- return loaded_params
diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py
index 671ffb6cb1..d03a96fd85 100644
--- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py
+++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py
@@ -119,7 +119,10 @@
from vllm_omni.model_executor.models.qwen2_5_omni.qwen2_5_omni_thinker import (
Qwen2_5OmniConditionalGenerationMixin,
)
-from vllm_omni.quantization.component_config import ComponentQuantizationConfig
+from vllm_omni.quantization.component_config import (
+ PRE_QUANTIZED_METHODS,
+ ComponentQuantizationConfig,
+)
try:
import flash_attn
@@ -1114,21 +1117,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.multimodal_config = multimodal_config
self.quant_config = quant_config
- # Pre-quantized checkpoints (modelopt NVFP4/FP8/MXFP8) quantize the
- # entire thinker — audio tower, visual encoder, and language model
- # all share the same quant method. Dynamic quantization methods
- # (e.g. --quantization fp8) should only target the language model.
- _PRE_QUANTIZED_METHODS = {"modelopt", "modelopt_fp4", "modelopt_mxfp8"}
+ # Pre-quantized checkpoints (modelopt NVFP4/FP8/MXFP8) only quantize
+ # the Thinker LM (language model). Vision and audio encoder weights
+ # remain in BF16 and have no corresponding scale tensors in the
+ # checkpoint. Dynamic quantization methods (e.g. --quantization fp8)
+ # should also only target the language model.
if isinstance(quant_config, ComponentQuantizationConfig):
audio_quant_config = quant_config.resolve("audio_tower")
visual_quant_config = quant_config.resolve("visual")
language_quant_config = quant_config.resolve("language_model")
elif quant_config is not None:
- if quant_config.get_name() in _PRE_QUANTIZED_METHODS:
- # Pre-quantized: pass quant_config to all subcomponents.
- audio_quant_config = quant_config
- visual_quant_config = quant_config
+ if quant_config.get_name() in PRE_QUANTIZED_METHODS:
+ # Pre-quantized: only the Thinker LM is quantized.
+ # Vision/audio encoder weights are BF16 with no FP8 scales;
+ # passing quant_config to them causes FP8 kernels to run on
+ # BF16 weights (producing garbage embeddings). Keep None.
+ audio_quant_config = None
+ visual_quant_config = None
language_quant_config = quant_config
else:
# Dynamic quantization: scope to language_model only.
diff --git a/vllm_omni/model_executor/models/qwen3_tts/pipeline.py b/vllm_omni/model_executor/models/qwen3_tts/pipeline.py
new file mode 100644
index 0000000000..5051715cea
--- /dev/null
+++ b/vllm_omni/model_executor/models/qwen3_tts/pipeline.py
@@ -0,0 +1,48 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Qwen3-TTS pipeline: Talker (text → RVQ codec) → Code2Wav (codec → audio).
+
+Chunked vs end-to-end mode is dispatched from ``deploy.async_chunk``.
+"""
+
+from vllm_omni.config.stage_config import (
+ PipelineConfig,
+ StageExecutionType,
+ StagePipelineConfig,
+)
+
+_PROC = "vllm_omni.model_executor.stage_input_processors.qwen3_tts"
+
+QWEN3_TTS_PIPELINE = PipelineConfig(
+ model_type="qwen3_tts",
+ # Pipeline-level default; the code2wav stage overrides per-stage below.
+ model_arch="Qwen3TTSTalkerForConditionalGeneration",
+ stages=(
+ StagePipelineConfig(
+ stage_id=0,
+ model_stage="qwen3_tts",
+ execution_type=StageExecutionType.LLM_AR,
+ input_sources=(),
+ owns_tokenizer=True,
+ engine_output_type="latent",
+ async_chunk_process_next_stage_input_func=(f"{_PROC}.talker2code2wav_async_chunk"),
+ sampling_constraints={
+ "detokenize": False,
+ "stop_token_ids": [2150],
+ },
+ ),
+ StagePipelineConfig(
+ stage_id=1,
+ model_stage="code2wav",
+ execution_type=StageExecutionType.LLM_GENERATION,
+ input_sources=(0,),
+ final_output=True,
+ final_output_type="audio",
+ engine_output_type="audio",
+ model_arch="Qwen3TTSCode2Wav",
+ sync_process_input_func=f"{_PROC}.talker2code2wav",
+ sampling_constraints={"detokenize": True},
+ extras={"tts_args": {"max_instructions_length": 500}},
+ ),
+ ),
+)
diff --git a/vllm_omni/model_executor/models/qwen3_tts/pipeline.yaml b/vllm_omni/model_executor/models/qwen3_tts/pipeline.yaml
deleted file mode 100644
index fd8ea3a3f4..0000000000
--- a/vllm_omni/model_executor/models/qwen3_tts/pipeline.yaml
+++ /dev/null
@@ -1,93 +0,0 @@
-model_type: qwen3_tts
-async_chunk: true
-
-stages:
- - stage_id: 0
- model_stage: qwen3_tts
- stage_type: llm
- is_comprehension: true
- input_sources: []
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- runtime:
- devices: "0"
- engine_args:
- max_num_seqs: 10
- model_arch: Qwen3TTSTalkerForConditionalGeneration
- hf_overrides:
- architectures: [Qwen3TTSTalkerForConditionalGeneration]
- enforce_eager: false
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: latent
- gpu_memory_utilization: 0.08
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 512
- max_model_len: 4096
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: false
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 1
- model_stage: code2wav
- stage_type: llm
- input_sources: [0]
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- final_output: true
- final_output_type: audio
- runtime:
- devices: "0"
- engine_args:
- max_num_seqs: 1
- model_arch: Qwen3TTSCode2Wav
- hf_overrides:
- architectures: [Qwen3TTSCode2Wav]
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: audio
- gpu_memory_utilization: 0.08
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 65536
- max_model_len: 65536
- input_connectors:
- from_stage_0: connector_of_shared_memory
- tts_args:
- max_instructions_length: 500
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: true
- repetition_penalty: 1.0
-
-connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536
- codec_streaming: true
- connector_get_sleep_s: 0.01
- connector_get_max_wait_first_chunk: 3000
- connector_get_max_wait: 300
- codec_chunk_frames: 25
- # Match the decoder sliding attention window to avoid chunk-boundary noise.
- codec_left_context_frames: 72
-
-edges:
- - from: 0
- to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py
index 1e84eaebaa..8d2f0686ae 100644
--- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py
+++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py
@@ -1,318 +1,27 @@
+"""Qwen3-TTS Code Predictor -- thin wrapper over CodePredictorWrapper."""
+
from __future__ import annotations
from collections.abc import Iterable
import torch
-import torch.nn as nn
-import torch.nn.functional as F
from vllm.config import VllmConfig
from vllm.config.vllm import set_current_vllm_config
-from vllm.logger import init_logger
-from vllm.model_executor.model_loader.weight_utils import (
- default_weight_loader,
-)
-from vllm_omni.platforms import current_omni_platform
+from vllm_omni.model_executor.models.common.qwen3_code_predictor import (
+ CodePredictorBaseModel,
+ CodePredictorWrapper,
+ CodePredictorWrapperConfig,
+)
from .configuration_qwen3_tts import Qwen3TTSTalkerCodePredictorConfig, Qwen3TTSTalkerConfig
-logger = init_logger(__name__)
-
-
-# ===================================================================
-# HF-numerics-compatible layers for code predictor
-# ===================================================================
-#
-# These use plain PyTorch ops (nn.Linear, manual RMSNorm in float32,
-# rotate_half RoPE) to produce outputs numerically identical to the
-# HuggingFace reference. vLLM's fused kernels (RMSNorm, QKVParallel,
-# get_rope) introduce small precision differences that compound across
-# the 15 autoregressive steps of the code predictor, causing severe
-# audio quality degradation (UTMOS ~4.26 → ~2.66).
-#
-# See: https://github.com/vllm-project/vllm-omni/issues/2274
-
-
-class _RMSNorm(nn.Module):
- """RMSNorm matching HuggingFace's Qwen3TTSRMSNorm exactly.
-
- Computes variance in float32 to avoid bfloat16 precision loss.
- """
-
- def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-
-def _rotate_half(x: torch.Tensor) -> torch.Tensor:
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
-
-
-class _RotaryEmbedding(nn.Module):
- """RoPE matching HuggingFace's Qwen3TTSRotaryEmbedding exactly.
-
- Forces float32 computation for cos/sin, matching HF's torch.autocast(enabled=False).
- """
-
- def __init__(self, config: Qwen3TTSTalkerCodePredictorConfig) -> None:
- super().__init__()
- head_dim = getattr(
- config,
- "head_dim",
- config.hidden_size // config.num_attention_heads,
- )
- # Standard default RoPE
- rope_theta = getattr(config, "rope_theta", 10000.0)
- inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
- self.register_buffer("inv_freq", inv_freq, persistent=False)
-
- def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
- # position_ids: [batch, seq_len]
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
-
- # Force float32 (matching HF)
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
-
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
-
-
-class _CodePredictorAttention(nn.Module):
- """Standalone multi-head attention for code predictor.
-
- Uses F.scaled_dot_product_attention with HF-compatible RoPE and RMSNorm.
- Input: [B, seq_len, hidden_size], output: [B, seq_len, hidden_size].
- """
-
- def __init__(
- self,
- config: Qwen3TTSTalkerCodePredictorConfig,
- *,
- prefix: str = "",
- ) -> None:
- super().__init__()
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.num_kv_heads = config.num_key_value_heads
- self.head_dim = getattr(
- config,
- "head_dim",
- config.hidden_size // config.num_attention_heads,
- )
- self.scaling = self.head_dim**-0.5
- self._use_gqa = self.num_kv_heads != self.num_heads
-
- # Separate q/k/v projections matching HF (no fused packing)
- self.q_proj = nn.Linear(
- self.hidden_size,
- self.num_heads * self.head_dim,
- bias=getattr(config, "attention_bias", False),
- )
- self.k_proj = nn.Linear(
- self.hidden_size,
- self.num_kv_heads * self.head_dim,
- bias=getattr(config, "attention_bias", False),
- )
- self.v_proj = nn.Linear(
- self.hidden_size,
- self.num_kv_heads * self.head_dim,
- bias=getattr(config, "attention_bias", False),
- )
- self.o_proj = nn.Linear(
- self.num_heads * self.head_dim,
- self.hidden_size,
- bias=False,
- )
- self.q_norm = _RMSNorm(self.head_dim, eps=config.rms_norm_eps)
- self.k_norm = _RMSNorm(self.head_dim, eps=config.rms_norm_eps)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- ) -> torch.Tensor:
- bsz, seq_len, _ = hidden_states.shape
- hidden_shape_q = (bsz, seq_len, self.num_heads, self.head_dim)
- hidden_shape_kv = (bsz, seq_len, self.num_kv_heads, self.head_dim)
-
- q = self.q_norm(self.q_proj(hidden_states).view(hidden_shape_q)).transpose(1, 2)
- k = self.k_norm(self.k_proj(hidden_states).view(hidden_shape_kv)).transpose(1, 2)
- v = self.v_proj(hidden_states).view(hidden_shape_kv).transpose(1, 2)
-
- cos, sin = position_embeddings
- # cos/sin are [batch, seq_len, head_dim], need unsqueeze at dim=1 for heads
- cos = cos.unsqueeze(1) # [batch, 1, seq_len, head_dim]
- sin = sin.unsqueeze(1)
- q = (q * cos) + (_rotate_half(q) * sin)
- k = (k * cos) + (_rotate_half(k) * sin)
-
- attn_out = F.scaled_dot_product_attention(
- q,
- k,
- v,
- scale=self.scaling,
- is_causal=True,
- enable_gqa=self._use_gqa,
- )
-
- attn_out = attn_out.transpose(1, 2).reshape(bsz, seq_len, -1)
- output = self.o_proj(attn_out)
- return output
-
-
-class _CodePredictorMLP(nn.Module):
- """SiLU-gated MLP for code predictor, matching HF's Qwen3TTSTalkerTextMLP."""
-
- def __init__(
- self,
- config: Qwen3TTSTalkerCodePredictorConfig,
- *,
- prefix: str = "",
- ) -> None:
- super().__init__()
- self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
- self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
- self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
-
-
-class _CodePredictorDecoderLayer(nn.Module):
- """Transformer decoder layer for code predictor (SDPA, no KV cache)."""
-
- def __init__(
- self,
- config: Qwen3TTSTalkerCodePredictorConfig,
- *,
- prefix: str = "",
- ) -> None:
- super().__init__()
- self.self_attn = _CodePredictorAttention(config, prefix=f"{prefix}.self_attn")
- self.mlp = _CodePredictorMLP(config, prefix=f"{prefix}.mlp")
- self.input_layernorm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- ) -> torch.Tensor:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- hidden_states = self.self_attn(hidden_states, position_embeddings)
- hidden_states = residual + hidden_states
-
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states
-
-
-# ===================================================================
-# Code Predictor Transformer Model
-# ===================================================================
-
-
-class Qwen3TTSTalkerCodePredictorModelVLLM(nn.Module):
- """Transformer model for the code predictor (re-prefill, no KV cache)."""
-
- def __init__(
- self,
- config: Qwen3TTSTalkerCodePredictorConfig,
- *,
- talker_hidden_size: int | None = None,
- prefix: str = "",
- ) -> None:
- super().__init__()
- self.config = config
-
- self.layers = nn.ModuleList(
- [_CodePredictorDecoderLayer(config, prefix=f"{prefix}.layers.{i}") for i in range(config.num_hidden_layers)]
- )
- self.norm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = _RotaryEmbedding(config)
-
- # Codec embeddings: one per residual group. Stored in talker hidden dim
- # (some checkpoints use talker_hidden_size != code_predictor hidden_size).
- emb_dim = int(talker_hidden_size) if talker_hidden_size is not None else int(config.hidden_size)
- self.codec_embedding = nn.ModuleList(
- [nn.Embedding(config.vocab_size, emb_dim) for _ in range(config.num_code_groups - 1)]
- )
-
- def get_input_embeddings(self) -> nn.ModuleList:
- return self.codec_embedding
-
- def forward(
- self,
- inputs_embeds: torch.Tensor,
- position_ids: torch.Tensor,
- ) -> torch.Tensor:
- hidden_states = inputs_embeds
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
- for layer in self.layers:
- hidden_states = layer(hidden_states, position_embeddings)
- hidden_states = self.norm(hidden_states)
- return hidden_states
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- params_dict = dict(self.named_parameters(remove_duplicate=False))
- loaded_params: set[str] = set()
- for name, loaded_weight in weights:
- if "rotary_emb.inv_freq" in name:
- continue
- param = params_dict.get(name)
- if param is None:
- continue
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- weight_loader(param, loaded_weight)
- loaded_params.add(name)
- return loaded_params
-
-
-# ===================================================================
-# Code Predictor Wrapper (optimized re-prefill + torch.compile)
-# ===================================================================
-
-
-class Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM(nn.Module):
- """vLLM-native code_predictor for the AR talker (residual codebooks).
+# Backward-compat alias used by tests
+Qwen3TTSTalkerCodePredictorModelVLLM = CodePredictorBaseModel
- Re-prefill approach: each AR step forwards the full growing sequence
- through the 5-layer transformer. No KV cache needed. This trades
- ~O(T^2) extra attention FLOPs (negligible for T=16, 5 layers) for
- zero KV cache management overhead and a simpler execution model.
- Uses HF-compatible layers (plain nn.Linear, float32 RMSNorm, rotate_half
- RoPE) to ensure numerical fidelity with the reference implementation.
- Precision matters here because small errors compound across 15 AR steps.
-
- Optimizations preserved:
- 1. torch.compile on model forward -- fuses small kernel launches.
- 2. Pre-allocated embedding buffer [B, max_seq, H] -- no torch.cat per step.
- 3. Projection caching -- each token projected once and cached.
- 4. Pre-allocated position_ids -- no torch.arange per step.
- 5. Inline sampling -- no custom op / forward_context overhead.
- 6. Cached module references -- bypass nn.Module.__call__ overhead.
- 7. CUDA graphs per batch-size bucket.
- """
+class Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM(CodePredictorWrapper):
+ """Qwen3-TTS code predictor (CUDA graphs, per-call sampling, projection)."""
def __init__(
self,
@@ -322,250 +31,24 @@ def __init__(
talker_config: Qwen3TTSTalkerConfig,
prefix: str = "code_predictor",
) -> None:
- super().__init__()
- self._vllm_config = vllm_config
- self.config = config
- self.talker_config = talker_config
-
- self.model = Qwen3TTSTalkerCodePredictorModelVLLM(
- config,
+ super().__init__(
+ vllm_config=vllm_config,
+ cp_config=config,
+ wrapper_config=CodePredictorWrapperConfig(
+ use_cuda_graphs=True,
+ use_parallel_embedding=False,
+ use_projection=(config.hidden_size != talker_config.hidden_size),
+ return_proj_buf=False,
+ sampling_mode="per_call",
+ ),
talker_hidden_size=int(talker_config.hidden_size),
- prefix=f"{prefix}.model",
+ prefix=prefix,
)
-
- self.lm_head = nn.ModuleList(
- [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_code_groups - 1)]
- )
-
- if config.hidden_size != talker_config.hidden_size:
- self.small_to_mtp_projection = nn.Linear(talker_config.hidden_size, config.hidden_size, bias=True)
- else:
- self.small_to_mtp_projection = nn.Identity()
-
- self._num_groups = int(config.num_code_groups)
- self._talker_hidden = int(talker_config.hidden_size)
- self._cp_hidden = int(config.hidden_size)
-
- # Pre-allocated buffers (lazily initialized on first forward).
- self._proj_buf: torch.Tensor | None = None
- self._model_dtype: torch.dtype | None = None
-
- # torch.compile + warmup state (lazily initialized in _setup_compile).
- self._compiled_model_fwd = None
- self._bucket_sizes: list[int] = []
- self._bucket_pos_ids: dict[int, torch.Tensor] = {}
- self._lm_heads_list: list[nn.Module] | None = None
- self._codec_embeds_list: list[nn.Module] | None = None
- self._cuda_graphs: dict[int, tuple[torch.cuda.CUDAGraph, torch.Tensor]] = {}
-
- def get_input_embeddings(self) -> nn.ModuleList:
- return self.model.get_input_embeddings()
+ # Store talker_config for backward compat (accessed by some callers)
+ self.talker_config = talker_config
+ self._vllm_config = vllm_config
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ """Load weights with vllm config context (required for VocabParallelEmbedding)."""
with set_current_vllm_config(self._vllm_config):
- loaded: set[str] = set()
- model_weights: list[tuple[str, torch.Tensor]] = []
- other_weights: list[tuple[str, torch.Tensor]] = []
- for name, w in weights:
- if name.startswith("model."):
- model_weights.append((name[len("model.") :], w))
- else:
- other_weights.append((name, w))
-
- loaded_model = self.model.load_weights(model_weights)
- loaded |= {f"model.{n}" for n in loaded_model}
-
- params = dict(self.named_parameters(remove_duplicate=False))
- for name, w in other_weights:
- if name not in params:
- continue
- default_weight_loader(params[name], w)
- loaded.add(name)
-
- return loaded
-
- # ------------------------------------------------------------------
- # Pre-allocated buffer management
- # ------------------------------------------------------------------
-
- def _ensure_buffers(self, device: torch.device, dtype: torch.dtype) -> None:
- max_seq = self._num_groups + 1
- if self._proj_buf is not None and self._proj_buf.device == device and self._proj_buf.dtype == dtype:
- return
- max_bsz = self._vllm_config.scheduler_config.max_num_seqs
- self._proj_buf = torch.zeros(
- max_bsz,
- max_seq,
- self._cp_hidden,
- dtype=dtype,
- device=device,
- )
-
- def _setup_compile(self) -> None:
- """Lazily set up torch.compile with manual CUDA graph capture."""
- if self._compiled_model_fwd is not None:
- return
- # Cache model parameter dtype so forward() doesn't need to query it
- # on every call. Also ensures warmup buffers match model precision
- # even when upstream modules produce a different dtype (#2385).
- self._model_dtype = next(self.model.parameters()).dtype
- self._lm_heads_list = list(self.lm_head)
- self._codec_embeds_list = list(self.model.codec_embedding)
- if not current_omni_platform.supports_torch_inductor():
- logger.warning_once("code_predictor: torch.compile disabled")
- self._compiled_model_fwd = self.model.forward
- return
-
- # torch.compile fuses RMSNorm/RoPE in ways that lose float32
- # precision, compounding across 15 AR steps. Use torch.compile
- # with options that disable the problematic fusions while still
- # getting kernel fusion benefits for the linear layers and SDPA.
- self._compiled_model_fwd = torch.compile(
- self.model.forward,
- dynamic=False,
- options={
- "epilogue_fusion": False,
- },
- )
- self._warmup_buckets()
- self._capture_cuda_graphs()
- logger.info("code_predictor: torch.compile (no epilogue fusion) + CUDA graphs")
-
- def _padded_bsz(self, bsz: int) -> int:
- for bucket in self._bucket_sizes:
- if bsz <= bucket:
- return bucket
- return bsz
-
- def _warmup_buckets(self) -> None:
- """Warmup power-of-2 batch-size buckets to front-load Inductor compilation."""
- max_bsz = self._vllm_config.scheduler_config.max_num_seqs
- bucket_sizes = [1 << i for i in range(max_bsz.bit_length()) if (1 << i) <= max_bsz]
- if max_bsz not in bucket_sizes:
- bucket_sizes.append(max_bsz)
- self._bucket_sizes = sorted(bucket_sizes)
-
- max_seq = self._num_groups + 1
- device = next(self.model.parameters()).device
-
- # Ensure proj_buf matches model parameter dtype to avoid dtype
- # mismatch during warmup compilation (see #2385).
- self._ensure_buffers(device, self._model_dtype)
- proj_buf = self._proj_buf
- for bsz in self._bucket_sizes:
- # position_ids: [batch, seq_len] for HF-style RoPE
- pos_ids = torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(bsz, -1)
- self._bucket_pos_ids[bsz] = pos_ids
- for _ in range(3):
- self._compiled_model_fwd(proj_buf[:bsz, :max_seq, :], pos_ids)
- logger.info("code_predictor: warmup done for buckets %s", self._bucket_sizes)
-
- def _capture_cuda_graphs(self) -> None:
- """Capture a CUDA graph per bucket using vLLM's global graph pool."""
- from vllm.platforms import current_platform
-
- pool = current_platform.get_global_graph_pool()
-
- max_seq = self._num_groups + 1
- proj_buf = self._proj_buf
-
- for bsz in self._bucket_sizes:
- static_input = proj_buf[:bsz, :max_seq, :]
- pos_ids = self._bucket_pos_ids[bsz]
-
- g = torch.cuda.CUDAGraph()
- with torch.cuda.graph(g, pool=pool):
- static_output = self._compiled_model_fwd(static_input, pos_ids)
-
- self._cuda_graphs[bsz] = (g, static_output)
-
- logger.info("code_predictor: captured CUDA graphs for buckets %s", self._bucket_sizes)
-
- # ------------------------------------------------------------------
- # Optimized forward: re-prefill + torch.compile + projection cache
- # ------------------------------------------------------------------
-
- @torch.inference_mode()
- def forward(
- self,
- layer0_code: torch.Tensor,
- layer0_embed: torch.Tensor,
- last_talker_hidden: torch.Tensor,
- do_sample: bool = True,
- temperature: float = 0.9,
- top_k: int = 50,
- top_p: float = 1.0,
- ) -> torch.Tensor:
- """Predict residual codebooks 1..Q-1 autoregressively via re-prefill.
-
- torch.compile fuses the ~60 small kernel launches per step into fewer
- fused kernels, reducing kernel launch overhead by ~75%.
-
- Projection caching: each token is projected once via small_to_mtp_projection
- and cached in _proj_buf, avoiding redundant re-projection of past tokens.
- """
- bsz = int(layer0_code.shape[0])
- num_groups = self._num_groups
- device = layer0_code.device
-
- all_codes = torch.empty(bsz, num_groups, dtype=torch.long, device=device)
- all_codes[:, 0] = layer0_code.reshape(bsz)
-
- # _setup_compile caches _model_dtype on first call; use it for buffers
- # so they always match model weight precision (#2385).
- self._setup_compile()
- dtype = self._model_dtype
- self._ensure_buffers(device, dtype)
-
- proj_buf = self._proj_buf
- max_seq = self._num_groups + 1
-
- projection = self.small_to_mtp_projection
- model_fwd = self._compiled_model_fwd
- lm_heads = self._lm_heads_list
- codec_embeds = self._codec_embeds_list
-
- use_sampling = do_sample and temperature > 0
- inv_temperature = 1.0 / max(temperature, 1e-6) if use_sampling else 0.0
- if use_sampling and top_p != 1.0:
- raise NotImplementedError(
- "top_p sampling is not implemented for the vLLM-native code predictor; please set top_p=1.0."
- )
-
- padded_bsz = self._padded_bsz(bsz)
- proj_buf[:padded_bsz].zero_()
-
- proj_buf[:bsz, 0, :] = projection(last_talker_hidden.reshape(bsz, 1, -1).to(dtype)).reshape(bsz, -1)
- proj_buf[:bsz, 1, :] = projection(layer0_embed.reshape(bsz, 1, -1).to(dtype)).reshape(bsz, -1)
- full_pos_ids = self._bucket_pos_ids.get(padded_bsz)
- if full_pos_ids is None:
- full_pos_ids = torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(padded_bsz, -1)
-
- # Use captured CUDA graph if available, otherwise call compiled fn.
- cuda_graph_entry = self._cuda_graphs.get(padded_bsz)
-
- for step in range(1, num_groups):
- if cuda_graph_entry is not None:
- cuda_graph_entry[0].replay()
- hidden_out = cuda_graph_entry[1]
- else:
- hidden_out = model_fwd(proj_buf[:padded_bsz, :max_seq, :], full_pos_ids)
- logits = lm_heads[step - 1](hidden_out[:bsz, step, :])
-
- if use_sampling:
- scaled = logits * inv_temperature
- if top_k > 0:
- topk_vals, _ = scaled.topk(top_k, dim=-1)
- scaled = scaled.masked_fill(scaled < topk_vals[:, -1:], float("-inf"))
- probs = F.softmax(scaled, dim=-1)
- next_ids = torch.multinomial(probs, num_samples=1)
- else:
- next_ids = logits.argmax(dim=-1, keepdim=True)
-
- all_codes[:, step] = next_ids.reshape(bsz)
-
- if step < num_groups - 1:
- new_embed = codec_embeds[step - 1](next_ids)
- proj_buf[:bsz, step + 1, :] = projection(new_embed.reshape(bsz, 1, -1)).reshape(bsz, -1)
-
- return all_codes
+ return super().load_weights(weights)
diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py
index 9f8aff6aff..d9cbcf7d4e 100644
--- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py
+++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py
@@ -13,7 +13,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from librosa.filters import mel as librosa_mel_fn
from transformers import AutoTokenizer
from transformers.activations import ACT2FN
from transformers.utils.hub import cached_file
@@ -24,9 +23,11 @@
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.qwen3 import Qwen3Model
from vllm.model_executor.models.utils import AutoWeightsLoader, PPMissingLayer, WeightsMapper, maybe_prefix
+from vllm.multimodal.audio import AudioResampler
from vllm.sequence import IntermediateTensors
from vllm_omni.model_executor.models.output_templates import OmniOutput
+from vllm_omni.utils.audio import mel_filter_bank
from vllm_omni.utils.voice_cache import VoiceEmbeddingCache
from .configuration_qwen3_tts import Qwen3TTSConfig, Qwen3TTSSpeakerEncoderConfig, Qwen3TTSTalkerConfig
@@ -258,14 +259,19 @@ def mel_spectrogram(
fmax: int | None = None,
center: bool = False,
) -> torch.Tensor:
- """Calculate mel spectrogram of an input signal using librosa mel filterbank and torch STFT."""
+ """Calculate mel spectrogram of an input signal using torchaudio mel filterbank and torch STFT."""
if torch.min(y) < -1.0:
logger.warning("Min value of input waveform signal is %s", torch.min(y))
if torch.max(y) > 1.0:
logger.warning("Max value of input waveform signal is %s", torch.max(y))
device = y.device
- mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
- mel_basis = torch.from_numpy(mel).float().to(device)
+ mel_basis = mel_filter_bank(
+ sr=sampling_rate,
+ n_fft=n_fft,
+ n_mels=num_mels,
+ fmin=fmin,
+ fmax=fmax,
+ ).to(device)
hann_window = torch.hann_window(win_size).to(device)
padding = (n_fft - hop_size) // 2
y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
@@ -871,7 +877,7 @@ def _load_audio_to_np(self, x: str) -> tuple[np.ndarray, int]:
Uses upstream vLLM's MediaConnector for http(s) URLs and ``file:``
URIs, with unrestricted local access (offline inference is trusted).
"""
- import librosa
+ from vllm.multimodal.media.audio import load_audio
if self._is_url(x):
from vllm.multimodal.media import MediaConnector
@@ -883,7 +889,7 @@ def _load_audio_to_np(self, x: str) -> tuple[np.ndarray, int]:
with io.BytesIO(wav_bytes) as f:
audio, sr = sf.read(f, dtype="float32", always_2d=False)
else:
- audio, sr = librosa.load(x, sr=None, mono=True)
+ audio, sr = load_audio(x, sr=None, mono=True)
if isinstance(audio, np.ndarray) and audio.ndim > 1:
audio = np.mean(audio, axis=-1)
@@ -1089,9 +1095,8 @@ def _extract_speaker_embedding(self, wav: np.ndarray, sr: int) -> torch.Tensor:
# Resample to 24kHz for speaker encoder.
target_sr = int(getattr(self.config.speaker_encoder_config, "sample_rate", 24000))
if sr != target_sr:
- import librosa
-
- wav = librosa.resample(y=wav.astype(np.float32), orig_sr=int(sr), target_sr=target_sr)
+ resampler = AudioResampler(target_sr=target_sr)
+ wav = resampler.resample(wav.astype(np.float32), orig_sr=int(sr))
sr = target_sr
# Follow official implementation: mel_spectrogram expects 24kHz.
@@ -1434,11 +1439,16 @@ def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None:
)
if ref_ids is None:
ref_text = _as_singleton(info_dict.get("ref_text"))
- if not isinstance(ref_text, str) or not ref_text.strip():
- raise ValueError("Base in-context voice cloning requires `ref_text` or tokenized `ref_ids`.")
- ref_ids = tok(self._build_ref_text(ref_text), return_tensors="pt", padding=False)["input_ids"].to(
- device=input_ids.device
- )
+ if isinstance(ref_text, str) and ref_text.strip():
+ ref_ids = tok(
+ self._build_ref_text(ref_text),
+ return_tensors="pt",
+ padding=False,
+ )["input_ids"].to(device=input_ids.device)
+ else:
+ logger.warning("Base ICL: ref_text/ref_ids missing, falling back to x-vector-only mode.")
+ in_context_mode = False
+ if in_context_mode:
icl_input_embed, trailing_text_hidden = self._generate_icl_prompt(
text_id=input_ids[:, 3:-5],
ref_id=ref_ids[:, 3:-2],
diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py
index 503e6bbc83..14bfbc5eed 100644
--- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py
+++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py
@@ -17,12 +17,13 @@
import urllib.request
from urllib.parse import urlparse
-import librosa
import numpy as np
import soundfile as sf
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoConfig, AutoFeatureExtractor, AutoModel
+from vllm.multimodal.audio import AudioResampler
+from vllm.multimodal.media.audio import load_audio as _load_audio_file
from .tokenizer_12hz.configuration_qwen3_tts_tokenizer_v2 import Qwen3TTSTokenizerV2Config
from .tokenizer_12hz.modeling_qwen3_tts_tokenizer_v2 import (
@@ -154,13 +155,14 @@ def load_audio(
with io.BytesIO(wav_bytes) as f:
audio, sr = sf.read(f, dtype="float32", always_2d=False)
else:
- audio, sr = librosa.load(x, sr=None, mono=True)
+ audio, sr = _load_audio_file(x, sr=None, mono=True)
if audio.ndim > 1:
audio = np.mean(audio, axis=-1)
if sr != target_sr:
- audio = librosa.resample(y=audio, orig_sr=sr, target_sr=target_sr)
+ resampler = AudioResampler(target_sr=target_sr)
+ audio = resampler.resample(audio, orig_sr=sr)
return audio.astype(np.float32)
@@ -208,7 +210,8 @@ def _normalize_audio_inputs(
if a.ndim > 1:
a = np.mean(a, axis=-1)
if int(sr) != target_sr:
- a = librosa.resample(y=a.astype(np.float32), orig_sr=int(sr), target_sr=target_sr)
+ resampler = AudioResampler(target_sr=target_sr)
+ a = resampler.resample(a.astype(np.float32), orig_sr=int(sr))
out.append(a.astype(np.float32))
return out
diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/assets/mel_filters.npz b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/assets/mel_filters.npz
deleted file mode 100644
index 28ea26909d..0000000000
Binary files a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/assets/mel_filters.npz and /dev/null differ
diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/speech_vq.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/speech_vq.py
index de2c69702c..f7e664c74d 100644
--- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/speech_vq.py
+++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/speech_vq.py
@@ -17,16 +17,17 @@
from itertools import accumulate
import onnxruntime
-import sox
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.compliance.kaldi as kaldi
-from librosa.filters import mel as librosa_mel_fn
from torch import Tensor
+from vllm_omni.model_executor.models.whisper_utils import Conv1d, ConvTranspose1d
+from vllm_omni.utils.audio import mel_filter_bank, peak_normalize
+
from .core_vq import DistributedGroupResidualVectorQuantization
-from .whisper_encoder import Conv1d, ConvTranspose1d, WhisperEncoder
+from .whisper_encoder import WhisperEncoder
def dynamic_range_compression_torch(x, c=1, clip_val=1e-5):
@@ -103,14 +104,14 @@ def extract(self, audio, **kwargs):
y = audio
if len(list(self.mel_basis.keys())) == 0:
- mel = librosa_mel_fn(
+ mel = mel_filter_bank(
sr=self.sampling_rate,
n_fft=self.filter_length,
n_mels=self.n_mel_channels,
fmin=self.mel_fmin,
fmax=self.mel_fmax,
)
- self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
+ self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)] = mel.to(y.device)
self.hann_window[str(y.device)] = torch.hann_window(self.win_length).to(y.device)
y = torch.nn.functional.pad(
@@ -152,9 +153,6 @@ def __init__(self, audio_codec_with_xvector):
audio_codec_with_xvector, sess_options=option, providers=providers
)
- self.tfm = sox.Transformer()
- self.tfm.norm(db_level=-6)
-
self.mel_ext = MelSpectrogramFeatures(
filter_length=1024,
hop_length=160,
@@ -182,8 +180,7 @@ def extract_code(self, audio):
return norm_embedding.numpy(), ref_mel.permute(0, 2, 1).squeeze(0).numpy()
def sox_norm(self, audio):
- wav_norm = self.tfm.build_array(input_array=audio, sample_rate_in=16000)
- return wav_norm
+ return peak_normalize(audio, db_level=-6)
class WhisperEncoderVQ(WhisperEncoder):
diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py
index e3bd6e1c3a..7756720b2b 100644
--- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py
+++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py
@@ -14,7 +14,6 @@
# limitations under the License.
import math
import operator
-import os
from functools import cache
from itertools import accumulate
@@ -24,6 +23,8 @@
from torch import Tensor, nn
from vllm_omni.diffusion.attention.backends.utils.fa import HAS_FLASH_ATTN, flash_attn_varlen_func
+from vllm_omni.model_executor.models.whisper_utils import Conv1d, Linear, sinusoids
+from vllm_omni.utils.audio import mel_filter_bank
N_FFT = 400
HOP_LENGTH = 160
@@ -31,21 +32,8 @@
@cache
def mel_filters(device, n_mels: int) -> torch.Tensor:
- """
- load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
- Allows decoupling librosa dependency; saved using:
-
- np.savez_compressed(
- "mel_filters.npz",
- mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
- mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
- )
- """
- assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
-
- filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
- with np.load(filters_path, allow_pickle=False) as f:
- return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
+ """Compute mel filterbank matrix for projecting STFT into a Mel spectrogram."""
+ return mel_filter_bank(sr=16000, n_fft=N_FFT, n_mels=n_mels).to(device)
def log_mel_spectrogram(
@@ -115,30 +103,6 @@ def get_mel_audio(audio, padding=False, audio_vq_ds_rate=1, n_mels=128):
return mel
-def sinusoids(length, channels, max_timescale=10000):
- """Returns sinusoids for positional embedding"""
- assert channels % 2 == 0
- log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
- inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
- scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
- return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
-
-
-class Conv1d(nn.Conv1d):
- def _conv_forward(self, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor:
- return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
-
-
-class ConvTranspose1d(nn.ConvTranspose1d):
- def _conv_forward(self, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor:
- return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
-
-
-class Linear(nn.Linear):
- def forward(self, x: Tensor) -> Tensor:
- return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype))
-
-
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int, use_flash_attention: bool = True):
super().__init__()
diff --git a/vllm_omni/model_executor/models/registry.py b/vllm_omni/model_executor/models/registry.py
index 3b51f20023..5a466dbd62 100644
--- a/vllm_omni/model_executor/models/registry.py
+++ b/vllm_omni/model_executor/models/registry.py
@@ -145,6 +145,18 @@
"fish_speech_dac_decoder",
"FishSpeechDACDecoder",
),
+ ## VoxCPM
+ "VoxCPMForConditionalGeneration": (
+ "voxcpm",
+ "voxcpm",
+ "VoxCPMForConditionalGeneration",
+ ),
+ ## VoxCPM2
+ "VoxCPM2TalkerForConditionalGeneration": (
+ "voxcpm2",
+ "voxcpm2_talker",
+ "VoxCPM2TalkerForConditionalGeneration",
+ ),
## Voxtral TTS
"VoxtralTTSForConditionalGeneration": (
"voxtral_tts",
@@ -162,6 +174,23 @@
"dynin_omni",
"DyninOmniForConditionalGeneration",
),
+ ## Ming-flash-omni-2.0
+ "MingFlashOmniForConditionalGeneration": (
+ "ming_flash_omni",
+ "ming_flash_omni",
+ "MingFlashOmniForConditionalGeneration",
+ ),
+ "MingFlashOmniThinkerForConditionalGeneration": (
+ "ming_flash_omni",
+ "ming_flash_omni_thinker",
+ "MingFlashOmniThinkerForConditionalGeneration",
+ ),
+ # Alias: HF repo currently ships this architecture name in config.json
+ "BailingMM2NativeForConditionalGeneration": (
+ "ming_flash_omni",
+ "ming_flash_omni",
+ "MingFlashOmniForConditionalGeneration",
+ ),
}
diff --git a/vllm_omni/model_executor/models/voxcpm/__init__.py b/vllm_omni/model_executor/models/voxcpm/__init__.py
new file mode 100644
index 0000000000..3b064c0f68
--- /dev/null
+++ b/vllm_omni/model_executor/models/voxcpm/__init__.py
@@ -0,0 +1,7 @@
+from .configuration_voxcpm import VoxCPMConfig
+from .voxcpm import VoxCPMForConditionalGeneration
+
+__all__ = [
+ "VoxCPMConfig",
+ "VoxCPMForConditionalGeneration",
+]
diff --git a/vllm_omni/model_executor/models/voxcpm/configuration_voxcpm.py b/vllm_omni/model_executor/models/voxcpm/configuration_voxcpm.py
new file mode 100644
index 0000000000..ce1d809bd3
--- /dev/null
+++ b/vllm_omni/model_executor/models/voxcpm/configuration_voxcpm.py
@@ -0,0 +1,3 @@
+from vllm_omni.transformers_utils.configs.voxcpm import VoxCPMConfig
+
+__all__ = ["VoxCPMConfig"]
diff --git a/vllm_omni/model_executor/models/voxcpm/voxcpm.py b/vllm_omni/model_executor/models/voxcpm/voxcpm.py
new file mode 100644
index 0000000000..6fa36fc420
--- /dev/null
+++ b/vllm_omni/model_executor/models/voxcpm/voxcpm.py
@@ -0,0 +1,886 @@
+from __future__ import annotations
+
+import json
+import os
+import sys
+import tempfile
+import warnings
+import wave
+from collections.abc import Callable, Generator, Iterable
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import torch
+import torch.nn as nn
+from einops import rearrange
+from tqdm import tqdm
+from vllm.config import VllmConfig
+from vllm.logger import init_logger
+from vllm.sequence import IntermediateTensors
+
+from vllm_omni.model_executor.models.output_templates import OmniOutput
+
+from .voxcpm_loader import (
+ _build_prompt_cache_with_soundfile,
+ _device_to_string,
+ _force_cuda_available_for_npu,
+ _import_voxcpm_audio_vae_classes,
+ _import_voxcpm_base_model_class,
+ _is_torchcodec_load_error,
+ _normalize_dtype_name,
+ _prepare_runtime_model_dir,
+ _resolve_runtime_device,
+)
+from .voxcpm_runtime_utils import resolve_voxcpm_model_dir
+from .voxcpm_stage_wrappers import _DirectVoxCPMAudioVAE, _DirectVoxCPMLatentGenerator
+
+logger = init_logger(__name__)
+_VOXCPM_LATENT_MAGIC = 131071
+
+
+def _make_voxcpm_model_for_omni(base: type[Any]) -> type[Any]:
+ """Subclass upstream VoxCPMModel: local ``_inference`` + ``latents_only`` prompt-cache generation."""
+
+ from voxcpm.model.utils import get_dtype
+
+ class VoxCPMModelForOmni(base):
+ @torch.inference_mode()
+ def build_prompt_cache(self, *args: Any, **kwargs: Any):
+ try:
+ return super().build_prompt_cache(*args, **kwargs)
+ except (ImportError, ModuleNotFoundError, RuntimeError) as exc:
+ if not _is_torchcodec_load_error(exc):
+ raise
+ return _build_prompt_cache_with_soundfile(self, *args, **kwargs)
+
+ @torch.inference_mode()
+ def _inference(
+ self,
+ text: torch.Tensor,
+ text_mask: torch.Tensor,
+ feat: torch.Tensor,
+ feat_mask: torch.Tensor,
+ min_len: int = 2,
+ max_len: int = 2000,
+ inference_timesteps: int = 10,
+ cfg_value: float = 2.0,
+ streaming: bool = False,
+ streaming_prefix_len: int = 3,
+ ) -> Generator[tuple[torch.Tensor, torch.Tensor | list[torch.Tensor]], None, None]:
+ B, _, _, _ = feat.shape
+
+ feat_embed = self.feat_encoder(feat)
+ feat_embed = self.enc_to_lm_proj(feat_embed)
+
+ scale_emb = self.config.lm_config.scale_emb if self.config.lm_config.use_mup else 1.0
+ text_embed = self.base_lm.embed_tokens(text) * scale_emb
+ combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
+
+ prefix_feat_cond = feat[:, -1, ...]
+ pred_feat_seq: list[torch.Tensor] = []
+
+ audio_patch_count = int(feat_mask.sum().item())
+ if audio_patch_count > 0:
+ context_len = min(streaming_prefix_len - 1, audio_patch_count)
+ prompt_context_patches = list(feat[:, -context_len:, :, :].split(1, dim=1))
+ pred_feat_seq = prompt_context_patches + pred_feat_seq
+
+ enc_outputs, kv_cache_tuple = self.base_lm(
+ inputs_embeds=combined_embed,
+ is_causal=True,
+ )
+ self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
+
+ enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
+ lm_hidden = enc_outputs[:, -1, :]
+
+ residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
+ inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
+ is_causal=True,
+ )
+ self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
+ residual_hidden = residual_enc_outputs[:, -1, :]
+
+ for step_idx in tqdm(range(max_len)):
+ dit_hidden = self.lm_to_dit_proj(lm_hidden) + self.res_to_dit_proj(residual_hidden)
+ pred_feat = self.feat_decoder(
+ mu=dit_hidden,
+ patch_size=self.patch_size,
+ cond=prefix_feat_cond.transpose(1, 2).contiguous(),
+ n_timesteps=inference_timesteps,
+ cfg_value=cfg_value,
+ ).transpose(1, 2)
+
+ curr_embed = self.enc_to_lm_proj(self.feat_encoder(pred_feat.unsqueeze(1)))
+ pred_feat_seq.append(pred_feat.unsqueeze(1))
+ prefix_feat_cond = pred_feat
+
+ if streaming:
+ pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
+ feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
+ yield feat_pred, pred_feat_seq
+
+ stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
+ if step_idx > min_len and stop_flag == 1:
+ break
+
+ lm_hidden = self.base_lm.forward_step(
+ curr_embed[:, 0, :],
+ torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device),
+ ).clone()
+ lm_hidden = self.fsq_layer(lm_hidden)
+ residual_hidden = self.residual_lm.forward_step(
+ lm_hidden + curr_embed[:, 0, :],
+ torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device),
+ ).clone()
+
+ if not streaming:
+ pred_feat_seq_cat = torch.cat(pred_feat_seq, dim=1)
+ feat_pred = rearrange(pred_feat_seq_cat, "b t p d -> b d (t p)", b=B, p=self.patch_size)
+ yield feat_pred, pred_feat_seq_cat.squeeze(0).cpu()
+
+ @torch.inference_mode()
+ def generate_latents_with_prompt_cache(
+ self,
+ target_text: str,
+ prompt_cache: dict,
+ min_len: int = 2,
+ max_len: int = 2000,
+ inference_timesteps: int = 10,
+ cfg_value: float = 2.0,
+ retry_badcase: bool = False,
+ retry_badcase_max_times: int = 3,
+ retry_badcase_ratio_threshold: float = 6.0,
+ streaming_prefix_len: int = 3,
+ ) -> tuple[None, torch.Tensor, torch.Tensor]:
+ return next(
+ self._generate_with_prompt_cache(
+ target_text=target_text,
+ prompt_cache=prompt_cache,
+ min_len=min_len,
+ max_len=max_len,
+ inference_timesteps=inference_timesteps,
+ cfg_value=cfg_value,
+ retry_badcase=retry_badcase,
+ retry_badcase_max_times=retry_badcase_max_times,
+ retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
+ streaming=False,
+ streaming_prefix_len=streaming_prefix_len,
+ latents_only=True,
+ )
+ )
+
+ @torch.inference_mode()
+ def generate_latents_with_prompt_cache_streaming(
+ self,
+ target_text: str,
+ prompt_cache: dict,
+ min_len: int = 2,
+ max_len: int = 2000,
+ inference_timesteps: int = 10,
+ cfg_value: float = 2.0,
+ retry_badcase: bool = False,
+ retry_badcase_max_times: int = 3,
+ retry_badcase_ratio_threshold: float = 6.0,
+ streaming_prefix_len: int = 3,
+ ) -> Generator[tuple[None, torch.Tensor, torch.Tensor], None, None]:
+ return self._generate_with_prompt_cache(
+ target_text=target_text,
+ prompt_cache=prompt_cache,
+ min_len=min_len,
+ max_len=max_len,
+ inference_timesteps=inference_timesteps,
+ cfg_value=cfg_value,
+ retry_badcase=retry_badcase,
+ retry_badcase_max_times=retry_badcase_max_times,
+ retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
+ streaming=True,
+ streaming_prefix_len=streaming_prefix_len,
+ latents_only=True,
+ )
+
+ @torch.inference_mode()
+ def _generate_with_prompt_cache(
+ self,
+ target_text: str,
+ prompt_cache: dict,
+ min_len: int = 2,
+ max_len: int = 2000,
+ inference_timesteps: int = 10,
+ cfg_value: float = 2.0,
+ retry_badcase: bool = False,
+ retry_badcase_max_times: int = 3,
+ retry_badcase_ratio_threshold: float = 6.0,
+ streaming: bool = False,
+ streaming_prefix_len: int = 3,
+ latents_only: bool = False,
+ ) -> Generator[tuple[torch.Tensor | None, torch.Tensor, torch.Tensor | list[torch.Tensor]], None, None]:
+ if retry_badcase and streaming:
+ warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
+ retry_badcase = False
+ if prompt_cache is None:
+ prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
+ text = target_text
+ else:
+ prompt_audio_feat = prompt_cache["audio_feat"]
+ prompt_text = prompt_cache["prompt_text"]
+ text = prompt_text + target_text
+
+ text_token = torch.LongTensor(self.text_tokenizer(text))
+ text_token = torch.cat(
+ [
+ text_token,
+ torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
+ ],
+ dim=-1,
+ )
+ target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
+
+ audio_length = prompt_audio_feat.size(0)
+ text_length = text_token.shape[0]
+ text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
+ audio_pad_feat = torch.zeros(
+ (text_token.shape[0], self.patch_size, self.audio_vae.latent_dim),
+ dtype=torch.float32,
+ device=text_token.device,
+ )
+ text_token = torch.cat([text_token, text_pad_token])
+ audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
+ text_mask = (
+ torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
+ )
+ audio_mask = (
+ torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
+ )
+
+ text_token = text_token.unsqueeze(0).to(self.device)
+ text_mask = text_mask.unsqueeze(0).to(self.device)
+ audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
+ audio_mask = audio_mask.unsqueeze(0).to(self.device)
+
+ target_text_length = len(self.text_tokenizer(target_text))
+ retry_badcase_times = 0
+ while retry_badcase_times < retry_badcase_max_times:
+ inference_result = self._inference(
+ text_token,
+ text_mask,
+ audio_feat,
+ audio_mask,
+ min_len=min_len,
+ max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len),
+ inference_timesteps=inference_timesteps,
+ cfg_value=cfg_value,
+ streaming=streaming,
+ streaming_prefix_len=streaming_prefix_len,
+ )
+ if streaming:
+ patch_len = self.patch_size * self.chunk_size
+ for latent_pred, pred_audio_feat in inference_result:
+ if latents_only:
+ decode_audio = None
+ yield (decode_audio, target_text_token, latent_pred)
+ else:
+ decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
+ decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
+ yield (decode_audio, target_text_token, pred_audio_feat)
+ break
+
+ latent_pred, pred_audio_feat = next(inference_result)
+ if retry_badcase and pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
+ ratio = pred_audio_feat.shape[0] / target_text_length
+ print(f" Badcase detected, audio_text_ratio={ratio}, retrying...", file=sys.stderr)
+ retry_badcase_times += 1
+ continue
+ break
+
+ if not streaming:
+ if latents_only:
+ decode_audio = None
+ else:
+ decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
+ patch_len = self.patch_size * self.chunk_size
+ if audio_mask.sum().item() > 0:
+ decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
+ else:
+ decode_audio = decode_audio[..., :].squeeze(1).cpu()
+ yield (decode_audio, target_text_token, pred_audio_feat)
+
+ VoxCPMModelForOmni.__name__ = "VoxCPMModelForOmni"
+ VoxCPMModelForOmni.__qualname__ = "VoxCPMModelForOmni"
+ return VoxCPMModelForOmni
+
+
+def _import_voxcpm_model_class() -> type[Any]:
+ base = _import_voxcpm_base_model_class()
+ return _make_voxcpm_model_for_omni(base)
+
+
+def _load_native_voxcpm_model(
+ model_path: str,
+ *,
+ device: torch.device,
+ dtype: str | None,
+):
+ VoxCPMModel = _import_voxcpm_model_class()
+ model_dir = resolve_voxcpm_model_dir(model_path)
+ runtime_model_path = _prepare_runtime_model_dir(model_dir, target_device=device, target_dtype=dtype)
+
+ if device.type == "npu" and hasattr(torch, "npu"):
+ torch.npu.set_device(device)
+
+ with _force_cuda_available_for_npu(device):
+ return VoxCPMModel.from_local(
+ runtime_model_path,
+ optimize=device.type == "cuda",
+ )
+
+
+def _load_native_voxcpm_latent_generator(
+ model_path: str,
+ *,
+ device: torch.device,
+ dtype: str | None,
+) -> _DirectVoxCPMLatentGenerator:
+ return _DirectVoxCPMLatentGenerator(_load_native_voxcpm_model(model_path, device=device, dtype=dtype))
+
+
+def _load_native_voxcpm_audio_vae(
+ model_path: str,
+ *,
+ device: torch.device,
+) -> _DirectVoxCPMAudioVAE:
+ AudioVAE, AudioVAEConfig = _import_voxcpm_audio_vae_classes()
+ model_dir = resolve_voxcpm_model_dir(model_path)
+ runtime_model_path = _prepare_runtime_model_dir(model_dir, target_device=device, target_dtype="float32")
+ config_dict = json.loads((Path(runtime_model_path) / "config.json").read_text())
+ audio_vae_config = config_dict.get("audio_vae_config")
+ audio_vae = AudioVAE(config=AudioVAEConfig(**audio_vae_config)) if audio_vae_config is not None else AudioVAE()
+
+ state_dict = torch.load(
+ Path(runtime_model_path) / "audiovae.pth",
+ map_location="cpu",
+ weights_only=True,
+ )["state_dict"]
+ audio_vae.load_state_dict(state_dict, strict=True)
+ audio_vae = audio_vae.to(device=device, dtype=torch.float32).eval()
+ if device.type == "npu" and hasattr(torch, "npu"):
+ torch.npu.set_device(device)
+ patch_size = int(config_dict.get("patch_size", 2))
+ return _DirectVoxCPMAudioVAE(audio_vae, patch_size=patch_size)
+
+
+class VoxCPMForConditionalGeneration(nn.Module):
+ input_modalities = "audio"
+ _LATENT_STAGES = {"latent_generator", "latent", "ar_dit"}
+ _VAE_STAGES = {"vae", "audio_vae"}
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ del prefix
+ self.vllm_config = vllm_config
+ self.model_path = vllm_config.model_config.model
+ self.model_stage = getattr(vllm_config.model_config, "model_stage", "latent_generator")
+ self.have_multimodal_outputs = True
+ self.has_preprocess = False
+ self.has_postprocess = False
+ self.enable_update_additional_information = True
+ self.requires_raw_input_tokens = True
+ self.inject_omni_request_id_into_runtime_info = True
+ self._pipeline = None
+ self._latent_stream_gens: dict[str, Any] = {}
+ self._latent_stream_terminal_pending: dict[str, int] = {}
+ self._latent_stream_completed: set[str] = set()
+ self._next_local_stream_key = 0
+ self._ar_emit_stop_token = True
+
+ def _runner_hidden_device_dtype(self) -> tuple[torch.device, torch.dtype]:
+ device = _resolve_runtime_device(self.vllm_config)
+ model_config = getattr(self.vllm_config, "model_config", None)
+ dtype = getattr(model_config, "dtype", torch.float32) if model_config is not None else torch.float32
+ return device, dtype
+
+ def _ensure_model_loaded(self):
+ if self._pipeline is not None:
+ return
+
+ target_device = _resolve_runtime_device(self.vllm_config)
+ model_dtype = getattr(self.vllm_config.model_config, "dtype", None)
+ normalized_dtype = _normalize_dtype_name(model_dtype)
+ if self.model_stage in self._LATENT_STAGES:
+ self._pipeline = _load_native_voxcpm_latent_generator(
+ self.model_path,
+ device=target_device,
+ dtype=normalized_dtype,
+ )
+ elif self.model_stage in self._VAE_STAGES:
+ self._pipeline = _load_native_voxcpm_audio_vae(
+ self.model_path,
+ device=target_device,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported VoxCPM model_stage: {self.model_stage}. "
+ "pure_voxcpm only supports split-stage latent_generator/vae inference."
+ )
+
+ logger.info("Loaded VoxCPM stage '%s' on %s", self.model_stage, _device_to_string(target_device))
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ del weights
+ self._ensure_model_loaded()
+ return set()
+
+ @staticmethod
+ def _extract_val(info: dict[str, Any], key: str, default: Any) -> Any:
+ value = info.get(key, default)
+ if isinstance(value, list):
+ return value[0] if value else default
+ return value
+
+ def _resolve_stream_request_key(self, info: dict[str, Any]) -> str:
+ request_key = info.get("__voxcpm_stream_key")
+ if request_key is not None:
+ return str(request_key)
+
+ request_key = info.get("_omni_req_id")
+ if request_key is not None:
+ request_key = str(request_key)
+ info["__voxcpm_stream_key"] = request_key
+ return request_key
+
+ request_key = f"voxcpm-local-{self._next_local_stream_key}"
+ self._next_local_stream_key += 1
+ info["__voxcpm_stream_key"] = request_key
+ return str(request_key)
+
+ def _recover_latent_from_input_ids(self, input_ids: torch.Tensor | None) -> torch.Tensor | None:
+ if input_ids is None or input_ids.numel() == 0:
+ return None
+ flat_ids = input_ids.detach().reshape(-1).to("cpu")
+ if flat_ids.numel() < 4 or int(flat_ids[0].item()) != _VOXCPM_LATENT_MAGIC:
+ return None
+ latent_dim = int(flat_ids[1].item())
+ time_dim = int(flat_ids[2].item())
+ payload = flat_ids[3:]
+ expected = latent_dim * time_dim
+ if latent_dim <= 0 or time_dim <= 0:
+ raise ValueError(f"Invalid VoxCPM latent header: latent_dim={latent_dim}, time_dim={time_dim}")
+ if int(payload.numel()) != expected:
+ raise ValueError(
+ "Invalid VoxCPM latent payload size: "
+ f"expected={expected}, actual={int(payload.numel())}, "
+ f"latent_dim={latent_dim}, time_dim={time_dim}"
+ )
+ packed = payload.to(dtype=torch.int32).to(torch.uint16)
+ return packed.view(torch.bfloat16).to(torch.float32).reshape(1, latent_dim, time_dim)
+
+ def _maybe_recover_vae_infos(
+ self,
+ infos: list[dict[str, Any]],
+ input_ids: torch.Tensor | None,
+ *,
+ async_chunk: bool,
+ ) -> list[dict[str, Any]]:
+ if not async_chunk:
+ return infos
+ if any(self._extract_val(info, "latent_audio_feat", None) is not None for info in infos):
+ return infos
+ recovered = self._recover_latent_from_input_ids(input_ids)
+ if recovered is None:
+ return infos
+ return [{"latent_audio_feat": recovered}]
+
+ @staticmethod
+ def _normalize_audio_samples(samples: Any) -> np.ndarray:
+ if isinstance(samples, torch.Tensor):
+ return samples.detach().cpu().float().reshape(-1).numpy()
+ return np.asarray(samples, dtype=np.float32).reshape(-1)
+
+ @classmethod
+ def _normalize_ref_audio(cls, ref_audio: Any) -> tuple[np.ndarray, int]:
+ if isinstance(ref_audio, str):
+ raise TypeError("String ref_audio should be handled as a path before waveform normalization.")
+
+ if isinstance(ref_audio, dict):
+ sample_rate = ref_audio.get("sample_rate") or ref_audio.get("sampling_rate") or ref_audio.get("sr")
+ samples = None
+ for key in ("audio", "wav", "samples", "array", "waveform"):
+ if key in ref_audio and ref_audio[key] is not None:
+ samples = ref_audio[key]
+ break
+ if sample_rate is None or samples is None:
+ raise ValueError("ref_audio dict must contain waveform data and sample rate.")
+ return cls._normalize_audio_samples(samples), int(sample_rate)
+
+ if isinstance(ref_audio, (list, tuple)):
+ if len(ref_audio) == 1:
+ return cls._normalize_ref_audio(ref_audio[0])
+ if len(ref_audio) == 2 and np.isscalar(ref_audio[1]):
+ return cls._normalize_audio_samples(ref_audio[0]), int(ref_audio[1])
+
+ raise TypeError(f"Unsupported ref_audio format: {type(ref_audio)!r}")
+
+ @staticmethod
+ def _write_temp_prompt_wav(waveform: np.ndarray, sample_rate: int) -> str:
+ prompt_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
+ prompt_file.close()
+
+ wav = np.asarray(waveform, dtype=np.float32).reshape(-1)
+ wav = np.clip(wav, -1.0, 1.0)
+ pcm16 = (wav * 32767.0).astype(np.int16)
+ with wave.open(prompt_file.name, "wb") as wav_file:
+ wav_file.setnchannels(1)
+ wav_file.setsampwidth(2)
+ wav_file.setframerate(int(sample_rate))
+ wav_file.writeframes(pcm16.tobytes())
+
+ return prompt_file.name
+
+ @classmethod
+ def _resolve_prompt_inputs(cls, info: dict[str, Any]) -> tuple[str | None, str | None, str | None]:
+ prompt_text = cls._extract_val(info, "prompt_text", None)
+ prompt_wav_path = cls._extract_val(info, "prompt_wav_path", None)
+ if prompt_wav_path:
+ if prompt_text is None:
+ prompt_text = cls._extract_val(info, "ref_text", None)
+ return prompt_wav_path, prompt_text, None
+
+ ref_audio = cls._extract_val(info, "ref_audio", None)
+ ref_text = cls._extract_val(info, "ref_text", None)
+ if ref_audio is None or ref_text is None:
+ return None, None, None
+ if isinstance(ref_audio, str):
+ return ref_audio, ref_text, None
+
+ waveform, sample_rate = cls._normalize_ref_audio(ref_audio)
+ temp_prompt_wav = cls._write_temp_prompt_wav(waveform, sample_rate)
+ return temp_prompt_wav, ref_text, temp_prompt_wav
+
+ def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor:
+ if input_ids.numel() == 0:
+ return torch.empty((0, 1), device=input_ids.device, dtype=torch.float32)
+ return torch.zeros((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.float32)
+
+ def _get_vocab_size(self) -> int:
+ model_config = getattr(self.vllm_config, "model_config", None)
+ if model_config is not None:
+ getter = getattr(model_config, "get_vocab_size", None)
+ if callable(getter):
+ try:
+ return int(getter())
+ except Exception:
+ pass
+ hf_config = getattr(model_config, "hf_text_config", None)
+ if hf_config is not None and hasattr(hf_config, "vocab_size"):
+ return int(hf_config.vocab_size)
+ return 32000
+
+ def _make_empty_output(
+ self,
+ *,
+ output_key: str,
+ payload_factory: Callable[[], torch.Tensor],
+ infos: list[dict[str, Any]],
+ sample_rate: int,
+ out_device: torch.device,
+ out_dtype: torch.dtype,
+ hidden_rows: int | None = None,
+ ) -> OmniOutput:
+ if hidden_rows is None:
+ hidden_rows = len(infos)
+ return OmniOutput(
+ text_hidden_states=torch.zeros((hidden_rows, 1), device=out_device, dtype=out_dtype),
+ multimodal_outputs={
+ output_key: [payload_factory() for _ in infos],
+ "sr": [torch.tensor(sample_rate, dtype=torch.int32) for _ in infos],
+ },
+ )
+
+ def _finalize_stage_output(
+ self,
+ *,
+ output_key: str,
+ outputs: list[torch.Tensor],
+ sample_rates: list[torch.Tensor],
+ out_device: torch.device,
+ out_dtype: torch.dtype,
+ hidden_rows: int | None = None,
+ ) -> OmniOutput:
+ multimodal_outputs: dict[str, Any] = {output_key: outputs, "sr": sample_rates}
+ if hidden_rows is not None:
+ text_hidden_states = torch.zeros((hidden_rows, 1), device=out_device, dtype=out_dtype)
+ elif outputs:
+ outputs_tensor = torch.stack(outputs)
+ text_hidden_states = (
+ outputs_tensor.unsqueeze(-1)
+ if outputs_tensor.ndim == 1
+ else outputs_tensor.reshape(-1, outputs_tensor.shape[-1])
+ )
+ else:
+ text_hidden_states = torch.zeros((0, 1), device=out_device, dtype=out_dtype)
+ text_hidden_states = text_hidden_states.to(device=out_device, dtype=out_dtype)
+ return OmniOutput(
+ text_hidden_states=text_hidden_states,
+ multimodal_outputs=multimodal_outputs,
+ )
+
+ def _forward_vae_stage(
+ self,
+ infos: list[dict[str, Any]],
+ *,
+ sample_rate: int,
+ async_chunk: bool,
+ out_device: torch.device,
+ out_dtype: torch.dtype,
+ ) -> OmniOutput:
+ if all(self._extract_val(info, "latent_audio_feat", None) is None for info in infos):
+ self._ar_emit_stop_token = True
+ return self._make_empty_output(
+ output_key="model_outputs",
+ payload_factory=lambda: torch.zeros((0,), dtype=torch.float32),
+ infos=infos,
+ sample_rate=sample_rate,
+ out_device=out_device,
+ out_dtype=out_dtype,
+ )
+
+ outputs: list[torch.Tensor] = []
+ sample_rates: list[torch.Tensor] = []
+ for info in infos:
+ latent_audio_feat = self._extract_val(info, "latent_audio_feat", None)
+ audio_tensor = self._pipeline.decode(latent_audio_feat, trim_streaming_patch=async_chunk)
+ outputs.append(audio_tensor.float().cpu())
+ sample_rates.append(torch.tensor(sample_rate, dtype=torch.int32))
+
+ self._ar_emit_stop_token = True
+ return self._finalize_stage_output(
+ output_key="model_outputs",
+ outputs=outputs,
+ sample_rates=sample_rates,
+ out_device=out_device,
+ out_dtype=out_dtype,
+ )
+
+ def _forward_latent_stage(
+ self,
+ infos: list[dict[str, Any]],
+ *,
+ sample_rate: int,
+ async_chunk: bool,
+ out_device: torch.device,
+ out_dtype: torch.dtype,
+ hidden_rows: int,
+ ) -> OmniOutput:
+ texts = [self._extract_val(info, "text", "") for info in infos]
+ if all(not text for text in texts):
+ self._ar_emit_stop_token = True
+ return self._make_empty_output(
+ output_key="latent_audio_feat",
+ payload_factory=lambda: torch.zeros((0,), dtype=torch.float32),
+ infos=infos,
+ sample_rate=sample_rate,
+ out_device=out_device,
+ out_dtype=out_dtype,
+ hidden_rows=hidden_rows,
+ )
+
+ outputs: list[torch.Tensor] = []
+ sample_rates: list[torch.Tensor] = []
+ last_chunk_flags: list[bool] | None = [] if async_chunk else None
+ payload_finished_flags: list[bool] | None = [] if async_chunk else None
+ for info in infos:
+ text = self._extract_val(info, "text", "")
+ cfg_value = float(self._extract_val(info, "cfg_value", 2.0))
+ inference_timesteps = int(self._extract_val(info, "inference_timesteps", 10))
+ min_len = int(self._extract_val(info, "min_len", 2))
+ max_len = int(self._extract_val(info, "max_len", self._extract_val(info, "max_new_tokens", 4096)))
+ retry_badcase = bool(self._extract_val(info, "retry_badcase", True))
+ retry_badcase_max_times = int(self._extract_val(info, "retry_badcase_max_times", 3))
+ retry_badcase_ratio_threshold = float(self._extract_val(info, "retry_badcase_ratio_threshold", 6.0))
+ streaming_prefix_len = int(self._extract_val(info, "streaming_prefix_len", 3))
+
+ request_key = self._resolve_stream_request_key(info)
+ created_temp: str | None = None
+
+ if async_chunk:
+ terminal_pending = self._latent_stream_terminal_pending.get(request_key, 0)
+ if terminal_pending > 0:
+ outputs.append(torch.zeros((0,), dtype=torch.float32))
+ assert last_chunk_flags is not None
+ last_chunk_flags.append(True)
+ assert payload_finished_flags is not None
+ payload_finished_flags.append(terminal_pending == 1)
+ if terminal_pending == 1:
+ self._latent_stream_terminal_pending.pop(request_key, None)
+ else:
+ self._latent_stream_terminal_pending[request_key] = terminal_pending - 1
+ sample_rates.append(torch.tensor(sample_rate, dtype=torch.int32))
+ continue
+
+ if request_key in self._latent_stream_completed:
+ outputs.append(torch.zeros((0,), dtype=torch.float32))
+ assert last_chunk_flags is not None
+ last_chunk_flags.append(True)
+ assert payload_finished_flags is not None
+ payload_finished_flags.append(False)
+ sample_rates.append(torch.tensor(sample_rate, dtype=torch.int32))
+ continue
+
+ if request_key not in self._latent_stream_gens:
+ prompt_wav_path, prompt_text, temp_prompt_wav = self._resolve_prompt_inputs(info)
+ created_temp = temp_prompt_wav
+ self._latent_stream_gens[request_key] = self._pipeline.iter_latent_chunks_streaming(
+ text=text,
+ prompt_wav_path=prompt_wav_path,
+ prompt_text=prompt_text,
+ cfg_value=cfg_value,
+ inference_timesteps=inference_timesteps,
+ min_len=min_len,
+ max_len=max_len,
+ streaming_prefix_len=streaming_prefix_len,
+ retry_badcase=False,
+ retry_badcase_max_times=retry_badcase_max_times,
+ retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
+ )
+ generator = self._latent_stream_gens[request_key]
+ try:
+ chunk_latent, is_last = next(generator)
+ except StopIteration:
+ self._latent_stream_gens.pop(request_key, None)
+ self._latent_stream_terminal_pending[request_key] = 1
+ self._latent_stream_completed.add(request_key)
+ outputs.append(torch.zeros((0,), dtype=torch.float32))
+ assert last_chunk_flags is not None
+ last_chunk_flags.append(True)
+ assert payload_finished_flags is not None
+ payload_finished_flags.append(True)
+ else:
+ if is_last:
+ self._latent_stream_gens.pop(request_key, None)
+ self._latent_stream_terminal_pending[request_key] = 1
+ self._latent_stream_completed.add(request_key)
+ outputs.append(chunk_latent.detach().float().cpu())
+ assert last_chunk_flags is not None
+ last_chunk_flags.append(bool(is_last))
+ assert payload_finished_flags is not None
+ payload_finished_flags.append(False)
+ finally:
+ if created_temp is not None and os.path.exists(created_temp):
+ os.unlink(created_temp)
+ sample_rates.append(torch.tensor(sample_rate, dtype=torch.int32))
+ continue
+
+ prompt_wav_path, prompt_text, temp_prompt_wav = self._resolve_prompt_inputs(info)
+ try:
+ latent_audio_feat = self._pipeline.generate_latents(
+ text=text,
+ prompt_wav_path=prompt_wav_path,
+ prompt_text=prompt_text,
+ cfg_value=cfg_value,
+ inference_timesteps=inference_timesteps,
+ min_len=min_len,
+ max_len=max_len,
+ retry_badcase=retry_badcase,
+ retry_badcase_max_times=retry_badcase_max_times,
+ retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
+ )
+ outputs.append(latent_audio_feat.float().cpu())
+ finally:
+ if temp_prompt_wav is not None and os.path.exists(temp_prompt_wav):
+ os.unlink(temp_prompt_wav)
+
+ sample_rates.append(torch.tensor(sample_rate, dtype=torch.int32))
+
+ self._ar_emit_stop_token = all(last_chunk_flags) if async_chunk and last_chunk_flags else True
+ output = self._finalize_stage_output(
+ output_key="latent_audio_feat",
+ outputs=outputs,
+ sample_rates=sample_rates,
+ out_device=out_device,
+ out_dtype=out_dtype,
+ hidden_rows=hidden_rows,
+ )
+ if async_chunk and payload_finished_flags is not None:
+ output.multimodal_outputs["finished"] = [
+ torch.tensor(flag, dtype=torch.bool) for flag in payload_finished_flags
+ ]
+ return output
+
+ def compute_logits(self, hidden_states: torch.Tensor | OmniOutput, sampling_metadata: Any = None) -> torch.Tensor:
+ del sampling_metadata
+ if isinstance(hidden_states, OmniOutput):
+ hidden_states = hidden_states.text_hidden_states
+ if hidden_states is None:
+ device, dtype = self._runner_hidden_device_dtype()
+ hidden_states = torch.zeros((0, 1), device=device, dtype=dtype)
+ if hidden_states.ndim == 1:
+ hidden_states = hidden_states.unsqueeze(-1)
+ elif hidden_states.ndim > 2:
+ hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
+
+ vocab_size = self._get_vocab_size()
+ num_rows = int(hidden_states.shape[0])
+ logits = torch.zeros((num_rows, vocab_size), dtype=torch.float32, device=hidden_states.device)
+ eos_id = 2 if vocab_size > 2 else 0
+ safe_id = 1 if vocab_size > 1 and 1 != eos_id else 0
+ emit_stop = getattr(self, "_ar_emit_stop_token", True)
+ if num_rows > 0:
+ if emit_stop:
+ logits[:, eos_id] = 1.0e6
+ else:
+ logits[:, eos_id] = -1.0e9
+ logits[:, safe_id] = 1.0e6
+ return logits
+
+ @torch.no_grad()
+ def forward(
+ self,
+ input_ids: torch.Tensor | None = None,
+ positions: torch.Tensor | None = None,
+ intermediate_tensors: Any = None,
+ inputs_embeds: torch.Tensor | None = None,
+ runtime_additional_information: list[dict[str, Any]] | None = None,
+ model_intermediate_buffer: list[dict[str, Any]] | None = None,
+ **kwargs: Any,
+ ) -> OmniOutput:
+ del positions, intermediate_tensors, inputs_embeds, kwargs
+ self._ensure_model_loaded()
+ out_device, out_dtype = self._runner_hidden_device_dtype()
+ if input_ids is not None and input_ids.device.type == out_device.type:
+ out_device = input_ids.device
+
+ infos = model_intermediate_buffer or runtime_additional_information or [{}]
+ hidden_rows = len(infos)
+ if input_ids is not None and len(input_ids.shape) > 0:
+ hidden_rows = max(hidden_rows, int(input_ids.shape[0]))
+ sample_rate = int(getattr(self._pipeline, "sample_rate", 24000))
+ async_chunk = bool(getattr(self.vllm_config.model_config, "async_chunk", False))
+ if self.model_stage in self._VAE_STAGES:
+ infos = self._maybe_recover_vae_infos(infos, input_ids, async_chunk=async_chunk)
+ return self._forward_vae_stage(
+ infos,
+ sample_rate=sample_rate,
+ async_chunk=async_chunk,
+ out_device=out_device,
+ out_dtype=out_dtype,
+ )
+ if self.model_stage in self._LATENT_STAGES:
+ return self._forward_latent_stage(
+ infos,
+ sample_rate=sample_rate,
+ async_chunk=async_chunk,
+ out_device=out_device,
+ out_dtype=out_dtype,
+ hidden_rows=hidden_rows,
+ )
+ raise ValueError(f"Unsupported VoxCPM model_stage at runtime: {self.model_stage}")
+
+ def make_empty_intermediate_tensors(
+ self, batch_size: int, dtype: torch.dtype, device: torch.device
+ ) -> IntermediateTensors:
+ del batch_size, dtype, device
+ return {}
+
+
+__all__ = ["VoxCPMForConditionalGeneration"]
diff --git a/vllm_omni/model_executor/models/voxcpm/voxcpm_loader.py b/vllm_omni/model_executor/models/voxcpm/voxcpm_loader.py
new file mode 100644
index 0000000000..dac7117cad
--- /dev/null
+++ b/vllm_omni/model_executor/models/voxcpm/voxcpm_loader.py
@@ -0,0 +1,247 @@
+from __future__ import annotations
+
+import importlib
+import json
+import os
+import shutil
+import sys
+import tempfile
+from contextlib import contextmanager
+from hashlib import sha256
+from pathlib import Path
+from typing import Any
+from unittest.mock import patch
+
+import numpy as np
+import torch
+from vllm.config import VllmConfig
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def _iter_voxcpm_src_candidates() -> list[Path]:
+ candidates: list[Path] = []
+ env_path = os.environ.get("VLLM_OMNI_VOXCPM_CODE_PATH")
+ if env_path:
+ candidates.append(Path(env_path).expanduser())
+
+ repo_root = Path(__file__).resolve().parents[4]
+ candidates.append(repo_root.parent / "VoxCPM" / "src")
+
+ unique_candidates: list[Path] = []
+ seen: set[str] = set()
+ for candidate in candidates:
+ candidate_key = str(candidate)
+ if candidate_key in seen:
+ continue
+ seen.add(candidate_key)
+ unique_candidates.append(candidate)
+ return unique_candidates
+
+
+def _prepend_voxcpm_src(candidate: Path) -> None:
+ candidate_str = str(candidate)
+ if candidate_str not in sys.path:
+ sys.path.insert(0, candidate_str)
+
+
+def _import_voxcpm_attrs(module_name: str, *attr_names: str) -> tuple[Any, ...]:
+ last_exc: ImportError | None = None
+ for candidate in _iter_voxcpm_src_candidates():
+ if not candidate.exists():
+ continue
+ _prepend_voxcpm_src(candidate)
+ try:
+ module = importlib.import_module(module_name)
+ return tuple(getattr(module, attr_name) for attr_name in attr_names)
+ except ImportError as exc:
+ last_exc = exc
+
+ try:
+ module = importlib.import_module(module_name)
+ return tuple(getattr(module, attr_name) for attr_name in attr_names)
+ except ImportError as exc:
+ last_exc = exc
+
+ raise ImportError(f"Failed to import {module_name}.") from last_exc
+
+
+def _import_voxcpm_base_model_class():
+ """Import upstream ``VoxCPMModel`` from ``VoxCPM/src/voxcpm`` (env, sibling tree, or pip)."""
+ try:
+ (VoxCPMModel,) = _import_voxcpm_attrs("voxcpm.model.voxcpm", "VoxCPMModel")
+ return VoxCPMModel
+ except ImportError as exc:
+ raise ImportError(
+ "Failed to import VoxCPMModel. Install the `voxcpm` package or set "
+ "`VLLM_OMNI_VOXCPM_CODE_PATH` to the VoxCPM repository `src` directory "
+ "(the parent of the `voxcpm` package that contains `model/` and `modules/`)."
+ ) from exc
+
+
+def _import_voxcpm_audio_vae_classes():
+ try:
+ return _import_voxcpm_attrs("voxcpm.modules.audiovae", "AudioVAE", "AudioVAEConfig")
+ except ImportError as exc:
+ raise ImportError(
+ "Failed to import VoxCPM AudioVAE. Install the `voxcpm` package or set "
+ "`VLLM_OMNI_VOXCPM_CODE_PATH` to the VoxCPM repository `src` directory."
+ ) from exc
+
+
+def _device_to_string(device: torch.device) -> str:
+ if device.index is None:
+ return device.type
+ return f"{device.type}:{device.index}"
+
+
+def _normalize_dtype_name(dtype: Any) -> str | None:
+ if dtype is None:
+ return None
+ if isinstance(dtype, torch.dtype):
+ mapping = {
+ torch.bfloat16: "bfloat16",
+ torch.float16: "float16",
+ torch.float32: "float32",
+ }
+ return mapping.get(dtype, str(dtype).removeprefix("torch."))
+ dtype_str = str(dtype)
+ return dtype_str.removeprefix("torch.")
+
+
+def _resolve_runtime_device(vllm_config: VllmConfig) -> torch.device:
+ try:
+ from vllm_omni.platforms import current_omni_platform
+
+ return current_omni_platform.get_torch_device()
+ except Exception:
+ pass
+
+ device = getattr(getattr(vllm_config, "device_config", None), "device", None)
+ if isinstance(device, torch.device):
+ return device
+ if device:
+ return torch.device(device)
+ return torch.device("cpu")
+
+
+def _prepare_runtime_model_dir(
+ model_path: str | Path,
+ *,
+ target_device: torch.device,
+ target_dtype: str | None,
+) -> str:
+ source_dir = Path(model_path)
+ config_path = source_dir / "config.json"
+ if not config_path.exists():
+ return str(source_dir)
+
+ config_text = config_path.read_text()
+ config_dict = json.loads(config_text)
+ desired_device = target_device.type
+ desired_dtype = target_dtype or config_dict.get("dtype")
+
+ if config_dict.get("device") == desired_device and config_dict.get("dtype") == desired_dtype:
+ return str(source_dir)
+
+ digest = sha256(f"{source_dir.resolve()}:{config_text}:{desired_device}:{desired_dtype}".encode()).hexdigest()[:16]
+ runtime_dir = Path(tempfile.gettempdir()) / "vllm_omni_voxcpm_runtime" / digest
+ runtime_dir.mkdir(parents=True, exist_ok=True)
+
+ for entry in source_dir.iterdir():
+ target = runtime_dir / entry.name
+ if entry.name == "config.json" or target.exists():
+ continue
+ try:
+ target.symlink_to(entry, target_is_directory=entry.is_dir())
+ except OSError as exc:
+ logger.warning(
+ "Falling back to copying VoxCPM runtime artifact %s into %s because symlink creation failed: %s",
+ entry,
+ runtime_dir,
+ exc,
+ )
+ if entry.is_dir():
+ shutil.copytree(entry, target, dirs_exist_ok=True)
+ else:
+ shutil.copy2(entry, target)
+
+ patched_config = dict(config_dict)
+ patched_config["device"] = desired_device
+ if desired_dtype is not None:
+ patched_config["dtype"] = desired_dtype
+ (runtime_dir / "config.json").write_text(json.dumps(patched_config, indent=2, sort_keys=True))
+ return str(runtime_dir)
+
+
+@contextmanager
+def _force_cuda_available_for_npu(device: torch.device):
+ if device.type != "npu":
+ yield
+ return
+
+ with patch("torch.cuda.is_available", return_value=True):
+ yield
+
+
+def _is_torchcodec_load_error(exc: BaseException) -> bool:
+ message = str(exc).lower()
+ return "torchcodec" in message or "load_with_torchcodec" in message
+
+
+def _load_audio_with_soundfile(
+ prompt_wav_path: str,
+ *,
+ sample_rate: int,
+) -> torch.Tensor:
+ try:
+ import soundfile as sf
+ except ImportError:
+ raise
+
+ audio_np, source_sr = sf.read(prompt_wav_path, dtype="float32", always_2d=True)
+ audio = torch.from_numpy(np.ascontiguousarray(audio_np.T))
+
+ if audio.size(0) > 1:
+ audio = audio.mean(dim=0, keepdim=True)
+
+ if int(source_sr) != int(sample_rate):
+ try:
+ import torchaudio
+ except ImportError as exc:
+ raise ImportError("torchaudio is required for resampling prompt audio.") from exc
+ audio = torchaudio.functional.resample(audio, int(source_sr), int(sample_rate))
+
+ return audio
+
+
+def _build_prompt_cache_with_soundfile(model: Any, *args: Any, **kwargs: Any) -> dict[str, Any]:
+ if args:
+ prompt_text = args[0]
+ prompt_wav_path = args[1] if len(args) > 1 else kwargs.get("prompt_wav_path")
+ else:
+ prompt_text = kwargs.get("prompt_text")
+ prompt_wav_path = kwargs.get("prompt_wav_path")
+
+ if not prompt_text or not prompt_wav_path:
+ raise ValueError("prompt_text and prompt_wav_path are required")
+
+ audio = _load_audio_with_soundfile(prompt_wav_path, sample_rate=int(model.sample_rate))
+
+ patch_len = model.patch_size * model.chunk_size
+ if audio.size(1) % patch_len != 0:
+ padding_size = patch_len - audio.size(1) % patch_len
+ audio = torch.nn.functional.pad(audio, (padding_size, 0))
+
+ audio_feat = model.audio_vae.encode(audio.to(model.device), model.sample_rate).cpu()
+ audio_feat = audio_feat.view(
+ model.audio_vae.latent_dim,
+ -1,
+ model.patch_size,
+ ).permute(1, 2, 0)
+
+ return {
+ "prompt_text": prompt_text,
+ "audio_feat": audio_feat,
+ }
diff --git a/vllm_omni/model_executor/models/voxcpm/voxcpm_runtime_utils.py b/vllm_omni/model_executor/models/voxcpm/voxcpm_runtime_utils.py
new file mode 100644
index 0000000000..36b4282c2d
--- /dev/null
+++ b/vllm_omni/model_executor/models/voxcpm/voxcpm_runtime_utils.py
@@ -0,0 +1,44 @@
+from __future__ import annotations
+
+import json
+import shutil
+from pathlib import Path
+
+
+def resolve_voxcpm_model_dir(model: str) -> Path:
+ model_path = Path(model).expanduser()
+ if model_path.exists():
+ return model_path
+
+ from huggingface_hub import snapshot_download
+
+ return Path(snapshot_download(repo_id=model))
+
+
+def prepare_voxcpm_hf_config_dir(model_dir: str | Path, hf_config_dir: str | Path) -> Path:
+ model_dir = Path(model_dir).expanduser()
+ hf_config_dir = Path(hf_config_dir).expanduser()
+ hf_config_dir.mkdir(parents=True, exist_ok=True)
+
+ source_config_path = model_dir / "config.json"
+ if not source_config_path.exists():
+ raise FileNotFoundError(f"VoxCPM config.json not found under {model_dir}")
+
+ config_path = hf_config_dir / "config.json"
+ shutil.copy2(source_config_path, config_path)
+
+ source_generation_config_path = model_dir / "generation_config.json"
+ if source_generation_config_path.exists():
+ shutil.copy2(source_generation_config_path, hf_config_dir / "generation_config.json")
+
+ config_dict = json.loads(config_path.read_text(encoding="utf-8"))
+ config_dict["model_type"] = "voxcpm"
+ config_dict.setdefault("architectures", ["VoxCPMForConditionalGeneration"])
+ config_path.write_text(json.dumps(config_dict, indent=2, ensure_ascii=False), encoding="utf-8")
+ return hf_config_dir
+
+
+__all__ = [
+ "prepare_voxcpm_hf_config_dir",
+ "resolve_voxcpm_model_dir",
+]
diff --git a/vllm_omni/model_executor/models/voxcpm/voxcpm_stage_wrappers.py b/vllm_omni/model_executor/models/voxcpm/voxcpm_stage_wrappers.py
new file mode 100644
index 0000000000..f4446c796e
--- /dev/null
+++ b/vllm_omni/model_executor/models/voxcpm/voxcpm_stage_wrappers.py
@@ -0,0 +1,185 @@
+from __future__ import annotations
+
+import os
+from collections.abc import Generator
+from typing import Any
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+
+class _DirectVoxCPMLatentGenerator:
+ def __init__(self, tts_model: Any):
+ self.tts_model = tts_model
+ self.sample_rate = int(getattr(tts_model, "sample_rate", 24000))
+
+ def generate_latents(
+ self,
+ *,
+ text: str,
+ prompt_wav_path: str | None = None,
+ prompt_text: str | None = None,
+ cfg_value: float = 2.0,
+ inference_timesteps: int = 10,
+ min_len: int = 2,
+ max_len: int = 4096,
+ retry_badcase: bool = True,
+ retry_badcase_max_times: int = 3,
+ retry_badcase_ratio_threshold: float = 6.0,
+ ) -> torch.Tensor:
+ if not isinstance(text, str) or not text.strip():
+ raise ValueError("target text must be a non-empty string")
+ if (prompt_wav_path is None) != (prompt_text is None):
+ raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
+ if prompt_wav_path is not None and not os.path.exists(prompt_wav_path):
+ raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
+
+ prompt_cache = None
+ if prompt_wav_path is not None and prompt_text is not None:
+ prompt_cache = self.tts_model.build_prompt_cache(
+ prompt_text=prompt_text,
+ prompt_wav_path=prompt_wav_path,
+ )
+
+ gen_kw = dict(
+ target_text=" ".join(text.split()),
+ prompt_cache=prompt_cache,
+ min_len=min_len,
+ max_len=max_len,
+ inference_timesteps=inference_timesteps,
+ cfg_value=cfg_value,
+ retry_badcase=retry_badcase,
+ retry_badcase_max_times=retry_badcase_max_times,
+ retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
+ )
+ latent_entry = getattr(self.tts_model, "generate_latents_with_prompt_cache", None)
+ if latent_entry is not None:
+ _, _, pred_audio_feat = latent_entry(**gen_kw)
+ else:
+ try:
+ _, _, pred_audio_feat = self.tts_model.generate_with_prompt_cache(
+ **gen_kw,
+ latents_only=True,
+ )
+ except TypeError:
+ _, _, pred_audio_feat = self.tts_model.generate_with_prompt_cache(**gen_kw)
+ return pred_audio_feat.detach().cpu().to(torch.float32)
+
+ def iter_latent_chunks_streaming(
+ self,
+ *,
+ text: str,
+ prompt_wav_path: str | None = None,
+ prompt_text: str | None = None,
+ cfg_value: float = 2.0,
+ inference_timesteps: int = 10,
+ min_len: int = 2,
+ max_len: int = 4096,
+ streaming_prefix_len: int = 3,
+ retry_badcase: bool = False,
+ retry_badcase_max_times: int = 3,
+ retry_badcase_ratio_threshold: float = 6.0,
+ ) -> Generator[tuple[torch.Tensor, bool], None, None]:
+ """Yield ``(latent_window, is_last_chunk)`` for Omni async_chunk latent to VAE."""
+ if not isinstance(text, str) or not text.strip():
+ raise ValueError("target text must be a non-empty string")
+ if (prompt_wav_path is None) != (prompt_text is None):
+ raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
+ if prompt_wav_path is not None and not os.path.exists(prompt_wav_path):
+ raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
+
+ prompt_cache = None
+ if prompt_wav_path is not None and prompt_text is not None:
+ prompt_cache = self.tts_model.build_prompt_cache(
+ prompt_text=prompt_text,
+ prompt_wav_path=prompt_wav_path,
+ )
+
+ gen_kw = dict(
+ target_text=" ".join(text.split()),
+ prompt_cache=prompt_cache,
+ min_len=min_len,
+ max_len=max_len,
+ inference_timesteps=inference_timesteps,
+ cfg_value=cfg_value,
+ retry_badcase=retry_badcase,
+ retry_badcase_max_times=retry_badcase_max_times,
+ retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
+ streaming_prefix_len=streaming_prefix_len,
+ )
+ stream_entry = getattr(self.tts_model, "generate_latents_with_prompt_cache_streaming", None)
+ if stream_entry is not None:
+ gen = stream_entry(**gen_kw)
+ else:
+ fallback_stream_entry = getattr(self.tts_model, "generate_with_prompt_cache_streaming", None)
+ if fallback_stream_entry is not None:
+ gen = fallback_stream_entry(**gen_kw, latents_only=True)
+ else:
+ gen = self.tts_model._generate_with_prompt_cache(streaming=True, latents_only=True, **gen_kw)
+
+ iterator = iter(gen)
+ previous = next(iterator, None)
+ while previous is not None:
+ current = next(iterator, None)
+ _, _target_tok, chunk_latent = previous
+ if not isinstance(chunk_latent, torch.Tensor):
+ chunk_latent = torch.as_tensor(chunk_latent)
+ yield chunk_latent, current is None
+ previous = current
+
+
+class _DirectVoxCPMAudioVAE:
+ def __init__(self, audio_vae: nn.Module, *, patch_size: int = 2):
+ self.audio_vae = audio_vae
+ self.sample_rate = int(getattr(audio_vae, "sample_rate", 24000))
+ self.latent_dim = int(getattr(audio_vae, "latent_dim", 64))
+ self.patch_size = int(patch_size)
+ self._chunk_size = int(getattr(audio_vae, "chunk_size", 1))
+ self._stream_audio_patch_samples = max(1, self.patch_size * self._chunk_size)
+
+ def _prepare_latents_for_decode(self, latent_audio_feat: Any) -> torch.Tensor:
+ latents = latent_audio_feat
+ if not isinstance(latents, torch.Tensor):
+ latents = torch.tensor(latents, dtype=torch.float32)
+ latents = latents.detach().to(torch.float32)
+
+ if latents.ndim == 3:
+ if latents.shape[-1] == self.latent_dim:
+ latents = rearrange(latents, "t p d -> 1 d (t p)")
+ elif latents.shape[1] == self.latent_dim:
+ latents = latents.contiguous()
+ else:
+ raise ValueError(f"Unsupported latent_audio_feat shape: {tuple(latents.shape)}")
+ elif latents.ndim == 2:
+ if latents.shape[0] == self.latent_dim:
+ latents = latents.unsqueeze(0)
+ elif latents.shape[1] == self.latent_dim:
+ latents = rearrange(latents, "t d -> 1 d t")
+ else:
+ raise ValueError(f"Unsupported latent_audio_feat shape: {tuple(latents.shape)}")
+ else:
+ raise ValueError(f"Unsupported latent_audio_feat ndim: {latents.ndim}")
+
+ return latents
+
+ @torch.no_grad()
+ def decode(self, latent_audio_feat: Any, *, trim_streaming_patch: bool = False) -> torch.Tensor:
+ latents = self._prepare_latents_for_decode(latent_audio_feat)
+ device = next(self.audio_vae.parameters()).device
+ raw = self.audio_vae.decode(latents.to(device=device, dtype=torch.float32))
+ if isinstance(raw, dict):
+ audio = raw.get("audio")
+ if audio is None:
+ audio = next(v for v in raw.values() if isinstance(v, torch.Tensor))
+ else:
+ audio = raw
+ if audio.dim() == 3:
+ stream = audio.squeeze(1)
+ elif audio.dim() == 2:
+ stream = audio
+ else:
+ stream = audio.reshape(audio.shape[0], -1)
+ if trim_streaming_patch:
+ stream = stream[..., -self._stream_audio_patch_samples :]
+ return stream.reshape(-1).detach().cpu().to(torch.float32)
diff --git a/vllm_omni/model_executor/models/voxcpm2/__init__.py b/vllm_omni/model_executor/models/voxcpm2/__init__.py
new file mode 100644
index 0000000000..77bd8dfb51
--- /dev/null
+++ b/vllm_omni/model_executor/models/voxcpm2/__init__.py
@@ -0,0 +1,5 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from .voxcpm2_talker import VoxCPM2TalkerForConditionalGeneration
+
+__all__ = ["VoxCPM2TalkerForConditionalGeneration"]
diff --git a/vllm_omni/model_executor/models/voxcpm2/minicpm4_hf_compat.py b/vllm_omni/model_executor/models/voxcpm2/minicpm4_hf_compat.py
new file mode 100644
index 0000000000..cb3101b16a
--- /dev/null
+++ b/vllm_omni/model_executor/models/voxcpm2/minicpm4_hf_compat.py
@@ -0,0 +1,114 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""fp32 RoPE + MLP matching native VoxCPM2 numerics.
+
+Exports: _MiniCPMLongRoPE, _MiniCPMMLP, _apply_rotary_pos_emb
+"""
+
+from __future__ import annotations
+
+import math
+from typing import Any
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+# ===================================================================
+# Primitives
+# ===================================================================
+
+
+def _rotate_half(x: torch.Tensor) -> torch.Tensor:
+ x1, x2 = x.chunk(2, dim=-1)
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def _apply_rotary_pos_emb(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Apply rotary embeddings in float32."""
+ orig_dtype = q.dtype
+ q, k = q.to(torch.float32), k.to(torch.float32)
+ q_embed = (q * cos) + (_rotate_half(q) * sin)
+ k_embed = (k * cos) + (_rotate_half(k) * sin)
+ return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
+
+
+# ===================================================================
+# LongRoPE — must match native computation order exactly
+# ===================================================================
+
+
+class _MiniCPMLongRoPE(nn.Module):
+ """LongRoPE matching native computation order."""
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_attention_heads: int,
+ kv_channels: int | None,
+ rope_theta: float,
+ max_position_embeddings: int,
+ rope_scaling: dict[str, Any],
+ ) -> None:
+ super().__init__()
+ self.dim = kv_channels if kv_channels else hidden_size // num_attention_heads
+ self.base = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+ self.short_factor = rope_scaling["short_factor"]
+ self.long_factor = rope_scaling["long_factor"]
+ self.original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
+
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
+ self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
+
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ self.max_seq_len_cached = 0
+ self.register_buffer("cos_cached", torch.empty(0), persistent=False)
+ self.register_buffer("sin_cached", torch.empty(0), persistent=False)
+ self._set_cos_sin_cache(self.max_position_embeddings, self.inv_freq.device, torch.float32)
+
+ def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
+
+ ext_factors = torch.tensor(
+ self.long_factor if seq_len > self.original_max_position_embeddings else self.short_factor,
+ dtype=torch.float32,
+ device=device,
+ )
+
+ freqs = torch.mul(
+ torch.outer(t, 1.0 / ext_factors).to(device=device),
+ self.inv_freq.to(device=device).to(dtype),
+ )
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.cos_cached = emb.cos().to(dtype) * self.scaling_factor
+ self.sin_cached = emb.sin().to(dtype) * self.scaling_factor
+
+ def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ return self.cos_cached[position_ids], self.sin_cached[position_ids]
+
+
+# ===================================================================
+# MLP
+# ===================================================================
+
+
+class _MiniCPMMLP(nn.Module):
+ """SiLU-gated MLP matching native MiniCPMMLP."""
+
+ def __init__(self, hidden_size: int, intermediate_size: int) -> None:
+ super().__init__()
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
diff --git a/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py b/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py
new file mode 100644
index 0000000000..b87ec5aafe
--- /dev/null
+++ b/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py
@@ -0,0 +1,457 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""MiniCPM4 with PagedAttention + fp32 RoPE/RMSNorm for VoxCPM2.
+
+Uses vllm Attention for KV cache, keeps fp32 precision ops from
+minicpm4_hf_compat.py to match native VoxCPM2 numerics.
+"""
+
+from __future__ import annotations
+
+import math
+from collections.abc import Iterable
+from typing import Any
+
+import torch
+import torch.nn as nn
+from vllm.config import CacheConfig, VllmConfig
+from vllm.logger import init_logger
+from vllm.model_executor.layers.attention import Attention
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.models.utils import make_empty_intermediate_tensors_factory
+from vllm.sequence import IntermediateTensors
+
+from .minicpm4_hf_compat import (
+ _apply_rotary_pos_emb,
+ _MiniCPMLongRoPE,
+ _MiniCPMMLP,
+)
+
+logger = init_logger(__name__)
+
+
+def _resolve_lm_cfg(config: Any) -> Any:
+ """Extract lm_config from VoxCPM2Config, converting dict to namespace if needed."""
+ lm_cfg = getattr(config, "lm_config", config)
+ if isinstance(lm_cfg, dict):
+
+ class _Cfg:
+ pass
+
+ c = _Cfg()
+ for k, v in lm_cfg.items():
+ setattr(c, k, v)
+ return c
+ return lm_cfg
+
+
+# ===================================================================
+# Attention with vllm PagedAttention backend
+# ===================================================================
+
+
+class _PagedMiniCPM4Attention(nn.Module):
+ """PagedAttention + fp32 RoPE with separate q/k/v projections."""
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_attention_heads: int,
+ num_key_value_heads: int,
+ kv_channels: int | None,
+ layer_idx: int,
+ cache_config: CacheConfig | None = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.hidden_size = hidden_size
+ self.num_heads = num_attention_heads
+ self.head_dim = kv_channels if kv_channels else hidden_size // num_attention_heads
+ self.num_kv_heads = num_key_value_heads
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+
+ self.q_proj = nn.Linear(hidden_size, self.q_size, bias=False)
+ self.k_proj = nn.Linear(hidden_size, self.kv_size, bias=False)
+ self.v_proj = nn.Linear(hidden_size, self.kv_size, bias=False)
+ self.o_proj = nn.Linear(self.q_size, hidden_size, bias=False)
+ self._fused_qkv_weight: torch.Tensor | None = None
+
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ scale=self.head_dim**-0.5,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ prefix=f"{prefix}.attn",
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ rope_emb: _MiniCPMLongRoPE | None = None,
+ ) -> torch.Tensor:
+ """Forward: fused QKV → fp32 RoPE → PagedAttention → o_proj."""
+ if self._fused_qkv_weight is None:
+ self._fused_qkv_weight = torch.cat(
+ [
+ self.q_proj.weight,
+ self.k_proj.weight,
+ self.v_proj.weight,
+ ],
+ dim=0,
+ ).detach()
+ qkv = nn.functional.linear(hidden_states, self._fused_qkv_weight)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+
+ if rope_emb is not None:
+ cos, sin = rope_emb(positions)
+ bsz = q.shape[0]
+ q_r = q.view(bsz, self.num_heads, self.head_dim)
+ k_r = k.view(bsz, self.num_kv_heads, self.head_dim)
+ q_r = q_r.unsqueeze(0).transpose(1, 2) # [1, heads, n_tokens, dim]
+ k_r = k_r.unsqueeze(0).transpose(1, 2) # [1, kv_heads, n_tokens, dim]
+ q_r, k_r = _apply_rotary_pos_emb(q_r, k_r, cos, sin)
+ q = q_r.transpose(1, 2).squeeze(0).reshape(bsz, -1) # [n_tokens, q_size]
+ k = k_r.transpose(1, 2).squeeze(0).reshape(bsz, -1) # [n_tokens, kv_size]
+
+ attn_output = self.attn(q, k, v)
+
+ output = self.o_proj(attn_output)
+ return output
+
+
+# ===================================================================
+# Decoder Layer
+# ===================================================================
+
+
+class _PagedMiniCPM4DecoderLayer(nn.Module):
+ """Decoder layer: PagedAttention + fp32 RMSNorm + muP scale_depth."""
+
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ num_attention_heads: int,
+ num_key_value_heads: int,
+ kv_channels: int | None,
+ rms_norm_eps: float,
+ layer_idx: int,
+ num_hidden_layers: int,
+ use_mup: bool,
+ scale_depth: float,
+ cache_config: CacheConfig | None = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.self_attn = _PagedMiniCPM4Attention(
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ kv_channels=kv_channels,
+ layer_idx=layer_idx,
+ cache_config=cache_config,
+ prefix=f"{prefix}.self_attn",
+ )
+ self.mlp = _MiniCPMMLP(hidden_size, intermediate_size)
+ self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
+
+ self.use_mup = use_mup
+ self.scale_depth = scale_depth
+ self.num_hidden_layers = num_hidden_layers
+
+ def _residual_scale(self) -> float:
+ if self.use_mup:
+ return self.scale_depth / math.sqrt(self.num_hidden_layers)
+ return 1.0
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ rope_emb: _MiniCPMLongRoPE | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ # Pre-norm + attention
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states = self.self_attn(positions, hidden_states, rope_emb)
+
+ scale = self._residual_scale()
+ if scale != 1.0:
+ hidden_states = residual + hidden_states * scale
+ else:
+ hidden_states = residual + hidden_states
+
+ # Pre-norm + FFN
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+
+ if scale != 1.0:
+ hidden_states = residual + hidden_states * scale
+ else:
+ hidden_states = residual + hidden_states
+
+ return hidden_states, None
+
+
+# ===================================================================
+# Full Model
+# ===================================================================
+
+
+class MiniCPM4PagedForVoxCPM2(nn.Module):
+ """PagedAttention base_lm (28 layers) for VoxCPM2 scaffold."""
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ cache_config = vllm_config.cache_config
+ self.config = config
+
+ lm_cfg = _resolve_lm_cfg(config)
+
+ hidden_size = lm_cfg.hidden_size
+ num_hidden_layers = lm_cfg.num_hidden_layers
+ kv_channels = getattr(lm_cfg, "kv_channels", None)
+
+ self.vocab_size = lm_cfg.vocab_size
+ self.embed_tokens = nn.Embedding(self.vocab_size, hidden_size)
+
+ rope_scaling = getattr(lm_cfg, "rope_scaling", None)
+ if isinstance(rope_scaling, dict):
+ rope_scaling_dict = rope_scaling
+ elif hasattr(rope_scaling, "__dict__"):
+ rope_scaling_dict = {
+ "short_factor": rope_scaling.short_factor,
+ "long_factor": rope_scaling.long_factor,
+ "original_max_position_embeddings": rope_scaling.original_max_position_embeddings,
+ }
+ else:
+ rope_scaling_dict = {}
+
+ no_rope = getattr(lm_cfg, "no_rope", False)
+ if not no_rope:
+ self.rope_emb = _MiniCPMLongRoPE(
+ hidden_size=hidden_size,
+ num_attention_heads=lm_cfg.num_attention_heads,
+ kv_channels=kv_channels,
+ rope_theta=getattr(lm_cfg, "rope_theta", 10000.0),
+ max_position_embeddings=getattr(lm_cfg, "max_position_embeddings", 32768),
+ rope_scaling=rope_scaling_dict,
+ )
+ else:
+ self.rope_emb = None
+
+ self.layers = nn.ModuleList(
+ [
+ _PagedMiniCPM4DecoderLayer(
+ hidden_size=hidden_size,
+ intermediate_size=lm_cfg.intermediate_size,
+ num_attention_heads=lm_cfg.num_attention_heads,
+ num_key_value_heads=lm_cfg.num_key_value_heads,
+ kv_channels=kv_channels,
+ rms_norm_eps=lm_cfg.rms_norm_eps,
+ layer_idx=i,
+ num_hidden_layers=num_hidden_layers,
+ use_mup=getattr(lm_cfg, "use_mup", False),
+ scale_depth=getattr(lm_cfg, "scale_depth", 1.0),
+ cache_config=cache_config,
+ prefix=f"{prefix}.layers.{i}",
+ )
+ for i in range(num_hidden_layers)
+ ]
+ )
+
+ self.norm = RMSNorm(hidden_size, eps=lm_cfg.rms_norm_eps)
+
+ self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], hidden_size
+ )
+
+ use_mup = getattr(lm_cfg, "use_mup", False)
+ self._scale_emb = getattr(lm_cfg, "scale_emb", 1.0) if use_mup else 1.0
+ self._compiled_layers: set[int] = set()
+
+ def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor:
+ return self.embed_tokens(input_ids) * self._scale_emb
+
+ def forward(
+ self,
+ input_ids: torch.Tensor | None,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ **kwargs: Any,
+ ) -> torch.Tensor | IntermediateTensors:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.embed_input_ids(input_ids)
+
+ residual = None
+ for layer in self.layers:
+ hidden_states, residual = layer(
+ positions,
+ hidden_states,
+ residual,
+ self.rope_emb,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return hidden_states
+
+ def precompute_fused_qkv(self) -> None:
+ """Materialize fused QKV weights before CUDA Graph capture."""
+ for layer in self.layers:
+ attn = layer.self_attn
+ if attn._fused_qkv_weight is None:
+ attn._fused_qkv_weight = torch.cat(
+ [attn.q_proj.weight, attn.k_proj.weight, attn.v_proj.weight],
+ dim=0,
+ ).detach()
+
+ def compile_selective(self) -> list[str]:
+ """Compile the full model forward as one graph.
+
+ Earlier versions compiled ``layer.mlp`` + ``layer.self_attn.o_proj``
+ (PR #2690) and then the whole ``layer`` (perf/voxcpm2-streaming-vae).
+ Both still paid one Dynamo dispatch per layer per decode step.
+ V3 profiling showed 1,332 per-layer dispatches (~28 layers × ~47
+ decode steps) costing ~726 ms of CPU self-time for a long prompt.
+
+ Compiling ``forward`` at the model level lets Dynamo unroll the
+ 28-layer Python loop inside the graph. Graph breaks at
+ PagedAttention produce sub-graphs but Dynamo memoises the whole
+ trace once, so the per-step dispatch drops from 28 to just a few.
+ """
+ if self._compiled_layers:
+ return []
+ # Null the fused-qkv caches so the compile sees the real weight layout.
+ for layer in self.layers:
+ layer.self_attn._fused_qkv_weight = None
+ self.forward = torch.compile(self.forward, mode="default", fullgraph=False)
+ # Mark every layer as compiled so idempotent callers don't double-wrap.
+ self._compiled_layers.update(range(len(self.layers)))
+ return ["forward (whole model)"]
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ """Load weights from native checkpoint (base_lm. prefix pre-stripped)."""
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
+ loaded: set[str] = set()
+
+ for name, loaded_weight in weights:
+ if "rotary_emb.inv_freq" in name:
+ continue
+ param = params_dict.get(name)
+ if param is None:
+ continue
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded.add(name)
+
+ return loaded
+
+
+# ===================================================================
+# Residual LM with PagedAttention (no RoPE, 8 layers)
+# ===================================================================
+
+
+class MiniCPM4PagedResidualLM(nn.Module):
+ """PagedAttention residual LM (8 layers, no RoPE) for VoxCPM2."""
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ cache_config = vllm_config.cache_config
+ self.config = config
+
+ lm_cfg = _resolve_lm_cfg(config)
+
+ hidden_size = lm_cfg.hidden_size
+ num_hidden_layers = getattr(config, "residual_lm_num_layers", 8)
+ kv_channels = getattr(lm_cfg, "kv_channels", None)
+
+ self.rope_emb = None
+
+ self.layers = nn.ModuleList(
+ [
+ _PagedMiniCPM4DecoderLayer(
+ hidden_size=hidden_size,
+ intermediate_size=lm_cfg.intermediate_size,
+ num_attention_heads=lm_cfg.num_attention_heads,
+ num_key_value_heads=lm_cfg.num_key_value_heads,
+ kv_channels=kv_channels,
+ rms_norm_eps=lm_cfg.rms_norm_eps,
+ layer_idx=i,
+ num_hidden_layers=num_hidden_layers,
+ use_mup=getattr(lm_cfg, "use_mup", False),
+ scale_depth=getattr(lm_cfg, "scale_depth", 1.0),
+ cache_config=cache_config,
+ prefix=f"{prefix}.layers.{i}",
+ )
+ for i in range(num_hidden_layers)
+ ]
+ )
+
+ self.norm = RMSNorm(hidden_size, eps=lm_cfg.rms_norm_eps)
+ self._compiled_layers: set[int] = set()
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ inputs_embeds: torch.Tensor,
+ ) -> torch.Tensor:
+ hidden_states = inputs_embeds
+ residual = None
+ for layer in self.layers:
+ hidden_states, residual = layer(
+ positions,
+ hidden_states,
+ residual,
+ self.rope_emb,
+ )
+ hidden_states = self.norm(hidden_states)
+ return hidden_states
+
+ def precompute_fused_qkv(self) -> None:
+ """Materialize fused QKV weights before CUDA Graph capture."""
+ for layer in self.layers:
+ attn = layer.self_attn
+ if attn._fused_qkv_weight is None:
+ attn._fused_qkv_weight = torch.cat(
+ [attn.q_proj.weight, attn.k_proj.weight, attn.v_proj.weight],
+ dim=0,
+ ).detach()
+
+ def compile_selective(self) -> list[str]:
+ """Compile the full residual model forward as one graph (same strategy as base_lm)."""
+ if self._compiled_layers:
+ return []
+ for layer in self.layers:
+ layer.self_attn._fused_qkv_weight = None
+ self.forward = torch.compile(self.forward, mode="default", fullgraph=False)
+ self._compiled_layers.update(range(len(self.layers)))
+ return ["forward (whole residual)"]
+
+ def load_weights_from_native(self, native_residual_lm: nn.Module) -> int:
+ """Load weights from native residual_lm. Returns param count."""
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
+ loaded = 0
+ for name, param in native_residual_lm.named_parameters():
+ if "rotary_emb" in name:
+ continue
+ target = params_dict.get(name)
+ if target is None:
+ continue
+ weight_loader = getattr(target, "weight_loader", default_weight_loader)
+ weight_loader(target, param.data)
+ loaded += 1
+ return loaded
diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_import_utils.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_import_utils.py
new file mode 100644
index 0000000000..231a51bbca
--- /dev/null
+++ b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_import_utils.py
@@ -0,0 +1,82 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Dynamic import utilities for the native VoxCPM2 package.
+
+Supports three discovery modes (first match wins):
+1. ``VLLM_OMNI_VOXCPM_CODE_PATH`` env var (explicit source tree)
+2. Sibling ``../VoxCPM/src`` relative to the vllm-omni repo root
+3. pip-installed ``voxcpm`` package (>= 2.0)
+"""
+
+from __future__ import annotations
+
+import importlib
+import os
+import sys
+from pathlib import Path
+from typing import Any
+
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def _iter_voxcpm2_src_candidates() -> list[Path]:
+ """Yield candidate source directories for VoxCPM2."""
+ candidates: list[Path] = []
+ env_path = os.environ.get("VLLM_OMNI_VOXCPM_CODE_PATH")
+ if env_path:
+ candidates.append(Path(env_path).expanduser())
+
+ repo_root = Path(__file__).resolve().parents[4]
+ candidates.append(repo_root.parent / "VoxCPM" / "src")
+
+ seen: set[str] = set()
+ unique: list[Path] = []
+ for c in candidates:
+ key = str(c)
+ if key not in seen:
+ seen.add(key)
+ unique.append(c)
+ return unique
+
+
+def _prepend_src(candidate: Path) -> None:
+ candidate_str = str(candidate)
+ if candidate_str not in sys.path:
+ sys.path.insert(0, candidate_str)
+
+
+def _import_voxcpm2_attrs(module_name: str, *attr_names: str) -> tuple[Any, ...]:
+ """Import attributes from the voxcpm package, trying source tree first."""
+ last_exc: ImportError | None = None
+
+ for candidate in _iter_voxcpm2_src_candidates():
+ if not candidate.exists():
+ continue
+ _prepend_src(candidate)
+ try:
+ mod = importlib.import_module(module_name)
+ return tuple(getattr(mod, name) for name in attr_names)
+ except (ImportError, AttributeError) as exc:
+ last_exc = ImportError(str(exc))
+ continue
+
+ try:
+ mod = importlib.import_module(module_name)
+ return tuple(getattr(mod, name) for name in attr_names)
+ except (ImportError, AttributeError) as exc:
+ last_exc = ImportError(str(exc))
+
+ raise ImportError(
+ f"Could not import {attr_names} from {module_name}. "
+ f"Install voxcpm>=2.0: pip install voxcpm. "
+ f"Or set VLLM_OMNI_VOXCPM_CODE_PATH to the VoxCPM source tree. "
+ f"Last error: {last_exc}"
+ )
+
+
+def import_voxcpm2_core():
+ """Import the VoxCPM core class used to load the native TTS model."""
+ (VoxCPM,) = _import_voxcpm2_attrs("voxcpm.core", "VoxCPM")
+ return VoxCPM
diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py
new file mode 100644
index 0000000000..3724528898
--- /dev/null
+++ b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py
@@ -0,0 +1,1289 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""VoxCPM2 AR talker — PagedAttention pipeline with per-request state.
+
+Architecture:
+ MiniCPM4PagedForVoxCPM2 (base_lm, 28 layers, PagedAttention + fp32 RoPE)
+ → FSQ → MiniCPM4PagedResidualLM (8 layers, PagedAttention, no RoPE)
+ → LocDiT (CFM solver) → AudioVAE → 48kHz waveform
+"""
+
+from __future__ import annotations
+
+import copy
+import dataclasses
+import logging
+import math
+import os
+import time
+from collections.abc import Iterable
+from typing import Any
+
+import librosa
+import torch
+import torch.nn as nn
+from vllm.config import VllmConfig
+from vllm.forward_context import get_forward_context, override_forward_context
+from vllm.logger import init_logger
+from vllm.model_executor.models.utils import (
+ AutoWeightsLoader,
+ WeightsMapper,
+ maybe_prefix,
+)
+from vllm.sequence import IntermediateTensors
+
+from vllm_omni.model_executor.models.output_templates import OmniOutput
+
+from .minicpm4_paged import MiniCPM4PagedForVoxCPM2, MiniCPM4PagedResidualLM
+from .voxcpm2_import_utils import import_voxcpm2_core
+
+logger = init_logger(__name__)
+
+_ENABLE_PROFILING = os.environ.get("VOXCPM2_PROFILE", "0") == "1"
+
+# Lower bound for the _active_states leak-warn threshold. The effective
+# threshold is max(_ACTIVE_STATE_LEAK_WARN_MIN, 4 * max_batch_size) so small
+# deployments still get a usable floor instead of a tiny noisy one.
+_ACTIVE_STATE_LEAK_WARN_MIN = 512
+
+
+def is_cjk_char(c: str) -> bool:
+ """Check if a character is a CJK ideograph."""
+ cp = ord(c)
+ return (
+ 0x4E00 <= cp <= 0x9FFF # CJK Unified Ideographs
+ or 0x3400 <= cp <= 0x4DBF # Extension A
+ or 0xF900 <= cp <= 0xFAFF # Compatibility Ideographs
+ or 0x20000 <= cp <= 0x2A6DF # Extension B
+ or 0x2A700 <= cp <= 0x2B73F # Extension C
+ or 0x2B740 <= cp <= 0x2B81F # Extension D
+ or 0x2F800 <= cp <= 0x2FA1F # Compatibility Supplement
+ )
+
+
+def build_cjk_split_map(tokenizer: Any) -> dict[int, list[int]]:
+ """Build {multichar_cjk_token_id: [single_char_ids]} from tokenizer vocab."""
+ vocab = tokenizer.get_vocab()
+ split_map: dict[int, list[int]] = {}
+ for token, token_id in vocab.items():
+ clean = token.replace("\u2581", "")
+ if len(clean) >= 2 and all(is_cjk_char(c) for c in clean):
+ char_ids = tokenizer.convert_tokens_to_ids(list(clean))
+ if all(cid != tokenizer.unk_token_id for cid in char_ids):
+ split_map[token_id] = char_ids
+ return split_map
+
+
+def split_multichar_chinese(token_ids: list[int], split_map: dict[int, list[int]]) -> list[int]:
+ """Replace multichar Chinese token IDs with single-char IDs (idempotent)."""
+ result: list[int] = []
+ for tid in token_ids:
+ expansion = split_map.get(tid)
+ if expansion is not None:
+ result.extend(expansion)
+ else:
+ result.append(tid)
+ return result
+
+
+def build_voxcpm2_prompt(
+ hf_config: Any,
+ tokenizer: Any,
+ split_map: dict[int, list[int]],
+ text: str,
+ ref_audio: Any | None = None,
+ ref_sr: int | None = None,
+ ref_text: str | None = None,
+) -> dict[str, Any]:
+ """Build a VoxCPM2 prefill prompt whose ``prompt_token_ids`` length matches
+ the talker-side prefill length.
+
+ Used by both online serving (``serving_speech._build_voxcpm2_prompt``) and
+ the offline example, so the talker-side length assertion never fires.
+ """
+ ids = split_multichar_chinese(tokenizer.encode(text, add_special_tokens=True), split_map)
+ bos = tokenizer.bos_token_id
+ if ids and ids[0] == bos:
+ ids = ids[1:]
+ prefill_len = len(ids) + 1 # + audio_start
+ additional: dict[str, Any] = {"text_token_ids": [ids]}
+ if ref_audio is not None:
+ vae = hf_config.audio_vae_config
+ patch_samples = hf_config.patch_size * math.prod(vae["encoder_rates"])
+ ref_len = math.ceil(math.ceil(len(ref_audio) * vae["sample_rate"] / ref_sr) / patch_samples)
+ if ref_text is not None:
+ additional["prompt_audio"] = [[ref_audio, ref_sr]]
+ additional["prompt_text"] = [ref_text]
+ ref_ids = split_multichar_chinese(tokenizer.encode(ref_text, add_special_tokens=True), split_map)
+ if ref_ids and ref_ids[0] == bos:
+ ref_ids = ref_ids[1:]
+ prefill_len += ref_len + len(ref_ids)
+ else:
+ additional["reference_audio"] = [[ref_audio, ref_sr]]
+ prefill_len += ref_len + 2 # ref_start / ref_end
+ return {"prompt_token_ids": [1] * prefill_len, "additional_information": additional}
+
+
+def _encode_raw_audio(
+ tts: nn.Module,
+ samples: list[float] | torch.Tensor,
+ sr: int,
+ padding_mode: str = "right",
+) -> torch.Tensor:
+ """Encode raw audio samples using the native VoxCPM2 AudioVAE.
+
+ Mirrors ``VoxCPM2Model._encode_wav`` but accepts in-memory samples
+ instead of a file path (needed for the OpenAI speech API).
+ """
+ if isinstance(samples, list):
+ audio = torch.tensor(samples, dtype=torch.float32)
+ else:
+ audio = samples.float()
+ if audio.ndim == 1:
+ audio = audio.unsqueeze(0)
+
+ encode_sr = tts._encode_sample_rate
+ if sr != encode_sr:
+ audio_np = audio.squeeze(0).numpy()
+ audio_np = librosa.resample(audio_np, orig_sr=sr, target_sr=encode_sr)
+ audio = torch.from_numpy(audio_np).unsqueeze(0)
+
+ patch_len = tts.patch_size * tts.chunk_size
+ if audio.size(1) % patch_len != 0:
+ padding_size = patch_len - audio.size(1) % patch_len
+ pad = (padding_size, 0) if padding_mode == "left" else (0, padding_size)
+ audio = torch.nn.functional.pad(audio, pad)
+
+ feat = tts.audio_vae.encode(audio.to(tts.device), encode_sr).cpu()
+ return feat.view(tts.audio_vae.latent_dim, -1, tts.patch_size).permute(1, 2, 0)
+
+
+# ===================================================================
+# Per-request state
+# ===================================================================
+
+
+@dataclasses.dataclass
+class _RequestState:
+ request_id: str
+ curr_embed_for_next: torch.Tensor | None = None
+ prev_feat_embed: torch.Tensor | None = None
+ curr_prefix_feat_cond: torch.Tensor | None = None
+ last_audio_patch_gpu: torch.Tensor | None = None
+ precomputed_stop_logits: torch.Tensor | None = None
+ # Rolling tail of previously-decoded latents used as VAE receptive-field context.
+ # Shape (n_pad_frames, feat_dim) on GPU. None before first decode.
+ decode_pad: torch.Tensor | None = None
+ # Audio chunks already emitted (CPU float32), concatenated for cumulative output.
+ audio_chunks: list[torch.Tensor] = dataclasses.field(default_factory=list)
+ decode_step_count: int = 0
+ request_start_time: float = 0.0
+ prefill_completed: bool = False
+ prefill_text: str = ""
+ prompt_cache: dict | None = None
+ prefill_masks: tuple | None = None
+ is_stopping: bool = False
+ last_decoded_audio: torch.Tensor | None = None
+
+
+@dataclasses.dataclass
+class _CapturedGraph:
+ graph: torch.cuda.CUDAGraph
+ input_embeds: torch.Tensor
+ positions: torch.Tensor
+ output: torch.Tensor
+
+
+# ===================================================================
+# Profiling timer
+# ===================================================================
+
+
+class _PerfTimer:
+ __slots__ = ("_enabled", "_timers", "_counts", "_starts", "_pairs")
+
+ def __init__(self, enabled: bool = False):
+ self._enabled = enabled
+ self._timers: dict[str, float] = {}
+ self._counts: dict[str, int] = {}
+ self._starts: dict[str, torch.cuda.Event] = {}
+ self._pairs: list[tuple[str, torch.cuda.Event, torch.cuda.Event]] = []
+
+ def start(self, name: str) -> None:
+ if not self._enabled:
+ return
+ evt = torch.cuda.Event(enable_timing=True)
+ evt.record()
+ self._starts[name] = evt
+
+ def stop(self, name: str) -> None:
+ if not self._enabled or name not in self._starts:
+ return
+ start_evt = self._starts.pop(name)
+ end_evt = torch.cuda.Event(enable_timing=True)
+ end_evt.record()
+ self._pairs.append((name, start_evt, end_evt))
+
+ def _resolve(self) -> None:
+ if not self._pairs:
+ return
+ torch.cuda.synchronize()
+ for name, s, e in self._pairs:
+ self._timers[name] = self._timers.get(name, 0.0) + s.elapsed_time(e)
+ self._counts[name] = self._counts.get(name, 0) + 1
+ self._pairs.clear()
+
+ def breakdown(self) -> str:
+ if not self._enabled:
+ return ""
+ self._resolve()
+ if not self._timers:
+ return ""
+ total = self._timers.get("decode_step", sum(self._timers.values()))
+ lines = [
+ "=== VoxCPM2 Decode Step Breakdown ===",
+ f"{'Component':<30} | {'ms':>10} | {'%':>6} | {'N':>5} | {'avg':>8}",
+ "-" * 70,
+ ]
+ for name in sorted(self._timers):
+ t, c = self._timers[name], self._counts[name]
+ lines.append(f"{name:<30} | {t:>10.2f} | {t / total * 100:>5.1f}% | {c:>5} | {t / c:>8.3f}")
+ lines.append(f"{'TOTAL':<30} | {total:>10.2f} |")
+ return "\n".join(lines)
+
+ def reset(self) -> None:
+ self._timers.clear()
+ self._counts.clear()
+ self._starts.clear()
+ self._pairs.clear()
+
+
+# ===================================================================
+# CFM pre-allocated buffers + optimized Euler solver
+# ===================================================================
+
+
+class _CFMBufferManager:
+ def __init__(
+ self,
+ device: torch.device,
+ dtype: torch.dtype,
+ feat_dim: int,
+ patch_size: int,
+ dit_hidden_size: int,
+ max_batch_size: int = 1,
+ sway_sampling_coef: float = 1.0,
+ ):
+ n = 2 * max_batch_size # CFG doubles the batch
+ self.x_in = torch.zeros(n, feat_dim, patch_size, device=device, dtype=dtype)
+ self.mu_in = torch.zeros(n, dit_hidden_size, device=device, dtype=dtype)
+ self.t_in = torch.zeros(n, device=device, dtype=dtype)
+ self.dt_in = torch.zeros(n, device=device, dtype=dtype)
+ self.cond_in = torch.zeros(n, feat_dim, patch_size, device=device, dtype=dtype)
+ self.noise = torch.zeros(max_batch_size, feat_dim, patch_size, device=device, dtype=dtype)
+ self._sway_coef = sway_sampling_coef
+ self._device = device
+ self._dtype = dtype
+ self.t_span_10 = self._make_t_span(10)
+
+ def _make_t_span(self, n: int) -> torch.Tensor:
+ t = torch.linspace(1, 0, n + 1, device=self._device, dtype=self._dtype)
+ return t + self._sway_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
+
+ def get_t_span(self, n: int) -> torch.Tensor:
+ return self.t_span_10 if n == 10 else self._make_t_span(n)
+
+
+def _optimized_solve_euler(
+ cfm_module: nn.Module,
+ mu: torch.Tensor,
+ patch_size: int,
+ cond: torch.Tensor,
+ n_timesteps: int,
+ cfg_value: float,
+ buffers: _CFMBufferManager,
+ use_cfg_zero_star: bool = True,
+ cfg_cutoff_ratio: float = 1.0,
+ perf: _PerfTimer | None = None,
+) -> torch.Tensor:
+ estimator = cfm_module.estimator
+ mean_mode = getattr(cfm_module, "mean_mode", False)
+ b = mu.size(0)
+
+ buffers.noise[:b].normal_()
+ x = buffers.noise[:b].clone()
+
+ t_span = buffers.get_t_span(n_timesteps)
+ t, dt = t_span[0], t_span[0] - t_span[1]
+ zero_init_steps = max(1, int(len(t_span) * 0.04))
+ cfg_cutoff_step = max(zero_init_steps + 1, int(len(t_span) * cfg_cutoff_ratio))
+
+ for step in range(1, len(t_span)):
+ if use_cfg_zero_star and step <= zero_init_steps:
+ dphi_dt = torch.zeros_like(x)
+ elif step <= cfg_cutoff_step:
+ buffers.x_in[:b].copy_(x)
+ buffers.x_in[b : 2 * b].copy_(x)
+ buffers.mu_in[:b].copy_(mu)
+ buffers.mu_in[b : 2 * b].zero_()
+ # Broadcast the 0-dim GPU scalar directly instead of
+ # ``.fill_(t.item())`` — ``.item()`` forces a GPU->CPU sync.
+ buffers.t_in[: 2 * b].copy_(t)
+ if mean_mode:
+ buffers.dt_in[: 2 * b].copy_(dt)
+ else:
+ buffers.dt_in.zero_()
+ buffers.cond_in[:b].copy_(cond[:b])
+ buffers.cond_in[b : 2 * b].copy_(cond[:b])
+
+ if perf:
+ perf.start(" cfm.estimator_cfg")
+ raw_out = estimator(
+ buffers.x_in[: 2 * b],
+ buffers.mu_in[: 2 * b],
+ buffers.t_in[: 2 * b],
+ buffers.cond_in[: 2 * b],
+ buffers.dt_in[: 2 * b],
+ )
+ if perf:
+ perf.stop(" cfm.estimator_cfg")
+
+ dphi_dt, cfg_dphi_dt = raw_out[:b], raw_out[b : 2 * b]
+ if use_cfg_zero_star:
+ pos = dphi_dt.reshape(b, -1)
+ neg = cfg_dphi_dt.reshape(b, -1)
+ st = torch.sum(pos * neg, 1, keepdim=True) / (torch.sum(neg**2, 1, keepdim=True) + 1e-8)
+ st = st.view(b, *([1] * (len(dphi_dt.shape) - 1)))
+ else:
+ st = 1.0
+ dphi_dt = cfg_dphi_dt * st + cfg_value * (dphi_dt - cfg_dphi_dt * st)
+ else:
+ buffers.x_in[:b].copy_(x)
+ buffers.mu_in[:b].copy_(mu)
+ # Broadcast the 0-dim GPU scalar; ``.fill_(t.item())`` would sync.
+ buffers.t_in[:b].copy_(t)
+ if mean_mode:
+ buffers.dt_in[:b].copy_(dt)
+ else:
+ buffers.dt_in[:b].zero_()
+ buffers.cond_in[:b].copy_(cond[:b])
+ if perf:
+ perf.start(" cfm.estimator_nocfg")
+ dphi_dt = estimator(
+ buffers.x_in[:b], buffers.mu_in[:b], buffers.t_in[:b], buffers.cond_in[:b], buffers.dt_in[:b]
+ )
+ if perf:
+ perf.stop(" cfm.estimator_nocfg")
+
+ x = x - dt * dphi_dt
+ t = t - dt
+ if step < len(t_span) - 1:
+ dt = t - t_span[step + 1]
+ return x
+
+
+# ===================================================================
+# Main talker model
+# ===================================================================
+
+
+class VoxCPM2TalkerForConditionalGeneration(nn.Module):
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ self.vllm_config = vllm_config
+ self.config = vllm_config.model_config.hf_config
+
+ self.have_multimodal_outputs = True
+ self.has_preprocess = True
+ self.has_postprocess = True
+
+ self.model = MiniCPM4PagedForVoxCPM2(
+ vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "model"),
+ )
+ self.residual_model = MiniCPM4PagedResidualLM(
+ vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "residual_model"),
+ )
+ self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors
+
+ self._tts: nn.Module | None = None
+ self._device = "cuda"
+ self._side_dtype = torch.bfloat16
+
+ self._patch_size = getattr(self.config, "patch_size", 4)
+ self._feat_dim = getattr(self.config, "feat_dim", 64)
+ self._sample_rate = getattr(self.config, "sample_rate", 48000)
+
+ self._inference_timesteps = 10
+ self._cfg_value = 2.0
+ self._cfg_cutoff_ratio = 1.0
+ # Number of trailing latent frames to keep as VAE receptive-field context
+ # for sliding-window streaming decode. 12 matches the nanovllm reference
+ # implementation and covers the longest VAE decoder receptive field.
+ self._n_decode_pad_frames = 12
+ self._enable_torch_compile = True
+ self._compile_vae = True
+ self._max_decode_steps = 2000
+ self._max_batch_size = getattr(vllm_config.scheduler_config, "max_num_seqs", 4)
+
+ self._perf = _PerfTimer(enabled=_ENABLE_PROFILING)
+ self._cfm_buffers: _CFMBufferManager | None = None
+ self._enable_cuda_graph = True
+ self._scaffold_graphs: dict[int, _CapturedGraph] = {}
+ self._residual_graphs: dict[int, _CapturedGraph] = {}
+ self._max_cached_graphs = self._max_batch_size
+ self._cuda_graph_pool: tuple | None = None
+ self._cuda_graph_warmup_steps = 0
+ self._cuda_graph_warmup_threshold = 3
+
+ self._multichar_zh_split: dict[int, list[int]] | None = None
+
+ self._active_states: dict[str, _RequestState] = {}
+ self._current_request_id: str | None = None
+ self._pending_requests: list[tuple[str, bool, torch.Tensor | None, int]] = []
+ self._results_queue: list[tuple[str, torch.Tensor | None]] = []
+ self._audio_queue: list[tuple[str, Any]] = []
+ self._deferred_cleanup_ids: set[str] = set()
+ self._active_state_warn_threshold = max(_ACTIVE_STATE_LEAK_WARN_MIN, 4 * self._max_batch_size)
+ # one-shot by design: fires at most once per process to avoid log spam.
+ self._active_state_warned = False
+
+ @property
+ def tts(self) -> nn.Module:
+ assert self._tts is not None, "Model not loaded yet"
+ return self._tts
+
+ # -------------------- request state management --------------------
+
+ def _get_or_create_state(self, request_id: str) -> _RequestState:
+ state = self._active_states.get(request_id)
+ if state is None:
+ state = _RequestState(request_id=request_id)
+ self._active_states[request_id] = state
+ if len(self._active_states) > self._active_state_warn_threshold and not self._active_state_warned:
+ logger.warning(
+ "VoxCPM2: _active_states size=%d exceeds threshold %d "
+ "(max_batch_size=%d); possible cleanup path leak",
+ len(self._active_states),
+ self._active_state_warn_threshold,
+ self._max_batch_size,
+ )
+ self._active_state_warned = True
+ return state
+
+ def _switch_to_request(self, request_id: str) -> _RequestState:
+ if request_id != self._current_request_id:
+ self._current_request_id = request_id
+ return self._get_or_create_state(request_id)
+
+ def _cleanup_request(self, request_id: str) -> None:
+ self._active_states.pop(request_id, None)
+ if self._current_request_id == request_id:
+ self._current_request_id = None
+
+ def on_requests_finished(self, finished_req_ids: set[str] | list[str]) -> None:
+ # Defer cleanup: on_requests_finished is called before forward(),
+ # so we must not delete state that the current step may still need.
+ self._deferred_cleanup_ids.update(finished_req_ids)
+
+ def _flush_deferred_cleanup(self) -> None:
+ for req_id in self._deferred_cleanup_ids:
+ self._cleanup_request(req_id)
+ self._deferred_cleanup_ids.clear()
+
+ def _build_prompt_cache(
+ self,
+ ref_audio: Any = None,
+ prompt_audio: Any = None,
+ prompt_text: str | None = None,
+ ) -> dict | None:
+ """Build prompt cache, handling both file paths and raw audio data.
+
+ The OpenAI speech API sends decoded audio as [samples_list, sr]
+ via ``_resolve_ref_audio``, while offline usage sends file paths.
+ """
+ tts = self.tts
+
+ def _is_raw_audio(v: Any) -> bool:
+ import numbers
+
+ return (
+ isinstance(v, (list, tuple))
+ and len(v) == 2
+ and isinstance(v[1], numbers.Integral)
+ and isinstance(v[0], (list, torch.Tensor))
+ )
+
+ if not _is_raw_audio(ref_audio) and not _is_raw_audio(prompt_audio):
+ return tts.build_prompt_cache(
+ prompt_text=prompt_text,
+ prompt_wav_path=prompt_audio,
+ reference_wav_path=ref_audio,
+ )
+
+ cache: dict[str, Any] = {}
+ if ref_audio is not None:
+ if _is_raw_audio(ref_audio):
+ samples, sr = ref_audio
+ cache["ref_audio_feat"] = _encode_raw_audio(tts, samples, sr)
+ else:
+ cache["ref_audio_feat"] = tts._encode_wav(ref_audio, padding_mode="right")
+
+ if prompt_audio is not None and prompt_text is not None:
+ cache["prompt_text"] = prompt_text
+ if _is_raw_audio(prompt_audio):
+ samples, sr = prompt_audio
+ cache["audio_feat"] = _encode_raw_audio(tts, samples, sr, padding_mode="left")
+ else:
+ cache["audio_feat"] = tts._encode_wav(prompt_audio, padding_mode="left")
+
+ has_ref = "ref_audio_feat" in cache
+ has_prompt = "audio_feat" in cache
+ if has_ref and has_prompt:
+ cache["mode"] = "ref_continuation"
+ elif has_ref:
+ cache["mode"] = "reference"
+ else:
+ cache["mode"] = "continuation"
+
+ return cache
+
+ # -------------------- compile setup --------------------
+
+ def _setup_cfm_buffers(self) -> None:
+ if self._cfm_buffers is not None:
+ return
+ tts = self.tts
+ dit_hidden = tts.lm_to_dit_proj.out_features + tts.res_to_dit_proj.out_features
+ self._cfm_buffers = _CFMBufferManager(
+ device=torch.device(self._device),
+ dtype=self._side_dtype,
+ feat_dim=self._feat_dim,
+ patch_size=self._patch_size,
+ dit_hidden_size=dit_hidden,
+ max_batch_size=self._max_batch_size,
+ )
+
+ def _setup_torch_compile(self) -> None:
+ if not self._enable_torch_compile:
+ return
+ tts = self.tts
+ estimator = tts.feat_decoder.estimator
+ if hasattr(estimator, "_compiled"):
+ return
+
+ targets: list[str] = []
+
+ try:
+ tts.feat_decoder.estimator = torch.compile(estimator, mode="reduce-overhead", fullgraph=False)
+ tts.feat_decoder.estimator._compiled = True
+ targets.append("LocDiT")
+ except Exception as e:
+ logger.warning("torch.compile LocDiT failed: %s", e)
+
+ try:
+ if not hasattr(tts.feat_encoder, "_compiled"):
+ tts.feat_encoder = torch.compile(tts.feat_encoder, mode="reduce-overhead", fullgraph=False)
+ tts.feat_encoder._compiled = True
+ targets.append("feat_encoder")
+ except Exception as e:
+ logger.warning("torch.compile feat_encoder failed: %s", e)
+
+ if self._compile_vae:
+ try:
+ if not hasattr(tts.audio_vae, "_compiled"):
+ tts.audio_vae.decode = torch.compile(tts.audio_vae.decode, mode="reduce-overhead", fullgraph=False)
+ tts.audio_vae._compiled = True
+ targets.append("AudioVAE")
+ except Exception as e:
+ logger.warning("torch.compile AudioVAE failed: %s", e)
+
+ if not self._enable_cuda_graph:
+ if not getattr(self.model, "_selective_compiled", False):
+ try:
+ targets.extend(f"scaffold.{t}" for t in self.model.compile_selective())
+ self.model._selective_compiled = True
+ except Exception as e:
+ logger.warning("scaffold compile failed: %s", e)
+
+ if not getattr(self.residual_model, "_selective_compiled", False):
+ try:
+ targets.extend(f"residual.{t}" for t in self.residual_model.compile_selective())
+ self.residual_model._selective_compiled = True
+ except Exception as e:
+ logger.warning("residual compile failed: %s", e)
+ else:
+ self.model.precompute_fused_qkv()
+ self.residual_model.precompute_fused_qkv()
+ targets.append("scaffold+residual (CUDA Graph, skipping compile)")
+
+ if not getattr(self, "_projections_compiled", False):
+ try:
+ self._compiled_dit_proj = torch.compile(self._dit_proj_fn, mode="default", fullgraph=True)
+ self._compiled_stop_fn = torch.compile(self._stop_fn, mode="default", fullgraph=True)
+ self._projections_compiled = True
+ targets.append("projections")
+ except Exception as e:
+ self._compiled_dit_proj = self._compiled_stop_fn = None
+ logger.warning("projections compile failed: %s", e)
+
+ if targets:
+ logger.info("VoxCPM2: torch.compile applied to: %s", ", ".join(targets))
+
+ def _dit_proj_fn(self, lm_h: torch.Tensor, res_h: torch.Tensor) -> torch.Tensor:
+ tts = self.tts
+ return torch.cat([tts.lm_to_dit_proj(lm_h), tts.res_to_dit_proj(res_h)], dim=-1)
+
+ def _stop_fn(self, lm_h: torch.Tensor) -> torch.Tensor:
+ tts = self.tts
+ return tts.stop_head(tts.stop_actn(tts.stop_proj(lm_h)))
+
+ def _get_cuda_graph_pool(self) -> tuple:
+ if self._cuda_graph_pool is None:
+ self._cuda_graph_pool = torch.cuda.graph_pool_handle()
+ return self._cuda_graph_pool
+
+ @staticmethod
+ def _nullify_volatile_metadata(ctx: Any) -> Any:
+ """Set ``scheduler_metadata`` to None on all attention layers.
+
+ This is the only tensor FA3 reallocates each step (variable shape).
+ All other metadata tensors are persistent model-runner buffers.
+ Setting it to None makes FA3 use default scheduling (~0.1ms cost).
+ """
+ if not isinstance(ctx.attn_metadata, dict):
+ return ctx
+
+ ctx = copy.copy(ctx)
+ new_meta: dict[str, Any] = {}
+ for layer_name, meta in ctx.attn_metadata.items():
+ if getattr(meta, "scheduler_metadata", None) is not None:
+ meta = copy.copy(meta)
+ meta.scheduler_metadata = None
+ new_meta[layer_name] = meta
+ ctx.attn_metadata = new_meta
+ return ctx
+
+ def _capture_graph(
+ self,
+ model: nn.Module,
+ batch_size: int,
+ label: str,
+ is_residual: bool = False,
+ ) -> _CapturedGraph:
+ """Capture a CUDA Graph for *model* at *batch_size*."""
+ hidden_size = self.config.hidden_size
+ dtype = self._side_dtype
+ dev = torch.device(self._device)
+ pool = self._get_cuda_graph_pool()
+
+ model.precompute_fused_qkv()
+
+ g = _CapturedGraph(
+ graph=torch.cuda.CUDAGraph(),
+ input_embeds=torch.zeros(batch_size, hidden_size, device=dev, dtype=dtype),
+ positions=torch.zeros(batch_size, device=dev, dtype=torch.long),
+ output=torch.zeros(batch_size, hidden_size, device=dev, dtype=dtype),
+ )
+
+ if is_residual:
+ call_kwargs = dict(positions=g.positions, inputs_embeds=g.input_embeds)
+ else:
+ call_kwargs = dict(input_ids=None, positions=g.positions, inputs_embeds=g.input_embeds)
+
+ ctx = get_forward_context()
+ patched_ctx = self._nullify_volatile_metadata(ctx)
+
+ with override_forward_context(patched_ctx):
+ for _ in range(3):
+ _ = model(**call_kwargs)
+
+ with torch.cuda.graph(g.graph, pool=pool):
+ g.output = model(**call_kwargs)
+
+ logger.info("CUDA Graph captured for %s (batch_size=%d)", label, batch_size)
+ return g
+
+ def _replay_graph(
+ self,
+ g: _CapturedGraph,
+ inputs_embeds: torch.Tensor,
+ positions: torch.Tensor,
+ batch_size: int,
+ ) -> torch.Tensor:
+ """Copy fresh inputs into static buffers, then replay.
+
+ No metadata copy needed: persistent buffers (seq_lens, slot_mapping,
+ etc.) are updated in-place by the model runner. scheduler_metadata
+ was nullified at capture time so no kernel references it.
+ """
+ g.input_embeds[:batch_size].copy_(inputs_embeds[:batch_size])
+ g.positions[:batch_size].copy_(positions[:batch_size])
+ g.graph.replay()
+ return g.output[:batch_size].clone()
+
+ # -------------------- vllm hooks --------------------
+
+ def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor:
+ return self.model.embed_input_ids(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ **kwargs: Any,
+ ) -> torch.Tensor | IntermediateTensors:
+ self._perf.start("forward_total")
+ dev = input_ids.device
+
+ num_reqs = len(self._pending_requests)
+ num_decode = sum(1 for _, is_p, _, n in self._pending_requests if not is_p and n == 1)
+ is_all_decode = num_decode == num_reqs and num_reqs > 0
+
+ tts_compiled = getattr(self.tts.feat_decoder.estimator, "_compiled", False) if self._tts is not None else False
+ graph_ready = tts_compiled and self._cuda_graph_warmup_steps >= self._cuda_graph_warmup_threshold
+ if num_decode > 0:
+ self._cuda_graph_warmup_steps += 1
+
+ can_use_graph = (
+ self._enable_cuda_graph and graph_ready and intermediate_tensors is None and inputs_embeds is not None
+ )
+
+ if can_use_graph and is_all_decode and num_reqs <= self._max_cached_graphs:
+ self._perf.start("scaffold_fwd")
+ if num_reqs not in self._scaffold_graphs:
+ self._scaffold_graphs[num_reqs] = self._capture_graph(self.model, num_reqs, "scaffold")
+ scaffold_hidden = self._replay_graph(self._scaffold_graphs[num_reqs], inputs_embeds, positions, num_reqs)
+ self._perf.stop("scaffold_fwd")
+
+ else:
+ self._perf.start("scaffold_fwd")
+ model_output = self.model(input_ids, positions, intermediate_tensors, inputs_embeds)
+ self._perf.stop("scaffold_fwd")
+ if isinstance(model_output, IntermediateTensors):
+ return model_output
+ scaffold_hidden = model_output
+ if isinstance(scaffold_hidden, tuple):
+ scaffold_hidden = scaffold_hidden[0]
+
+ # Phase 1: per-request FSQ + residual input
+ token_offset = 0
+ residual_inputs: list[torch.Tensor] = []
+ residual_positions: list[torch.Tensor] = []
+ req_metas: list[tuple] = []
+
+ for req_id, is_prefill, _req_embeds, n in self._pending_requests:
+ state = self._switch_to_request(req_id)
+ req_hidden = scaffold_hidden[token_offset : token_offset + n]
+ req_pos = positions[token_offset : token_offset + n]
+
+ if is_prefill:
+ res_input, meta = self._prepare_residual_prefill(state, req_hidden, dev)
+ elif state.prefill_completed:
+ res_input, meta = self._prepare_residual_decode(state, req_hidden, dev)
+ else:
+ token_offset += n
+ self._results_queue.append((req_id, None))
+ self._audio_queue.append((req_id, None))
+ continue
+
+ residual_inputs.append(res_input)
+ residual_positions.append(req_pos)
+ req_metas.append((state, is_prefill, meta))
+ token_offset += n
+
+ # Phase 2: batch residual_model forward
+ if residual_inputs:
+ batch_in = torch.cat(residual_inputs, dim=0)
+ batch_pos = torch.cat(residual_positions, dim=0)
+
+ residual_batch_size = batch_in.shape[0]
+ use_residual_graph = (
+ self._enable_cuda_graph
+ and is_all_decode
+ and graph_ready
+ and residual_batch_size == num_reqs # 1 token per request
+ and residual_batch_size <= self._max_cached_graphs
+ )
+
+ self._perf.start("residual_fwd")
+ if use_residual_graph:
+ if residual_batch_size not in self._residual_graphs:
+ self._residual_graphs[residual_batch_size] = self._capture_graph(
+ self.residual_model, residual_batch_size, "residual", is_residual=True
+ )
+ batch_out = self._replay_graph(
+ self._residual_graphs[residual_batch_size], batch_in, batch_pos, residual_batch_size
+ )
+ else:
+ batch_out = self.residual_model(batch_pos, batch_in)
+ self._perf.stop("residual_fwd")
+
+ # Phase 3: per-request LocDiT + update
+ offset = 0
+ for idx, (state, is_prefill, meta) in enumerate(req_metas):
+ n = residual_inputs[idx].shape[0]
+ res_out = batch_out[offset : offset + n]
+ offset += n
+
+ if is_prefill:
+ self._finish_prefill(state, meta, res_out, dev)
+ else:
+ self._finish_decode(state, meta, res_out, dev)
+
+ self._results_queue.append((state.request_id, state.precomputed_stop_logits))
+ self._audio_queue.append((state.request_id, self._collect_audio(state)))
+
+ self._pending_requests.clear()
+ self._flush_deferred_cleanup()
+ self._perf.stop("forward_total")
+ return scaffold_hidden
+
+ # -------------------- prefill / decode helpers --------------------
+
+ def _prepare_residual_prefill(self, state: _RequestState, base_lm_out: torch.Tensor, dev: Any):
+ tts = self.tts
+ text_mask, feat_mask, feat, feat_embed = state.prefill_masks
+ state.prefill_masks = None
+
+ tts_len = text_mask.shape[1]
+ scaffold_len = base_lm_out.shape[0]
+ assert scaffold_len == tts_len, (
+ f"voxcpm2 prefill length mismatch: scaffold_len={scaffold_len} tts_len={tts_len}; "
+ "caller must pad prompt_token_ids to the full prefill length "
+ "(see serving_speech._build_voxcpm2_prompt or the offline example)."
+ )
+ enc_out = base_lm_out.unsqueeze(0)
+
+ prefix_feat_cond = (
+ feat[:, -1, ...]
+ if feat.shape[1] > 0
+ else torch.zeros(1, self._patch_size, self._feat_dim, device=dev, dtype=self._side_dtype)
+ )
+ enc_outputs = tts.fsq_layer(enc_out) * feat_mask.unsqueeze(-1) + enc_out * text_mask.unsqueeze(-1)
+ lm_hidden = enc_outputs[:, -1, :]
+
+ residual_input = tts.fusion_concat_proj(torch.cat([enc_outputs, feat_mask.unsqueeze(-1) * feat_embed], dim=-1))
+ meta = {"lm_hidden": lm_hidden, "prefix_feat_cond": prefix_feat_cond}
+ return residual_input.squeeze(0), meta
+
+ def _prepare_residual_decode(self, state: _RequestState, base_lm_out: torch.Tensor, dev: Any):
+ tts = self.tts
+ state.decode_step_count += 1
+
+ if state.decode_step_count >= self._max_decode_steps:
+ logger.warning("MAX_DECODE_STEPS for %s (%d), forcing stop", state.request_id, state.decode_step_count)
+ state.is_stopping = True
+
+ h = base_lm_out.unsqueeze(0) if base_lm_out.ndim == 1 else base_lm_out
+ lm_h = tts.fsq_layer(h)
+ if lm_h.ndim == 1:
+ lm_h = lm_h.unsqueeze(0)
+
+ prev = state.prev_feat_embed.to(self._side_dtype)
+ if prev.ndim == 1:
+ prev = prev.unsqueeze(0)
+ res_input = tts.fusion_concat_proj(torch.cat([lm_h, prev], dim=-1))
+ return res_input, {"new_lm_hidden": lm_h}
+
+ def _run_cfm(self, dit_h: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
+ if self._cfm_buffers is not None:
+ return _optimized_solve_euler(
+ self.tts.feat_decoder,
+ dit_h,
+ self._patch_size,
+ cond,
+ self._inference_timesteps,
+ self._cfg_value,
+ self._cfm_buffers,
+ cfg_cutoff_ratio=self._cfg_cutoff_ratio,
+ perf=self._perf,
+ ).transpose(1, 2)
+ return self.tts.feat_decoder(
+ mu=dit_h,
+ patch_size=self._patch_size,
+ cond=cond,
+ n_timesteps=self._inference_timesteps,
+ cfg_value=self._cfg_value,
+ ).transpose(1, 2)
+
+ def _finish_prefill(self, state: _RequestState, meta: dict, res_out: torch.Tensor, dev: Any):
+ tts = self.tts
+ lm_hidden = meta["lm_hidden"]
+ prefix_feat_cond = meta["prefix_feat_cond"]
+ residual_hidden = res_out[-1:, :]
+
+ state.precomputed_stop_logits = tts.stop_head(tts.stop_actn(tts.stop_proj(lm_hidden))).detach()
+ dit_h = torch.cat([tts.lm_to_dit_proj(lm_hidden), tts.res_to_dit_proj(residual_hidden)], dim=-1)
+
+ self._setup_cfm_buffers()
+ if self._enable_torch_compile:
+ self._setup_torch_compile()
+
+ pred_feat = self._run_cfm(dit_h, prefix_feat_cond.transpose(1, 2).contiguous())
+
+ with torch.no_grad():
+ curr_embed = tts.enc_to_lm_proj(tts.feat_encoder(pred_feat.unsqueeze(1))).squeeze(1)
+
+ state.curr_embed_for_next = curr_embed.detach()
+ state.prev_feat_embed = curr_embed.detach()
+ state.curr_prefix_feat_cond = pred_feat[0].detach()
+ state.last_audio_patch_gpu = pred_feat.detach()
+ state.decode_step_count = 0
+ state.request_start_time = time.perf_counter()
+ state.prefill_completed = True
+
+ if logger.isEnabledFor(logging.DEBUG):
+ # Only compute the norm (which forces a GPU->CPU sync) if we will log it.
+ logger.debug("PREFILL[%s]: patch norm=%.4f", state.request_id, pred_feat.norm().item())
+ self._perf.reset()
+
+ def _finish_decode(self, state: _RequestState, meta: dict, res_out: torch.Tensor, dev: Any):
+ self._perf.start("decode_step")
+ tts = self.tts
+
+ lm_h = meta["new_lm_hidden"]
+ res_h = res_out.unsqueeze(0) if res_out.ndim == 1 else res_out
+
+ dit_proj = getattr(self, "_compiled_dit_proj", None) or self._dit_proj_fn
+ stop_fn = getattr(self, "_compiled_stop_fn", None) or self._stop_fn
+
+ dit_h = dit_proj(lm_h, res_h)
+ pfc = state.curr_prefix_feat_cond.to(self._side_dtype)
+ if pfc.ndim == 2:
+ pfc = pfc.unsqueeze(0)
+
+ pred_feat = self._run_cfm(dit_h, pfc.transpose(1, 2).contiguous())
+ next_embed = tts.enc_to_lm_proj(tts.feat_encoder(pred_feat.unsqueeze(1))).squeeze(1)
+
+ state.precomputed_stop_logits = stop_fn(lm_h).detach()
+ state.curr_embed_for_next = next_embed.detach()
+ state.prev_feat_embed = next_embed.detach()
+ state.curr_prefix_feat_cond = pred_feat[0].detach()
+ state.last_audio_patch_gpu = pred_feat.detach()
+
+ self._perf.stop("decode_step")
+ if _ENABLE_PROFILING and state.decode_step_count % 20 == 0:
+ logger.info("Step %d[%s]:\n%s", state.decode_step_count, state.request_id, self._perf.breakdown())
+
+ # -------------------- audio collection --------------------
+
+ def _collect_audio(self, state: _RequestState) -> torch.Tensor | None:
+ """Per-step sliding-window VAE decode (nanovllm pattern).
+
+ Each decode step feeds ``[decode_pad, new_patch]`` through the VAE
+ and slices out only the audio region corresponding to the new patch.
+ The pad buffer (last ``_n_decode_pad_frames`` latent frames) provides
+ the receptive-field context needed by the VAE's transposed convolutions,
+ eliminating boundary artifacts between chunks.
+
+ Returns the delta audio chunk (not cumulative) so the output processor
+ can stream each chunk to the client independently.
+ """
+ patch = state.last_audio_patch_gpu
+ if patch is None:
+ return None
+ state.last_audio_patch_gpu = None
+
+ # patch shape: (patch_size, feat_dim) or (1, patch_size, feat_dim)
+ new_latent = patch.reshape(-1, self._feat_dim).to(torch.float32)
+ n_new = new_latent.shape[0] # = patch_size (typically 4)
+
+ self._perf.start("vae_decode")
+
+ # Build VAE input: [pad_frames | new_latent]
+ if state.decode_pad is not None:
+ vae_input = torch.cat([state.decode_pad, new_latent], dim=0)
+ pad_frames = state.decode_pad.shape[0]
+ else:
+ vae_input = new_latent
+ pad_frames = 0
+
+ # VAE decode: (1, feat_dim, T_frames) -> (1, 1, T_samples)
+ feat = vae_input.unsqueeze(0).transpose(1, 2).contiguous()
+ with torch.no_grad():
+ audio = self.tts.audio_vae.decode(feat.to(self._device)).reshape(-1)
+
+ # Slice out only the new audio (after the pad region).
+ # Each latent frame maps to decoder_chunk_size audio samples.
+ dcs = int(getattr(self.tts.audio_vae, "decode_chunk_size", audio.numel() // vae_input.shape[0]))
+ new_audio = audio[pad_frames * dcs : (pad_frames + n_new) * dcs].detach().cpu().float()
+
+ # Roll the pad buffer: keep last N latent frames as context for next step.
+ all_latents = vae_input # [pad + new]
+ state.decode_pad = all_latents[-self._n_decode_pad_frames :].detach()
+
+ state.audio_chunks.append(new_audio)
+ state.last_decoded_audio = new_audio
+ self._perf.stop("vae_decode")
+ return new_audio
+
+ # -------------------- compute_logits --------------------
+
+ def compute_logits(
+ self, hidden_states: torch.Tensor | OmniOutput, sampling_metadata: Any = None
+ ) -> torch.Tensor | None:
+ if isinstance(hidden_states, OmniOutput):
+ hidden_states = hidden_states.text_hidden_states
+ if hidden_states is None:
+ return None
+
+ bsz = hidden_states.shape[0]
+ logits = torch.full(
+ (bsz, self.config.vocab_size), float("-inf"), device=hidden_states.device, dtype=hidden_states.dtype
+ )
+
+ if self._results_queue:
+ for i, (req_id, stop_logits) in enumerate(self._results_queue):
+ if i >= bsz:
+ break
+ state = self._active_states.get(req_id)
+ if stop_logits is not None:
+ if state is not None and state.is_stopping:
+ logits[i, 0] = 0.0
+ logits[i, 1] = 1.0
+ state.precomputed_stop_logits = None
+ else:
+ logits[i, 0] = stop_logits[0, 0]
+ logits[i, 1] = stop_logits[0, 1]
+ if state is not None:
+ state.is_stopping = bool(stop_logits[0, 1] > stop_logits[0, 0])
+ state.precomputed_stop_logits = None
+ elif state and state.prefill_completed:
+ logits[i, 1] = 1.0
+ else:
+ logits[i, 0] = 1.0
+ self._results_queue.clear()
+ else:
+ logits[:, 0] = 1.0
+ return logits
+
+ # -------------------- omni output --------------------
+
+ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: Any) -> OmniOutput:
+ if isinstance(model_outputs, OmniOutput):
+ return model_outputs
+
+ mm: dict[str, Any] = {}
+ if self._audio_queue:
+ audio_by_req = {rid: audio for rid, audio in self._audio_queue}
+ order = [r for r, _ in self._audio_queue]
+ mm["model_outputs"] = [audio_by_req.get(r) for r in order]
+ mm["sr"] = [torch.tensor(self._sample_rate, dtype=torch.int32) for _ in order]
+ self._audio_queue.clear()
+
+ return OmniOutput(text_hidden_states=model_outputs, multimodal_outputs=mm)
+
+ # -------------------- Chinese token splitting --------------------
+
+ def _get_multichar_zh_split(self) -> dict[int, list[int]]:
+ """Lazy-build {multichar_chinese_token_id: [char_id, ...]} map."""
+ if self._multichar_zh_split is not None:
+ return self._multichar_zh_split
+ base_tokenizer = self.tts.text_tokenizer.tokenizer
+ self._multichar_zh_split = build_cjk_split_map(base_tokenizer)
+ logger.info("VoxCPM2: built multichar Chinese split map (%d entries)", len(self._multichar_zh_split))
+ return self._multichar_zh_split
+
+ # -------------------- preprocess / postprocess --------------------
+
+ def preprocess(
+ self, input_ids: torch.Tensor, input_embeds: torch.Tensor | None, **info_dict: Any
+ ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]:
+ additional = info_dict.get("additional_information")
+ if isinstance(additional, dict):
+ merged = {k: v for k, v in info_dict.items() if k != "additional_information"}
+ for k, v in additional.items():
+ merged.setdefault(k, v)
+ info_dict = merged
+
+ span_len = int(input_ids.shape[0])
+ dev = input_ids.device
+ req_id = info_dict.get("request_id", "default")
+ is_prefill = span_len > 1
+
+ if is_prefill:
+ # Do not evict state here: _pending_requests is a per-step prefix,
+ # not the full batch. Cleanup is driven by on_requests_finished ->
+ # _flush_deferred_cleanup (fed by vLLM scheduler._free_request via
+ # gpu_ar_model_runner.py).
+ real = info_dict.get("text_token_ids")
+ token_ids = input_ids.tolist() if real is None else real[0]
+ # Fail-fast: unsplit multichar Chinese IDs in input_ids means the
+ # serving layer didn't pre-split. Silent fixup here would cause
+ # input_ids/embeds length mismatch (scheduler slot count is fixed).
+ split_map = self._get_multichar_zh_split()
+ if split_map and any(tid in split_map for tid in token_ids):
+ raise ValueError(
+ "VoxCPM2 preprocess received unsplit multichar Chinese "
+ "token IDs. The serving layer must send prompt_token_ids "
+ "with single-char CJK IDs (see _voxcpm2_encode)."
+ )
+ if token_ids and token_ids[0] == self.config.bos_token_id:
+ token_ids = token_ids[1:]
+
+ state = self._get_or_create_state(req_id)
+ state.prefill_text = ""
+ state.decode_pad = None
+ state.audio_chunks = []
+ state.prefill_completed = False
+ state.decode_step_count = 0
+ state.precomputed_stop_logits = None
+ state.last_audio_patch_gpu = None
+ state.curr_embed_for_next = None
+ state.prev_feat_embed = None
+ state.curr_prefix_feat_cond = None
+ state.is_stopping = False
+ state.last_decoded_audio = None
+
+ # Voice clone / continuation
+ ref_audio = info_dict.get("reference_audio") or info_dict.get("ref_audio")
+ prompt_audio = info_dict.get("prompt_audio")
+ prompt_text = info_dict.get("prompt_text")
+ if isinstance(ref_audio, list):
+ ref_audio = ref_audio[0] if ref_audio else None
+ if isinstance(prompt_audio, list):
+ prompt_audio = prompt_audio[0] if prompt_audio else None
+ if isinstance(prompt_text, list):
+ prompt_text = prompt_text[0] if prompt_text else None
+
+ state.prompt_cache = None
+ if ref_audio or (prompt_audio and prompt_text):
+ try:
+ state.prompt_cache = self._build_prompt_cache(
+ ref_audio=ref_audio,
+ prompt_audio=prompt_audio,
+ prompt_text=prompt_text,
+ )
+ except Exception as e:
+ logger.warning("build_prompt_cache failed: %s", e)
+
+ inputs = self._build_prefill_inputs(token_ids, dev, req_id)
+ tts = self.tts
+ feat_embed = tts.enc_to_lm_proj(tts.feat_encoder(inputs["audio_feat"]))
+ text_embed = self.model.embed_input_ids(inputs["text_token"].to(dev))
+ text_mask, feat_mask = inputs["text_mask"], inputs["audio_mask"]
+ embeds = (text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed).squeeze(0)
+ state.prefill_masks = (text_mask, feat_mask, inputs["audio_feat"], feat_embed)
+ else:
+ state = self._active_states.get(req_id)
+ curr = state.curr_embed_for_next if state else None
+ if curr is not None:
+ embeds = curr.to(dev, dtype=self._side_dtype).reshape(1, -1)
+ else:
+ embeds = torch.zeros(1, self.config.hidden_size, device=dev, dtype=self._side_dtype)
+
+ self._pending_requests.append((req_id, is_prefill, embeds, span_len))
+ return input_ids, embeds, {}
+
+ def postprocess(self, hidden_states: torch.Tensor, **info: Any) -> dict[str, Any]:
+ req_id = info.get("request_id", self._current_request_id or "default")
+ if _ENABLE_PROFILING:
+ state = self._active_states.get(req_id)
+ if state and state.decode_step_count > 0:
+ logger.info(
+ "REQUEST DONE[%s]: %d steps, %.2fs\n%s",
+ req_id,
+ state.decode_step_count,
+ time.perf_counter() - state.request_start_time,
+ self._perf.breakdown(),
+ )
+ return {}
+
+ # -------------------- build prefill inputs --------------------
+
+ def _build_prefill_inputs(self, token_ids: list[int], dev: Any, req_id: str = "default") -> dict:
+ tts = self.tts
+ dtype = self._side_dtype
+ state = self._active_states.get(req_id)
+ cache = state.prompt_cache if state else None
+ mode = cache.get("mode", "continuation") if cache else "zero_shot"
+
+ if cache and mode in ("continuation", "ref_continuation"):
+ prompt_text = cache.get("prompt_text", "")
+ prompt_ids = list(tts.text_tokenizer(prompt_text)) if prompt_text else []
+ all_ids = prompt_ids + token_ids
+ else:
+ all_ids = token_ids
+
+ text_token = torch.tensor(all_ids, dtype=torch.int32)
+ text_token = torch.cat([text_token, torch.tensor([tts.audio_start_token], dtype=torch.int32)], dim=-1)
+ text_len = text_token.shape[0]
+ latent_dim = tts.audio_vae.latent_dim
+ ps = self._patch_size
+
+ if mode in ("zero_shot", "continuation"):
+ audio_feat = cache["audio_feat"] if cache else torch.empty((0, ps, latent_dim), dtype=torch.float32)
+ a_len = audio_feat.size(0)
+ text_token = torch.cat([text_token, torch.zeros(a_len, dtype=torch.int32)])
+ audio_feat = torch.cat([torch.zeros((text_len, ps, latent_dim), dtype=torch.float32), audio_feat])
+ text_mask = torch.cat([torch.ones(text_len, dtype=torch.int32), torch.zeros(a_len, dtype=torch.int32)])
+ audio_mask = torch.cat([torch.zeros(text_len, dtype=torch.int32), torch.ones(a_len, dtype=torch.int32)])
+ elif mode == "reference":
+ ref = cache["ref_audio_feat"]
+ rt, rf, rtm, ram = tts._make_ref_prefix(ref, text_token.device)
+ text_token = torch.cat([rt.cpu(), text_token])
+ audio_feat = torch.cat([rf.cpu(), torch.zeros((text_len, ps, latent_dim), dtype=torch.float32)])
+ text_mask = torch.cat([rtm.cpu(), torch.ones(text_len, dtype=torch.int32)])
+ audio_mask = torch.cat([ram.cpu(), torch.zeros(text_len, dtype=torch.int32)])
+ else: # ref_continuation
+ ref = cache["ref_audio_feat"]
+ prompt = cache["audio_feat"]
+ p_len = prompt.size(0)
+ rt, rf, rtm, ram = tts._make_ref_prefix(ref, text_token.device)
+ text_token = torch.cat([rt.cpu(), text_token, torch.zeros(p_len, dtype=torch.int32)])
+ audio_feat = torch.cat([rf.cpu(), torch.zeros((text_len, ps, latent_dim), dtype=torch.float32), prompt])
+ ones_t = torch.ones(text_len, dtype=torch.int32)
+ zeros_p = torch.zeros(p_len, dtype=torch.int32)
+ zeros_t = torch.zeros(text_len, dtype=torch.int32)
+ ones_p = torch.ones(p_len, dtype=torch.int32)
+ text_mask = torch.cat([rtm.cpu(), ones_t, zeros_p])
+ audio_mask = torch.cat([ram.cpu(), zeros_t, ones_p])
+
+ return {
+ "text_token": text_token.unsqueeze(0).to(dev),
+ "audio_feat": audio_feat.unsqueeze(0).to(dev).to(dtype),
+ "text_mask": text_mask.unsqueeze(0).to(dev),
+ "audio_mask": audio_mask.unsqueeze(0).to(dev),
+ }
+
+ # -------------------- weight loading --------------------
+
+ hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"base_lm.": "model."})
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ def _base_lm_only(ws):
+ for name, tensor in ws:
+ if name.startswith("base_lm."):
+ yield name, tensor
+
+ loader = AutoWeightsLoader(self)
+ loaded = loader.load_weights(_base_lm_only(weights), mapper=self.hf_to_vllm_mapper)
+
+ model_path = self.vllm_config.model_config.model
+ VoxCPM = import_voxcpm2_core()
+ native = VoxCPM.from_pretrained(model_path, load_denoiser=False, optimize=False)
+ self._tts = native.tts_model.to("cuda")
+ self._side_dtype = self._tts.fusion_concat_proj.weight.dtype
+ self._device = "cuda"
+ self._patch_size = self._tts.patch_size
+ self._feat_dim = self._tts.feat_dim
+
+ n = self.residual_model.load_weights_from_native(self._tts.residual_lm)
+ for name, _ in self.residual_model.named_parameters():
+ loaded.add(f"residual_model.{name}")
+ logger.info("VoxCPM2: loaded %d params into paged residual_model", n)
+
+ del self._tts.base_lm
+ self._tts.base_lm = None
+ del self._tts.residual_lm
+ self._tts.residual_lm = None
+ torch.cuda.empty_cache()
+
+ logger.info(
+ "Loaded VoxCPM2 (patch=%d, feat_dim=%d, dtype=%s)", self._patch_size, self._feat_dim, self._side_dtype
+ )
+ return loaded
diff --git a/vllm_omni/model_executor/models/voxtral_tts/configuration_voxtral_tts.py b/vllm_omni/model_executor/models/voxtral_tts/configuration_voxtral_tts.py
index d32a882e78..0f22c764a0 100644
--- a/vllm_omni/model_executor/models/voxtral_tts/configuration_voxtral_tts.py
+++ b/vllm_omni/model_executor/models/voxtral_tts/configuration_voxtral_tts.py
@@ -48,6 +48,15 @@ def _remap_mistral_audio_args(self, config_dict: dict) -> dict:
audio_tokenizer_args = config_dict["multimodal"].pop("audio_tokenizer_args", None)
audio_config = {}
if encoder_args is not None:
+ # Default n_decoding_steps if not provided
+ acoustic_args = encoder_args.get("acoustic_transformer_args", {})
+ if acoustic_args.get("n_decoding_steps") is None:
+ logger.warning(
+ "n_decoding_steps not provided in acoustic_transformer_args, defaulting to 7. "
+ "Please add 'n_decoding_steps' to params.json under acoustic_transformer_args."
+ )
+ acoustic_args["n_decoding_steps"] = 7
+
audio_config = {
"sampling_rate": encoder_args["audio_encoding_args"]["sampling_rate"],
"codec_args": audio_tokenizer_args,
diff --git a/vllm_omni/model_executor/models/voxtral_tts/cuda_graph_acoustic_transformer_wrapper.py b/vllm_omni/model_executor/models/voxtral_tts/cuda_graph_acoustic_transformer_wrapper.py
index a4d58df5b1..ff053342db 100644
--- a/vllm_omni/model_executor/models/voxtral_tts/cuda_graph_acoustic_transformer_wrapper.py
+++ b/vllm_omni/model_executor/models/voxtral_tts/cuda_graph_acoustic_transformer_wrapper.py
@@ -49,7 +49,7 @@ def __init__(
self.acoustic_embeddings_levels = self.acoustic_transformer.acoustic_embeddings_levels
self.cfg_alpha = 1.2
- self.n_steps = 8
+ self.n_steps = self.acoustic_transformer.acoustic_transformer_args.n_decoding_steps
# Graph storage
self.graphs: dict[int, CUDAGraph] = {}
@@ -73,7 +73,7 @@ def _warmup_and_capture(self, device: torch.device, dtype: torch.dtype, hidden_d
)
# Pre-create persistent buffers
- self.timesteps = torch.linspace(0, 1, self.n_steps, device=device, dtype=dtype)
+ self.timesteps = torch.linspace(0, 1, self.n_steps + 1, device=device, dtype=dtype)
self.fake_eos_one = torch.tensor(1.0, dtype=dtype, device=device)
self.fake_eos_zero = torch.tensor(0.0, dtype=dtype, device=device)
diff --git a/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts_audio_generation.py b/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts_audio_generation.py
index b5d1161733..cd67e4f074 100644
--- a/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts_audio_generation.py
+++ b/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts_audio_generation.py
@@ -108,6 +108,7 @@ class AcousticTransformerArgs:
use_biases: bool = False
norm_eps: float = 1e-5
sigma: float = 1e-5 # was 0.01 in beta version
+ n_decoding_steps: int | None = None # Number of Euler ODE steps for flow matching
@dataclass
@@ -436,14 +437,13 @@ def __init__(
self._empty_audio_token_id = AudioSpecialTokens.id(AudioSpecialTokens.empty_audio)
# Flow matching constants
- # TODO(chenyo): hardcoded, need to fix
- self._acoustic_decode_iters = 8
+ self._n_steps = args.n_decoding_steps
# TODO(chenyo): hardcoded, need to fix
self._cfg_alpha = 1.2
self._noise_scale = 1.0
self.register_buffer(
"_timesteps",
- torch.linspace(0, 1, self._acoustic_decode_iters),
+ torch.linspace(0, 1, self._n_steps + 1),
persistent=False,
)
@@ -864,6 +864,29 @@ def get_replacement(item_idx: int):
),
]
+ def _apply_hf_processor_mm_only(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ tokenization_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ """
+ Apply the HF processor on the multi-modal data only.
+
+ Issue: Voxtral TTS use Mistral Tokenizer with custom audio encoder. It doesn't
+ inherit Transformers ProcessorMixin and can't use call_hf_processor_mm_only.
+
+ Solution: Override this method to call _apply_hf_processor_text_mm directly.
+ """
+ mm_counts = mm_items.get_all_counts()
+ _, mm_processed_data, _ = self._apply_hf_processor_text_mm(
+ prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
+ mm_items=mm_items,
+ hf_processor_mm_kwargs=hf_processor_mm_kwargs,
+ tokenization_kwargs=tokenization_kwargs,
+ )
+ return mm_processed_data
+
def _cached_apply_hf_processor(
self,
inputs: ProcessorInputs,
diff --git a/vllm_omni/model_executor/models/whisper_utils.py b/vllm_omni/model_executor/models/whisper_utils.py
new file mode 100644
index 0000000000..5aa2fc8a3a
--- /dev/null
+++ b/vllm_omni/model_executor/models/whisper_utils.py
@@ -0,0 +1,39 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+# Copyright (c) 2022 OpenAI
+#
+# Shared Whisper encoder primitives used by multiple model implementations.
+# Originally from the OpenAI Whisper codebase.
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def sinusoids(length, channels, max_timescale=10000):
+ """Returns sinusoids for positional embedding."""
+ assert channels % 2 == 0
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
+
+
+class Conv1d(nn.Conv1d):
+ """Conv1d with automatic dtype casting for mixed precision inference."""
+
+ def _conv_forward(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None) -> torch.Tensor:
+ return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
+
+
+class ConvTranspose1d(nn.ConvTranspose1d):
+ def _conv_forward(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None) -> torch.Tensor:
+ return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
+
+
+class Linear(nn.Linear):
+ """Linear layer with automatic dtype casting for mixed precision inference."""
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype))
diff --git a/vllm_omni/model_executor/stage_configs/bagel.yaml b/vllm_omni/model_executor/stage_configs/bagel.yaml
index d1031b574a..75f7c8a063 100644
--- a/vllm_omni/model_executor/stage_configs/bagel.yaml
+++ b/vllm_omni/model_executor/stage_configs/bagel.yaml
@@ -52,14 +52,9 @@ stage_args:
engine_args:
model_stage: dit
max_num_seqs: 1
- gpu_memory_utilization: 0.45
enforce_eager: true
trust_remote_code: true
- engine_output_type: image
distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- tensor_parallel_size: 1
omni_kv_config:
need_recv_cache: true
engine_input_source: [0]
@@ -76,10 +71,6 @@ stage_args:
# Runtime edges
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
-
# Distributed connectors configuration (optional)
# More connectors will be supported in the future.
connectors:
@@ -109,4 +100,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml
index 4919395cad..7a0d851f0f 100644
--- a/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml
+++ b/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml
@@ -45,14 +45,9 @@ stage_args:
engine_args:
model_stage: dit
max_num_seqs: 1
- gpu_memory_utilization: 0.45
enforce_eager: true
trust_remote_code: true
- engine_output_type: image
distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- tensor_parallel_size: 1
omni_kv_config:
need_recv_cache: true
engine_input_source: [0]
@@ -69,10 +64,6 @@ stage_args:
# Runtime edges
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
-
# Distributed connectors configuration (optional)
# More connectors will be supported in the future.
connectors:
@@ -109,4 +100,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/bagel_single_stage.yaml b/vllm_omni/model_executor/stage_configs/bagel_single_stage.yaml
index 2c1d84af49..b2d4b07b13 100644
--- a/vllm_omni/model_executor/stage_configs/bagel_single_stage.yaml
+++ b/vllm_omni/model_executor/stage_configs/bagel_single_stage.yaml
@@ -9,14 +9,9 @@ stage_args:
engine_args:
model_stage: dit
max_num_seqs: 1
- gpu_memory_utilization: 0.45
enforce_eager: true
trust_remote_code: true
- engine_output_type: image
distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- tensor_parallel_size: 1
final_output: true
final_output_type: image
@@ -27,6 +22,3 @@ stage_args:
# Runtime edges
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
diff --git a/vllm_omni/model_executor/stage_configs/bagel_think.yaml b/vllm_omni/model_executor/stage_configs/bagel_think.yaml
index c4cf32c707..2575e6736d 100644
--- a/vllm_omni/model_executor/stage_configs/bagel_think.yaml
+++ b/vllm_omni/model_executor/stage_configs/bagel_think.yaml
@@ -49,14 +49,9 @@ stage_args:
engine_args:
model_stage: dit
max_num_seqs: 1
- gpu_memory_utilization: 0.45
enforce_eager: true
trust_remote_code: true
- engine_output_type: image
distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- tensor_parallel_size: 1
omni_kv_config:
need_recv_cache: true
engine_input_source: [0]
@@ -70,9 +65,6 @@ stage_args:
# Runtime edges
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
connectors:
shared_memory_connector:
@@ -83,4 +75,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml b/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml
index 632c227f36..4599f8b059 100644
--- a/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml
+++ b/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml
@@ -45,14 +45,9 @@ stage_args:
max_batch_size: 1
engine_args:
model_stage: dit
- gpu_memory_utilization: 0.45
enforce_eager: true
trust_remote_code: true
- engine_output_type: image
distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- tensor_parallel_size: 1
parallel_config:
ulysses_degree: 2
# ring_degree: 2
@@ -67,9 +62,6 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
connectors:
shared_memory_connector:
name: SharedMemoryConnector
@@ -78,4 +70,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/bailingmm_moe_v2_lite.yaml b/vllm_omni/model_executor/stage_configs/bailingmm_moe_v2_lite.yaml
new file mode 100644
index 0000000000..b7d0aeeb74
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/bailingmm_moe_v2_lite.yaml
@@ -0,0 +1,46 @@
+# Stage config for Ming-flash-omni-2.0
+# Stage 0: Thinker (Multimodal understanding + text generation)
+# Stage 1a: Image Generator (Text embeddings -> PIL image)
+# Stage 1b: Talker (Text embeddings -> audio waveform)
+
+async_chunk: false
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ runtime:
+ devices: "0,1,2,3"
+ max_batch_size: 1
+ engine_args:
+ model_stage: thinker
+ model_arch: MingFlashOmniForConditionalGeneration
+ # tokenizer_subdir: talker/llm
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.9
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: latent
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 4 # Use 4 GPUs for MoE model
+ # pipeline_parallel_size: 4
+ hf_config_name: llm_config
+ compilation_config:
+ pass_config:
+ # there's a version mismatch regarding vllm and flashinfer
+ # disable fuse allreduce for now
+ fuse_allreduce_rms: false
+ final_output: true # Can output text directly
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ max_tokens: 2048
+ repetition_penalty: 1.05
+ seed: 42
+ detokenize: true
+
+ # Future Stage 1a: Image Generator (Optional - not yet implemented)
+ # Future Stage 1b: Talker/TTS (Optional - not yet implemented)
diff --git a/vllm_omni/model_executor/stage_configs/cosyvoice3_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/cosyvoice3_async_chunk.yaml
index ca7e9850ae..13419ef107 100644
--- a/vllm_omni/model_executor/stage_configs/cosyvoice3_async_chunk.yaml
+++ b/vllm_omni/model_executor/stage_configs/cosyvoice3_async_chunk.yaml
@@ -63,9 +63,6 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
connectors:
connector_of_shared_memory:
@@ -82,4 +79,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/dynin_omni.yaml b/vllm_omni/model_executor/stage_configs/dynin_omni.yaml
index 0724146aa7..131a0d1cd7 100644
--- a/vllm_omni/model_executor/stage_configs/dynin_omni.yaml
+++ b/vllm_omni/model_executor/stage_configs/dynin_omni.yaml
@@ -67,14 +67,9 @@ stage_args:
# Top-level runtime config (concise): default windows and stage edges
runtime:
enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
edges:
- from: 0
to: 1
- window_size: -1
- from: 1
to: 2
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml
index 7259daa9ea..4a54f8188a 100644
--- a/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml
+++ b/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml
@@ -71,9 +71,6 @@ stage_args:
# Top-level runtime config (concise): default windows and stage edges
runtime:
enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
####
# same as Qwen2.5_omni version
# Distributed connectors configuration (optional)
@@ -108,7 +105,5 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
- from: 1
to: 2
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml b/vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml
index 0b0b278592..90f80c22d7 100644
--- a/vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml
+++ b/vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml
@@ -71,10 +71,6 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 16
-
connectors:
connector_of_shared_memory:
name: SharedMemoryConnector
@@ -93,4 +89,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/glm_image.yaml b/vllm_omni/model_executor/stage_configs/glm_image.yaml
index 3cc23e1e25..05ac84a7a0 100644
--- a/vllm_omni/model_executor/stage_configs/glm_image.yaml
+++ b/vllm_omni/model_executor/stage_configs/glm_image.yaml
@@ -70,11 +70,6 @@ stage_args:
# Top-level runtime config
runtime:
enabled: true
- defaults:
- window_size: -1 # Trigger downstream only after full upstream completion
- max_inflight: 1 # Process serially within each stage
-
edges:
- from: 0 # AR → Diffusion: trigger after AR completes
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml b/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml
index 719c73a9fc..7bd66c403f 100644
--- a/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml
+++ b/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml
@@ -70,14 +70,9 @@ stage_args:
# Top-level runtime config with MultiConnector support
runtime:
enabled: true
- defaults:
- window_size: -1 # Trigger downstream only after full upstream completion
- max_inflight: 1 # Process serially within each stage
-
edges:
- from: 0 # AR → Diffusion
to: 1
- window_size: -1
# OmniConnector configuration for efficient inter-stage tensor transfer
connectors:
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml
new file mode 100644
index 0000000000..b68b184ec3
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml
@@ -0,0 +1,41 @@
+# Stage config for HunyuanImage-3.0 Image-to-Text (I2T / image understanding).
+# Single LLM stage: AR model reads image + text prompt, generates text output.
+
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ runtime:
+ process: true
+ devices: "0,1,2,3"
+ max_batch_size: 1
+ requires_multimodal_data: true
+ engine_args:
+ model_stage: AR
+ max_num_seqs: 1
+ model_arch: HunyuanImage3ForCausalMM
+ worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.95
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 4
+ pipeline_parallel_size: 1
+ hf_overrides:
+ rope_parameters:
+ mrope_section: [0, 32, 32]
+ rope_type: default
+ is_comprehension: true
+ final_output: true
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 0.95
+ top_k: 1024
+ max_tokens: 2048
+ stop_token_ids: [127957, 128026] # <|endoftext|>,
+ detokenize: True
+
+runtime:
+ enabled: true
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml
new file mode 100644
index 0000000000..413e0f09cb
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml
@@ -0,0 +1,74 @@
+# Stage config for HunyuanImage-3.0 Image+Text-to-Image (image editing).
+# Stage 0: AR (HunyuanImage3ForConditionalGeneration) — reads (image, text), emits latent tokens
+# Stage 1: Diffusion (HunyuanImage3Pipeline / DiT + VAE) — denoise + decode latents → image
+
+stage_args:
+ # Stage 0: AR Model
+ - stage_id: 0
+ stage_type: llm
+ runtime:
+ process: true
+ devices: "0,1,2,3"
+ max_batch_size: 1
+ requires_multimodal_data: true # AR needs the original image
+ engine_args:
+ model_stage: AR
+ model_arch: HunyuanImage3ForCausalMM
+ worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.95
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: latent # AR outputs latent for DiT
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 4
+ pipeline_parallel_size: 1
+ hf_overrides:
+ rope_parameters:
+ mrope_section: [0, 32, 32]
+ rope_type: default
+ is_comprehension: false # Generation task, not comprehension
+ final_output: false # AR is not the final output
+ default_sampling_params:
+ temperature: 0.6
+ top_p: 0.95
+ top_k: 1024
+ max_tokens: 4096
+ stop_token_ids: [127957] # <|endoftext|>
+ detokenize: false
+
+ # Stage 1: Diffusion (DiT + VAE)
+ # Receives latents from AR stage, performs denoising + VAE decode
+ - stage_id: 1
+ stage_type: diffusion
+ runtime:
+ process: true
+ devices: "4,5,6,7"
+ max_batch_size: 1
+ requires_multimodal_data: true # May need condition images
+ engine_args:
+ model_stage: dit
+ model_arch: HunyuanImage3ForCausalMM
+ enforce_eager: true
+ trust_remote_code: true
+ distributed_executor_backend: "mp"
+ parallel_config:
+ tensor_parallel_size: 4
+ enable_expert_parallel: true
+ omni_kv_config:
+ need_recv_cache: true
+ engine_input_source: [0] # Input from AR stage
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.hunyuan_image3.ar2diffusion
+ final_output: true
+ final_output_type: image
+ default_sampling_params:
+ num_inference_steps: 50
+ guidance_scale: 2.5
+
+# Top-level runtime config
+runtime:
+ enabled: true
+ edges:
+ - from: 0 # AR → Diffusion
+ to: 1
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit_2gpu_fp8.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit_2gpu_fp8.yaml
index 51110c2858..586b601bc5 100644
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit_2gpu_fp8.yaml
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit_2gpu_fp8.yaml
@@ -11,13 +11,9 @@ stage_args:
max_batch_size: 1
engine_args:
model_stage: dit
- gpu_memory_utilization: 0.9
enforce_eager: true
trust_remote_code: true
- engine_output_type: image
distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
quantization: "fp8"
parallel_config:
tensor_parallel_size: 2
@@ -34,6 +30,3 @@ stage_args:
# Runtime edges
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i.yaml
similarity index 80%
rename from vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit.yaml
rename to vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i.yaml
index 0b812ff376..1d8c7f4812 100644
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit.yaml
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i.yaml
@@ -11,13 +11,9 @@ stage_args:
engine_args:
max_num_seqs: 1
model_stage: dit
- gpu_memory_utilization: 0.65
enforce_eager: true
trust_remote_code: true
- engine_output_type: image
distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
parallel_config:
tensor_parallel_size: 4
enable_expert_parallel: true
@@ -33,6 +29,3 @@ stage_args:
# Runtime edges
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe_2gpu.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i_2gpu.yaml
similarity index 95%
rename from vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe_2gpu.yaml
rename to vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i_2gpu.yaml
index e029c38362..41ed74ba62 100644
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe_2gpu.yaml
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i_2gpu.yaml
@@ -39,6 +39,3 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml
new file mode 100644
index 0000000000..a0a1a0dc1c
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml
@@ -0,0 +1,42 @@
+# Stage config for HunyuanImage-3.0 Text-to-Text (T2T / pure text generation).
+# Single LLM stage: AR model reads text prompt only, generates text output.
+# Sampling params aligned with official generation_config.json.
+
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ runtime:
+ process: true
+ devices: "0,1,2,3"
+ max_batch_size: 1
+ requires_multimodal_data: false
+ engine_args:
+ model_stage: AR
+ max_num_seqs: 1
+ model_arch: HunyuanImage3ForCausalMM
+ worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.95
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 4
+ pipeline_parallel_size: 1
+ hf_overrides:
+ rope_parameters:
+ mrope_section: [0, 32, 32]
+ rope_type: default
+ is_comprehension: true
+ final_output: true
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 0.95
+ top_k: 1024
+ max_tokens: 2048
+ stop_token_ids: [127957, 128026] # <|endoftext|>,
+ detokenize: True
+
+runtime:
+ enabled: true
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe.yaml
deleted file mode 100644
index 6f4ba306a5..0000000000
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe.yaml
+++ /dev/null
@@ -1,85 +0,0 @@
-# Stage config for running Hunyuan-Image3.0 for multi-stage omni runtime.
-# Stage 0: AR Model (vLLM implementation)
-
-# The following config has been verified on 8x L40S-48G GPU.
-modes:
- - mode: text-to-image
- stages: [1]
- - mode: image-to-text
- stages: [0]
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true # Run this stage in a separate process
- devices: "0,1,2,3,4,5,6,7" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
- engine_args:
- model_stage: AR
- max_num_seqs: 1
- model_arch: HunyuanImage3ForCausalMM
- worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.3
- enforce_eager: true # Now we only support eager mode
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- tensor_parallel_size: 8
- pipeline_parallel_size: 1
- hf_overrides:
- rope_parameters:
- mrope_section: [0, 32, 32]
- rope_type: default
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- - stage_id: 1
- stage_type: diffusion
- runtime:
- process: true
- devices: "0,1,2,3,4,5,6,7"
- max_batch_size: 1
- engine_args:
- model_stage: diffusion
- gpu_memory_utilization: 0.9
- enforce_eager: true
- engine_output_type: image
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- vae_use_slicing: false
- vae_use_tiling: false
- cache_backend: null
- cache_config: null
- enable_cache_dit_summary: false
- parallel_config:
- pipeline_parallel_size: 1
- data_parallel_size: 1
- tensor_parallel_size: 8
- enable_expert_parallel: false
- sequence_parallel_size: 1
- ulysses_degree: 1
- ring_degree: 1
- cfg_parallel_size: 1
- vae_patch_parallel_size: 1
- use_hsdp: false
- hsdp_shard_size: -1
- hsdp_replicate_size: 1
- final_output: true
- final_output_type: image
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
diff --git a/vllm_omni/model_executor/stage_configs/mimo_audio_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/mimo_audio_async_chunk.yaml
index b3c6bbbaf0..2fa1b982af 100644
--- a/vllm_omni/model_executor/stage_configs/mimo_audio_async_chunk.yaml
+++ b/vllm_omni/model_executor/stage_configs/mimo_audio_async_chunk.yaml
@@ -74,10 +74,6 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
-
connectors:
connector_of_shared_memory:
name: SharedMemoryConnector
@@ -93,4 +89,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/omnivoice.yaml b/vllm_omni/model_executor/stage_configs/omnivoice.yaml
index 49f11e9674..546e3b3dc2 100644
--- a/vllm_omni/model_executor/stage_configs/omnivoice.yaml
+++ b/vllm_omni/model_executor/stage_configs/omnivoice.yaml
@@ -10,10 +10,8 @@ stage_args:
engine_args:
model_stage: dit
model_class_name: "OmniVoicePipeline"
- gpu_memory_utilization: 0.5
enforce_eager: true
trust_remote_code: true
- engine_output_type: audio
distributed_executor_backend: "mp"
dtype: "float32"
final_output: true
diff --git a/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml b/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml
deleted file mode 100644
index 0a307b4477..0000000000
--- a/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml
+++ /dev/null
@@ -1,107 +0,0 @@
-# stage config for running qwen2.5-omni for multi-stage omni runtime.
-
-# The following config has been verified on 2x H100-80G GPU.
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: true # Now we only support eager mode
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- mm_processor_cache_gb: 0
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- engine_output_type: latent
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
-
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.15
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- async_scheduling: false
- engine_output_type: audio
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
-
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml
deleted file mode 100644
index b318aebe36..0000000000
--- a/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml
+++ /dev/null
@@ -1,151 +0,0 @@
-# stage config for running qwen2.5-omni for multi-stage omni runtime.
-
-# The following config has been verified on 1x H100-80G GPU.
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: true # Now we only support eager mode
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- mm_processor_cache_gb: 0
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- # Distributed connector configuration (optional)
- output_connectors:
- to_stage_1: mooncake_connector
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: latent
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
- # Distributed connector configuration (optional)
- input_connectors:
- from_stage_0: mooncake_connector
- output_connectors:
- to_stage_2: mooncake_connector
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "2" # Example: use a different GPU than the previous stage; use "0" if single GPU
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.3
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- engine_output_type: audio
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- # Distributed connector configuration (optional)
- input_connectors:
- from_stage_1: mooncake_connector
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
-
- # Distributed connectors configuration (optional)
- # More connectors will be supported in the future.
- connectors:
- # Mooncake connector for cross-node/intra-node communication
- mooncake_connector:
- name: MooncakeStoreConnector
- extra:
- host: "127.0.0.1"
- metadata_server: "http://10.90.67.86:8080/metadata"
- master: "10.90.67.86:50051"
- segment: 512000000 # 512MB
- localbuf: 64000000 # 64MB
- proto: "tcp"
-
- # Mori RDMA connector for cross-node/intra-node communication
- mori_connector:
- name: MoriTransferEngineConnector
- extra:
- host: "auto"
- zmq_port: 50051
- device_name: ""
- memory_pool_size: 536870912 # 512 MB
- memory_pool_device: "cpu"
-
- # Yuanrong connector for cross-node/intra-node communication
- yuanrong_connector:
- name: YuanrongConnector
- extra:
- host: "127.0.0.1"
- port: "35000"
-
- # SharedMemory connector for intra-node communication
- # Alternative SHM connector with different threshold
- shared_memory_connector:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536 # 64KB threshold
-
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml
deleted file mode 100644
index 0ce4f0c94f..0000000000
--- a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml
+++ /dev/null
@@ -1,101 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
-
-# The following config has been verified on 2x H100-80G GPUs.
-async_chunk: false
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "0"
- engine_args:
- model_stage: thinker
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- hf_config_name: thinker_config
- tensor_parallel_size: 1
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
-
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "1"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 32
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.1
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 1000000
- hf_config_name: thinker_config
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml
deleted file mode 100644
index 38626fc081..0000000000
--- a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml
+++ /dev/null
@@ -1,117 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes)
-# Stage 2: Code2Wav (16-layer RVQ codes → audio waveform)
-
-# The following config has been verified on 2x H100-80G GPUs.
-async_chunk: true
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "0"
- engine_args:
- model_stage: thinker
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- hf_config_name: thinker_config
- tensor_parallel_size: 1
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk
- final_output: true
- final_output_type: text
- is_comprehension: true
- # Use named connector to apply runtime.connectors.extra.
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
-
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk
- engine_input_source: [0]
- # final_output: true
- # final_output_type: text
- # Distributed connector configuration
- input_connectors:
- from_stage_0: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "1"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.1
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 51200 # [TODO] if max_num_batch_tokens < max_num_seqs * 800, there will be precision problem.
- hf_config_name: thinker_config
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-runtime:
-
- connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- # Align with Omni: small chunks with sufficient context overlap.
- codec_chunk_frames: 25 # code2wav decode chunk size
- codec_left_context_frames: 25 # code2wav left context size
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml
deleted file mode 100644
index 6c2d2a7669..0000000000
--- a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml
+++ /dev/null
@@ -1,143 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings -> 8-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes -> audio waveform)
-
-# The following config has been verified on 2x H100-80G GPUs.
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "0"
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- hf_config_name: thinker_config
- tensor_parallel_size: 1
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- # Distributed connector configuration
- output_connectors:
- to_stage_1: mooncake_connector
-
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- # tensor_parallel_size: 2
- enable_prefix_caching: false
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
- # Distributed connector configuration
- input_connectors:
- from_stage_0: mooncake_connector
- output_connectors:
- to_stage_2: mooncake_connector
-
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "1"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.1
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 1000000
- hf_config_name: thinker_config
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- # Distributed connector configuration
- input_connectors:
- from_stage_1: mooncake_connector
-
-# Top-level runtime config: default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
-
- # Distributed connectors configuration
- connectors:
- # Mooncake connector for cross-node/intra-node communication
- mooncake_connector:
- name: MooncakeStoreConnector
- extra:
- host: "127.0.0.1"
- metadata_server: "http://10.90.67.86:8080/metadata"
- master: "10.90.67.86:50051"
- segment: 512000000 # 512MB
- localbuf: 64000000 # 64MB
- proto: "tcp"
-
- # SharedMemory connector for intra-node communication
- shared_memory_connector:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536 # 64KB threshold
-
- edges:
- - from: 0
- to: 1
- window_size: -1
- - from: 1
- to: 2
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml
deleted file mode 100644
index 75b2bab3a2..0000000000
--- a/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml
+++ /dev/null
@@ -1,100 +0,0 @@
-# Same as qwen3_tts.yaml with batched talker and code2wav.
-# Stage 0: max_num_seqs 4, stage 1: max_num_seqs 4.
-# max_num_seqs must be a power of two to align with CUDA graph capture sizes
-# (stage 0) and must match --batch-size in end2end.py / benchmark scripts.
-async_chunk: true
-stage_args:
- - stage_id: 0
- stage_type: llm
- is_comprehension: true
- runtime:
- devices: "0"
- engine_args:
- model_stage: qwen3_tts
- max_num_seqs: 4
- model_arch: Qwen3TTSTalkerForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- enforce_eager: false
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: latent
- gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 512
- max_model_len: 4096
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
- # Use named connector to apply runtime.connectors.extra.
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: false
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 1
- stage_type: llm
- runtime:
- devices: "0"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 4
- model_arch: Qwen3TTSCode2Wav
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: audio
- gpu_memory_utilization: 0.2
- distributed_executor_backend: "mp"
- # Must be divisible by num_code_groups and cover (left_context + chunk).
- max_num_batched_tokens: 65536
- # Flat codec prompt can exceed 32k tokens (Q * frames); align with max_tokens below.
- max_model_len: 65536
- engine_input_source: [0]
- final_output: true
- final_output_type: audio
- # Distributed connector configuration
- input_connectors:
- from_stage_0: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: true
- repetition_penalty: 1.0
-
-runtime:
- enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
-
- connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536
- # Frame-aligned codec streaming transport.
- codec_streaming: true
- # Connector polling / timeout (unit: loop count, sleep interval in seconds).
- connector_get_sleep_s: 0.01
- connector_get_max_wait_first_chunk: 3000
- connector_get_max_wait: 300
- # Match the decoder sliding attention window to avoid chunk-boundary noise.
- codec_chunk_frames: 25
- codec_left_context_frames: 72
-
- edges:
- - from: 0
- to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml
similarity index 94%
rename from vllm_omni/model_executor/stage_configs/qwen3_tts.yaml
rename to vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml
index a0d38eb4b9..4ca8d11ad7 100644
--- a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml
+++ b/vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml
@@ -17,7 +17,6 @@ stage_args:
enable_prefix_caching: false
engine_output_type: latent
gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
max_num_batched_tokens: 512
max_model_len: 4096
custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
@@ -49,7 +48,6 @@ stage_args:
enable_prefix_caching: false
engine_output_type: audio
gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
# Must be divisible by num_code_groups and cover (left_context + chunk).
# Prefill length is Q * num_frames (e.g. 16 * 2148 = 34368); keep headroom past 32k.
max_num_batched_tokens: 65536
@@ -74,9 +72,6 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
connectors:
connector_of_shared_memory:
@@ -96,4 +91,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/voxcpm.yaml b/vllm_omni/model_executor/stage_configs/voxcpm.yaml
new file mode 100644
index 0000000000..a5f324f660
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/voxcpm.yaml
@@ -0,0 +1,69 @@
+# VoxCPM two-stage (latent → VAE) without async_chunk: one-shot latent then decode.
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ is_comprehension: true
+ runtime:
+ devices: "0"
+ max_batch_size: 1
+ engine_args:
+ dtype: bfloat16
+ model_stage: latent_generator
+ model_arch: VoxCPMForConditionalGeneration
+ # Optional persistent HF-compatible config dir for native VoxCPM models.
+ hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,}
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ async_scheduling: false
+ enable_prefix_caching: false
+ engine_output_type: latent
+ gpu_memory_utilization: 0.7
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 4096
+ max_model_len: 4096
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 4096
+ stop_token_ids: [2]
+ seed: 42
+ detokenize: false
+ repetition_penalty: 1.0
+ final_output: false
+
+ - stage_id: 1
+ stage_type: llm
+ runtime:
+ devices: "0"
+ max_batch_size: 1
+ engine_args:
+ dtype: float32
+ model_stage: vae
+ model_arch: VoxCPMForConditionalGeneration
+ hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,}
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ async_scheduling: false
+ enable_prefix_caching: false
+ engine_output_type: audio
+ gpu_memory_utilization: 0.15
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 8192
+ max_model_len: 4096
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.voxcpm.latent2vae
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 1
+ seed: 42
+ detokenize: true
+ repetition_penalty: 1.0
diff --git a/vllm_omni/model_executor/stage_configs/voxcpm2.yaml b/vllm_omni/model_executor/stage_configs/voxcpm2.yaml
new file mode 100644
index 0000000000..7cc93d6b26
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/voxcpm2.yaml
@@ -0,0 +1,36 @@
+# VoxCPM2 AR pipeline with per-request state batching.
+# Uses native MiniCPM4 base_lm + per-request StaticKVCache.
+# max_batch_size > 1 supported via KV cache save/restore.
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ is_comprehension: true
+ runtime:
+ devices: "0"
+ max_batch_size: 4
+ engine_args:
+ dtype: bfloat16
+ model_stage: latent_generator
+ model_arch: VoxCPM2TalkerForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ async_scheduling: true
+ enable_prefix_caching: false
+ engine_output_type: audio
+ gpu_memory_utilization: 0.9
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 4096
+ max_model_len: 4096
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 4096
+ seed: 42
+ detokenize: false
+ repetition_penalty: 1.0
+ stop_token_ids: [1]
+ final_output: true
+ final_output_type: audio
diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml b/vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml
similarity index 60%
rename from vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml
rename to vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml
index cd82d91b71..c6fd177a35 100644
--- a/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml
+++ b/vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml
@@ -1,3 +1,5 @@
+# VoxCPM two-stage streaming (align with qwen3_tts.yaml async_chunk pattern).
+# Stage0 (latent_generator) emits latent in time chunks; Stage1 (VAE) decodes as chunks arrive.
async_chunk: true
stage_args:
- stage_id: 0
@@ -5,42 +7,48 @@ stage_args:
is_comprehension: true
runtime:
devices: "0"
+ max_batch_size: 1
engine_args:
- model_stage: qwen3_tts
- max_num_seqs: 1
- model_arch: Qwen3TTSTalkerForConditionalGeneration
+ dtype: bfloat16
+ model_stage: latent_generator
+ model_arch: VoxCPMForConditionalGeneration
+ # Optional persistent HF-compatible config dir for native VoxCPM models.
+ hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,}
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
enforce_eager: true
trust_remote_code: true
- async_scheduling: false
+ async_scheduling: true
enable_prefix_caching: false
engine_output_type: latent
- gpu_memory_utilization: 0.3
+ gpu_memory_utilization: 0.7
distributed_executor_backend: "mp"
- max_num_batched_tokens: 512
+ max_num_batched_tokens: 4096
max_model_len: 4096
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
- # Use named connector to apply runtime.connectors.extra.
- output_connectors:
- to_stage_1: connector_of_shared_memory
+ custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.voxcpm.latent2vae_async_chunk
default_sampling_params:
- temperature: 0.9
- top_k: 50
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
max_tokens: 4096
+ stop_token_ids: [2]
seed: 42
detokenize: false
- repetition_penalty: 1.05
- stop_token_ids: [2150]
+ repetition_penalty: 1.0
+ final_output: false
+ output_connectors:
+ to_stage_1: voxcpm_shm
- stage_id: 1
stage_type: llm
runtime:
devices: "0"
+ max_batch_size: 1
engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen3TTSCode2Wav
+ dtype: float32
+ model_stage: vae
+ model_arch: VoxCPMForConditionalGeneration
+ hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,}
worker_type: generation
scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
enforce_eager: true
@@ -48,35 +56,30 @@ stage_args:
async_scheduling: false
enable_prefix_caching: false
engine_output_type: audio
- gpu_memory_utilization: 0.2
+ gpu_memory_utilization: 0.15
distributed_executor_backend: "mp"
- max_num_batched_tokens: 65536
- max_model_len: 65536
+ max_num_batched_tokens: 8192
+ max_model_len: 4096
engine_input_source: [0]
final_output: true
final_output_type: audio
- # Distributed connector configuration
input_connectors:
- from_stage_0: connector_of_shared_memory
- tts_args:
- max_instructions_length: 500
+ from_stage_0: voxcpm_shm
default_sampling_params:
temperature: 0.0
top_p: 1.0
top_k: -1
- max_tokens: 65536
+ max_tokens: 128
seed: 42
detokenize: true
repetition_penalty: 1.0
+
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
connectors:
- connector_of_shared_memory:
+ voxcpm_shm:
name: SharedMemoryConnector
extra:
shm_threshold_bytes: 65536
@@ -87,10 +90,9 @@ runtime:
connector_get_max_wait_first_chunk: 3000
connector_get_max_wait: 300
# Align with Omni: small chunks with sufficient context overlap.
- codec_chunk_frames: 25
- codec_left_context_frames: 72
+ codec_chunk_frames: 1
+ codec_left_context_frames: 1
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/voxtral_tts.yaml b/vllm_omni/model_executor/stage_configs/voxtral_tts.yaml
index 31cccb9ccf..b0d9a81e78 100644
--- a/vllm_omni/model_executor/stage_configs/voxtral_tts.yaml
+++ b/vllm_omni/model_executor/stage_configs/voxtral_tts.yaml
@@ -82,10 +82,6 @@ stage_args:
# Top-level runtime config (concise): default windows and stage edges
runtime:
enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
-
connectors:
connector_of_shared_memory:
name: SharedMemoryConnector
@@ -102,4 +98,3 @@ runtime:
edges:
- from: 0 # language_model → acoustic_transformer: trigger only after receiving full input (-1)
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_input_processors/bagel.py b/vllm_omni/model_executor/stage_input_processors/bagel.py
index bfcff0ea0f..52cc14d3aa 100644
--- a/vllm_omni/model_executor/stage_input_processors/bagel.py
+++ b/vllm_omni/model_executor/stage_input_processors/bagel.py
@@ -82,6 +82,8 @@ def expand_cfg_prompts(
neg_prompt = _get_negative_prompt(prompt, sampling_params)
if "image" in modalities:
+ if not neg_prompt:
+ return []
neg_prompt_dict = {
"prompt": neg_prompt,
"modalities": prompt.get("modalities", []),
@@ -166,6 +168,8 @@ def expand_cfg_prompts_think(
companion_params = {"max_tokens": 1}
if "image" in modalities:
+ if not neg_prompt:
+ return []
neg_prompt_dict = {
"prompt": neg_prompt,
"modalities": prompt.get("modalities", []),
@@ -287,9 +291,10 @@ def _get_negative_prompt(
) -> str:
"""Resolve the negative prompt for CFG from prompt or sampling params.
- An empty string is treated the same as absent (falls through to
- the Bagel default token pair), because an empty negative prompt is
- not meaningful for CFG guidance.
+ Returns the negative prompt string when one is supplied, otherwise an
+ empty string. Callers decide how to treat the empty case: text2img
+ skips the cfg_text companion entirely, while img2img substitutes it
+ into the cfg_text prompt template.
"""
neg = prompt.get("negative_prompt")
if neg:
@@ -300,4 +305,4 @@ def _get_negative_prompt(
if neg:
return neg
- return "<|im_start|><|im_end|>"
+ return ""
diff --git a/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py b/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py
new file mode 100644
index 0000000000..89a7a28f6c
--- /dev/null
+++ b/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py
@@ -0,0 +1,123 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Stage input processor for HunyuanImage3: AR → Diffusion transition.
+
+In IT2I (image editing) mode:
+ - Stage 0 (AR) receives (image + edit instruction), generates CoT/latent tokens
+ - Stage 1 (DiT) receives the AR output + original image, denoises → edited image
+
+The ar2diffusion function bridges these two stages, following the same
+signature pattern as glm_image.ar2diffusion.
+"""
+
+from typing import Any
+
+import torch
+from vllm.inputs import TextPrompt
+from vllm.logger import init_logger
+
+from vllm_omni.inputs.data import OmniTokensPrompt
+
+logger = init_logger(__name__)
+
+
+def ar2diffusion(
+ stage_list: list[Any],
+ engine_input_source: list[int],
+ prompt: OmniTokensPrompt | TextPrompt | list | None = None,
+ requires_multimodal_data: bool = False,
+) -> list[dict[str, Any]]:
+ """Process AR stage outputs to create Diffusion stage inputs.
+
+ Args:
+ stage_list: List of stage clients (set by orchestrator).
+ engine_input_source: List of source stage IDs (from YAML).
+ prompt: Original user prompt (may contain multimodal data).
+ requires_multimodal_data: Whether to forward multimodal data.
+
+ Returns:
+ List of dicts, each consumable by the HunyuanImage3 diffusion pipeline.
+ """
+ if not engine_input_source:
+ raise ValueError("engine_input_source cannot be empty")
+
+ source_stage_id = engine_input_source[0]
+ if source_stage_id >= len(stage_list):
+ raise IndexError(f"Invalid source stage_id: {source_stage_id}")
+
+ if stage_list[source_stage_id].engine_outputs is None:
+ raise RuntimeError(f"Stage {source_stage_id} has no outputs yet")
+
+ ar_outputs = stage_list[source_stage_id].engine_outputs
+ diffusion_inputs = []
+
+ # Normalize prompt to list
+ if not isinstance(prompt, list):
+ prompt = [prompt] if prompt is not None else [{}]
+
+ for i, ar_output in enumerate(ar_outputs):
+ output = ar_output.outputs[0]
+ generated_token_ids = output.token_ids
+ generated_text = getattr(output, "text", "") or ""
+
+ # Get original prompt info
+ original_prompt = prompt[i] if i < len(prompt) else {}
+ if isinstance(original_prompt, dict):
+ pass
+ elif hasattr(original_prompt, "_asdict"):
+ original_prompt = original_prompt._asdict()
+ elif hasattr(original_prompt, "__dict__"):
+ original_prompt = vars(original_prompt)
+ else:
+ original_prompt = {}
+
+ height = original_prompt.get("height", 1024)
+ width = original_prompt.get("width", 1024)
+ text_prompt = original_prompt.get("prompt", "")
+
+ logger.info(
+ "[ar2diffusion] Request %d: AR generated %d tokens, text length=%d, target size=%dx%d",
+ i,
+ len(generated_token_ids),
+ len(generated_text),
+ height,
+ width,
+ )
+
+ token_tensor = torch.tensor(generated_token_ids, dtype=torch.long)
+
+ diffusion_input: dict[str, Any] = {
+ "prompt": text_prompt,
+ "height": height,
+ "width": width,
+ "extra": {
+ "ar_token_ids": token_tensor,
+ "ar_generated_text": generated_text,
+ },
+ }
+
+ # Forward multimodal data (original image for IT2I conditioning)
+ mm_data = original_prompt.get("multi_modal_data")
+ if mm_data:
+ pil_image = mm_data.get("image")
+ if pil_image is None:
+ images = mm_data.get("images")
+ if images:
+ pil_image = images[0] if isinstance(images, list) else images
+ if pil_image is not None:
+ diffusion_input["pil_image"] = pil_image
+
+ # Forward multimodal output from AR (if any)
+ if hasattr(ar_output, "multimodal_output") and ar_output.multimodal_output:
+ mm_output = ar_output.multimodal_output
+ if isinstance(mm_output, dict):
+ diffusion_input["extra"]["ar_multimodal_output"] = mm_output
+
+ # Forward sampling params
+ for key in ["seed", "num_inference_steps", "guidance_scale", "negative_prompt"]:
+ if key in original_prompt:
+ diffusion_input[key] = original_prompt[key]
+
+ diffusion_inputs.append(diffusion_input)
+
+ return diffusion_inputs
diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
index f4828fddaa..699e4b194a 100644
--- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
+++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
@@ -3,6 +3,8 @@
# Copyright 2025 The Qwen team.
"""Stage input processor for Qwen3 Omni MoE: Thinker → Talker transition."""
+import logging
+from dataclasses import dataclass, field
from typing import Any
import torch
@@ -18,6 +20,12 @@
extract_speaker_from_request,
)
+logger = logging.getLogger(__name__)
+
+# Pooling output layer keys: "0" = word embedding, "24" = accept_hidden_layer
+_EMBED_LAYER_KEY = "0"
+_HIDDEN_LAYER_KEY = "24"
+
def _compute_talker_prompt_ids_length(info, device: torch.device | str = "cuda") -> int:
im_start_token_id = 151644
@@ -84,6 +92,191 @@ def _validate_stage_inputs(stage_list, engine_input_source):
return stage.engine_outputs
+# =========================
+# PD disaggregation helpers
+# =========================
+
+
+def _get_prefill_stage(stage_list: list[Any], source_stage_id: int) -> Any | None:
+ if source_stage_id <= 0:
+ return None
+ source_stage = stage_list[source_stage_id]
+ if not getattr(source_stage, "is_decode_only", False):
+ return None
+ prev_stage = stage_list[source_stage_id - 1]
+ if getattr(prev_stage, "is_prefill_only", False) and prev_stage.engine_outputs is not None:
+ return prev_stage
+ return None
+
+
+def _merge_pd_embeddings(
+ decode_emb: torch.Tensor,
+ decode_hid: torch.Tensor,
+ prefill_mm: dict[str, Any],
+ device: torch.device,
+ expected_total: int | None = None,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Merge prefill prompt embeddings with decode generated embeddings.
+
+ In PD mode the prefill engine processes the prompt and the decode engine
+ generates tokens starting from position 1. This function concatenates
+ them, removing the overlapping token(s):
+
+ merged = prefill[:P] + decode[overlap:]
+
+ where overlap = P + D - expected_total.
+ """
+ try:
+ p_emb = prefill_mm[_EMBED_LAYER_KEY].detach().to(device=device, dtype=torch.float)
+ p_hid = prefill_mm[_HIDDEN_LAYER_KEY].detach().to(device=device, dtype=torch.float)
+ except (KeyError, AttributeError, TypeError) as exc:
+ available_keys = list(prefill_mm.keys()) if isinstance(prefill_mm, dict) else type(prefill_mm).__name__
+ logger.error(
+ "_merge_pd_embeddings: failed to extract prefill embeddings (%s). "
+ "Expected keys %r and %r, got: %s. "
+ "Falling back to decode-only embeddings – talker user-segment will be degraded.",
+ exc,
+ _EMBED_LAYER_KEY,
+ _HIDDEN_LAYER_KEY,
+ available_keys,
+ )
+ return decode_emb, decode_hid
+
+ if p_emb.shape[0] == 0 or decode_emb.shape[0] == 0:
+ return decode_emb, decode_hid
+
+ raw_total = p_emb.shape[0] + decode_emb.shape[0]
+ overlap = max(0, raw_total - expected_total) if expected_total is not None else 0
+
+ merged_emb = torch.cat([p_emb, decode_emb[overlap:]], dim=0)
+ merged_hid = torch.cat([p_hid, decode_hid[overlap:]], dim=0)
+ return merged_emb, merged_hid
+
+
+def _get_prefill_multimodal_output(prefill_stage: Any, output_index: int) -> dict[str, Any] | None:
+ """Return multimodal_output dict from the PD prefill stage for a given batch index."""
+ try:
+ prefill_eos = prefill_stage.engine_outputs
+ prefill_eo = prefill_eos[min(output_index, len(prefill_eos) - 1)]
+ return prefill_eo.outputs[0].multimodal_output
+ except Exception:
+ return None
+
+
+def _resolve_tts_token_embedding(
+ key: str,
+ *,
+ thinker_mm: dict[str, Any],
+ prefill_mm: dict[str, Any] | None,
+ device: torch.device,
+) -> torch.Tensor | None:
+ """Return TTS BOS/EOS/PAD embedding tensors for the talker projection path.
+
+ Values are taken from the current thinker (decode) ``multimodal_output``; in
+ PD mode, missing keys may be filled from the paired prefill stage output.
+ """
+ val = thinker_mm.get(key)
+ if val is None and prefill_mm is not None:
+ val = prefill_mm.get(key)
+ return val.detach().to(device=device, dtype=torch.float) if val is not None else None
+
+
+# =========================
+# Streaming input helpers
+# =========================
+
+
+@dataclass
+class _Thinker2TalkerStreamingState:
+ last_prompt_len: int = 0
+ last_output_len: int = 0
+ merged_sequences: list[int] = field(default_factory=list)
+
+
+@dataclass
+class _Qwen3OmniStreamingState:
+ thinker2talker: _Thinker2TalkerStreamingState = field(default_factory=_Thinker2TalkerStreamingState)
+ talker2code2wav_last_seq_len: int = 0
+
+
+def _get_qwen3_streaming_state(
+ request_id: str,
+ streaming_context: Any | None,
+) -> _Qwen3OmniStreamingState:
+ bridge_states = getattr(streaming_context, "bridge_states", None)
+ per_model_state = bridge_states.setdefault("qwen3_omni", {})
+ state = per_model_state.get(request_id)
+ if state is None:
+ state = _Qwen3OmniStreamingState()
+ per_model_state[request_id] = state
+ return state
+
+
+def _get_streaming_talker_tokens(
+ request_id: str,
+ prompt_token_ids: list[int],
+ output_token_ids: list[int],
+ new_prompt_len_snapshot: int | None = None,
+ streaming_context: Any | None = None,
+ *,
+ clear_state: bool = False,
+) -> tuple[list[int], list[int], list[int], list[int]]:
+ """Return streaming token slices and merged token views for thinker->talker.
+ e.g. For the second streaming input request:
+ merged_sequences: [input_prompt 1, output_tokens 1[:-1], input_prompt 2, output_tokens 2]
+ thinker_input_ids: [input_prompt 1, output_tokens 1[:-1], input_prompt 2]
+ Returns:
+ inc_prompt: prompt token delta for this segment.
+ inc_output: output token delta for this segment.
+ merged_sequences: full thinker_sequences to send downstream.
+ thinker_input_ids: full thinker_input_ids paired with merged_sequences.
+ """
+ state = _get_qwen3_streaming_state(request_id, streaming_context).thinker2talker
+ if new_prompt_len_snapshot:
+ prompt_token_ids = prompt_token_ids[:-new_prompt_len_snapshot]
+ cur_prompt_len = len(prompt_token_ids)
+ cur_output_len = len(output_token_ids)
+
+ inc_prompt = prompt_token_ids[state.last_prompt_len :]
+ inc_output = output_token_ids[state.last_output_len :]
+ delta_sequences = inc_prompt + inc_output
+ cached_sequences = state.merged_sequences
+
+ merged_sequences = cached_sequences + delta_sequences
+ thinker_input_ids = cached_sequences + inc_prompt
+
+ # Persist history for next segment. Drop the latest sampled token to keep
+ # thinker_input_ids / thinker_sequences alignment with next-step append.
+ cached_sequences.extend(delta_sequences[:-1])
+
+ state.last_prompt_len = cur_prompt_len
+ state.last_output_len = cur_output_len
+
+ if clear_state:
+ state.last_prompt_len = 0
+ state.last_output_len = 0
+ state.merged_sequences.clear()
+
+ return inc_prompt, inc_output, merged_sequences, thinker_input_ids
+
+
+def _get_streaming_codec_delta_len(
+ cur_seq_len: int,
+ request_id: str,
+ talker_output: Any,
+ streaming_context: Any | None = None,
+) -> int:
+ """Return newly added seq_len for talker->code2wav in streaming mode."""
+ state = _get_qwen3_streaming_state(request_id, streaming_context)
+ prev_seq_len = state.talker2code2wav_last_seq_len
+ seq_len = cur_seq_len - prev_seq_len
+ state.talker2code2wav_last_seq_len = cur_seq_len + 1
+ if bool(getattr(talker_output, "finished", False)):
+ # Final segment: clear history to avoid cross-session carry-over.
+ state.talker2code2wav_last_seq_len = 0
+ return seq_len
+
+
# =========================
# Thinker -> Talker
# =========================
@@ -111,8 +304,8 @@ def thinker2talker_async_chunk(
all_token_ids = _ensure_list(all_token_ids)
prompt_token_ids = _ensure_list(prompt_token_ids)
talker_additional_info = {
- "thinker_prefill_embeddings": pooling_output.get("0").detach().cpu(),
- "thinker_hidden_states": pooling_output.get("24").detach().cpu(),
+ "thinker_prefill_embeddings": pooling_output.get(_EMBED_LAYER_KEY).detach().cpu(),
+ "thinker_hidden_states": pooling_output.get(_HIDDEN_LAYER_KEY).detach().cpu(),
"thinker_sequences": all_token_ids,
"thinker_input_ids": prompt_token_ids,
# Provide thinker-side TTS token embeddings for talker projection
@@ -161,7 +354,7 @@ def thinker2talker_async_chunk(
if output_token_ids:
talker_additional_info["override_keys"] = ["thinker_decode_embeddings", "thinker_output_token_ids"]
- talker_additional_info["thinker_decode_embeddings"] = pooling_output.get("0").detach().cpu()
+ talker_additional_info["thinker_decode_embeddings"] = pooling_output.get(_EMBED_LAYER_KEY).detach().cpu()
talker_additional_info["thinker_output_token_ids"] = output_token_ids
else:
# When prefilling a chunked thinker, thinker_hidden_states needs to be updated.
@@ -176,6 +369,7 @@ def thinker2talker(
engine_input_source: list[int],
prompt: OmniTokensPrompt | TextPrompt | None = None,
requires_multimodal_data: bool = False,
+ streaming_context: Any | None = None,
) -> list[OmniTokensPrompt]:
"""
Process thinker outputs to create talker inputs.
@@ -185,6 +379,9 @@ def thinker2talker(
2. Split hidden states into: prompt embeddings + generated embeddings
3. Package for talker with additional information
+ In PD disaggregation mode, merges prefill-stage prompt embeddings with
+ decode-stage generated embeddings before handing off to the talker.
+
Args:
stage_list: List of stage objects
engine_input_source: Source stage IDs (typically [0] for thinker)
@@ -199,21 +396,66 @@ def thinker2talker(
device = torch.device(current_platform.device_type)
+ # PD disaggregation: look up the preceding prefill stage (if any)
+ source_stage_id = engine_input_source[0]
+ prefill_stage = _get_prefill_stage(stage_list, source_stage_id)
+
# Process each thinker output
for i, thinker_output in enumerate(thinker_outputs):
output = thinker_output.outputs[0]
+ req_id = str(getattr(thinker_output, "request_id", f"idx-{i}"))
+ prompt_token_ids = _ensure_list(thinker_output.prompt_token_ids)
+ output_ids = _ensure_list(output.token_ids)
+ is_streaming_session = bool(getattr(streaming_context, "enabled", False))
+ if is_streaming_session:
+ prompt_token_ids, output_ids, thinker_sequences, thinker_input_ids = _get_streaming_talker_tokens(
+ req_id,
+ prompt_token_ids,
+ output_ids,
+ getattr(streaming_context, "new_prompt_len_snapshot", None),
+ streaming_context,
+ clear_state=bool(getattr(thinker_output, "finished", False)),
+ )
+ else:
+ thinker_sequences = prompt_token_ids + output_ids
+ thinker_input_ids = prompt_token_ids
+ # For streaming input, just send incremental prefill and hidden states tensor to talker
+ # Equally applicable to non-streaming cases.
+ new_seq_length = len(prompt_token_ids + output_ids) - 1
+ thinker_mm = output.multimodal_output
+ # Full thinker embedding sequence for the talker: single thinker engine in the
+ # non-PD path; after optional merge with prefill-side tensors in PD mode.
+ thinker_emb = thinker_mm[_EMBED_LAYER_KEY].detach().to(device=device, dtype=torch.float)[-new_seq_length:]
+ thinker_hid = thinker_mm[_HIDDEN_LAYER_KEY].detach().to(device=device, dtype=torch.float)[-new_seq_length:]
+
+ prefill_mm: dict[str, Any] | None = None
+ if prefill_stage is not None:
+ prefill_mm = _get_prefill_multimodal_output(prefill_stage, i)
+
+ if prefill_mm is not None:
+ expected_total = len(prompt_token_ids) + len(output_ids)
+ try:
+ thinker_emb, thinker_hid = _merge_pd_embeddings(
+ thinker_emb, thinker_hid, prefill_mm, device, expected_total=expected_total
+ )
+ except Exception as exc:
+ logger.warning("[PD] Could not merge prefill embeddings: %s", exc)
info = {
- "thinker_prefill_embeddings": output.multimodal_output["0"].detach().to(device=device, dtype=torch.float),
- "thinker_hidden_states": output.multimodal_output["24"].detach().to(device=device, dtype=torch.float),
- "thinker_sequences": (
- thinker_output.prompt_token_ids + output.token_ids
- ), # the thinker_sequences is the whole ids
- "thinker_input_ids": thinker_output.prompt_token_ids,
+ "thinker_prefill_embeddings": thinker_emb,
+ "thinker_hidden_states": thinker_hid,
+ "thinker_sequences": thinker_sequences, # the thinker_sequences is the whole ids
+ "thinker_input_ids": thinker_input_ids,
# Provide thinker-side TTS token embeddings for talker projection
- "tts_bos_embed": output.multimodal_output["tts_bos_embed"].detach().to(device=device, dtype=torch.float),
- "tts_eos_embed": output.multimodal_output["tts_eos_embed"].detach().to(device=device, dtype=torch.float),
- "tts_pad_embed": output.multimodal_output["tts_pad_embed"].detach().to(device=device, dtype=torch.float),
+ "tts_bos_embed": _resolve_tts_token_embedding(
+ "tts_bos_embed", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device
+ ),
+ "tts_eos_embed": _resolve_tts_token_embedding(
+ "tts_eos_embed", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device
+ ),
+ "tts_pad_embed": _resolve_tts_token_embedding(
+ "tts_pad_embed", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device
+ ),
}
speaker = extract_speaker_from_prompt(prompt, index=i)
if speaker is not None:
@@ -314,6 +556,7 @@ def talker2code2wav(
engine_input_source: list[int],
prompt: OmniTokensPrompt | TextPrompt | None = None,
requires_multimodal_data: bool = False,
+ streaming_context: Any | None = None,
) -> list[OmniTokensPrompt]:
"""
Process talker outputs to create code2wav inputs.
@@ -335,9 +578,14 @@ def talker2code2wav(
talker_outputs = _validate_stage_inputs(stage_list, engine_input_source)
code2wav_inputs: list[OmniTokensPrompt] = []
# Process each talker output
- for talker_output in talker_outputs:
+ for i, talker_output in enumerate(talker_outputs):
output = talker_output.outputs[0]
- seq_len = len(output.token_ids) - 1
+ req_id = str(getattr(talker_output, "request_id", f"idx-{i}"))
+ cur_seq_len = len(output.token_ids) - 1
+ seq_len = cur_seq_len
+ is_streaming_session = bool(getattr(streaming_context, "enabled", False))
+ if is_streaming_session:
+ seq_len = _get_streaming_codec_delta_len(cur_seq_len, req_id, talker_output, streaming_context)
# Extract codec codes from talker output
# Expected shape: [8, seq_len] (8-layer RVQ codes)
codec_codes = (
diff --git a/vllm_omni/model_executor/stage_input_processors/voxcpm.py b/vllm_omni/model_executor/stage_input_processors/voxcpm.py
new file mode 100644
index 0000000000..c2fcf521bf
--- /dev/null
+++ b/vllm_omni/model_executor/stage_input_processors/voxcpm.py
@@ -0,0 +1,128 @@
+from __future__ import annotations
+
+from typing import Any
+
+import torch
+from vllm.inputs import TextPrompt
+
+from vllm_omni.inputs.data import OmniTokensPrompt
+
+_VOXCPM_LATENT_MAGIC = 131071
+
+
+def _serialize_latent_to_codes(latent: Any) -> list[int]:
+ latent_tensor = latent if isinstance(latent, torch.Tensor) else torch.as_tensor(latent)
+ latent_tensor = latent_tensor.detach().cpu().contiguous()
+ if latent_tensor.ndim == 3:
+ if latent_tensor.shape[0] != 1:
+ raise ValueError(f"Expected batch=1 latent tensor, got shape={tuple(latent_tensor.shape)}")
+ latent_tensor = latent_tensor.squeeze(0)
+ if latent_tensor.ndim != 2:
+ raise ValueError(f"Unsupported latent_audio_feat shape for async chunk: {tuple(latent_tensor.shape)}")
+ latent_dim, time_dim = int(latent_tensor.shape[0]), int(latent_tensor.shape[1])
+ packed = latent_tensor.to(torch.bfloat16).contiguous().view(torch.uint16).reshape(-1).to(torch.int32)
+ return [_VOXCPM_LATENT_MAGIC, latent_dim, time_dim, *packed.tolist()]
+
+
+def _coerce_finished_flag(value: Any) -> bool:
+ """Normalize VoxCPM async-chunk finished markers to a Python bool."""
+ if value is None:
+ return False
+ if isinstance(value, torch.Tensor):
+ if value.numel() != 1:
+ raise ValueError(f"finished tensor must be scalar, got shape={tuple(value.shape)}")
+ return bool(value.detach().cpu().item())
+ if isinstance(value, (list, tuple)):
+ if not value:
+ return False
+ if len(value) != 1:
+ raise ValueError(f"finished container must have one element, got len={len(value)}")
+ return _coerce_finished_flag(value[0])
+ return bool(value)
+
+
+def latent2vae(
+ stage_list: list[Any],
+ engine_input_source: list[int],
+ prompt: OmniTokensPrompt | TextPrompt | None = None,
+ requires_multimodal_data: bool = False,
+) -> list[OmniTokensPrompt]:
+ del prompt, requires_multimodal_data
+
+ if not engine_input_source:
+ raise ValueError("engine_input_source cannot be empty")
+
+ source_stage_id = engine_input_source[0]
+ if source_stage_id >= len(stage_list):
+ raise IndexError(f"Invalid stage_id: {source_stage_id}")
+
+ source_outputs = stage_list[source_stage_id].engine_outputs
+ if source_outputs is None:
+ raise RuntimeError(f"Stage {source_stage_id} has no outputs yet")
+
+ vae_inputs: list[OmniTokensPrompt] = []
+ for source_output in source_outputs:
+ output = source_output.outputs[0]
+ multimodal_output = getattr(output, "multimodal_output", None)
+ if not isinstance(multimodal_output, dict) or "latent_audio_feat" not in multimodal_output:
+ raise ValueError(
+ "VoxCPM latent stage output missing 'latent_audio_feat'. "
+ f"request_id={getattr(source_output, 'request_id', None)}"
+ )
+
+ additional_information = {
+ "latent_audio_feat": multimodal_output["latent_audio_feat"],
+ }
+ if "sr" in multimodal_output:
+ additional_information["sample_rate"] = [int(multimodal_output["sr"])]
+
+ vae_inputs.append(
+ OmniTokensPrompt(
+ prompt_token_ids=[0],
+ additional_information=additional_information,
+ multi_modal_data=None,
+ mm_processor_kwargs=None,
+ )
+ )
+
+ return vae_inputs
+
+
+def latent2vae_async_chunk(
+ transfer_manager: Any,
+ pooling_output: dict[str, Any] | None,
+ request: Any,
+ is_finished: bool = False,
+) -> dict[str, Any] | None:
+ """Stage-0 latent → stage-1 VAE under ``async_chunk`` (connector payload)."""
+ # Kept for callback signature compatibility with OmniChunkTransferAdapter.
+ _ = transfer_manager
+ finished_request = _coerce_finished_flag(is_finished)
+ if callable(getattr(request, "is_finished", None)):
+ finished_request = finished_request or _coerce_finished_flag(request.is_finished())
+ if not isinstance(pooling_output, dict):
+ if finished_request:
+ return {
+ "code_predictor_codes": [],
+ "finished": torch.tensor(True, dtype=torch.bool),
+ }
+ return None
+
+ latent = pooling_output.get("latent_audio_feat")
+ if isinstance(latent, torch.Tensor) and latent.numel() == 0:
+ latent = None
+
+ if latent is None:
+ if finished_request:
+ return {
+ "code_predictor_codes": [],
+ "finished": torch.tensor(True, dtype=torch.bool),
+ }
+ return None
+
+ serialized_codes = _serialize_latent_to_codes(latent)
+ out: dict[str, Any] = {
+ "code_predictor_codes": serialized_codes,
+ "finished": torch.tensor(finished_request, dtype=torch.bool),
+ }
+ return out
diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py
index 9a7bb67065..c02c0c1427 100644
--- a/vllm_omni/outputs.py
+++ b/vllm_omni/outputs.py
@@ -9,6 +9,33 @@
from vllm_omni.inputs.data import OmniPromptType
+@dataclass
+class OmniConnectorOutput:
+ """Communication results from Model Runner to Scheduler.
+
+ Carries transfer readiness signals so the Scheduler can make scheduling
+ decisions without ever calling connector.put()/get() directly.
+
+ Attributes:
+ chunk_ready_req_ids: Request IDs with newly arrived chunks this cycle.
+ chunk_finished_req_ids: Request IDs whose final chunk has arrived.
+ request_metadata: Lightweight scheduling metadata keyed by request ID
+ (e.g. next_stage_prompt_len, code_predictor_codes, left_context_size).
+ Full payloads are owned by the Model Runner's local cache.
+ kv_sent_req_ids: Request IDs whose KV cache was successfully sent.
+ stage_recv_req_ids: Request IDs that received batch stage inputs.
+ has_pending_kv_work: True if the mixin has pending, active, or
+ completed KV transfers that the scheduler should account for.
+ """
+
+ chunk_ready_req_ids: set[str] = field(default_factory=set)
+ chunk_finished_req_ids: set[str] = field(default_factory=set)
+ request_metadata: dict[str, dict[str, Any]] = field(default_factory=dict)
+ kv_sent_req_ids: list[str] = field(default_factory=list)
+ stage_recv_req_ids: set[str] = field(default_factory=set)
+ has_pending_kv_work: bool = False
+
+
class OmniModelRunnerOutput(ModelRunnerOutput):
"""Model runner output for omni models.
@@ -24,6 +51,7 @@ class OmniModelRunnerOutput(ModelRunnerOutput):
# IDs of requests whose KV cache has been extracted from GPU/NPU to CPU.
# The Scheduler can safely free the block tables for these requests.
kv_extracted_req_ids: list[str] | None = None
+ omni_connector_output: OmniConnectorOutput | None = None
@dataclass
@@ -72,6 +100,9 @@ class OmniRequestOutput:
# memory usage info
peak_memory_mb: float = 0.0
+ # error handling
+ error: str | None = None
+
@classmethod
def from_pipeline(
cls,
diff --git a/vllm_omni/patch.py b/vllm_omni/patch.py
index eafff821a2..f6c483a92f 100644
--- a/vllm_omni/patch.py
+++ b/vllm_omni/patch.py
@@ -1,6 +1,8 @@
import sys
+from functools import cached_property
from aenum import extend_enum
+from vllm.config import ModelConfig as _OriginalModelConfig
from vllm.inputs import TokensPrompt as _OriginalTokensPrompt
from vllm.model_executor.layers.rotary_embedding import (
MRotaryEmbedding as _OriginalMRotaryEmbedding,
@@ -10,12 +12,63 @@
from vllm.v1.engine import EngineCoreRequest as _OriginalEngineCoreRequest
from vllm.v1.request import Request as _OriginalRequest
from vllm.v1.request import RequestStatus
+from vllm.v1.request import StreamingUpdate as _OriginalStreamingUpdate
import vllm_omni.logger # noqa: F401
from vllm_omni.engine import OmniEngineCoreOutput, OmniEngineCoreOutputs, OmniEngineCoreRequest
from vllm_omni.inputs.data import OmniTokensPrompt
from vllm_omni.model_executor.layers.rotary_embedding import OmniMRotaryEmbedding
-from vllm_omni.request import OmniRequest
+from vllm_omni.request import OmniRequest, OmniStreamingUpdate
+
+# =============================================================================
+# Patch ModelConfig.is_mm_prefix_lm to support omni-specific models
+# =============================================================================
+# WHY: HunyuanImage-3.0 requires bidirectional attention for image tokens
+# (cond_token_attn_type: "joint_full" in config.json). vLLM gates this on
+# is_mm_prefix_lm, which checks an internal MM_PREFIX_LM_MODELS list that
+# does not include "hunyuan_image_3_moe" (the upstream HF model_type).
+#
+# WHY NOT model-level: is_mm_prefix_lm is checked in vLLM core (scheduler,
+# attention backend selection) before model code runs — no model-level hook.
+#
+# SCOPE: Only affects model_type in _OMNI_MM_PREFIX_LM_MODELS (currently
+# just "hunyuan_image_3_moe"). All other models fall through to the
+# original vLLM implementation unchanged.
+#
+# FRAGILITY: Relies on is_mm_prefix_lm being a cached_property on
+# ModelConfig. The __dict__ access + __set_name__ dance works around a
+# pydantic dataclass issue in vllm 0.19.0+. If vLLM changes
+# is_mm_prefix_lm to a regular method or removes it, this will break.
+#
+# TODO: Upstream a configurable MM_PREFIX_LM_MODELS or a model_config flag
+# so this patch can be removed.
+_OMNI_MM_PREFIX_LM_MODELS = ("hunyuan_image_3_moe",)
+# Access via __dict__ to avoid triggering cached_property.__get__ which fails
+# with "Cannot use cached_property instance without calling __set_name__" in
+# pydantic dataclasses (vllm 0.19.0+).
+_cp = _OriginalModelConfig.__dict__["is_mm_prefix_lm"]
+_original_is_mm_prefix_lm = _cp.func if hasattr(_cp, "func") else _cp.fget
+
+
+def _patched_is_mm_prefix_lm(self):
+ if _original_is_mm_prefix_lm(self):
+ return True
+ model_type = getattr(self.hf_config, "model_type", "")
+ return model_type in _OMNI_MM_PREFIX_LM_MODELS
+
+
+_patched_cp = cached_property(_patched_is_mm_prefix_lm)
+_patched_cp.__set_name__(_OriginalModelConfig, "is_mm_prefix_lm")
+_OriginalModelConfig.is_mm_prefix_lm = _patched_cp
+
+# Sanity check: verify the patch is active. If vLLM changes the descriptor
+# type or __set_name__ semantics, this will fail loudly at import time
+# rather than silently falling back to unpatched behavior.
+_installed = _OriginalModelConfig.__dict__.get("is_mm_prefix_lm")
+assert _installed is _patched_cp, (
+ "is_mm_prefix_lm patch failed to install — bidirectional attention "
+ "for HunyuanImage3 will not work. Check vLLM ModelConfig changes."
+)
# =============================================================================
# Patch GlmImageTextConfig to expose mrope_section in rope_parameters
@@ -63,5 +116,7 @@ def _patched_glm_image_text_config_init(self, *args, **kwargs):
module.MRotaryEmbedding = OmniMRotaryEmbedding
if hasattr(module, "Request") and module.Request == _OriginalRequest:
module.Request = OmniRequest
+ if hasattr(module, "StreamingUpdate") and module.StreamingUpdate == _OriginalStreamingUpdate:
+ module.StreamingUpdate = OmniStreamingUpdate
if hasattr(module, "EngineCoreRequest") and module.EngineCoreRequest == _OriginalEngineCoreRequest:
module.EngineCoreRequest = OmniEngineCoreRequest
diff --git a/vllm_omni/platforms/interface.py b/vllm_omni/platforms/interface.py
index 8f1e66747d..b69731a67d 100644
--- a/vllm_omni/platforms/interface.py
+++ b/vllm_omni/platforms/interface.py
@@ -64,7 +64,7 @@ def get_default_stage_config_path(cls) -> str:
@classmethod
def get_diffusion_model_impl_qualname(cls, op_name: str) -> str:
if op_name == "hunyuan_fused_moe":
- return "vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe.HunyuanFusedMoEDefault"
+ return "vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe.HunyuanFusedMoEDefault"
raise NotImplementedError(f"Unsupported diffusion model op: {op_name}")
@classmethod
diff --git a/vllm_omni/platforms/musa/platform.py b/vllm_omni/platforms/musa/platform.py
index fe1ccc6d0b..64a70a9beb 100644
--- a/vllm_omni/platforms/musa/platform.py
+++ b/vllm_omni/platforms/musa/platform.py
@@ -39,7 +39,7 @@ def get_default_stage_config_path(cls) -> str:
def get_diffusion_model_impl_qualname(cls, op_name: str) -> str:
# MUSA uses default implementations for diffusion ops
if op_name == "hunyuan_fused_moe":
- return "vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe.HunyuanFusedMoEDefault"
+ return "vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe.HunyuanFusedMoEDefault"
return super().get_diffusion_model_impl_qualname(op_name)
@classmethod
diff --git a/vllm_omni/platforms/npu/platform.py b/vllm_omni/platforms/npu/platform.py
index c40dd6fea1..53ffe6775a 100644
--- a/vllm_omni/platforms/npu/platform.py
+++ b/vllm_omni/platforms/npu/platform.py
@@ -69,6 +69,9 @@ def get_diffusion_attn_backend_cls(
# Try FLASH_ATTN if mindiesd is available, otherwise fall back to SDPA
if find_spec("mindiesd"):
+ # Configure ASCEND_CUSTOM_OPP_PATH for mindiesd custom ops upon import
+ import mindiesd # noqa: F401
+
logger.info("Defaulting to diffusion attention backend FLASH_ATTN")
return DiffusionAttentionBackendEnum.FLASH_ATTN.get_path()
diff --git a/vllm_omni/platforms/npu/stage_configs/hunyuan_image3_moe_dit.yaml b/vllm_omni/platforms/npu/stage_configs/hunyuan_image3_t2i.yaml
similarity index 94%
rename from vllm_omni/platforms/npu/stage_configs/hunyuan_image3_moe_dit.yaml
rename to vllm_omni/platforms/npu/stage_configs/hunyuan_image3_t2i.yaml
index 053e8a8cca..0fd03949d1 100644
--- a/vllm_omni/platforms/npu/stage_configs/hunyuan_image3_moe_dit.yaml
+++ b/vllm_omni/platforms/npu/stage_configs/hunyuan_image3_t2i.yaml
@@ -33,6 +33,3 @@ stage_args:
# Runtime defaults
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
diff --git a/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml b/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml
deleted file mode 100644
index 8f7af161d6..0000000000
--- a/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml
+++ /dev/null
@@ -1,97 +0,0 @@
-# stage config for running qwen2.5-omni for multi-stage omni runtime.
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: true # haven't supported talker ACL graph on NPU
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: latent
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "2" # Example: use a different NPU than the previous stage; use "0" if single NPU
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.15
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe.yaml
deleted file mode 100644
index 2638c99cd4..0000000000
--- a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe.yaml
+++ /dev/null
@@ -1,99 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
-
-# The following config has been verified on 5x A2/A3-64G NPUs.
-stage_args:
- - stage_id: 0
- runtime:
- devices: "0,1"
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- hf_config_name: thinker_config
- tensor_parallel_size: 2
- # profiler_config:
- # profiler: torch
- # torch_profiler_dir: ./perf
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
-
- - stage_id: 1
- runtime:
- devices: "2"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: true # haven't supported talker ACL graph on NPU
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- # tensor_parallel_size: 2
- enable_prefix_caching: false
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 2
- runtime:
- devices: "2"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 1000000
- hf_config_name: thinker_config
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml
deleted file mode 100644
index 9aa20baecf..0000000000
--- a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml
+++ /dev/null
@@ -1,101 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes)
-# Stage 2: Code2Wav (16-layer RVQ codes → audio waveform)
-
-# The following config has been verified on 2x H100-80G GPUs.
-async_chunk: true
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "0,1"
- engine_args:
- max_num_seqs: 10
- model_stage: thinker
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- hf_config_name: thinker_config
- tensor_parallel_size: 2
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
-
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "2"
- engine_args:
- max_num_seqs: 10
- model_stage: talker
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk
- engine_input_source: [0]
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.0
- stop_token_ids: [2150]
-
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "2"
- engine_args:
- max_num_seqs: 10
- model_stage: code2wav
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 51200 # [TODO] if max_num_batched_tokens < max_num_seqs * 800, there will be precision problem.
- hf_config_name: thinker_config
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml b/vllm_omni/platforms/npu/stage_configs/voxcpm.yaml
similarity index 59%
rename from vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml
rename to vllm_omni/platforms/npu/stage_configs/voxcpm.yaml
index 3f412fc4dc..dcd1f40517 100644
--- a/vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml
+++ b/vllm_omni/platforms/npu/stage_configs/voxcpm.yaml
@@ -1,42 +1,47 @@
-async_chunk: false
stage_args:
- stage_id: 0
stage_type: llm
is_comprehension: true
runtime:
devices: "0"
+ max_batch_size: 1
engine_args:
- model_stage: qwen3_tts
- max_num_seqs: 1
- model_arch: Qwen3TTSTalkerForConditionalGeneration
+ dtype: bfloat16
+ model_stage: latent_generator
+ model_arch: VoxCPMForConditionalGeneration
+ # Optional persistent HF-compatible config dir for native VoxCPM models.
+ hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,}
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- enforce_eager: false
+ enforce_eager: true
trust_remote_code: true
async_scheduling: false
enable_prefix_caching: false
engine_output_type: latent
- gpu_memory_utilization: 0.3
+ gpu_memory_utilization: 0.75
distributed_executor_backend: "mp"
- max_num_batched_tokens: 512
+ max_num_batched_tokens: 4096
max_model_len: 4096
default_sampling_params:
- temperature: 0.9
- top_k: 50
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
max_tokens: 4096
seed: 42
detokenize: false
- repetition_penalty: 1.05
- stop_token_ids: [2150]
+ repetition_penalty: 1.0
+ final_output: false
- stage_id: 1
stage_type: llm
runtime:
devices: "0"
+ max_batch_size: 1
engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen3TTSCode2Wav
+ dtype: float32
+ model_stage: vae
+ model_arch: VoxCPMForConditionalGeneration
+ hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,}
worker_type: generation
scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
enforce_eager: true
@@ -44,21 +49,19 @@ stage_args:
async_scheduling: false
enable_prefix_caching: false
engine_output_type: audio
- gpu_memory_utilization: 0.2
+ gpu_memory_utilization: 0.1
distributed_executor_backend: "mp"
- max_num_batched_tokens: 65536
- max_model_len: 65536
+ max_num_batched_tokens: 8192
+ max_model_len: 4096
engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.voxcpm.latent2vae
final_output: true
final_output_type: audio
- tts_args:
- max_instructions_length: 500
default_sampling_params:
temperature: 0.0
top_p: 1.0
top_k: -1
- max_tokens: 65536
+ max_tokens: 1
seed: 42
detokenize: true
repetition_penalty: 1.0
diff --git a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs1.yaml b/vllm_omni/platforms/npu/stage_configs/voxcpm_async_chunk.yaml
similarity index 64%
rename from benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs1.yaml
rename to vllm_omni/platforms/npu/stage_configs/voxcpm_async_chunk.yaml
index ca441d286d..87843634cb 100644
--- a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs1.yaml
+++ b/vllm_omni/platforms/npu/stage_configs/voxcpm_async_chunk.yaml
@@ -1,5 +1,3 @@
-# Qwen3-TTS batch_size=1 config (streaming with async_chunk)
-# 2-stage pipeline: Talker -> Code2Wav
async_chunk: true
stage_args:
- stage_id: 0
@@ -7,87 +5,85 @@ stage_args:
is_comprehension: true
runtime:
devices: "0"
+ max_batch_size: 1
engine_args:
- max_num_seqs: 1
- model_stage: qwen3_tts
- model_arch: Qwen3TTSTalkerForConditionalGeneration
+ dtype: bfloat16
+ model_stage: latent_generator
+ model_arch: VoxCPMForConditionalGeneration
+ # Optional persistent HF-compatible config dir for native VoxCPM models.
+ hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,}
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- enforce_eager: false
+ enforce_eager: true
trust_remote_code: true
- async_scheduling: true
+ async_scheduling: false
enable_prefix_caching: false
engine_output_type: latent
- gpu_memory_utilization: 0.3
+ gpu_memory_utilization: 0.75
distributed_executor_backend: "mp"
- max_num_batched_tokens: 512
+ max_num_batched_tokens: 4096
max_model_len: 4096
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
+ custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.voxcpm.latent2vae_async_chunk
output_connectors:
to_stage_1: connector_of_shared_memory
default_sampling_params:
- temperature: 0.9
- top_k: 50
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
max_tokens: 4096
seed: 42
detokenize: false
- repetition_penalty: 1.05
- stop_token_ids: [2150]
+ repetition_penalty: 1.0
+ final_output: false
- stage_id: 1
stage_type: llm
runtime:
devices: "0"
+ max_batch_size: 1
engine_args:
- max_num_seqs: 1
- model_stage: code2wav
- model_arch: Qwen3TTSCode2Wav
+ dtype: float32
+ model_stage: vae
+ model_arch: VoxCPMForConditionalGeneration
+ hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,}
worker_type: generation
scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
enforce_eager: true
trust_remote_code: true
- async_scheduling: true
+ async_scheduling: false
enable_prefix_caching: false
engine_output_type: audio
- gpu_memory_utilization: 0.3
+ gpu_memory_utilization: 0.1
distributed_executor_backend: "mp"
max_num_batched_tokens: 8192
- max_model_len: 32768
+ max_model_len: 4096
engine_input_source: [0]
- final_output: true
- final_output_type: audio
input_connectors:
from_stage_0: connector_of_shared_memory
- tts_args:
- max_instructions_length: 500
+ final_output: true
+ final_output_type: audio
default_sampling_params:
temperature: 0.0
top_p: 1.0
top_k: -1
- max_tokens: 65536
+ max_tokens: 1
seed: 42
detokenize: true
repetition_penalty: 1.0
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
connectors:
connector_of_shared_memory:
name: SharedMemoryConnector
extra:
shm_threshold_bytes: 65536
- codec_streaming: true
+ codec_streaming: false
connector_get_sleep_s: 0.01
connector_get_max_wait_first_chunk: 3000
connector_get_max_wait: 300
- codec_chunk_frames: 25
- codec_left_context_frames: 25
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py
index 138948064b..ffb997048b 100644
--- a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py
+++ b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py
@@ -149,7 +149,15 @@ def execute_model(
encoder_cache=self.encoder_cache,
) as ec_connector_output:
self._execute_mm_encoder(scheduler_output)
- return make_empty_encoder_model_runner_output(scheduler_output)
+
+ kv_ids = self.kv_extracted_req_ids
+ self.kv_extracted_req_ids = None
+
+ output = make_empty_encoder_model_runner_output(scheduler_output)
+ if kv_ids:
+ output = copy(output)
+ output.kv_extracted_req_ids = kv_ids
+ return output
if not num_scheduled_tokens:
if (
@@ -163,10 +171,20 @@ def execute_model(
# dummy run to ensure coordinate_batch_across_dp
# is called into to avoid out of sync issues.
self._dummy_run(1)
+
+ kv_ids = self.kv_extracted_req_ids
+ self.kv_extracted_req_ids = None
+
if not has_kv_transfer_group():
- # Return empty ModelRunnerOutput if no work to do.
- return EMPTY_MODEL_RUNNER_OUTPUT
- return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
+ output = EMPTY_MODEL_RUNNER_OUTPUT
+ else:
+ output = self.kv_connector_no_forward(scheduler_output, self.vllm_config)
+
+ if kv_ids:
+ output = copy(output)
+ output.kv_extracted_req_ids = kv_ids
+
+ return output
if self.cache_config.kv_sharing_fast_prefill:
assert not self.num_prompt_logprobs, (
"--kv-sharing-fast-prefill produces incorrect "
diff --git a/vllm_omni/platforms/rocm/platform.py b/vllm_omni/platforms/rocm/platform.py
index 4479e54f2a..7b0e09c128 100644
--- a/vllm_omni/platforms/rocm/platform.py
+++ b/vllm_omni/platforms/rocm/platform.py
@@ -16,6 +16,34 @@ class RocmOmniPlatform(OmniPlatform, RocmPlatform):
Inherits all ROCm-specific implementations from vLLM's RocmPlatform,
and adds Omni-specific interfaces from OmniPlatform.
+
+
+ NOTE: AR Attention Backend Overriding Logic:
+ ------------------------------------------
+ Since vLLM v0.19.0, the default attention backend is ROCM_ATTN for ROCm.
+ However, the compatibility of ROCM_ATTN with Omni is not guaranteed.
+ Therefore, we still use TRITON_ATTN as the default attention backend,
+ when the selected_backend is not specified.
+
+ So the behaviour of the attention backend overriding logic currently lives in
+ extract_stage_metadata in `vllm_omni/engine/stage_init_utils.py`
+
+ ```
+ if current_omni_platform.is_rocm():
+ print(f"engine_args: {str(engine_args)}")
+ if engine_args.get("attention_backend") is None:
+ from vllm._aiter_ops import rocm_aiter_ops
+
+ if rocm_aiter_ops.is_enabled():
+ engine_args["attention_backend"] = "ROCM_AITER_FA"
+ # Before vLLM v0.19.0, the default attention backend is TRITON_ATTN for ROCm.
+ # Since vLLM v0.19.0, the default attention backend is ROCM_ATTN for ROCm.
+ # However, the compatibility of ROCM_ATTN with Omni is not guaranteed.
+ # Therefore, we still use TRITON_ATTN as the default attention backend,
+ # when the selected_backend is not specified.
+ engine_args["attention_backend"] = "TRITON_ATTN"
+ ```
+
"""
_omni_enum = OmniPlatformEnum.ROCM
diff --git a/vllm_omni/platforms/rocm/stage_configs/qwen2_5_omni.yaml b/vllm_omni/platforms/rocm/stage_configs/qwen2_5_omni.yaml
deleted file mode 100644
index 35e8193545..0000000000
--- a/vllm_omni/platforms/rocm/stage_configs/qwen2_5_omni.yaml
+++ /dev/null
@@ -1,102 +0,0 @@
-# stage config for running qwen2.5-omni for multi-stage omni runtime.
-
-# The following config has been verified on 2x H100-80G GPU.
-stage_args:
- - stage_id: 0
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: true # Now we only support eager mode
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
- - stage_id: 1
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- engine_output_type: latent
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
-
- - stage_id: 2
- runtime:
- process: true
- devices: "2" # Example: use a different GPU than the previous stage; use "0" if single GPU
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.15
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- engine_output_type: audio
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
-
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/vllm_omni/platforms/rocm/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/platforms/rocm/stage_configs/qwen3_omni_moe.yaml
deleted file mode 100644
index 0ca150bee6..0000000000
--- a/vllm_omni/platforms/rocm/stage_configs/qwen3_omni_moe.yaml
+++ /dev/null
@@ -1,97 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
-
-# The following config has been verified on 2x H100-80G GPUs.
-stage_args:
- - stage_id: 0
- runtime:
- devices: "0"
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- hf_config_name: thinker_config
- tensor_parallel_size: 1
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
-
- - stage_id: 1
- runtime:
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- # tensor_parallel_size: 2
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 2
- runtime:
- devices: "1"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.1
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 1000000
- hf_config_name: thinker_config
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/vllm_omni/platforms/xpu/stage_configs/bagel.yaml b/vllm_omni/platforms/xpu/stage_configs/bagel.yaml
index 0fc8a25ea5..7b27f6a443 100644
--- a/vllm_omni/platforms/xpu/stage_configs/bagel.yaml
+++ b/vllm_omni/platforms/xpu/stage_configs/bagel.yaml
@@ -67,10 +67,6 @@ stage_args:
# Runtime edges
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
-
# Distributed connectors configuration (optional)
# More connectors will be supported in the future.
connectors:
@@ -83,4 +79,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml b/vllm_omni/platforms/xpu/stage_configs/hunyuan_image3_t2i.yaml
similarity index 93%
rename from vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml
rename to vllm_omni/platforms/xpu/stage_configs/hunyuan_image3_t2i.yaml
index 8f969ced5f..4e0005f82a 100644
--- a/vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml
+++ b/vllm_omni/platforms/xpu/stage_configs/hunyuan_image3_t2i.yaml
@@ -78,6 +78,3 @@ stage_args:
# Top-level runtime config (concise): default windows and stage edges
runtime:
enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
diff --git a/vllm_omni/platforms/xpu/stage_configs/qwen2_5_omni.yaml b/vllm_omni/platforms/xpu/stage_configs/qwen2_5_omni.yaml
deleted file mode 100644
index 7dbedb29a5..0000000000
--- a/vllm_omni/platforms/xpu/stage_configs/qwen2_5_omni.yaml
+++ /dev/null
@@ -1,101 +0,0 @@
-# stage config for running qwen2.5-omni for multi-stage omni runtime.
-
-# The following config is verified with 2 * Intel Arc Pro B60 XPU.
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9 # thinker weight is around 16.74GB for Qwen2.5-Omni-7B
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.5 # talker weight is 6.03GB for Qwen2.5-Omni-7B
- enforce_eager: false
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: latent
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
-
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.3 # code2wav weight is around 1.46GB for Qwen2.5-Omni-7B
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
-
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/vllm_omni/platforms/xpu/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/platforms/xpu/stage_configs/qwen3_omni_moe.yaml
deleted file mode 100644
index 49914bebc4..0000000000
--- a/vllm_omni/platforms/xpu/stage_configs/qwen3_omni_moe.yaml
+++ /dev/null
@@ -1,102 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
-
-# The following config is verified with 8 * Intel Arc Pro B60 XPU.
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "0,1,2,3"
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9 # thinker weight is around 61.08GB for Qwen3-Omni-30B-A3B-Instruct
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- hf_config_name: thinker_config
- tensor_parallel_size: 4
- max_cudagraph_capture_size: 0
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
-
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "4"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6 # talker weight is around 8.5GB for Qwen3-Omni-30B-A3B-Instruct
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- max_cudagraph_capture_size: 0
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "4"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.3 # code2wav weight is around 0.4GB for Qwen3-Omni-30B-A3B-Instruct
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 1000000
- hf_config_name: thinker_config
- max_cudagraph_capture_size: 0
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/vllm_omni/platforms/xpu/stage_configs/voxtral_tts.yaml b/vllm_omni/platforms/xpu/stage_configs/voxtral_tts.yaml
index 10051c1eda..0820ab6320 100644
--- a/vllm_omni/platforms/xpu/stage_configs/voxtral_tts.yaml
+++ b/vllm_omni/platforms/xpu/stage_configs/voxtral_tts.yaml
@@ -88,9 +88,6 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
connectors:
connector_of_shared_memory:
@@ -108,4 +105,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/quantization/component_config.py b/vllm_omni/quantization/component_config.py
index 7986da8850..f9286079be 100644
--- a/vllm_omni/quantization/component_config.py
+++ b/vllm_omni/quantization/component_config.py
@@ -23,6 +23,31 @@
)
+# Pre-quantized checkpoints (modelopt FP8/FP4/MXFP8) only quantize the
+# Thinker LM. Vision and audio encoder weights remain in BF16 with no
+# corresponding scale tensors in the checkpoint.
+PRE_QUANTIZED_METHODS: frozenset[str] = frozenset({"modelopt", "modelopt_fp4", "modelopt_mxfp8"})
+
+
+def resolve_encoder_quant_config(
+ quant_config: QuantizationConfig | None,
+) -> QuantizationConfig | None:
+ """Resolve quantization config for vision / audio encoders.
+
+ Returns *None* for pre-quantized methods so that FP8 kernels are never
+ applied to BF16 encoder weights (which lack scale tensors). All other
+ configs — including ``ComponentQuantizationConfig`` and ``None`` — are
+ returned as-is so the caller can handle them.
+ """
+ if (
+ quant_config is not None
+ and not isinstance(quant_config, ComponentQuantizationConfig)
+ and quant_config.get_name() in PRE_QUANTIZED_METHODS
+ ):
+ return None
+ return quant_config
+
+
class ComponentQuantizationConfig(QuantizationConfig):
"""Routes quantization to different configs by layer prefix."""
diff --git a/vllm_omni/request.py b/vllm_omni/request.py
index 3ec325316f..48cbf9b31d 100644
--- a/vllm_omni/request.py
+++ b/vllm_omni/request.py
@@ -1,8 +1,11 @@
from collections.abc import Callable
+from dataclasses import dataclass
from typing import TYPE_CHECKING
import numpy as np
import torch
+from vllm.multimodal.inputs import MultiModalFeatureSpec
+from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request
if TYPE_CHECKING:
@@ -92,3 +95,34 @@ def from_engine_core_request(
resumable=request.resumable,
reasoning_ended=request.reasoning_ended,
)
+
+
+@dataclass
+class OmniStreamingUpdate:
+ """
+ Override: add additional information
+ Lightweight data for streaming session continuation.
+
+ Contains only the fields needed to update an existing streaming session
+ with new input data.
+ """
+
+ mm_features: list[MultiModalFeatureSpec] | None
+ prompt_token_ids: list[int] | None
+ max_tokens: int
+ arrival_time: float
+ sampling_params: SamplingParams | None
+ additional_information: AdditionalInformationPayload | None = None
+
+ @classmethod
+ def from_request(cls, request: "Request") -> "OmniStreamingUpdate | None":
+ if not request.resumable:
+ return None
+ return cls(
+ mm_features=request.mm_features,
+ prompt_token_ids=request.prompt_token_ids,
+ max_tokens=request.max_tokens,
+ arrival_time=request.arrival_time,
+ sampling_params=request.sampling_params,
+ additional_information=request.additional_information,
+ )
diff --git a/vllm_omni/transformers_utils/configs/__init__.py b/vllm_omni/transformers_utils/configs/__init__.py
index 59b23f9149..598ac3a965 100644
--- a/vllm_omni/transformers_utils/configs/__init__.py
+++ b/vllm_omni/transformers_utils/configs/__init__.py
@@ -17,6 +17,13 @@
"FishSpeechConfig": "vllm_omni.transformers_utils.configs.fish_speech",
"FishSpeechSlowARConfig": "vllm_omni.transformers_utils.configs.fish_speech",
"FishSpeechFastARConfig": "vllm_omni.transformers_utils.configs.fish_speech",
+ "VoxCPMConfig": "vllm_omni.transformers_utils.configs.voxcpm",
+ "VoxCPM2Config": "vllm_omni.transformers_utils.configs.voxcpm2",
+ "BailingMoeV2Config": "vllm_omni.transformers_utils.configs.ming_flash_omni",
+ "BailingMM2Config": "vllm_omni.transformers_utils.configs.ming_flash_omni",
+ "MingFlashOmniConfig": "vllm_omni.transformers_utils.configs.ming_flash_omni",
+ "Qwen3VLMoeVisionConfig": "vllm_omni.transformers_utils.configs.ming_flash_omni",
+ "WhisperEncoderConfig": "vllm_omni.transformers_utils.configs.ming_flash_omni",
}
__all__ = [
@@ -27,6 +34,13 @@
"FishSpeechConfig",
"FishSpeechSlowARConfig",
"FishSpeechFastARConfig",
+ "VoxCPMConfig",
+ "VoxCPM2Config",
+ "BailingMoeV2Config",
+ "BailingMM2Config",
+ "MingFlashOmniConfig",
+ "Qwen3VLMoeVisionConfig",
+ "WhisperEncoderConfig",
]
@@ -47,3 +61,6 @@ def __dir__():
# run as soon as `vllm_omni.transformers_utils.configs` is imported.
from vllm_omni.transformers_utils.configs import fish_speech as _fish_speech # noqa: F401, E402
from vllm_omni.transformers_utils.configs import mammoth_moda2 as _mammoth_moda2 # noqa: F401, E402
+from vllm_omni.transformers_utils.configs import ming_flash_omni as _ming_flash_omni # noqa: F401, E402
+from vllm_omni.transformers_utils.configs import voxcpm as _voxcpm # noqa: F401, E402
+from vllm_omni.transformers_utils.configs import voxcpm2 as _voxcpm2 # noqa: F401, E402
diff --git a/vllm_omni/transformers_utils/configs/ming_flash_omni.py b/vllm_omni/transformers_utils/configs/ming_flash_omni.py
new file mode 100644
index 0000000000..dd13b682de
--- /dev/null
+++ b/vllm_omni/transformers_utils/configs/ming_flash_omni.py
@@ -0,0 +1,302 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+# Copyright 2024 ANT Group and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Configuration for Ming-flash-omni-2.0 model"""
+
+import os
+from typing import Any, ClassVar
+
+from transformers import AutoConfig, AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+class BailingMoeV2Config(PretrainedConfig):
+ model_type = "bailing_moe_v2"
+
+ def __init__(
+ self,
+ vocab_size=30592,
+ hidden_size=1024,
+ intermediate_size=None,
+ num_hidden_layers=24,
+ num_attention_heads=16,
+ num_key_value_heads=0,
+ hidden_act="silu",
+ use_qkv_bias=False,
+ use_qk_norm=False,
+ use_bias=True,
+ rms_norm_eps=1e-05,
+ norm_head=False,
+ tie_word_embeddings=False,
+ embedding_dropout=0.0,
+ attention_dropout=0.0,
+ output_dropout=0.0,
+ initializer_range=0.02,
+ max_position_embeddings=16384,
+ rope_theta=10000.0,
+ use_cache=True,
+ use_sliding_window=False,
+ sliding_window=81920,
+ max_window_layers=28,
+ rope_scaling=None,
+ mrope_section=None,
+ pad_token_id=126081,
+ num_experts=16,
+ num_shared_experts=1,
+ num_experts_per_tok=2,
+ n_group=8,
+ topk_group=4,
+ routed_scaling_factor=2.5,
+ moe_intermediate_size=None,
+ first_k_dense_replace=0,
+ head_dim=None,
+ output_router_logits=False,
+ partial_rotary_factor=0.5,
+ router_type="topN",
+ _attn_implementation="flash_attention_2",
+ use_interleaved_frame_timestamp=True,
+ # Multimodal token IDs
+ image_patch_token=157157,
+ video_patch_token=157175,
+ audio_patch_token=157168,
+ image_start_token=157158,
+ video_start_token=157160,
+ audio_start_token=157169,
+ image_end_token=157159,
+ video_end_token=157161,
+ audio_end_token=157170,
+ # Position encoding parameters
+ spatial_merge_size=2,
+ tokens_per_second=2,
+ **kwargs,
+ ):
+ self.num_hidden_layers = num_hidden_layers
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.use_qkv_bias = use_qkv_bias
+ self.use_bias = use_bias
+ self.norm_head = norm_head
+ self.rms_norm_eps = rms_norm_eps
+ self.embedding_dropout = embedding_dropout
+ self.attention_dropout = attention_dropout
+ self.output_dropout = output_dropout
+ self.initializer_range = initializer_range
+ self.max_position_embeddings = max_position_embeddings
+ self.rope_theta = rope_theta
+ self.use_cache = use_cache
+ self.use_sliding_window = use_sliding_window
+ self.sliding_window = sliding_window
+ self.max_window_layers = max_window_layers
+ self.head_dim = head_dim or self.hidden_size // self.num_attention_heads
+ self.use_qk_norm = use_qk_norm # arg unused; QK norm is always applied
+
+ # By default, match the value of `mrope_section`
+ # to `apply_3d_rotary_pos_emb` in Ming's repo:
+ # https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/modeling_bailing_moe_v2.py
+ if mrope_section is None:
+ mrope_section = (rope_scaling or {}).get("mrope_section", [8, 12, 12])
+ # Ensure mrope_section is stored inside rope_scaling
+ if rope_scaling is not None and isinstance(rope_scaling, dict):
+ rope_scaling = dict(rope_scaling)
+ rope_scaling.setdefault("mrope_section", mrope_section)
+ self.rope_scaling = rope_scaling
+
+ # NOTE: Expose rope_parameters["mrope_section"]
+ # This refers to the pattern used for GLM-Image in vllm_omni/patch.py
+ rope_type = (rope_scaling or {}).get("type", (rope_scaling or {}).get("rope_type", ""))
+ if rope_type in ("video_rope", "3D", "mrope"):
+ self.rope_parameters = {"mrope_section": mrope_section}
+ else:
+ self.rope_parameters = None
+
+ # MoE configs
+ self.num_experts = num_experts
+ self.num_shared_experts = num_shared_experts
+ self.num_experts_per_tok = num_experts_per_tok
+ self.n_group = n_group
+ self.topk_group = topk_group
+ self.moe_intermediate_size = moe_intermediate_size
+ self.first_k_dense_replace = first_k_dense_replace
+ self.output_router_logits = output_router_logits
+ self.routed_scaling_factor = routed_scaling_factor
+ self.partial_rotary_factor = partial_rotary_factor
+ self.router_type = router_type
+ self.use_interleaved_frame_timestamp = use_interleaved_frame_timestamp
+ self._attn_implementation = _attn_implementation
+
+ # Multimodal token IDs and position encoding
+ self.image_patch_token = image_patch_token
+ self.video_patch_token = video_patch_token
+ self.audio_patch_token = audio_patch_token
+ self.image_start_token = image_start_token
+ self.video_start_token = video_start_token
+ self.audio_start_token = audio_start_token
+ self.image_end_token = image_end_token
+ self.video_end_token = video_end_token
+ self.audio_end_token = audio_end_token
+ self.spatial_merge_size = spatial_merge_size
+ self.tokens_per_second = tokens_per_second
+
+ super().__init__(pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+class Qwen3VLMoeVisionConfig(PretrainedConfig):
+ """Configuration class for Qwen3 MoE Vision Transformer"""
+
+ model_type = "qwen3_moe_vit"
+
+ def __init__(
+ self,
+ depth=27,
+ hidden_size=1152,
+ hidden_act="gelu_pytorch_tanh",
+ intermediate_size=4304,
+ num_heads=16,
+ in_channels=3,
+ patch_size=16,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ out_hidden_size=3584,
+ num_position_embeddings=2304,
+ deepstack_visual_indexes=[8, 16, 24],
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.depth = depth
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.out_hidden_size = out_hidden_size
+ self.num_position_embeddings = num_position_embeddings
+ self.initializer_range = initializer_range
+ self.deepstack_visual_indexes = deepstack_visual_indexes
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs) -> "PretrainedConfig":
+ cls._set_token_in_kwargs(kwargs)
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ if "vision_config" in config_dict:
+ config_dict = config_dict["vision_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+class WhisperEncoderConfig(PretrainedConfig):
+ """Configuration class for Whisper audio encoder"""
+
+ model_type = "whisper_encoder"
+
+ def __init__(
+ self,
+ whisper_encoder_config: dict[str, Any] | None = None,
+ ds_kernel_size=3,
+ ds_stride=2,
+ norm_query_embeds=True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.whisper_encoder_config = whisper_encoder_config or {}
+ self.ds_kernel_size = ds_kernel_size
+ self.ds_stride = ds_stride
+ self.norm_query_embeds = norm_query_embeds
+
+
+class BailingMM2Config(PretrainedConfig):
+ model_type = "bailingmm_moe_v2_lite"
+ is_composition = True
+ sub_configs: ClassVar = {"llm_config": AutoConfig}
+
+ def __init__(
+ self,
+ mlp_depth=1,
+ llm_config: BailingMoeV2Config | None = None,
+ vision_config: Qwen3VLMoeVisionConfig | None = None,
+ audio_config: WhisperEncoderConfig | None = None,
+ **kwargs,
+ ):
+ self.audio_config = WhisperEncoderConfig(**audio_config) if isinstance(audio_config, dict) else audio_config
+ self.vision_config = (
+ Qwen3VLMoeVisionConfig(**vision_config) if isinstance(vision_config, dict) else vision_config
+ )
+ self.llm_config = BailingMoeV2Config(**llm_config) if isinstance(llm_config, dict) else llm_config
+ self.mlp_depth = mlp_depth
+ super().__init__(**kwargs)
+
+ def get_text_config(self, decoder: bool = False) -> PretrainedConfig: # noqa: ARG002
+ return self.llm_config
+
+
+class MingFlashOmniConfig(PretrainedConfig):
+ """Configuration class for unified Ming-flash-omni-2.0 model"""
+
+ model_type = "ming_flash_omni"
+ is_composition = True
+ sub_configs: ClassVar = {"thinker_config": BailingMM2Config}
+
+ def __init__(
+ self,
+ thinker_config: BailingMM2Config | None = None,
+ image_gen_config: dict[str, Any] | None = None,
+ talker_config: dict[str, Any] | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ if isinstance(thinker_config, dict):
+ self.thinker_config = BailingMM2Config(**thinker_config)
+ else:
+ self.thinker_config = thinker_config or BailingMM2Config()
+
+ # Image generation config (for future implementation)
+ self.image_gen_config = image_gen_config
+
+ # Talker config (for future implementation)
+ self.talker_config = talker_config
+
+ def get_text_config(self, decoder: bool = False) -> PretrainedConfig: # noqa: ARG002
+ return self.thinker_config.get_text_config()
+
+
+# Register model_type -> config class for AutoConfig
+AutoConfig.register(BailingMoeV2Config.model_type, BailingMoeV2Config)
+AutoConfig.register(BailingMM2Config.model_type, BailingMM2Config)
+AutoConfig.register(MingFlashOmniConfig.model_type, MingFlashOmniConfig)
+
+# Register tokenizer mapping for composition configs so that
+# AutoTokenizer.from_pretrained can resolve the tokenizer class
+AutoTokenizer.register(BailingMM2Config, fast_tokenizer_class=PreTrainedTokenizerFast)
+AutoTokenizer.register(MingFlashOmniConfig, fast_tokenizer_class=PreTrainedTokenizerFast)
diff --git a/vllm_omni/transformers_utils/configs/voxcpm.py b/vllm_omni/transformers_utils/configs/voxcpm.py
new file mode 100644
index 0000000000..0267838915
--- /dev/null
+++ b/vllm_omni/transformers_utils/configs/voxcpm.py
@@ -0,0 +1,68 @@
+from transformers import AutoConfig
+from transformers.configuration_utils import PretrainedConfig
+
+
+class VoxCPMConfig(PretrainedConfig):
+ model_type = "voxcpm"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ bos_token_id: int = 1,
+ eos_token_id: int = 2,
+ vocab_size: int = 32000,
+ hidden_size: int = 1024,
+ intermediate_size: int = 4096,
+ max_position_embeddings: int = 4096,
+ num_attention_heads: int = 16,
+ num_hidden_layers: int = 24,
+ num_key_value_heads: int = 16,
+ rms_norm_eps: float = 1e-6,
+ rope_theta: float = 10000.0,
+ rope_scaling: dict | None = None,
+ lm_config: dict | None = None,
+ encoder_config: dict | None = None,
+ dit_config: dict | None = None,
+ audio_vae_config: dict | None = None,
+ patch_size: int = 2,
+ feat_dim: int = 64,
+ residual_lm_num_layers: int = 6,
+ scalar_quantization_latent_dim: int = 256,
+ scalar_quantization_scale: int = 9,
+ max_length: int = 4096,
+ device: str = "cuda",
+ dtype: str = "bfloat16",
+ dit_mean_mode: bool = False,
+ **kwargs,
+ ):
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.max_position_embeddings = max_position_embeddings
+ self.num_attention_heads = num_attention_heads
+ self.num_hidden_layers = num_hidden_layers
+ self.num_key_value_heads = num_key_value_heads
+ self.rms_norm_eps = rms_norm_eps
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+
+ self.lm_config = lm_config or {}
+ self.encoder_config = encoder_config or {}
+ self.dit_config = dit_config or {}
+ self.audio_vae_config = audio_vae_config
+
+ self.patch_size = patch_size
+ self.feat_dim = feat_dim
+ self.residual_lm_num_layers = residual_lm_num_layers
+ self.scalar_quantization_latent_dim = scalar_quantization_latent_dim
+ self.scalar_quantization_scale = scalar_quantization_scale
+ self.max_length = max_length
+ self.device = device
+ self.dtype = dtype
+ self.dit_mean_mode = dit_mean_mode
+
+
+AutoConfig.register("voxcpm", VoxCPMConfig)
+
+__all__ = ["VoxCPMConfig"]
diff --git a/vllm_omni/transformers_utils/configs/voxcpm2.py b/vllm_omni/transformers_utils/configs/voxcpm2.py
new file mode 100644
index 0000000000..c625284bd6
--- /dev/null
+++ b/vllm_omni/transformers_utils/configs/voxcpm2.py
@@ -0,0 +1,153 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import math
+
+from transformers import AutoConfig
+from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_rope_utils import rope_config_validation
+
+
+class VoxCPM2Config(PretrainedConfig):
+ """Configuration for VoxCPM2 native AR integration.
+
+ The HuggingFace checkpoint stores LM parameters inside a nested
+ ``lm_config`` dict. This class hoists them to top-level attributes
+ so that vllm's ``MiniCPMModel`` can consume them directly.
+
+ vllm's MiniCPM **always** applies muP scaling (scale_emb, scale_depth,
+ dim_model_base). VoxCPM2 was trained with ``use_mup=false``, so we
+ neutralise the scalings:
+ * ``scale_emb = 1.0``
+ * ``scale_depth = sqrt(num_hidden_layers)`` (cancels the division)
+ * ``dim_model_base = hidden_size`` (makes scale_width = 1.0)
+ """
+
+ model_type = "voxcpm2"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ # -- top-level VoxCPM2 params --
+ architecture: str = "voxcpm2",
+ lm_config: dict | None = None,
+ encoder_config: dict | None = None,
+ dit_config: dict | None = None,
+ audio_vae_config: dict | None = None,
+ patch_size: int = 4,
+ feat_dim: int = 64,
+ residual_lm_num_layers: int = 8,
+ residual_lm_no_rope: bool = True,
+ scalar_quantization_latent_dim: int = 512,
+ scalar_quantization_scale: int = 9,
+ max_length: int = 8192,
+ device: str = "cuda",
+ dtype: str = "bfloat16",
+ # -- LM defaults (overridden by lm_config if present) --
+ bos_token_id: int = 1,
+ eos_token_id: int = 2,
+ vocab_size: int = 73448,
+ hidden_size: int = 2048,
+ intermediate_size: int = 6144,
+ max_position_embeddings: int = 32768,
+ num_attention_heads: int = 16,
+ num_hidden_layers: int = 28,
+ num_key_value_heads: int = 2,
+ rms_norm_eps: float = 1e-5,
+ rope_theta: float = 10000.0,
+ rope_scaling: dict | None = None,
+ **kwargs,
+ ):
+ super().__init__(
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+ self.architecture = architecture
+
+ # -- VoxCPM2-specific fields --
+ self.lm_config = lm_config or {}
+ self.encoder_config = encoder_config or {}
+ self.dit_config = dit_config or {}
+ self.audio_vae_config = audio_vae_config or {}
+ self.patch_size = patch_size
+ self.feat_dim = feat_dim
+ self.residual_lm_num_layers = residual_lm_num_layers
+ self.residual_lm_no_rope = residual_lm_no_rope
+ self.scalar_quantization_latent_dim = scalar_quantization_latent_dim
+ self.scalar_quantization_scale = scalar_quantization_scale
+ self.max_length = max_length
+ self.device = device
+ self.dtype = dtype
+
+ # -- Hoist LM parameters to top-level for MiniCPMModel --
+ lm = self.lm_config
+ self.vocab_size = lm.get("vocab_size", vocab_size)
+ self.hidden_size = lm.get("hidden_size", hidden_size)
+ self.intermediate_size = lm.get("intermediate_size", intermediate_size)
+ self.max_position_embeddings = lm.get("max_position_embeddings", max_position_embeddings)
+ self.num_attention_heads = lm.get("num_attention_heads", num_attention_heads)
+ self.num_hidden_layers = lm.get("num_hidden_layers", num_hidden_layers)
+ self.num_key_value_heads = lm.get("num_key_value_heads", num_key_value_heads)
+ self.rms_norm_eps = lm.get("rms_norm_eps", rms_norm_eps)
+ self.rope_theta = lm.get("rope_theta", rope_theta)
+
+ # MiniCPM-specific: kv_channels overrides head_dim when set.
+ kv_channels = lm.get("kv_channels")
+ if kv_channels is not None:
+ self.head_dim = kv_channels
+ else:
+ self.head_dim = self.hidden_size // self.num_attention_heads
+
+ # MiniCPM requires hidden_act; VoxCPM2 uses SiLU.
+ self.hidden_act = "silu"
+ self.hidden_act_param = 0.0
+ self.tie_word_embeddings = False
+ self.num_experts = 0
+
+ # -- muP scaling --
+ # Native VoxCPM2 MiniCPM gates scale_depth behind use_mup:
+ # use_mup=True → residual += h * (scale_depth / sqrt(N))
+ # use_mup=False → residual += h (plain add, no scaling)
+ # But vllm's MiniCPMModel ALWAYS applies scale_depth / sqrt(N).
+ # Native applies scale_emb externally; vllm applies it in embed_input_ids.
+ use_mup = lm.get("use_mup", False)
+ self.scale_emb = lm.get("scale_emb", 1.0)
+ if use_mup:
+ self.scale_depth = lm.get("scale_depth", 1.0)
+ self.dim_model_base = lm.get("dim_model_base", self.hidden_size)
+ else:
+ # Neutralize: scale_depth/sqrt(N) = 1.0, scale_width = 1.0
+ self.scale_depth = math.sqrt(self.num_hidden_layers)
+ self.dim_model_base = self.hidden_size
+
+ # -- RoPE scaling (longrope) --
+ raw_rope = lm.get("rope_scaling", rope_scaling)
+ if raw_rope is not None:
+ self.rope_scaling = dict(raw_rope)
+ # HF expects "rope_type" not "type"
+ if "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling.pop("type")
+ # longrope requires "factor" (used by HF validation)
+ if "factor" not in self.rope_scaling:
+ self.rope_scaling["factor"] = 1.0
+ rope_config_validation(self)
+
+ # vllm's MiniCPMAttention reads config.rope_parameters (a dict
+ # with rope_type, theta, scaling factors, etc.). HF transformers
+ # only auto-computes this for known model_types; for custom
+ # types we must build it manually.
+ if not getattr(self, "rope_parameters", None):
+ rp = dict(self.rope_scaling)
+ rp["rope_theta"] = self.rope_theta
+ self.rope_parameters = rp
+ else:
+ self.rope_scaling = None
+
+ def get_text_config(self, **kwargs):
+ """Return self as the text config — LM attributes are top-level."""
+ return self
+
+
+AutoConfig.register("voxcpm2", VoxCPM2Config)
+
+__all__ = ["VoxCPM2Config"]
diff --git a/vllm_omni/transformers_utils/processors/__init__.py b/vllm_omni/transformers_utils/processors/__init__.py
new file mode 100644
index 0000000000..52ca657539
--- /dev/null
+++ b/vllm_omni/transformers_utils/processors/__init__.py
@@ -0,0 +1,12 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+
+from vllm_omni.transformers_utils.processors.ming import (
+ MingFlashOmniProcessor,
+ MingWhisperFeatureExtractor,
+)
+
+__all__ = [
+ "MingFlashOmniProcessor",
+ "MingWhisperFeatureExtractor",
+]
diff --git a/vllm_omni/transformers_utils/processors/ming.py b/vllm_omni/transformers_utils/processors/ming.py
new file mode 100644
index 0000000000..7f414b7268
--- /dev/null
+++ b/vllm_omni/transformers_utils/processors/ming.py
@@ -0,0 +1,430 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+# Copyright 2024 ANT Group and the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+
+import numpy as np
+import torch
+from transformers import AutoFeatureExtractor, AutoProcessor
+from transformers.feature_extraction_utils import BatchFeature, FeatureExtractionMixin
+from transformers.processing_utils import ProcessorMixin
+from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
+
+DEFAULT_IMAGE_PATCH_TOKEN = ""
+DEFAULT_IM_START_TOKEN = ""
+DEFAULT_IM_END_TOKEN = " "
+DEFAULT_VID_START_TOKEN = ""
+DEFAULT_VID_END_TOKEN = " "
+DEFAULT_FRAME_PATCH_TOKEN = ""
+
+DEFAULT_AUDIO_PATCH_TOKEN = ""
+DEFAULT_AU_START_TOKEN = ""
+DEFAULT_AU_END_TOKEN = " "
+
+# High-level placeholders used in user prompts
+PLACEHOLDER_IMAGE_TOKEN_IN_TEXT = ""
+PLACEHOLDER_VIDEO_TOKEN_IN_TEXT = ""
+PLACEHOLDER_AUDIO_TOKEN_IN_TEXT = ""
+
+# Chat template constants
+USER_PREFIX = "HUMAN "
+ASSISTANT_PREFIX = "ASSISTANT "
+SYSTEM_PROMPT_NOTHINK = "SYSTEM 你是一个友好的AI助手。\n\ndetailed thinking off"
+SYSTEM_PROMPT_THINK = "SYSTEM 你是一个友好的AI助手。\n\ndetailed thinking on"
+
+
+_NORM_FACTOR_FOR_DTYPE = {
+ torch.int8: 2**7,
+ torch.int16: 2**15,
+ torch.int32: 2**31,
+ torch.int64: 2**63,
+ torch.float32: 1,
+ torch.float64: 1,
+}
+
+
+def _normalize_audio_tensor(
+ waveform: torch.Tensor,
+ sample_rate: int,
+ target_sample_rate: int = 16000,
+) -> torch.Tensor:
+ """Normalize waveform to float32, mono, and optionally resample."""
+ norm_factor = _NORM_FACTOR_FOR_DTYPE.get(waveform.dtype, 1)
+ waveform = waveform.to(torch.float32) / norm_factor
+
+ # Remove channel dimension
+ while len(waveform.shape) > 1:
+ waveform = waveform[0]
+
+ # Resample if needed
+ if sample_rate != target_sample_rate:
+ import torchaudio
+
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
+ waveform = resampler(waveform.unsqueeze(0)).squeeze(0)
+
+ return waveform
+
+
+class MingWhisperFeatureExtractor(FeatureExtractionMixin):
+ """Whisper log-mel feature extractor for Ming-flash-omni-2.0.
+
+ Produces audio_feats in the time-first packed format.
+
+ Adapted from Ming's WhisperAudioEncoder
+ https://github.com/inclusionAI/Ming/blob/070dc3c13f95d97952ab7d22030df0c9e28a5122/modeling_whisper_encoder.py
+ and HF transformers WhisperFeatureExtractor
+ https://github.com/huggingface/transformers/blob/f842abaca95a7dbf3fc6e16122e7409109bc1431/src/transformers/models/whisper/feature_extraction_whisper.py#L33
+ """
+
+ model_input_names = ["audio_feats", "audio_feats_lengths"]
+
+ def __init__(self, feature_size: int = 128, sampling_rate: int = 16000, **kwargs):
+ # feature_size == n_mels; stored so to_dict() serialises it correctly.
+ self.feature_size = feature_size
+ self.sampling_rate = sampling_rate
+ super().__init__(**kwargs)
+
+ @property
+ def n_mels(self) -> int:
+ return self.feature_size
+
+ def __call__(
+ self,
+ audios: tuple | list,
+ return_tensors: str | None = None,
+ **kwargs,
+ ) -> BatchFeature:
+ """Preprocess audio(s) into Whisper log-mel spectrograms"""
+ import whisper
+
+ if not isinstance(audios, list):
+ audios = [audios]
+
+ audio_feat_list = []
+ for waveform, sr in audios:
+ if isinstance(waveform, np.ndarray):
+ waveform = torch.from_numpy(waveform)
+ waveform = _normalize_audio_tensor(waveform, sr, target_sample_rate=self.sampling_rate)
+ mel = whisper.log_mel_spectrogram(waveform, n_mels=self.n_mels)
+ audio_feat_list.append(mel.transpose(0, 1)) # [T, n_mels]
+
+ audio_feats_lengths = torch.tensor([[feat.shape[0] for feat in audio_feat_list]], dtype=torch.long)
+ # Two stride-2 convolutions in series:
+ # 1. WhisperAudioEncoder conv2: kernel=3, stride=2, padding=1
+ # (conv1 has stride=1 and does not change T)
+ # 2. AudioProjector Conv1d: kernel=3, stride=2, padding=1
+ # Combined: T → ((T-1)//2 + 1 - 1)//2 + 1
+ # See also: AudioProjector.compute_output_length()
+ encoder_feats_lengths = ((audio_feats_lengths - 3 + 2 * 1) // 2 + 1 - 3 + 2 * 1) // 2 + 1
+ audio_feats = torch.cat(audio_feat_list, dim=0).unsqueeze(0) # [1, T_total, n_mels]
+
+ data = {
+ # [1, T_total, n_mels], all audio clips concatenated
+ "audio_feats": audio_feats.numpy(),
+ # [1, n_audios], actual frame count
+ "audio_feats_lengths": audio_feats_lengths.numpy(),
+ # [1, n_audios]
+ "encoder_feats_lengths": encoder_feats_lengths,
+ }
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+class MingFlashOmniProcessor(ProcessorMixin):
+ """Top-level multimodal processor for Ming-flash-omni 2.0.
+
+ Adapted from Ming's BailingMM2Processor
+ https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/processing_bailingmm2.py
+
+ Subprocessors include:
+ - Qwen2VLImageProcessor (image/video)
+ - MingWhisperFeatureExtractor (modified audio processor using Whisper's log-mel spectrogram)
+ """
+
+ attributes = ["image_processor", "audio_processor", "tokenizer"]
+ image_processor_class = "AutoImageProcessor"
+ audio_processor_class = "AutoFeatureExtractor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(
+ self,
+ image_processor=None,
+ audio_processor=None,
+ tokenizer=None,
+ merge_size: int = 2,
+ **kwargs,
+ ):
+ # Enforce that all sub-processors exist
+ # Keep None defaults in the signature for HF ProcessorMixin compatibility
+ if image_processor is None:
+ raise ValueError("MingFlashOmniProcessor requires `image_processor`.")
+ if audio_processor is None:
+ raise ValueError("MingFlashOmniProcessor requires `audio_processor`.")
+ if tokenizer is None:
+ raise ValueError("MingFlashOmniProcessor requires `tokenizer`.")
+
+ self.spatial_merge_size = merge_size
+ self.image_token = PLACEHOLDER_IMAGE_TOKEN_IN_TEXT
+ self.video_token = PLACEHOLDER_VIDEO_TOKEN_IN_TEXT
+ self.audio_token = PLACEHOLDER_AUDIO_TOKEN_IN_TEXT
+ super().__init__(
+ image_processor=image_processor,
+ audio_processor=audio_processor,
+ tokenizer=tokenizer,
+ )
+
+ # Fall back to the tokenizer's own chat_template.
+ if self.chat_template is None:
+ self.chat_template = getattr(tokenizer, "chat_template", None)
+
+ def __call__(
+ self,
+ text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput],
+ images: Any | None = None,
+ videos: Any | None = None,
+ audios: tuple[np.ndarray, int] | list[tuple[np.ndarray, int]] | None = None,
+ **kwargs,
+ ) -> BatchFeature:
+ # This should always be parallel implementations that mirror
+ # `_get_prompt_updates` logic in Ming processor, and vice versa.
+ # Ensure text is a list
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list):
+ raise ValueError("text must be a string or list of strings")
+
+ data: dict[str, Any] = {}
+
+ if images is not None:
+ image_outputs = self.image_processor(
+ images=images,
+ videos=None,
+ return_tensors="pt",
+ **kwargs.get("images_kwargs", {}),
+ )
+ data.update(image_outputs)
+ if "image_grid_thw" in image_outputs:
+ text = self._expand_image_tokens(text, image_outputs["image_grid_thw"])
+
+ if videos is not None:
+ video_outputs = self.image_processor(
+ images=None,
+ videos=videos,
+ return_tensors="pt",
+ **kwargs.get("videos_kwargs", {}),
+ )
+ if "pixel_values" in video_outputs:
+ video_outputs["pixel_values_videos"] = video_outputs.pop("pixel_values")
+ if "image_grid_thw" in video_outputs:
+ video_outputs["video_grid_thw"] = video_outputs.pop("image_grid_thw")
+ data.update(video_outputs)
+ if "video_grid_thw" in video_outputs:
+ text = self._expand_video_tokens(text, video_outputs["video_grid_thw"])
+
+ if audios is not None:
+ audio_outputs = self.audio_processor(
+ audios,
+ return_tensors="pt",
+ **kwargs.get("audio_kwargs", {}),
+ )
+ data.update(audio_outputs)
+ if "encoder_feats_lengths" in audio_outputs:
+ text = self._expand_audio_tokens(text, audio_outputs["encoder_feats_lengths"])
+
+ text_outputs = self.tokenizer(
+ text,
+ return_tensors="pt",
+ **kwargs.get("text_kwargs", {}),
+ )
+ data.update(text_outputs)
+ return BatchFeature(data=data)
+
+ def _expand_image_tokens(
+ self,
+ text: list[str],
+ image_grid_thw: torch.Tensor,
+ special_token: str = PLACEHOLDER_IMAGE_TOKEN_IN_TEXT,
+ ) -> list[str]:
+ merge_size = self.spatial_merge_size
+ num_patches_per_image = torch.prod(image_grid_thw, dim=1) // (merge_size**2)
+ prompt_strings = []
+ image_index = 0
+ for sample in text:
+ num_images = sample.count(special_token)
+ if num_images > 0:
+ for i in range(image_index, num_images + image_index):
+ num_patches = int(num_patches_per_image[i].item())
+ img_text = (
+ DEFAULT_IM_START_TOKEN + (DEFAULT_IMAGE_PATCH_TOKEN * num_patches) + DEFAULT_IM_END_TOKEN + "\n"
+ )
+ sample = sample.replace(special_token, img_text, 1)
+ image_index += num_images
+ prompt_strings.append(sample)
+ return prompt_strings
+
+ def _expand_video_tokens(
+ self,
+ text: list[str],
+ video_grid_thw: torch.Tensor,
+ special_token: str = PLACEHOLDER_VIDEO_TOKEN_IN_TEXT,
+ ) -> list[str]:
+ merge_size = self.spatial_merge_size
+ num_patches_per_video = torch.prod(video_grid_thw, dim=1) // (merge_size**2)
+ prompt_strings = []
+ video_index = 0
+ for sample in text:
+ num_videos = sample.count(special_token)
+ if num_videos > 0:
+ for i in range(video_index, num_videos + video_index):
+ num_patches = int(num_patches_per_video[i].item())
+ video_text = (
+ DEFAULT_VID_START_TOKEN
+ + (DEFAULT_FRAME_PATCH_TOKEN * num_patches)
+ + DEFAULT_VID_END_TOKEN
+ + "\n"
+ )
+ sample = sample.replace(special_token, video_text, 1)
+ video_index += num_videos
+ prompt_strings.append(sample)
+ return prompt_strings
+
+ def _expand_audio_tokens(
+ self,
+ text: list[str],
+ encoder_feats_lengths: torch.Tensor,
+ special_token: str = PLACEHOLDER_AUDIO_TOKEN_IN_TEXT,
+ ) -> list[str]:
+ prompt_strings = []
+ for sample, lengths_tensor in zip(text, encoder_feats_lengths):
+ for length in lengths_tensor:
+ num_patches = int(length.item())
+ if num_patches == 0:
+ continue
+ audio_text = DEFAULT_AU_START_TOKEN + (DEFAULT_AUDIO_PATCH_TOKEN * num_patches) + DEFAULT_AU_END_TOKEN
+ if special_token in sample:
+ sample = sample.replace(special_token, audio_text, 1)
+ else:
+ sample = sample + audio_text + "\n"
+ prompt_strings.append(sample)
+ return prompt_strings
+
+ def apply_system_template(
+ self,
+ sys_prompt_exp: str | None = None,
+ use_cot_system_prompt: bool = False,
+ ) -> str:
+ sys_prompt = SYSTEM_PROMPT_THINK if use_cot_system_prompt else SYSTEM_PROMPT_NOTHINK
+ if sys_prompt_exp is not None:
+ sys_prompt = sys_prompt.replace("你是一个友好的AI助手。", sys_prompt_exp)
+ return sys_prompt
+
+ def apply_chat_template(
+ self,
+ conversation: list[dict[str, Any]],
+ sys_prompt_exp: str | None = None,
+ use_cot_system_prompt: bool = False,
+ **kwargs,
+ ) -> str:
+ eos = self.tokenizer.eos_token
+ text = self.apply_system_template(sys_prompt_exp, use_cot_system_prompt) + eos
+
+ for idx, message in enumerate(conversation):
+ assert message["role"] in ["HUMAN", "ASSISTANT"], (
+ f"Invalid role: {message['role']}. Must be 'HUMAN' or 'ASSISTANT'"
+ )
+ if idx == len(conversation) - 1:
+ assert message["role"] == "HUMAN", "Last message must be from HUMAN"
+
+ text += USER_PREFIX if message["role"] == "HUMAN" else ASSISTANT_PREFIX
+
+ content = message["content"]
+ if isinstance(content, str):
+ # text-only
+ text += content
+ elif isinstance(content, list):
+ # structured content with multimodal elements
+ # Count existing placeholders from text items only
+ image_placeholders = 0
+ video_placeholders = 0
+ audio_placeholders = 0
+ for content_item in content:
+ if content_item.get("type", "text") == "text":
+ t = content_item.get("text", "")
+ image_placeholders += t.count(PLACEHOLDER_IMAGE_TOKEN_IN_TEXT)
+ video_placeholders += t.count(PLACEHOLDER_VIDEO_TOKEN_IN_TEXT)
+ audio_placeholders += t.count(PLACEHOLDER_AUDIO_TOKEN_IN_TEXT)
+
+ if video_placeholders > 1:
+ raise ValueError("Video count must be at most 1 per message!")
+
+ # Insert placeholders only for media items not already covered
+ for content_item in content:
+ content_type = content_item.get("type", "text")
+
+ if content_type == "image":
+ image_data = content_item.get("image")
+ if image_data is not None:
+ from PIL import Image as PILImage
+
+ num_images = 1 if isinstance(image_data, (str, PILImage.Image)) else len(image_data)
+ for _ in range(num_images):
+ if image_placeholders > 0:
+ image_placeholders -= 1
+ else:
+ text += PLACEHOLDER_IMAGE_TOKEN_IN_TEXT
+
+ elif content_type == "video":
+ if video_placeholders > 0:
+ video_placeholders -= 1
+ else:
+ text += PLACEHOLDER_VIDEO_TOKEN_IN_TEXT
+ elif content_type == "audio":
+ audio_data = content_item.get("audio")
+ if audio_data is not None:
+ num_audios = 1 if isinstance(audio_data, str) else len(audio_data)
+ for _ in range(num_audios):
+ if audio_placeholders > 0:
+ audio_placeholders -= 1
+ else:
+ text += PLACEHOLDER_AUDIO_TOKEN_IN_TEXT
+
+ elif content_type == "text":
+ text += content_item.get("text", "")
+
+ # Add EOS token after each message except the last one
+ text += eos
+
+ text += ASSISTANT_PREFIX
+ return text
+
+ def batch_decode(self, *args, **kwargs):
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @property
+ def model_input_names(self):
+ names = (
+ self.tokenizer.model_input_names
+ + self.image_processor.model_input_names
+ + self.audio_processor.model_input_names
+ )
+ return list(dict.fromkeys(names))
+
+
+AutoFeatureExtractor.register("MingWhisperFeatureExtractor", MingWhisperFeatureExtractor)
+AutoProcessor.register("MingFlashOmniProcessor", MingFlashOmniProcessor)
diff --git a/vllm_omni/utils/audio.py b/vllm_omni/utils/audio.py
new file mode 100644
index 0000000000..cc25c17947
--- /dev/null
+++ b/vllm_omni/utils/audio.py
@@ -0,0 +1,68 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""Audio utility functions shared across models and entrypoints."""
+
+import numpy as np
+import torch
+from torchaudio.functional import melscale_fbanks
+
+
+def mel_filter_bank(
+ sr: int,
+ n_fft: int,
+ n_mels: int,
+ fmin: float = 0.0,
+ fmax: float | None = None,
+) -> torch.Tensor:
+ """Compute a mel filterbank matrix.
+
+ Drop-in replacement for ``librosa.filters.mel`` using
+ ``torchaudio.functional.melscale_fbanks``.
+
+ Args:
+ sr: Sample rate of the audio.
+ n_fft: FFT window size.
+ n_mels: Number of mel bands.
+ fmin: Minimum frequency (Hz).
+ fmax: Maximum frequency (Hz). Defaults to ``sr / 2``.
+
+ Returns:
+ Tensor of shape ``(n_mels, n_fft // 2 + 1)``.
+ """
+ if fmax is None:
+ fmax = float(sr) / 2.0
+ # Use mel_scale='slaney' and norm='slaney' to match librosa's
+ # default behaviour (Slaney 1998 frequency mapping with area
+ # normalization).
+ return melscale_fbanks(
+ n_freqs=n_fft // 2 + 1,
+ f_min=float(fmin),
+ f_max=float(fmax),
+ n_mels=n_mels,
+ sample_rate=sr,
+ mel_scale="slaney",
+ norm="slaney",
+ ).T
+
+
+def peak_normalize(
+ audio: np.ndarray,
+ db_level: float = -6.0,
+) -> np.ndarray:
+ """Normalize audio so peak amplitude reaches a target dB level.
+
+ Drop-in replacement for ``sox.Transformer().norm(db_level=...)``.
+
+ Args:
+ audio: Input waveform as a 1-D numpy array.
+ db_level: Target peak amplitude in dBFS.
+
+ Returns:
+ Normalized waveform with the same dtype as *audio*.
+ """
+ peak = np.abs(audio).max()
+ if peak == 0:
+ return audio
+ target = 10.0 ** (db_level / 20.0)
+ return audio * (target / peak)
diff --git a/vllm_omni/utils/mm_outputs.py b/vllm_omni/utils/mm_outputs.py
new file mode 100644
index 0000000000..66d4e6ffe0
--- /dev/null
+++ b/vllm_omni/utils/mm_outputs.py
@@ -0,0 +1,93 @@
+"""Utilities for handling multimodal outputs / building multimodal output
+payloads, most of which are shared by the prefix cache / no prefix cache path.
+"""
+
+import torch
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def build_mm_cpu(multimodal_outputs: dict) -> dict[str, object]:
+ """Pre-copies multimodal tensor to CPU once (not per-request) to avoid
+ redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU.
+
+ In the case of prefix caching, the multimodal outputs provided will
+ only contain the passthrough data.
+
+ Args:
+ multimodal_outputs: Multimodal dict mapping strings to objects.
+ """
+ # Pre-copy multimodal tensors to CPU once (not per-request) to avoid
+ # redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU.
+ mm_cpu: dict[str, object] = {}
+ # Currently there are some cases where this is true at the
+ # moment, which should be fixed.
+ if not isinstance(multimodal_outputs, dict):
+ logger.warning("Multimodal outputs are not a dict and will not be passed")
+
+ if multimodal_outputs:
+ for k, v in multimodal_outputs.items():
+ if isinstance(v, torch.Tensor):
+ mm_cpu[k] = v.detach().to("cpu").contiguous()
+ elif isinstance(v, dict):
+ sub_dict: dict[str, torch.Tensor] = {}
+ for sk, sv in v.items():
+ if isinstance(sv, torch.Tensor):
+ sub_dict[str(sk)] = sv.detach().to("cpu").contiguous()
+ if sub_dict:
+ mm_cpu[k] = sub_dict
+ elif isinstance(v, list) and len(v) > 0:
+ cpu_list = []
+ for elem in v:
+ if isinstance(elem, torch.Tensor):
+ cpu_list.append(elem.detach().to("cpu").contiguous())
+ else:
+ cpu_list.append(elem)
+ mm_cpu[k] = cpu_list
+ elif v is not None:
+ mm_cpu[k] = v
+ return mm_cpu
+
+
+def to_payload_element(
+ element: object, idx: int, start: int, end: int, pass_lists_through: bool = False, seq_len: int | None = None
+):
+ """Build an mm payload element corresponding to one request index
+ from an element containing 0 or more CPU tensors.
+
+ Args:
+ element: The object to be added to the payload.
+ idx: The index of the request.
+ start: The start index corresponding to the request idx.
+ end: The end index corresponding to the request idx.
+ pass_lists_through: bool Whether or not lists should be treated as
+ passthrough data; this should be False in normal cases, but True
+ if we need to avoid splitting nonempty lists prior to calling
+ postprocess, which is the case for prefix cache.
+ seq_len: Optional sequence length (i.e., dim 0 of hidden states).
+ This should be set to None in the prefix caching case, because
+ the condition that would be executed here is the same as the
+ criteria for being added to the multimodal outputs cache.
+ """
+ # Prefix cache won't hit this case because this is the condition
+ # for being a mm_cache_key in the multimodal outputs tensor.
+ if seq_len is not None and isinstance(element, torch.Tensor) and element.shape[0] == seq_len:
+ return element[start:end].contiguous()
+ # Every other case is shared between prefix cache (passthrough data)
+ # and running a model without prefix caching.
+ elif isinstance(element, dict):
+ return {sk: sv[start:end].contiguous() for sk, sv in element.items()}
+ elif isinstance(element, list):
+ # For lists, clone tensors to avoid cross-request aliasing
+ if pass_lists_through:
+ return [elem.clone() if isinstance(elem, torch.Tensor) else elem for elem in element]
+ element = element[idx] if idx < len(element) else element[0]
+ if isinstance(element, torch.Tensor):
+ element = element.clone()
+ return element
+ elif isinstance(element, torch.Tensor):
+ # List-derived tensor payloads are request-invariant; clone to
+ # avoid accidental cross-request aliasing on downstream mutation.
+ return element.clone()
+ return element
diff --git a/vllm_omni/version.py b/vllm_omni/version.py
index e5f0b6b661..296bebc8e2 100644
--- a/vllm_omni/version.py
+++ b/vllm_omni/version.py
@@ -5,12 +5,12 @@
and written to _version.py during package build.
"""
+import warnings
+
try:
# Import auto-generated version from _version.py (created by setuptools_scm)
from ._version import __version__, __version_tuple__
except ImportError as e:
- import warnings
-
warnings.warn(
f"Failed to import version from _version.py: {e}\n"
"This typically happens in development mode before building.\n"
@@ -22,4 +22,37 @@
__version__ = "dev"
__version_tuple__ = (0, 0, "dev")
+
+def warn_if_misaligned_vllm_version():
+ """Warn if vLLM and vllm-omni versions don't match (major.minor)."""
+ # Import vllm lazily since import order may be sensitive with current monkeypatching,
+ # but we want to check this before potentially breaking imports run.
+ from vllm import __version__ as vllm_version
+ from vllm import __version_tuple__ as vllm_version_tuple
+
+ omni_ver: tuple[str | int, ...] = __version_tuple__[:2]
+ vllm_ver: tuple[str | int, ...] = vllm_version_tuple[:2]
+ # Skip if either version is dev (0, 0)
+ if omni_ver == (0, 0) or vllm_ver == (0, 0):
+ return
+
+ # Compare major.minor
+ if omni_ver != vllm_ver:
+ warnings.warn(
+ "vLLM and vLLM-Omni appear to have mismatched major/minor versions:\n"
+ f" --> vLLM-Omni version {__version__}\n"
+ f" --> vLLM version {vllm_version}\n"
+ "This will likely cause compatibility issues.",
+ RuntimeWarning,
+ stacklevel=2,
+ )
+
+
__all__ = ["__version__", "__version_tuple__"]
+
+# Run version check automatically when this module is imported
+try:
+ warn_if_misaligned_vllm_version()
+except ModuleNotFoundError:
+ # vLLM not installed (e.g., documentation builds)
+ pass
diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py
index 01ec23acb4..f37b2224ef 100644
--- a/vllm_omni/worker/gpu_ar_model_runner.py
+++ b/vllm_omni/worker/gpu_ar_model_runner.py
@@ -39,7 +39,9 @@
from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager
from vllm_omni.outputs import OmniModelRunnerOutput
+from vllm_omni.utils.mm_outputs import build_mm_cpu, to_payload_element
from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner
+from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin
logger = init_logger(__name__)
@@ -60,7 +62,7 @@ class ExecuteModelState(NamedTuple):
slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None
-class GPUARModelRunner(OmniGPUModelRunner):
+class GPUARModelRunner(OmniGPUModelRunner, OmniConnectorModelRunnerMixin):
"""Autoregressive GPU model runner that returns hidden states per request.
Follows the v0.12 two-phase execute/sample flow from GPUModelRunner, and
@@ -138,6 +140,125 @@ def _sampling_metadata_for_model_sampler(self, sampling_metadata):
return sampling_metadata
return replace(sampling_metadata, output_token_ids=output_token_ids)
+ def capture_model(self) -> int:
+ result = super().capture_model()
+ self._capture_talker_mtp_graphs()
+ return result
+
+ def _capture_talker_mtp_graphs(self) -> None:
+ from vllm_omni.worker.gpu_model_runner import CUDAGraphWrapper
+
+ if not self.has_talker_mtp or not isinstance(self.talker_mtp, CUDAGraphWrapper):
+ return
+
+ from vllm.compilation.monitor import set_cudagraph_capturing_enabled
+ from vllm.distributed.parallel_state import graph_capture
+
+ capture_sizes = self.compilation_config.cudagraph_capture_sizes
+ num_warmups = self.compilation_config.cudagraph_num_of_warmups
+ capture_sizes = sorted(capture_sizes, reverse=True)
+ logger.info("Capturing talker_mtp graphs for sizes %s", capture_sizes)
+
+ set_cudagraph_capturing_enabled(True)
+ try:
+ with torch.inference_mode(), graph_capture(device=self.device):
+ for bsz in capture_sizes:
+ _, batch_desc, _, _, _ = self._determine_batch_execution_and_padding(
+ num_tokens=bsz,
+ num_reqs=bsz,
+ num_scheduled_tokens_np=np.ones(bsz, dtype=np.int32),
+ max_num_scheduled_tokens=1,
+ use_cascade_attn=False,
+ )
+ n = batch_desc.num_tokens
+ ids = self.talker_mtp_input_ids.gpu[:n]
+ emb = self.talker_mtp_inputs_embeds.gpu[:n]
+ hid = self.last_talker_hidden.gpu[:n]
+ ts = self.text_step.gpu[:n]
+
+ for _ in range(num_warmups):
+ with set_forward_context(
+ None,
+ self.vllm_config,
+ cudagraph_runtime_mode=CUDAGraphMode.NONE,
+ batch_descriptor=batch_desc,
+ ):
+ self.talker_mtp(ids, emb, hid, ts)
+
+ with set_forward_context(
+ None,
+ self.vllm_config,
+ cudagraph_runtime_mode=CUDAGraphMode.FULL,
+ batch_descriptor=batch_desc,
+ ):
+ self.talker_mtp(ids, emb, hid, ts)
+ torch.cuda.synchronize()
+
+ logger.info("Captured talker_mtp graphs for %d sizes", len(capture_sizes))
+ except RuntimeError as e:
+ raise RuntimeError(
+ f"talker_mtp graph capture failed for a model that declared talker_mtp_graph_safe=True: {e}"
+ ) from e
+ finally:
+ set_cudagraph_capturing_enabled(False)
+
+ def _maybe_update_prefix_cache(
+ self,
+ hidden_states: torch.Tensor,
+ multimodal_outputs: dict,
+ num_tokens_unpadded: int,
+ num_tokens_padded: int,
+ ):
+ """If prefix caching is enabled and it's the last pipeline parallelism rank,
+ retrieve the hidden states & multimodal outputs from the prefix cache based
+ on our batch slot mappings.
+ """
+ # Cache hidden states if we've enabled hidden state prefix caching
+ # unless this isn't the last pipeline parallelism rank.
+ if self.omni_prefix_cache is not None and get_pp_group().is_last_rank:
+ # If this happens, it generally means the model is not following the correct
+ # interface yet and is therefore currently not compatible with prefix cache.
+ if multimodal_outputs is not None and not isinstance(multimodal_outputs, dict):
+ logger.warning_once(
+ "prefix caching expects mm outputs to be a dict, but got %s",
+ type(multimodal_outputs),
+ )
+
+ self.omni_prefix_cache.update_omni_tensor_prefix_cache(
+ hidden_states=hidden_states,
+ multimodal_outputs=multimodal_outputs,
+ num_tokens_unpadded=num_tokens_unpadded,
+ slot_mapping=self.input_batch.block_table[0].slot_mapping.cpu,
+ num_tokens_padded=num_tokens_padded,
+ )
+
+ def _maybe_get_combined_prefix_cache_tensors(
+ self,
+ hidden_states: torch.Tensor,
+ multimodal_outputs: dict,
+ num_scheduled_tokens: dict[str, int],
+ ) -> tuple[dict[str, torch.Tensor] | None, dict | None]:
+ """If prefix caching is enabled, extract the merged hidden states and multimodal outputs for
+ all requests in the batch (including those that aren't a hit on Prefix cache).
+ """
+ # Prior to applying the post-processing func, extract
+ # the prefix cached hidden states and multimodal states.
+ combined_hidden_states, combined_multimodal_outputs = None, None
+ if self.omni_prefix_cache is not None:
+ combined_hidden_states = self.omni_prefix_cache.get_merged_hidden_states(
+ query_start_loc=self.query_start_loc.cpu,
+ input_batch=self.input_batch,
+ hidden_states=hidden_states,
+ num_scheduled_tokens=num_scheduled_tokens,
+ )
+ combined_multimodal_outputs = self.omni_prefix_cache.get_merged_multimodal_states(
+ query_start_loc=self.query_start_loc.cpu,
+ input_batch=self.input_batch,
+ multimodal_outputs=multimodal_outputs,
+ num_scheduled_tokens=num_scheduled_tokens,
+ )
+ return combined_hidden_states, combined_multimodal_outputs
+
@torch.inference_mode()
def execute_model(
self,
@@ -199,30 +320,49 @@ def execute_model(
# Update persistent batch states.
deferred_state_corrections_fn = self._update_states(scheduler_output)
+ # Notify model of finished requests for state cleanup
+ if scheduler_output.finished_req_ids and hasattr(self.model, "on_requests_finished"):
+ self.model.on_requests_finished(scheduler_output.finished_req_ids)
+
if has_ec_transfer() and not get_ec_transfer().is_consumer:
with self.maybe_get_ec_connector_output(
scheduler_output,
encoder_cache=self.encoder_cache,
) as ec_connector_output:
self._execute_mm_encoder(scheduler_output)
- return make_empty_encoder_model_runner_output(scheduler_output)
+
+ kv_ids = self.kv_extracted_req_ids
+ self.kv_extracted_req_ids = None
+
+ output = make_empty_encoder_model_runner_output(scheduler_output)
+ if kv_ids:
+ output = copy(output)
+ output.kv_extracted_req_ids = kv_ids
+ return output
if not num_scheduled_tokens:
if (
self.parallel_config.distributed_executor_backend == "external_launcher"
and self.parallel_config.data_parallel_size > 1
):
- # this is a corner case when both external launcher
- # and DP are enabled, num_scheduled_tokens could be
- # 0, and has_unfinished_requests in the outer loop
- # returns True. before returning early here we call
- # dummy run to ensure coordinate_batch_across_dp
- # is called into to avoid out of sync issues.
self._dummy_run(1)
+
+ # Capture KV extraction results before early return;
+ # sample_tokens() is skipped on this path so the IDs
+ # would otherwise be silently overwritten next step.
+ kv_ids = self.kv_extracted_req_ids
+ self.kv_extracted_req_ids = None
+
if not has_kv_transfer_group():
- # Return empty ModelRunnerOutput if no work to do.
- return EMPTY_MODEL_RUNNER_OUTPUT
- return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
+ output = EMPTY_MODEL_RUNNER_OUTPUT
+ else:
+ output = self.kv_connector_no_forward(scheduler_output, self.vllm_config)
+
+ if kv_ids:
+ output = copy(output)
+ output.kv_extracted_req_ids = kv_ids
+
+ return output
if self.cache_config.kv_sharing_fast_prefill:
assert not self.num_prompt_logprobs, (
@@ -394,6 +534,15 @@ def execute_model(
hidden_states, multimodal_outputs = self.extract_multimodal_outputs(model_output)
+ # Cache hidden states & multimodal outputs if we've enabled hidden state
+ # prefix caching unless this isn't the last pipeline parallelism rank.
+ self._maybe_update_prefix_cache(
+ hidden_states=hidden_states,
+ multimodal_outputs=multimodal_outputs,
+ num_tokens_unpadded=num_tokens_unpadded,
+ num_tokens_padded=num_tokens_padded,
+ )
+
if not self.broadcast_pp_output:
# Common case.
if not get_pp_group().is_last_rank:
@@ -507,6 +656,23 @@ def _sample(
return super()._sample(logits, spec_decode_metadata)
+ @staticmethod
+ def _resolve_req_hidden_states(
+ hidden_states_cpu: torch.Tensor,
+ combined_hidden_states: dict[str, torch.Tensor] | None,
+ rid: str,
+ start: int,
+ end: int,
+ ):
+ if combined_hidden_states is not None:
+ # We always have all request IDs for prefix cache, even for
+ # partial cache misses, so this should never happen.
+ if rid not in combined_hidden_states:
+ raise RuntimeError("Request IDs in the batch are missing from the merged states!")
+ return combined_hidden_states[rid]
+ # Prefix caching is disabled
+ return hidden_states_cpu[start:end]
+
@torch.inference_mode()
def sample_tokens(
self,
@@ -515,6 +681,13 @@ def sample_tokens(
kv_extracted_req_ids = getattr(self, "kv_extracted_req_ids", None)
self.kv_extracted_req_ids = None
+ # Used for prefix cache
+ combined_hidden_states = None
+ combined_multimodal_outputs = None
+ # Used when we don't use prefix cache; prefix cache builds the payloads
+ # internally since it already needs to do this for the cached tensors
+ mm_cpu = {}
+
if self.execute_model_state is None:
kv_connector_output = self.kv_connector_output
self.kv_connector_output = None
@@ -546,6 +719,7 @@ def sample_tokens(
slot_mappings, # OMNI: unpack slot_mappings for drafter
) = self.execute_model_state
self.execute_model_state = None
+ seq_len = hidden_states.shape[0]
# Apply structured output bitmasks if present.
if grammar_output is not None:
@@ -667,65 +841,73 @@ def propose_draft_token_ids(sampled_token_ids):
dtype=np.int32,
)
+ # Prior to applying the post-processing func, extract
+ # the prefix cached hidden states and multimodal states.
+ if self.omni_prefix_cache is not None:
+ (
+ combined_hidden_states,
+ combined_multimodal_outputs,
+ ) = self._maybe_get_combined_prefix_cache_tensors(
+ hidden_states,
+ multimodal_outputs,
+ scheduler_output.num_scheduled_tokens,
+ )
+ # Otherwise we don't have the mm CPU data yet, so we still need to build it
+ if self.omni_prefix_cache is None:
+ mm_cpu = build_mm_cpu(multimodal_outputs)
+
self._process_additional_information_updates(
- hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output
+ hidden_states,
+ multimodal_outputs,
+ num_scheduled_tokens_np,
+ scheduler_output,
+ combined_hidden_states,
+ combined_multimodal_outputs,
)
- # Pre-copy multimodal tensors to CPU once (not per-request) to avoid
- # redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU.
- mm_cpu: dict[str, object] = {}
- if isinstance(multimodal_outputs, dict) and multimodal_outputs:
- for k, v in multimodal_outputs.items():
- try:
- if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]:
- mm_cpu[k] = v.detach().to("cpu").contiguous()
- elif isinstance(v, dict):
- sub_dict: dict[str, torch.Tensor] = {}
- for sk, sv in v.items():
- if isinstance(sv, torch.Tensor) and sv.shape[0] == hidden_states_cpu.shape[0]:
- sub_dict[str(sk)] = sv.detach().to("cpu").contiguous()
- if sub_dict:
- mm_cpu[k] = sub_dict
- elif isinstance(v, list):
- if len(v) == 0:
- continue
- cpu_list = []
- for elem in v:
- if isinstance(elem, torch.Tensor):
- cpu_list.append(elem.detach().to("cpu").contiguous())
- else:
- cpu_list.append(elem)
- mm_cpu[k] = cpu_list
- except Exception as e:
- logger.error(f"Error in merge multimodal outputs: {e}")
-
pooler_output: list[dict[str, object]] = []
for rid in req_ids_output_copy:
idx = req_id_to_index_output_copy[rid]
start = int(self.query_start_loc.cpu[idx])
sched = int(num_scheduled_tokens_np[idx])
end = start + sched
- hidden_slice = hidden_states_cpu[start:end]
- payload: dict[str, object] = {"hidden": hidden_slice}
- if mm_cpu:
- mm_payload: dict[str, object] = {}
- for k, v in mm_cpu.items():
- if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]:
- mm_payload[k] = v[start:end].contiguous()
- elif isinstance(v, dict):
- mm_payload[k] = {sk: sv[start:end].contiguous() for sk, sv in v.items()}
- elif isinstance(v, list):
- element = v[idx] if idx < len(v) else v[0]
- # Clone tensors to avoid cross-request aliasing
- if isinstance(element, torch.Tensor):
- element = element.clone()
- mm_payload[k] = element
- elif isinstance(v, torch.Tensor):
- # List-derived tensor payloads are request-invariant; clone to
- # avoid accidental cross-request aliasing on downstream mutation.
- mm_payload[k] = v.clone()
- else:
- mm_payload[k] = v
+ # If prefix cache is enabled, we have already split everything
+ # by request and converted the states to CPU tensors
+ req_hidden_states = self._resolve_req_hidden_states(
+ hidden_states_cpu,
+ combined_hidden_states,
+ rid,
+ start,
+ end,
+ )
+ payload: dict[str, object] = {"hidden": req_hidden_states}
+
+ mm_payload: dict[str, object] = {}
+ if combined_multimodal_outputs or mm_cpu:
+ if combined_multimodal_outputs:
+ # Prefix cache enabled; all items have already been processed
+ # and split apart for each request as needed, and all tensors
+ # have already been detached to the CPU. The only exception is
+ # lists, which we keep as passthrough data for consistent behavior
+ # in postprocess.
+ for mm_key in combined_multimodal_outputs.keys():
+ value = combined_multimodal_outputs[mm_key][rid]
+ if isinstance(value, list):
+ mm_payload[mm_key] = value[idx] if idx < len(value) else value[0]
+ else:
+ mm_payload[mm_key] = value
+
+ else:
+ # Prefix cache disabled; we still need to process the data
+ for mm_key, mm_val in mm_cpu.items():
+ mm_payload[mm_key] = to_payload_element(
+ element=mm_val,
+ idx=idx,
+ start=start,
+ end=end,
+ pass_lists_through=False,
+ seq_len=seq_len,
+ )
payload.update(mm_payload)
pooler_output.append(payload)
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py
index d95b676f6d..f10115c8e9 100644
--- a/vllm_omni/worker/gpu_generation_model_runner.py
+++ b/vllm_omni/worker/gpu_generation_model_runner.py
@@ -39,11 +39,12 @@
from vllm_omni.outputs import OmniModelRunnerOutput
from vllm_omni.worker.gpu_ar_model_runner import ExecuteModelState
from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner
+from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin
logger = logging.getLogger(__name__)
-class GPUGenerationModelRunner(OmniGPUModelRunner):
+class GPUGenerationModelRunner(OmniGPUModelRunner, OmniConnectorModelRunnerMixin):
"""Generation model runner for vLLM-Omni (non-autoregressive).
- Reuses GPUModelRunner preparation, multimodal handling, and TP/PP/DP glue.
diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py
index 35e1598435..d1c15eac64 100644
--- a/vllm_omni/worker/gpu_model_runner.py
+++ b/vllm_omni/worker/gpu_model_runner.py
@@ -20,6 +20,7 @@
from vllm.v1.worker.gpu_model_runner import GPUModelRunner, IntermediateTensors, PerLayerAttnMetadata
from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices
+from vllm_omni.core.prefix_cache import OmniTensorPrefixCache
from vllm_omni.engine.serialization import deserialize_additional_information
from vllm_omni.model_executor.layers.rotary_embedding.mrope import OmniMRotaryEmbedding as MRotaryEmbedding
from vllm_omni.model_executor.models.output_templates import OmniOutput
@@ -43,6 +44,9 @@ def __init__(self, *args, **kwargs):
self.model_intermediate_buffer: dict[str, dict[str, Any]] = {}
self._omni_num_scheduled_tokens_np: np.ndarray | None = None
self._omni_last_model_output: object | None = None
+ # The Omni tensor prefix cache will be allocated
+ # when we initialize the metadata builders if enabled
+ self.omni_prefix_cache = None
def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes):
"""Override to fix scheduler_metadata buffer size for FA3 + CUDA graph.
@@ -70,6 +74,16 @@ def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes):
device=sm.device,
)
+ # Initialize the wrapper for both multimodal output tensors
+ # and for hidden states to be passed between stages
+ if self.cache_config.enable_prefix_caching:
+ self.omni_prefix_cache = OmniTensorPrefixCache(
+ num_blocks=kv_cache_config.num_blocks,
+ block_size=self.cache_config.block_size,
+ hidden_size=self.model_config.get_hidden_size(),
+ hs_dtype=self.dtype,
+ )
+
@instrument(span_name="Loading (GPU)")
def load_model(self, *args, **kwargs) -> None:
super().load_model(*args, **kwargs)
@@ -83,11 +97,9 @@ def load_model(self, *args, **kwargs) -> None:
self.has_talker_mtp = True
cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None
- # Only wrap talker_mtp in CUDAGraphWrapper for Omni models that
- # have a separate .talker sub-module. TTS models' code predictor
- # has internal AR loops / torch.multinomial — not graph-safe.
has_separate_talker = getattr(self.model, "talker", None) is not None
- if cudagraph_mode.has_full_cudagraphs() and has_separate_talker:
+ talker_mtp_graph_safe = getattr(self.model, "talker_mtp_graph_safe", False)
+ if cudagraph_mode.has_full_cudagraphs() and (has_separate_talker or talker_mtp_graph_safe):
self.talker_mtp = CUDAGraphWrapper(talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL)
# TTS exposes mtp_hidden_size; Omni uses hf_text_config.hidden_size.
hidden_size = int(
@@ -236,6 +248,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput"):
The SamplingMetadata is updated and copied to the GPU if there is a
new/resumed/paused/finished request in the batch.
"""
+ # Used for prefix cache
+ if self.omni_prefix_cache is not None:
+ self.omni_prefix_cache.reset_prefix_cached_new_req_ids()
+
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
@@ -292,10 +308,18 @@ def _update_states(self, scheduler_output: "SchedulerOutput"):
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
if req_id in self.requests:
+ self._update_streaming_input_additional_info(new_req_data, req_id)
req_state = self._update_streaming_request(req_id, new_req_data)
reqs_to_add.append(req_state)
continue
+ # Since this is the first time the request has been scheduled,
+ # num_computed_tokens > 0 means that we have a hit in prefix
+ # caching; mark it so that we can manage the hidden states
+ # later on as needed.
+ if self.omni_prefix_cache is not None and new_req_data.num_computed_tokens > 0:
+ self.omni_prefix_cache.add_prefix_cached_new_req_id(req_id)
+
sampling_params = new_req_data.sampling_params
pooling_params = new_req_data.pooling_params
@@ -1012,6 +1036,8 @@ def _process_additional_information_updates(
multimodal_outputs: object,
num_scheduled_tokens_np: np.ndarray,
scheduler_output: "SchedulerOutput",
+ combined_hidden_states: dict[str, torch.Tensor] | None = None,
+ combined_multimodal_outputs: dict[str, object] | None = None,
) -> None:
"""Process model-provided per-request updates and merge into model_intermediate_buffer."""
try:
@@ -1020,21 +1046,31 @@ def _process_additional_information_updates(
if hasattr(self.model, "has_postprocess") and self.model.has_postprocess:
for req_index, req_id in enumerate(self.input_batch.req_ids):
req_infos = self.model_intermediate_buffer.get(req_id, {})
- start_offset = int(self.query_start_loc.cpu[req_index])
- sched_tokens = int(num_scheduled_tokens_np[req_index])
- s, e = start_offset, start_offset + sched_tokens
- # only consider to store data into update dict.
- hidden_states_slice = hidden_states[s:e]
+ if combined_hidden_states:
+ # Combined hidden states contains all hidden states for every request
+ hidden_states_slice = combined_hidden_states[req_id]
+ else:
+ start_offset = int(self.query_start_loc.cpu[req_index])
+ sched_tokens = int(num_scheduled_tokens_np[req_index])
+ s, e = start_offset, start_offset + sched_tokens
+ # only consider to store data into update dict.
+ hidden_states_slice = hidden_states[s:e]
+
+ if combined_multimodal_outputs:
+ # NOTE this is a bit ugly, but the mm data is structured as a list of
+ # keys mapping to request IDs, and if enabled, we will always have all
+ # request IDs in every subdict, including for cache misses.
+ mm_out = {k: v[req_id] for k, v in combined_multimodal_outputs.items()}
+ else:
+ mm_out = multimodal_outputs
update_dict = self.model.postprocess(
- hidden_states_slice, multimodal_outputs=multimodal_outputs, **req_infos
+ hidden_states_slice,
+ multimodal_outputs=mm_out,
+ **req_infos,
)
self._update_intermediate_buffer(req_id, update_dict)
except Exception as e:
- logger.error(
- f"Error merging for requests:{self.input_batch.req_ids} "
- f"additional information update: {e}, with the multimodal_outputs "
- f"as {multimodal_outputs}"
- )
+ logger.error(f"Error merging for requests:{self.input_batch.req_ids} additional information update: {e}")
import traceback
traceback.print_exc()
@@ -1243,6 +1279,7 @@ def _preprocess(
span_len = int(e) - int(s)
# call the custom process function
+ req_infos["request_id"] = req_id
embed_slice = inputs_embeds[s:e] if inputs_embeds is not None else None
req_input_ids, req_embeds, update_dict = self.model.preprocess(
input_ids=input_ids[s:e], input_embeds=embed_slice, **req_infos
@@ -1378,3 +1415,30 @@ def _update_intermediate_buffer(self, req_id: str, upd: dict) -> None:
def _merge_additional_information_update(self, req_id, upd):
logger.warning_once("_merge_additional_information_update is deprecated, use _update_intermediate_buffer")
return self._update_intermediate_buffer(req_id, upd)
+
+ def _update_streaming_input_additional_info(self, new_req_data, req_id):
+ # For streaming input prefill case only. Update buffer from last segment input
+ cached_additional_info = self.model_intermediate_buffer.get(req_id, {})
+ if cached_additional_info:
+ payload_info = getattr(new_req_data, "additional_information", None)
+ inc_info = deserialize_additional_information(payload_info)
+ if isinstance(inc_info, dict) and inc_info:
+ merged_info = dict(cached_additional_info)
+ for key, value in inc_info.items():
+ accumulated_keys: set[str] = set()
+ if hasattr(self, "model") and hasattr(self.model, "streaming_accumulated_keys"):
+ accumulated_keys = self.model.streaming_accumulated_keys
+ if key in accumulated_keys and isinstance(value, torch.Tensor):
+ inc_tensor = value.detach().to("cpu").contiguous()
+ old_tensor = merged_info.get(key)
+ if old_tensor is None:
+ merged_info[key] = inc_tensor
+ else:
+ merged_info[key] = torch.cat((old_tensor, inc_tensor), dim=0)
+ continue
+
+ # Default for other keys: latest value.
+ merged_info[key] = value
+ merged_info["num_processed_tokens"] = 0
+ self.model_intermediate_buffer[req_id] = merged_info
+ setattr(self.requests[req_id], "additional_information_cpu", merged_info)
diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py
new file mode 100644
index 0000000000..e0df3ba3d7
--- /dev/null
+++ b/vllm_omni/worker/omni_connector_model_runner_mixin.py
@@ -0,0 +1,2125 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unified data-plane communication mixin for Model Runners.
+
+All connector.put()/get() calls are consolidated here. Background I/O
+threads handle async_chunk and full_payload_mode transfers; KV cache is delegated to
+the existing OmniKVTransferManager (to be absorbed later).
+
+The mixin reports transfer results via OmniConnectorOutput so that the
+Scheduler can make scheduling decisions without ever touching a connector.
+"""
+
+from __future__ import annotations
+
+import importlib
+import inspect
+import os
+import threading
+from collections import defaultdict, deque
+from types import SimpleNamespace
+from typing import TYPE_CHECKING, Any
+
+import torch
+from vllm.distributed.parallel_state import get_tp_group
+from vllm.logger import init_logger
+
+from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory
+from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec
+from vllm_omni.outputs import OmniConnectorOutput
+from vllm_omni.worker.payload_span import (
+ THINKER_DECODE_EMBEDDINGS_KEY,
+ THINKER_DECODE_TOKEN_END_KEY,
+ THINKER_DECODE_TOKEN_START_KEY,
+ THINKER_OUTPUT_TOKEN_IDS_KEY,
+ get_tensor_span,
+ merge_tensor_spans,
+)
+
+if TYPE_CHECKING:
+ from vllm_omni.distributed.omni_connectors.connectors.base import (
+ OmniConnectorBase,
+ )
+ from vllm_omni.distributed.omni_connectors.kv_transfer_manager import (
+ OmniKVTransferManager,
+ )
+
+logger = init_logger(__name__)
+
+
+class OmniConnectorModelRunnerMixin:
+ """Unified data-plane communication mixin for Model Runners.
+
+ Provides three transfer modes through a single pair of bg I/O threads:
+ - **full_payload_mode**: ``recv_full_payload_inputs`` / ``send_full_payload_outputs``
+ - **Streaming (async_chunk)**: ``recv_chunk`` / ``send_chunk``
+ - **KV cache**: ``send_kv_cache`` / ``recv_kv_cache`` (delegates to
+ the existing ``OmniKVTransferManager``)
+
+ The mixin owns connector instances and background threads. It never
+ touches scheduling queues -- readiness is communicated to the Scheduler
+ via ``OmniConnectorOutput``.
+ """
+
+ # ------------------------------------------------------------------ #
+ # Init / Shutdown
+ # ------------------------------------------------------------------ #
+
+ def init_omni_connectors(
+ self,
+ vllm_config: Any,
+ model_config: Any,
+ kv_transfer_manager: OmniKVTransferManager | None = None,
+ ) -> None:
+ """Initialize connectors and background threads.
+
+ Args:
+ vllm_config: Full vLLM config object.
+ model_config: Stage-level model config with connector settings.
+ kv_transfer_manager: Existing KV transfer manager to delegate to.
+ """
+ self._omni_connector: OmniConnectorBase | None = self._create_connector(model_config)
+ self._kv_transfer_manager = kv_transfer_manager
+
+ self._async_chunk: bool = getattr(model_config, "async_chunk", False)
+ self._model_mode: str = getattr(model_config, "worker_type", "ar")
+ stage_id = getattr(model_config, "stage_id", 0)
+ if isinstance(stage_id, str):
+ stage_id = int(stage_id)
+ self._stage_id: int = stage_id if isinstance(stage_id, int) else 0
+
+ self._custom_process_func_path, self._custom_process_func = self._load_custom_func(model_config)
+ self._custom_process_supports_is_finished = self._custom_process_supports_is_finished_kwarg()
+ logger.info(
+ "[Stage-%s] init_omni_connectors: async_chunk=%s, custom_process_func=%s, connector=%s, func_path=%s",
+ self._stage_id,
+ self._async_chunk,
+ self._custom_process_func,
+ type(self._omni_connector).__name__ if self._omni_connector else None,
+ self._custom_process_func_path,
+ )
+
+ # -- next stage ID (from connector config or default stage_id + 1) --
+ self._next_stage_id: int = self._resolve_next_stage_id(model_config)
+
+ # -- heterogeneous TP rank support --
+ rank_cfg = self._parse_rank_mapping(model_config)
+ self._from_tp: int = rank_cfg["from_tp"]
+ self._to_tp: int = rank_cfg["to_tp"]
+ self._local_rank: int = rank_cfg["local_rank"]
+ if self._kv_transfer_manager is not None:
+ self._kv_transfer_manager.kv_send_key_builder = self.get_rank_aware_kv_send_keys
+ self._kv_transfer_manager.kv_recv_key_builder = self.get_rank_aware_kv_keys
+ self._kv_transfer_manager.kv_payload_merger = self._merge_rank_sharded_kv_payloads
+ self._kv_transfer_manager.kv_payload_slicer = self._slice_rank_sharded_kv_payload
+
+ # -- chunk index tracking (ported from OmniChunkTransferAdapter) --
+ self._put_req_chunk: dict[str, int] = defaultdict(int)
+ self._get_req_chunk: dict[str, int] = defaultdict(int)
+ # Send-side async accumulation / staging buffer. Receive-side payload
+ # ownership lives in ``_local_stage_payload_cache``.
+ self._send_side_request_payload: dict[str, dict[str, Any]] = {}
+ self._code_prompt_token_ids: dict[str, list[list[int]]] = defaultdict(list)
+ self._request_ids_mapping: dict[str, str] = {}
+
+ # -- async I/O state (shared by chunk + full_payload_mode) --
+ self._pending_load_reqs: dict[str, Any] = {}
+ self._finished_load_reqs: set[str] = set()
+ self._pending_save_reqs: dict[str, deque] = {}
+ self._pending_save_counts: dict[str, int] = defaultdict(int)
+ self._deferred_send_cleanup: set[str] = set()
+ # -- per-cycle output accumulator --
+ self._chunk_ready_req_ids: set[str] = set()
+ self._chunk_finished_req_ids: set[str] = set()
+ self._stage_recv_req_ids: set[str] = set()
+ self._full_payload_pending_broadcast_req_ids: set[str] = set()
+ self._async_chunk_updated_req_ids: set[str] = set()
+
+ # -- Model Runner local payload cache (RFC §2.4) --
+ # Full stage payloads land here first on the recv side. We
+ # intentionally do not write connector recv results straight into
+ # `model_intermediate_buffer`: runner-owned runtime state is
+ # materialized later by `_sync_local_stage_payloads()` on the
+ # model thread. This keeps recv timing separate from execute-step
+ # visibility and avoids mixing connector I/O with model runtime
+ # ownership.
+ self._local_stage_payload_cache: dict[str, dict[str, Any]] = {}
+ # Lightweight scheduling metadata pending delivery to the Scheduler.
+ self._local_request_metadata: dict[str, dict[str, Any]] = {}
+
+ # -- persistent set of request IDs whose chunk stream is complete --
+ # Prevents re-registration after the finish sentinel has been received.
+ self._chunk_stream_completed: set[str] = set()
+
+ # -- full_payload_mode: accumulate latest pooler_output per request,
+ # send only when the request finishes (next-cycle flush) --
+ self._pending_full_payload_send: dict[str, tuple[Any, Any]] = {}
+
+ # -- KV sent accumulator --
+ self._kv_sent_req_ids: list[str] = []
+
+ # -- KV transfer lifecycle (absorbed from scheduler) --
+ # Requests marked for KV transfer: {req_id: {seq_len, block_ids}}
+ self._kv_pending_transfers: dict[str, dict[str, Any]] = {}
+ # Requests whose KV transfer has been submitted but not yet acked
+ self._kv_active_transfers: set[str] = set()
+ # Requests whose KV transfer is complete (acked by kv_extracted_req_ids)
+ self._kv_completed_transfers: set[str] = set()
+ # Dedup guard: requests that have already triggered KV transfer
+ self._kv_triggered_requests: set[str] = set()
+
+ self._lock = threading.Lock()
+ self._stop_event = threading.Event()
+ self._work_available = threading.Event()
+
+ # Start background threads only when there's a connector
+ self._recv_thread: threading.Thread | None = None
+ self._save_thread: threading.Thread | None = None
+ if self._omni_connector is not None:
+ self._recv_thread = threading.Thread(
+ target=self._recv_loop,
+ daemon=True,
+ name="omni-mixin-recv",
+ )
+ self._recv_thread.start()
+ self._save_thread = threading.Thread(
+ target=self._save_loop,
+ daemon=True,
+ name="omni-mixin-save",
+ )
+ self._save_thread.start()
+
+ def shutdown_omni_connectors(self) -> None:
+ """Stop background threads and release connector resources."""
+ self._stop_event.set()
+ if self._recv_thread is not None:
+ self._recv_thread.join(timeout=5)
+ if self._save_thread is not None:
+ self._save_thread.join(timeout=5)
+ if self._omni_connector is not None:
+ try:
+ self._omni_connector.close()
+ except Exception:
+ pass
+
+ def cleanup_finished_request(self, req_id: str) -> None:
+ """Clean up per-request state after a request is fully finished.
+
+ Call this when a request is freed from the model runner to prevent
+ memory leaks in the mixin's tracking dicts/sets. The external
+ request ID is resolved before cleaning up ``_put_req_chunk`` which
+ is keyed by external ID.
+ """
+ ext_id = self._request_ids_mapping.pop(req_id, None)
+ send_req_id = ext_id if ext_id is not None else req_id
+
+ with self._lock:
+ if self._pending_save_counts.get(send_req_id, 0):
+ self._deferred_send_cleanup.add(send_req_id)
+ else:
+ self._put_req_chunk.pop(send_req_id, None)
+ self._send_side_request_payload.pop(send_req_id, None)
+ self._code_prompt_token_ids.pop(send_req_id, None)
+ self._kv_pending_transfers.pop(req_id, None)
+ self._kv_active_transfers.discard(req_id)
+ self._kv_completed_transfers.discard(req_id)
+ self._kv_triggered_requests.discard(req_id)
+ self._cleanup_recv_delivery_state(req_id)
+
+ def drop_inactive_request_delivery_state(self, req_id: str) -> None:
+ """Clear recv-side state for inactive requests."""
+ ext_id = self._request_ids_mapping.pop(req_id, None)
+ if hasattr(self, "_lock"):
+ with self._lock:
+ self._drop_send_side_payload_state(req_id, ext_id)
+ else:
+ self._drop_send_side_payload_state(req_id, ext_id)
+ self._cleanup_recv_delivery_state(req_id)
+
+ def _drop_send_side_payload_state(self, req_id: str, ext_id: str | None) -> None:
+ if ext_id is not None:
+ self._send_side_request_payload.pop(ext_id, None)
+ self._send_side_request_payload.pop(req_id, None)
+
+ def _cleanup_recv_delivery_state(self, req_id: str) -> None:
+ """Clear recv-side delivery-cycle state."""
+ if hasattr(self, "_lock"):
+ with self._lock:
+ self._clear_recv_delivery_state(req_id)
+ else:
+ self._clear_recv_delivery_state(req_id)
+
+ def _clear_recv_delivery_state(self, req_id: str) -> None:
+ self._get_req_chunk.pop(req_id, None)
+ self._pending_load_reqs.pop(req_id, None)
+ self._finished_load_reqs.discard(req_id)
+ self._chunk_ready_req_ids.discard(req_id)
+ self._chunk_finished_req_ids.discard(req_id)
+ self._chunk_stream_completed.discard(req_id)
+ self._stage_recv_req_ids.discard(req_id)
+ self._full_payload_pending_broadcast_req_ids.discard(req_id)
+ self._async_chunk_updated_req_ids.discard(req_id)
+ self._local_stage_payload_cache.pop(req_id, None)
+ self._local_request_metadata.pop(req_id, None)
+
+ def prune_inactive_requests(self, active_req_ids: Any) -> set[str]:
+ """Drop connector state for requests that no longer exist locally.
+
+ Preempted / unscheduled requests are expected to stay in
+ ``self.requests`` and therefore remain untouched. This only prunes
+ stale request IDs that have already fallen out of the active request
+ map, preventing background recv/send bookkeeping from outliving the
+ request lifecycle.
+ """
+ if active_req_ids is None:
+ return set()
+
+ active_req_ids = set(active_req_ids)
+ pending_req_ids = set(getattr(self, "_pending_load_reqs", {}).keys())
+ received_req_ids = set(getattr(self, "_stage_recv_req_ids", set()))
+ received_req_ids.update(getattr(self, "_full_payload_pending_broadcast_req_ids", set()))
+ received_req_ids.update(getattr(self, "_local_request_metadata", {}).keys())
+ # Pending recv requests may not yet be in the caller's active set
+ # (e.g. WAITING_FOR_CHUNK requests live in the coordinator's internal
+ # queues, not in model runner self.requests). Protect them so that
+ # legitimate waiting requests are not pruned.
+ #
+ # Likewise, a full payload can arrive on the background recv thread
+ # after the scheduler_output snapshot for the current execute_model()
+ # cycle was already materialized. Those requests may briefly live only
+ # in recv-side buffers/local cache until the next scheduler cycle wakes
+ # them up; pruning them here drops the payload before stage_recv can be
+ # published.
+ active_req_ids.update(pending_req_ids)
+ active_req_ids.update(received_req_ids)
+ stale_req_ids: set[str] = set()
+
+ # NOTE: _pending_load_reqs is excluded from the scan list because
+ # all its entries are unconditionally protected above. The mixin
+ # cannot distinguish a legitimately-waiting pending recv from an
+ # orphaned one (only the coordinator/scheduler knows).
+ #
+ # Requests with freshly received full payloads / local stage payloads
+ # are also protected above. Their scheduler wake-up may lag the recv
+ # thread by one execute_model() cycle, especially when the request was
+ # added after the current scheduler_output snapshot.
+ #
+ # Orphaned pending recv entries (e.g. from upstream stage crash)
+ # are handled by OmniSchedulingCoordinator.collect_timed_out_request_ids()
+ # which detects wait-time violations. The scheduler then removes the
+ # request from its queues, sets FINISHED_ERROR, and calls _free_request()
+ # which ultimately triggers cleanup_finished_request() here.
+ for attr_name in (
+ "_request_ids_mapping",
+ "_get_req_chunk",
+ "_finished_load_reqs",
+ "_chunk_ready_req_ids",
+ "_chunk_finished_req_ids",
+ "_chunk_stream_completed",
+ "_stage_recv_req_ids",
+ "_full_payload_pending_broadcast_req_ids",
+ "_async_chunk_updated_req_ids",
+ "_local_stage_payload_cache",
+ "_local_request_metadata",
+ "_kv_pending_transfers",
+ "_kv_active_transfers",
+ "_kv_completed_transfers",
+ "_kv_triggered_requests",
+ ):
+ state = getattr(self, attr_name, None)
+ if isinstance(state, dict):
+ stale_req_ids.update(req_id for req_id in state if req_id not in active_req_ids)
+ elif isinstance(state, set):
+ stale_req_ids.update(req_id for req_id in state if req_id not in active_req_ids)
+
+ for req_id in stale_req_ids:
+ self.cleanup_finished_request(req_id)
+
+ return stale_req_ids
+
+ # ------------------------------------------------------------------ #
+ # Local payload cache (RFC §2.4 – Model Runner ownership)
+ # ------------------------------------------------------------------ #
+
+ def put_local_stage_payload(self, req_id: str, payload: dict[str, Any]) -> None:
+ """Store a full stage payload in the local cache."""
+ self._local_stage_payload_cache[req_id] = payload
+
+ def get_local_stage_payload(self, req_id: str) -> dict[str, Any] | None:
+ """Read a stage payload without removing it."""
+ return self._local_stage_payload_cache.get(req_id)
+
+ def pop_local_stage_payload(self, req_id: str) -> dict[str, Any] | None:
+ """Remove and return a stage payload (consume after use)."""
+ return self._local_stage_payload_cache.pop(req_id, None)
+
+ def put_local_request_metadata(self, req_id: str, metadata: dict[str, Any]) -> None:
+ """Store lightweight scheduling metadata for a request."""
+ self._local_request_metadata[req_id] = metadata
+
+ def get_local_request_metadata(self, req_id: str) -> dict[str, Any] | None:
+ """Retrieve scheduling metadata for a request."""
+ return self._local_request_metadata.get(req_id)
+
+ # ------------------------------------------------------------------ #
+ # Scheduling metadata extraction
+ # ------------------------------------------------------------------ #
+
+ _SCHEDULING_METADATA_KEYS = (
+ "next_stage_prompt_len",
+ "code_predictor_codes",
+ "left_context_size",
+ )
+
+ @classmethod
+ def _extract_scheduling_metadata(cls, payload: dict[str, Any]) -> dict[str, Any]:
+ """Extract only the fields the scheduler needs from a full payload."""
+ return {k: payload[k] for k in cls._SCHEDULING_METADATA_KEYS if k in payload}
+
+ _NON_CONSUMABLE_PAYLOAD_KEYS = {
+ "finished",
+ "override_keys",
+ "next_stage_prompt_len",
+ "left_context_size",
+ THINKER_OUTPUT_TOKEN_IDS_KEY,
+ THINKER_DECODE_TOKEN_START_KEY,
+ THINKER_DECODE_TOKEN_END_KEY,
+ }
+
+ @staticmethod
+ def _payload_value_has_content(value: Any) -> bool:
+ if value is None:
+ return False
+ if isinstance(value, torch.Tensor):
+ return value.numel() > 0
+ if isinstance(value, (list, tuple, dict, set)):
+ return len(value) > 0
+ return True
+
+ @classmethod
+ def _payload_is_consumable(cls, payload: dict[str, Any] | None) -> bool:
+ """Return True when an async payload can drive a real forward step.
+
+ Metadata-only wake-ups should not transition WAITING_FOR_CHUNK requests
+ back to schedulable state. In particular, a widened token horizon without
+ any newly visible thinker decode embeds should not force a placeholder-only
+ talker decode step.
+ """
+ if not isinstance(payload, dict) or not payload:
+ return False
+
+ decode_embeddings = payload.get(THINKER_DECODE_EMBEDDINGS_KEY)
+ if isinstance(decode_embeddings, torch.Tensor):
+ if decode_embeddings.ndim == 0:
+ return True
+ return decode_embeddings.numel() > 0 and decode_embeddings.shape[0] > 0
+
+ if "code_predictor_codes" in payload:
+ code_predictor_codes = payload.get("code_predictor_codes")
+ if isinstance(code_predictor_codes, torch.Tensor):
+ return code_predictor_codes.numel() > 0
+ # Codec code 0 is valid; non-empty code payloads are consumable.
+ if hasattr(code_predictor_codes, "__len__"):
+ return len(code_predictor_codes) > 0
+ else:
+ return code_predictor_codes is not None
+
+ for key, value in payload.items():
+ if key in cls._NON_CONSUMABLE_PAYLOAD_KEYS:
+ continue
+ if cls._payload_value_has_content(value):
+ return True
+ return False
+
+ @staticmethod
+ def _get_local_tp_group() -> Any | None:
+ """Return the local TP group when tensor parallelism is initialized."""
+ try:
+ return get_tp_group()
+ except Exception:
+ return None
+
+ def _recv_ordinary_stage_result(
+ self,
+ connector: OmniConnectorBase,
+ from_stage: str,
+ to_stage: str,
+ connector_get_key: str,
+ ) -> Any:
+ """Receive one ordinary non-KV stage payload on the local leader rank only."""
+ tp_group = self._get_local_tp_group()
+ if tp_group is None or getattr(tp_group, "world_size", 1) <= 1:
+ return connector.get(from_stage, to_stage, connector_get_key)
+ if not self.is_data_transfer_rank():
+ return None
+ return connector.get(from_stage, to_stage, connector_get_key)
+
+ def _recv_full_payload_result(
+ self,
+ connector: OmniConnectorBase,
+ from_stage: str,
+ to_stage: str,
+ connector_get_key: str,
+ ) -> Any:
+ """Receive one full-payload transfer on the local leader rank only."""
+ return self._recv_ordinary_stage_result(
+ connector,
+ from_stage,
+ to_stage,
+ connector_get_key,
+ )
+
+ def _recv_async_chunk_result(
+ self,
+ connector: OmniConnectorBase,
+ from_stage: str,
+ to_stage: str,
+ connector_get_key: str,
+ ) -> Any:
+ """Receive one ordinary async chunk on the local leader rank only."""
+ return self._recv_ordinary_stage_result(
+ connector,
+ from_stage,
+ to_stage,
+ connector_get_key,
+ )
+
+ @staticmethod
+ def _snapshot_payload(payload: Any) -> Any:
+ if isinstance(payload, dict):
+ return dict(payload)
+ return payload
+
+ def _broadcast_tp_payload_packet(self, packet: Any) -> Any:
+ """Broadcast one ordinary payload packet from TP rank 0 when TP is active."""
+ tp_group = self._get_local_tp_group()
+ if tp_group is None or getattr(tp_group, "world_size", 1) <= 1:
+ return packet
+ leader_packet = packet if self.is_data_transfer_rank() else None
+ return tp_group.broadcast_object(leader_packet, src=0)
+
+ def _apply_staged_payloads_locked(self, staged_payloads: dict[str, Any]) -> None:
+ for req_id, payload in staged_payloads.items():
+ self._local_stage_payload_cache[req_id] = self._snapshot_payload(payload)
+
+ def _collect_full_payload_results_locked(self) -> dict[str, Any] | None:
+ if not self._full_payload_pending_broadcast_req_ids:
+ return None
+ results: dict[str, Any] = {}
+ missing_req_ids: list[str] = []
+ for req_id in tuple(self._full_payload_pending_broadcast_req_ids):
+ payload = self._local_stage_payload_cache.get(req_id)
+ if payload is None:
+ missing_req_ids.append(req_id)
+ continue
+ results[req_id] = self._snapshot_payload(payload)
+ self._full_payload_pending_broadcast_req_ids.discard(req_id)
+ if missing_req_ids:
+ logger.warning(
+ "[Stage-%s] _collect_full_payload_results_locked: "
+ "pending full-payload reqs missing from local cache: %s",
+ self._stage_id,
+ missing_req_ids,
+ )
+ return results or None
+
+ def _collect_async_chunk_fanout_packet_locked(self) -> dict[str, Any] | None:
+ payload_req_ids = set(self._async_chunk_updated_req_ids)
+ payload_req_ids.update(self._finished_load_reqs)
+ payload_req_ids.update(self._chunk_finished_req_ids)
+ payload_req_ids.update(self._local_request_metadata)
+ if not (
+ payload_req_ids or self._finished_load_reqs or self._chunk_finished_req_ids or self._local_request_metadata
+ ):
+ return None
+
+ staged_payloads = {
+ req_id: self._snapshot_payload(self._local_stage_payload_cache[req_id])
+ for req_id in payload_req_ids
+ if req_id in self._local_stage_payload_cache
+ }
+ packet = {
+ "staged_payloads": staged_payloads,
+ "request_metadata": dict(self._local_request_metadata),
+ "newly_finished": set(self._finished_load_reqs),
+ "chunk_finished": set(self._chunk_finished_req_ids),
+ }
+
+ self._async_chunk_updated_req_ids.clear()
+ self._finished_load_reqs.clear()
+ self._chunk_finished_req_ids.clear()
+ self._local_request_metadata.clear()
+
+ for req_id in packet["chunk_finished"]:
+ if req_id not in self._local_stage_payload_cache:
+ continue
+ ext_req_id = self._request_ids_mapping.get(req_id, req_id)
+ self._send_side_request_payload.pop(ext_req_id, None)
+ if ext_req_id != req_id:
+ self._send_side_request_payload.pop(req_id, None)
+
+ return packet
+
+ def _apply_async_chunk_fanout_packet(self, packet: dict[str, Any]) -> None:
+ staged_payloads = packet.get("staged_payloads", {})
+ chunk_finished = set(packet.get("chunk_finished", ()))
+ with self._lock:
+ self._apply_staged_payloads_locked(staged_payloads)
+ for req_id in chunk_finished:
+ self._pending_load_reqs.pop(req_id, None)
+ self._chunk_stream_completed.add(req_id)
+
+ # ------------------------------------------------------------------ #
+ # full_payload_mode (recv_full_payload_inputs / send_full_payload_outputs)
+ # ------------------------------------------------------------------ #
+
+ def recv_full_payload_inputs(self, scheduler_output: Any) -> dict[str, Any] | None:
+ """Check for incoming full_payload_mode stage inputs (non-blocking).
+
+ Returns a dict mapping ``request_id -> engine_inputs`` for data
+ that has arrived, or ``None`` if nothing is ready. Stores full
+ payloads in the local cache and extracts scheduling metadata.
+ """
+ with self._lock:
+ results = self._collect_full_payload_results_locked() if self.is_data_transfer_rank() else None
+ results = self._broadcast_tp_payload_packet(results)
+ if not results:
+ return None
+ with self._lock:
+ self._stage_recv_req_ids.update(results.keys())
+ for req_id in results:
+ self._pending_load_reqs.pop(req_id, None)
+ self._apply_staged_payloads_locked(results)
+ for req_id, payload in results.items():
+ self._local_request_metadata[req_id] = self._extract_scheduling_metadata(payload)
+ logger.info(
+ "[Stage-%s] recv_full_payload_inputs: consumed %s reqs: %s, stage_recv_req_ids now=%s",
+ self._stage_id,
+ len(results),
+ list(results.keys()),
+ self._stage_recv_req_ids,
+ )
+ return results
+
+ @staticmethod
+ def _is_all_zero_tensor(t: Any) -> bool:
+ """Return True if *t* is a torch.Tensor whose elements are all zero."""
+ return isinstance(t, torch.Tensor) and t.numel() > 0 and not t.any()
+
+ def accumulate_full_payload_output(
+ self,
+ req_id: str,
+ pooler_output: Any,
+ request: Any,
+ ) -> None:
+ """Accumulate pooler_output for a request across steps (full_payload_mode).
+
+ Per-token tensors (2-D+, matching trailing dims) are concatenated
+ along dim-0. Scalar / global tensors (1-D or 0-D) are replaced
+ with the latest value.
+
+ All-zero tensors (e.g. ``code_predictor_codes`` emitted during
+ prefill) are dropped so that they do not pollute downstream stages
+ with garbage / noise frames.
+
+ The data is actually sent when ``flush_full_payload_outputs`` is called
+ with the finished request IDs from the next scheduler cycle.
+ """
+ # ---- Filter out all-zero tensors from the incoming pooler_output ----
+ filtered: dict[str, Any] = {}
+ dropped_zero_keys: list[tuple[str, tuple[int, ...]]] = []
+ for k, v in pooler_output.items():
+ if self._is_all_zero_tensor(v):
+ dropped_zero_keys.append((k, tuple(v.shape)))
+ continue # skip prefill zero-filled placeholders
+ filtered[k] = v
+ if dropped_zero_keys:
+ logger.info(
+ "[Stage-%s] accumulate_full_payload_output: req=%s dropped_zero_keys=%s",
+ self._stage_id,
+ req_id,
+ dropped_zero_keys,
+ )
+ pooler_output = filtered
+
+ existing = self._pending_full_payload_send.get(req_id)
+ if existing is None:
+ self._pending_full_payload_send[req_id] = (pooler_output, request)
+ return
+
+ prev_output, _ = existing
+ merged: dict[str, Any] = {}
+ for k in set(prev_output) | set(pooler_output):
+ v_new = pooler_output.get(k)
+ v_old = prev_output.get(k)
+ if v_new is None:
+ merged[k] = v_old
+ elif v_old is None:
+ merged[k] = v_new
+ elif (
+ isinstance(v_new, torch.Tensor)
+ and isinstance(v_old, torch.Tensor)
+ and v_new.dim() >= 2
+ and v_old.dim() >= 2
+ and v_new.shape[1:] == v_old.shape[1:]
+ ):
+ merged[k] = torch.cat([v_old, v_new], dim=0)
+ else:
+ merged[k] = v_new
+ self._pending_full_payload_send[req_id] = (merged, request)
+
+ def flush_full_payload_outputs(self, finished_req_ids: set[str]) -> None:
+ """Send accumulated full_payload outputs for requests that just finished."""
+ logger.info(
+ "[Stage-%s] flush_full_payload_outputs: finished_req_ids=%s, pending=%s",
+ self._stage_id,
+ finished_req_ids,
+ list(self._pending_full_payload_send.keys()),
+ )
+ to_send: dict[str, tuple[Any, Any]] = {}
+ for req_id in finished_req_ids:
+ entry = self._pending_full_payload_send.pop(req_id, None)
+ if entry is not None:
+ to_send[req_id] = entry
+ logger.info("[Stage-%s] flush_full_payload_outputs: to_send=%s", self._stage_id, list(to_send.keys()))
+ if to_send:
+ self.send_full_payload_outputs(scheduler_output=None, outputs=to_send)
+
+ def send_full_payload_outputs(
+ self,
+ scheduler_output: Any,
+ outputs: dict[str, tuple[Any, Any] | Any],
+ ) -> list[str]:
+ """Send full_payload stage outputs to the next stage via connector.
+
+ Args:
+ outputs: Mapping of ``req_id`` to either a
+ ``(pooling_output, request)`` tuple (preferred) or a raw
+ payload dict. When a tuple is supplied the request object
+ is forwarded to ``custom_process_stage_input_func``.
+
+ Returns list of request IDs successfully enqueued.
+ """
+ if self._omni_connector is None:
+ logger.info("[Stage-%s] send_full_payload_outputs: connector is None, skip", self._stage_id)
+ return []
+ if not self.is_data_transfer_rank():
+ logger.info(
+ "[Stage-%s] send_full_payload_outputs: not data_transfer_rank (rank=%s), skip",
+ self._stage_id,
+ self._local_rank,
+ )
+ return list(outputs.keys())
+ sent_ids: list[str] = []
+ next_stage_id = self._next_stage_id
+ for req_id, value in outputs.items():
+ if isinstance(value, tuple) and len(value) == 2:
+ raw_output, request = value
+ else:
+ raw_output, request = value, None
+
+ payload = raw_output
+ if self._custom_process_func is not None:
+ payload = self._build_custom_process_payload(
+ request_id=req_id,
+ request=request,
+ pooling_output=raw_output,
+ )
+ if payload is None:
+ continue
+ if payload is None:
+ logger.info("[Stage-%s] send_full_payload_outputs: payload is None for %s", self._stage_id, req_id)
+ continue
+ if isinstance(payload, dict):
+ code_predictor_codes = payload.get("code_predictor_codes")
+ if isinstance(code_predictor_codes, torch.Tensor):
+ code_len = int(code_predictor_codes.numel())
+ elif hasattr(code_predictor_codes, "__len__"):
+ code_len = len(code_predictor_codes)
+ else:
+ code_len = None
+ logger.info(
+ "[Stage-%s] send_full_payload_outputs: req=%s payload_keys=%s code_len=%s left_context_size=%s",
+ self._stage_id,
+ req_id,
+ sorted(payload.keys()),
+ code_len,
+ payload.get("left_context_size"),
+ )
+
+ external_req_id = self._resolve_external_req_id(request, req_id)
+ chunk_id = self._put_req_chunk[req_id]
+ self._put_req_chunk[req_id] += 1
+ connector_put_key = f"{external_req_id}_{self._stage_id}_{chunk_id}"
+
+ logger.info(
+ "[Stage-%s] send_full_payload_outputs: enqueue req=%s put_key=%s next_stage=%s",
+ self._stage_id,
+ req_id,
+ connector_put_key,
+ next_stage_id,
+ )
+ task = {
+ "stage_id": self._stage_id,
+ "next_stage_id": next_stage_id,
+ "put_key": connector_put_key,
+ "data": payload,
+ "request_id": req_id,
+ }
+ with self._lock:
+ self._pending_save_reqs.setdefault(req_id, deque()).append(task)
+ self._pending_save_counts[req_id] += 1
+ sent_ids.append(req_id)
+ if sent_ids:
+ self._work_available.set()
+ return sent_ids
+
+ def recv_stage_inputs(self, scheduler_output: Any) -> dict[str, Any] | None:
+ """Compatibility wrapper for ``recv_full_payload_inputs``."""
+ return self.recv_full_payload_inputs(scheduler_output)
+
+ def accumulate_batch_output(
+ self,
+ req_id: str,
+ pooler_output: Any,
+ request: Any,
+ ) -> None:
+ """Compatibility wrapper for ``accumulate_full_payload_output``."""
+ self.accumulate_full_payload_output(req_id, pooler_output, request)
+
+ def flush_batch_outputs(self, finished_req_ids: set[str]) -> None:
+ """Compatibility wrapper for ``flush_full_payload_outputs``."""
+ self.flush_full_payload_outputs(finished_req_ids)
+
+ def send_stage_outputs(
+ self,
+ scheduler_output: Any,
+ outputs: dict[str, tuple[Any, Any] | Any],
+ ) -> list[str]:
+ """Compatibility wrapper for ``send_full_payload_outputs``."""
+ return self.send_full_payload_outputs(scheduler_output, outputs)
+
+ # ------------------------------------------------------------------ #
+ # Streaming chunk mode (recv_chunk / send_chunk)
+ # ------------------------------------------------------------------ #
+
+ def register_chunk_recv(self, request: Any) -> None:
+ """Register a request for async chunk retrieval by the bg thread.
+
+ Stage-0 has no upstream producer so this is a no-op there.
+ Skips requests whose batch data has already been received to
+ prevent the bg thread from polling for non-existent chunks.
+ """
+ if self._stage_id == 0:
+ return
+ request_id = request.request_id
+ self._request_ids_mapping[request_id] = getattr(
+ request,
+ "external_req_id",
+ request_id,
+ )
+ with self._lock:
+ if request_id in self._stage_recv_req_ids:
+ return
+ # Don't re-register if the finish sentinel was already received
+ if request_id in self._chunk_stream_completed:
+ return
+ self._pending_load_reqs[request_id] = request
+ self._work_available.set()
+
+ def recv_chunk(self) -> dict[str, Any]:
+ """Collect chunks received by the bg thread since last call.
+
+ Returns a dict ``{request_id: chunk_payload}`` for newly arrived
+ chunks. Empty dict when nothing is ready.
+
+ This method reads from ``_finished_load_reqs`` without clearing
+ it -- ``get_omni_connector_output()`` is the sole consumer that
+ drains and resets ``_finished_load_reqs`` at the end of each
+ ``execute_model`` cycle.
+
+ Returns **shallow copies** of the cached payloads so that the
+ caller can read them without racing against the background recv
+ thread, which may concurrently mutate the live cache entries via
+ ``dict.update()``.
+ """
+ with self._lock:
+ finished = set(self._finished_load_reqs)
+ if not finished:
+ return {}
+ # Snapshot the payloads under the lock to avoid racing with
+ # _poll_single_request which does existing.update(payload_data)
+ # on the same dict objects.
+ result = {}
+ for rid in finished:
+ payload = self._local_stage_payload_cache.get(rid)
+ result[rid] = dict(payload) if isinstance(payload, dict) else payload
+
+ self._chunk_ready_req_ids.update(finished)
+ return result
+
+ def send_chunk(
+ self,
+ request: Any,
+ pooling_output: Any | None = None,
+ ) -> bool:
+ """Derive and enqueue one chunk for async sending.
+
+ Payload extraction runs in the caller thread (via
+ ``custom_process_stage_input_func``); the actual
+ ``connector.put()`` is done by the background save thread.
+ Non-KV data is identical across TP ranks; only rank 0 sends.
+ """
+ if self._omni_connector is None:
+ logger.warning("[Stage-%s] send_chunk: connector is None", self._stage_id)
+ return False
+ if not self.is_data_transfer_rank():
+ return True
+ raw_req_id = getattr(request, "request_id", None) or getattr(request, "req_id", None)
+ request_id = self._resolve_external_req_id(request, raw_req_id)
+ # Cache the internal→external mapping so that finish sentinels can
+ # resolve the external ID even after the request is freed.
+ if raw_req_id and raw_req_id != request_id:
+ self._request_ids_mapping.setdefault(raw_req_id, request_id)
+ chunk_id = self._put_req_chunk[request_id]
+
+ payload_data = self._build_custom_process_payload(
+ request_id=request_id,
+ request=request,
+ pooling_output=pooling_output,
+ )
+ if payload_data is None:
+ if chunk_id == 0:
+ logger.warning(
+ "[Stage-%s] send_chunk: payload is None for req=%s chunk=%s (process_func=%s)",
+ self._stage_id,
+ request_id,
+ chunk_id,
+ self._custom_process_func,
+ )
+ return False
+
+ self._put_req_chunk[request_id] += 1
+ next_stage_id = self._next_stage_id
+ connector_put_key = f"{request_id}_{self._stage_id}_{chunk_id}"
+
+ if chunk_id == 0:
+ logger.info(
+ "[Stage-%s] send_chunk: first chunk enqueued, req=%s key=%s",
+ self._stage_id,
+ request_id,
+ connector_put_key,
+ )
+
+ task = {
+ "stage_id": self._stage_id,
+ "next_stage_id": next_stage_id,
+ "put_key": connector_put_key,
+ "data": payload_data,
+ "request_id": request_id,
+ }
+ with self._lock:
+ self._pending_save_reqs.setdefault(request_id, deque()).append(task)
+ self._pending_save_counts[request_id] += 1
+ self._work_available.set()
+ return True
+
+ # ------------------------------------------------------------------ #
+ # KV cache (delegates to OmniKVTransferManager)
+ # ------------------------------------------------------------------ #
+
+ def send_kv_cache(
+ self,
+ finished_reqs: dict[str, dict[str, Any]],
+ kv_caches: list[torch.Tensor],
+ block_size: int,
+ cache_dtype: str,
+ request_id_resolver: Any | None = None,
+ ) -> list[str]:
+ """Send KV cache for finished requests.
+
+ Delegates to the existing ``OmniKVTransferManager``.
+ """
+ if self._kv_transfer_manager is None:
+ return list(finished_reqs.keys()) if finished_reqs else []
+ result = self._kv_transfer_manager.handle_finished_requests_kv_transfer(
+ finished_reqs=finished_reqs,
+ kv_caches=kv_caches,
+ block_size=block_size,
+ cache_dtype=cache_dtype,
+ request_id_resolver=request_id_resolver,
+ )
+ if result:
+ self._kv_sent_req_ids.extend(result)
+ return result
+
+ def recv_kv_cache(
+ self,
+ request_id: str,
+ target_device: torch.device | None = None,
+ ) -> tuple[dict[str, Any] | None, int]:
+ """Receive KV cache for a request.
+
+ Delegates to the existing ``OmniKVTransferManager``.
+ """
+ if self._kv_transfer_manager is None:
+ return None, 0
+ return self._kv_transfer_manager.receive_kv_cache_for_request(
+ request_id=request_id,
+ target_device=target_device,
+ )
+
+ def receive_cfg_companion_kv_payloads(
+ self,
+ cfg_request_ids: dict[str, str],
+ target_device: torch.device | None = None,
+ ) -> dict[str, tuple[dict[str, Any] | None, int]]:
+ """Receive raw CFG companion KV payloads keyed by role."""
+ return {
+ role: self.recv_kv_cache(companion_rid, target_device=target_device)
+ for role, companion_rid in cfg_request_ids.items()
+ }
+
+ def receive_multi_kv_cache(
+ self,
+ req: Any,
+ cfg_kv_collect_func: Any | None = None,
+ target_device: torch.device | None = None,
+ ) -> bool:
+ """Receive primary and optional companion KV caches for a request.
+
+ The mixin owns the runner-facing orchestration: primary KV receive,
+ companion payload fetch, and applying any model-specific CFG fields back
+ onto ``req.sampling_params``.
+ """
+ if self._kv_transfer_manager is None:
+ return False
+
+ request_id = getattr(req, "request_id", None) or (
+ req.request_ids[0] if hasattr(req, "request_ids") and req.request_ids else None
+ )
+ if not request_id:
+ logger.warning("Request has no ID, cannot receive KV cache")
+ return False
+
+ active_requests = getattr(self, "requests", None)
+ if active_requests is not None and request_id not in active_requests:
+ logger.info("Skip receiving KV cache for inactive request %s", request_id)
+ return False
+
+ primary_ok = False
+ data, _size = self.recv_kv_cache(request_id, target_device=target_device)
+ if data:
+ self._kv_transfer_manager.apply_kv_cache_to_request(req, data)
+ primary_ok = True
+
+ cfg_ids = getattr(getattr(req, "sampling_params", None), "cfg_kv_request_ids", None)
+ if cfg_ids and cfg_kv_collect_func:
+ try:
+ cfg_role_payloads = self.receive_cfg_companion_kv_payloads(
+ cfg_ids,
+ target_device=target_device,
+ )
+ cfg_kvs = cfg_kv_collect_func(request_id, cfg_role_payloads)
+ if cfg_kvs and hasattr(req, "sampling_params") and req.sampling_params is not None:
+ for key, value in cfg_kvs.items():
+ setattr(req.sampling_params, key, value)
+ logger.info("Applied CFG KV caches: %s", list(cfg_kvs.keys()))
+ except Exception:
+ logger.exception("Failed to collect CFG KV caches for %s", request_id)
+
+ return primary_ok
+
+ # ------------------------------------------------------------------ #
+ # Rank-aware KV transfer routing
+ # ------------------------------------------------------------------ #
+
+ def get_rank_aware_kv_keys(
+ self,
+ req_id: str,
+ from_stage: int,
+ to_stage: int | None = None,
+ chunk_id: int = 0,
+ ) -> list[str]:
+ """Build recv-side connector keys for all remote ranks this rank needs.
+
+ For heterogeneous TP receive, the local rank is the target rank and must
+ fetch one or more source-rank shards keyed as ``from_rank -> to_rank``.
+ """
+ remote_ranks = self.get_kv_remote_ranks()
+ return [
+ self.get_kv_connector_key(
+ req_id=req_id,
+ from_stage=from_stage,
+ chunk_id=chunk_id,
+ from_rank=remote_rank,
+ to_rank=self._local_rank,
+ )
+ for remote_rank in remote_ranks
+ ]
+
+ def get_kv_target_ranks_for_send(self) -> list[int]:
+ """Determine which target ranks this local rank should send KV shards to."""
+ self._validate_kv_tp_topology()
+ if self._from_tp == self._to_tp:
+ return [self._local_rank]
+ if self._from_tp > self._to_tp:
+ tp_ratio = self._from_tp // self._to_tp
+ return [self._local_rank // tp_ratio]
+ tp_ratio = self._to_tp // self._from_tp
+ base_rank = self._local_rank * tp_ratio
+ return [base_rank + i for i in range(tp_ratio)]
+
+ def get_rank_aware_kv_send_keys(
+ self,
+ req_id: str,
+ from_stage: int,
+ to_stage: int | None = None,
+ chunk_id: int = 0,
+ ) -> list[str]:
+ """Build send-side connector keys for this rank's KV shard(s)."""
+ target_ranks = self.get_kv_target_ranks_for_send()
+ return [
+ self.get_kv_connector_key(
+ req_id=req_id,
+ from_stage=from_stage,
+ chunk_id=chunk_id,
+ from_rank=self._local_rank,
+ to_rank=target_rank,
+ )
+ for target_rank in target_ranks
+ ]
+
+ @staticmethod
+ def _merge_rank_sharded_kv_payloads(payloads: list[dict[str, Any]]) -> dict[str, Any] | None:
+ """Merge multiple source-rank KV shards for one target rank."""
+ payloads = [payload for payload in payloads if isinstance(payload, dict)]
+ if not payloads:
+ return None
+ if len(payloads) == 1:
+ return payloads[0]
+
+ merged = dict(payloads[0])
+ layer_blocks = merged.get("layer_blocks")
+ if not isinstance(layer_blocks, dict):
+ return merged
+
+ def _merge_tensor_lists(name: str) -> list[torch.Tensor | None]:
+ merged_list: list[torch.Tensor | None] = []
+ cache_lists = [payload.get("layer_blocks", {}).get(name, []) for payload in payloads]
+ max_len = max((len(cache_list) for cache_list in cache_lists), default=0)
+ for idx in range(max_len):
+ tensors = [cache_list[idx] for cache_list in cache_lists if idx < len(cache_list)]
+ tensors = [tensor for tensor in tensors if isinstance(tensor, torch.Tensor)]
+ if not tensors:
+ merged_list.append(None)
+ elif len(tensors) == 1:
+ merged_list.append(tensors[0])
+ else:
+ merged_list.append(torch.cat(tensors, dim=-2).contiguous())
+ return merged_list
+
+ merged["layer_blocks"] = {
+ "key_cache": _merge_tensor_lists("key_cache"),
+ "value_cache": _merge_tensor_lists("value_cache"),
+ }
+ metadata = dict(merged.get("metadata", {}))
+ metadata["merged_remote_rank_count"] = len(payloads)
+ merged["metadata"] = metadata
+ return merged
+
+ def _slice_rank_sharded_kv_payload(self, payload: dict[str, Any] | None) -> dict[str, Any] | None:
+ """Slice a duplicated source-rank KV shard for ``from_tp < to_tp`` cases."""
+ if payload is None or self._from_tp >= self._to_tp:
+ return payload
+
+ tp_ratio = self._to_tp // self._from_tp
+ shard_index = self._local_rank % tp_ratio
+ layer_blocks = payload.get("layer_blocks") if isinstance(payload, dict) else None
+ if not isinstance(layer_blocks, dict):
+ return payload
+
+ def _slice_tensor_list(name: str) -> list[torch.Tensor | None]:
+ sliced: list[torch.Tensor | None] = []
+ for tensor in layer_blocks.get(name, []):
+ if not isinstance(tensor, torch.Tensor) or tensor.ndim < 2:
+ sliced.append(tensor)
+ continue
+ head_dim = tensor.shape[-2]
+ if head_dim % tp_ratio != 0:
+ sliced.append(tensor)
+ continue
+ per_rank = head_dim // tp_ratio
+ start = shard_index * per_rank
+ sliced.append(tensor.narrow(-2, start, per_rank).contiguous())
+ return sliced
+
+ payload = dict(payload)
+ payload["layer_blocks"] = {
+ "key_cache": _slice_tensor_list("key_cache"),
+ "value_cache": _slice_tensor_list("value_cache"),
+ }
+ metadata = dict(payload.get("metadata", {}))
+ metadata["sliced_for_local_rank"] = self._local_rank
+ payload["metadata"] = metadata
+ return payload
+
+ def should_replicate_payload(self) -> bool:
+ """Whether non-KV payloads should be replicated across ranks.
+
+ Data payloads (stage inputs, chunks) are identical after all-gather,
+ so only rank 0 transfers them. KV payloads are rank-specific and
+ all ranks participate.
+ """
+ return self._local_rank != 0
+
+ def get_kv_rank_mapping(self) -> dict[str, Any]:
+ """Return the current rank mapping configuration.
+
+ Useful for debugging and for downstream code that needs to know
+ the TP topology without re-parsing model config.
+ """
+ return {
+ "from_tp": self._from_tp,
+ "to_tp": self._to_tp,
+ "local_rank": self._local_rank,
+ "remote_ranks": self.get_kv_remote_ranks(),
+ "is_data_transfer_rank": self.is_data_transfer_rank(),
+ }
+
+ # ------------------------------------------------------------------ #
+ # KV transfer lifecycle (RFC – mixin-owned)
+ # ------------------------------------------------------------------ #
+
+ def mark_kv_transfer(
+ self,
+ req_id: str,
+ seq_len: int,
+ block_ids: list[int],
+ custom_metadata: dict[str, Any] | None = None,
+ ) -> None:
+ """Mark a request as needing KV cache transfer.
+
+ Called by the scheduler when a transfer trigger fires. The mixin
+ owns the lifecycle from this point: pending → active → completed.
+ """
+ if req_id in self._kv_pending_transfers:
+ return
+ self._kv_triggered_requests.add(req_id)
+ transfer = {
+ "seq_len": seq_len,
+ "block_ids": block_ids,
+ }
+ if custom_metadata is not None:
+ transfer["custom_metadata"] = custom_metadata
+ self._kv_pending_transfers[req_id] = transfer
+
+ def drain_pending_kv_transfers(self) -> dict[str, dict[str, Any]]:
+ """Drain pending KV transfers and move them to active.
+
+ Returns ``{req_id: {seq_len, block_ids}}`` for the model runner
+ to submit to ``send_kv_cache``.
+ """
+ if not self._kv_pending_transfers:
+ return {}
+ pending = dict(self._kv_pending_transfers)
+ self._kv_active_transfers.update(pending.keys())
+ self._kv_pending_transfers.clear()
+ return pending
+
+ def ack_kv_transfers(self, req_ids: list[str] | set[str]) -> None:
+ """Acknowledge completed KV transfers (from kv_extracted_req_ids).
+
+ Moves requests from active to completed so the scheduler can
+ safely free their blocks.
+ """
+ for req_id in req_ids:
+ self._kv_active_transfers.discard(req_id)
+ self._kv_completed_transfers.add(req_id)
+
+ def drain_completed_kv_transfers(self) -> set[str]:
+ """Drain and return completed KV transfer request IDs.
+
+ The scheduler calls this to know which requests' blocks can be freed.
+ """
+ completed = set(self._kv_completed_transfers)
+ self._kv_completed_transfers.clear()
+ return completed
+
+ def is_kv_transfer_triggered(self, req_id: str) -> bool:
+ """Check if a request has already triggered KV transfer."""
+ return req_id in self._kv_triggered_requests
+
+ def has_pending_kv_work(self) -> bool:
+ """True if any KV transfers are pending, active, or awaiting ack."""
+ return bool(self._kv_pending_transfers or self._kv_active_transfers or self._kv_completed_transfers)
+
+ # Output aggregation
+ # ------------------------------------------------------------------ #
+
+ def _empty_output_with_connector_signals(self) -> Any:
+ """Return a minimal ModelRunnerOutput carrying pending connector signals.
+
+ Used by early-return paths (e.g. ``num_scheduled_tokens == 0``)
+ that still need to deliver ``omni_connector_output`` to the
+ Scheduler so that WAITING_FOR_INPUT / WAITING_FOR_CHUNK
+ transitions are not lost.
+ """
+ from vllm_omni.outputs import OmniModelRunnerOutput
+
+ output = OmniModelRunnerOutput(req_ids=[], req_id_to_index={})
+ output.omni_connector_output = self.get_omni_connector_output()
+ return output
+
+ def get_omni_connector_output(self) -> OmniConnectorOutput:
+ """Collect and reset transfer results for this execute_model cycle.
+
+ ``request_metadata`` carries only lightweight scheduling metadata.
+ Full payloads remain owned by the Model Runner local cache for all
+ paths.
+ """
+ if not hasattr(self, "_lock"):
+ return OmniConnectorOutput()
+
+ tp_group = self._get_local_tp_group()
+ if self._async_chunk and tp_group is not None and getattr(tp_group, "world_size", 1) > 1:
+ if self.is_data_transfer_rank():
+ with self._lock:
+ fanout_packet = self._collect_async_chunk_fanout_packet_locked()
+ else:
+ fanout_packet = None
+ fanout_packet = self._broadcast_tp_payload_packet(fanout_packet)
+ if fanout_packet is None:
+ newly_finished = set()
+ chunk_finished = set()
+ request_metadata = {}
+ else:
+ if not self.is_data_transfer_rank():
+ self._apply_async_chunk_fanout_packet(fanout_packet)
+ newly_finished = set(fanout_packet["newly_finished"])
+ chunk_finished = set(fanout_packet["chunk_finished"])
+ request_metadata = dict(fanout_packet["request_metadata"])
+ else:
+ with self._lock:
+ newly_finished = set(self._finished_load_reqs)
+ self._finished_load_reqs.clear()
+ chunk_finished = set(self._chunk_finished_req_ids)
+ self._chunk_finished_req_ids.clear()
+ request_metadata = dict(self._local_request_metadata)
+ self._local_request_metadata.clear()
+ # _send_side_request_payload is the async accumulation buffer for
+ # future recv chunks. Clearing it on every consumable wake-up drops
+ # intermediate
+ # thinker decode spans before the model side can consume them.
+ # Only terminal chunk_finished requests may release that buffer.
+ for req_id in chunk_finished:
+ if req_id not in self._local_stage_payload_cache:
+ continue
+ ext_req_id = self._request_ids_mapping.get(req_id, req_id)
+ self._send_side_request_payload.pop(ext_req_id, None)
+ if ext_req_id != req_id:
+ self._send_side_request_payload.pop(req_id, None)
+ self._chunk_ready_req_ids.update(newly_finished)
+
+ output = OmniConnectorOutput(
+ chunk_ready_req_ids=set(self._chunk_ready_req_ids),
+ chunk_finished_req_ids=chunk_finished,
+ request_metadata=request_metadata,
+ kv_sent_req_ids=list(self._kv_sent_req_ids),
+ stage_recv_req_ids=set(self._stage_recv_req_ids),
+ has_pending_kv_work=self.has_pending_kv_work(),
+ )
+ if output.stage_recv_req_ids or chunk_finished or newly_finished:
+ logger.info(
+ "[Stage-%s] get_omni_connector_output: stage_recv=%s, chunk_finished=%s, chunk_ready=%s",
+ self._stage_id,
+ output.stage_recv_req_ids,
+ chunk_finished,
+ output.chunk_ready_req_ids,
+ )
+ self._chunk_ready_req_ids.clear()
+ self._kv_sent_req_ids.clear()
+ self._stage_recv_req_ids.clear()
+ return output
+
+ @staticmethod
+ def _connector_output_has_signals(output: OmniConnectorOutput) -> bool:
+ return bool(
+ output.chunk_ready_req_ids
+ or output.chunk_finished_req_ids
+ or output.request_metadata
+ or output.kv_sent_req_ids
+ or output.stage_recv_req_ids
+ or output.has_pending_kv_work
+ )
+
+ def attach_omni_connector_output(self, result: Any | None) -> Any:
+ omni_output = self.get_omni_connector_output()
+ if not self._connector_output_has_signals(omni_output):
+ return result
+
+ from copy import copy
+
+ from vllm.v1.worker.gpu_model_runner import EMPTY_MODEL_RUNNER_OUTPUT
+
+ wrapped = copy(result if result is not None else EMPTY_MODEL_RUNNER_OUTPUT)
+ wrapped.omni_connector_output = omni_output
+ return wrapped
+
+ # ------------------------------------------------------------------ #
+ # Properties for compatibility with custom_process funcs that access
+ # transfer_manager.put_req_chunk / request_payload / code_prompt_token_ids
+ # ------------------------------------------------------------------ #
+
+ @property
+ def put_req_chunk(self) -> dict[str, int]:
+ return self._put_req_chunk
+
+ @property
+ def request_payload(self) -> dict[str, dict[str, Any]]:
+ return self._send_side_request_payload
+
+ @request_payload.setter
+ def request_payload(self, value: dict[str, dict[str, Any]]) -> None:
+ self._send_side_request_payload = value
+
+ @property
+ def code_prompt_token_ids(self) -> dict[str, list[list[int]]]:
+ return self._code_prompt_token_ids
+
+ @property
+ def connector(self) -> Any | None:
+ return self._omni_connector
+
+ # ------------------------------------------------------------------ #
+ # Background I/O threads
+ # ------------------------------------------------------------------ #
+
+ def _recv_loop(self) -> None:
+ """Background thread: poll connector for incoming data."""
+ _recv_poll_count = 0
+ while not self._stop_event.is_set():
+ with self._lock:
+ pending_ids = list(self._pending_load_reqs.keys())
+
+ if not pending_ids:
+ self._work_available.wait(timeout=0.01)
+ self._work_available.clear()
+ continue
+
+ _recv_poll_count += 1
+ if _recv_poll_count % 5000 == 1:
+ logger.info(
+ "[Stage-%s] _recv_loop: polling %s pending reqs: %s (poll#%s)",
+ self._stage_id,
+ len(pending_ids),
+ pending_ids[:5],
+ _recv_poll_count,
+ )
+
+ made_progress = False
+ for req_id in pending_ids:
+ if self._stop_event.is_set():
+ break
+ try:
+ made_progress = self._poll_single_request(req_id) or made_progress
+ except Exception:
+ logger.warning("Error receiving data for %s", req_id, exc_info=True)
+
+ if not made_progress and not self._stop_event.is_set():
+ self._work_available.wait(timeout=0.001)
+ self._work_available.clear()
+
+ _MAX_SEND_RETRIES = 3
+
+ def _save_loop(self) -> None:
+ """Background thread: send outgoing data via connector."""
+ while not self._stop_event.is_set():
+ task = None
+ with self._lock:
+ for req_id in list(self._pending_save_reqs.keys()):
+ dq = self._pending_save_reqs[req_id]
+ if dq:
+ task = dq.popleft()
+ if not dq:
+ del self._pending_save_reqs[req_id]
+ break
+ del self._pending_save_reqs[req_id]
+
+ if task is not None:
+ success = False
+ try:
+ success = self._send_single_request(task)
+ except Exception:
+ logger.error(
+ "Error saving data for %s",
+ task.get("request_id"),
+ exc_info=True,
+ )
+ if not success:
+ self._requeue_or_drop_failed_send(task)
+ continue
+
+ self._work_available.wait(timeout=0.01)
+ self._work_available.clear()
+
+ def _requeue_or_drop_failed_send(self, task: dict) -> None:
+ """Re-enqueue a failed send task or drop it after max retries."""
+ retry_count = task.get("_retry_count", 0) + 1
+ req_id = task.get("request_id")
+ if retry_count <= self._MAX_SEND_RETRIES:
+ task["_retry_count"] = retry_count
+ logger.warning(
+ "[Stage-%s] Re-enqueuing failed send for %s (retry %d/%d)",
+ getattr(self, "_stage_id", "?"),
+ req_id,
+ retry_count,
+ self._MAX_SEND_RETRIES,
+ )
+ with self._lock:
+ dq = self._pending_save_reqs.setdefault(req_id, deque())
+ dq.appendleft(task)
+ else:
+ logger.error(
+ "[Stage-%s] Giving up on send for %s after %d retries",
+ getattr(self, "_stage_id", "?"),
+ req_id,
+ self._MAX_SEND_RETRIES,
+ )
+ self._decrement_pending_save_count(req_id)
+
+ # ------------------------------------------------------------------ #
+ # Chunk-level poll / send (ported from OmniChunkTransferAdapter)
+ # ------------------------------------------------------------------ #
+
+ def _poll_single_request(self, req_id: str) -> bool:
+ """Poll connector for one chunk of a request (non-blocking)."""
+ connector = self._omni_connector
+ if connector is None:
+ return False
+
+ if self._async_chunk and self._model_mode != "ar":
+ with self._lock:
+ staged_payload = self._local_stage_payload_cache.get(req_id)
+ metadata_in_flight = req_id in self._local_request_metadata
+ scheduler_wakeup_pending = req_id in self._finished_load_reqs
+ if self._payload_is_consumable(staged_payload) or metadata_in_flight or scheduler_wakeup_pending:
+ logger.debug(
+ "[Stage-%s] delaying recv for req=%s until staged async payload is handed to scheduler",
+ self._stage_id,
+ req_id,
+ )
+ return False
+
+ target_stage_id = self._stage_id - 1
+ chunk_id = self._get_req_chunk[req_id]
+ external_req_id = self._request_ids_mapping.get(req_id, req_id)
+ connector_get_key = f"{external_req_id}_{target_stage_id}_{chunk_id}"
+
+ if self._async_chunk:
+ result = self._recv_async_chunk_result(
+ connector,
+ str(target_stage_id),
+ str(self._stage_id),
+ connector_get_key,
+ )
+ else:
+ result = self._recv_full_payload_result(
+ connector,
+ str(target_stage_id),
+ str(self._stage_id),
+ connector_get_key,
+ )
+
+ if result is None:
+ return False
+
+ payload_data, _size = result
+ if not payload_data:
+ return False
+ if isinstance(payload_data, dict):
+ logger.info(
+ "[Stage-%s] recv_chunk_result: req=%s ext=%s key=%s keys=%s finished=%s",
+ self._stage_id,
+ req_id,
+ external_req_id,
+ connector_get_key,
+ sorted(payload_data.keys()),
+ bool(payload_data.get("finished")) if "finished" in payload_data else None,
+ )
+
+ self._get_req_chunk[req_id] += 1
+
+ if self._async_chunk:
+ is_finished = bool(payload_data.get("finished"))
+ incoming_payload_consumable = self._payload_is_consumable(payload_data)
+
+ if self._model_mode == "ar":
+ payload_data = self._accumulate_payload(external_req_id, payload_data)
+ payload_consumable = incoming_payload_consumable
+ else:
+ new_ids = payload_data.get("code_predictor_codes", [])
+ if not new_ids and not is_finished:
+ return False
+ payload_consumable = self._payload_is_consumable(payload_data)
+
+ with self._lock:
+ if is_finished:
+ self._chunk_finished_req_ids.add(req_id)
+ self._chunk_stream_completed.add(req_id)
+ # Local cache (RFC §2.4) — merge, don't replace, so that
+ # earlier chunk keys (e.g. thinker_prefill_embeddings from
+ # chunk 0) are not overwritten by later chunks.
+ existing = self._local_stage_payload_cache.get(req_id)
+ if existing is not None and isinstance(existing, dict) and isinstance(payload_data, dict):
+ existing.update(payload_data)
+ else:
+ self._local_stage_payload_cache[req_id] = payload_data
+ staged_payload = self._local_stage_payload_cache[req_id]
+ self._async_chunk_updated_req_ids.add(req_id)
+ self.put_local_request_metadata(req_id, self._extract_scheduling_metadata(staged_payload))
+ # A finish-only sentinel still needs one terminal wake-up so
+ # the downstream stage can sync the merged local payload and
+ # flush/finish even when the last recv carries no new
+ # consumable chunk bytes.
+ if payload_consumable or is_finished:
+ self._finished_load_reqs.add(req_id)
+ if is_finished and not payload_consumable:
+ logger.debug(
+ "[Stage-%s] finish sentinel arrived for req=%s without new consumable payload",
+ self._stage_id,
+ req_id,
+ )
+ elif not payload_consumable:
+ logger.debug(
+ "[Stage-%s] req=%s received metadata-only / non-consumable async payload; delaying wake-up",
+ self._stage_id,
+ req_id,
+ )
+ if is_finished:
+ self._pending_load_reqs.pop(req_id, None)
+ else:
+ # full_payload_mode: the complete payload arrives in a single get(),
+ # so always unregister immediately.
+ if isinstance(payload_data, dict):
+ engine_inputs = payload_data.get("engine_inputs", payload_data)
+ else:
+ engine_inputs = payload_data
+ with self._lock:
+ self._local_stage_payload_cache[req_id] = self._snapshot_payload(engine_inputs)
+ # Publish full-payload readiness only after the aligned TP broadcast
+ # path in recv_full_payload_inputs() has materialized the payload on all
+ # local ranks. Publishing metadata / stage_recv from the background recv
+ # thread can let the scheduler observe a request before the payload is
+ # actually visible to the model thread.
+ self._full_payload_pending_broadcast_req_ids.add(req_id)
+ self._pending_load_reqs.pop(req_id, None)
+ logger.info(
+ "[Stage-%s] full_payload recv complete: req=%s key=%s payload_type=%s",
+ self._stage_id,
+ req_id,
+ connector_get_key,
+ type(engine_inputs).__name__,
+ )
+
+ logger.debug("[Stage-%s] Received data for key %s", self._stage_id, connector_get_key)
+ return True
+
+ def _build_custom_process_payload(
+ self,
+ request_id: str | None,
+ request: Any | None,
+ pooling_output: Any | None,
+ ) -> Any | None:
+ """Run the custom process hook with a best-effort finished kwarg."""
+ if self._custom_process_func is None:
+ return None
+
+ kwargs = {
+ "transfer_manager": self,
+ "pooling_output": pooling_output,
+ "request": request,
+ }
+ supports_is_finished = getattr(
+ self,
+ "_custom_process_supports_is_finished",
+ self._custom_process_supports_is_finished_kwarg(),
+ )
+ is_finished_fn = getattr(request, "is_finished", None)
+ if callable(is_finished_fn):
+ try:
+ if supports_is_finished is not False:
+ kwargs["is_finished"] = bool(is_finished_fn())
+ except Exception:
+ logger.debug("request.is_finished() failed for %s", request_id, exc_info=True)
+
+ try:
+ return self._custom_process_func(**kwargs)
+ except TypeError as exc:
+ if "is_finished" not in kwargs or not self._is_unexpected_is_finished_kwarg_error(exc):
+ logger.exception("custom_process_stage_input_func failed for chunk %s", request_id)
+ return None
+ kwargs.pop("is_finished", None)
+ try:
+ return self._custom_process_func(**kwargs)
+ except Exception:
+ logger.exception("custom_process_stage_input_func failed for chunk %s", request_id)
+ return None
+ except Exception:
+ logger.exception("custom_process_stage_input_func failed for chunk %s", request_id)
+ return None
+
+ def _custom_process_supports_is_finished_kwarg(self) -> bool | None:
+ """Return whether the custom process hook accepts `is_finished`."""
+ if self._custom_process_func is None:
+ return None
+ try:
+ signature = inspect.signature(self._custom_process_func)
+ except (TypeError, ValueError):
+ return None
+
+ for param in signature.parameters.values():
+ if param.kind == inspect.Parameter.VAR_KEYWORD:
+ return True
+
+ is_finished_param = signature.parameters.get("is_finished")
+ if is_finished_param is None:
+ return False
+ return is_finished_param.kind in (
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ inspect.Parameter.KEYWORD_ONLY,
+ )
+
+ @staticmethod
+ def _is_unexpected_is_finished_kwarg_error(exc: TypeError) -> bool:
+ message = str(exc)
+ return (
+ "unexpected keyword argument 'is_finished'" in message
+ or 'unexpected keyword argument "is_finished"' in message
+ or "positional-only arguments passed as keyword arguments: 'is_finished'" in message
+ )
+
+ def _send_single_request(self, task: dict) -> bool:
+ """Send one queued task via connector.put().
+
+ Returns True on success. On failure (put() raises or returns
+ ``success=False``), returns False **without** decrementing
+ ``_pending_save_counts`` so the caller can retry or clean up.
+ """
+ connector = self._omni_connector
+ if connector is None:
+ return True
+
+ request_id = task.get("request_id")
+ payload_data = task.get("data")
+ if payload_data is None and task.get("request") is not None:
+ payload_data = self._build_custom_process_payload(
+ request_id=request_id,
+ request=task.get("request"),
+ pooling_output=task.get("pooling_output"),
+ )
+ put_key = task.get("put_key")
+
+ success, _size, _metadata = connector.put(
+ from_stage=str(task["stage_id"]),
+ to_stage=str(task["next_stage_id"]),
+ put_key=put_key,
+ data=payload_data,
+ )
+ logger.info(
+ "[Stage-%s] _send_single_request: put_key=%s success=%s size=%s",
+ task["stage_id"],
+ put_key,
+ success,
+ _size,
+ )
+
+ if not success:
+ return False
+
+ self._decrement_pending_save_count(request_id)
+ return True
+
+ def _decrement_pending_save_count(self, request_id: str) -> None:
+ """Decrement pending save count and run deferred cleanup if zero."""
+ cleanup_req_id = None
+ with self._lock:
+ remaining = self._pending_save_counts.get(request_id, 0)
+ if remaining > 1:
+ self._pending_save_counts[request_id] = remaining - 1
+ elif remaining == 1:
+ self._pending_save_counts.pop(request_id, None)
+ if request_id in self._deferred_send_cleanup:
+ self._deferred_send_cleanup.remove(request_id)
+ cleanup_req_id = request_id
+ if cleanup_req_id is not None:
+ self._put_req_chunk.pop(cleanup_req_id, None)
+ self._send_side_request_payload.pop(cleanup_req_id, None)
+ self._code_prompt_token_ids.pop(cleanup_req_id, None)
+
+ # ------------------------------------------------------------------ #
+ # Payload accumulation (ported from OmniChunkTransferAdapter)
+ # ------------------------------------------------------------------ #
+
+ def _accumulate_payload(self, req_id: str, payload_data: dict[str, Any]) -> dict[str, Any]:
+ """Accumulate chunk payloads (concat tensors, extend lists).
+
+ Returns a **shallow copy** of the accumulated state so callers
+ (e.g. ``_poll_single_request``) can store it in
+ ``_local_stage_payload_cache`` without aliasing the authoritative
+ ``_send_side_request_payload`` dict.
+ """
+ if req_id not in self._send_side_request_payload:
+ self._send_side_request_payload[req_id] = dict(payload_data)
+ return dict(self._send_side_request_payload[req_id])
+
+ origin = self._send_side_request_payload[req_id]
+ merged = dict(origin)
+ override_keys = payload_data.get("override_keys", ())
+ drop_decode_span = False
+ decode_span_handled = False
+ for key, value in payload_data.items():
+ if key == "finished":
+ merged[key] = value
+ continue
+ if key == THINKER_DECODE_EMBEDDINGS_KEY:
+ merged_span = merge_tensor_spans(
+ get_tensor_span(
+ origin,
+ tensor_key=THINKER_DECODE_EMBEDDINGS_KEY,
+ start_key=THINKER_DECODE_TOKEN_START_KEY,
+ end_key=THINKER_DECODE_TOKEN_END_KEY,
+ ),
+ get_tensor_span(
+ payload_data,
+ tensor_key=THINKER_DECODE_EMBEDDINGS_KEY,
+ start_key=THINKER_DECODE_TOKEN_START_KEY,
+ end_key=THINKER_DECODE_TOKEN_END_KEY,
+ ),
+ )
+ if merged_span is not None:
+ merged[key], merged[THINKER_DECODE_TOKEN_START_KEY], merged[THINKER_DECODE_TOKEN_END_KEY] = (
+ merged_span
+ )
+ decode_span_handled = True
+ continue
+ if isinstance(value, torch.Tensor) and key in origin:
+ if (
+ THINKER_DECODE_TOKEN_START_KEY in origin
+ or THINKER_DECODE_TOKEN_END_KEY in origin
+ or THINKER_DECODE_TOKEN_START_KEY in payload_data
+ or THINKER_DECODE_TOKEN_END_KEY in payload_data
+ ):
+ logger.warning(
+ "[Stage-%s] req=%s falling back to legacy thinker decode "
+ "merge due to missing/invalid/non-contiguous span "
+ "metadata",
+ self._stage_id,
+ req_id,
+ )
+ drop_decode_span = True
+ merged[key] = torch.cat([origin[key], value], dim=0)
+ continue
+ merged[key] = value
+ continue
+ if key in {THINKER_DECODE_TOKEN_START_KEY, THINKER_DECODE_TOKEN_END_KEY}:
+ if decode_span_handled or drop_decode_span:
+ continue
+ merged[key] = value
+ continue
+ if key in override_keys:
+ merged[key] = value
+ continue
+ if isinstance(value, torch.Tensor) and key in origin:
+ merged[key] = torch.cat([origin[key], value], dim=0)
+ elif isinstance(value, list) and key in origin:
+ merged[key] = origin[key] + value
+ else:
+ merged[key] = value
+
+ if drop_decode_span:
+ merged.pop(THINKER_DECODE_TOKEN_START_KEY, None)
+ merged.pop(THINKER_DECODE_TOKEN_END_KEY, None)
+ self._send_side_request_payload[req_id] = merged
+ return dict(merged)
+
+ def drop_inactive_request_runtime_state(self, req_id: str) -> None:
+ """Clear inactive request state used by both the runner and mixin.
+
+ This centralizes the model-runner-side cleanup pattern so
+ ``OmniGPUModelRunner`` can reuse it instead of open-coding the same
+ inactive-request state mutations.
+ """
+ if hasattr(self, "model_intermediate_buffer"):
+ self.model_intermediate_buffer.pop(req_id, None)
+ self.drop_inactive_request_delivery_state(req_id)
+
+ # ------------------------------------------------------------------ #
+ # Helpers
+ # ------------------------------------------------------------------ #
+
+ @staticmethod
+ def _freeze_request_attr(value: Any) -> Any:
+ if isinstance(value, list):
+ return list(value)
+ if isinstance(value, tuple):
+ return list(value)
+ if isinstance(value, torch.Tensor):
+ return value.clone()
+ raw_list = getattr(value, "_x", None)
+ if raw_list is not None:
+ return list(raw_list)
+ return value
+
+ def _snapshot_request_for_send(self, request: Any, external_req_id: str) -> Any:
+ finished = bool(getattr(request, "is_finished", lambda: False)())
+ attrs: dict[str, Any] = {}
+ try:
+ attrs.update(vars(request))
+ except TypeError:
+ pass
+
+ for name in (
+ "request_id",
+ "req_id",
+ "external_req_id",
+ "prompt_token_ids",
+ "output_token_ids",
+ "all_token_ids",
+ "additional_information",
+ "sampling_params",
+ "multi_modal_data",
+ "mm_hashes",
+ ):
+ if hasattr(request, name):
+ attrs[name] = self._freeze_request_attr(getattr(request, name))
+
+ attrs["external_req_id"] = external_req_id
+ attrs["_frozen_is_finished"] = finished
+ snapshot = SimpleNamespace(**attrs)
+ snapshot.is_finished = lambda: finished
+ return snapshot
+
+ @staticmethod
+ def _create_connector(model_config: Any) -> OmniConnectorBase | None:
+ """Create a connector from model_config, or None if unconfigured."""
+ connector_config = getattr(model_config, "stage_connector_config", None)
+ if connector_config is None:
+ return None
+
+ if not isinstance(connector_config, dict):
+ connector_config = {
+ "name": getattr(connector_config, "name", None),
+ "extra": getattr(connector_config, "extra", None),
+ }
+
+ name = connector_config.get("name")
+ if not isinstance(name, str) or not name.strip():
+ raise RuntimeError("Invalid stage connector config: missing connector name")
+ name = name.strip()
+
+ extra = connector_config.get("extra")
+ if extra is None:
+ extra = {}
+ elif not isinstance(extra, dict):
+ raise RuntimeError(f"Invalid extra config for connector {name}: expected dict, got {type(extra).__name__}")
+
+ spec = ConnectorSpec(name=name, extra=extra)
+ try:
+ return OmniConnectorFactory.create_connector(spec)
+ except Exception as exc:
+ raise RuntimeError(f"Failed to create connector {name}") from exc
+
+ @staticmethod
+ def _load_custom_func(model_config: Any) -> tuple[str | None, Any | None]:
+ """Load the connector payload builder for the downstream stage.
+
+ Preferred source is ``custom_process_next_stage_input_func``. Some
+ full_payload_mode configs (async_chunk=false) only expose the next-stage prompt builder via
+ ``custom_process_input_func`` (for example ``thinker2talker``), while the
+ connector payload builder lives beside it as ``thinker2talker_full_payload``.
+ In that case, derive the full_payload_mode builder path automatically.
+ """
+ candidates: list[str] = []
+
+ next_stage_func = getattr(model_config, "custom_process_next_stage_input_func", None)
+ if isinstance(next_stage_func, str) and next_stage_func:
+ candidates.append(next_stage_func)
+
+ if not getattr(model_config, "async_chunk", False):
+ input_func = getattr(model_config, "custom_process_input_func", None)
+ if isinstance(input_func, str) and input_func:
+ try:
+ module_path, func_name = input_func.rsplit(".", 1)
+ if func_name.endswith("_full_payload") or func_name.endswith("_batch"):
+ candidates.append(f"{module_path}.{func_name}")
+ else:
+ candidates.append(f"{module_path}.{func_name}_full_payload")
+ candidates.append(f"{module_path}.{func_name}_batch")
+ candidates.append(input_func)
+ except ValueError:
+ candidates.append(input_func)
+
+ tried: set[str] = set()
+ for func_path in candidates:
+ if func_path in tried:
+ continue
+ tried.add(func_path)
+ try:
+ module_path, func_name = func_path.rsplit(".", 1)
+ module = importlib.import_module(module_path)
+ func = getattr(module, func_name, None)
+ if callable(func):
+ if not OmniConnectorModelRunnerMixin._is_connector_payload_builder(func):
+ logger.debug(
+ "Skipping incompatible connector payload hook %s; signature=%s",
+ func_path,
+ inspect.signature(func),
+ )
+ continue
+ return func_path, func
+ except Exception:
+ logger.warning("Failed to load custom func: %s", func_path, exc_info=True)
+
+ return None, None
+
+ @staticmethod
+ def _is_connector_payload_builder(func: Any) -> bool:
+ """Whether *func* matches the mixin payload-builder contract."""
+ try:
+ signature = inspect.signature(func)
+ except (TypeError, ValueError):
+ return False
+
+ params = signature.parameters
+ if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params.values()):
+ return True
+
+ required = {"transfer_manager", "pooling_output", "request"}
+ supported = {
+ name
+ for name, param in params.items()
+ if param.kind
+ in (
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ inspect.Parameter.KEYWORD_ONLY,
+ )
+ }
+ return required.issubset(supported)
+
+ def _resolve_external_req_id(self, request: Any, fallback_req_id: str) -> str:
+ """Resolve the external request ID consistently.
+
+ Checks ``_request_ids_mapping`` first (populated by
+ ``register_chunk_recv``), then falls back to the request's
+ ``external_req_id`` attribute, and finally to the given
+ ``fallback_req_id``.
+ """
+ mapped = self._request_ids_mapping.get(fallback_req_id)
+ if mapped is not None:
+ return mapped
+ if request is not None:
+ return getattr(request, "external_req_id", fallback_req_id)
+ return fallback_req_id
+
+ def _resolve_next_stage_id(self, model_config: Any) -> int:
+ """Determine the downstream stage ID from connector config.
+
+ Falls back to ``stage_id + 1`` when the config does not specify
+ a ``to_stage`` explicitly.
+ """
+ connector_config = getattr(model_config, "stage_connector_config", None)
+ if connector_config is not None:
+ if isinstance(connector_config, dict):
+ to_stage = connector_config.get("to_stage")
+ else:
+ to_stage = getattr(connector_config, "to_stage", None)
+ if isinstance(to_stage, int):
+ return to_stage
+ if isinstance(to_stage, str) and to_stage.strip():
+ return int(to_stage)
+ return self._stage_id + 1
+
+ @staticmethod
+ def _parse_rank_mapping(model_config: Any) -> dict[str, int]:
+ """Parse rank_mapping from connector config (optional).
+
+ Returns ``{"from_tp": int, "to_tp": int, "local_rank": int}``.
+ When ``rank_mapping`` is absent, assumes 1:1 homogeneous mapping.
+ """
+ connector_config = getattr(model_config, "stage_connector_config", None)
+ if connector_config is not None and not isinstance(connector_config, dict):
+ connector_config = getattr(connector_config, "__dict__", {})
+
+ rank_mapping: dict = {}
+ if isinstance(connector_config, dict):
+ rank_mapping = connector_config.get("rank_mapping", {})
+
+ from_tp = int(rank_mapping.get("from_tp", 1))
+ to_tp = int(rank_mapping.get("to_tp", 1))
+
+ local_rank = 0
+ try:
+ local_rank = int(os.environ.get("LOCAL_RANK", "0"))
+ except (ValueError, TypeError):
+ pass
+
+ return {"from_tp": from_tp, "to_tp": to_tp, "local_rank": local_rank}
+
+ # ------------------------------------------------------------------ #
+ # Heterogeneous TP rank support
+ # ------------------------------------------------------------------ #
+
+ def _validate_kv_tp_topology(self) -> None:
+ """Reject heterogeneous TP mappings that cannot be routed losslessly."""
+ if self._from_tp <= 0 or self._to_tp <= 0:
+ raise ValueError(f"Invalid KV TP mapping: from_tp={self._from_tp}, to_tp={self._to_tp}")
+ larger = max(self._from_tp, self._to_tp)
+ smaller = min(self._from_tp, self._to_tp)
+ if larger % smaller != 0:
+ raise ValueError(
+ f"KV TP mapping must be divisible for rank-aware routing: from_tp={self._from_tp}, to_tp={self._to_tp}"
+ )
+
+ def get_kv_remote_ranks(self) -> list[int]:
+ """Determine which remote ranks this local rank exchanges KV with.
+
+ Follows vLLM's ``TpKVTopology.get_target_remote_ranks()`` pattern:
+ - ``from_tp > to_tp``: each to-rank reads from multiple from-ranks
+ - ``from_tp < to_tp``: multiple to-ranks read from the same from-rank
+ - ``from_tp == to_tp``: 1:1 mapping
+ """
+ self._validate_kv_tp_topology()
+ if self._from_tp == self._to_tp:
+ return [self._local_rank]
+
+ if self._from_tp > self._to_tp:
+ tp_ratio = self._from_tp // self._to_tp
+ return [self._local_rank * tp_ratio + i for i in range(tp_ratio)]
+ else:
+ tp_ratio = self._to_tp // self._from_tp
+ return [self._local_rank // tp_ratio]
+
+ def is_data_transfer_rank(self) -> bool:
+ """Whether this rank should participate in data (non-KV) transfer.
+
+ Ordinary stage payloads are TP-identical, so exactly one TP rank
+ should talk to the connector. When TP is initialized, use TP rank 0
+ so the connector leader matches TP-local broadcast source rank.
+ Otherwise fall back to LOCAL_RANK==0 for the single-rank case.
+ """
+ tp_group = self._get_local_tp_group()
+ if tp_group is not None and getattr(tp_group, "world_size", 1) > 1:
+ return getattr(tp_group, "rank_in_group", 0) == 0
+ return self._local_rank == 0
+
+ def get_kv_connector_key(
+ self,
+ req_id: str,
+ from_stage: int,
+ chunk_id: int,
+ from_rank: int,
+ to_rank: int,
+ ) -> str:
+ """Build connector key that includes rank info for KV transfers."""
+ return f"{req_id}_{from_stage}_{chunk_id}_{from_rank}_{to_rank}"
diff --git a/vllm_omni/worker/payload_span.py b/vllm_omni/worker/payload_span.py
new file mode 100644
index 0000000000..994392343a
--- /dev/null
+++ b/vllm_omni/worker/payload_span.py
@@ -0,0 +1,64 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Helpers for explicit thinker decode span metadata."""
+
+from collections.abc import Mapping
+from typing import Any
+
+import torch
+
+THINKER_DECODE_EMBEDDINGS_KEY = "thinker_decode_embeddings"
+THINKER_OUTPUT_TOKEN_IDS_KEY = "thinker_output_token_ids"
+THINKER_DECODE_TOKEN_START_KEY = "thinker_decode_embeddings_token_start"
+THINKER_DECODE_TOKEN_END_KEY = "thinker_decode_embeddings_token_end"
+
+CACHED_THINKER_DECODE_EMBEDDINGS_KEY = "cached_thinker_decode_embeddings"
+CACHED_THINKER_DECODE_TOKEN_START_KEY = "cached_thinker_decode_embeddings_token_start"
+CACHED_THINKER_DECODE_TOKEN_END_KEY = "cached_thinker_decode_embeddings_token_end"
+
+TensorSpan = tuple[torch.Tensor, int, int]
+
+
+def get_tensor_span(payload: Mapping[str, Any], *, tensor_key: str, start_key: str, end_key: str) -> TensorSpan | None:
+ tensor = payload.get(tensor_key)
+ start = payload.get(start_key)
+ end = payload.get(end_key)
+ if not isinstance(tensor, torch.Tensor):
+ return None
+ if not isinstance(start, int) or not isinstance(end, int):
+ return None
+ if start < 0 or end < start or (end - start) != int(tensor.shape[0]):
+ return None
+ return tensor, start, end
+
+
+def merge_tensor_spans(existing_span: TensorSpan | None, incoming_span: TensorSpan | None) -> TensorSpan | None:
+ if existing_span is None or incoming_span is None:
+ return None
+
+ existing_tensor, existing_start, existing_end = existing_span
+ incoming_tensor, incoming_start, incoming_end = incoming_span
+ if incoming_tensor.device != existing_tensor.device or incoming_tensor.dtype != existing_tensor.dtype:
+ incoming_tensor = incoming_tensor.to(device=existing_tensor.device, dtype=existing_tensor.dtype)
+ if incoming_start == existing_end:
+ return torch.cat([existing_tensor, incoming_tensor], dim=0), existing_start, incoming_end
+ if incoming_start < existing_end:
+ overlap = existing_end - incoming_start
+ if overlap >= int(incoming_tensor.shape[0]):
+ return existing_tensor, existing_start, existing_end
+ trimmed_tensor = incoming_tensor[overlap:]
+ return (
+ torch.cat([existing_tensor, trimmed_tensor], dim=0),
+ existing_start,
+ existing_end + int(trimmed_tensor.shape[0]),
+ )
+ return None
+
+
+def get_tensor_span_row(span: TensorSpan | None, index: int) -> torch.Tensor | None:
+ if span is None:
+ return None
+ tensor, start, end = span
+ if index < start or index >= end:
+ return None
+ return tensor[index - start]