diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index d9a2315953..00823951dc 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -44,11 +44,19 @@ steps: agents: queue: "cpu_queue_premerge" - # L4 Test — main+NIGHTLY=1 (scheduled), or PR with label nightly-test (e.g. add label then Rebuild) + # L4 Test — main+NIGHTLY=1 (scheduled), or PR with specific label (e.g. add label then Rebuild) - label: "Upload Nightly Pipeline" depends_on: image-build key: upload-nightly-pipeline - if: '(build.branch == "main" && build.env("NIGHTLY") == "1") || (build.branch != "main" && build.pull_request.labels includes "nightly-test")' + if: >- + (build.branch == "main" && build.env("NIGHTLY") == "1") || + (build.branch != "main" && ( + build.pull_request.labels includes "nightly-test" || + build.pull_request.labels includes "omni-test" || + build.pull_request.labels includes "tts-test" || + build.pull_request.labels includes "diffusion-x2iat-test" || + build.pull_request.labels includes "diffusion-x2v-test" + )) commands: - buildkite-agent pipeline upload .buildkite/test-nightly.yml agents: diff --git a/.buildkite/test-amd-merge.yml b/.buildkite/test-amd-merge.yml index 60ba0d9d41..ac52f60b35 100644 --- a/.buildkite/test-amd-merge.yml +++ b/.buildkite/test-amd-merge.yml @@ -32,7 +32,6 @@ steps: mirror_hardwares: [amdproduction] grade: Blocking commands: - - export GPU_ARCHS=gfx942 - export VLLM_LOGGING_LEVEL=DEBUG - export VLLM_WORKER_MULTIPROC_METHOD=spawn - | @@ -55,7 +54,7 @@ steps: # - export GPU_ARCHS=gfx942 # - export VLLM_LOGGING_LEVEL=DEBUG # - export VLLM_WORKER_MULTIPROC_METHOD=spawn -# - timeout 20m pytest -s -v tests/e2e/offline_inference/test_stable_audio_model.py +# - timeout 20m pytest -s -v tests/e2e/offline_inference/test_stable_audio_expansion.py -m "advanced_model and diffusion and L4" --run-level advanced_model - label: "Diffusion Cache Backend Test" agent_pool: mi325_1 @@ -63,13 +62,12 @@ steps: mirror_hardwares: [amdproduction] grade: Blocking commands: - - export GPU_ARCHS=gfx942 - export VLLM_LOGGING_LEVEL=DEBUG - export VLLM_WORKER_MULTIPROC_METHOD=spawn - timeout 15m pytest -s -v -m "core_model and cache and diffusion and not distributed_cuda and L4" -- label: "Diffusion Sequence Parallelism Test" - agent_pool: mi325_2 +- label: "Diffusion Sequence Parallelism Test (Need 4 GPUs)" + agent_pool: mi325_4 depends_on: amd-build mirror_hardwares: [amdproduction] grade: Blocking @@ -77,6 +75,7 @@ steps: - export VLLM_LOGGING_LEVEL=DEBUG - export VLLM_WORKER_MULTIPROC_METHOD=spawn - timeout 20m pytest -s -v tests/e2e/offline_inference/test_sequence_parallel.py + - timeout 20m pytest -s -v tests/diffusion/distributed/test_ulysses_uaa_perf.py # merge-only tests - label: "Diffusion Tensor Parallelism Test" @@ -95,22 +94,14 @@ steps: commands: - timeout 20m pytest -s -v tests/diffusion/test_diffusion_worker.py -- label: "Benchmark & Engine Test" - agent_pool: mi325_2 +- label: "Engine Test" + agent_pool: mi325_1 depends_on: amd-build mirror_hardwares: [amdproduction] grade: Blocking commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - | - timeout 20m bash -c ' - set +e - pytest -s -v tests/benchmarks/test_serve_cli.py - EXIT1=\$? - pytest -s -v tests/engine/test_async_omni_engine_abort.py - EXIT2=\$? - exit \$((EXIT1 | EXIT2)) - ' + - timeout 20m pytest -s -v tests/engine/test_async_omni_engine_abort.py - label: "Omni Model Test Qwen2-5-Omni" agent_pool: mi325_2 @@ -121,6 +112,7 @@ steps: - export VLLM_LOGGING_LEVEL=DEBUG - export VLLM_WORKER_MULTIPROC_METHOD=spawn - timeout 20m pytest -s -v tests/e2e/offline_inference/test_qwen2_5_omni.py + - timeout 20m pytest -s -v tests/e2e/online_serving/test_qwen2_5_omni.py -m "advanced_model" --run-level "advanced_model" - label: "Omni Model Test Qwen3-Omni" agent_pool: mi325_2 @@ -131,11 +123,10 @@ steps: - export VLLM_LOGGING_LEVEL=DEBUG - export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_TEST_CLEAN_GPU_MEMORY=1 - - timeout 10m pytest -s -v tests/e2e/offline_inference/test_qwen3_omni.py - - timeout 20m pytest -s -v tests/e2e/online_serving/test_qwen3_omni.py -m "advanced_model" --run-level "advanced_model" + - timeout 30m pytest -s -v tests/e2e/offline_inference/test_qwen3_omni.py tests/e2e/online_serving/test_qwen3_omni.py tests/e2e/online_serving/test_mimo_audio.py -m "advanced_model" --run-level "advanced_model" - label: "Qwen3-TTS CustomVoice E2E Test" - agent_pool: mi325_2 + agent_pool: mi325_1 depends_on: amd-build mirror_hardwares: [amdproduction] grade: Blocking @@ -145,21 +136,21 @@ steps: export VLLM_LOGGING_LEVEL=DEBUG export VLLM_WORKER_MULTIPROC_METHOD=spawn export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1" - pytest -s -v tests/e2e/online_serving/test_qwen3_tts_customvoice.py -m "advanced_model" --run-level "advanced_model" && pytest -s -v tests/e2e/offline_inference/test_qwen3_tts_customvoice.py + pytest -s -v tests/e2e/online_serving/test_qwen3_tts_customvoice.py tests/e2e/offline_inference/test_qwen3_tts_customvoice.py -m "advanced_model" --run-level "advanced_model" ' - label: "Qwen3-TTS Base E2E Test" - agent_pool: mi325_2 + agent_pool: mi325_1 depends_on: amd-build mirror_hardwares: [amdproduction] grade: Blocking commands: - | - timeout 20m bash -c ' + timeout 30m bash -c ' export VLLM_LOGGING_LEVEL=DEBUG export VLLM_WORKER_MULTIPROC_METHOD=spawn export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1" - pytest -s -v tests/e2e/online_serving/test_qwen3_tts_base.py -m "advanced_model" --run-level "advanced_model" && pytest -s -v tests/e2e/offline_inference/test_qwen3_tts_base.py + pytest -s -v tests/e2e/online_serving/test_qwen3_tts_base.py tests/e2e/offline_inference/test_qwen3_tts_base.py -m "advanced_model" --run-level "advanced_model" ' - label: "Diffusion Image Edit Test" @@ -173,43 +164,58 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - timeout 20m pytest -s -v tests/e2e/online_serving/test_image_gen_edit.py -# split Bagel Model Test with H100 (Real Weights) into three tests -- label: "Bagel Text2Img Model Test" - agent_pool: mi325_1 - depends_on: amd-build - mirror_hardwares: [amdproduction] - grade: Blocking - commands: - - export GPU_ARCHS=gfx942 - - export VLLM_TEST_CLEAN_GPU_MEMORY=1 - - export VLLM_LOGGING_LEVEL=DEBUG - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - export VLLM_ROCM_USE_AITER_RMSNORM=0 - - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_text2img.py -m "advanced_model" --run-level "advanced_model" -k "shared_memory" -k "rocm" +# TODO: Bagel test on ROCm is very unstable. @tjtanaa +# Need to debug before reneable numerical changes across large PRs +# # split Bagel Model Test with H100 (Real Weights) into three tests +# - label: "Bagel Text2Img Model Test (1/3)" +# agent_pool: mi325_1 +# depends_on: amd-build +# mirror_hardwares: [amdproduction] +# grade: Blocking +# commands: +# - export GPU_ARCHS=gfx942 +# - export VLLM_TEST_CLEAN_GPU_MEMORY=1 +# - export VLLM_LOGGING_LEVEL=DEBUG +# - export VLLM_WORKER_MULTIPROC_METHOD=spawn +# - export VLLM_ROCM_USE_AITER_RMSNORM=0 +# - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_text2img.py -m "advanced_model" --run-level "advanced_model" -k "shared_memory" -k "rocm" -- label: "Bagel Img2Img Model Test" - agent_pool: mi325_1 - depends_on: amd-build - mirror_hardwares: [amdproduction] - grade: Blocking - commands: - - export GPU_ARCHS=gfx942 - - export VLLM_TEST_CLEAN_GPU_MEMORY=1 - - export VLLM_LOGGING_LEVEL=DEBUG - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - export VLLM_ROCM_USE_AITER_RMSNORM=0 - - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_img2img.py -m "advanced_model" --run-level "advanced_model" -k "rocm" +# - label: "Bagel Img2Img Model Test (2/3)" +# agent_pool: mi325_1 +# depends_on: amd-build +# mirror_hardwares: [amdproduction] +# grade: Blocking +# commands: +# - export GPU_ARCHS=gfx942 +# - export VLLM_TEST_CLEAN_GPU_MEMORY=1 +# - export VLLM_LOGGING_LEVEL=DEBUG +# - export VLLM_WORKER_MULTIPROC_METHOD=spawn +# - export VLLM_ROCM_USE_AITER_RMSNORM=0 +# - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_img2img.py -m "advanced_model" --run-level "advanced_model" -k "rocm" + +# - label: "Bagel Online Serving Test (3/3)" +# agent_pool: mi325_1 +# depends_on: amd-build +# mirror_hardwares: [amdproduction] +# grade: Blocking +# commands: +# - export GPU_ARCHS=gfx942 +# - export VLLM_TEST_CLEAN_GPU_MEMORY=1 +# - export VLLM_IMAGE_FETCH_TIMEOUT=60 +# - export VLLM_LOGGING_LEVEL=DEBUG +# - export VLLM_WORKER_MULTIPROC_METHOD=spawn +# - export VLLM_ROCM_USE_AITER_RMSNORM=0 +# - timeout 40m pytest -s -v tests/e2e/online_serving/test_bagel_online.py -m "advanced_model" --run-level "advanced_model" -k "rocm" -- label: "Bagel Online Serving Test" +- label: "Voxtral-TTS E2E Test" agent_pool: mi325_1 depends_on: amd-build mirror_hardwares: [amdproduction] grade: Blocking commands: - - export GPU_ARCHS=gfx942 - - export VLLM_TEST_CLEAN_GPU_MEMORY=1 - - export VLLM_IMAGE_FETCH_TIMEOUT=60 - - export VLLM_LOGGING_LEVEL=DEBUG - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - export VLLM_ROCM_USE_AITER_RMSNORM=0 - - timeout 40m pytest -s -v tests/e2e/online_serving/test_bagel_online.py -m "advanced_model" --run-level "advanced_model" -k "rocm" + - | + timeout 20m bash -c ' + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_WORKER_MULTIPROC_METHOD=spawn + pytest -s -v tests/e2e/online_serving/test_voxtral_tts.py tests/e2e/offline_inference/test_voxtral_tts.py -m "advanced_model" --run-level "advanced_model" + ' diff --git a/.buildkite/test-amd-ready.yaml b/.buildkite/test-amd-ready.yaml index 6e31163acc..30bbc76941 100644 --- a/.buildkite/test-amd-ready.yaml +++ b/.buildkite/test-amd-ready.yaml @@ -9,13 +9,37 @@ steps: - export VLLM_ROCM_USE_AITER=0 - "timeout 20m pytest -v -s -m 'core_model and cpu' --cov=vllm_omni --cov-branch --cov-report=term-missing --cov-report=html --cov-report=xml" +- label: "Voxtral TTS CUDA Unit Test" + agent_pool: mi325_1 + depends_on: amd-build + mirror_hardwares: [amdproduction] + grade: Blocking + commands: + - timeout 10m pytest -s -v tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py + - label: "Diffusion Model Test" - agent_pool: mi325_2 + agent_pool: mi325_1 + depends_on: amd-build + mirror_hardwares: [amdproduction] + grade: Blocking + commands: + - timeout 30m pytest -s -v tests/e2e/offline_inference/test_t2i_model.py -m "core_model and diffusion" --run-level "core_model" + +- label: "Diffusion Batching Test" + agent_pool: mi325_1 depends_on: amd-build mirror_hardwares: [amdproduction] grade: Blocking commands: - - timeout 20m pytest -s -v tests/e2e/offline_inference/test_t2i_model.py -m "core_model and diffusion" --run-level "core_model" + - timeout 20m pytest -s -v tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py -m "core_model and diffusion" --run-level "core_model" + +- label: "Custom Pipeline Test" + agent_pool: mi325_1 + depends_on: amd-build + mirror_hardwares: [amdproduction] + grade: Blocking + commands: + - timeout 20m pytest -s -v tests/e2e/offline_inference/custom_pipeline/ -m "core_model" - label: "Diffusion Model CPU offloading Test" agent_pool: mi325_1 @@ -23,7 +47,6 @@ steps: mirror_hardwares: [amdproduction] grade: Blocking commands: - - export GPU_ARCHS=gfx942 - export VLLM_LOGGING_LEVEL=DEBUG - export VLLM_WORKER_MULTIPROC_METHOD=spawn - | @@ -46,7 +69,7 @@ steps: # - export GPU_ARCHS=gfx942 # - export VLLM_LOGGING_LEVEL=DEBUG # - export VLLM_WORKER_MULTIPROC_METHOD=spawn -# - timeout 20m pytest -s -v tests/e2e/offline_inference/test_stable_audio_model.py +# - timeout 20m pytest -s -v tests/e2e/offline_inference/test_stable_audio_expansion.py -m "advanced_model and diffusion and L4" --run-level advanced_model - label: "Diffusion Cache Backend Test" agent_pool: mi325_1 @@ -77,47 +100,58 @@ steps: commands: - timeout 20m pytest -s -v tests/diffusion/test_diffusion_worker.py -- label: "Benchmark & Engine Test" - agent_pool: mi325_2 +- label: "Engine Test" + agent_pool: mi325_1 depends_on: amd-build mirror_hardwares: [amdproduction] grade: Blocking commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - | - timeout 30m bash -c ' - set +e - pytest -s -v tests/benchmarks/test_serve_cli.py - EXIT1=\$? - pytest -s -v tests/engine/test_async_omni_engine_abort.py - EXIT2=\$? - exit \$((EXIT1 | EXIT2)) + timeout 15m bash -c ' + pytest -s -v tests/engine/test_async_omni_engine_abort.py ' -- label: "Omni Model Test Qwen2-5-Omni" - agent_pool: mi325_2 - depends_on: amd-build - mirror_hardwares: [amdproduction] - grade: Blocking - commands: - - export VLLM_LOGGING_LEVEL=DEBUG - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - timeout 17m pytest -s -v tests/e2e/offline_inference/test_qwen2_5_omni.py -- label: "Omni Model Test Qwen3-Omni" - agent_pool: mi325_2 +# NOTE: This test is not running any thing. It is skipped and deselected. +# Currently it is = 1 skipped, 1 deselected, 17 warnings in 0.03s ====== +# - label: "Omni Model Test Qwen2-5-Omni" +# agent_pool: mi325_2 +# depends_on: amd-build +# mirror_hardwares: [amdproduction] +# grade: Blocking +# commands: +# - export VLLM_LOGGING_LEVEL=DEBUG +# - export VLLM_WORKER_MULTIPROC_METHOD=spawn +# - timeout 20m pytest -s -v tests/e2e/offline_inference/test_qwen2_5_omni.py -m "core_model" --run-level "core_model" + +# - label: "Omni Model Test Qwen3-Omni" +# agent_pool: mi325_2 +# depends_on: amd-build +# mirror_hardwares: [amdproduction] +# grade: Blocking +# commands: +# - export VLLM_LOGGING_LEVEL=DEBUG +# - export VLLM_WORKER_MULTIPROC_METHOD=spawn +# - export VLLM_TEST_CLEAN_GPU_MEMORY=1 +# - timeout 10m pytest -s -v tests/e2e/offline_inference/test_qwen3_omni.py +# - timeout 20m pytest -s -v tests/e2e/online_serving/test_qwen3_omni.py -m "core_model" --run-level "core_model" + +- label: "MiMo-Audio E2E Test with H100" + agent_pool: mi325_1 depends_on: amd-build mirror_hardwares: [amdproduction] grade: Blocking commands: - - export VLLM_LOGGING_LEVEL=DEBUG - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - export VLLM_TEST_CLEAN_GPU_MEMORY=1 - - timeout 10m pytest -s -v tests/e2e/offline_inference/test_qwen3_omni.py - - timeout 10m pytest -s -v tests/e2e/online_serving/test_qwen3_omni.py -m "core_model" --run-level "core_model" + - | + timeout 30m bash -c ' + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_WORKER_MULTIPROC_METHOD=spawn + pytest -s -v tests/e2e/online_serving/test_mimo_audio.py -m "core_model" --run-level "core_model" + ' - label: "Qwen3-TTS E2E Test" - agent_pool: mi325_2 + agent_pool: mi325_1 depends_on: amd-build mirror_hardwares: [amdproduction] grade: Blocking @@ -125,55 +159,82 @@ steps: - export VLLM_LOGGING_LEVEL=DEBUG - export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1" - - timeout 20m pytest -s -v tests/e2e/online_serving/test_qwen3_tts_customvoice.py -m "core_model" --run-level "core_model" + - timeout 30m pytest -s -v tests/e2e/online_serving/test_qwen3_tts_customvoice.py -m "core_model" --run-level "core_model" -- label: "Diffusion Image Edit Test" +- label: "Voxtral-TTS E2E Test" agent_pool: mi325_1 depends_on: amd-build mirror_hardwares: [amdproduction] grade: Blocking commands: - - export GPU_ARCHS=gfx942 - - export VLLM_LOGGING_LEVEL=DEBUG - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - timeout 20m pytest -s -v tests/e2e/online_serving/test_image_gen_edit.py + - | + timeout 20m bash -c ' + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_WORKER_MULTIPROC_METHOD=spawn + pytest -s -v tests/e2e/online_serving/test_voxtral_tts.py -m "advanced_model" --run-level "advanced_model" + pytest -s -v tests/e2e/offline_inference/test_voxtral_tts.py -m "advanced_model" --run-level "advanced_model" + ' -- label: "Bagel Text2Img Model Test" +- label: "Diffusion Image Edit Test" agent_pool: mi325_1 depends_on: amd-build mirror_hardwares: [amdproduction] grade: Blocking commands: - export GPU_ARCHS=gfx942 - - export VLLM_TEST_CLEAN_GPU_MEMORY=1 - export VLLM_LOGGING_LEVEL=DEBUG - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - export VLLM_ROCM_USE_AITER_RMSNORM=0 - - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_text2img.py -m "core_model" --run-level "core_model" -k "rocm" + - timeout 20m pytest -s -v tests/e2e/online_serving/test_image_gen_edit.py -- label: "Bagel Img2Img Model Test" - agent_pool: mi325_1 - depends_on: amd-build - mirror_hardwares: [amdproduction] - grade: Blocking - commands: - - export GPU_ARCHS=gfx942 - - export VLLM_TEST_CLEAN_GPU_MEMORY=1 - - export VLLM_LOGGING_LEVEL=DEBUG - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - export VLLM_ROCM_USE_AITER_RMSNORM=0 - - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_img2img.py -m "core_model" --run-level "core_model" -k "rocm" +# TODO: Bagel test on ROCm is very unstable. @tjtanaa +# Need to debug before reneable numerical changes across large PRs +# - label: "Bagel Text2Img Model Test" +# agent_pool: mi325_1 +# depends_on: amd-build +# mirror_hardwares: [amdproduction] +# grade: Blocking +# commands: +# - export GPU_ARCHS=gfx942 +# - export VLLM_TEST_CLEAN_GPU_MEMORY=1 +# - export VLLM_LOGGING_LEVEL=DEBUG +# - export VLLM_WORKER_MULTIPROC_METHOD=spawn +# - export VLLM_ROCM_USE_AITER_RMSNORM=0 +# - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_text2img.py -m "core_model" --run-level "core_model" -k "rocm" + +# - label: "Bagel Img2Img Model Test" +# agent_pool: mi325_1 +# depends_on: amd-build +# mirror_hardwares: [amdproduction] +# grade: Blocking +# commands: +# - export GPU_ARCHS=gfx942 +# - export VLLM_TEST_CLEAN_GPU_MEMORY=1 +# - export VLLM_LOGGING_LEVEL=DEBUG +# - export VLLM_WORKER_MULTIPROC_METHOD=spawn +# - export VLLM_ROCM_USE_AITER_RMSNORM=0 +# - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_img2img.py -m "core_model" --run-level "core_model" -k "rocm" -- label: "Bagel Online Serving Test" +# - label: "Bagel Online Serving Test" +# agent_pool: mi325_1 +# depends_on: amd-build +# mirror_hardwares: [amdproduction] +# grade: Blocking +# commands: +# - export GPU_ARCHS=gfx942 +# - export VLLM_TEST_CLEAN_GPU_MEMORY=1 +# - export VLLM_IMAGE_FETCH_TIMEOUT=60 +# - export VLLM_LOGGING_LEVEL=DEBUG +# - export VLLM_WORKER_MULTIPROC_METHOD=spawn +# - export VLLM_ROCM_USE_AITER_RMSNORM=0 +# - timeout 40m pytest -s -v tests/e2e/online_serving/test_bagel_online.py -m "core_model" --run-level "core_model" -k "rocm" + +- label: "CosyVoice3-TTS E2E Test" agent_pool: mi325_1 depends_on: amd-build mirror_hardwares: [amdproduction] grade: Blocking commands: - - export GPU_ARCHS=gfx942 - - export VLLM_TEST_CLEAN_GPU_MEMORY=1 - - export VLLM_IMAGE_FETCH_TIMEOUT=60 - - export VLLM_LOGGING_LEVEL=DEBUG - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - export VLLM_ROCM_USE_AITER_RMSNORM=0 - - timeout 40m pytest -s -v tests/e2e/online_serving/test_bagel_online.py -m "core_model" --run-level "core_model" -k "rocm" + - | + timeout 20m bash -c ' + pytest -s -v tests/e2e/online_serving/test_cosyvoice3_tts.py -m "core_model" --run-level "core_model" + ' diff --git a/.buildkite/test-merge.yml b/.buildkite/test-merge.yml index 7355e2b4c7..2a6cb6488a 100644 --- a/.buildkite/test-merge.yml +++ b/.buildkite/test-merge.yml @@ -76,24 +76,6 @@ steps: volumes: - "/fsx/hf_cache:/fsx/hf_cache" - - label: "Audio Generation Model Test" - timeout_in_minutes: 20 - depends_on: upload-merge-pipeline - commands: - - pytest -s -v tests/e2e/offline_inference/test_stable_audio_model.py - agents: - queue: "gpu_1_queue" # g6.4xlarge instance on AWS, has 1 L4 GPU - plugins: - - docker#v5.2.0: - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - always-pull: true - propagate-environment: true - environment: - - "HF_HOME=/fsx/hf_cache" - - "HF_TOKEN" - volumes: - - "/fsx/hf_cache:/fsx/hf_cache" - - label: "Diffusion Cache Backend Test" timeout_in_minutes: 15 depends_on: upload-merge-pipeline @@ -113,7 +95,7 @@ steps: - "/fsx/hf_cache:/fsx/hf_cache" - label: "Diffusion Sequence Parallelism Test" - timeout_in_minutes: 20 + timeout_in_minutes: 25 depends_on: upload-merge-pipeline commands: - pytest -s -v tests/e2e/offline_inference/test_sequence_parallel.py tests/diffusion/distributed/test_ulysses_uaa_perf.py diff --git a/.buildkite/test-nightly-diffusion.yml b/.buildkite/test-nightly-diffusion.yml deleted file mode 100644 index 04b99c0a83..0000000000 --- a/.buildkite/test-nightly-diffusion.yml +++ /dev/null @@ -1,364 +0,0 @@ -# Nightly diffusion GPU tests — appended to the main nightly build via -# buildkite-agent pipeline upload .buildkite/test-nightly-diffusion.yml -# from test-nightly.yml (step key: nightly-diffusion-model-test). Top-level groups are -# foldable in the Buildkite UI (Other / Wan / Qwen-Image). -env: - VLLM_WORKER_MULTIPROC_METHOD: spawn - HF_HUB_DOWNLOAD_TIMEOUT: 300 - HF_HUB_ETAG_TIMEOUT: 60 - -steps: - - group: ":card_index_dividers: Other Model Test" - key: nightly-other-model-test-group - steps: - - label: ":full_moon: Diffusion · Other · Function Test with H100" - timeout_in_minutes: 120 - # Shared nightly vs PR label conditional; referenced below as *nightly_or_pr_label - if: &nightly_or_pr_label 'build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test"' - commands: - - pytest -s -v tests/e2e/online_serving/test_*_expansion.py -k "not test_wan22_expansion and not test_wan_2_1_vace_expansion and not test_qwen_image" -m "advanced_model and diffusion and H100" --run-level "advanced_model" - agents: - queue: "mithril-h100-pool" - plugins: - - kubernetes: - podSpec: - containers: - - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - resources: - limits: - nvidia.com/gpu: 2 - volumeMounts: - - name: devshm - mountPath: /dev/shm - - name: hf-cache - mountPath: /root/.cache/huggingface - env: - - name: HF_HOME - value: /root/.cache/huggingface - - name: HF_TOKEN - valueFrom: - secretKeyRef: - name: hf-token-secret - key: token - nodeSelector: - node.kubernetes.io/instance-type: gpu-h100-sxm - volumes: - - name: devshm - emptyDir: - medium: Memory - - name: hf-cache - hostPath: - path: /mnt/hf-cache - type: DirectoryOrCreate - - - label: ":full_moon: Diffusion · Other · Function Test with L4" - timeout_in_minutes: 60 - if: *nightly_or_pr_label - commands: - - pytest -s -v tests/e2e/online_serving/test_*_expansion.py -m "advanced_model and diffusion and L4" --run-level "advanced_model" - agents: - queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU - plugins: - - docker#v5.2.0: - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - always-pull: true - propagate-environment: true - shm-size: "8gb" - environment: - - "HF_HOME=/fsx/hf_cache" - - "HF_TOKEN" - volumes: - - "/fsx/hf_cache:/fsx/hf_cache" - - - label: ":full_moon: Diffusion · Other · Doc Test" - timeout_in_minutes: 60 - if: *nightly_or_pr_label - commands: - - export VLLM_TEST_CLEAN_GPU_MEMORY="1" - - pytest -s -v tests/examples/online_serving/test_text_to_image.py tests/examples/offline_inference/test_text_to_image.py -m "advanced_model and example and H100" --run-level "advanced_model" - agents: - queue: "mithril-h100-pool" - plugins: - - kubernetes: - podSpec: - containers: - - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - resources: - limits: - nvidia.com/gpu: 2 - volumeMounts: - - name: devshm - mountPath: /dev/shm - - name: hf-cache - mountPath: /root/.cache/huggingface - env: - - name: HF_HOME - value: /root/.cache/huggingface - - name: HF_TOKEN - valueFrom: - secretKeyRef: - name: hf-token-secret - key: token - nodeSelector: - node.kubernetes.io/instance-type: gpu-h100-sxm - volumes: - - name: devshm - emptyDir: - medium: Memory - - name: hf-cache - hostPath: - path: /mnt/hf-cache - type: DirectoryOrCreate - - - group: ":card_index_dividers: Wan Series Model Test" - key: nightly-wan-model-test-group - steps: - - label: ":full_moon: Diffusion · Wan · Function Test" - timeout_in_minutes: 90 - if: *nightly_or_pr_label - commands: - - pytest -s -v tests/e2e/online_serving/test_wan22_expansion.py tests/e2e/online_serving/test_wan_2_1_vace_expansion.py -m "advanced_model" --run-level "advanced_model" - agents: - queue: "mithril-h100-pool" - plugins: - - kubernetes: - podSpec: - containers: - - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - resources: - limits: - nvidia.com/gpu: 2 - volumeMounts: - - name: devshm - mountPath: /dev/shm - - name: hf-cache - mountPath: /root/.cache/huggingface - env: - - name: HF_HOME - value: /root/.cache/huggingface - - name: HF_TOKEN - valueFrom: - secretKeyRef: - name: hf-token-secret - key: token - nodeSelector: - node.kubernetes.io/instance-type: gpu-h100-sxm - volumes: - - name: devshm - emptyDir: - medium: Memory - - name: hf-cache - hostPath: - path: /mnt/hf-cache - type: DirectoryOrCreate - - - label: ":full_moon: Diffusion · Wan · Accuracy Test" - key: nightly-wan22-i2v-accuracy - timeout_in_minutes: 180 - if: *nightly_or_pr_label - commands: - - pytest -s -v tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py --run-level advanced_model - agents: - queue: "mithril-h100-pool" - plugins: - - kubernetes: - podSpec: - containers: - - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - resources: - limits: - nvidia.com/gpu: 2 - volumeMounts: - - name: devshm - mountPath: /dev/shm - - name: hf-cache - mountPath: /root/.cache/huggingface - env: - - name: HF_HOME - value: /root/.cache/huggingface - - name: HF_TOKEN - valueFrom: - secretKeyRef: - name: hf-token-secret - key: token - nodeSelector: - node.kubernetes.io/instance-type: gpu-h100-sxm - volumes: - - name: devshm - emptyDir: - medium: Memory - - name: hf-cache - hostPath: - path: /mnt/hf-cache - type: DirectoryOrCreate - - - group: ":card_index_dividers: Qwen-Image Series Model Test" - key: nightly-qwen-image-edit-group - steps: - - label: ":full_moon: Diffusion · Qwen-Image · Function Test with H100" - timeout_in_minutes: 120 - if: *nightly_or_pr_label - commands: - - pytest -s -v tests/e2e/online_serving/test_qwen_image*_expansion.py -m "advanced_model and diffusion and H100" --run-level "advanced_model" - agents: - queue: "mithril-h100-pool" - plugins: - - kubernetes: - podSpec: - containers: - - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - resources: - limits: - nvidia.com/gpu: 2 - volumeMounts: - - name: devshm - mountPath: /dev/shm - - name: hf-cache - mountPath: /root/.cache/huggingface - env: - - name: HF_HOME - value: /root/.cache/huggingface - - name: HF_TOKEN - valueFrom: - secretKeyRef: - name: hf-token-secret - key: token - nodeSelector: - node.kubernetes.io/instance-type: gpu-h100-sxm - volumes: - - name: devshm - emptyDir: - medium: Memory - - name: hf-cache - hostPath: - path: /mnt/hf-cache - type: DirectoryOrCreate - - - label: ":full_moon: Diffusion · Qwen-Image · GEBench Accuracy Test" - key: nightly-gebench-accuracy - timeout_in_minutes: 60 - if: *nightly_or_pr_label - commands: - - pytest -s -v tests/e2e/accuracy/test_gebench_h100_smoke.py --run-level advanced_model --gebench-model Qwen/Qwen-Image-2512 --accuracy-judge-model QuantTrio/Qwen3-VL-30B-A3B-Instruct-AWQ --accuracy-gpu 0 --gebench-port 8093 --accuracy-workers 1 - - buildkite-agent artifact upload "tests/e2e/accuracy/artifacts/gebench_qwen-image-2512/summary*.json" - agents: - queue: "mithril-h100-pool" - plugins: - - kubernetes: - podSpec: - containers: - - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - resources: - limits: - nvidia.com/gpu: 1 - volumeMounts: - - name: devshm - mountPath: /dev/shm - - name: hf-cache - mountPath: /root/.cache/huggingface - env: - - name: HF_HOME - value: /root/.cache/huggingface - - name: HF_TOKEN - valueFrom: - secretKeyRef: - name: hf-token-secret - key: token - nodeSelector: - node.kubernetes.io/instance-type: gpu-h100-sxm - volumes: - - name: devshm - emptyDir: - medium: Memory - - name: hf-cache - hostPath: - path: /mnt/hf-cache - type: DirectoryOrCreate - - - label: ":full_moon: Diffusion · Qwen-Image · GEdit-Bench Accuracy Test" - key: nightly-gedit-bench-accuracy - timeout_in_minutes: 60 - if: *nightly_or_pr_label - commands: - - pytest -s -v tests/e2e/accuracy/test_gedit_bench_h100_smoke.py --run-level advanced_model --gedit-model Qwen/Qwen-Image-Edit --accuracy-judge-model QuantTrio/Qwen3-VL-30B-A3B-Instruct-AWQ --accuracy-gpu 0 --gedit-port 8093 --gedit-samples-per-group 20 --accuracy-workers 1 - - buildkite-agent artifact upload "tests/e2e/accuracy/artifacts/gedit_scores_qwen-image-edit/qwen-image-edit_all_all_vie_score_*.csv" - - buildkite-agent artifact upload "tests/e2e/accuracy/artifacts/gedit_scores_qwen-image-edit/qwen-image-edit_all_all_summary_*.json" - agents: - queue: "mithril-h100-pool" - plugins: - - kubernetes: - podSpec: - containers: - - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - resources: - limits: - nvidia.com/gpu: 1 - volumeMounts: - - name: devshm - mountPath: /dev/shm - - name: hf-cache - mountPath: /root/.cache/huggingface - env: - - name: HF_HOME - value: /root/.cache/huggingface - - name: VLLM_HTTP_TIMEOUT_KEEP_ALIVE - value: "120" - - name: HF_TOKEN - valueFrom: - secretKeyRef: - name: hf-token-secret - key: token - nodeSelector: - node.kubernetes.io/instance-type: gpu-h100-sxm - volumes: - - name: devshm - emptyDir: - medium: Memory - - name: hf-cache - hostPath: - path: /mnt/hf-cache - type: DirectoryOrCreate - - - label: ":full_moon: Diffusion · Qwen-Image · Perf Test" - key: nightly-qwen-image-performance - timeout_in_minutes: 180 - if: *nightly_or_pr_label - commands: - - export DIFFUSION_BENCHMARK_DIR=tests/dfx/perf/results - - export CACHE_DIT_VERSION=1.3.0 - - pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --config-file tests/dfx/perf/tests/test_qwen_image_vllm_omni.json - - buildkite-agent artifact upload "tests/dfx/perf/results/benchmark_results_*.json" - - buildkite-agent artifact upload "tests/dfx/perf/results/logs/*.log" - agents: - queue: "mithril-h100-pool" - plugins: - - kubernetes: - podSpec: - containers: - - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - resources: - limits: - nvidia.com/gpu: 4 - volumeMounts: - - name: devshm - mountPath: /dev/shm - - name: hf-cache - mountPath: /root/.cache/huggingface - env: - - name: HF_HOME - value: /root/.cache/huggingface - - name: HF_TOKEN - valueFrom: - secretKeyRef: - name: hf-token-secret - key: token - nodeSelector: - node.kubernetes.io/instance-type: gpu-h100-sxm - volumes: - - name: devshm - emptyDir: - medium: Memory - - name: hf-cache - hostPath: - path: /mnt/hf-cache - type: DirectoryOrCreate diff --git a/.buildkite/test-nightly.yml b/.buildkite/test-nightly.yml index 06b7c14ae1..02d8cced40 100644 --- a/.buildkite/test-nightly.yml +++ b/.buildkite/test-nightly.yml @@ -7,12 +7,11 @@ steps: # Group: collapses under one heading in the Buildkite UI; child steps still run in parallel. - group: ":card_index_dividers: Omni Model Test" key: nightly-omni-test-group + depends_on: upload-nightly-pipeline + if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test" || build.pull_request.labels includes "omni-test" steps: - - label: ":full_moon: Omni · Function Test with H100" + - label: ":full_moon: Omni · Function Test" timeout_in_minutes: 90 - depends_on: upload-nightly-pipeline - # Shared nightly vs PR label conditional; referenced below as *nightly_or_pr_label - if: &nightly_or_pr_label 'build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test"' commands: - pytest -s -v tests/e2e/online_serving/test_*_expansion.py -m "advanced_model and H100 and omni" --run-level "advanced_model" agents: @@ -49,13 +48,11 @@ steps: path: /mnt/hf-cache type: DirectoryOrCreate - - label: ":full_moon: Omni · Function Test with L4" + - label: ":full_moon: Omni · Doc Test with L4" timeout_in_minutes: 90 - depends_on: upload-nightly-pipeline - if: *nightly_or_pr_label commands: - export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1" - - pytest -s -v tests/e2e/online_serving/test_*_expansion.py -m "advanced_model and L4 and omni" --run-level "advanced_model" + - pytest -s -v tests/examples/ -m "advanced_model and omni and L4" --run-level "advanced_model" agents: queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU plugins: @@ -70,13 +67,211 @@ steps: volumes: - "/fsx/hf_cache:/fsx/hf_cache" - - label: ":full_moon: Omni · Doc Test with L4" + - label: ":full_moon: Omni · Doc Test with H100" + timeout_in_minutes: 90 + commands: + - pytest -s -v tests/examples/ -m "advanced_model and omni and H100" --run-level "advanced_model" + agents: + queue: "mithril-h100-pool" + plugins: + - kubernetes: + podSpec: + containers: + - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + resources: + limits: + nvidia.com/gpu: 2 + volumeMounts: + - name: devshm + mountPath: /dev/shm + - name: hf-cache + mountPath: /root/.cache/huggingface + env: + - name: HF_HOME + value: /root/.cache/huggingface + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + nodeSelector: + node.kubernetes.io/instance-type: gpu-h100-sxm + volumes: + - name: devshm + emptyDir: + medium: Memory + - name: hf-cache + hostPath: + path: /mnt/hf-cache + type: DirectoryOrCreate + + - label: ":full_moon: Omni · Perf Test" + key: nightly-omni-performance + timeout_in_minutes: 180 + commands: + - export BENCHMARK_DIR=tests/dfx/perf/results + - | + set +e + pytest -s -v tests/dfx/perf/scripts/run_benchmark.py --test-config-file tests/dfx/perf/tests/test_qwen_omni.json + EXIT=$$? + buildkite-agent artifact upload "tests/dfx/perf/results/*.json" + exit $$EXIT + agents: + queue: "mithril-h100-pool" + plugins: + - kubernetes: + podSpec: + containers: + - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + resources: + limits: + nvidia.com/gpu: 2 + volumeMounts: + - name: devshm + mountPath: /dev/shm + - name: hf-cache + mountPath: /root/.cache/huggingface + env: + - name: HF_HOME + value: /root/.cache/huggingface + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + nodeSelector: + node.kubernetes.io/instance-type: gpu-h100-sxm + volumes: + - name: devshm + emptyDir: + medium: Memory + - name: hf-cache + hostPath: + path: /mnt/hf-cache + type: DirectoryOrCreate + + + - group: ":card_index_dividers: TTS Model Test" + key: nightly-tts-test-group + depends_on: upload-nightly-pipeline + if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test" || build.pull_request.labels includes "tts-test" + steps: + - label: ":full_moon: TTS · Function Test" timeout_in_minutes: 90 - depends_on: upload-nightly-pipeline - if: *nightly_or_pr_label commands: - export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1" - - pytest -s -v tests/examples/ -m "advanced_model and omni and L4" --run-level "advanced_model" + - pytest -s -v tests/e2e/online_serving/test_*_expansion.py -m "advanced_model and L4 and omni" --run-level "advanced_model" + agents: + queue: "gpu_1_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + shm-size: "8gb" + environment: + - "HF_HOME=/fsx/hf_cache" + - "HF_TOKEN" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + + - label: ":full_moon: TTS · Perf Test" + key: nightly-tts-performance + timeout_in_minutes: 180 + commands: + - export BENCHMARK_DIR=tests/dfx/perf/results + - export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1" + - | + set +e + pytest -s -v tests/dfx/perf/scripts/run_benchmark.py --test-config-file tests/dfx/perf/tests/test_tts.json + EXIT=$$? + buildkite-agent artifact upload "tests/dfx/perf/results/*.json" + exit $$EXIT + agents: + queue: "mithril-h100-pool" + plugins: + - kubernetes: + podSpec: + containers: + - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + resources: + limits: + nvidia.com/gpu: 1 + volumeMounts: + - name: devshm + mountPath: /dev/shm + - name: hf-cache + mountPath: /root/.cache/huggingface + env: + - name: HF_HOME + value: /root/.cache/huggingface + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + nodeSelector: + node.kubernetes.io/instance-type: gpu-h100-sxm + volumes: + - name: devshm + emptyDir: + medium: Memory + - name: hf-cache + hostPath: + path: /mnt/hf-cache + type: DirectoryOrCreate + + # Diffusion X2I suite: x2i / x2a / x2t and related non-video paths; x2v is only in "Diffusion X2V Model Test" below. + - group: ":card_index_dividers: Diffusion X2I(&A&T) Model Test" + key: nightly-diffusion-x2iat-group + depends_on: upload-nightly-pipeline + if: >- + build.env("NIGHTLY") == "1" || + build.pull_request.labels includes "nightly-test" || + build.pull_request.labels includes "diffusion-x2iat-test" + steps: + - label: ":full_moon: Diffusion X2I(&A&T) · Function Test with H100" + timeout_in_minutes: 120 + commands: + - pytest -s -v tests/e2e/online_serving/test_*_expansion.py -k "not test_wan22_expansion and not test_wan_2_1_vace_expansion and not hunyuan" -m "advanced_model and diffusion and H100" --run-level "advanced_model" + agents: + queue: "mithril-h100-pool" + plugins: + - kubernetes: + podSpec: + containers: + - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + resources: + limits: + nvidia.com/gpu: 2 + volumeMounts: + - name: devshm + mountPath: /dev/shm + - name: hf-cache + mountPath: /root/.cache/huggingface + env: + - name: HF_HOME + value: /root/.cache/huggingface + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + nodeSelector: + node.kubernetes.io/instance-type: gpu-h100-sxm + volumes: + - name: devshm + emptyDir: + medium: Memory + - name: hf-cache + hostPath: + path: /mnt/hf-cache + type: DirectoryOrCreate + + - label: ":full_moon: Diffusion X2I(&A&T) · Function Test with L4" + timeout_in_minutes: 60 + commands: + - pytest -s -v tests/e2e/online_serving/test_*_expansion.py -k "not test_wan22_expansion and not test_wan_2_1_vace_expansion and not hunyuan" -m "advanced_model and diffusion and L4" --run-level "advanced_model" agents: queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU plugins: @@ -91,12 +286,11 @@ steps: volumes: - "/fsx/hf_cache:/fsx/hf_cache" - - label: ":full_moon: Omni · Doc Test with H100" - timeout_in_minutes: 90 - depends_on: upload-nightly-pipeline - if: *nightly_or_pr_label + - label: ":full_moon: Diffusion X2I(&A&T) · Doc Test" + timeout_in_minutes: 60 commands: - - pytest -s -v tests/examples/ -m "advanced_model and omni and H100" --run-level "advanced_model" + - export VLLM_TEST_CLEAN_GPU_MEMORY="1" + - pytest -s -v tests/examples/*/test_text_to_image.py -m "advanced_model and example and H100" --run-level "advanced_model" agents: queue: "mithril-h100-pool" plugins: @@ -131,17 +325,91 @@ steps: path: /mnt/hf-cache type: DirectoryOrCreate - - label: ":full_moon: Omni · Perf Test" - key: nightly-omni-performance + - label: ":full_moon: Diffusion X2I(&A&T) · GEBench Accuracy Test" + timeout_in_minutes: 60 + commands: + - pytest -s -v tests/e2e/accuracy/test_gebench_h100_smoke.py --run-level advanced_model --gebench-model Qwen/Qwen-Image-2512 --accuracy-judge-model QuantTrio/Qwen3-VL-30B-A3B-Instruct-AWQ --accuracy-gpu 0 --gebench-port 8093 --accuracy-workers 1 + - buildkite-agent artifact upload "tests/e2e/accuracy/artifacts/gebench_qwen-image-2512/summary*.json" + agents: + queue: "mithril-h100-pool" + plugins: + - kubernetes: + podSpec: + containers: + - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + resources: + limits: + nvidia.com/gpu: 1 + volumeMounts: + - name: devshm + mountPath: /dev/shm + - name: hf-cache + mountPath: /root/.cache/huggingface + env: + - name: HF_HOME + value: /root/.cache/huggingface + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + nodeSelector: + node.kubernetes.io/instance-type: gpu-h100-sxm + volumes: + - name: devshm + emptyDir: + medium: Memory + - name: hf-cache + hostPath: + path: /mnt/hf-cache + type: DirectoryOrCreate + + - label: ":full_moon: Diffusion X2I(&A&T) · GEdit-Bench Accuracy Test" + timeout_in_minutes: 60 + commands: + - pytest -s -v tests/e2e/accuracy/test_gedit_bench_h100_smoke.py --run-level advanced_model --gedit-model Qwen/Qwen-Image-Edit --accuracy-judge-model QuantTrio/Qwen3-VL-30B-A3B-Instruct-AWQ --accuracy-gpu 0 --gedit-port 8093 --gedit-samples-per-group 20 --accuracy-workers 1 + - buildkite-agent artifact upload "tests/e2e/accuracy/artifacts/gedit_scores_qwen-image-edit/qwen-image-edit_all_all_vie_score_*.csv" + - buildkite-agent artifact upload "tests/e2e/accuracy/artifacts/gedit_scores_qwen-image-edit/qwen-image-edit_all_all_summary_*.json" + agents: + queue: "mithril-h100-pool" + plugins: + - kubernetes: + podSpec: + containers: + - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + resources: + limits: + nvidia.com/gpu: 1 + volumeMounts: + - name: devshm + mountPath: /dev/shm + - name: hf-cache + mountPath: /root/.cache/huggingface + env: + - name: HF_HOME + value: /root/.cache/huggingface + - name: VLLM_HTTP_TIMEOUT_KEEP_ALIVE + value: "120" + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + nodeSelector: + node.kubernetes.io/instance-type: gpu-h100-sxm + volumes: + - name: devshm + emptyDir: + medium: Memory + - name: hf-cache + hostPath: + path: /mnt/hf-cache + type: DirectoryOrCreate + + - label: ":full_moon: Diffusion X2I(&A&T) · Accuracy Test" timeout_in_minutes: 180 - depends_on: upload-nightly-pipeline - if: *nightly_or_pr_label commands: - - export BENCHMARK_DIR=tests/dfx/perf/results - - export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1" - - pytest -s -v tests/dfx/perf/scripts/run_benchmark.py - - buildkite-agent artifact upload "tests/dfx/perf/results/*.json" - - buildkite-agent artifact upload "tests/dfx/perf/results/*.html" + - pytest -s -v tests/e2e/accuracy/test_qwen_image*.py --run-level advanced_model agents: queue: "mithril-h100-pool" plugins: @@ -151,7 +419,65 @@ steps: - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT resources: limits: - nvidia.com/gpu: 2 + nvidia.com/gpu: 1 + volumeMounts: + - name: devshm + mountPath: /dev/shm + - name: hf-cache + mountPath: /root/.cache/huggingface + env: + - name: HF_HOME + value: /root/.cache/huggingface + - name: VLLM_HTTP_TIMEOUT_KEEP_ALIVE + value: "120" + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + nodeSelector: + node.kubernetes.io/instance-type: gpu-h100-sxm + volumes: + - name: devshm + emptyDir: + medium: Memory + - name: hf-cache + hostPath: + path: /mnt/hf-cache + type: DirectoryOrCreate + + - label: ":full_moon: Diffusion X2I(&A&T) · Perf Test" + key: nightly-diffusion-x2iat-performance + timeout_in_minutes: 180 + commands: + - export DIFFUSION_BENCHMARK_DIR=tests/dfx/perf/results + - export DIFFUSION_ATTENTION_BACKEND=FLASH_ATTN + - export CACHE_DIT_VERSION=1.3.0 + # [HACK]: run upload in the same command block as pytest. + # Because `exit` aborts the entire commands list. + - | + set +e + pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --test-config-file tests/dfx/perf/tests/test_qwen_image_vllm_omni.json + EXIT1=$$? + pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --test-config-file tests/dfx/perf/tests/test_qwen_image_edit_vllm_omni.json + EXIT2=$$? + pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --test-config-file tests/dfx/perf/tests/test_qwen_image_edit_2509_vllm_omni.json + EXIT3=$$? + pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --test-config-file tests/dfx/perf/tests/test_qwen_image_layered_vllm_omni.json + EXIT4=$$? + buildkite-agent artifact upload "tests/dfx/perf/results/diffusion_result_*.json" + buildkite-agent artifact upload "tests/dfx/perf/results/logs/*.log" + exit $$((EXIT1 | EXIT2 | EXIT3 | EXIT4)) + agents: + queue: "mithril-h100-pool" + plugins: + - kubernetes: + podSpec: + containers: + - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + resources: + limits: + nvidia.com/gpu: 4 volumeMounts: - name: devshm mountPath: /dev/shm @@ -176,23 +502,96 @@ steps: path: /mnt/hf-cache type: DirectoryOrCreate - # Dynamically appends steps from test-nightly-diffusion.yml into this build (same mechanism as - # pipeline.yml → test-ready.yml / test-merge.yml / test-nightly.yml). Foldable groups stay in the - # uploaded YAML (Other / Wan / Qwen-Image). - - label: ":card_index_dividers: Diffusion Model Test" - key: nightly-diffusion-model-test + # Diffusion x2v only (Wan, HunyuanVideo, …). x2i/x2a/x2t live in the X2I group above, not here. + - group: ":card_index_dividers: Diffusion X2V Model Test" + key: nightly-diffusion-x2v-group depends_on: upload-nightly-pipeline - if: *nightly_or_pr_label - commands: - - buildkite-agent pipeline upload .buildkite/test-nightly-diffusion.yml - agents: - queue: "cpu_queue_premerge" + if: >- + build.env("NIGHTLY") == "1" || + build.pull_request.labels includes "nightly-test" || + build.pull_request.labels includes "diffusion-x2v-test" + steps: + - label: ":full_moon: Diffusion X2V · Function Test" + timeout_in_minutes: 90 + commands: + - pytest -s -v tests/e2e/online_serving/test_wan22_expansion.py tests/e2e/online_serving/test_wan_2_1_vace_expansion.py tests/e2e/online_serving/test_hunyuan_video_15_expansion.py -m "advanced_model" --run-level "advanced_model" + agents: + queue: "mithril-h100-pool" + plugins: + - kubernetes: + podSpec: + containers: + - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + resources: + limits: + nvidia.com/gpu: 2 + volumeMounts: + - name: devshm + mountPath: /dev/shm + - name: hf-cache + mountPath: /root/.cache/huggingface + env: + - name: HF_HOME + value: /root/.cache/huggingface + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + nodeSelector: + node.kubernetes.io/instance-type: gpu-h100-sxm + volumes: + - name: devshm + emptyDir: + medium: Memory + - name: hf-cache + hostPath: + path: /mnt/hf-cache + type: DirectoryOrCreate + + - label: ":full_moon: Diffusion X2V · Accuracy Test" + timeout_in_minutes: 180 + commands: + - pytest -s -v tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py --run-level advanced_model + agents: + queue: "mithril-h100-pool" + plugins: + - kubernetes: + podSpec: + containers: + - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + resources: + limits: + nvidia.com/gpu: 2 + volumeMounts: + - name: devshm + mountPath: /dev/shm + - name: hf-cache + mountPath: /root/.cache/huggingface + env: + - name: HF_HOME + value: /root/.cache/huggingface + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + nodeSelector: + node.kubernetes.io/instance-type: gpu-h100-sxm + volumes: + - name: devshm + emptyDir: + medium: Memory + - name: hf-cache + hostPath: + path: /mnt/hf-cache + type: DirectoryOrCreate - label: ":bar_chart: Testcase Statistics" key: nightly-testcase-statistics timeout_in_minutes: 120 depends_on: upload-nightly-pipeline - if: *nightly_or_pr_label + if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test" commands: - python tools/nightly/buildkite_testcase_statistics.py -o tests/dfx/perf/results/buildkite_testcase_statistics.html - buildkite-agent artifact upload "tests/dfx/perf/results/*.html" @@ -235,16 +634,18 @@ steps: key: nightly-perf-distribution depends_on: - nightly-omni-performance - - nightly-qwen-image-performance + - nightly-tts-performance + - nightly-diffusion-x2iat-performance - nightly-testcase-statistics if: build.env("NIGHTLY") == "1" commands: - pip install openpyxl - export DEFAULT_INPUT_DIR=tests/dfx/perf/results - export DEFAULT_OUTPUT_DIR=tests/dfx/perf/results + - buildkite-agent artifact download "tests/dfx/perf/results/*.json" . --step nightly-tts-performance - buildkite-agent artifact download "tests/dfx/perf/results/*.json" . --step nightly-omni-performance - - buildkite-agent artifact download "tests/dfx/perf/results/*.json" . --step nightly-qwen-image-performance - - buildkite-agent artifact download "tests/dfx/perf/results/*.html" . --step nightly-omni-performance + - buildkite-agent artifact download "tests/dfx/perf/results/*.json" . --step nightly-diffusion-x2iat-performance + - buildkite-agent artifact download "tests/dfx/perf/results/*.html" . --step nightly-testcase-statistics - python tools/nightly/generate_nightly_perf_excel.py - python tools/nightly/generate_nightly_perf_html.py - python tools/nightly/send_nightly_email.py --report-file "tests/dfx/perf/results/*.xlsx, tests/dfx/perf/results/*.html" diff --git a/.buildkite/test-ready.yml b/.buildkite/test-ready.yml index 2f1f05463a..3ca1747fe6 100644 --- a/.buildkite/test-ready.yml +++ b/.buildkite/test-ready.yml @@ -123,7 +123,7 @@ steps: - label: "Audio Generation Model Test" depends_on: upload-ready-pipeline commands: - - timeout 20m pytest -s -v tests/e2e/offline_inference/test_stable_audio_model.py + - timeout 20m pytest -s -v tests/e2e/offline_inference/test_stable_audio_expansion.py -m "advanced_model and diffusion and L4" --run-level advanced_model agents: queue: "gpu_1_queue" # g6.4xlarge instance on AWS, has 1 L4 GPU plugins: @@ -194,28 +194,6 @@ steps: volumes: - "/fsx/hf_cache:/fsx/hf_cache" - - - label: "Omni Model Test" - depends_on: upload-ready-pipeline - commands: - - | - timeout 17m bash -c ' - export VLLM_LOGGING_LEVEL=DEBUG - pytest -s -v tests/e2e/online_serving/test_qwen2_5_omni.py -m "core_model" --run-level "core_model" - ' - agents: - queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU - plugins: - - docker#v5.2.0: - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - always-pull: true - propagate-environment: true - environment: - - "HF_HOME=/fsx/hf_cache" - - "HF_TOKEN" - volumes: - - "/fsx/hf_cache:/fsx/hf_cache" - - label: "Omni Model Test with H100" depends_on: upload-ready-pipeline commands: @@ -317,6 +295,56 @@ steps: volumes: - "/fsx/hf_cache:/fsx/hf_cache" + - label: "VoxCPM E2E Test" + timeout_in_minutes: 20 + depends_on: upload-ready-pipeline + commands: + - | + timeout 20m bash -c ' + pip install voxcpm + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_WORKER_MULTIPROC_METHOD=spawn + pytest -s -v tests/e2e/offline_inference/test_voxcpm.py -m "core_model" --run-level "core_model" + ' + agents: + queue: "gpu_1_queue" + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + shm-size: "8gb" + environment: + - "HF_HOME=/fsx/hf_cache" + - "HF_TOKEN" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + + - label: "VoxCPM2 Native AR E2E Test" + timeout_in_minutes: 20 + depends_on: upload-ready-pipeline + commands: + - | + timeout 20m bash -c ' + pip install voxcpm + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_WORKER_MULTIPROC_METHOD=spawn + pytest -s -v tests/e2e/offline_inference/test_voxcpm2.py -m "core_model" --run-level "core_model" + ' + agents: + queue: "gpu_1_queue" + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + shm-size: "8gb" + environment: + - "HF_HOME=/fsx/hf_cache" + - "HF_TOKEN" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + - label: "OmniVoice E2E Test" timeout_in_minutes: 20 depends_on: upload-ready-pipeline @@ -339,6 +367,33 @@ steps: volumes: - "/fsx/hf_cache:/fsx/hf_cache" + - label: "Qwen3-TTS Base E2E Test (ModelRunner V2)" + depends_on: upload-ready-pipeline + soft_fail: + - exit_status: 1 + commands: + - | + timeout 20m bash -c ' + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_WORKER_MULTIPROC_METHOD=spawn + export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1" + export VLLM_OMNI_USE_V2_RUNNER="1" + pytest -s -v tests/e2e/online_serving/test_qwen3_tts_base.py -m "core_model" --run-level "core_model" + ' + agents: + queue: "gpu_1_queue" + plugins: + - docker#v5.2.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + always-pull: true + propagate-environment: true + shm-size: "8gb" + environment: + - "HF_HOME=/fsx/hf_cache" + - "HF_TOKEN" + volumes: + - "/fsx/hf_cache:/fsx/hf_cache" + - label: "Voxtral-TTS E2E Test" timeout_in_minutes: 20 depends_on: upload-ready-pipeline diff --git a/.buildkite/test-template-amd-omni.j2 b/.buildkite/test-template-amd-omni.j2 index 8dc91a1172..f4c386a5fe 100644 --- a/.buildkite/test-template-amd-omni.j2 +++ b/.buildkite/test-template-amd-omni.j2 @@ -48,6 +48,9 @@ DOCKER_BUILDKIT: "1" TEST_COMMAND: |- (command rocm-smi || true) && cd {{ (step.working_dir or default_working_dir) | safe }} +{% if "mi250" in step.agent_pool %} + python3 -m pip uninstall -y amd-aiter +{% endif %} {{ indented_cmd | safe }} priority: 100 {% if step.grade and step.grade == "Blocking" %} diff --git a/.claude/skills/add-diffusion-model/SKILL.md b/.claude/skills/add-diffusion-model/SKILL.md new file mode 100644 index 0000000000..a7e0bbf9a5 --- /dev/null +++ b/.claude/skills/add-diffusion-model/SKILL.md @@ -0,0 +1,534 @@ +--- +name: add-diffusion-model +description: Add a new diffusion model (text-to-image, text-to-video, image-to-video, text-to-audio, image editing) to vLLM-Omni, including Cache-DiT acceleration and parallelism support (TP, SP/USP, CFG-Parallel, HSDP). Use when integrating a new diffusion model, porting a diffusers pipeline or a custom model repo to vllm-omni, creating a new DiT transformer adapter, adding diffusion model support, or enabling multi-GPU parallelism and cache acceleration for an existing model. +--- + +# Adding a Diffusion Model to vLLM-Omni + +## Overview + +This skill guides you through adding a new diffusion model to vLLM-Omni. The model may come from HuggingFace Diffusers (structured pipeline) or from a private/custom repo. The workflow differs significantly depending on the source. + +## Prerequisites + +Before starting, determine: + +1. **Model category**: Text-to-Image, Text-to-Video, Image-to-Video, Image Editing, Text-to-Audio, or Omni +2. **Reference source**: Diffusers pipeline, custom repo, or a combination +3. **Model HuggingFace ID** or local checkpoint path +4. **Architecture**: Scheduler, text encoder, VAE, transformer/backbone + +## Step 0: Classify the Migration Path + +Check the model's HF repo for `model_index.json`. This determines your path: + +| Scenario | How to identify | Migration path | +|----------|----------------|----------------| +| **Already supported** | `_class_name` in `model_index.json` matches a key in `_DIFFUSION_MODELS` in `registry.py` | Skip to Step 5 (test) and Step 7 (docs) | +| **Diffusers-based** | Has standard `model_index.json` with `_diffusers_version`, subfolders for `transformer/`, `vae/`, etc. | Follow **Path A** below | +| **Custom/private repo** | No diffusers `model_index.json`, weights in non-standard format, custom model code in a separate git repo | Follow **Path B** below | +| **Hybrid** | Has some diffusers components (VAE) but custom transformer/fusion | Mix of Path A and Path B | + +## Path A: Diffusers-Based Model + +For models with a standard diffusers layout. See [references/transformer-adaptation.md](references/transformer-adaptation.md) for detailed code patterns. + +### A1. Analyze `model_index.json` + +Identify components: `transformer`, `scheduler`, `vae`, `text_encoder`, `tokenizer`. + +### A2. Create model directory + +``` +vllm_omni/diffusion/models/your_model_name/ +├── __init__.py +├── pipeline_your_model.py +└── your_model_transformer.py +``` + +### A3. Adapt transformer + +1. Copy from diffusers source. Remove mixins (`ModelMixin`, `ConfigMixin`, `AttentionModuleMixin`). +2. Replace attention with `vllm_omni.diffusion.attention.layer.Attention` (QKV shape: `[B, seq, heads, head_dim]`). +3. Add `od_config: OmniDiffusionConfig | None = None` to `__init__`. +4. Add `load_weights()` method mapping diffusers weight names to vllm-omni names. +5. Add class attributes: `_repeated_blocks`, `_layerwise_offload_blocks_attr`. + +### A4. Adapt pipeline + +Inherit from `nn.Module`. The key contract: + +```python +class YourPipeline(nn.Module): + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): + # Load VAE, text encoder, tokenizer via from_pretrained() + # Instantiate transformer (weights loaded later via weights_sources) + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, subfolder="transformer", + prefix="transformer.", fall_back_to_pt=True)] + + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: + # Encode prompt → prepare latents → denoise loop → VAE decode + return DiffusionOutput(output=output) + + def load_weights(self, weights): + return AutoWeightsLoader(self).load_weights(weights) +``` + +Add post/pre-process functions in the same pipeline file. Register them in `registry.py`. + +### A5. Register, test, docs → continue at Step 4 below. + +--- + +## Path B: Custom/Private Repo Model + +For models without a diffusers pipeline — weights in custom format, model code in a private repo. Real examples: DreamID-Omni, BAGEL, HunyuanImage3. + +### B1. Understand the reference repo + +Study the original model's code to identify: +- **Model architecture files** (transformers, fusion modules, embeddings) +- **Weight format** (safetensors, `.pth`, custom checkpoint structure) +- **Weight loading helpers** (custom init functions, checkpoint loaders) +- **Pre/post-processing** (image/audio transforms, tokenization, VAE encode/decode) +- **External dependencies** (packages not on PyPI) +- **Config format** (JSON config files, hardcoded dicts) + +### B2. Decide what lives WHERE + +This is the key design decision for custom models. Follow these placement rules: + +| Code type | Where to place | Example | +|-----------|---------------|---------| +| **Pipeline orchestration** (init, forward, denoise loop) | `vllm_omni/diffusion/models//pipeline_.py` | Always required | +| **Custom transformer/backbone** (ported and adapted to vllm-omni) | `vllm_omni/diffusion/models//_transformer.py` or similar | `wan2_2.py`, `fusion.py`, `bagel_transformer.py` | +| **Custom sub-models** (VAE, fusion, autoencoder) | `vllm_omni/diffusion/models//` as separate files | `autoencoder.py`, `fusion.py` | +| **External dependency code** (original repo utilities) | **External repo**, installed via download script or pip | `dreamid_omni` package via git clone | +| **Hardcoded model configs** | Module-level dicts in pipeline file | `VIDEO_CONFIG`, `AUDIO_CONFIG` dicts | +| **Download/setup script** | `examples/offline_inference//download_.py` | `download_dreamid_omni.py` | +| **Custom `model_index.json`** | Generated by download script, placed at model root | Minimal: `{"_class_name": "YourPipeline", ...}` | + +### B3. Handle external dependencies + +If the model's code lives in a separate git repo: + +**Option 1: Import with graceful fallback** (recommended for models with external utils) + +```python +try: + from external_model.utils import init_vae, load_checkpoint +except ImportError: + raise ImportError( + "Failed to import from dependency 'external_model'. " + "Please run the download script first." + ) +``` + +**Option 2: Port the code directly** (preferred when feasible) + +Copy the essential model files into `vllm_omni/diffusion/models//` and adapt them. This avoids external dependencies. BAGEL does this — `autoencoder.py` and `bagel_transformer.py` are ported directly. + +**Decision criteria**: Port if the code is self-contained and won't diverge. Use external deps if the model repo is actively maintained and the code is complex. + +### B4. Handle custom weight loading + +Custom models have two patterns for weight loading: + +**Pattern 1: Bypass standard loader** (DreamID-Omni style) + +When the original model has complex custom init functions that load weights in `__init__`: + +```python +class CustomPipeline(nn.Module): + def __init__(self, *, od_config, prefix=""): + super().__init__() + model = od_config.model + # Load everything eagerly in __init__ using custom helpers + self.vae = custom_init_vae(model, device=self.device) + self.text_encoder = custom_init_text_encoder(model, device=self.device) + self.transformer = CustomFusionModel(CONFIG) + load_custom_checkpoint(self.transformer, + checkpoint_path=os.path.join(model, "model.safetensors")) + # NO weights_sources defined — bypasses standard loader + + def load_weights(self, weights): + pass # No-op — all weights loaded in __init__ +``` + +**Pattern 2: Use standard loader with custom `load_weights`** (BAGEL style) + +When weights are in safetensors format but need name remapping: + +```python +class CustomPipeline(nn.Module): + def __init__(self, *, od_config, prefix=""): + super().__init__() + # Instantiate model architecture without weights + self.bagel = BagelModel(config) + self.vae = AutoEncoder(ae_params) + + # Point loader at the safetensors in the model root + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder=None, # weights at root, not in subfolder + prefix="", + fall_back_to_pt=False, + ) + ] + + def load_weights(self, weights): + # Custom name remapping for non-diffusers weight names + params = dict(self.named_parameters()) + loaded = set() + for name, tensor in weights: + # Remap original weight names to vllm-omni module names + name = self._remap_weight_name(name) + if name in params: + default_weight_loader(params[name], tensor) + loaded.add(name) + return loaded +``` + +### B5. Create the `model_index.json` + +Custom models need a `model_index.json` at the model root for vllm-omni to discover them. For custom models, this is minimal: + +```json +{ + "_class_name": "YourModelPipeline", + "custom_key": "path/to/custom_weights.safetensors" +} +``` + +The `_class_name` must match a key in `_DIFFUSION_MODELS` in `registry.py`. Additional keys are model-specific (accessed via `od_config.model_config`). + +If the model's weights come from multiple HF repos, write a **download script** that: +1. Downloads from each repo +2. Assembles into a single directory +3. Generates `model_index.json` +4. Installs any external dependencies (git clone + `.pth` file) + +Place at: `examples/offline_inference//download_.py` + +### B6. Handle multi-modal inputs + +If the model accepts images, audio, or other multi-modal inputs, implement the protocol classes from `vllm_omni/diffusion/models/interface.py`: + +```python +from vllm_omni.diffusion.models.interface import SupportImageInput, SupportAudioInput + +class MyPipeline(nn.Module, SupportImageInput, SupportAudioInput): + # Protocol markers — the engine uses these to enable proper input routing + pass +``` + +Preprocessing for custom models is typically done **inside `forward()`** rather than via registered pre-process functions, since the logic is often tightly coupled to the model. + +### B7. Continue at Step 4 below. + +--- + +## Common Steps (Both Paths) + +### Step 4: Register Model in registry.py + +Edit `vllm_omni/diffusion/registry.py`: + +```python +_DIFFUSION_MODELS = { + "YourModelPipeline": ("your_model_name", "pipeline_your_model", "YourModelPipeline"), +} +_DIFFUSION_POST_PROCESS_FUNCS = { + "YourModelPipeline": "get_your_model_post_process_func", # if applicable +} +_DIFFUSION_PRE_PROCESS_FUNCS = { + "YourModelPipeline": "get_your_model_pre_process_func", # if applicable +} +``` + +The registry key is the `_class_name` from `model_index.json`. The tuple is `(folder_name, module_file, class_name)`. + +Create `__init__.py` exporting the pipeline class and any factory functions. + +### Step 5: Run, Test, Debug + +Use the appropriate existing example script: + +| Category | Script | +|----------|--------| +| Text-to-Image | `examples/offline_inference/text_to_image/text_to_image.py` | +| Text-to-Video | `examples/offline_inference/text_to_video/text_to_video.py` | +| Image-to-Video | `examples/offline_inference/image_to_video/image_to_video.py` | +| Image-to-Image | `examples/offline_inference/image_to_image/image_edit.py` | +| Text-to-Audio | `examples/offline_inference/text_to_audio/text_to_audio.py` | + +For custom/Omni models that don't fit these categories, create a dedicated example script. + +**Validation**: No errors, output is meaningful, quality matches reference implementation. + +See [references/troubleshooting.md](references/troubleshooting.md) for common errors. + +### Step 6: Add Example Scripts + +For Omni or custom models, create: +- `examples/offline_inference/your_model_name/` — offline script + README +- `examples/online_serving/your_model_name/` — server script + client +- Download script if weights require assembly from multiple sources + +### Step 7: Update Documentation + +Required updates: +1. `docs/user_guide/diffusion/parallelism_acceleration.md` — parallelism support table +2. `docs/user_guide/diffusion/teacache.md` — if TeaCache supported +3. `docs/user_guide/diffusion/cache_dit_acceleration.md` — if Cache-DiT supported +4. `examples/offline_inference/xxx/README.md` — offline example docs +5. `examples/online_serve/xxx/README.md` — online serve docs + +### Step 8: Add E2E Tests (Recommended) + +Create `tests/e2e/online_serving/test_your_model_expansion.py`. + +### Step 9: Add Cache-DiT Acceleration + +Cache-DiT accelerates inference by caching intermediate computation results across denoising steps. After your model is working correctly on a single GPU, add cache-dit support. + +See [references/cache-dit-patterns.md](references/cache-dit-patterns.md) for detailed code patterns. + +#### 9a. Determine your model type + +| Model Type | Description | Action | +|------------|-------------|--------| +| **Standard single-transformer** | One transformer with one `ModuleList` of blocks | No code needed — `CacheDiTBackend` auto-detects via `enable_cache_for_dit()` | +| **Multi-block-list** | One transformer with multiple block lists (e.g., `transformer_blocks` + `single_transformer_blocks`) | Write custom enabler with `BlockAdapter` | +| **Dual-transformer** | Two transformers (e.g., high-noise + low-noise) | Write custom enabler with `BlockAdapter` wrapping both | + +#### 9b. Standard models — verify automatic support + +For standard single-transformer models, test directly: + +```python +omni = Omni( + model="your-model-name", + cache_backend="cache_dit", + cache_config={ + "Fn_compute_blocks": 1, + "Bn_compute_blocks": 0, + "max_warmup_steps": 4, + } +) +``` + +Check logs for "Cache-dit enabled successfully on xxx". If it works, skip to Step 9e. + +#### 9c. Custom architectures — write a custom enabler + +For multi-block-list or dual-transformer models, write a custom enabler function: + +```python +from cache_dit import BlockAdapter, ForwardPattern, ParamsModifier, DBCacheConfig + +def enable_cache_for_your_model(pipeline, cache_config): + db_cache_config = DBCacheConfig( + num_inference_steps=None, + Fn_compute_blocks=cache_config.Fn_compute_blocks, + Bn_compute_blocks=cache_config.Bn_compute_blocks, + max_warmup_steps=cache_config.max_warmup_steps, + max_cached_steps=cache_config.max_cached_steps, + max_continuous_cached_steps=cache_config.max_continuous_cached_steps, + residual_diff_threshold=cache_config.residual_diff_threshold, + ) + + cache_dit.enable_cache( + BlockAdapter( + transformer=pipeline.transformer, + blocks=[ + pipeline.transformer.transformer_blocks, + pipeline.transformer.single_transformer_blocks, + ], + forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_1], + params_modifiers=[ParamsModifier(...)], + ), + cache_config=db_cache_config, + ) + + def refresh_cache_context(pipeline, num_inference_steps, verbose=True): + cache_dit.refresh_context( + pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose + ) + return refresh_cache_context +``` + +#### 9d. Register the custom enabler + +Add your enabler to `CUSTOM_DIT_ENABLERS` in `vllm_omni/diffusion/cache/cache_dit_backend.py`: + +```python +CUSTOM_DIT_ENABLERS = { + "Wan22Pipeline": enable_cache_for_wan22, + "LongCatImagePipeline": enable_cache_for_longcat_image, + "YourModelPipeline": enable_cache_for_your_model, # Add here +} +``` + +#### 9e. Test Cache-DiT + +```python +omni = Omni( + model="your-model-name", + cache_backend="cache_dit", + cache_config={ + "Fn_compute_blocks": 1, "Bn_compute_blocks": 0, + "max_warmup_steps": 4, "residual_diff_threshold": 0.24, + } +) +images = omni.generate("a beautiful landscape", + OmniDiffusionSamplingParams(num_inference_steps=50)) +``` + +**Verify**: 1) logs show cache enabled, 2) 1.5-2x speedup, 3) output quality acceptable vs baseline. + +If quality degrades, lower `residual_diff_threshold` (try 0.12-0.18) or increase `max_warmup_steps` (try 6-8). + +--- + +### Step 10: Add Parallelism Support + +After the model works on a single GPU, add multi-GPU parallelism. Add each type incrementally, testing after each addition. + +See [references/parallelism-patterns.md](references/parallelism-patterns.md) for detailed code patterns and API reference. + +**Recommended order**: TP → SP/USP → CFG Parallel → HSDP + +#### 10a. Tensor Parallelism (TP) + +Shards DiT linear layers across GPUs. Requires code changes in the transformer. + +**What to change in the transformer**: +1. Replace `nn.Linear` with `ColumnParallelLinear` / `RowParallelLinear` / `QKVParallelLinear` +2. Update `load_weights()` to handle QKV fusion with `stacked_params_mapping` +3. Use `self.to_qkv.num_heads` (local heads) instead of total heads for split sizes + +```python +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, RowParallelLinear, ColumnParallelLinear, +) + +# Attention: QKV → RowParallel output +self.to_qkv = QKVParallelLinear(dim, head_dim, num_heads, num_kv_heads) +self.to_out = RowParallelLinear(dim, dim, input_is_parallel=True) + +# FFN: ColumnParallel → RowParallel +self.w1 = ColumnParallelLinear(dim, ffn_dim) +self.w2 = RowParallelLinear(ffn_dim, dim, input_is_parallel=True) +``` + +**Constraints**: `num_heads % tp_size == 0` and `num_kv_heads % tp_size == 0`. + +**Test**: `--tensor-parallel-size 2` + +#### 10b. Sequence Parallelism (SP / USP) + +Splits sequence tokens across GPUs. Non-intrusive via `_sp_plan` on the transformer class — no changes to `forward()`. + +**What to change in the transformer**: + +Add `_sp_plan` class attribute: + +```python +from vllm_omni.diffusion.distributed.sp_plan import ( + SequenceParallelInput, SequenceParallelOutput, +) + +class YourTransformer(nn.Module): + _sp_plan = { + "blocks.0": { + "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3), + }, + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } +``` + +If inline tensor ops (e.g., `torch.cat`) exist between shard/gather points, extract them into `nn.Module` submodules so hooks can intercept them. + +For RoPE that needs splitting, add an entry for the RoPE module with `split_output=True`. + +**Test**: `--ulysses-degree 2` (offline) or `--usp 2` (online serving) + +#### 10c. CFG Parallel + +Distributes positive/negative CFG branches across 2 GPUs. Requires the pipeline to inherit `CFGParallelMixin`. + +**What to change in the pipeline**: + +```python +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin + +class YourPipeline(nn.Module, CFGParallelMixin): + def diffuse(self, ...) -> torch.Tensor: + for i, t in enumerate(timesteps): + positive_kwargs = {...} + negative_kwargs = {...} if do_true_cfg else None + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, true_cfg_scale=cfg_scale, + positive_kwargs=positive_kwargs, negative_kwargs=negative_kwargs, + ) + latents = self.scheduler_step_maybe_with_cfg( + noise_pred, t, latents, do_true_cfg + ) + return latents +``` + +Override `predict_noise()` if your transformer call is non-standard. Override `combine_cfg_noise()` for multi-output models (e.g., video + audio). + +**Constraint**: Exactly 2 GPUs. Only for models using classifier-free guidance. + +**Test**: `--cfg-parallel-size 2` + +#### 10d. HSDP (Hybrid Sharded Data Parallel) + +Shards transformer weights via PyTorch FSDP2 to reduce per-GPU VRAM. No code changes to the forward pass — just add a class attribute. + +**What to change in the transformer**: + +```python +class YourTransformer(nn.Module): + @staticmethod + def _is_transformer_block(name: str, module) -> bool: + return "blocks" in name and name.split(".")[-1].isdigit() + + _hsdp_shard_conditions = [_is_transformer_block] +``` + +**Constraint**: Cannot combine with TP. For standalone HSDP, set `hsdp_shard_size` explicitly. + +**Test**: `--use-hsdp` or `DiffusionParallelConfig(use_hsdp=True)` + +#### 10e. Update parallelism documentation + +After adding parallelism support, update: +1. `docs/user_guide/diffusion/parallelism_acceleration.md` — add your model to the support table +2. Record which parallelism methods are supported (USP, Ring, CFG, TP, HSDP, VAE-Patch) + +--- + +## Iterative Development Tips + +1. **Start minimal**: Basic generation first, no parallelism/caching +2. **Use `--enforce-eager`**: Disable torch.compile during debugging +3. **Use small models**: Test with smaller variants first +4. **Check tensor shapes**: Most errors are reshape mismatches in attention +5. **Add features incrementally**: Single GPU → TP → SP → CFG → HSDP → Cache-DiT +6. **For custom models**: Get the model running with the original code first, then progressively replace components with vllm-omni equivalents +7. **Cache-DiT before parallelism tuning**: Cache-DiT is lossy — verify quality at baseline before combining with parallelism +8. **Combine lossless + lossy**: e.g., TP + SP + Cache-DiT for maximum throughput + +## Reference Files + +- [Transformer Adaptation](references/transformer-adaptation.md) — porting transformers from diffusers +- [Custom Model Patterns](references/custom-model-patterns.md) — patterns for non-diffusers models +- [Parallelism Patterns](references/parallelism-patterns.md) — TP, SP/USP, CFG parallel, HSDP implementation details +- [Cache-DiT Patterns](references/cache-dit-patterns.md) — cache-dit acceleration for standard and custom architectures +- [Troubleshooting](references/troubleshooting.md) — common errors and fixes diff --git a/.claude/skills/add-diffusion-model/references/cache-dit-patterns.md b/.claude/skills/add-diffusion-model/references/cache-dit-patterns.md new file mode 100644 index 0000000000..d34ce0e0f4 --- /dev/null +++ b/.claude/skills/add-diffusion-model/references/cache-dit-patterns.md @@ -0,0 +1,254 @@ +# Cache-DiT Patterns Reference + +## Overview + +Cache-DiT accelerates Diffusion Transformers by caching intermediate computation results across denoising steps. Adjacent steps produce similar features, so redundant computations can be skipped. + +Three caching strategies: +- **DBCache**: Dynamic block-level caching — selectively computes or caches transformer blocks based on residual differences +- **TaylorSeer**: Calibration-based prediction using Taylor expansion to estimate block outputs +- **SCM** (Step Computation Masking): Dynamic step skipping based on configurable policies + +**Typical speedup**: 1.5-2.5x depending on model and configuration. + +**Official docs**: https://docs.vllm.ai/projects/vllm-omni/en/latest/design/feature/cache_dit + +## Architecture + +vLLM-Omni integrates cache-dit through `CacheDiTBackend`: + +| Component | Purpose | +|-----------|---------| +| `CacheDiTBackend` | Unified backend — auto-selects enabler (standard or custom) | +| `enable_cache_for_dit()` | Default enabler for standard single-transformer models | +| `CUSTOM_DIT_ENABLERS` dict | Registry of custom enablers keyed by pipeline class name | +| `BlockAdapter` | Wraps complex architectures (multi-block-list or multi-transformer) | +| `ForwardPattern` | Specifies block forward signature: `Pattern_0`, `Pattern_1`, `Pattern_2` | +| `ParamsModifier` | Per-transformer or per-block-list config customization | +| `DBCacheConfig` | Configuration for DBCache parameters | +| `cache_dit.refresh_context()` | Updates cache context when `num_inference_steps` changes | + +**Source files**: +- `vllm_omni/diffusion/cache/cache_dit_backend.py` — `CacheDiTBackend`, enablers, `CUSTOM_DIT_ENABLERS` +- `vllm_omni/diffusion/cache/` — cache backend implementations + +## Standard Models: Automatic Support + +Most DiT models follow this pattern: +- Single transformer with one `nn.ModuleList` of blocks +- Standard forward signature +- Compatible with cache-dit's automatic detection + +**Examples**: Qwen-Image, Z-Image, FLUX + +No code changes needed. `CacheDiTBackend` automatically uses `enable_cache_for_dit()`: + +```python +from vllm_omni import Omni + +omni = Omni( + model="Qwen/Qwen-Image", + cache_backend="cache_dit", + cache_config={ + "Fn_compute_blocks": 1, + "Bn_compute_blocks": 0, + "max_warmup_steps": 4, + } +) +``` + +What happens automatically: + +```python +def enable_cache_for_dit(pipeline, cache_config): + db_cache_config = DBCacheConfig( + num_inference_steps=None, + Fn_compute_blocks=cache_config.Fn_compute_blocks, + Bn_compute_blocks=cache_config.Bn_compute_blocks, + max_warmup_steps=cache_config.max_warmup_steps, + max_cached_steps=cache_config.max_cached_steps, + max_continuous_cached_steps=cache_config.max_continuous_cached_steps, + residual_diff_threshold=cache_config.residual_diff_threshold, + ) + + cache_dit.enable_cache(pipeline.transformer, cache_config=db_cache_config) + + def refresh_cache_context(pipeline, num_inference_steps, verbose=True): + cache_dit.refresh_context( + pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose + ) + return refresh_cache_context +``` + +## Custom Architectures: Writing Custom Enablers + +### When you need a custom enabler + +- Model has multiple block lists in one transformer (e.g., `transformer_blocks` + `single_transformer_blocks`) +- Model has two transformers (e.g., high-noise + low-noise like Wan2.2) +- Model uses non-standard block forward signature + +### Pattern 1: Multi-Block-List (LongCat-Image style) + +Single transformer with two block lists: + +```python +import cache_dit +from cache_dit import BlockAdapter, ForwardPattern, ParamsModifier, DBCacheConfig + +def enable_cache_for_your_model(pipeline, cache_config): + db_cache_config = DBCacheConfig( + num_inference_steps=None, + Fn_compute_blocks=cache_config.Fn_compute_blocks, + Bn_compute_blocks=cache_config.Bn_compute_blocks, + max_warmup_steps=cache_config.max_warmup_steps, + max_cached_steps=cache_config.max_cached_steps, + max_continuous_cached_steps=cache_config.max_continuous_cached_steps, + residual_diff_threshold=cache_config.residual_diff_threshold, + ) + + cache_dit.enable_cache( + BlockAdapter( + transformer=pipeline.transformer, + blocks=[ + pipeline.transformer.transformer_blocks, + pipeline.transformer.single_transformer_blocks, + ], + forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_1], + params_modifiers=[ParamsModifier(...)], + ), + cache_config=db_cache_config, + ) + + def refresh_cache_context(pipeline, num_inference_steps, verbose=True): + cache_dit.refresh_context( + pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose + ) + return refresh_cache_context +``` + +For single transformer with multiple block lists, `refresh_context` works the same as standard models — call it once on the transformer. + +### Pattern 2: Dual-Transformer (Wan2.2 style) + +Two transformers with separate configs: + +```python +def enable_cache_for_dual_transformer(pipeline, cache_config): + db_cache_config = DBCacheConfig(...) + + cache_dit.enable_cache( + BlockAdapter( + transformer=[pipeline.transformer, pipeline.transformer_2], + blocks=[pipeline.transformer.blocks, pipeline.transformer_2.blocks], + forward_pattern=[ForwardPattern.Pattern_2, ForwardPattern.Pattern_2], + params_modifiers=[ + ParamsModifier(...), # Config for transformer 1 + ParamsModifier(...), # Config for transformer 2 + ], + ), + cache_config=db_cache_config, + ) + + def refresh_cache_context(pipeline, num_inference_steps, verbose=True): + high_steps, low_steps = _split_inference_steps(num_inference_steps) + cache_dit.refresh_context( + pipeline.transformer, num_inference_steps=high_steps, verbose=verbose + ) + cache_dit.refresh_context( + pipeline.transformer_2, num_inference_steps=low_steps, verbose=verbose + ) + return refresh_cache_context +``` + +Key difference: `refresh_context` must be called on **each transformer separately** with its own step count. + +### Choosing the ForwardPattern + +| Pattern | Block forward signature | Example models | +|---------|------------------------|----------------| +| `Pattern_0` | `block(hidden_states, **kwargs)` → residual added inside block | Default | +| `Pattern_1` | `block(hidden_states, **kwargs)` → returns `(hidden_states, ...)` tuple | FLUX-style single blocks | +| `Pattern_2` | `block(hidden_states, **kwargs)` → `(hidden_states, ...)` with different residual pattern | Wan2.2 blocks | + +Inspect your block's `forward()` return type and residual connection pattern to choose the right one. See [Cache-DiT API Reference](https://cache-dit.readthedocs.io/en/latest/user_guide/CACHE_API/) for details. + +## Registering Custom Enablers + +Add your enabler to `CUSTOM_DIT_ENABLERS` in `vllm_omni/diffusion/cache/cache_dit_backend.py`: + +```python +CUSTOM_DIT_ENABLERS = { + "Wan22Pipeline": enable_cache_for_wan22, + "LongCatImagePipeline": enable_cache_for_longcat_image, + "YourModelPipeline": enable_cache_for_your_model, +} +``` + +The key must match `pipeline.__class__.__name__`. + +## Configuration Parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `Fn_compute_blocks` | 1 | Number of blocks to always compute at the front | +| `Bn_compute_blocks` | 0 | Number of blocks to always compute at the back | +| `max_warmup_steps` | 4 | Steps to run without caching at the beginning | +| `max_cached_steps` | — | Max total cached steps | +| `max_continuous_cached_steps` | — | Max consecutive cached steps | +| `residual_diff_threshold` | 0.24 | Threshold for deciding whether to cache a block | + +### Tuning for quality vs speed + +| Goal | Adjustments | +|------|-------------| +| **More speed, acceptable quality loss** | Higher `residual_diff_threshold` (0.24-0.4), lower `max_warmup_steps` (2-4) | +| **Better quality, less speed** | Lower `residual_diff_threshold` (0.12-0.18), higher `max_warmup_steps` (6-8), lower `max_continuous_cached_steps` (2) | + +## Testing + +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +omni = Omni( + model="your-model-name", + cache_backend="cache_dit", + cache_config={ + "Fn_compute_blocks": 1, + "Bn_compute_blocks": 0, + "max_warmup_steps": 4, + "residual_diff_threshold": 0.24, + } +) +images = omni.generate( + "a beautiful landscape", + OmniDiffusionSamplingParams(num_inference_steps=50), +) +``` + +CLI (online serving): + +```bash +vllm serve your-model --omni --port 8098 \ + --cache-backend cache_dit \ + --cache-config '{"Fn_compute_blocks": 1, "Bn_compute_blocks": 0, "max_warmup_steps": 4}' +``` + +**Verification checklist**: +1. Logs show "Cache-dit enabled successfully on xxx" +2. Performance: 1.5-2x speedup vs no cache +3. Quality: compare output with `cache_backend=None` + +## Excluded Models + +Models listed in `_NO_CACHE_ACCELERATION` in `vllm_omni/diffusion/registry.py` do not support cache-dit (e.g., `NextStep11Pipeline`, `StableDiffusionPipeline`). Check this set before attempting to enable cache-dit. + +## Reference Implementations + +| Model | Path | Notes | +|-------|------|-------| +| Standard DiT | `cache_dit_backend.py::enable_cache_for_dit` | Default enabler, automatic | +| Wan2.2 | `cache_dit_backend.py::enable_cache_for_wan22` | Dual-transformer, auto-detects mode | +| LongCat | `cache_dit_backend.py::enable_cache_for_longcat_image` | Multi-block-list | +| BAGEL | `cache_dit_backend.py::enable_cache_for_bagel` | Complex omni model | diff --git a/.claude/skills/add-diffusion-model/references/custom-model-patterns.md b/.claude/skills/add-diffusion-model/references/custom-model-patterns.md new file mode 100644 index 0000000000..2434e0b5da --- /dev/null +++ b/.claude/skills/add-diffusion-model/references/custom-model-patterns.md @@ -0,0 +1,273 @@ +# Custom Model Patterns Reference + +Patterns for adding models that don't come from the standard diffusers pipeline format. + +## Directory Structure Comparison + +### Diffusers-based model (e.g., Wan2.2) + +``` +vllm_omni/diffusion/models/wan2_2/ +├── __init__.py # Exports pipeline + transformer + helpers +├── pipeline_wan2_2.py # Pipeline: loads components via from_pretrained() +├── pipeline_wan2_2_i2v.py # Variant pipeline for image-to-video +└── wan2_2_transformer.py # Transformer: ported from diffusers, uses Attention layer +``` + +The transformer is loaded separately via `weights_sources` + `load_weights()`. Non-transformer components (VAE, text encoder) are loaded in `__init__` via `from_pretrained()`. + +### Custom model with external deps (e.g., DreamID-Omni) + +``` +vllm_omni/diffusion/models/dreamid_omni/ +├── __init__.py # Exports pipeline only +├── pipeline_dreamid_omni.py # Pipeline: loads ALL weights in __init__ via custom helpers +├── fusion.py # Custom fusion architecture (video + audio cross-attention) +└── wan2_2.py # Re-implemented Wan backbone with split API + +examples/offline_inference/x_to_video_audio/ +└── download_dreamid_omni.py # Downloads weights from 3 HF repos + clones code repo +``` + +All weights loaded eagerly in `__init__`. `load_weights()` is a no-op. External dependency (`dreamid_omni` package) imported with try/except. + +### Custom model with ported code (e.g., BAGEL) + +``` +vllm_omni/diffusion/models/bagel/ +├── __init__.py +├── pipeline_bagel.py # Pipeline: instantiates models, uses weights_sources +├── bagel_transformer.py # Full LLM backbone (Qwen2-MoT) ported into vllm-omni +└── autoencoder.py # Custom VAE ported from original repo +``` + +Model code is fully ported (no external dependency). Uses `weights_sources` and `load_weights()` with custom name remapping to handle non-diffusers safetensors format. + +## Weight Loading Patterns + +### Pattern 1: Standard diffusers flow (Wan2.2, Z-Image, FLUX) + +``` +init → create transformer (empty) → set weights_sources → [loader calls load_weights()] +``` + +- `weights_sources` points to safetensors in HF subfolder (e.g., `transformer/`) +- `load_weights()` receives `(name, tensor)` pairs from the loader +- Name remapping handles diffusers→vllm-omni differences (QKV fusion, Sequential index removal) + +### Pattern 2: Custom safetensors at root (BAGEL) + +``` +init → create all models (empty) → set weights_sources(subfolder=None) → [loader calls load_weights()] +``` + +- `weights_sources` points to **root** of model directory, not a subfolder +- Weights have non-diffusers names (e.g., `bagel.language_model.model.layers.0.self_attn.q_proj.weight`) +- `load_weights()` does heavy name normalization + +```python +self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder=None, # root directory + prefix="", # no prefix stripping + fall_back_to_pt=False, + ) +] +``` + +### Pattern 3: Fully custom loading (DreamID-Omni) + +``` +init → load ALL weights eagerly via custom helpers → load_weights() = no-op +``` + +- No `weights_sources` attribute — standard loader finds nothing to iterate +- Custom init functions (e.g., `init_wan_vae_2_2()`, `load_fusion_checkpoint()`) handle downloading and loading +- `load_weights()` is `pass` +- Weights may come from multiple HF repos in different formats (`.pth`, `.safetensors`) + +Use this when: +- The original model has complex, well-tested loading code you don't want to rewrite +- Weights span multiple HF repos +- Weight format is non-standard (e.g., a single `.pth` file, not sharded safetensors) + +## model_index.json for Custom Models + +Standard diffusers `model_index.json`: +```json +{ + "_class_name": "WanPipeline", + "_diffusers_version": "0.35.0.dev0", + "scheduler": ["diffusers", "UniPCMultistepScheduler"], + "transformer": ["diffusers", "WanTransformer3DModel"], + "vae": ["diffusers", "AutoencoderKLWan"] +} +``` + +Custom model `model_index.json` (minimal): +```json +{ + "_class_name": "DreamIDOmniPipeline", + "fusion": "DreamID-Omni/dreamid_omni.safetensors" +} +``` + +The only **required** field is `_class_name` — it must match a key in `_DIFFUSION_MODELS` in `registry.py`. Other fields are model-specific and accessible via `od_config.model_config` dict. + +## External Dependency Management + +### Git clone + .pth injection (DreamID-Omni pattern) + +```python +def download_dependency(): + CACHE_DIR.mkdir(parents=True, exist_ok=True) + with open(LOCK_FILE, "w") as f: + fcntl.flock(f, fcntl.LOCK_EX) + if not DEPENDENCY_DIR.exists(): + subprocess.run([ + "git", "clone", "--depth", "1", + REPO_URL, "--branch", BRANCH, + str(DEPENDENCY_DIR) + ], check=True) + fcntl.flock(f, fcntl.LOCK_UN) + + # Add to Python path via .pth file + site_packages = Path(site.getsitepackages()[0]) + pth_file = site_packages / "vllm_omni_dependency.pth" + pth_file.write_text(str(DEPENDENCY_DIR)) +``` + +### Direct port (BAGEL pattern) + +Copy essential files from the original repo into `vllm_omni/diffusion/models//`. Adapt imports to use vllm-omni utilities. Benefits: no external dependency, no git clone step. Drawback: must maintain the ported code. + +## Multi-Modal Input/Output Protocols + +Custom models that handle images, audio, or video I/O should implement protocol classes: + +```python +from vllm_omni.diffusion.models.interface import ( + SupportImageInput, # Model accepts image input + SupportAudioInput, # Model accepts audio input + SupportAudioOutput, # Model produces audio output +) + +class MyPipeline(nn.Module, SupportImageInput, SupportAudioInput, SupportAudioOutput): + pass # Protocol markers enable proper engine routing +``` + +The engine checks `isinstance(pipeline, SupportImageInput)` at startup to configure input validation and warmup behavior. + +## Hardcoded Config vs Config Files + +Diffusers models use `config.json` in each subfolder. Custom models often use: + +**Module-level config dicts** (DreamID-Omni): +```python +VIDEO_CONFIG = { + "patch_size": [1, 2, 2], "model_type": "ti2v", + "dim": 3072, "ffn_dim": 14336, "num_heads": 24, "num_layers": 30, ... +} +``` + +**Loaded from custom JSON** (BAGEL): +```python +cfg_path = os.path.join(model_path, "config.json") +with open(cfg_path) as f: + bagel_cfg = json.load(f) +vae_cfg = bagel_cfg.get("vae_config", {}) +``` + +## Custom Architecture Patterns + +### Split forward API (DreamID-Omni) + +When a fusion model needs to interleave blocks from two backbones: + +```python +class WanModel(nn.Module): + def prepare_transformer_block_kwargs(self, x, t, context, ...): + # Patch embed, time embed, text embed, RoPE + return x, e, kwargs + + def post_transformer_block_out(self, x, grid_sizes, e): + # Output projection, unpatchify + return output + + def forward(self, *args, **kwargs): + raise NotImplementedError # Fusion model handles block iteration +``` + +The `FusionModel` then iterates blocks in lock-step: +```python +for video_block, audio_block in zip(self.video_model.blocks, self.audio_model.blocks): + video_out = video_block(video_hidden, ...) + audio_out = audio_block(audio_hidden, ...) + # Cross-attend between modalities + video_out = cross_attention(video_out, audio_out) + audio_out = cross_attention(audio_out, video_out) +``` + +### LLM-as-denoiser (BAGEL) + +When the backbone is a language model that also does diffusion: + +```python +class BagelModel(nn.Module): + def __init__(self): + self.language_model = Qwen2MoTForCausalLM(config) + self.vit_model = SiglipVisionModel(vit_config) +``` + +The LLM processes both text tokens and latent image tokens in a single forward pass, using KV caching for the text portion. + +## Pre/Post Processing for Custom Models + +Custom models typically handle pre/post processing **inside `forward()`** rather than via registered functions, because the logic is tightly coupled: + +```python +def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: + # Inline preprocessing + image = self._load_and_resize_image(req.prompts[0].get("multi_modal_data", {}).get("image")) + image_latent = self._vae_encode(image) + + # ... denoising loop ... + + # Inline postprocessing + pil_image = self._decode_to_pil(latents) + return DiffusionOutput(output=[pil_image]) +``` + +If pre/post functions are not registered in `_DIFFUSION_PRE_PROCESS_FUNCS` / `_DIFFUSION_POST_PROCESS_FUNCS`, the engine simply skips those steps. + +## Download Script Template + +```python +# examples/offline_inference//download_.py +from huggingface_hub import snapshot_download +import json, os + +def main(output_dir): + # Download model weights from HF + snapshot_download(repo_id="org/model-weights", local_dir=os.path.join(output_dir, "weights")) + + # Download additional components if from separate repos + snapshot_download(repo_id="org/vae-weights", local_dir=os.path.join(output_dir, "vae"), + allow_patterns=["*.safetensors"]) + + # Generate model_index.json + config = {"_class_name": "YourPipeline", "custom_key": "weights/model.safetensors"} + with open(os.path.join(output_dir, "model_index.json"), "w") as f: + json.dump(config, f, indent=2) + + # Install external code dependency (if needed) + download_dependency() + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", default="./your_model") + args = parser.parse_args() + main(args.output_dir) +``` diff --git a/.claude/skills/add-diffusion-model/references/parallelism-patterns.md b/.claude/skills/add-diffusion-model/references/parallelism-patterns.md new file mode 100644 index 0000000000..933e2d2320 --- /dev/null +++ b/.claude/skills/add-diffusion-model/references/parallelism-patterns.md @@ -0,0 +1,571 @@ +# Parallelism Patterns Reference + +## Overview + +vLLM-Omni supports multiple parallelism strategies for diffusion models. Each targets a different bottleneck: + +| Strategy | Splits | Best For | Constraint | +|----------|--------|----------|------------| +| Tensor Parallel (TP) | Model layers across GPUs | Latency reduction, large models | Requires fast GPU interconnect, `num_heads % tp == 0` | +| Sequence Parallel (SP/USP) | Sequence tokens across GPUs | Long sequences (video, high-res) | Near-linear scaling | +| CFG Parallel | Positive/negative CFG branches | Models using classifier-free guidance | Exactly 2 GPUs | +| HSDP | Weight shards via FSDP2 | VRAM reduction | Cannot combine with TP | +| VAE Patch Parallel | VAE decode spatial tiles | Large VAE outputs | Auto-enables tiling | + +**Recommended integration order**: TP → SP → CFG Parallel → HSDP + +**Official design docs**: +- TP: https://docs.vllm.ai/projects/vllm-omni/en/latest/design/feature/tensor_parallel +- SP: https://docs.vllm.ai/projects/vllm-omni/en/latest/design/feature/sequence_parallel +- CFG: https://docs.vllm.ai/projects/vllm-omni/en/latest/design/feature/cfg_parallel +- HSDP: https://docs.vllm.ai/projects/vllm-omni/en/latest/design/feature/hsdp + +--- + +## Tensor Parallelism (TP) + +Replace standard `nn.Linear` with vLLM's parallel linear layers. This is the most invasive change but provides direct VRAM savings and compute speedup. + +### Layer replacement rules + +| Pattern | vLLM Layer | When to Use | +|---------|-----------|-------------| +| Fan-out (first in FFN) | `ColumnParallelLinear` | Projection that splits output across ranks | +| Fan-in (second in FFN) | `RowParallelLinear` | Projection that gathers across ranks | +| QKV projection | `QKVParallelLinear` | Fused Q/K/V for self-attention | +| Single Q or K or V | `ColumnParallelLinear` | Separate projections (cross-attention) | +| Attention output | `RowParallelLinear` | Output projection after attention | +| Must not shard | `ReplicatedLinear` | Layers that must stay replicated | + +### MLP Block (Up-Down Pattern) + +```python +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, RowParallelLinear, +) + +class TPFeedForward(nn.Module): + def __init__(self, dim, ffn_dim): + super().__init__() + self.fc1 = ColumnParallelLinear(dim, ffn_dim, bias=False, return_bias=False) + self.fc2 = RowParallelLinear( + ffn_dim, dim, bias=False, + input_is_parallel=True, # Input already sharded from fc1 + return_bias=False, + ) + + def forward(self, x): + x, _ = self.fc1(x) + x = torch.nn.functional.gelu(x) + x, _ = self.fc2(x) + return x +``` + +### Attention Block (QKV-Out Pattern) + +```python +from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear +from vllm_omni.diffusion.attention.layer import Attention + +class TPSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads=None): + super().__init__() + num_kv_heads = num_kv_heads or num_heads + self.head_dim = dim // num_heads + + self.to_qkv = QKVParallelLinear( + hidden_size=dim, + head_size=self.head_dim, + total_num_heads=num_heads, + total_num_kv_heads=num_kv_heads, + bias=False, + return_bias=False, + ) + self.to_out = RowParallelLinear( + dim, dim, bias=False, + input_is_parallel=True, + return_bias=False, + ) + self.attn = Attention( + num_heads=self.to_qkv.num_heads, # Local heads per GPU + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim ** 0.5), + causal=False, + num_kv_heads=self.to_qkv.num_kv_heads, # Local KV heads per GPU + ) + + def forward(self, x): + qkv, _ = self.to_qkv(x) + q, k, v = qkv.split( + [self.to_qkv.num_heads * self.head_dim, + self.to_qkv.num_kv_heads * self.head_dim, + self.to_qkv.num_kv_heads * self.head_dim], + dim=-1, + ) + B, S, _ = x.shape + q = q.view(B, S, self.to_qkv.num_heads, self.head_dim) + k = k.view(B, S, self.to_qkv.num_kv_heads, self.head_dim) + v = v.view(B, S, self.to_qkv.num_kv_heads, self.head_dim) + out = self.attn(q, k, v) + out = out.reshape(B, S, -1) + out, _ = self.to_out(out) + return out +``` + +### QKV Fusion in load_weights + +When you fuse separate Q/K/V into `QKVParallelLinear`, map diffusers' separate weight names: + +```python +stacked_params_mapping = [ + ("to_qkv", "to_q", "q"), + ("to_qkv", "to_k", "k"), + ("to_qkv", "to_v", "v"), +] + +def load_weights(self, weights): + params = dict(self.named_parameters()) + loaded = set() + for name, tensor in weights: + for fused_name, orig_name, shard_id in stacked_params_mapping: + if orig_name in name: + name = name.replace(orig_name, fused_name) + param = params[name] + param.weight_loader(param, tensor, shard_id) + loaded.add(name) + break + else: + if name in params: + param = params[name] + if hasattr(param, "weight_loader"): + param.weight_loader(param, tensor) + else: + default_weight_loader(param, tensor) + loaded.add(name) + return loaded +``` + +### RMSNorm with TP + +When RMSNorm sits between TP-sharded dimensions, use `DistributedRMSNorm` — it computes global RMS via all-reduce across TP ranks. See the Wan2.2 implementation for the pattern. + +### TP Constraints + +- `num_heads % tp_size == 0` +- `num_kv_heads % tp_size == 0` +- Use `self.to_qkv.num_heads` (local per-GPU count), not total heads, for split sizes + +### Testing TP + +```bash +python text_to_image.py --model Your-org/your-model \ + --tensor-parallel-size 2 --output "tp_test.png" +``` + +**Verify**: speedup, memory reduction proportional to TP size, quality matches single-GPU. + +### Reference implementations + +| Model | Path | +|-------|------| +| Z-Image | `vllm_omni/diffusion/models/z_image/z_image_transformer.py` | +| FLUX | `vllm_omni/diffusion/models/flux/flux_transformer.py` | +| Qwen-Image | `vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py` | + +--- + +## Sequence Parallelism (SP / USP) + +SP splits sequence tokens across GPUs using Ulysses (all-to-all) or Ring (P2P) communication. It is applied non-intrusively via the `_sp_plan` dict — no changes to `forward()` logic. + +### Approach 1: Non-Intrusive `_sp_plan` (Recommended) + +The framework automatically registers hooks to shard inputs and gather outputs at `nn.Module` boundaries. + +#### Step 1: Identify module boundaries + +Find where tensors need sharding/gathering: + +```python +class MyTransformer(nn.Module): + def __init__(self): + self.patch_embed = PatchEmbed() # Before blocks + self.pos_embed = RoPE() # RoPE may need splitting + self.blocks = nn.ModuleList([...]) # Blocks process sharded x + self.norm_out = LayerNorm() + self.proj_out = Linear() # Gather after this + + def forward(self, x): + x = self.patch_embed(x) + pos = self.pos_embed(x) + for block in self.blocks: + x = block(x, pos) + x = self.norm_out(x) + return self.proj_out(x) +``` + +#### Step 2: Handle inline operations + +`_sp_plan` hooks only work at `nn.Module` boundaries. Inline ops like `torch.cat()` must be extracted into submodules: + +```python +# BAD: Inline — hooks can't intercept +unified = torch.cat([x, cap_feats], dim=1) + +# GOOD: Extract into submodule +class UnifiedPrepare(nn.Module): + def forward(self, x, cap_feats): + return torch.cat([x, cap_feats], dim=1) + +self.unified_prepare = UnifiedPrepare() +unified = self.unified_prepare(x, cap_feats) +``` + +Common cases: `torch.cat()`, `pad_sequence()`, `tensor.reshape()`, complex preprocessing. + +#### Step 3: Write `_sp_plan` + +**Pattern 1: Shard at first block, gather at output** (most common) + +```python +from vllm_omni.diffusion.distributed.sp_plan import ( + SequenceParallelInput, SequenceParallelOutput, +) + +class StandardTransformer(nn.Module): + _sp_plan = { + "blocks.0": { + "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3), + }, + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } +``` + +**Pattern 2: Shard RoPE outputs separately** + +```python +class TransformerWithRoPE(nn.Module): + _sp_plan = { + "rope": { + 0: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), + 1: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True), + }, + "blocks.0": { + "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3), + }, + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } +``` + +**Pattern 3: Dual-stream (shard image, replicate text)** + +```python +class DualStreamTransformer(nn.Module): + _sp_plan = { + "rope_preparer": { + 2: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True), + 3: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True), + }, + "transformer_blocks.0": { + "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3), + }, + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } +``` + +### API Reference + +**SequenceParallelInput**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `split_dim` | int | Dimension to split (usually 1 for sequence) | +| `expected_dims` | int/None | Expected tensor rank for validation | +| `split_output` | bool | `False`: shard input params; `True`: shard output tensors | +| `auto_pad` | bool | Auto-pad if sequence not divisible by world_size | + +**SequenceParallelOutput**: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `gather_dim` | int | Dimension to gather (usually 1 for sequence) | +| `expected_dims` | int/None | Expected tensor rank for validation | + +**Module naming**: + +| Key | Meaning | +|-----|---------| +| `"blocks.0"` | First element of ModuleList | +| `"blocks.*"` | All elements of ModuleList | +| `"rope"` | Named submodule | + +**Dictionary value types**: + +| Key type | split_output | Description | +|----------|-------------|-------------| +| `"param_name"` (str) | False | Shard input parameter by name | +| `0, 1, ...` (int) | True | Shard output tuple by index | + +### Approach 2: Intrusive Modification (Complex Cases) + +For dynamic sharding logic that can't be expressed via `_sp_plan`: + +```python +from vllm_omni.diffusion.distributed.sp_sharding import sp_shard, sp_gather + +def forward(self, hidden_states, ...): + if self.parallel_config.sequence_parallel_size > 1: + hidden_states = sp_shard(hidden_states, dim=1) + for block in self.blocks: + hidden_states = block(hidden_states) + if self.parallel_config.sequence_parallel_size > 1: + hidden_states = sp_gather(hidden_states, dim=1) + return hidden_states +``` + +Use intrusive modification as a last resort — `_sp_plan` is preferred for maintainability. + +### UAA Mode (Experimental) + +`ulysses_mode="advanced_uaa"` handles arbitrary sequence lengths and head counts that aren't divisible by `ulysses_degree`. Uses variable all-to-all split sizes and temporary head padding. + +### Combining SP methods + +Ulysses and Ring can be combined: `ulysses_degree × ring_degree = total SP GPUs`. + +```python +DiffusionParallelConfig(ulysses_degree=2, ring_degree=2) # 4 GPUs total +``` + +### Testing SP + +```bash +# Offline +python text_to_image.py --model Your-model --ulysses-degree 2 + +# Online serving +vllm serve Your-model --omni --usp 2 +``` + +### Reference implementations + +| Model | Path | +|-------|------| +| Qwen-Image | `vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py` | +| Wan2.2 | `vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py` | +| Z-Image | `vllm_omni/diffusion/models/z_image/z_image_transformer.py` | + +--- + +## CFG Parallelism + +Distributes positive/negative Classifier-Free Guidance branches across 2 GPUs. + +### Implementation + +Inherit `CFGParallelMixin` and implement `diffuse()`: + +```python +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin + +class YourPipeline(nn.Module, CFGParallelMixin): + def diffuse(self, latents, timesteps, prompt_embeds, negative_embeds, + do_true_cfg, true_cfg_scale, **kwargs): + for i, t in enumerate(timesteps): + positive_kwargs = { + "hidden_states": latents, + "encoder_hidden_states": prompt_embeds, + "timestep": t, + } + negative_kwargs = { + "hidden_states": latents, + "encoder_hidden_states": negative_embeds, + "timestep": t, + } if do_true_cfg else None + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=true_cfg_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + ) + latents = self.scheduler_step_maybe_with_cfg( + noise_pred, t, latents, do_true_cfg + ) + return latents +``` + +### Customization hooks + +| Method | Override when | +|--------|-------------| +| `predict_noise()` | Non-standard transformer call (e.g., dual-transformer like Wan2.2) | +| `cfg_normalize_function()` | Custom normalization (e.g., LongCat with clamping) | +| `combine_cfg_noise()` | Multi-output models (e.g., video + audio: CFG on video, positive-only on audio) | + +**Custom predict_noise** (Wan2.2 — selects active transformer): + +```python +def predict_noise(self, current_model=None, **kwargs): + if current_model is None: + current_model = self.transformer + return current_model(**kwargs)[0] +``` + +**Custom combine_cfg_noise** (multi-output): + +```python +def combine_cfg_noise(self, positive_pred, negative_pred, scale, normalize): + video_pos, audio_pos = positive_pred + video_neg, audio_neg = negative_pred + video_combined = super().combine_cfg_noise(video_pos, video_neg, scale, normalize) + return (video_combined, audio_pos) +``` + +### Composite scheduler for multi-output + +When each output has its own schedule: + +```python +class VideoAudioScheduler: + def __init__(self, video_scheduler, audio_scheduler): + self.video_scheduler = video_scheduler + self.audio_scheduler = audio_scheduler + + def step(self, noise_pred, t, latents, return_dict=False, generator=None): + video_out = self.video_scheduler.step( + noise_pred[0], t[0], latents[0], return_dict=False, generator=generator + )[0] + audio_out = self.audio_scheduler.step( + noise_pred[1], t[1], latents[1], return_dict=False, generator=generator + )[0] + return ((video_out, audio_out),) +``` + +### Testing CFG Parallel + +```bash +python text_to_image.py --model Your-model \ + --cfg-parallel-size 2 --cfg-scale 4.0 \ + --negative-prompt "ugly, unclear" +``` + +**Constraint**: `guidance_scale > 1.0` and negative prompt must be provided. + +### Reference implementations + +| Model | Path | +|-------|------| +| Qwen-Image | `vllm_omni/diffusion/models/qwen_image/cfg_parallel.py` | +| Wan2.2 | `vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py` | +| Mixin base | `vllm_omni/diffusion/distributed/cfg_parallel.py` | + +--- + +## HSDP (Hybrid Sharded Data Parallel) + +Shards model weights across GPUs using PyTorch FSDP2. Reduces per-GPU VRAM without changing computation. + +### Implementation + +Add `_hsdp_shard_conditions` to the transformer class: + +```python +class YourTransformer(nn.Module): + @staticmethod + def _is_transformer_block(name: str, module) -> bool: + return "blocks" in name and name.split(".")[-1].isdigit() + + _hsdp_shard_conditions = [_is_transformer_block] +``` + +For MoE models, add additional conditions: + +```python +class MoETransformer(nn.Module): + @staticmethod + def _is_transformer_block(name, module): + return "blocks" in name and name.split(".")[-1].isdigit() + + @staticmethod + def _is_moe_expert(name, module): + return "experts" in name and name.split(".")[-1].isdigit() + + _hsdp_shard_conditions = [_is_transformer_block, _is_moe_expert] +``` + +A module is sharded if **any** condition returns `True`. + +### Constraints + +- Cannot combine with Tensor Parallelism +- For standalone HSDP (no other parallelism), `hsdp_shard_size` must be specified explicitly +- Can combine with SP: HSDP reduces memory while SP distributes sequence + +### Testing HSDP + +```python +from vllm_omni.diffusion.data import DiffusionParallelConfig + +parallel_config = DiffusionParallelConfig(use_hsdp=True, hsdp_shard_size=8) +omni = Omni(model="your-model", parallel_config=parallel_config) +``` + +Or CLI: + +```bash +vllm serve Your-model --omni --use-hsdp +``` + +**Verify**: logs show "HSDP Inference: replicate_size=..., shard_size=..." and "Sharded N modules + root". Check VRAM reduction. + +### Reference implementations + +| Model | Path | +|-------|------| +| Wan2.2 | `vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py` | +| HSDP Core | `vllm_omni/diffusion/distributed/hsdp.py` | + +--- + +## VAE Patch Parallelism + +Shards VAE decode spatially across ranks using tiling: + +```bash +python text_to_image.py --model Your-model --vae-patch-parallel-size 4 +``` + +Auto-enables `--vae-use-tiling`. Uses `DistributedAutoencoderKLWan` or similar distributed VAE. Set `vae_patch_parallel_size` in `DiffusionParallelConfig`. + +--- + +## Combining Parallelism Methods + +Common multi-GPU recipes: + +```bash +# 4 GPUs: CFG (2) × Ulysses (2) +python text_to_image.py --model Qwen/Qwen-Image \ + --cfg-parallel-size 2 --ulysses-degree 2 + +# 8 GPUs: Ulysses (4) × Ring (2) + VAE patch (8) +python text_to_video.py --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --ulysses-degree 4 --ring-degree 2 --vae-patch-parallel-size 8 + +# 2 GPUs: HSDP + Ulysses (cannot combine HSDP with TP) +vllm serve Your-model --omni --use-hsdp --usp 2 +``` + +## Discovering Parallelism Support + +Check which parallelism methods a model supports: + +| Check | How | +|-------|-----| +| **Ulysses / Ring SP** | Transformer defines `_sp_plan`. Search: `grep -r '_sp_plan' vllm_omni/diffusion/models/` | +| **CFG Parallel** | Pipeline inherits `CFGParallelMixin`. Search: `grep -r 'CFGParallelMixin' vllm_omni/diffusion/models/` | +| **TP** | Uses `ColumnParallelLinear` / `QKVParallelLinear`. Search: `grep -r 'ParallelLinear\|QKVParallel' vllm_omni/diffusion/models//` | +| **HSDP** | Transformer defines `_hsdp_shard_conditions`. Search: `grep -r '_hsdp_shard_conditions' vllm_omni/diffusion/models/` | + +The canonical per-model support table is in `docs/user_guide/diffusion/parallelism_acceleration.md`. diff --git a/.claude/skills/add-diffusion-model/references/transformer-adaptation.md b/.claude/skills/add-diffusion-model/references/transformer-adaptation.md new file mode 100644 index 0000000000..6e344b6a66 --- /dev/null +++ b/.claude/skills/add-diffusion-model/references/transformer-adaptation.md @@ -0,0 +1,218 @@ +# Transformer Adaptation Reference + +## Adapting a Diffusers Transformer to vLLM-Omni + +### Step-by-step Checklist + +1. Copy the transformer class from diffusers source +2. Remove all mixin classes — inherit only from `nn.Module` +3. Replace attention dispatch with `vllm_omni.diffusion.attention.layer.Attention` +4. Replace logger with `vllm.logger.init_logger` +5. Add `od_config: OmniDiffusionConfig | None = None` to `__init__` +6. Remove training-only code (gradient checkpointing, dropout) +7. Add `load_weights()` method for weight loading from safetensors +8. Add class-level attributes for acceleration features + +### Mixin Removal + +Remove these diffusers mixins (and their imports): + +```python +# Remove all of these: +from diffusers.models.modeling_utils import ModelMixin +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention_processor import AttentionModuleMixin +from diffusers.loaders import PeftAdapterMixin, FromOriginalModelMixin + +# Replace: +class MyTransformer(ModelMixin, ConfigMixin, AttentionModuleMixin): +# With: +class MyTransformer(nn.Module): +``` + +Also remove `@register_to_config` decorators from `__init__`. + +### Attention Replacement + +The vLLM-Omni `Attention` layer wraps backend selection (FlashAttention, SDPA, SageAttn, etc.) and supports sequence parallelism hooks. + +**QKV tensor shape must be `[batch, seq_len, num_heads, head_dim]`.** + +#### Self-Attention Pattern + +```python +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata + +class SelfAttentionBlock(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.to_q = nn.Linear(dim, dim) + self.to_k = nn.Linear(dim, dim) + self.to_v = nn.Linear(dim, dim) + self.to_out = nn.Linear(dim, dim) + + self.attn = Attention( + num_heads=num_heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim ** 0.5), + causal=False, + num_kv_heads=num_heads, + ) + + def forward(self, x, attn_mask=None): + B, S, _ = x.shape + q = self.to_q(x).view(B, S, self.num_heads, self.head_dim) + k = self.to_k(x).view(B, S, self.num_heads, self.head_dim) + v = self.to_v(x).view(B, S, self.num_heads, self.head_dim) + + attn_metadata = AttentionMetadata(attn_mask=attn_mask) + out = self.attn(q, k, v, attn_metadata=attn_metadata) + out = out.reshape(B, S, -1) + return self.to_out(out) +``` + +#### Fused QKV with TP (Advanced) + +For tensor parallelism, use vLLM's parallel linear layers: + +```python +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, RowParallelLinear +) + +class TPSelfAttention(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.to_qkv = QKVParallelLinear( + hidden_size=dim, + head_size=self.head_dim, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + ) + self.to_out = RowParallelLinear(dim, dim) + + self.attn = Attention( + num_heads=num_heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim ** 0.5), + causal=False, + num_kv_heads=num_heads, + ) +``` + +### Logger Replacement + +```python +# Replace: +from diffusers.utils import logging +logger = logging.get_logger(__name__) + +# With: +from vllm.logger import init_logger +logger = init_logger(__name__) +``` + +### Custom Layers from vLLM-Omni + +Available utility layers: + +```python +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm_omni.diffusion.layers.rope import RotaryEmbedding +from vllm_omni.diffusion.layers.adalayernorm import AdaLayerNorm +``` + +### Config Support + +```python +from vllm_omni.diffusion.data import OmniDiffusionConfig + +class MyTransformer(nn.Module): + def __init__(self, *, od_config=None, num_layers=28, hidden_size=3072, **kwargs): + super().__init__() + self.od_config = od_config + self.parallel_config = od_config.parallel_config if od_config else None + # ... build layers +``` + +The transformer config values come from `model_index.json` → `config.json` in the transformer subfolder. The pipeline uses `get_transformer_config_kwargs(od_config.tf_model_config, TransformerClass)` to filter config keys to match the `__init__` signature. + +### Weight Loading + +The `load_weights` method receives an iterable of `(name, tensor)` from safetensors files, with the prefix (e.g., `"transformer."`) already stripped by the loader. + +```python +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +class MyTransformer(nn.Module): + def load_weights(self, weights): + params = dict(self.named_parameters()) + loaded = set() + for name, tensor in weights: + # Optional: remap names from diffusers to vllm-omni naming + # e.g., "ff.net.0.proj" -> "ff.net_0.proj" + + if name in params: + param = params[name] + if hasattr(param, "weight_loader"): + param.weight_loader(param, tensor) + else: + default_weight_loader(param, tensor) + loaded.add(name) + return loaded +``` + +#### QKV Fusion in load_weights + +If you fused separate Q/K/V into a `QKVParallelLinear`, you need to map diffusers' separate weight names: + +```python +stacked_params_mapping = [ + ("to_qkv", "to_q", "q"), + ("to_qkv", "to_k", "k"), + ("to_qkv", "to_v", "v"), +] + +def load_weights(self, weights): + params = dict(self.named_parameters()) + loaded = set() + for name, tensor in weights: + for fused_name, orig_name, shard_id in stacked_params_mapping: + if orig_name in name: + name = name.replace(orig_name, fused_name) + param = params[name] + param.weight_loader(param, tensor, shard_id) + loaded.add(name) + break + else: + # Normal loading + ... + return loaded +``` + +### Class-Level Attributes for Features + +```python +class MyTransformer(nn.Module): + # torch.compile: list block class names that repeat and can be compiled + _repeated_blocks = ["MyTransformerBlock"] + + # CPU offload: attribute name of the nn.ModuleList containing blocks + _layerwise_offload_blocks_attr = "blocks" + + # LoRA: mapping of fused param names to original param names + packed_modules_mapping = {"to_qkv": ["to_q", "to_k", "to_v"]} + + # Sequence parallelism plan (advanced — add after basic impl works) + _sp_plan = { + "blocks.0": SequenceParallelInput(split_dim=1), + "proj_out": SequenceParallelOutput(gather_dim=1), + } +``` diff --git a/.claude/skills/add-diffusion-model/references/troubleshooting.md b/.claude/skills/add-diffusion-model/references/troubleshooting.md new file mode 100644 index 0000000000..27acdd8d15 --- /dev/null +++ b/.claude/skills/add-diffusion-model/references/troubleshooting.md @@ -0,0 +1,178 @@ +# Troubleshooting Reference + +## Common Errors When Adding a Diffusion Model + +### ImportError / ModuleNotFoundError + +**Cause**: Missing or incorrect registration. + +**Fix checklist**: +1. Model registered in `vllm_omni/diffusion/registry.py` `_DIFFUSION_MODELS` dict +2. `__init__.py` exports the pipeline class +3. Pipeline file exists at the correct path: `vllm_omni/diffusion/models/{folder}/{file}.py` +4. Class name in registry matches the actual class name in the file + +### Shape Mismatch in Attention + +**Symptom**: `RuntimeError: shape mismatch` or `expected 4D tensor` + +**Cause**: QKV tensors not reshaped to `[batch, seq_len, num_heads, head_dim]`. + +**Fix**: Before calling `self.attn(q, k, v, ...)`, ensure: +```python +q = q.view(batch, seq_len, self.num_heads, self.head_dim) +k = k.view(batch, kv_seq_len, self.num_kv_heads, self.head_dim) +v = v.view(batch, kv_seq_len, self.num_kv_heads, self.head_dim) +``` + +After attention, reshape back: +```python +out = out.reshape(batch, seq_len, -1) +``` + +### Weight Loading Failures + +**Symptom**: `RuntimeError: size mismatch for parameter ...` or missing keys + +**Debugging**: +1. Print diffusers weight names: `safetensors.safe_open(path, "pt").keys()` +2. Print model parameter names: `dict(model.named_parameters()).keys()` +3. Compare and add name remappings in `load_weights()` + +**Common remappings needed**: +- `ff.net.0.proj` → `ff.net_0.proj` (PyTorch Sequential indexing) +- `.to_out.0.` → `.to_out.` (Sequential unwrapping) +- `scale_shift_table` → moved to a wrapper module + +### Black/Blank/Noisy Output + +**Possible causes**: +1. **Wrong latent normalization**: Check VAE expects latents scaled by `vae.config.scaling_factor` +2. **Wrong scheduler**: Using the wrong scheduler class or wrong `flow_shift` +3. **Missing CFG**: Some models require `guidance_scale > 1.0` with negative prompt +4. **Wrong timestep format**: Some schedulers expect float, others expect int/long +5. **Missing post-processing**: Raw VAE output may need denormalization + +**Quick test**: Run with diffusers directly using the same seed and compare latents at each step. + +### OOM (Out of Memory) + +**Solutions** (in order of preference): +1. `--enforce-eager` to disable torch.compile (saves compile memory) +2. `--enable-cpu-offload` for model-level offload +3. `--enable-layerwise-offload` for block-level offload (better for large models) +4. `--vae-use-slicing --vae-use-tiling` for VAE memory reduction +5. Reduce resolution: `--height 480 --width 832` +6. Use TP: `--tensor-parallel-size 2` + +### Different Output vs Diffusers Reference + +**Common causes**: +1. **Attention backend difference**: FlashAttention vs SDPA may produce slightly different results. Set `DIFFUSION_ATTENTION_BACKEND=TORCH_SDPA` to match diffusers +2. **Float precision**: vLLM-Omni may use bfloat16 where diffusers uses float32 for some operations +3. **Missing normalization**: Check all LayerNorm/RMSNorm are preserved +4. **Scheduler rounding**: Some schedulers have numerical sensitivity + +### Tensor Parallel Errors + +**Symptom**: `AssertionError: not divisible` or incorrect output with TP>1 + +**Fix**: +1. Verify `num_heads % tp_size == 0` and `num_kv_heads % tp_size == 0` +2. Ensure `ColumnParallelLinear` / `RowParallelLinear` are used correctly +3. Check that norms between parallel layers use distributed norm if needed +4. Verify `load_weights` handles TP sharding for norm weights +5. Use `self.to_qkv.num_heads` (local heads per GPU) for QKV split sizes, not total heads + +**Missing `input_is_parallel=True`**: + +`RowParallelLinear` expects sharded input from `ColumnParallelLinear`: +```python +self.w1 = ColumnParallelLinear(dim, hidden_dim, return_bias=False) +self.w2 = RowParallelLinear(hidden_dim, dim, input_is_parallel=True, return_bias=False) +``` + +### Sequence Parallel Errors + +**Symptom**: Incorrect output or crashes with `--ulysses-degree N` or `--usp N` + +**Possible causes**: +1. **Inline operations between shard/gather points**: `torch.cat()`, `pad_sequence()` etc. not at `nn.Module` boundaries. Fix: extract into submodule. +2. **Wrong `split_dim`**: Check the tensor shape at the shard point. Sequence dimension is typically `dim=1` for `[B, S, D]` tensors. +3. **RoPE not sharded**: If RoPE is computed separately, add it to `_sp_plan` with `split_output=True`. +4. **Sequence not divisible by SP degree**: Use `auto_pad=True` in `SequenceParallelInput` or switch to `ulysses_mode="advanced_uaa"`. + +**Debugging**: Add `expected_dims=N` to `SequenceParallelInput`/`Output` for shape validation at runtime. + +### CFG Parallel Errors + +**Symptom**: CFG parallel not activating, no speedup + +**Fix checklist**: +1. Pipeline inherits `CFGParallelMixin` +2. `guidance_scale > 1.0` +3. Negative prompt provided (even if empty string) +4. `--cfg-parallel-size 2` specified +5. `diffuse()` method calls `predict_noise_maybe_with_cfg()` and `scheduler_step_maybe_with_cfg()` + +**Symptom**: Different output with CFG parallel vs sequential + +**Possible cause**: Non-deterministic scheduler. Fix: pass `generator=torch.Generator(device).manual_seed(seed)` to `scheduler_step_maybe_with_cfg()`. + +### HSDP Errors + +**Symptom**: HSDP not activating or errors during weight loading + +**Fix checklist**: +1. Transformer defines `_hsdp_shard_conditions` class attribute +2. Shard condition functions return `True` for correct modules (test with `model.named_modules()`) +3. Not combining with TP (HSDP and TP are incompatible) +4. For standalone HSDP, `hsdp_shard_size` is specified explicitly + +**Verify**: Check logs for "HSDP Inference: replicate_size=..., shard_size=..." and "Sharded N modules + root". + +### Cache-DiT Not Applied + +**Symptom**: No speedup, no cache-related log messages + +**Fix checklist**: +1. Model not in `_NO_CACHE_ACCELERATION` in `registry.py` +2. Pipeline class name matches `CUSTOM_DIT_ENABLERS` key (if using custom enabler) +3. `cache_backend="cache_dit"` specified +4. Check logs for "Cache-dit enabled successfully on xxx" + +**Verify pipeline name**: `print(pipeline.__class__.__name__)` — must match registry key. + +### Cache-DiT Quality Degradation + +**Symptom**: Artifacts or lower quality with cache-dit + +**Fix**: Reduce aggressiveness: +```python +cache_config={ + "residual_diff_threshold": 0.12, # Lower from 0.24 + "max_warmup_steps": 6, # Increase from 4 + "max_continuous_cached_steps": 2, # Reduce if higher +} +``` + +If quality is still poor, the model may need a custom enabler with per-block-list `ParamsModifier` tuning. + +### Model Not Detected / Wrong Pipeline Class + +**Symptom**: `ValueError: Model class ... not found in diffusion model registry` + +**Cause**: The model's `model_index.json` has a `_class_name` for the pipeline that doesn't match registry keys. + +**Fix**: The registry key must match the diffusers pipeline class name from `model_index.json`. If using a different name, map it in the registry: +```python +"DiffusersPipelineClassName": ("your_folder", "your_file", "YourVllmClassName"), +``` + +## Debugging Workflow + +1. **Add verbose logging**: Use `logger.info()` to print tensor shapes at each stage +2. **Compare step-by-step**: Run diffusers and vllm-omni side by side, comparing tensors after each major operation +3. **Use small configs**: Reduce `num_inference_steps=2`, small resolution for fast iteration +4. **Test transformer isolation**: Feed the same input to both diffusers and vllm-omni transformers, compare outputs +5. **Binary search for bugs**: Comment out blocks/layers to isolate where divergence starts diff --git a/.claude/skills/add-tts-model/SKILL.md b/.claude/skills/add-tts-model/SKILL.md new file mode 100644 index 0000000000..e64e7e763e --- /dev/null +++ b/.claude/skills/add-tts-model/SKILL.md @@ -0,0 +1,284 @@ +--- +name: add-tts-model +description: "Integrate a new text-to-speech model into vLLM-Omni from HuggingFace reference implementation through production-ready serving with streaming and CUDA graph acceleration. Use when adding a new TTS model, wiring stage separation for speech synthesis, enabling online voice generation serving, debugging TTS integration behavior, or building audio output pipelines." +--- + +# TTS Model Integration Workflow + +## Overview + +``` +HF Reference -> Stage Separation -> Online Serving -> Async Chunk -> CUDA Graph + (Phase 1) (Phase 2) (Phase 3) (Phase 4) (Phase 5) +``` + +## Phase 1: HuggingFace Reference + +**Goal**: Understand the reference implementation and verify it produces correct audio. + +### Steps + +1. **Run the reference model** end-to-end using the official HuggingFace / GitHub code +2. **Document the architecture**: + - What are the sub-models? (AR decoder, codec decoder, vocoder, etc.) + - What is the token vocabulary? (semantic codes, RVQ codebooks, special tokens) + - What is the output format? (sample rate, channels, codec type) +3. **Capture reference outputs** for comparison during integration +4. **Identify the config structure**: `config.json` fields, `model_type`, sub-model configs + +### Key Questions + +- How many codebooks? What are the codebook sizes? +- What special tokens exist? (`<|voice|>`, `<|audio_start|>`, `<|im_end|>`, etc.) +- What is the token-to-ID mapping for codec codes? +- What is the hop length / frame rate of the codec? +- Does the model support voice cloning? How? (reference audio encoding, speaker embeddings, etc.) + +### Deliverables + +- Working reference script that produces audio +- Architecture diagram / notes +- Token vocabulary mapping +- Reference audio samples for regression testing + +## Phase 2: Stage Separation (Offline Inference) + +**Goal**: Split the model into vLLM-Omni stages and get offline inference working. + +### Steps + +1. **Register the model** in `vllm_omni/model_executor/models/registry.py` +2. **Create config classes** (`configuration_.py`) with `model_type` registration +3. **Implement Stage 0** (AR model): + - Subclass appropriate base (e.g., wrap Qwen3 decoder layers) + - Implement `forward()` for autoregressive token generation + - Handle special token logic (start/stop tokens, codec token mapping) + - If dual-AR (like Fish Speech), implement Fast AR as a nested module +4. **Implement Stage 1** (Decoder): + - Load codec weights (may need lazy loading from separate checkpoint) + - Implement `forward()`: codec codes -> audio waveform + - Return `OmniOutput` with `multimodal_outputs` +5. **Create stage config YAML** defining both stages, memory allocation, and model paths +6. **Create stage input processor** for prompt building +7. **Write end2end.py** test script + +### Critical Parameters to Get Right + +| Parameter | Impact if Wrong | +|-----------|----------------| +| Hop length | Audio duration wrong, streaming noise | +| Token ID mapping | Garbage codes -> noise output | +| Codebook count/size | Shape mismatch crashes | +| Stop token | Generation never stops or stops too early | +| dtype / autocast | Numerical issues, silent quality degradation | +| Repetition penalty | Must match reference (often 1.0 for TTS) | + +### Debugging Priority (from experience) + +When audio output is wrong, check in this order: + +1. **RoPE / attention**: Are position encodings correct? Is the attention mask right? +2. **Normalization**: RMSNorm epsilon, layer norm placement (pre vs post) +3. **Hop length**: Product of all upsample rates in the codec decoder +4. **Token mapping**: Are codec IDs correctly offset from the vocabulary base? +5. **Sampling parameters**: Temperature, top_k, top_p, repetition_penalty +6. **Tensor layout**: Codebook-major vs frame-major ordering +7. **dtype**: Float32 for codec decoders (autocast can corrupt audio) + +### Deliverables + +- Model files in `vllm_omni/model_executor/models//` +- Stage config YAML +- Working `end2end.py` with correct audio output +- README.md in the example directory + +## Phase 3: Online Serving + +**Goal**: Expose the model via `/v1/audio/speech` API endpoint. + +### Steps + +1. **Register in `serving_speech.py`**: + - Add model stage name to `_TTS_MODEL_STAGES` set + - Add model detection flag (e.g., `_is_fish_speech`) + - Implement prompt builder method (e.g., `_build_fish_speech_prompt()`) +2. **Handle model-specific parameters**: + - Voice cloning: `ref_audio` encoding and prompt injection + - `max_new_tokens` override in sampling params + - Model-specific default values +3. **Create client scripts**: `speech_client.py`, `run_server.sh` +4. **Test all response formats**: wav, mp3, flac, pcm +5. **Add Gradio demo**: Interactive web UI with streaming support + +### Voice Cloning Pattern + +```python +import base64 +from pathlib import Path + +def build_voice_clone_prompt(ref_audio_path: str, text: str, codec) -> list: + """Build prompt with reference audio for voice cloning in serving_speech.py.""" + audio_bytes = Path(ref_audio_path).read_bytes() + codes = codec.encode(audio_bytes) # Encode on CPU using model's codec (e.g., DAC) + token_ids = [code + codec.vocab_offset for code in codes.flatten().tolist()] + return [ + {"role": "system", "content": f"<|voice|>{''.join(chr(t) for t in token_ids)}"}, + {"role": "user", "content": text}, + ] +``` + +### Deliverables + +- Updated `serving_speech.py` with model-specific prompt builder +- Client scripts and server launcher +- Gradio demo with streaming and voice cloning UI +- Documentation (offline + online serving docs) + +## Phase 4: Async Chunk (Streaming) + +**Goal**: Enable inter-stage streaming so audio chunks are produced while AR generation continues. + +### Steps + +1. **Update stage config YAML**: + ```yaml + async_chunk: true + codec_chunk_frames: 25 # frames per chunk + codec_left_context_frames: 25 # overlap for smooth boundaries + ``` +2. **Implement chunk handling in Stage 1**: + - Accept partial input (chunk of codec codes) + - Handle left context for smooth audio boundaries + - Return partial audio in `OmniOutput` +3. **Test streaming**: + - Verify audio quality matches non-streaming output + - Check for artifacts at chunk boundaries + - Measure TTFA (time to first audio) +4. **Update online serving** to support `stream=true` with PCM output + +### Streaming Architecture + +``` +Stage 0 (AR) Stage 1 (Decoder) + | | + |-- chunk 0 (25 frames) ------> decode -> audio chunk 0 -> client + |-- chunk 1 (25 frames) ------> decode -> audio chunk 1 -> client + |-- chunk 2 (25 frames) ------> decode -> audio chunk 2 -> client + ... +``` + +### Key Considerations + +- **Left context overlap**: Prevents audible artifacts at chunk boundaries +- **Hop length matters**: `context_audio_samples = context_frames * hop_length` +- **First chunk latency**: Can use larger initial chunk for better quality, then smaller chunks + +### Deliverables + +- Updated stage config with async_chunk enabled +- Smooth streaming audio without boundary artifacts +- TTFA metrics + +## Phase 5: CUDA Graph Acceleration + +**Goal**: Capture the AR loop as a CUDA graph for significant speedup. + +### Steps + +1. **Identify the hot loop**: The AR decoding loop that runs N steps per token +2. **Create static buffers**: + - KV caches with fixed max sequence length + - Pre-built causal masks and position tensors per step + - Static input/output tensors +3. **Implement graph capture**: + - Warm up with real data + - Capture the forward pass + - Replay with updated inputs +4. **Handle constraints**: + - Use `torch.argmax` instead of `torch.multinomial` (graph-safe) + - Fixed batch size (fall back to eager for other sizes) + - No dynamic control flow inside the graph + +### Example: Code Predictor CUDA Graph (Qwen3-TTS) + +```python +import torch + +class CodePredictorGraph: + """Captures the 16-step code predictor AR loop as a single CUDA graph.""" + + def setup_graph(self, device: torch.device, kv_heads: int = 4, head_dim: int = 64): + self.num_steps = 16 + self.kv_cache = torch.zeros(1, kv_heads, self.num_steps, head_dim, device=device) + self.positions = torch.arange(self.num_steps, device=device) + self.causal_mask = torch.tril(torch.ones(self.num_steps, self.num_steps, device=device)) + self.input_buf = torch.zeros(1, 1, kv_heads * head_dim, device=device) + self.output_buf = torch.zeros(1, self.num_steps, device=device, dtype=torch.long) + # Warm up, then: self.graph = torch.cuda.CUDAGraph(); self.graph.capture(...) + + def run_graph(self, initial_input: torch.Tensor) -> torch.Tensor: + self.input_buf.copy_(initial_input) + self.graph.replay() + return self.output_buf.clone() +``` + +### Performance Expectations + +Based on Qwen3-TTS code predictor experience: +- **3-5x speedup** for the graphed component +- Only effective for fixed batch sizes (typically batch_size=1) +- Falls back to eager mode for unsupported configurations + +### Deliverables + +- CUDA graph implementation for the AR hot loop +- Benchmark script comparing eager vs graph performance +- Documentation of constraints and fallback behavior + +## Integration Checklist + +Use this checklist when integrating a new TTS model: + +### Phase 1: HF Reference +- [ ] Reference model runs and produces correct audio +- [ ] Architecture documented (stages, codebooks, tokens, sample rate) +- [ ] Reference audio samples saved for comparison + +### Phase 2: Stage Separation +- [ ] Model registered in `registry.py` +- [ ] Config classes created with `model_type` registration +- [ ] Stage 0 (AR) implemented and generates correct tokens +- [ ] Stage 1 (Decoder) produces correct audio from tokens +- [ ] Stage config YAML created +- [ ] `end2end.py` produces audio matching reference quality +- [ ] README.md written + +### Phase 3: Online Serving +- [ ] Model added to `serving_speech.py` +- [ ] Prompt builder handles text input correctly +- [ ] Voice cloning works (if supported) +- [ ] All response formats work (wav, mp3, flac, pcm) +- [ ] Client scripts and server launcher created +- [ ] Gradio demo working +- [ ] Documentation added (offline + online docs, nav, supported models) + +### Phase 4: Async Chunk +- [ ] Stage config updated with `async_chunk: true` +- [ ] Stage 1 handles partial chunks correctly +- [ ] No audio artifacts at chunk boundaries +- [ ] Streaming via API (`stream=true`) works +- [ ] TTFA measured and acceptable + +### Phase 5: CUDA Graph +- [ ] Hot loop identified and profiled +- [ ] Static buffers allocated +- [ ] Graph captured and replays correctly +- [ ] Benchmark shows meaningful speedup +- [ ] Fallback to eager works for unsupported configs + +## References + +- [TTS audio skill](../vllm-omni-audio-tts/SKILL.md) -- supported models and usage +- [Fish Speech integration](../vllm-omni-audio-tts/references/fish-speech.md) -- complete example of Phases 1-3 +- [Qwen3-TTS reference](../vllm-omni-audio-tts/references/qwen-tts.md) -- complete example of all 5 phases +- [Adding a TTS model (developer guide)](https://github.com/vllm-project/vllm-omni/blob/main/docs/contributing/model/adding_tts_model.md) diff --git a/.claude/skills/readme.md b/.claude/skills/readme.md new file mode 100644 index 0000000000..b66f2ecd13 --- /dev/null +++ b/.claude/skills/readme.md @@ -0,0 +1,34 @@ +# Claude Skills for vLLM-Omni + +This directory contains Claude Code skills maintained for the `vllm-omni` +repository. These skills capture repeatable workflows for common contributor +tasks such as model integration, pull request review, and release note +generation. + +## Directory Structure + +Each skill lives in its own directory under `.claude/skills/`. A skill may +include: + +- `SKILL.md`: the main workflow and operating instructions +- `references/`: focused reference material used by the skill +- `scripts/`: small helper scripts used by the skill + +## Available Skills + +- `add-diffusion-model`: guides integration of a new diffusion model into + `vllm-omni` +- `add-omni-model`: covers addition of new omni-modality model support +- `add-tts-model`: covers integration of new TTS models and related serving + workflows +- `generate-release-note`: helps prepare release notes for repository changes +- `review-pr`: provides a structured workflow for reviewing pull requests + +## Maintenance Guidelines + +- Keep skill names short and task-oriented. +- Prefer repository-local paths, commands, and examples. +- Avoid hardcoding fast-changing support matrices unless the skill is actively + maintained alongside those changes. +- Treat skills as contributor tooling: optimize for clarity, actionability, and + low maintenance overhead. diff --git a/.claude/skills/vllm-omni-npu-upgrade/SKILL.md b/.claude/skills/vllm-omni-npu-upgrade/SKILL.md new file mode 100644 index 0000000000..1ef7ab3930 --- /dev/null +++ b/.claude/skills/vllm-omni-npu-upgrade/SKILL.md @@ -0,0 +1,300 @@ +--- +name: vllm-omni-npu-model-runner-upgrade +description: "Upgrade vllm-omni NPU model runners (OmniNPUModelRunner, NPUARModelRunner, NPUGenerationModelRunner) to align with the latest vllm-ascend NPUModelRunner while preserving omni-specific logic." +--- + +# vLLM-Omni NPU Model Runner Upgrade Skill + +## Overview + +This skill guides the process of upgrading vllm-omni's NPU model runners to align with the latest vllm-ascend codebase while preserving omni-specific enhancements. The NPU runners are designed to run omni multimodal models (like Qwen3-Omni, Bagel, MiMoAudio) on Ascend NPUs. + +## File Structure + +### NPU Model Runner Files +``` +vllm-omni/vllm_omni/platforms/npu/worker/ +├── __init__.py +├── npu_model_runner.py # OmniNPUModelRunner (base class) +├── npu_ar_model_runner.py # NPUARModelRunner (autoregressive) +├── npu_ar_worker.py # AR worker +├── npu_generation_model_runner.py # NPUGenerationModelRunner (diffusion/non-AR) +└── npu_generation_worker.py # Generation worker +``` + +### GPU Reference Files (for omni-specific logic sync) +``` +vllm-omni/vllm_omni/worker/ +├── __init__.py +├── gpu_model_runner.py # OmniGPUModelRunner +├── gpu_ar_model_runner.py # GPUARModelRunner +├── gpu_ar_worker.py +├── gpu_generation_model_runner.py +├── gpu_generation_worker.py +├── mixins.py +├── base.py +└── gpu_memory_utils.py +``` + +### vllm-ascend Reference Files +``` +vllm-ascend/vllm_ascend/worker/ +├── model_runner_v1.py # NPUModelRunner (base class to copy from) +├── npu_input_batch.py +├── block_table.py +├── pcp_utils.py +└── worker.py +``` + +## Inheritance Hierarchy + +``` + GPUModelRunner (vllm) + | + +----------------+----------------+ + | | + OmniGPUModelRunner NPUModelRunner (vllm-ascend) + (vllm_omni/worker) (vllm_ascend/worker) + | | + +----------- OmniNPUModelRunner --+ + (multiple inheritance) + | + +---------------+---------------+ + | | + NPUARModelRunner NPUGenerationModelRunner + (autoregressive) (non-autoregressive/diffusion) +``` + +## Omni-Specific Comment Markers + +Omni-specific logic is marked with comment blocks: +```python +# -------------------------------------- Omni-new ------------------------------------------------- +# ... omni-specific code ... +# -------------------------------------- Omni-new ------------------------------------------------- +``` + +Or simpler variations: +```python +# -------------------------------------- Omni-new ------------------------------------------------- +# ------------------------------------------------------------------------------------------------ +``` + +**Important**: +- Always preserve and add these markers when modifying code. +- **The reference documents (`references/omni-specific-blocks.md`) may not be up-to-date.** Always grep for `Omni-new` in the GPU implementations to find the authoritative list of omni-specific blocks. +- When you discover new omni-specific code that is not documented in the references, please update the reference files. + +## Key Methods Requiring Attention + +### OmniNPUModelRunner (npu_model_runner.py) + +| Method | Description | Omni-Specific Logic | +|--------|-------------|---------------------| +| `load_model` | Load model and initialize talker_mtp | Uses `ACLGraphWrapper` instead of `CUDAGraphWrapper`, initializes talker buffers | +| `_dummy_run` | Warmup/profiling run | talker_mtp dummy forward, `extract_multimodal_outputs` | +| `_model_forward` | Forward pass wrapper | Injects `model_kwargs_extra`, wraps with `OmniOutput`, NPU-specific graph updates | +| `_talker_mtp_forward` | Talker MTP forward for Qwen3-Omni | Uses `set_ascend_forward_context` | + +### NPUARModelRunner (npu_ar_model_runner.py) + +| Method | Description | Omni-Specific Logic | +|--------|-------------|---------------------| +| `__init__` | Initialize with KV transfer manager | `OmniKVTransferManager` setup | +| `execute_model` | Main inference entry | KV transfer handling, `_update_states` override, `extract_multimodal_outputs` | +| `sample_tokens` | Token sampling | Hidden states extraction, multimodal outputs processing, `OmniModelRunnerOutput` | +| `_resolve_global_request_id` | Request ID resolution | For disaggregated inference | + +### NPUGenerationModelRunner (npu_generation_model_runner.py) + +| Method | Description | Omni-Specific Logic | +|--------|-------------|---------------------| +| `_update_request_states` | Update request states for async chunk | async_chunk handling | +| `execute_model` | Generation forward | async_chunk, `seq_token_counts`, `_run_generation_model` | +| `sample_tokens` | Output processing | multimodal output packaging to `OmniModelRunnerOutput` | +| `_dummy_run` | Dummy run override | model_kwargs initialization, multimodal extraction | +| `_run_generation_model` | Run generation model | Calls `_model_forward` with sampler | + +## Upgrade Workflow + +### Step 1: Preparation + +1. **Identify target versions**(Use gh cli to check): + - We're using vllm-omni main branch + - Check the last release of vllm-omni + - Target vllm-ascend version(Just directly use the local latest vllm-ascend code) + +2. **Check GPU-side changes** (since last release): + ```bash + cd /root/vllm-workspace/vllm-omni + git log --oneline --since="" -- vllm_omni/worker/ + ``` + +3. **Read latest vllm-ascend code**: + - We don't track vllm-ascend changes - just directly use the latest code from `/root/vllm-workspace/vllm-ascend/vllm_ascend/worker/model_runner_v1.py` + - Copy the relevant methods and re-insert omni-specific blocks + +### Step 2: Analyze Omni-Specific Logic + +For each NPU model runner file: + +1. **Extract existing omni-specific blocks**: + ```bash + grep -n "Omni-new" vllm_omni/platforms/npu/worker/npu_model_runner.py + ``` + +2. **Document each omni block**: + - Which method it belongs to + - What functionality it provides + - Dependencies on other omni code + +### Step 3: Update Base Class (OmniNPUModelRunner) + +**Note**: Always check the GPU implementation `gpu_model_runner.py` for any new omni logic not yet documented in references. + +1. **Read the latest vllm-ascend `NPUModelRunner.load_model`** +2. **Copy the method, keeping the structure** +3. **Re-insert omni-specific logic** (check GPU `gpu_model_runner.py` for authoritative list): + - Replace `CUDAGraphWrapper` with `ACLGraphWrapper` + - Keep talker_mtp initialization + - Preserve buffer allocations for talker + - Check for any new omni blocks added since last sync + +4. **Update `_dummy_run`**: + - Copy from vllm-ascend + - Compare with GPU `_dummy_run` for omni-specific blocks + - Re-insert all `Omni-new` marked code from GPU version + +5. **Update `_model_forward`**: + - Keep the omni wrapper logic + - Update NPU-specific parts (graph params, SP all-gather) + - Check GPU version for any new omni logic + +### Step 4: Update AR Model Runner + +1. **Compare with GPU `gpu_ar_model_runner.py`** for any new omni features +2. **Copy `execute_model` from vllm-ascend** +3. **Re-insert omni blocks** (reference `references/omni-specific-blocks.md`, but note it may be incomplete): + - **IMPORTANT**: Always check the GPU implementation `gpu_ar_model_runner.py` for all `Omni-new` marked code blocks + - The reference doc may not include newly added omni logic - treat it as a starting point, not exhaustive + - When discovering new omni code blocks, please update `references/omni-specific-blocks.md` + - Common omni blocks include but are not limited to: KV transfer, multimodal outputs, sampling_metadata handling, etc. + +4. **Update `sample_tokens`** (also compare with GPU implementation): + - Compare with `gpu_ar_model_runner.py`'s `sample_tokens` method + - Identify all `Omni-new` marked code blocks + - Ensure NPU version includes all omni-specific logic + +### Step 5: Update Generation Model Runner + +**Note**: Generation model runner may have unique omni logic for diffusion/non-AR models. + +1. **Compare with GPU `gpu_generation_model_runner.py`** - grep for all `Omni-new` blocks +2. **Update `execute_model`**: + - Check GPU version for all omni-specific blocks + - Keep async_chunk handling + - Keep `seq_token_counts` injection + - Update forward/context setup from vllm-ascend + - Look for any new omni logic not documented in references + +3. **Update `_dummy_run`**: + - Copy from vllm-ascend base + - Compare with GPU `_dummy_run` if exists + - Re-insert all omni-specific logic + +### Step 6: Update Imports + +Check and update imports at the top of each file: + +```python +# Common vllm-ascend imports +from vllm_ascend.ascend_forward_context import get_forward_context, set_ascend_forward_context +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.utils import using_paged_attention +from vllm_ascend.compilation.acl_graph import ACLGraphWrapper, update_full_graph_params +from vllm_ascend.ops.rotary_embedding import update_cos_sin +from vllm_ascend.utils import enable_sp, lmhead_tp_enable +from vllm_ascend.worker.model_runner_v1 import SEQ_LEN_WITH_MAX_PA_WORKSPACE, NPUModelRunner + +# Omni-specific imports +from vllm_omni.model_executor.models.output_templates import OmniOutput +from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner +from vllm_omni.outputs import OmniModelRunnerOutput +from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager +``` + +### Step 7: Sync GPU-Side Omni Changes + +1. **Check recent GPU worker changes**: + ```bash + git diff .. -- vllm_omni/worker/gpu_model_runner.py + git diff .. -- vllm_omni/worker/gpu_ar_model_runner.py + ``` + +2. **Identify new omni features** that need to be ported to NPU + +3. **Apply corresponding changes** to NPU runners + +### Step 8: Validation + +1. **Run type checking**: + ```bash + cd /root/vllm-workspace/vllm-omni + python -m py_compile vllm_omni/platforms/npu/worker/npu_model_runner.py + python -m py_compile vllm_omni/platforms/npu/worker/npu_ar_model_runner.py + python -m py_compile vllm_omni/platforms/npu/worker/npu_generation_model_runner.py + ``` + +2. **Run import test**: + ```bash + python -c "from vllm_omni.platforms.npu.worker import *" + ``` + +3. **Run model serving test** (if hardware available): + ```bash + vllm serve --trust-remote-code + ``` + +## Common Pitfalls + +### 1. Forward Context Differences +- GPU uses `set_forward_context` +- NPU uses `set_ascend_forward_context` +- Parameters may differ slightly + +### 2. Graph Wrapper Differences +- GPU: `CUDAGraphWrapper` +- NPU: `ACLGraphWrapper` +- Constructor parameters may differ + +### 3. Buffer Creation +- GPU: `_make_buffer` returns different structure +- NPU: May need numpy=True/False parameter + +### 4. Attention Metadata +- GPU: Uses vllm attention metadata builders +- NPU: Uses `AscendCommonAttentionMetadata` + +### 5. Sampling +- GPU: Uses vllm sampler +- NPU: Uses `AscendSampler` + +## Checklist Before Commit + +- [ ] All omni-specific comment markers preserved +- [ ] New omni logic from GPU side synced +- [ ] Imports updated to latest vllm-ascend +- [ ] No `CUDAGraphWrapper` references in NPU code +- [ ] `set_ascend_forward_context` used instead of `set_forward_context` +- [ ] `ACLGraphWrapper` used for talker_mtp wrapping +- [ ] Type hints match vllm-ascend signatures +- [ ] No duplicate code blocks +- [ ] Python syntax valid (py_compile passes) + +## Reference Files for Comparison + +When upgrading, keep these files open for reference: + +1. **vllm-ascend NPUModelRunner**: `/root/vllm-workspace/vllm-ascend/vllm_ascend/worker/model_runner_v1.py` +2. **vllm GPUModelRunner**: `/root/vllm-workspace/vllm/vllm/v1/worker/gpu_model_runner.py` +3. **vllm-omni OmniGPUModelRunner**: `/root/vllm-workspace/vllm-omni/vllm_omni/worker/gpu_model_runner.py` diff --git a/.claude/skills/vllm-omni-npu-upgrade/references/gpu-to-npu-translation.md b/.claude/skills/vllm-omni-npu-upgrade/references/gpu-to-npu-translation.md new file mode 100644 index 0000000000..89067d37b2 --- /dev/null +++ b/.claude/skills/vllm-omni-npu-upgrade/references/gpu-to-npu-translation.md @@ -0,0 +1,335 @@ +# GPU to NPU Translation Patterns + +This document provides a quick reference for translating GPU code patterns to NPU equivalents when porting omni-specific logic. + +## Import Translations + +### Forward Context +```python +# GPU +from vllm.forward_context import set_forward_context + +# NPU +from vllm_ascend.ascend_forward_context import set_ascend_forward_context +``` + +### Graph Wrapper +```python +# GPU +from vllm.compilation.cuda_graph import CUDAGraphWrapper + +# NPU +from vllm_ascend.compilation.acl_graph import ACLGraphWrapper +``` + +### Attention State +```python +# GPU (no equivalent - uses FlashAttention states directly) + +# NPU +from vllm_ascend.attention.attention_v1 import AscendAttentionState +``` + +### Utilities +```python +# GPU +# (directly use torch.cuda functions) + +# NPU +from vllm_ascend.utils import enable_sp, lmhead_tp_enable +from vllm_ascend.ops.rotary_embedding import update_cos_sin +``` + +## Context Manager Translations + +### Forward Context Setup +```python +# GPU +with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens_padded, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_mode, + batch_descriptor=batch_desc, +): + # forward pass + +# NPU +with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens_padded, + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=cudagraph_mode, # Note: 'aclgraph' not 'cudagraph' + batch_descriptor=batch_desc, + num_actual_tokens=scheduler_output.total_num_scheduled_tokens, + model_instance=self.model, +): + # forward pass +``` + +### Graph Capture Context +```python +# GPU +from vllm.compilation.cuda_graph import graph_capture as cuda_graph_capture +with cuda_graph_capture(self.device): + # capture + +# NPU +from vllm_ascend.worker.model_runner_v1 import graph_capture +with graph_capture(self.device): + # capture +``` + +## Graph Wrapper Usage + +### Creating Graph Wrapper +```python +# GPU +if cudagraph_mode.has_full_cudagraphs() and has_separate_talker: + self.talker_mtp = CUDAGraphWrapper( + talker_mtp, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL + ) + +# NPU +if cudagraph_mode.has_full_cudagraphs() and has_separate_talker: + self.talker_mtp = ACLGraphWrapper( + talker_mtp, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL + ) +``` + +### Checking Graph Wrapper Type +```python +# GPU +if not isinstance(self.talker_mtp, CUDAGraphWrapper): + _cudagraph_mode = CUDAGraphMode.NONE + +# NPU +if not isinstance(self.talker_mtp, ACLGraphWrapper): + _cudagraph_mode = CUDAGraphMode.NONE +``` + +## Device Operations + +### Synchronization +```python +# GPU +torch.cuda.synchronize() + +# NPU +torch.npu.synchronize() +``` + +### Stream Operations +```python +# GPU +stream = torch.cuda.Stream(device=device) +torch.cuda.current_stream() + +# NPU +stream = torch.npu.Stream(device=device) +torch.npu.current_stream() +``` + +## Attention Metadata + +### State Setting (NPU-specific) +```python +# GPU - handled internally by attention backends + +# NPU - explicit state setting required +self.attn_state = AscendAttentionState.DecodeOnly +if self.speculative_config and self.speculative_config.method == "mtp": + if self.vllm_config.model_config.use_mla: + self.attn_state = AscendAttentionState.SpecDecoding + else: + self.attn_state = AscendAttentionState.ChunkedPrefill +``` + +### Building Attention Metadata +```python +# GPU - uses vllm attention builders + +# NPU - may need additional parameters +(attn_metadata, spec_decode_common_attn_metadata) = self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded, + num_reqs=num_reqs, + num_reqs_padded=num_reqs_padded, + max_query_len=max_num_scheduled_tokens, + ubatch_slices=ubatch_slices_attn, + logits_indices=logits_indices, + use_spec_decode=use_spec_decode, + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + num_scheduled_tokens_np=num_scheduled_tokens_np, + cascade_attn_prefix_lens=cascade_attn_prefix_lens, +) +``` + +## Rotary Embedding + +### Update Cos/Sin Cache +```python +# GPU - typically handled inside attention + +# NPU - explicit update required before forward +from vllm_ascend.ops.rotary_embedding import update_cos_sin +update_cos_sin(positions) +``` + +## Sequence Parallelism + +### Enable SP Check +```python +# GPU - use vllm distributed utilities + +# NPU - use vllm-ascend wrapper +from vllm_ascend.utils import enable_sp + +if enable_sp(): + # sequence parallelism enabled +``` + +## Sampler + +### Sampler Type +```python +# GPU - uses vllm sampler +self.sampler = Sampler() + +# NPU - uses AscendSampler +from vllm_ascend.sample.sampler import AscendSampler +self.sampler = AscendSampler() +``` + +## Input Batch + +### Batch Class +```python +# GPU +from vllm.v1.worker.gpu_input_batch import InputBatch + +# NPU +from vllm_ascend.worker.npu_input_batch import NPUInputBatch +``` + +## Graph Parameter Updates + +### Full Graph Params Update (NPU-specific) +```python +# GPU - not needed + +# NPU - required for FULL graph mode +from vllm_ascend.compilation.acl_graph import update_full_graph_params + +forward_context = get_forward_context() +if ( + forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL + and not forward_context.capturing + and not self.use_sparse +): + update_full_graph_params( + self.attn_backend, + self.update_stream, + forward_context, + num_tokens_padded, + self.vllm_config, + self.speculative_config, + positions.shape[0], + ) +``` + +## Paged Attention Check + +```python +# GPU - not typically needed + +# NPU +from vllm_ascend.attention.utils import using_paged_attention + +if is_graph_capturing and using_paged_attention(num_tokens, self.vllm_config): + seq_lens = SEQ_LEN_WITH_MAX_PA_WORKSPACE +``` + +## Common Method Signature Differences + +### _dummy_run Parameters +```python +# GPU (v0.17.0) +def _dummy_run( + self, + num_tokens: int, + cudagraph_runtime_mode: CUDAGraphMode | None = None, + force_attention: bool = False, + uniform_decode: bool = False, + allow_microbatching: bool = True, + skip_eplb: bool = False, + is_profile: bool = False, + create_mixed_batch: bool = False, + remove_lora: bool = True, + is_graph_capturing: bool = False, + num_active_loras: int = 0, +) -> tuple[torch.Tensor, torch.Tensor]: + +# NPU (v0.17.0) - adds with_prefill, activate_lora +def _dummy_run( + self, + num_tokens: int, + with_prefill: bool = False, + cudagraph_runtime_mode: CUDAGraphMode | None = None, + force_attention: bool = False, + uniform_decode: bool = False, + is_profile: bool = False, + create_mixed_batch: bool = False, + allow_microbatching: bool = True, + skip_eplb: bool = False, + remove_lora: bool = True, + activate_lora: bool = False, + is_graph_capturing: bool = False, + num_active_loras: int = 0, +) -> tuple[torch.Tensor, torch.Tensor]: +``` + +### _model_forward Parameters +```python +# GPU - no num_tokens_padded +def _model_forward( + self, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any], +): + +# NPU - has num_tokens_padded as first parameter +def _model_forward( + self, + num_tokens_padded: int, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any], +): +``` + +## Quick Reference Table + +| Feature | GPU | NPU | +|---------|-----|-----| +| Graph wrapper | `CUDAGraphWrapper` | `ACLGraphWrapper` | +| Forward context | `set_forward_context` | `set_ascend_forward_context` | +| Runtime mode param | `cudagraph_runtime_mode` | `aclgraph_runtime_mode` | +| Device sync | `torch.cuda.synchronize()` | `torch.npu.synchronize()` | +| Stream | `torch.cuda.Stream` | `torch.npu.Stream` | +| Current stream | `torch.cuda.current_stream()` | `torch.npu.current_stream()` | +| Input batch | `InputBatch` | `NPUInputBatch` | +| Sampler | `Sampler` | `AscendSampler` | +| Attention state | N/A | `AscendAttentionState` | +| RoPE update | N/A | `update_cos_sin()` | diff --git a/.claude/skills/vllm-omni-npu-upgrade/references/omni-specific-blocks.md b/.claude/skills/vllm-omni-npu-upgrade/references/omni-specific-blocks.md new file mode 100644 index 0000000000..8c5d32ab4c --- /dev/null +++ b/.claude/skills/vllm-omni-npu-upgrade/references/omni-specific-blocks.md @@ -0,0 +1,374 @@ +# Omni-Specific Code Blocks Reference + +This document catalogs omni-specific code blocks in the NPU model runners, making it easier to identify what needs to be preserved during upgrades. + +> **IMPORTANT**: This document may not be complete or up-to-date! +> +> - Always grep for `Omni-new` in the GPU implementations (`vllm_omni/worker/`) to find the authoritative list +> - New omni features may be added that are not yet documented here +> - When you discover new omni-specific blocks during an upgrade, please update this document +> - Last verified: Check git history for this file + +## OmniNPUModelRunner (npu_model_runner.py) + +### load_model - Talker MTP Initialization + +```python +def load_model(self, *args, **kwargs) -> None: + NPUModelRunner.load_model(self, *args, **kwargs) + # Initialize enable_sp cache to avoid get_current_vllm_config() error + # in _pad_for_sequence_parallelism during execute_model. + # This is a workaround for vllm-ascend not passing vllm_config to enable_sp(). + enable_sp(self.vllm_config) + # TODO move this model specific logic to a separate class + # TTS model IS the talker (no .talker sub-attr); use getattr to support both Omni and TTS. + talker_mtp = getattr(self.model, "talker_mtp", None) + if talker_mtp is not None: + self.talker_mtp = talker_mtp # type: ignore[assignment] + cudagraph_mode = self.compilation_config.cudagraph_mode + assert cudagraph_mode is not None + # Only wrap talker_mtp in CUDAGraphWrapper for Omni models that + # have a separate .talker sub-module. TTS models' code predictor + # has internal AR loops / torch.multinomial — not graph-safe. + has_separate_talker = getattr(self.model, "talker", None) is not None + if cudagraph_mode.has_full_cudagraphs() and has_separate_talker: + # NOTE: Use ACLGraphWrapper on NPU, not CUDAGraphWrapper + self.talker_mtp = ACLGraphWrapper(talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) + # TTS exposes mtp_hidden_size; Omni uses hf_text_config.hidden_size. + hidden_size = int( + getattr(self.model, "mtp_hidden_size", 0) or getattr(self.model_config.hf_text_config, "hidden_size") + ) + max_batch_size = max(self.max_num_reqs, self.compilation_config.max_cudagraph_capture_size) + self.talker_mtp_input_ids = self._make_buffer(max_batch_size, dtype=torch.int32) + self.talker_mtp_inputs_embeds = self._make_buffer( + max_batch_size, hidden_size, dtype=self.dtype, numpy=False + ) + self.last_talker_hidden = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False) + self.text_step = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False) +``` + +### _dummy_run - Talker MTP Dummy Forward + +Location: Inside `set_ascend_forward_context` block, before main model forward + +```python +# ---------------------------------------Omni-new---------------------------------------------- +if getattr(self.model, "talker", None) is not None and hasattr(self.model, "talker_mtp"): + num_tokens_padded_talker_mtp = num_tokens_padded + if num_tokens_padded_talker_mtp == self.max_num_tokens: + num_tokens_padded_talker_mtp = self.talker_mtp_input_ids.gpu.shape[0] + outputs = self.talker_mtp( + self.talker_mtp_input_ids.gpu[:num_tokens_padded_talker_mtp], + self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded_talker_mtp], + self.last_talker_hidden.gpu[:num_tokens_padded_talker_mtp], + self.text_step.gpu[:num_tokens_padded_talker_mtp], + ) + self.compilation_config.cache_dir = None +# ---------------------------------------Omni-new---------------------------------------------- +``` + +### _dummy_run - Extract Multimodal Outputs + +Location: After model forward, before dummy_compute_logits + +```python +# ---------------------------------------Omni-new---------------------------------------------- +hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states) +# ---------------------------------------Omni-new---------------------------------------------- +``` + +### _model_forward - Omni Output Wrapping + +```python +def _model_forward( + self, + num_tokens_padded: int, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any], +): + """Override to combine NPUModelRunner's signature with OmniGPUModelRunner's logic.""" + # Omni-specific: build and inject extra model kwargs + model_kwargs_extra = self._build_model_kwargs_extra() + + # Call the model forward (same as NPUModelRunner) + assert self.model is not None + model_output = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + **model_kwargs_extra, + ) + + # Omni-specific: wrap output if needed + if not isinstance(model_output, OmniOutput) and hasattr(self.model, "make_omni_output"): + model_output = self.model.make_omni_output(model_output, **model_kwargs_extra) + + # Omni-specific: cache model output for later sample_tokens + self._omni_last_model_output = model_output + + # NPU-specific: update full graph params (keep from vllm-ascend) + forward_context = get_forward_context() + # ... NPU graph update logic ... + + # NPU-specific: all-gather for sequence parallelism (keep from vllm-ascend) + if get_forward_context().sp_enabled and not isinstance(model_output, IntermediateTensors): + model_output = self._all_gather_hidden_states_and_aux(model_output) + + return model_output +``` + +--- + +## NPUARModelRunner (npu_ar_model_runner.py) + +### __init__ - KV Transfer Manager + +```python +def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) + # each model stage has their own hidden size + self.hidden_size = self.model_config.hf_text_config.hidden_size + self.inputs_embeds = self._make_buffer(self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False) + # Initialize KV cache manager (preserve vllm_config fallback behavior) + self.kv_transfer_manager = OmniKVTransferManager.from_vllm_config(self.vllm_config, self.model_config) +``` + +### execute_model - KV Transfer Before Update States + +Location: At the very beginning of execute_model + +```python +# -------------------------------------- Omni-new ------------------------------------------------- +# [Omni] Handle KV transfer BEFORE updating states (which removes finished requests) +self.kv_extracted_req_ids = self.kv_transfer_manager.handle_finished_requests_kv_transfer( + finished_reqs=getattr(scheduler_output, "finished_requests_needing_kv_transfer", {}), + kv_caches=self.kv_caches, + block_size=self.cache_config.block_size, + cache_dtype=str(self.cache_config.cache_dtype), + request_id_resolver=self._resolve_global_request_id, +) +# -------------------------------------- Omni-new ------------------------------------------------- +``` + +### execute_model - Custom _update_states Call + +Location: Inside synchronize_input_prep context + +```python +# -------------------------------------- Omni-new ------------------------------------------------- +self._update_states(scheduler_output) +# ------------------------------------------------------------------------------------------------ +``` + +### execute_model - Extract Multimodal Outputs + +Location: In post process section, after hidden_states assignment + +```python +# -------------------------------------- Omni-new ------------------------------------------------- +hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states) + +if multimodal_outputs is not None: + keys_or_type = ( + list(multimodal_outputs.keys()) + if isinstance(multimodal_outputs, dict) + else type(multimodal_outputs) + ) + logger.debug(f"[AR] execute_model: multimodal_outputs keys = {keys_or_type}") +else: + logger.debug("[AR] execute_model: multimodal_outputs is None") +# -------------------------------------- Omni-new ------------------------------------------------- +``` + +### execute_model - Compute Logits with sampling_metadata + +Location: In both broadcast_pp_output True and False branches + +```python +# -------------------------------------- Omni-new ------------------------------------------------- +# Try with sampling_metadata first; fall back to without for models that don't support it +try: + logits = self.model.compute_logits( + sample_hidden_states, sampling_metadata=self.input_batch.sampling_metadata + ) +except TypeError: + logits = self.model.compute_logits(sample_hidden_states) +# -------------------------------------- Omni-new ------------------------------------------------- +``` + +### sample_tokens - KV Extracted Req IDs + +Location: At the beginning of sample_tokens + +```python +# -------------------------------------- Omni-new ------------------------------------------------- +kv_extracted_req_ids = getattr(self, "kv_extracted_req_ids", None) +self.kv_extracted_req_ids = None +# -------------------------------------- Omni-new ------------------------------------------------- +``` + +### sample_tokens - Process Additional Information and Build Output + +Location: After bookkeeping sync, replacing the original output construction + +```python +# -------------------------------------- Omni-new ------------------------------------------------- +hidden_states_cpu = hidden_states.detach().to("cpu").contiguous() +num_scheduled_tokens_np = getattr(self, "_omni_num_scheduled_tokens_np", None) +if num_scheduled_tokens_np is None: + req_ids = self.input_batch.req_ids + num_scheduled_tokens_np = np.array( + [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids], + dtype=np.int32, + ) + +self._process_additional_information_updates( + hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output +) + +pooler_output: list[dict[str, object]] = [] +for rid in req_ids_output_copy: + idx = req_id_to_index_output_copy[rid] + start = int(self.query_start_loc.cpu[idx]) + sched = int(num_scheduled_tokens_np[idx]) + end = start + sched + hidden_slice = hidden_states_cpu[start:end] + payload: dict[str, object] = {"hidden": hidden_slice} + if isinstance(multimodal_outputs, dict) and multimodal_outputs: + # ... multimodal output slicing logic ... + pooler_output.append(payload) + +model_runner_output = OmniModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=valid_sampled_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=(pooler_output if self.vllm_config.model_config.engine_output_type != "text" else None), + kv_connector_output=kv_connector_output, +) +model_runner_output.kv_extracted_req_ids = kv_extracted_req_ids +# -------------------------------------- Omni-new ------------------------------------------------- +``` + +--- + +## NPUGenerationModelRunner (npu_generation_model_runner.py) + +### execute_model - Async Chunk Update + +Location: Inside prepare input section, before synchronize_input_prep + +```python +# -------------------------------------- Omni-new ------------------------------------------------- +if self.model_config.async_chunk and num_scheduled_tokens: + self._update_request_states(scheduler_output) +# -------------------------------------- Omni-new ------------------------------------------------- +``` + +### execute_model - Seq Token Counts + +Location: After _preprocess call + +```python +# [Omni] Pass token counts per request for code2wav output slicing +model_kwargs["seq_token_counts"] = tokens +``` + +### execute_model - Run Generation Model + +Location: Inside forward context + +```python +# -------------------------------------- Omni-new ------------------------------------------------- +outputs = self._run_generation_model( + num_tokens_padded=num_tokens_padded, + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + model_kwargs=model_kwargs, + logits_indices=logits_indices, +) +_, multimodal_outputs = self.extract_multimodal_outputs(outputs) +# -------------------------------------- Omni-new ------------------------------------------------- +``` + +### sample_tokens - Multimodal Output Processing + +The entire sample_tokens method body is omni-specific for generation models: + +```python +# -------------------------------------- Omni-new ------------------------------------------------- +pooler_output: list[object] = [] +if isinstance(multimodal_outputs, torch.Tensor): + # ... tensor handling ... +elif isinstance(multimodal_outputs, list): + # ... list handling ... +elif isinstance(multimodal_outputs, dict): + # ... dict handling per request ... +else: + raise RuntimeError("Unsupported diffusion output type") +# [Omni] Copy req_id mappings to avoid async scheduling mutation. +req_ids_output_copy = self.input_batch.req_ids.copy() +req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() +output = OmniModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=pooler_output, + kv_connector_output=kv_connector_output, + num_nans_in_logits={}, + ec_connector_output=ec_connector_output if self.supports_mm_inputs else None, +) +# -------------------------------------- Omni-new ------------------------------------------------- +``` + +### _dummy_run - Model Kwargs Init and Multimodal Extract + +Location: Before model forward and after + +```python +model_kwargs = self._init_model_kwargs() # Before forward + +# ... forward ... + +# -------------------------------------- Omni-new ------------------------------------------------- +hidden_states, _ = self.extract_multimodal_outputs(hidden_states) +# ------------------------------------------------------------------------------------------------- +``` + +--- + +## ExecuteModelState Extension + +The `ExecuteModelState` NamedTuple is extended for omni: + +```python +class ExecuteModelState(NamedTuple): + """Ephemeral cached state transferred between execute_model() and + sample_tokens(), after execute_model() returns None.""" + + scheduler_output: SchedulerOutput + logits: torch.Tensor + spec_decode_metadata: SpecDecodeMetadata | None + spec_decode_common_attn_metadata: AscendCommonAttentionMetadata | None + hidden_states: torch.Tensor + sample_hidden_states: torch.Tensor + aux_hidden_states: list[torch.Tensor] | None + attn_metadata: PerLayerAttnMetadata + positions: torch.Tensor + ec_connector_output: ECConnectorOutput | None + cudagraph_stats: CUDAGraphStat | None + multimodal_outputs: Any # <-- Omni extension +``` + +This extended state must be imported from `npu_ar_model_runner` in `npu_generation_model_runner`. diff --git a/.claude/skills/vllm-omni-npu-upgrade/references/workflow-checklist.md b/.claude/skills/vllm-omni-npu-upgrade/references/workflow-checklist.md new file mode 100644 index 0000000000..4f184df0ec --- /dev/null +++ b/.claude/skills/vllm-omni-npu-upgrade/references/workflow-checklist.md @@ -0,0 +1,222 @@ +# NPU Model Runner Upgrade Workflow Checklist + +> **Note**: Reference documents (`omni-specific-blocks.md`) may not be complete. Always grep for `Omni-new` in GPU implementations to find all omni-specific code blocks. Update the reference docs when discovering new blocks. + +## Pre-Upgrade Preparation + +### 1. Version Information +- [ ] Identify current vllm-omni version: `_________` +- [ ] Identify target vllm-ascend version: `_________` +- [ ] Identify target vllm version: `_________` +- [ ] Last release date for GPU worker changes: `_________` + +### 2. Gather Git History +```bash +# GPU-side omni changes since last release +cd /root/vllm-workspace/vllm-omni +git log --oneline --since="YYYY-MM-DD" -- vllm_omni/worker/ + +# vllm-ascend NPUModelRunner changes +cd /root/vllm-workspace/vllm-ascend +git log --oneline .. -- vllm_ascend/worker/model_runner_v1.py +``` + +### 3. Backup Current Files +- [ ] Create backup of current NPU runners: + ```bash + cp -r vllm_omni/platforms/npu/worker vllm_omni/platforms/npu/worker.backup + ``` + +--- + +## OmniNPUModelRunner (npu_model_runner.py) + +### Read and Understand +- [ ] Read current `npu_model_runner.py` +- [ ] Read latest `vllm_ascend/worker/model_runner_v1.py` +- [ ] Read latest `vllm_omni/worker/gpu_model_runner.py` + +### Method: load_model +- [ ] Document existing omni-specific logic +- [ ] Copy latest NPUModelRunner.load_model structure +- [ ] Re-insert: `enable_sp(self.vllm_config)` call +- [ ] Re-insert: talker_mtp detection and setup +- [ ] Replace: `CUDAGraphWrapper` → `ACLGraphWrapper` +- [ ] Re-insert: Buffer allocations (talker_mtp_input_ids, etc.) + +### Method: _dummy_run +- [ ] Document existing omni-specific logic locations +- [ ] Copy latest NPUModelRunner._dummy_run +- [ ] Re-insert: talker_mtp dummy forward block (inside context) +- [ ] Re-insert: `extract_multimodal_outputs` call +- [ ] Verify: Comment markers are present + +### Method: _model_forward +- [ ] Copy latest NPUModelRunner._model_forward structure +- [ ] Re-insert: `_build_model_kwargs_extra()` call +- [ ] Re-insert: OmniOutput wrapping logic +- [ ] Re-insert: `_omni_last_model_output` caching +- [ ] Keep: NPU graph params update +- [ ] Keep: SP all-gather logic + +### Method: _talker_mtp_forward +- [ ] Verify: Uses `set_ascend_forward_context` +- [ ] Verify: Uses `ACLGraphWrapper` check +- [ ] Sync any changes from GPU `_talker_mtp_forward` + +### Imports +- [ ] Update vllm-ascend imports to latest paths +- [ ] Verify all omni imports are present +- [ ] Remove any deprecated imports + +--- + +## NPUARModelRunner (npu_ar_model_runner.py) + +### Read and Understand +- [ ] Read current `npu_ar_model_runner.py` +- [ ] Read latest `vllm_ascend/worker/model_runner_v1.py` execute_model +- [ ] Read latest `vllm_omni/worker/gpu_ar_model_runner.py` + +### Method: __init__ +- [ ] Sync any new initialization from GPU side +- [ ] Keep: `OmniKVTransferManager` setup +- [ ] Keep: Custom buffer allocations + +### Method: execute_model +- [ ] Document all omni blocks with line numbers +- [ ] Copy latest NPUModelRunner.execute_model structure +- [ ] Re-insert: KV transfer handling (beginning) +- [ ] Re-insert: Custom `_update_states` call +- [ ] Re-insert: `extract_multimodal_outputs` +- [ ] Re-insert: `compute_logits` with sampling_metadata try/except +- [ ] Update: ExecuteModelState to include multimodal_outputs + +### Method: sample_tokens +- [ ] Document all omni blocks +- [ ] Copy latest NPUModelRunner.sample_tokens structure +- [ ] Re-insert: `kv_extracted_req_ids` handling +- [ ] Re-insert: Hidden states CPU copy +- [ ] Re-insert: `_process_additional_information_updates` +- [ ] Re-insert: `OmniModelRunnerOutput` construction + +### ExecuteModelState +- [ ] Verify: `multimodal_outputs` field is present +- [ ] Verify: Imported/used correctly in execute_model + +### Imports +- [ ] Update all vllm-ascend imports +- [ ] Keep omni-specific imports + +--- + +## NPUGenerationModelRunner (npu_generation_model_runner.py) + +### Read and Understand +- [ ] Read current `npu_generation_model_runner.py` +- [ ] Read latest GPU `gpu_generation_model_runner.py` + +### Method: _update_request_states +- [ ] Verify: async_chunk handling is correct +- [ ] Sync any changes from GPU side + +### Method: execute_model +- [ ] Document all omni blocks +- [ ] Copy latest NPUModelRunner.execute_model base structure +- [ ] Re-insert: async_chunk update logic +- [ ] Re-insert: `seq_token_counts` injection +- [ ] Re-insert: `_run_generation_model` call +- [ ] Re-insert: `extract_multimodal_outputs` +- [ ] Use: ExecuteModelState from npu_ar_model_runner + +### Method: sample_tokens +- [ ] Keep: Entire omni multimodal output processing +- [ ] Update: Any new output fields needed +- [ ] Keep: `OmniModelRunnerOutput` construction + +### Method: _run_generation_model +- [ ] Sync any changes from GPU side +- [ ] Keep: `_model_forward` call with sampler + +### Method: _dummy_run +- [ ] Copy latest NPUModelRunner._dummy_run +- [ ] Re-insert: `model_kwargs = self._init_model_kwargs()` +- [ ] Re-insert: `extract_multimodal_outputs` at end + +### Imports +- [ ] Import ExecuteModelState from npu_ar_model_runner +- [ ] Update vllm-ascend imports + +--- + +## Post-Upgrade Validation + +### Syntax Validation +- [ ] `python -m py_compile vllm_omni/platforms/npu/worker/npu_model_runner.py` +- [ ] `python -m py_compile vllm_omni/platforms/npu/worker/npu_ar_model_runner.py` +- [ ] `python -m py_compile vllm_omni/platforms/npu/worker/npu_generation_model_runner.py` + +### Import Validation +- [ ] `python -c "from vllm_omni.platforms.npu.worker.npu_model_runner import OmniNPUModelRunner"` +- [ ] `python -c "from vllm_omni.platforms.npu.worker.npu_ar_model_runner import NPUARModelRunner"` +- [ ] `python -c "from vllm_omni.platforms.npu.worker.npu_generation_model_runner import NPUGenerationModelRunner"` + +### Comment Markers +- [ ] Grep for "Omni-new" in all three files +- [ ] Verify all omni blocks have closing markers + +### Code Review +- [ ] No `CUDAGraphWrapper` references +- [ ] All `set_forward_context` replaced with `set_ascend_forward_context` +- [ ] Parameter names correct (`aclgraph_runtime_mode` not `cudagraph_runtime_mode`) +- [ ] No duplicate code blocks +- [ ] No missing imports + +--- + +## Git Commit + +### Commit Message Template +``` +[NPU] Upgrade model runners to align with vllm-ascend vX.Y.Z + +- Update OmniNPUModelRunner with latest NPUModelRunner base +- Update NPUARModelRunner execute_model and sample_tokens +- Update NPUGenerationModelRunner for async_chunk changes +- Sync GPU-side omni changes from vX.Y.Z release +- Preserve all omni-specific logic (marked with Omni-new comments) + +Changes from vllm-ascend: +- + +Changes synced from GPU: +- +``` + +### Files to Stage +- [ ] `vllm_omni/platforms/npu/worker/npu_model_runner.py` +- [ ] `vllm_omni/platforms/npu/worker/npu_ar_model_runner.py` +- [ ] `vllm_omni/platforms/npu/worker/npu_generation_model_runner.py` +- [ ] Any other modified files + +--- + +## Troubleshooting + +### Import Errors +- Check if vllm-ascend module paths have changed +- Verify PYTHONPATH includes both vllm-ascend and vllm-omni + +### Type Errors +- Check method signatures match between GPU and NPU +- Verify NamedTuple fields match expected structure + +### Runtime Errors +- Enable debug logging: `export VLLM_LOGGING_LEVEL=DEBUG` +- Check graph capture issues: try `--enforce-eager` +- Check attention issues: verify AscendAttentionState usage + +### Performance Regression +- Compare with previous version on same model +- Check if graph capture is working: look for ACLGraph logs +- Verify SP/EP configurations are correct diff --git a/.gitignore b/.gitignore index 7f101a784c..35dc7571ee 100644 --- a/.gitignore +++ b/.gitignore @@ -158,7 +158,19 @@ cython_debug/ # Claude CLAUDE.md -.claude/ +/.claude/* +!.claude/skills/ +!.claude/skills/readme.md +!.claude/skills/add-diffusion-model/ +!.claude/skills/add-diffusion-model/SKILL.md +!.claude/skills/add-diffusion-model/references/ +!.claude/skills/add-diffusion-model/references/*.md +!.claude/skills/add-tts-model/ +!.claude/skills/add-tts-model/SKILL.md +!.claude/skills/review-pr/ +!.claude/skills/review-pr/SKILL.md +!.claude/skills/review-pr/references/ +!.claude/skills/review-pr/references/*.md # Codex AGENTS.md @@ -191,6 +203,7 @@ checkpoints/ # Cache directories cache/ !vllm_omni/diffusion/cache/ +!tests/diffusion/cache/ .cache/ diffusion_cache/ kv_cache/ @@ -250,3 +263,5 @@ tmp_test vllm_omni/_version.py # output files *.wav +# CI overlay yamls materialized from tests/utils.py:_CI_OVERLAYS at test time +tests/.ci_generated/ diff --git a/benchmarks/build_dataset/download_process_data_seedtts.md b/benchmarks/build_dataset/download_process_data_seedtts.md index ec16f64424..faf072303b 100644 --- a/benchmarks/build_dataset/download_process_data_seedtts.md +++ b/benchmarks/build_dataset/download_process_data_seedtts.md @@ -27,7 +27,7 @@ pip install gdown Download the dataset from Google Drive: ```bash -gdown --id 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP +gdown 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP ``` ### 4. Extract the Dataset @@ -74,7 +74,7 @@ rm meta.lst # Full setup and benchmark cd benchmarks/build_dataset pip install gdown -gdown --id 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP +gdown 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP tar -xf seedtts_testset.tar cp seedtts_testset/en/meta.lst meta.lst python extract_tts_prompts.py -i meta.lst -o top100.txt -n 100 diff --git a/benchmarks/diffusion/backends.py b/benchmarks/diffusion/backends.py index fa53f87aed..13ce7c8309 100644 --- a/benchmarks/diffusion/backends.py +++ b/benchmarks/diffusion/backends.py @@ -306,6 +306,8 @@ async def async_request_v1_videos( video_bytes = await content_response.read() output.response_body = video_bytes output.success = True + if "stage_durations" in poll_json: + output.stage_durations = poll_json["stage_durations"] or {} if "peak_memory_mb" in poll_json: output.peak_memory_mb = poll_json["peak_memory_mb"] elif "peak_memory_mb" in resp_json: diff --git a/benchmarks/diffusion/diffusion_benchmark_serving.py b/benchmarks/diffusion/diffusion_benchmark_serving.py index aad955b0d1..32ec48a698 100644 --- a/benchmarks/diffusion/diffusion_benchmark_serving.py +++ b/benchmarks/diffusion/diffusion_benchmark_serving.py @@ -558,6 +558,7 @@ def __init__(self, args, api_url: str, model: str, enable_negative_prompt: bool super().__init__(args, api_url, model) self.num_prompts = args.num_prompts self.enable_negative_prompt = enable_negative_prompt + self.num_input_images = max(1, args.num_input_images) self.random_request_config = getattr(args, "random_request_config", None) if self.random_request_config: self.random_request_config = json.loads(self.random_request_config) @@ -580,11 +581,7 @@ def __init__(self, args, api_url: str, model: str, enable_negative_prompt: bool # Random image generate if self.args.task in ["i2v", "ti2v", "ti2i", "i2i"]: - img = Image.new("RGB", (512, 512), (255, 255, 255)) - - image_path = os.path.join(tempfile.gettempdir(), "diffusion_benchmark_random_image.png") - self._random_image_path = [image_path] - img.save(image_path) + self._random_image_path = self._generate_random_image_paths() else: self._random_image_path = None @@ -619,6 +616,18 @@ def __getitem__(self, idx: int) -> RequestFuncInput: def get_requests(self) -> list[RequestFuncInput]: return [self[i] for i in range(len(self))] + def _generate_random_image_paths(self) -> list[str]: + image_paths: list[str] = [] + for image_idx in range(self.num_input_images): + img = Image.new("RGB", (512, 512), (255, 255, 255)) + image_path = os.path.join( + tempfile.gettempdir(), + f"diffusion_benchmark_random_image_{image_idx}.png", + ) + img.save(image_path) + image_paths.append(image_path) + return image_paths + def _compute_expected_latency_ms_from_base(req: RequestFuncInput, args, base_time_ms: float | None) -> float | None: """Compute expected execution time (ms) based on a base per-step-per-frame unit time. @@ -1115,6 +1124,15 @@ async def limited_request_func(req, session, pbar): '{"width":768,"height":768,"num_inference_steps":20,"weight":0.85}]' ), ) + parser.add_argument( + "--num-input-images", + type=int, + default=1, + help=( + "Number of synthetic input images to attach for image-conditioned tasks " + "(i2v, ti2v, ti2i, i2i) when using random dataset." + ), + ) args = parser.parse_args() diff --git a/benchmarks/qwen3-omni/README.md b/benchmarks/qwen3-omni/README.md index de27c05c2c..dc282d0525 100644 --- a/benchmarks/qwen3-omni/README.md +++ b/benchmarks/qwen3-omni/README.md @@ -9,7 +9,7 @@ cd benchmarks/build_dataset pip install gdown # Download SeedTTS test set from Google Drive -gdown --id 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP +gdown 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP # Extract tar -xf seedtts_testset.tar diff --git a/benchmarks/qwen3-tts/README.md b/benchmarks/qwen3-tts/README.md index 9c01f29aa9..a1c2ebe12f 100644 --- a/benchmarks/qwen3-tts/README.md +++ b/benchmarks/qwen3-tts/README.md @@ -35,8 +35,8 @@ MODEL=Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice bash run_benchmark.sh --async-only # Use a Voice Clone model MODEL=Qwen/Qwen3-TTS-12Hz-1.7B-Base TASK_TYPE=Base bash run_benchmark.sh --async-only -# Use bs16 config for higher throughput -STAGE_CONFIG=vllm_omni/configs/qwen3_tts_bs16.yaml bash run_benchmark.sh --async-only +# Use batch size 16 for higher throughput +BATCH_SIZE=16 bash run_benchmark.sh --async-only # Custom GPU, prompt count, concurrency levels GPU_DEVICE=1 NUM_PROMPTS=20 CONCURRENCY="1 4" bash run_benchmark.sh @@ -50,7 +50,8 @@ GPU_DEVICE=1 NUM_PROMPTS=20 CONCURRENCY="1 4" bash run_benchmark.sh CUDA_VISIBLE_DEVICES=0 python -m vllm_omni.entrypoints.cli.main serve \ "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice" \ --omni --host 127.0.0.1 --port 8000 \ - --stage-configs-path benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs1.yaml \ + --deploy-config vllm_omni/deploy/qwen3_tts.yaml \ + --stage-overrides '{"0":{"max_num_seqs":1,"gpu_memory_utilization":0.3,"max_num_batched_tokens":512},"1":{"max_num_seqs":1,"gpu_memory_utilization":0.3,"max_num_batched_tokens":8192}}' \ --trust-remote-code ``` @@ -84,16 +85,19 @@ python benchmarks/qwen3-tts/plot_results.py \ --output results/comparison.png ``` -## Stage Configs +## Batch-size presets -| Config | max_num_seqs | Description | -|--------|:------------:|-------------| -| `vllm_omni/configs/qwen3_tts_bs1.yaml` | 1 | Single-request processing (lowest latency) | -| `vllm_omni/configs/qwen3_tts_bs16.yaml` | 16 | High-throughput concurrent processing | +The bench script loads the bundled production deploy (`vllm_omni/deploy/qwen3_tts.yaml`) and layers per-stage budgets on top via `--stage-overrides`, driven by the `BATCH_SIZE` env var. Each batch size picks compatible per-stage `max_num_seqs`, `max_num_batched_tokens`, and `gpu_memory_utilization` defaults: -All configs use a 2-stage pipeline (Talker -> Code2Wav) with `async_chunk` streaming enabled. The `SharedMemoryConnector` streams codec frames (25-frame chunks with 25-frame context overlap) between stages. +| `BATCH_SIZE` | Description | +|:--:|-------------| +| `1` (default) | Single-request processing (lowest latency) | +| `4` | Moderate-throughput concurrent processing | +| `16` | High-throughput concurrent processing | -The model is specified via the CLI `--model` flag (or `MODEL` env var), so the same configs work for both the 0.6B and 1.7B model variants. +The 2-stage pipeline (Talker -> Code2Wav) runs with `async_chunk` streaming enabled via the prod deploy; the `SharedMemoryConnector` streams codec frames (25-frame chunks with 25-frame context overlap) between stages. + +The model is specified via the CLI `--model` flag (or `MODEL` env var), so the same bench script works for both the 0.6B and 1.7B model variants. ## Metrics diff --git a/benchmarks/qwen3-tts/run_benchmark.sh b/benchmarks/qwen3-tts/run_benchmark.sh index 283b6b844c..8c3e46903c 100755 --- a/benchmarks/qwen3-tts/run_benchmark.sh +++ b/benchmarks/qwen3-tts/run_benchmark.sh @@ -26,8 +26,8 @@ # # Use Voice Clone model # MODEL=Qwen/Qwen3-TTS-12Hz-1.7B-Base TASK_TYPE=Base bash run_benchmark.sh --async-only # -# # Use batch_size=4 config: -# STAGE_CONFIG=vllm_omni/configs/qwen3_tts_bs4.yaml bash run_benchmark.sh --async-only +# # Use batch_size=4: +# BATCH_SIZE=4 bash run_benchmark.sh --async-only # # Environment variables: # GPU_DEVICE - GPU index to use (default: 0) @@ -35,9 +35,9 @@ # CONCURRENCY - Space-separated concurrency levels (default: "1 4 10") # MODEL - Model name (default: Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice) # PORT - Server port (default: 8000) -# GPU_MEM_TALKER - gpu_memory_utilization for talker stage (default: 0.3) -# GPU_MEM_CODE2WAV - gpu_memory_utilization for code2wav stage (default: 0.2) -# STAGE_CONFIG - Path to stage config YAML (default: configs/qwen3_tts_bs1.yaml) +# BATCH_SIZE - Per-stage ``max_num_seqs`` for both talker and code2wav (default: 1) +# GPU_MEM_TALKER - gpu_memory_utilization for talker stage (default: 0.3 at bs=1, else 0.2) +# GPU_MEM_CODE2WAV - gpu_memory_utilization for code2wav stage (default: 0.3 at bs=1, else 0.2) # TASK_TYPE - Task type: CustomVoice, VoiceDesign, Base (default: CustomVoice) set -euo pipefail @@ -51,14 +51,36 @@ NUM_PROMPTS="${NUM_PROMPTS:-50}" CONCURRENCY="${CONCURRENCY:-1 4 10}" MODEL="${MODEL:-Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice}" PORT="${PORT:-8000}" -GPU_MEM_TALKER="${GPU_MEM_TALKER:-0.3}" -GPU_MEM_CODE2WAV="${GPU_MEM_CODE2WAV:-0.2}" +BATCH_SIZE="${BATCH_SIZE:-1}" +DEFAULT_MEM=$([ "${BATCH_SIZE}" = "1" ] && echo "0.3" || echo "0.2") +GPU_MEM_TALKER="${GPU_MEM_TALKER:-${DEFAULT_MEM}}" +GPU_MEM_CODE2WAV="${GPU_MEM_CODE2WAV:-${DEFAULT_MEM}}" NUM_WARMUPS="${NUM_WARMUPS:-3}" -STAGE_CONFIG="${STAGE_CONFIG:-vllm_omni/configs/qwen3_tts_bs1.yaml}" +DEPLOY_CONFIG="vllm_omni/deploy/qwen3_tts.yaml" RESULT_DIR="${SCRIPT_DIR}/results" TIMESTAMP="$(date +%Y%m%d_%H%M%S)" TASK_TYPE="${TASK_TYPE:-CustomVoice}" +# Build --stage-overrides JSON from BATCH_SIZE + GPU_MEM_*. +STAGE_OVERRIDES=$( + BATCH_SIZE="${BATCH_SIZE}" \ + GPU_MEM_TALKER="${GPU_MEM_TALKER}" \ + GPU_MEM_CODE2WAV="${GPU_MEM_CODE2WAV}" \ + python - <<'PYEOF' +import json, os +bs = int(os.environ["BATCH_SIZE"]) +mem_t = float(os.environ["GPU_MEM_TALKER"]) +mem_c = float(os.environ["GPU_MEM_CODE2WAV"]) +# Prefill budget grows with batch size on both stages. +talker_batched = 512 if bs <= 4 else 4096 +code2wav_batched = 8192 if bs <= 4 else 32768 +print(json.dumps({ + "0": {"max_num_seqs": bs, "gpu_memory_utilization": mem_t, "max_num_batched_tokens": talker_batched}, + "1": {"max_num_seqs": bs, "gpu_memory_utilization": mem_c, "max_num_batched_tokens": code2wav_batched}, +})) +PYEOF +) + # Parse args RUN_ASYNC=true RUN_HF=true @@ -75,41 +97,27 @@ mkdir -p "${RESULT_DIR}" echo "============================================================" echo " Qwen3-TTS Benchmark" echo "============================================================" -echo " GPU: ${GPU_DEVICE}" -echo " Model: ${MODEL}" -echo " Prompts: ${NUM_PROMPTS}" -echo " Concurrency: ${CONCURRENCY}" -echo " Port: ${PORT}" -echo " Stage config: ${STAGE_CONFIG}" -echo " Results: ${RESULT_DIR}" -echo " Task type: ${TASK_TYPE}" +echo " GPU: ${GPU_DEVICE}" +echo " Model: ${MODEL}" +echo " Prompts: ${NUM_PROMPTS}" +echo " Concurrency: ${CONCURRENCY}" +echo " Port: ${PORT}" +echo " Deploy config: ${DEPLOY_CONFIG}" +echo " Batch size: ${BATCH_SIZE}" +echo " GPU mem T/C: ${GPU_MEM_TALKER} / ${GPU_MEM_CODE2WAV}" +echo " Results: ${RESULT_DIR}" +echo " Task type: ${TASK_TYPE}" echo "============================================================" -# Prepare stage config with correct GPU device and memory settings -prepare_config() { - local config_template="$1" - local config_name="$2" - local output_path="${RESULT_DIR}/${config_name}_stage_config.yaml" - - # Use sed to patch GPU device and memory utilization - sed \ - -e "s/devices: \"0\"/devices: \"${GPU_DEVICE}\"/g" \ - -e "s/gpu_memory_utilization: 0.3/gpu_memory_utilization: ${GPU_MEM_TALKER}/g" \ - -e "s/gpu_memory_utilization: 0.2/gpu_memory_utilization: ${GPU_MEM_CODE2WAV}/g" \ - "${config_template}" > "${output_path}" - - echo "${output_path}" -} - # Start server and wait for it to be ready start_server() { - local stage_config="$1" - local config_name="$2" + local config_name="$1" local log_file="${RESULT_DIR}/server_${config_name}_${TIMESTAMP}.log" echo "" echo "Starting server with config: ${config_name}" - echo " Stage config: ${stage_config}" + echo " Deploy config: ${DEPLOY_CONFIG}" + echo " Stage overrides: ${STAGE_OVERRIDES}" echo " Log file: ${log_file}" VLLM_WORKER_MULTIPROC_METHOD=spawn \ @@ -118,7 +126,8 @@ start_server() { --omni \ --host 127.0.0.1 \ --port "${PORT}" \ - --stage-configs-path "${stage_config}" \ + --deploy-config "${DEPLOY_CONFIG}" \ + --stage-overrides "${STAGE_OVERRIDES}" \ --stage-init-timeout 120 \ --trust-remote-code \ --disable-log-stats \ @@ -175,17 +184,13 @@ trap 'stop_server' EXIT # Run benchmark for a given config run_bench() { local config_name="$1" - local config_template="$2" echo "" echo "============================================================" echo " Benchmarking: ${config_name}" echo "============================================================" - local stage_config - stage_config=$(prepare_config "${config_template}" "${config_name}") - - start_server "${stage_config}" "${config_name}" + start_server "${config_name}" # Convert concurrency string to args local conc_args="" @@ -212,7 +217,7 @@ run_bench() { # Run vllm-omni benchmark if [ "${RUN_ASYNC}" = true ]; then - run_bench "async_chunk" "${SCRIPT_DIR}/${STAGE_CONFIG}" + run_bench "async_chunk" fi # Run HuggingFace baseline benchmark diff --git a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs16.yaml b/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs16.yaml deleted file mode 100644 index 2cc5cf5353..0000000000 --- a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs16.yaml +++ /dev/null @@ -1,94 +0,0 @@ -# Qwen3-TTS max_num_seqs=16 config (streaming with async_chunk) -# High-throughput concurrent request processing -# 2-stage pipeline: Talker -> Code2Wav -async_chunk: true -stage_args: - - stage_id: 0 - stage_type: llm - is_comprehension: true - runtime: - devices: "0" - engine_args: - max_num_seqs: 16 - model_stage: qwen3_tts - model_arch: Qwen3TTSTalkerForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - enforce_eager: false - trust_remote_code: true - async_scheduling: true - enable_prefix_caching: false - engine_output_type: latent - gpu_memory_utilization: 0.3 - distributed_executor_backend: "mp" - max_num_batched_tokens: 4096 - max_model_len: 4096 - custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk - output_connectors: - to_stage_1: connector_of_shared_memory - default_sampling_params: - temperature: 0.9 - top_k: 50 - max_tokens: 4096 - seed: 42 - detokenize: false - repetition_penalty: 1.05 - stop_token_ids: [2150] - - - stage_id: 1 - stage_type: llm - runtime: - devices: "0" - engine_args: - max_num_seqs: 16 - model_stage: code2wav - model_arch: Qwen3TTSCode2Wav - worker_type: generation - scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler - enforce_eager: true - trust_remote_code: true - async_scheduling: true - enable_prefix_caching: false - engine_output_type: audio - gpu_memory_utilization: 0.2 - distributed_executor_backend: "mp" - max_num_batched_tokens: 16384 - max_model_len: 32768 - engine_input_source: [0] - final_output: true - final_output_type: audio - input_connectors: - from_stage_0: connector_of_shared_memory - tts_args: - max_instructions_length: 500 - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 65536 - seed: 42 - detokenize: true - repetition_penalty: 1.0 - -runtime: - enabled: true - defaults: - window_size: -1 - max_inflight: 16 - - connectors: - connector_of_shared_memory: - name: SharedMemoryConnector - extra: - shm_threshold_bytes: 65536 - codec_streaming: true - connector_get_sleep_s: 0.01 - connector_get_max_wait_first_chunk: 3000 - connector_get_max_wait: 300 - codec_chunk_frames: 25 - codec_left_context_frames: 25 - - edges: - - from: 0 - to: 1 - window_size: -1 diff --git a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs4.yaml b/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs4.yaml deleted file mode 100644 index 5de107d497..0000000000 --- a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs4.yaml +++ /dev/null @@ -1,94 +0,0 @@ -# Qwen3-TTS batch_size=4 config (streaming with async_chunk) -# Enables concurrent request processing -# 2-stage pipeline: Talker -> Code2Wav -async_chunk: true -stage_args: - - stage_id: 0 - stage_type: llm - is_comprehension: true - runtime: - devices: "0" - engine_args: - max_num_seqs: 4 - model_stage: qwen3_tts - model_arch: Qwen3TTSTalkerForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - enforce_eager: false - trust_remote_code: true - async_scheduling: true - enable_prefix_caching: false - engine_output_type: latent - gpu_memory_utilization: 0.3 - distributed_executor_backend: "mp" - max_num_batched_tokens: 512 - max_model_len: 4096 - custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk - output_connectors: - to_stage_1: connector_of_shared_memory - default_sampling_params: - temperature: 0.9 - top_k: 50 - max_tokens: 4096 - seed: 42 - detokenize: false - repetition_penalty: 1.05 - stop_token_ids: [2150] - - - stage_id: 1 - stage_type: llm - runtime: - devices: "0" - engine_args: - max_num_seqs: 4 - model_stage: code2wav - model_arch: Qwen3TTSCode2Wav - worker_type: generation - scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler - enforce_eager: true - trust_remote_code: true - async_scheduling: true - enable_prefix_caching: false - engine_output_type: audio - gpu_memory_utilization: 0.2 - distributed_executor_backend: "mp" - max_num_batched_tokens: 8192 - max_model_len: 32768 - engine_input_source: [0] - final_output: true - final_output_type: audio - input_connectors: - from_stage_0: connector_of_shared_memory - tts_args: - max_instructions_length: 500 - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 65536 - seed: 42 - detokenize: true - repetition_penalty: 1.0 - -runtime: - enabled: true - defaults: - window_size: -1 - max_inflight: 4 - - connectors: - connector_of_shared_memory: - name: SharedMemoryConnector - extra: - shm_threshold_bytes: 65536 - codec_streaming: true - connector_get_sleep_s: 0.01 - connector_get_max_wait_first_chunk: 3000 - connector_get_max_wait: 300 - codec_chunk_frames: 25 - codec_left_context_frames: 25 - - edges: - - from: 0 - to: 1 - window_size: -1 diff --git a/benchmarks/qwen3-tts/vllm_omni/run_async_chunk_benchmark.sh b/benchmarks/qwen3-tts/vllm_omni/run_async_chunk_benchmark.sh index 61cf7757a9..0ede359ea3 100755 --- a/benchmarks/qwen3-tts/vllm_omni/run_async_chunk_benchmark.sh +++ b/benchmarks/qwen3-tts/vllm_omni/run_async_chunk_benchmark.sh @@ -31,8 +31,11 @@ PORT_OFF="${PORT_OFF:-8001}" RESULT_DIR="${SCRIPT_DIR}/results" TIMESTAMP="$(date +%Y%m%d_%H%M%S)" -STAGE_CONFIG_ON="vllm_omni/model_executor/stage_configs/qwen3_tts.yaml" -STAGE_CONFIG_OFF="vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml" +# The bundled ``vllm_omni/deploy/qwen3_tts.yaml`` is auto-loaded by the model +# registry; no ``--deploy-config`` flag needed on the default (ON) path. +# async_chunk OFF is selected by the ``--no-async-chunk`` CLI flag — +# the single ``qwen3_tts`` pipeline dispatches to the end-to-end codec +# processor when ``deploy.async_chunk`` is false. mkdir -p "${RESULT_DIR}" @@ -77,7 +80,6 @@ wait_for_server() { echo "" echo "[Phase 1] Starting async_chunk ON server on port ${PORT_ON}..." CUDA_VISIBLE_DEVICES=${GPU_DEVICE} vllm-omni serve "${MODEL}" \ - --stage-configs-path "${STAGE_CONFIG_ON}" \ --host 0.0.0.0 --port "${PORT_ON}" \ --trust-remote-code --enforce-eager --omni \ > "${RESULT_DIR}/server_on_${TIMESTAMP}.log" 2>&1 & @@ -104,7 +106,7 @@ sleep 5 echo "" echo "[Phase 2] Starting async_chunk OFF server on port ${PORT_OFF}..." CUDA_VISIBLE_DEVICES=${GPU_DEVICE} vllm-omni serve "${MODEL}" \ - --stage-configs-path "${STAGE_CONFIG_OFF}" \ + --no-async-chunk \ --host 0.0.0.0 --port "${PORT_OFF}" \ --trust-remote-code --enforce-eager --omni \ > "${RESULT_DIR}/server_off_${TIMESTAMP}.log" 2>&1 & diff --git a/benchmarks/voxcpm/README.md b/benchmarks/voxcpm/README.md new file mode 100644 index 0000000000..17f904101b --- /dev/null +++ b/benchmarks/voxcpm/README.md @@ -0,0 +1,119 @@ +# VoxCPM Benchmark + +This directory contains both: + +- online serving benchmark through the OpenAI-compatible `/v1/audio/speech` API +- offline benchmark for `Omni` / `AsyncOmni` +- full offline smoke-matrix orchestration + +Both benchmark paths report: + +- TTFP: time to first PCM packet +- E2E latency +- RTF: real-time factor (`e2e / audio_duration`) + +## Offline Benchmark + +Single offline benchmark run: + +```bash +python benchmarks/voxcpm/vllm_omni/bench_tts_offline.py \ + --model /path/to/voxcpm-model \ + --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm.yaml \ + --text "This is a split-stage VoxCPM synthesis example running on vLLM Omni." \ + --warmup-runs 1 \ + --output-dir benchmarks/voxcpm/results/offline_single +``` + +Streaming offline benchmark: + +```bash +python benchmarks/voxcpm/vllm_omni/bench_tts_offline.py \ + --model /path/to/voxcpm-model \ + --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml \ + --text "This is a split-stage VoxCPM streaming example running on vLLM Omni." \ + --warmup-runs 1 \ + --output-dir benchmarks/voxcpm/results/offline_streaming +``` + +Full fixed offline matrix, equivalent to the old `examples/offline_inference/voxcpm/test.py`: + +```bash +python benchmarks/voxcpm/vllm_omni/run_offline_matrix.py \ + --model /path/to/voxcpm-model \ + --ref-audio /path/to/reference.wav \ + --ref-text "The exact transcript spoken in reference.wav." \ + --output-root benchmarks/voxcpm/results/offline_matrix +``` + +The full matrix covers both routes: + +- streaming: `voxcpm_async_chunk.yaml` +- sync: `voxcpm.yaml` + +And these six scenarios under each route: + +- warmup + single TTS +- warmup + single voice cloning +- warmup + batch TTS +- warmup + batch voice cloning +- cold single TTS +- cold single voice cloning + +`bench_tts_offline.py` itself no longer writes `summary.json` / `results.json`; it prints TTFP / RTF inline and saves generated WAV files only. The matrix runner keeps only per-case `run.log`. + +## Start the Server + +Async-chunk: + +```bash +vllm serve /path/to/voxcpm-model \ + --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml \ + --trust-remote-code \ + --enforce-eager \ + --omni \ + --port 8091 +``` + +Non-streaming: + +```bash +vllm serve /path/to/voxcpm-model \ + --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm.yaml \ + --trust-remote-code \ + --enforce-eager \ + --omni \ + --port 8091 +``` + +## Run the Benchmark + +```bash +python benchmarks/voxcpm/vllm_omni/bench_tts_serve.py \ + --host 127.0.0.1 \ + --port 8091 \ + --num-prompts 20 \ + --max-concurrency 1 \ + --result-dir /tmp/voxcpm_bench +``` + +Voice cloning benchmark: + +```bash +python benchmarks/voxcpm/vllm_omni/bench_tts_serve.py \ + --host 127.0.0.1 \ + --port 8091 \ + --num-prompts 10 \ + --max-concurrency 1 \ + --ref-audio https://example.com/reference.wav \ + --ref-text "The exact transcript spoken in the reference audio." \ + --result-dir /tmp/voxcpm_clone_bench +``` + +## Notes + +- The benchmark uses `stream=true` and `response_format=pcm` so TTFP is measured from the first audio packet. +- `RTF < 1.0` means the server generates audio faster than real time. +- For `voxcpm_async_chunk.yaml`, keep concurrency at `1`. This matches native VoxCPM streaming more closely. +- Do not benchmark concurrent online streaming on `voxcpm_async_chunk.yaml`; use `voxcpm.yaml` for multi-request throughput runs. +- For the offline matrix mode, `--ref-audio` and `--ref-text` are required because clone cases are part of the fixed coverage set. diff --git a/benchmarks/voxcpm/vllm_omni/bench_tts_offline.py b/benchmarks/voxcpm/vllm_omni/bench_tts_offline.py new file mode 100644 index 0000000000..a3bad3e692 --- /dev/null +++ b/benchmarks/voxcpm/vllm_omni/bench_tts_offline.py @@ -0,0 +1,890 @@ +"""Offline VoxCPM benchmark for vLLM Omni. + +Supports both: +- sync one-shot (Omni.generate) +- streaming (AsyncOmni.generate with async_chunk config) +- text-only synthesis +- voice cloning +- text/clone batch inputs from txt or jsonl +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import tempfile +import time +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch +from vllm.utils.argparse_utils import FlexibleArgumentParser + +from vllm_omni import AsyncOmni, Omni + +REPO_ROOT = Path(__file__).resolve().parents[3] +DEFAULT_STAGE_ASYNC = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm_async_chunk.yaml" +DEFAULT_STAGE_SYNC = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm.yaml" + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True, slots=True) +class PromptSpec: + text: str + label: str + ref_audio: str | None = None + ref_text: str | None = None + + +def _require_soundfile(): + try: + import soundfile as sf # type: ignore + except ModuleNotFoundError as exc: + raise RuntimeError( + "soundfile is required to write VoxCPM benchmark WAV outputs. Install it with: pip install soundfile" + ) from exc + return sf + + +def _build_prompt( + args, + *, + text: str, + ref_audio: str | None = None, + ref_text: str | None = None, + global_request_id: str | None = None, +) -> dict[str, Any]: + additional_information: dict[str, list[Any]] = { + "text": [text], + "cfg_value": [args.cfg_value], + "inference_timesteps": [args.inference_timesteps], + "min_len": [args.min_len], + "max_new_tokens": [args.max_new_tokens], + } + if args.streaming_prefix_len is not None: + additional_information["streaming_prefix_len"] = [args.streaming_prefix_len] + + if ref_audio: + additional_information["ref_audio"] = [ref_audio] + if ref_text: + additional_information["ref_text"] = [ref_text] + if global_request_id is not None: + additional_information["global_request_id"] = [global_request_id] + + return { + "prompt_token_ids": [1], + "additional_information": additional_information, + } + + +def _extract_audio_tensor(mm: dict[str, Any]) -> torch.Tensor: + audio = mm.get("audio", mm.get("model_outputs")) + if audio is None: + raise ValueError("No audio output found in multimodal output.") + if isinstance(audio, list): + parts = [torch.as_tensor(a).float().cpu().reshape(-1) for a in audio] + audio = torch.cat(parts, dim=-1) if parts else torch.zeros(0) + if not isinstance(audio, torch.Tensor): + audio = torch.as_tensor(audio) + return audio.float().cpu().reshape(-1) + + +def _extract_sample_rate(mm: dict[str, Any]) -> int: + sr_raw = mm.get("sr", 24000) + if isinstance(sr_raw, list) and sr_raw: + sr_raw = sr_raw[-1] + if hasattr(sr_raw, "item"): + return int(sr_raw.item()) + return int(sr_raw) + + +def _emit_offline_metrics( + *, + request_id: str, + elapsed_s: float, + first_audio_elapsed: float | None, + audio_duration_s: float, +) -> None: + metrics = { + "request_id": request_id, + "ttfp_ms": round(first_audio_elapsed * 1000.0, 3) if first_audio_elapsed is not None else None, + "audio_duration_s": round(audio_duration_s, 6), + "rtf": round(elapsed_s / audio_duration_s, 6) if audio_duration_s > 0 else None, + } + print(f"[OfflineMetrics] {metrics}") + + +def _write_audio_tensor(output_path: Path, audio_tensor: Any, sample_rate: int) -> None: + sf = _require_soundfile() + if isinstance(audio_tensor, torch.Tensor): + audio_np = audio_tensor.float().cpu().clamp(-1.0, 1.0).numpy() + else: + audio_np = torch.as_tensor(audio_tensor).float().cpu().clamp(-1.0, 1.0).numpy() + sf.write( + output_path, + audio_np, + sample_rate, + format="WAV", + subtype="PCM_16", + ) + + +def _save_wav(mm: dict[str, Any], output_dir: Path, request_id: str) -> Path: + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / f"output_{request_id}.wav" + _write_audio_tensor(output_path, _extract_audio_tensor(mm), _extract_sample_rate(mm)) + return output_path + + +def _iter_request_multimodal_outputs(request_output: Any): + outputs = getattr(request_output, "outputs", None) + if outputs: + for output in outputs: + mm = getattr(output, "multimodal_output", None) + if isinstance(mm, dict): + yield mm + + mm = getattr(request_output, "multimodal_output", None) + if isinstance(mm, dict): + yield mm + + +def _read_non_empty_lines(path: str) -> list[str]: + with open(path, encoding="utf-8") as f: + return [line.strip() for line in f if line.strip()] + + +def _load_prompt_specs(args) -> list[PromptSpec]: + specs: list[PromptSpec] = [] + + if args.txt_prompts is not None: + texts = _read_non_empty_lines(args.txt_prompts) + if not texts: + raise ValueError(f"No prompts found in {args.txt_prompts}") + for idx, text in enumerate(texts, start=1): + specs.append( + PromptSpec( + text=text, + label=f"item{idx:03d}", + ref_audio=args.ref_audio, + ref_text=args.ref_text, + ) + ) + return specs + + if args.jsonl_prompts is not None: + with open(args.jsonl_prompts, encoding="utf-8") as f: + for line_no, raw_line in enumerate(f, start=1): + line = raw_line.strip() + if not line: + continue + try: + item = json.loads(line) + except json.JSONDecodeError as exc: + raise ValueError(f"{args.jsonl_prompts}:{line_no} is not valid JSON: {exc}") from exc + if not isinstance(item, dict): + raise ValueError(f"{args.jsonl_prompts}:{line_no} must be a JSON object") + + text = item.get("text") + if not isinstance(text, str) or not text.strip(): + raise ValueError(f"{args.jsonl_prompts}:{line_no} requires non-empty string field 'text'") + + ref_audio = item.get("ref_audio", args.ref_audio) + ref_text = item.get("ref_text", args.ref_text) + if (ref_audio is None) != (ref_text is None): + raise ValueError( + f"{args.jsonl_prompts}:{line_no} must provide both 'ref_audio' and 'ref_text' together" + ) + + specs.append( + PromptSpec( + text=text.strip(), + label=f"item{len(specs) + 1:03d}", + ref_audio=ref_audio, + ref_text=ref_text, + ) + ) + + if not specs: + raise ValueError(f"No prompts found in {args.jsonl_prompts}") + return specs + + specs.append( + PromptSpec( + text=args.text, + label="item001", + ref_audio=args.ref_audio, + ref_text=args.ref_text, + ) + ) + return specs + + +def _build_prompt_for_spec(args, spec: PromptSpec, *, global_request_id: str | None = None) -> dict[str, Any]: + return _build_prompt( + args, + text=spec.text, + ref_audio=spec.ref_audio, + ref_text=spec.ref_text, + global_request_id=global_request_id, + ) + + +def _count_voice_clone_prompts(prompt_specs: list[PromptSpec]) -> int: + return sum(1 for spec in prompt_specs if spec.ref_audio is not None) + + +def _get_warmup_specs(prompt_specs: list[PromptSpec]) -> list[PromptSpec]: + return prompt_specs[:1] + + +def _extract_stream_finished(stage_output: Any) -> bool: + request_output = getattr(stage_output, "request_output", None) + request_finished = getattr(request_output, "finished", None) + if request_finished is not None: + return bool(request_finished) + return bool(getattr(stage_output, "finished", False)) + + +def _build_profiled_stage_config( + stage_configs_path: str, + profiler_dir: str, +) -> str: + stage_config_path = Path(stage_configs_path) + yaml_text = stage_config_path.read_text(encoding="utf-8") + injected_lines: list[str] = [] + injected_count = 0 + + for line in yaml_text.splitlines(): + injected_lines.append(line) + if line.strip() != "engine_args:": + continue + indent = line[: len(line) - len(line.lstrip())] + child_indent = indent + " " + grandchild_indent = child_indent + " " + injected_lines.extend( + [ + f"{child_indent}profiler_config:", + f'{grandchild_indent}profiler: "torch"', + f'{grandchild_indent}torch_profiler_dir: "{profiler_dir}"', + f"{grandchild_indent}torch_profiler_with_stack: true", + ] + ) + injected_count += 1 + + if injected_count == 0: + raise ValueError(f"No engine_args block found in stage config: {stage_configs_path}") + + tmp = tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + delete=False, + suffix=".yaml", + prefix=f"{stage_config_path.stem}_profile_", + ) + tmp.write("\n".join(injected_lines) + "\n") + tmp.close() + return tmp.name + + +def parse_args(): + parser = FlexibleArgumentParser( + description="Offline split-stage VoxCPM inference with vLLM Omni (auto sync/streaming by stage config)" + ) + parser.add_argument( + "--model", + type=str, + default=os.environ.get("VOXCPM_MODEL"), + help="Local VoxCPM model directory. Defaults to $VOXCPM_MODEL.", + ) + parser.add_argument( + "--text", + type=str, + default="This is a split-stage VoxCPM synthesis example running on vLLM Omni.", + help="Text to synthesize. Ignored when --txt-prompts or --jsonl-prompts is used.", + ) + parser.add_argument( + "--txt-prompts", + type=str, + default=None, + help="Path to a .txt file with one synthesis text per line.", + ) + parser.add_argument( + "--jsonl-prompts", + type=str, + default=None, + help=( + "Path to a .jsonl file. Each line must contain at least {'text': ...}; " + "clone rows can also set ref_audio/ref_text, and ref_text must be the " + "real transcript of ref_audio." + ), + ) + parser.add_argument( + "--ref-audio", + type=str, + default=None, + help=( + "Optional reference audio path for voice cloning. With --txt-prompts, " + "the same reference is applied to every line." + ), + ) + parser.add_argument( + "--ref-text", + type=str, + default=None, + help=( + "Real transcript of the reference audio. Placeholder text or mismatched " + "text will usually produce noisy/electronic clone audio." + ), + ) + parser.add_argument( + "--stage-configs-path", + type=str, + default=str(DEFAULT_STAGE_SYNC), + help="Stage config YAML path. Routing is selected only from this path.", + ) + parser.add_argument( + "--cfg-value", + type=float, + default=2.0, + help="Classifier-free guidance value for VoxCPM.", + ) + parser.add_argument( + "--inference-timesteps", + type=int, + default=10, + help="Number of inference timesteps.", + ) + parser.add_argument( + "--min-len", + type=int, + default=2, + help="Minimum generated token length.", + ) + parser.add_argument( + "--max-new-tokens", + type=int, + default=4096, + help="Maximum generated token length.", + ) + parser.add_argument( + "--streaming-prefix-len", + type=int, + default=None, + help="VoxCPM streaming window (optional, streaming mode only).", + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Directory for output WAV files.", + ) + parser.add_argument( + "--stage-init-timeout", + type=int, + default=600, + help="Stage initialization timeout in seconds.", + ) + parser.add_argument( + "--log-stats", + dest="log_stats", + action="store_true", + help="Enable vLLM Omni stats logging.", + ) + parser.add_argument( + "--no-log-stats", + dest="log_stats", + action="store_false", + help="Disable vLLM Omni stats logging.", + ) + parser.set_defaults(log_stats=True) + parser.add_argument( + "--num-runs", + type=int, + default=1, + help="Number of full inference runs (same prompt each time). Default 1.", + ) + parser.add_argument( + "--warmup-runs", + type=int, + default=0, + help=( + "Optional number of warmup passes before measured runs. Warmup uses only " + "the first prompt and does not save outputs." + ), + ) + parser.add_argument( + "--enable-profiler", + action="store_true", + help=( + "Enable torch profiler for the configured stages. A temporary profiled " + "stage config is generated automatically." + ), + ) + parser.add_argument( + "--profiler-dir", + type=str, + default=None, + help="Directory for profiler traces. Defaults to /profiler when profiling is enabled.", + ) + parser.add_argument( + "--profiler-stages", + type=int, + nargs="*", + default=None, + help="Optional stage ids to profile. Defaults to all stages that have profiler_config.", + ) + parser.add_argument( + "--profiler-wait-seconds", + type=float, + default=30.0, + help="Seconds to wait after stop_profile for trace files to flush.", + ) + args = parser.parse_args() + + if not args.model: + parser.error("--model is required unless $VOXCPM_MODEL is set") + if args.txt_prompts is not None and args.jsonl_prompts is not None: + parser.error("--txt-prompts and --jsonl-prompts are mutually exclusive") + if (args.ref_audio is None) != (args.ref_text is None): + parser.error("--ref-audio and --ref-text must be provided together") + if args.num_runs < 1: + parser.error("--num-runs must be >= 1") + if args.warmup_runs < 0: + parser.error("--warmup-runs must be >= 0") + if args.output_dir is None: + args.output_dir = ( + "output_audio_streaming" if _is_streaming_stage_config(args.stage_configs_path) else "output_audio" + ) + if args.enable_profiler and args.profiler_dir is None: + args.profiler_dir = str(Path(args.output_dir) / "profiler") + try: + args.prompt_specs = _load_prompt_specs(args) + except ValueError as exc: + parser.error(str(exc)) + + return args + + +def _is_streaming_stage_config(stage_configs_path: str) -> bool: + cfg_name = Path(stage_configs_path).name.lower() + # Keep routing purely config-path based: + # - voxcpm.yaml => sync + # - voxcpm_async_chunk.yaml => streaming + return "async_chunk" in cfg_name + + +async def _collect_streaming_audio( + omni: AsyncOmni, + args: Any, + spec: PromptSpec, + request_id: str, + *, + phase_label: str, + prompt_index: int, + prompt_count: int, + print_prompt: bool = False, +) -> tuple[torch.Tensor, int, float, float | None]: + prompt = _build_prompt_for_spec(args, spec, global_request_id=request_id) + delta_chunks: list[torch.Tensor] = [] + sample_rate = 24000 + chunk_i = 0 + prev_total_samples = 0 + t_start = time.perf_counter() + first_audio_elapsed: float | None = None + + if print_prompt: + print(f"---prompt---:{prompt}") + + async for stage_output in omni.generate(prompt, request_id=request_id): + mm = getattr(stage_output, "multimodal_output", None) + if not isinstance(mm, dict): + ro = getattr(stage_output, "request_output", None) + if ro is None: + continue + mm = getattr(ro, "multimodal_output", None) + if not isinstance(mm, dict) and getattr(ro, "outputs", None): + seq = ro.outputs[0] + mm = getattr(seq, "multimodal_output", None) + if not isinstance(mm, dict): + continue + sample_rate = _extract_sample_rate(mm) + try: + w = _extract_audio_tensor(mm) + n = int(w.numel()) + if n == 0: + continue + finished = _extract_stream_finished(stage_output) + if n > prev_total_samples: + delta = w.reshape(-1)[prev_total_samples:] + prev_total_samples = n + elif finished and n == prev_total_samples: + delta = w.reshape(-1)[:0] + else: + delta = w.reshape(-1) + prev_total_samples += int(delta.numel()) + if int(delta.numel()) > 0: + delta_chunks.append(delta) + if first_audio_elapsed is None and int(delta.numel()) > 0: + first_audio_elapsed = time.perf_counter() - t_start + logger.info( + "%s prompt=%d/%d chunk=%d delta_samples=%d buf_len=%d finished=%s", + phase_label, + prompt_index + 1, + prompt_count, + chunk_i, + int(delta.numel()), + n, + finished, + ) + chunk_i += 1 + except ValueError: + if not _extract_stream_finished(stage_output): + logger.debug("skip non-audio partial output chunk=%d", chunk_i) + + if not delta_chunks: + raise RuntimeError("No audio chunks received; check stage config and logs.") + + audio_cat = torch.cat([c.reshape(-1) for c in delta_chunks], dim=0) + elapsed = time.perf_counter() - t_start + return audio_cat, sample_rate, elapsed, first_audio_elapsed + + +async def _abort_streaming_residual_work( + omni: AsyncOmni, + request_id: str, + *, + settle_seconds: float = 0.1, +) -> None: + """Stop any late stage-0 work once the final audio has been collected.""" + await omni.engine.abort_async([request_id]) + if settle_seconds > 0: + await asyncio.sleep(settle_seconds) + + +async def _run_streaming_single( + omni: AsyncOmni, + args: Any, + spec: PromptSpec, + output_dir: Path, + request_id: str, + *, + run_index: int, + num_runs: int, + prompt_index: int, + prompt_count: int, +) -> Path: + audio_cat, sample_rate, elapsed, first_audio_elapsed = await _collect_streaming_audio( + omni, + args, + spec, + request_id, + phase_label=f"run={run_index + 1}/{num_runs}", + prompt_index=prompt_index, + prompt_count=prompt_count, + print_prompt=(run_index == 0 and prompt_index == 0), + ) + await _abort_streaming_residual_work(omni, request_id) + output_path = output_dir / f"output_run{run_index + 1}_{spec.label}.wav" + _write_audio_tensor(output_path, audio_cat, sample_rate) + audio_duration_s = float(audio_cat.numel()) / float(sample_rate) if sample_rate > 0 else 0.0 + ttfp_text = f", ttfp={first_audio_elapsed:.2f}s" if first_audio_elapsed is not None else "" + rtf_text = f", rtf={elapsed / audio_duration_s:.3f}" if audio_duration_s > 0 else "" + print( + f"Saved (streaming) run {run_index + 1}/{num_runs}, " + f"prompt {prompt_index + 1}/{prompt_count}: {output_path} ({elapsed:.2f}s{ttfp_text}{rtf_text})" + ) + _emit_offline_metrics( + request_id=request_id, + elapsed_s=elapsed, + first_audio_elapsed=first_audio_elapsed, + audio_duration_s=audio_duration_s, + ) + return output_path + + +async def _run_streaming_warmup(args, omni: AsyncOmni) -> None: + if args.warmup_runs == 0: + return + + warmup_specs = _get_warmup_specs(args.prompt_specs) + print( + f"Warmup: {args.warmup_runs} run(s) using the first prompt " + f"({len(warmup_specs)} prompt(s)); outputs will be discarded." + ) + for warmup_index in range(args.warmup_runs): + t_warmup = time.perf_counter() + tasks = [] + request_ids: list[str] = [] + for prompt_index, spec in enumerate(warmup_specs): + request_id = f"warmup_stream_{warmup_index + 1}_{spec.label}_{uuid.uuid4().hex[:8]}" + request_ids.append(request_id) + tasks.append( + _collect_streaming_audio( + omni, + args, + spec, + request_id, + phase_label=f"warmup={warmup_index + 1}/{args.warmup_runs}", + prompt_index=prompt_index, + prompt_count=len(warmup_specs), + ) + ) + results = await asyncio.gather(*tasks) + for request_id in request_ids: + await _abort_streaming_residual_work(omni, request_id) + total_samples = sum(int(audio.numel()) for audio, _, _, _ in results) + warmup_ttfps = [ttfp for _, _, _, ttfp in results if ttfp is not None] + ttfp_text = f", ttfp={min(warmup_ttfps):.2f}s" if warmup_ttfps else "" + print( + f"Warmup (streaming) {warmup_index + 1}/{args.warmup_runs} finished: " + f"{len(results)} prompt(s), {total_samples} sample(s) " + f"({time.perf_counter() - t_warmup:.2f}s{ttfp_text})" + ) + + +async def _run_streaming(args) -> list[Path]: + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + omni = AsyncOmni( + model=args.model, + stage_configs_path=args.stage_configs_path, + log_stats=args.log_stats, + stage_init_timeout=args.stage_init_timeout, + ) + + await _run_streaming_warmup(args, omni) + profiler_started = False + if args.enable_profiler: + profile_prefix = f"voxcpm_streaming_{int(time.time())}" + stages_text = args.profiler_stages if args.profiler_stages is not None else "all-configured" + print(f"Starting profiler (streaming): stages={stages_text}, dir={args.profiler_dir}") + await omni.start_profile(profile_prefix=profile_prefix, stages=args.profiler_stages) + profiler_started = True + t_total = time.perf_counter() + total_elapsed = 0.0 + paths: list[Path] = [] + prompt_specs: list[PromptSpec] = args.prompt_specs + try: + for run in range(args.num_runs): + for prompt_index, spec in enumerate(prompt_specs): + request_id = f"stream_{run + 1}_{spec.label}_{uuid.uuid4().hex[:8]}" + paths.append( + await _run_streaming_single( + omni, + args, + spec, + output_dir, + request_id, + run_index=run, + num_runs=args.num_runs, + prompt_index=prompt_index, + prompt_count=len(prompt_specs), + ) + ) + total_elapsed = time.perf_counter() - t_total + finally: + if profiler_started: + print("Stopping profiler (streaming)...") + await omni.stop_profile(stages=args.profiler_stages) + if args.profiler_wait_seconds > 0: + print(f"Waiting {args.profiler_wait_seconds:.1f}s for profiler traces to flush...") + await asyncio.sleep(args.profiler_wait_seconds) + + print( + f"All streaming runs finished: {args.num_runs} run(s), " + f"{len(prompt_specs)} prompt(s), {len(paths)} file(s) in {total_elapsed:.2f}s total" + ) + return paths + + +def _run_sync(args) -> list[Path]: + output_dir = Path(args.output_dir) + + omni = Omni( + model=args.model, + stage_configs_path=args.stage_configs_path, + log_stats=args.log_stats, + stage_init_timeout=args.stage_init_timeout, + ) + + def _run_sync_single( + spec: PromptSpec, + *, + request_prefix: str, + save_outputs: bool, + run_index: int | None = None, + ) -> tuple[list[Path], int, float | None, float, float, str]: + global_request_id = f"{request_prefix}_{spec.label}" + prompt = _build_prompt_for_spec(args, spec, global_request_id=global_request_id) + if save_outputs and run_index == 0 and spec.label == "item001": + print(f"---prompt---:{prompt}") + + saved_paths: list[Path] = [] + output_count = 0 + first_audio_elapsed: float | None = None + total_audio_duration_s = 0.0 + metrics_request_id = global_request_id + t_start = time.perf_counter() + for stage_outputs in omni.generate(prompt): + request_output = stage_outputs.request_output + if request_output is None: + continue + request_output_id = getattr(request_output, "request_id", None) + if isinstance(request_output_id, str) and request_output_id: + metrics_request_id = request_output_id + for j, mm in enumerate(_iter_request_multimodal_outputs(request_output)): + output_count += 1 + if first_audio_elapsed is None: + try: + audio_tensor = _extract_audio_tensor(mm) + if int(audio_tensor.numel()) > 0: + first_audio_elapsed = time.perf_counter() - t_start + total_audio_duration_s += float(audio_tensor.numel()) / float(_extract_sample_rate(mm)) + except ValueError: + pass + else: + try: + audio_tensor = _extract_audio_tensor(mm) + total_audio_duration_s += float(audio_tensor.numel()) / float(_extract_sample_rate(mm)) + except ValueError: + pass + if not save_outputs: + continue + save_stem = f"run{run_index + 1}_{spec.label}" if j == 0 else f"run{run_index + 1}_{spec.label}_{j}" + saved_paths.append(_save_wav(mm, output_dir, save_stem)) + + if output_count == 0: + raise RuntimeError("No output from Omni.generate") + elapsed_s = time.perf_counter() - t_start + return saved_paths, output_count, first_audio_elapsed, elapsed_s, total_audio_duration_s, metrics_request_id + + if args.warmup_runs: + warmup_specs = _get_warmup_specs(args.prompt_specs) + print( + f"Warmup: {args.warmup_runs} run(s) using the first prompt " + f"({len(warmup_specs)} prompt(s)); outputs will be discarded." + ) + for warmup_index in range(args.warmup_runs): + t_warmup = time.perf_counter() + _, output_count, first_audio_elapsed, elapsed_s, audio_duration_s, _ = _run_sync_single( + warmup_specs[0], + request_prefix=f"warmup_sync{warmup_index + 1}", + save_outputs=False, + ) + ttfp_text = f", ttfp={first_audio_elapsed:.2f}s" if first_audio_elapsed is not None else "" + rtf_text = f", rtf={elapsed_s / audio_duration_s:.3f}" if audio_duration_s > 0 else "" + print( + f"Warmup (sync) {warmup_index + 1}/{args.warmup_runs} finished: " + f"{output_count} output(s) ({time.perf_counter() - t_warmup:.2f}s{ttfp_text}{rtf_text})" + ) + + profiler_started = False + if args.enable_profiler: + profile_prefix = f"voxcpm_sync_{int(time.time())}" + stages_text = args.profiler_stages if args.profiler_stages is not None else "all-configured" + print(f"Starting profiler (sync): stages={stages_text}, dir={args.profiler_dir}") + omni.start_profile(profile_prefix=profile_prefix, stages=args.profiler_stages) + profiler_started = True + + t_total = time.perf_counter() + total_elapsed = 0.0 + saved_paths: list[Path] = [] + prompt_specs: list[PromptSpec] = args.prompt_specs + try: + for run in range(args.num_runs): + t_run = time.perf_counter() + run_paths: list[Path] = [] + for prompt_index, spec in enumerate(prompt_specs): + prompt_paths, _, first_audio_elapsed, elapsed_s, audio_duration_s, metrics_request_id = ( + _run_sync_single( + spec, + request_prefix=f"sync_run{run + 1}_{prompt_index + 1:03d}", + save_outputs=True, + run_index=run, + ) + ) + run_paths.extend(prompt_paths) + ttfp_text = f", ttfp={first_audio_elapsed:.2f}s" if first_audio_elapsed is not None else "" + rtf_text = f", rtf={elapsed_s / audio_duration_s:.3f}" if audio_duration_s > 0 else "" + print( + f"Saved (sync) run {run + 1}/{args.num_runs}, " + f"prompt {prompt_index + 1}/{len(prompt_specs)}: {len(prompt_paths)} file(s){ttfp_text}{rtf_text}" + ) + _emit_offline_metrics( + request_id=metrics_request_id, + elapsed_s=elapsed_s, + first_audio_elapsed=first_audio_elapsed, + audio_duration_s=audio_duration_s, + ) + + saved_paths.extend(run_paths) + print( + f"Run {run + 1}/{args.num_runs} finished: {len(run_paths)} file(s) ({time.perf_counter() - t_run:.2f}s)" + ) + for path in run_paths: + print(f" {path}") + + total_elapsed = time.perf_counter() - t_total + finally: + if profiler_started: + print("Stopping profiler (sync)...") + omni.stop_profile(stages=args.profiler_stages) + if args.profiler_wait_seconds > 0: + print(f"Waiting {args.profiler_wait_seconds:.1f}s for profiler traces to flush...") + time.sleep(args.profiler_wait_seconds) + + print( + f"All sync runs finished: {args.num_runs} run(s), " + f"{len(prompt_specs)} prompt(s), {len(saved_paths)} file(s) in {total_elapsed:.2f}s total" + ) + return saved_paths + + +def main(args) -> int: + logging.basicConfig(level=logging.INFO) + profiled_stage_config_path: str | None = None + original_stage_config_path = args.stage_configs_path + if args.enable_profiler: + Path(args.profiler_dir).mkdir(parents=True, exist_ok=True) + profiled_stage_config_path = _build_profiled_stage_config( + args.stage_configs_path, + str(Path(args.profiler_dir).resolve()), + ) + args.stage_configs_path = profiled_stage_config_path + + is_streaming = _is_streaming_stage_config(args.stage_configs_path) + voice_clone_count = _count_voice_clone_prompts(args.prompt_specs) + print(f"Model: {args.model}") + print(f"Stage config: {original_stage_config_path}") + print(f"Route: {'streaming' if is_streaming else 'sync'} (from stage-configs-path)") + print(f"Prompt count: {len(args.prompt_specs)}") + print("Batch mode: sequential (aligned with native VoxCPM)") + print(f"Warmup runs: {args.warmup_runs}") + print(f"Voice cloning prompts: {voice_clone_count}/{len(args.prompt_specs)}") + if args.enable_profiler: + print(f"Profiler: enabled (dir={args.profiler_dir}, stages={args.profiler_stages or 'all-configured'})") + print(f"Profiled stage config: {args.stage_configs_path}") + if voice_clone_count: + print("Voice cloning note: --ref-text/ref_text must match the spoken content of the reference audio.") + print(f"Num runs: {args.num_runs}") + try: + if is_streaming: + asyncio.run(_run_streaming(args)) + else: + _run_sync(args) + finally: + if profiled_stage_config_path is not None and os.path.exists(profiled_stage_config_path): + os.unlink(profiled_stage_config_path) + return 0 + + +if __name__ == "__main__": + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + raise SystemExit(main(parse_args())) diff --git a/benchmarks/voxcpm/vllm_omni/bench_tts_serve.py b/benchmarks/voxcpm/vllm_omni/bench_tts_serve.py new file mode 100644 index 0000000000..816df32796 --- /dev/null +++ b/benchmarks/voxcpm/vllm_omni/bench_tts_serve.py @@ -0,0 +1,283 @@ +"""Benchmark VoxCPM via /v1/audio/speech. + +Reports TTFP (time to first packet), E2E latency, and RTF (real-time factor). +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import time +from dataclasses import asdict, dataclass, field +from datetime import datetime +from pathlib import Path + +import aiohttp +import numpy as np +from tqdm.asyncio import tqdm + +DEFAULT_MODEL = "OpenBMB/VoxCPM1.5" +DEFAULT_SAMPLE_RATE = 24000 +PROMPTS = [ + "Hello, welcome to the VoxCPM speech benchmark.", + "This is a short benchmark prompt for online text-to-speech generation.", + "The quick brown fox jumps over the lazy dog near the riverbank.", + "Please remember to bring your identification documents tomorrow morning.", + "Learning a new language takes patience, practice, and curiosity.", + "This benchmark reports TTFP and RTF for the VoxCPM online serving path.", +] + + +@dataclass +class RequestResult: + success: bool = False + ttfp: float = 0.0 + e2e: float = 0.0 + audio_bytes: int = 0 + audio_duration: float = 0.0 + rtf: float = 0.0 + prompt: str = "" + error: str = "" + + +@dataclass +class BenchmarkResult: + concurrency: int = 0 + num_prompts: int = 0 + completed: int = 0 + failed: int = 0 + duration_s: float = 0.0 + mean_ttfp_ms: float = 0.0 + median_ttfp_ms: float = 0.0 + p95_ttfp_ms: float = 0.0 + mean_e2e_ms: float = 0.0 + median_e2e_ms: float = 0.0 + p95_e2e_ms: float = 0.0 + mean_rtf: float = 0.0 + median_rtf: float = 0.0 + p95_rtf: float = 0.0 + total_audio_duration_s: float = 0.0 + request_throughput: float = 0.0 + per_request: list[dict[str, float | str]] = field(default_factory=list) + + +def pcm_bytes_to_duration(num_bytes: int, sample_rate: int = DEFAULT_SAMPLE_RATE, sample_width: int = 2) -> float: + num_samples = num_bytes / sample_width + return num_samples / sample_rate + + +async def send_tts_request( + session: aiohttp.ClientSession, + api_url: str, + *, + model: str, + prompt: str, + ref_audio: str | None, + ref_text: str | None, + pbar: tqdm | None = None, +) -> RequestResult: + payload: dict[str, object] = { + "model": model, + "input": prompt, + "stream": True, + "response_format": "pcm", + } + if ref_audio is not None: + payload["ref_audio"] = ref_audio + if ref_text is not None: + payload["ref_text"] = ref_text + + result = RequestResult(prompt=prompt) + started_at = time.perf_counter() + + try: + async with session.post(api_url, json=payload) as response: + if response.status != 200: + result.error = f"HTTP {response.status}: {await response.text()}" + return result + + first_chunk = True + total_bytes = 0 + async for chunk in response.content.iter_any(): + if not chunk: + continue + if first_chunk: + result.ttfp = time.perf_counter() - started_at + first_chunk = False + total_bytes += len(chunk) + + result.e2e = time.perf_counter() - started_at + result.audio_bytes = total_bytes + result.audio_duration = pcm_bytes_to_duration(total_bytes) + if result.audio_duration > 0: + result.rtf = result.e2e / result.audio_duration + result.success = True + except Exception as e: + result.error = str(e) + result.e2e = time.perf_counter() - started_at + + if pbar is not None: + pbar.update(1) + return result + + +async def run_benchmark( + *, + host: str, + port: int, + model: str, + num_prompts: int, + max_concurrency: int, + num_warmups: int, + ref_audio: str | None, + ref_text: str | None, +) -> BenchmarkResult: + api_url = f"http://{host}:{port}/v1/audio/speech" + connector = aiohttp.TCPConnector(limit=max_concurrency, limit_per_host=max_concurrency, keepalive_timeout=60) + timeout = aiohttp.ClientTimeout(total=600) + + async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: + if num_warmups > 0: + print(f" Warming up with {num_warmups} requests...") + warmup_tasks = [ + send_tts_request( + session, + api_url, + model=model, + prompt=PROMPTS[i % len(PROMPTS)], + ref_audio=ref_audio, + ref_text=ref_text, + ) + for i in range(num_warmups) + ] + await asyncio.gather(*warmup_tasks) + print(" Warmup done.") + + request_prompts = [PROMPTS[i % len(PROMPTS)] for i in range(num_prompts)] + semaphore = asyncio.Semaphore(max_concurrency) + pbar = tqdm(total=num_prompts, desc=f" concurrency={max_concurrency}") + + async def limited_request(prompt: str) -> RequestResult: + async with semaphore: + return await send_tts_request( + session, + api_url, + model=model, + prompt=prompt, + ref_audio=ref_audio, + ref_text=ref_text, + pbar=pbar, + ) + + started_at = time.perf_counter() + results = await asyncio.gather(*[asyncio.create_task(limited_request(prompt)) for prompt in request_prompts]) + duration = time.perf_counter() - started_at + pbar.close() + + succeeded = [result for result in results if result.success] + bench = BenchmarkResult( + concurrency=max_concurrency, + num_prompts=num_prompts, + completed=len(succeeded), + failed=len(results) - len(succeeded), + duration_s=duration, + ) + + if not succeeded: + return bench + + ttfps = np.array([result.ttfp * 1000 for result in succeeded], dtype=np.float64) + e2es = np.array([result.e2e * 1000 for result in succeeded], dtype=np.float64) + rtfs = np.array([result.rtf for result in succeeded], dtype=np.float64) + audio_durations = np.array([result.audio_duration for result in succeeded], dtype=np.float64) + + bench.mean_ttfp_ms = float(np.mean(ttfps)) + bench.median_ttfp_ms = float(np.median(ttfps)) + bench.p95_ttfp_ms = float(np.percentile(ttfps, 95)) + bench.mean_e2e_ms = float(np.mean(e2es)) + bench.median_e2e_ms = float(np.median(e2es)) + bench.p95_e2e_ms = float(np.percentile(e2es, 95)) + bench.mean_rtf = float(np.mean(rtfs)) + bench.median_rtf = float(np.median(rtfs)) + bench.p95_rtf = float(np.percentile(rtfs, 95)) + bench.total_audio_duration_s = float(np.sum(audio_durations)) + bench.request_throughput = len(succeeded) / duration if duration > 0 else 0.0 + bench.per_request = [ + { + "prompt": result.prompt, + "ttfp_ms": result.ttfp * 1000, + "e2e_ms": result.e2e * 1000, + "rtf": result.rtf, + "audio_duration_s": result.audio_duration, + } + for result in succeeded + ] + + return bench + + +def print_summary(result: BenchmarkResult) -> None: + width = 54 + print("") + print("=" * width) + print(f"{'VoxCPM Serving Benchmark':^{width}}") + print("=" * width) + print(f"concurrency : {result.concurrency}") + print(f"requests : {result.completed}/{result.num_prompts} succeeded") + print(f"wall time (s) : {result.duration_s:.3f}") + print(f"mean TTFP (ms) : {result.mean_ttfp_ms:.2f}") + print(f"p95 TTFP (ms) : {result.p95_ttfp_ms:.2f}") + print(f"mean E2E (ms) : {result.mean_e2e_ms:.2f}") + print(f"p95 E2E (ms) : {result.p95_e2e_ms:.2f}") + print(f"mean RTF : {result.mean_rtf:.3f}") + print(f"p95 RTF : {result.p95_rtf:.3f}") + print(f"request throughput : {result.request_throughput:.2f} req/s") + print("=" * width) + + +async def main_async(args) -> None: + result_dir = Path(args.result_dir) + result_dir.mkdir(parents=True, exist_ok=True) + + all_results: list[BenchmarkResult] = [] + for concurrency in args.max_concurrency: + result = await run_benchmark( + host=args.host, + port=args.port, + model=args.model, + num_prompts=args.num_prompts, + max_concurrency=concurrency, + num_warmups=args.num_warmups, + ref_audio=args.ref_audio, + ref_text=args.ref_text, + ) + print_summary(result) + all_results.append(result) + + payload = { + "model": args.model, + "created_at": datetime.utcnow().isoformat() + "Z", + "results": [asdict(result) for result in all_results], + } + result_path = result_dir / "bench_tts_serve.json" + result_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + print(f"Saved results to: {result_path}") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Benchmark VoxCPM via /v1/audio/speech") + parser.add_argument("--host", default="127.0.0.1", help="Server host") + parser.add_argument("--port", type=int, default=8091, help="Server port") + parser.add_argument("--model", default=DEFAULT_MODEL, help="Model name or path") + parser.add_argument("--num-prompts", type=int, default=20, help="Number of prompts to send") + parser.add_argument("--max-concurrency", type=int, nargs="+", default=[1], help="Concurrency levels to benchmark") + parser.add_argument("--num-warmups", type=int, default=3, help="Warmup request count") + parser.add_argument("--ref-audio", default=None, help="Reference audio URL or data URL for voice cloning") + parser.add_argument("--ref-text", default=None, help="Reference audio transcript for voice cloning") + parser.add_argument("--result-dir", default="results", help="Directory to save benchmark JSON") + return parser.parse_args() + + +if __name__ == "__main__": + asyncio.run(main_async(parse_args())) diff --git a/benchmarks/voxcpm/vllm_omni/run_offline_matrix.py b/benchmarks/voxcpm/vllm_omni/run_offline_matrix.py new file mode 100644 index 0000000000..cee46c0f86 --- /dev/null +++ b/benchmarks/voxcpm/vllm_omni/run_offline_matrix.py @@ -0,0 +1,303 @@ +"""Run the full offline VoxCPM smoke matrix. + +This script keeps the old `test.py` coverage, but delegates each case to +`bench_tts_offline.py` so the benchmark runner itself stays focused on a +single execution path. +""" + +from __future__ import annotations + +import shlex +import subprocess +import sys +import time +from dataclasses import dataclass +from pathlib import Path + +from vllm.utils.argparse_utils import FlexibleArgumentParser + +REPO_ROOT = Path(__file__).resolve().parents[3] +BENCH_SCRIPT = Path(__file__).with_name("bench_tts_offline.py") +DEFAULT_STAGE_ASYNC = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm_async_chunk.yaml" +DEFAULT_STAGE_SYNC = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm.yaml" +DEFAULT_OUTPUT_ROOT = BENCH_SCRIPT.parents[1] / "results" / "offline_matrix" + +SINGLE_TTS_TEXT = "This is a single text-to-speech smoke test for VoxCPM on vLLM Omni." +SINGLE_CLONE_TEXT = "This sentence is synthesized with the cloned voice for validation." +BATCH_TTS_TEXTS = [ + "The first batch text-to-speech sample validates sequential batch execution.", + "The second batch text-to-speech sample checks another prompt in the same file.", + "The third batch text-to-speech sample completes the sequential batch path.", +] +BATCH_CLONE_TEXTS = [ + "The first cloned sample validates sequential batch voice cloning.", + "The second cloned sample checks the same reference voice on another prompt.", + "The third cloned sample finishes the shared-reference clone batch path.", +] + + +@dataclass(frozen=True, slots=True) +class ModeSpec: + name: str + stage_config: Path + + +@dataclass(frozen=True, slots=True) +class CaseSpec: + name: str + warmup_runs: int + prompt_kind: str + voice_clone: bool + + +@dataclass(frozen=True, slots=True) +class CaseResult: + mode: str + case: str + returncode: int + elapsed_s: float + output_dir: Path + log_path: Path + + @property + def ok(self) -> bool: + return self.returncode == 0 + + +MODE_SPECS = [ + ModeSpec(name="streaming", stage_config=DEFAULT_STAGE_ASYNC), + ModeSpec(name="sync", stage_config=DEFAULT_STAGE_SYNC), +] + +CASE_SPECS = [ + CaseSpec(name="warmup_single_tts", warmup_runs=1, prompt_kind="single", voice_clone=False), + CaseSpec(name="warmup_single_clone", warmup_runs=1, prompt_kind="single", voice_clone=True), + CaseSpec(name="warmup_batch_tts", warmup_runs=1, prompt_kind="batch", voice_clone=False), + CaseSpec(name="warmup_batch_clone", warmup_runs=1, prompt_kind="batch", voice_clone=True), + CaseSpec(name="cold_single_tts", warmup_runs=0, prompt_kind="single", voice_clone=False), + CaseSpec(name="cold_single_clone", warmup_runs=0, prompt_kind="single", voice_clone=True), +] + + +def _write_lines(path: Path, lines: list[str]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def _prepare_batch_inputs(output_root: Path) -> tuple[Path, Path]: + input_dir = output_root / "inputs" + batch_tts_path = input_dir / "batch_tts_prompts.txt" + batch_clone_path = input_dir / "batch_clone_prompts.txt" + _write_lines(batch_tts_path, BATCH_TTS_TEXTS) + _write_lines(batch_clone_path, BATCH_CLONE_TEXTS) + return batch_tts_path, batch_clone_path + + +def _base_command(args, mode: ModeSpec, output_dir: Path) -> list[str]: + cmd = [ + args.python, + str(BENCH_SCRIPT), + "--model", + args.model, + "--stage-configs-path", + str(mode.stage_config), + "--output-dir", + str(output_dir), + "--num-runs", + str(args.num_runs), + "--stage-init-timeout", + str(args.stage_init_timeout), + ] + cmd.append("--log-stats" if args.log_stats else "--no-log-stats") + cmd.extend(["--cfg-value", str(args.cfg_value)]) + cmd.extend(["--inference-timesteps", str(args.inference_timesteps)]) + cmd.extend(["--min-len", str(args.min_len)]) + cmd.extend(["--max-new-tokens", str(args.max_new_tokens)]) + if args.streaming_prefix_len is not None: + cmd.extend(["--streaming-prefix-len", str(args.streaming_prefix_len)]) + if args.enable_profiler: + profiler_dir = Path(args.profiler_dir) if args.profiler_dir is not None else (output_dir / "profiler") + cmd.append("--enable-profiler") + cmd.extend(["--profiler-dir", str(profiler_dir)]) + cmd.extend(["--profiler-wait-seconds", str(args.profiler_wait_seconds)]) + if args.profiler_stages is not None: + cmd.append("--profiler-stages") + cmd.extend(str(stage_id) for stage_id in args.profiler_stages) + return cmd + + +def _build_case_command( + args, + mode: ModeSpec, + case: CaseSpec, + *, + batch_tts_path: Path, + batch_clone_path: Path, + output_dir: Path, +) -> list[str]: + cmd = _base_command(args, mode, output_dir) + cmd.extend(["--warmup-runs", str(case.warmup_runs)]) + if case.prompt_kind == "single": + cmd.extend(["--text", SINGLE_CLONE_TEXT if case.voice_clone else SINGLE_TTS_TEXT]) + else: + cmd.extend(["--txt-prompts", str(batch_clone_path if case.voice_clone else batch_tts_path)]) + if case.voice_clone: + cmd.extend(["--ref-audio", args.ref_audio, "--ref-text", args.ref_text]) + return cmd + + +def _run_case( + args, + mode: ModeSpec, + case: CaseSpec, + *, + batch_tts_path: Path, + batch_clone_path: Path, + output_root: Path, +) -> CaseResult: + case_output_dir = output_root / mode.name / case.name + case_output_dir.mkdir(parents=True, exist_ok=True) + case_log_path = case_output_dir / "run.log" + cmd = _build_case_command( + args, + mode, + case, + batch_tts_path=batch_tts_path, + batch_clone_path=batch_clone_path, + output_dir=case_output_dir, + ) + + print() + print("=" * 80) + print(f"[{mode.name}] {case.name}") + print(f"Output directory: {case_output_dir}") + print(shlex.join(cmd)) + + start = time.perf_counter() + with case_log_path.open("w", encoding="utf-8") as log_fp: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + assert process.stdout is not None + for line in process.stdout: + print(line, end="") + log_fp.write(line) + process.wait() + + elapsed_s = time.perf_counter() - start + status = "PASS" if (process.returncode or 0) == 0 else f"FAIL({process.returncode})" + print(f"[{mode.name}] {case.name} -> {status} ({elapsed_s:.2f}s)") + return CaseResult( + mode=mode.name, + case=case.name, + returncode=int(process.returncode or 0), + elapsed_s=elapsed_s, + output_dir=case_output_dir, + log_path=case_log_path, + ) + + +def parse_args(): + parser = FlexibleArgumentParser(description="Run the full offline VoxCPM smoke matrix.") + parser.add_argument("--model", type=str, required=True, help="Local VoxCPM model directory.") + parser.add_argument("--ref-audio", type=str, required=True, help="Reference audio path for clone cases.") + parser.add_argument("--ref-text", type=str, required=True, help="Exact transcript spoken in --ref-audio.") + parser.add_argument("--output-root", type=str, default=str(DEFAULT_OUTPUT_ROOT), help="Root directory for outputs.") + parser.add_argument("--python", type=str, default=sys.executable, help="Python executable used to launch cases.") + parser.add_argument("--stage-init-timeout", type=int, default=600, help="Stage initialization timeout in seconds.") + parser.add_argument("--log-stats", dest="log_stats", action="store_true", help="Enable vLLM Omni stats logging.") + parser.add_argument( + "--no-log-stats", + dest="log_stats", + action="store_false", + help="Disable vLLM Omni stats logging.", + ) + parser.set_defaults(log_stats=True) + parser.add_argument("--num-runs", type=int, default=1, help="Number of measured runs per case.") + parser.add_argument("--cfg-value", type=float, default=2.0, help="Classifier-free guidance value for VoxCPM.") + parser.add_argument("--inference-timesteps", type=int, default=10, help="Number of inference timesteps.") + parser.add_argument("--min-len", type=int, default=2, help="Minimum generated token length.") + parser.add_argument("--max-new-tokens", type=int, default=4096, help="Maximum generated token length.") + parser.add_argument( + "--streaming-prefix-len", + type=int, + default=None, + help="Optional VoxCPM streaming window passed to streaming cases.", + ) + parser.add_argument("--enable-profiler", action="store_true", help="Enable torch profiler for each case.") + parser.add_argument( + "--profiler-dir", + type=str, + default=None, + help="Profiler output root. Defaults to /profiler.", + ) + parser.add_argument( + "--profiler-stages", + type=int, + nargs="*", + default=None, + help="Optional stage ids to profile. Defaults to all configured stages.", + ) + parser.add_argument( + "--profiler-wait-seconds", + type=float, + default=30.0, + help="Seconds to wait after stopping profiler for traces to flush.", + ) + args = parser.parse_args() + if args.num_runs < 1: + parser.error("--num-runs must be >= 1") + return args + + +def main(args) -> int: + output_root = Path(args.output_root) + output_root.mkdir(parents=True, exist_ok=True) + batch_tts_path, batch_clone_path = _prepare_batch_inputs(output_root) + + print(f"Model: {args.model}") + print(f"Reference audio: {args.ref_audio}") + print(f"Reference text: {args.ref_text}") + print(f"Python: {args.python}") + print(f"Output root: {output_root}") + print(f"Cases: {len(MODE_SPECS) * len(CASE_SPECS)}") + + results: list[CaseResult] = [] + for mode in MODE_SPECS: + for case in CASE_SPECS: + results.append( + _run_case( + args, + mode, + case, + batch_tts_path=batch_tts_path, + batch_clone_path=batch_clone_path, + output_root=output_root, + ) + ) + + failed = [result for result in results if not result.ok] + print() + print("=" * 80) + print("Summary:") + for result in results: + status = "PASS" if result.ok else f"FAIL({result.returncode})" + print(f"- [{result.mode}] {result.case}: {status} ({result.elapsed_s:.2f}s)") + print(f" output_dir={result.output_dir}") + print(f" log={result.log_path}") + + print(f"Passed: {len(results) - len(failed)}/{len(results)}") + if failed: + print("Failed cases:") + for result in failed: + print(f"- [{result.mode}] {result.case}: see {result.log_path}") + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main(parse_args())) diff --git a/docker/Dockerfile.ci b/docker/Dockerfile.ci index 24ce39bafd..9cbf89d0b7 100644 --- a/docker/Dockerfile.ci +++ b/docker/Dockerfile.ci @@ -7,7 +7,7 @@ COPY . . # Install system dependencies RUN apt-get update && \ - apt-get install -y espeak-ng ffmpeg git sox libsox-fmt-all jq && \ + apt-get install -y espeak-ng git jq && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* diff --git a/docker/Dockerfile.cuda b/docker/Dockerfile.cuda index 754d491d86..28e10f4fb8 100644 --- a/docker/Dockerfile.cuda +++ b/docker/Dockerfile.cuda @@ -7,7 +7,7 @@ WORKDIR ${COMMON_WORKDIR} # Step 1: Setup - Install system dependencies RUN apt-get update && \ - apt-get install -y ffmpeg git sox libsox-fmt-all jq && \ + apt-get install -y git jq && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index bfbb060bcb..a54aa3b793 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -18,8 +18,10 @@ ARG COMMON_WORKDIR=/app WORKDIR ${COMMON_WORKDIR} # Step 1: Setup - Install system dependencies +# Need to include ffmpeg because vllm rocm upstream docker image +# does not include it. RUN apt-get update && \ - apt-get install -y espeak-ng ffmpeg git sox libsox-fmt-all jq && \ + apt-get install -y espeak-ng ffmpeg git jq && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* @@ -39,6 +41,24 @@ RUN if [ "${USE_NIGHTLY_BUILD}" = "1" ]; then \ # Step 3: Copy vllm-omni code and install without uv RUN mkdir -p ${COMMON_WORKDIR}/vllm-omni COPY . ${COMMON_WORKDIR}/vllm-omni + +# This is a workaround to ensure pytest exits with the correct status code in CI tests. +RUN printf '%s\n' \ + 'import os' \ + '' \ + '_exit_code = 1' \ + '' \ + 'def pytest_sessionfinish(session, exitstatus):' \ + ' global _exit_code' \ + ' _exit_code = int(exitstatus)' \ + '' \ + 'def pytest_unconfigure(config):' \ + ' import sys' \ + ' sys.stdout.flush()' \ + ' sys.stderr.flush()' \ + ' os._exit(_exit_code)' \ + > ${COMMON_WORKDIR}/vllm-omni/conftest.py + RUN cd ${COMMON_WORKDIR}/vllm-omni && uv pip install --python "$(python3 -c 'import sys; print(sys.executable)')" --no-cache-dir ".[dev]" --no-build-isolation RUN ln -sf /usr/bin/python3 /usr/bin/python diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index 17f1aebf0d..25d5d0c800 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -15,9 +15,7 @@ RUN apt clean && apt-get update -y && \ apt-get install -y --no-install-recommends --fix-missing \ curl \ espeak-ng \ - ffmpeg \ git \ - libsndfile1 \ libsm6 \ libxext6 \ libgl1 \ diff --git a/docs/.nav.yml b/docs/.nav.yml index 86ce4a3b0c..455a052505 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -64,6 +64,7 @@ nav: - FP8: user_guide/diffusion/quantization/fp8.md - Int8: user_guide/diffusion/quantization/int8.md - GGUF: user_guide/diffusion/quantization/gguf.md + - Frame Interpolation: user_guide/diffusion/frame_interpolation.md - Parallelism: - Overview: user_guide/diffusion/parallelism/overview.md - CFG Parallel: user_guide/diffusion/parallelism/cfg_parallel.md @@ -97,6 +98,7 @@ nav: - design/feature/disaggregated_inference.md - design/feature/ray_based_execution.md - design/feature/omni_connectors/ + - design/feature/prefix_caching.md - design/feature/cfg_parallel.md - design/feature/expert_parallel.md - design/feature/sequence_parallel.md @@ -105,7 +107,7 @@ nav: - design/feature/hsdp.md - design/feature/cache_dit.md - design/feature/teacache.md - - design/feature/async_chunk_design.md + - design/feature/async_chunk.md - design/feature/vae_parallel.md - design/feature/diffusion_step_execution.md - Module Design: diff --git a/docs/api/README.md b/docs/api/README.md index f65cbb525d..0147f19e12 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -5,7 +5,7 @@ Main entry points for vLLM-Omni inference and serving. - [vllm_omni.entrypoints.async_omni.AsyncOmni][] -- [vllm_omni.entrypoints.cfg_companion_tracker.CfgCompanionTracker][] +- [vllm_omni.engine.cfg_companion_tracker.CfgCompanionTracker][] - [vllm_omni.entrypoints.cli.benchmark.base.OmniBenchmarkSubcommandBase][] - [vllm_omni.entrypoints.cli.benchmark.main.OmniBenchmarkSubcommand][] - [vllm_omni.entrypoints.cli.benchmark.serve.OmniBenchmarkServingSubcommand][] diff --git a/docs/assets/WeChat.jpg b/docs/assets/WeChat.jpg index c32ece6c10..83252b7569 100644 Binary files a/docs/assets/WeChat.jpg and b/docs/assets/WeChat.jpg differ diff --git a/docs/configuration/README.md b/docs/configuration/README.md index b5761a7f1b..390176e9ce 100644 --- a/docs/configuration/README.md +++ b/docs/configuration/README.md @@ -6,7 +6,7 @@ For options within a vLLM Engine. Please refer to [vLLM Configuration](https://d Currently, the main options are maintained by stage configs for each model. -For specific example, please refer to [Qwen2.5-omni stage config](stage_configs/qwen2_5_omni.yaml) +For a specific example, see the [Qwen2.5-Omni deploy config](gh-file:vllm_omni/deploy/qwen2_5_omni.yaml). The matching frozen pipeline topology lives at [vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py](gh-file:vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py). For introduction, please check [Introduction for stage config](./stage_configs.md) diff --git a/docs/configuration/pd_disaggregation.md b/docs/configuration/pd_disaggregation.md index 1cf6189e60..9196bdb024 100644 --- a/docs/configuration/pd_disaggregation.md +++ b/docs/configuration/pd_disaggregation.md @@ -11,7 +11,7 @@ deployment-specific values usually change per environment: - connector backend and connector ports - connector IPs or bootstrap addresses -Start from the [default Qwen3-Omni stage config](gh-file:vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml) +Start from the [default Qwen3-Omni stage config](gh-file:vllm_omni/deploy/qwen3_omni_moe.yaml) and copy it to your own file, for example `qwen3_omni_pd.yaml`. Then apply the changes below. @@ -145,19 +145,13 @@ Compared with the default Qwen3-Omni config: ```yaml runtime: enabled: true - defaults: - window_size: -1 - max_inflight: 1 edges: - from: 0 to: 1 - window_size: -1 - from: 1 to: 2 - window_size: -1 - from: 2 to: 3 - window_size: -1 ``` ## 4. Launch with your custom config diff --git a/docs/configuration/stage_configs.md b/docs/configuration/stage_configs.md index 95c42afcc7..55b4053cc7 100644 --- a/docs/configuration/stage_configs.md +++ b/docs/configuration/stage_configs.md @@ -3,7 +3,147 @@ In vLLM-Omni, the target model is separated into multiple stages, which are processed by different LLMEngines, DiffusionEngines or other types of engines. Depending on different types of stages, such as Autoregressive (AR) stage or Diffusion transformer (DiT) stage, each can choose corresponding schedulers, model workers to load with the Engines in a plug-in fashion. !!! note - Default stage config YAMLs (for example, `vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml` and `vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml`) are bundled and loaded automatically when `stage_configs_path` is not provided. They have been verified to work on 1xH100 for Qwen2.5-Omni and 2xH100 for Qwen3-Omni. + Default deploy config YAMLs (for example, `vllm_omni/deploy/qwen2_5_omni.yaml`, `vllm_omni/deploy/qwen3_omni_moe.yaml`, and `vllm_omni/deploy/qwen3_tts.yaml`) are bundled and loaded automatically when neither `--stage-configs-path` nor `--deploy-config` is provided — the model registry resolves the right pipeline + deploy YAML by `model_type`. The bundled defaults have been verified on 1xH100 for Qwen2.5-Omni and 2xH100 for Qwen3-Omni. Models that have not yet migrated to the new schema continue to use the legacy `vllm_omni/model_executor/stage_configs/.yaml` files via `--stage-configs-path`. + +## New deploy schema reference + +The new deploy schema lives under `vllm_omni/deploy/` and is paired with a frozen `PipelineConfig` registered by the model's `pipeline.py`. Each deploy YAML has these top-level fields: + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `base_config` | str (path) | optional | — | Overlay parent (relative or absolute). `stages:` / `platforms:` deep-merged by stage_id; other scalars overlay-wins. Intended for user-authored overlays; prod yamls stay flat. | +| `async_chunk` | bool | optional | `true` | Enable chunked streaming between stages. Pin to `false` if the pipeline runs end-to-end. | +| `connectors` | dict | optional | `null` | Named connector specs (`{name, extra}`). Referenced by each stage's `input_connectors` / `output_connectors`. See [Connector schema](#connector-schema). | +| `edges` | list | optional | `null` | Explicit edge list for the KV transfer graph. Auto-derived from stage inputs if omitted. | +| `stages` | list | required | — | Per-stage engine args + wiring (see [Stage fields](#stage-fields)). | +| `platforms` | dict | optional | `null` | Keyed by `npu` / `rocm` / `xpu`, each contains a `stages:` list with per-platform overrides applied on top of the CUDA defaults. | +| `pipeline` | str | optional | `null` | Override the auto-detected pipeline registry key (used for structural variants like `qwen2_5_omni_thinker_only`). | +| `trust_remote_code` | bool | optional | `true` | **Pipeline-wide.** Trust HF remote code on model load; applies to every stage. | +| `distributed_executor_backend` | str | optional | `"mp"` | **Pipeline-wide.** Executor backend (`"mp"` or `"ray"`). | +| `dtype` | str \| null | optional | `null` | **Pipeline-wide.** Model dtype for every stage. | +| `quantization` | str \| null | optional | `null` | **Pipeline-wide.** Quantization method for every stage. | +| `enable_prefix_caching` | bool | optional | `false` | **Pipeline-wide.** Prefix cache toggle applied to every stage. | +| `enable_chunked_prefill` | bool \| null | optional | `null` | **Pipeline-wide.** Chunked prefill toggle applied to every stage. | +| `data_parallel_size` | int | optional | `1` | **Pipeline-wide.** DP degree for every stage. | +| `pipeline_parallel_size` | int | optional | `1` | **Pipeline-wide.** PP degree for every stage. | + +### Stage fields + +Each entry under `stages:` accepts any `StageDeployConfig` field directly (no nested `engine_args:`). Only fields whose value legitimately varies across stages live here; pipeline-wide settings (trust_remote_code, distributed_executor_backend, dtype, quantization, prefix/chunked prefill, DP/PP sizes) are declared at the top level and applied to every stage. Unknown keys fall through to `engine_extras:` and are forwarded to the engine. + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `stage_id` | int | required | — | Stage identity; matched against `PipelineConfig.stages[*].stage_id`. | +| `max_num_seqs` | int | optional | `64` | Max concurrent sequences per stage. | +| `gpu_memory_utilization` | float | optional | `0.9` | Per-stage memory budget. | +| `tensor_parallel_size` | int | optional | `1` | TP degree for this stage. | +| `enforce_eager` | bool | optional | `false` | Disable CUDA graphs. | +| `max_num_batched_tokens` | int | optional | `32768` | Prefill budget. | +| `max_model_len` | int \| null | optional | `null` | Per-stage context length (auto-sets `VLLM_ALLOW_LONG_MAX_MODEL_LEN=1` when larger than HF default). | +| `async_scheduling` | bool \| null | optional | `null` | Per-stage async scheduling toggle. | +| `devices` | str | optional | `"0"` | `CUDA_VISIBLE_DEVICES`-style device list. | +| `output_connectors` | dict \| null | optional | `null` | Keyed by `to_stage_`; values are names registered under top-level `connectors:`. | +| `input_connectors` | dict \| null | optional | `null` | Keyed by `from_stage_`; values are names registered under top-level `connectors:`. | +| `default_sampling_params` | dict \| null | optional | `null` | Baseline sampling params. Deep-merged with pipeline `sampling_constraints` (pipeline wins). | +| `engine_extras` | dict | optional | `{}` | Catch-all for keys not listed above; deep-merged across overlays. Also carries per-stage overrides of pipeline-wide settings (e.g. stage-specific `dtype`). | + +### Connector schema + +Each entry under top-level `connectors:` follows this shape: + +```yaml +connectors: + : + name: # required — class registered in vllm_omni.distributed + extra: # optional — forwarded to the connector's __init__ + : + ... +``` + +| Connector class | Use case | `extra` keys | +|-----------------|----------|--------------| +| `SharedMemoryConnector` | Same-host KV transfer between stages (default for bundled YAMLs). | `shm_threshold_bytes` (int, default `65536`). | +| `MooncakeStoreConnector` | Cross-host KV transfer over TCP. Required for multi-node deployments. | `host`, `metadata_server`, `master`, `segment` (int bytes), `localbuf` (int bytes), `proto` (`"tcp"` / `"rdma"`). | + +A stage references a connector by name in its `input_connectors` / `output_connectors`: + +```yaml +connectors: + shm: + name: SharedMemoryConnector + +stages: + - stage_id: 0 + output_connectors: {to_stage_1: shm} + - stage_id: 1 + input_connectors: {from_stage_0: shm} +``` + +### CLI flags introduced in this refactor + +| Flag | Description | +|------|-------------| +| `--deploy-config PATH` | Load a new-schema deploy YAML. Takes precedence over `--stage-configs-path`. **Optional** — when omitted, the bundled `vllm_omni/deploy/.yaml` is auto-loaded by the model registry. | +| `--stage-overrides JSON` | Per-stage JSON overrides, e.g. `'{"0":{"gpu_memory_utilization":0.5}}'`. Per-stage values always win over global flags. | +| `--async-chunk` / `--no-async-chunk` | Flip the deploy YAML's `async_chunk:` bool. Unset (default) leaves the YAML value in force. | +| `--stage-configs-path` | **Deprecated.** Accepts legacy `stage_args` yamls and (auto-detected) new deploy yamls; emits a deprecation warning. Migrate to `--deploy-config`. To be removed in a follow-up PR. | + +### Precedence + +From highest to lowest: + +1. Per-stage flags (`--stage-overrides` JSON, `--stage--` if registered) +2. Explicit global CLI flags (`--gpu-memory-utilization 0.85`, etc.) +3. Platform section (`platforms.npu.stages`, etc.) on top of the base `stages:` +4. Overlay YAML (via `base_config:`) on top of the base YAML +5. Parser defaults + +### Worked override example + +Starting from the bundled `vllm_omni/deploy/qwen3_omni_moe.yaml`: + +```yaml +# vllm_omni/deploy/qwen3_omni_moe.yaml (excerpt) +async_chunk: true +stages: + - stage_id: 0 + gpu_memory_utilization: 0.9 + max_num_seqs: 32 + - stage_id: 1 + gpu_memory_utilization: 0.7 + max_num_seqs: 16 +``` + +A user-authored overlay that inherits the base and overrides only stage 1: + +```yaml +# my_overrides.yaml +base_config: /path/to/vllm_omni/deploy/qwen3_omni_moe.yaml +stages: + - stage_id: 1 + gpu_memory_utilization: 0.5 # smaller GPU +``` + +Launched with both an explicit global flag and a per-stage override: + +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \ + --deploy-config my_overrides.yaml \ + --max-model-len 16384 \ + --stage-overrides '{"0": {"max_num_seqs": 8}}' +``` + +Effective config per stage after the merge: + +| Stage | Field | Final value | Source | +|-------|-------|-------------|--------| +| 0 | `gpu_memory_utilization` | `0.9` | base YAML (overlay didn't touch stage 0) | +| 0 | `max_num_seqs` | `8` | per-stage CLI (`--stage-overrides`) — wins over base `32` | +| 0 | `max_model_len` | `16384` | global CLI | +| 1 | `gpu_memory_utilization` | `0.5` | overlay YAML — wins over base `0.7` | +| 1 | `max_num_seqs` | `16` | base YAML (overlay didn't touch this field) | +| 1 | `max_model_len` | `16384` | global CLI | +| 2 | (all defaults) | — | base YAML (no overrides apply) | Therefore, as a core part of vLLM-Omni, the stage configs for a model have several main functions: @@ -35,7 +175,7 @@ stage_args: - stage_id: 0 # mark the unique id for each stage runtime: # The disaggregated configuration process: true # Run this stage in a separate process - devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) + devices: "0" # Logical device index for this stage (mapped through CUDA_VISIBLE_DEVICES / ASCEND_RT_VISIBLE_DEVICES if set) engine_args: # Engine arguments for a certain engine model_stage: thinker max_num_seqs: 1 @@ -114,16 +254,12 @@ stage_args: # Top-level runtime config (concise): default windows and stage edges runtime: enabled: true - defaults: - window_size: -1 # Simplified: trigger downstream only after full upstream completion - max_inflight: 1 # Simplified: process serially within each stage + edges: - from: 0 # thinker → talker: trigger only after receiving full input (-1) to: 1 - window_size: -1 - from: 1 # talker → code2wav: trigger only after receiving full input (-1) to: 2 - window_size: -1 ``` @@ -155,7 +291,9 @@ Default: `true` #### `runtime.devices` -Visible devices for this stage, specified as a string. This controls which GPU devices are available to the stage process, similar to setting `CUDA_VISIBLE_DEVICES` or using `torch.cuda.set_device()`. For example, `"0"` uses GPU 0, `"1"` uses GPU 1, and `"0,1"` makes both GPUs 0 and 1 visible. +Logical device indices for this stage, specified as a string. Values are **logical indices** (`0`, `1`, `2`, ...) — not physical GPU IDs — and are mapped through the platform's visibility env var (`CUDA_VISIBLE_DEVICES` on CUDA, `ASCEND_RT_VISIBLE_DEVICES` on NPU) before being applied via `torch.cuda.set_device()` (or the equivalent). + +Example: if `CUDA_VISIBLE_DEVICES=0,2,4` is set in the environment, then `devices: "0"` selects physical GPU 0 (the first visible), `devices: "1"` selects physical GPU 2, and `devices: "0,1"` makes physical GPUs 0 and 2 available to the stage. If no visibility env var is set, logical and physical IDs coincide. Default: `"0"` diff --git a/docs/configuration/stage_configs/qwen2_5_omni.yaml b/docs/configuration/stage_configs/qwen2_5_omni.yaml deleted file mode 100644 index 690577b84a..0000000000 --- a/docs/configuration/stage_configs/qwen2_5_omni.yaml +++ /dev/null @@ -1,94 +0,0 @@ -# stage config for running qwen2.5-omni with AsyncOmniEngine + Orchestrator runtime. -stage_args: - - stage_id: 0 - runtime: - process: true # Run this stage in a separate process - devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) - engine_args: - model_stage: thinker - max_num_seqs: 1 - model_arch: Qwen2_5OmniForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.8 - enforce_eager: true # Now we only support eager mode - trust_remote_code: true - engine_output_type: latent - enable_prefix_caching: false - is_comprehension: true - final_output: true - final_output_type: text - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.1 - - stage_id: 1 - runtime: - process: true - devices: "1" - engine_args: - model_stage: talker - max_num_seqs: 1 - model_arch: Qwen2_5OmniForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.8 - enforce_eager: true - trust_remote_code: true - enable_prefix_caching: false - engine_output_type: latent - engine_input_source: [0] - custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker - default_sampling_params: - temperature: 0.9 - top_p: 0.8 - top_k: 40 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.05 - stop_token_ids: [8294] - - stage_id: 2 - runtime: - process: true - devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU - engine_args: - model_stage: code2wav - max_num_seqs: 1 - model_arch: Qwen2_5OmniForConditionalGeneration - worker_type: generation - scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler - gpu_memory_utilization: 0.15 - enforce_eager: true - trust_remote_code: true - enable_prefix_caching: false - engine_output_type: audio - engine_input_source: [1] - final_output: true - final_output_type: audio - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.1 - -# Top-level runtime config (concise): default windows and stage edges -runtime: - enabled: true - defaults: - window_size: -1 # Simplified: trigger downstream only after full upstream completion - max_inflight: 1 # Simplified: process serially within each stage - edges: - - from: 0 # thinker → talker: trigger only after receiving full input (-1) - to: 1 - window_size: -1 - - from: 1 # talker → code2wav: trigger only after receiving full input (-1) - to: 2 - window_size: -1 diff --git a/docs/contributing/ci/CI_5levels.md b/docs/contributing/ci/CI_5levels.md index 74ae1a38eb..2452ef5d4a 100644 --- a/docs/contributing/ci/CI_5levels.md +++ b/docs/contributing/ci/CI_5levels.md @@ -86,7 +86,8 @@ Through five levels (L1-L5) and common (Common) specifications, the system clari /tests/e2e/online_serving/test_{model_name}_expansion.py
/tests/e2e/offline_inference/test_{model_name}_expansion.py
Performance:
- /tests/dfx/perf/tests/test.json
+ /tests/dfx/perf/tests/test_qwen_omni.json (Omni), test_tts.json (TTS),
+ and /tests/dfx/perf/tests/test_{diffusion_model}_vllm_omni.json (Diffusion)
Doc Test:
tests/example/online_serving/test_{model_name}.py
tests/example/offline_inference/test_{model_name}.py @@ -230,8 +231,7 @@ vllm_omni/ tests/ │ ├── test_qwen3_omni_expansion.py │ ├── test_mimo_audio.py │ ├── test_image_gen_edit.py - │ ├── test_images_generations_lora.py - │ └── stage_configs/ + │ └── test_images_generations_lora.py └── offline_inference/ ✅ ├── test_qwen2_5_omni.py ├── test_qwen3_omni.py @@ -242,16 +242,17 @@ vllm_omni/ tests/ ├── test_zimage_tensor_parallel.py ├── test_cache_dit.py ├── test_teacache.py - ├── test_stable_audio_model.py + ├── test_stable_audio_expansion.py ├── test_diffusion_cpu_offload.py ├── test_diffusion_layerwise_offload.py ├── test_diffusion_lora.py ├── test_sequence_parallel.py - └── stage_configs/ - ├── qwen2_5_omni_ci.yaml - ├── qwen3_omni_ci.yaml - ├── bagel_*.yaml - └── npu/, rocm/, etc. + └── stage_configs/ (legacy schema, still + ├── bagel_*.yaml present for unmigrated + └── npu/, rocm/, etc. models) + +# Migrated models (qwen3_omni_moe, qwen2_5_omni, qwen3_tts) live under +# vllm_omni/deploy/ instead — see docs/configuration/stage_configs.md. ``` @@ -530,13 +531,13 @@ L4 level testing is a comprehensive quality audit before a version release. It e ### 3.2 Testing Content and Scope - ***Full Functionality Testing***: Executes all test cases defined in `test_{model_name}_expansion.py`, covering all implemented features, positive flows, boundary conditions, and exception handling. -- ***Performance Testing***: Uses the `tests/dfx/perf/tests/test.json` configuration file to drive performance testing tools for stress, load, and endurance tests, collecting metrics like throughput, response time, and resource utilization. +- ***Performance Testing***: Uses `tests/dfx/perf/tests/test_qwen_omni.json`, `tests/dfx/perf/tests/test_tts.json`, and diffusion configs in the form `tests/dfx/perf/tests/test_*_vllm_omni.json` (passed to `run_benchmark.py` via `--test-config-file`) to drive performance testing tools for stress, load, and endurance tests, collecting metrics like throughput, response time, and resource utilization. - ***Documentation Testing***: Verifies whether the example code provided to users is runnable and its results match the description. ### 3.3 Test Directory and Execution Files - ***Functional Testing***: Same directories as L3. -- ***Performance Test Configuration***: `tests/dfx/perf/tests/test.json` +- ***Performance Test Configuration***: `tests/dfx/perf/tests/test_qwen_omni.json`, `tests/dfx/perf/tests/test_tts.json`, and diffusion configs `tests/dfx/perf/tests/test_*_vllm_omni.json` (e.g. `test_qwen_image_vllm_omni.json`) - ***Documentation Example Tests***: - - `tests/example/online_serving/test_{model_name}.py` - `tests/example/offline_inference/test_{model_name}.py` diff --git a/docs/contributing/ci/test_examples/l4_performance_tests.inc.md b/docs/contributing/ci/test_examples/l4_performance_tests.inc.md index 8093e1459f..f1f3073dc5 100644 --- a/docs/contributing/ci/test_examples/l4_performance_tests.inc.md +++ b/docs/contributing/ci/test_examples/l4_performance_tests.inc.md @@ -1,4 +1,4 @@ -When you want to add L4-level ***performance test*** cases, you can refer to the following format for case addition in tests/dfx/perf/tests/test.json: +When you want to add L4-level ***performance test*** cases, you can refer to the following format for case addition in `tests/dfx/perf/tests/test_qwen_omni.json`, `tests/dfx/perf/tests/test_tts.json`, or diffusion configs such as `tests/dfx/perf/tests/test_*_vllm_omni.json` (selected via `pytest ... run_benchmark.py --test-config-file `): ```JSON { diff --git a/docs/contributing/ci/test_guide.md b/docs/contributing/ci/test_guide.md index 425f24332c..08b2e3b4ea 100644 --- a/docs/contributing/ci/test_guide.md +++ b/docs/contributing/ci/test_guide.md @@ -45,7 +45,6 @@ Our test scripts use the pytest framework. First, please use `git clone https:// === "L3 level & L4 level" ```bash - cd tests pytest -s -v -m "advanced_model" --run-level=advanced_model ``` If you only want to run L3 test case, you can use: @@ -60,9 +59,9 @@ Our test scripts use the pytest framework. First, please use `git clone https:// ```bash pytest -s -v -m "core_model and distributed_cuda and L4" --run-level=core_model ``` - Note: To run performance tests, use: + Note: To run performance tests (defaults to ``test_qwen_omni.json``; use ``--test-config-file tests/dfx/perf/tests/test_tts.json`` for TTS): ```bash - pytest -s -v perf/scripts/run_benchmark.py + pytest -s -v tests/dfx/perf/scripts/run_benchmark.py ``` The latest L3 test commands for various test suites can be found in the [pipeline](https://github.com/vllm-project/vllm-omni/blob/main/.buildkite/test-merge.yml). diff --git a/docs/contributing/ci/tests_style.md b/docs/contributing/ci/tests_style.md index 8b10cf4cc1..392f004721 100644 --- a/docs/contributing/ci/tests_style.md +++ b/docs/contributing/ci/tests_style.md @@ -135,8 +135,7 @@ vllm_omni/ tests/ │ ├── test_qwen3_omni_expansion.py │ ├── test_mimo_audio.py │ ├── test_image_gen_edit.py - │ ├── test_images_generations_lora.py - │ └── stage_configs/ + │ └── test_images_generations_lora.py └── offline_inference/ ✅ ├── test_qwen2_5_omni.py ├── test_qwen3_omni.py @@ -147,17 +146,18 @@ vllm_omni/ tests/ ├── test_zimage_tensor_parallel.py ├── test_cache_dit.py ├── test_teacache.py - ├── test_stable_audio_model.py + ├── test_stable_audio_expansion.py ├── test_diffusion_cpu_offload.py ├── test_diffusion_layerwise_offload.py ├── test_diffusion_lora.py ├── test_sequence_parallel.py ├── test_qwen_image_edit_expansion.py - └── stage_configs/ - ├── qwen2_5_omni_ci.yaml - ├── qwen3_omni_ci.yaml - ├── bagel_*.yaml + └── stage_configs/ (legacy schema, still present + ├── bagel_*.yaml for unmigrated models) └── npu/, rocm/, etc. + +# Migrated models (qwen3_omni_moe, qwen2_5_omni, qwen3_tts) live under +# vllm_omni/deploy/ instead — see docs/configuration/stage_configs.md. examples/ tests │ └── examples ├── online_serving/ → ├── online_serving/ @@ -229,6 +229,7 @@ from tests.conftest import ( generate_synthetic_video, merge_base64_and_convert_to_text, ) +from tests.utils import get_deploy_config_path from vllm_omni.platforms import current_omni_platform # Edit: model name and stage config path @@ -236,7 +237,7 @@ models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"] #If you use the default configuration file, you can directly use the following address. def get_default_config(): - return str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml") + return get_deploy_config_path("ci/qwen3_omni_moe.yaml") #If you need to modify the configuration file, you can use modify_stage_config. def get_chunk_config(): diff --git a/docs/contributing/model/adding_omni_model.md b/docs/contributing/model/adding_omni_model.md index a0619e3381..1eaff10596 100644 --- a/docs/contributing/model/adding_omni_model.md +++ b/docs/contributing/model/adding_omni_model.md @@ -313,7 +313,7 @@ The registry uses lazy loading, so the model class is imported only when needed. ## Stage Configuration -Create a YAML configuration file in `vllm_omni/model_executor/stage_configs/`. For a complete example, see the [Qwen3-Omni configuration file](gh-file:vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml). +Create a YAML configuration file in `vllm_omni/deploy/`. For a complete example, see the [Qwen3-Omni configuration file](gh-file:vllm_omni/deploy/qwen3_omni_moe.yaml). ### Key Configuration Fields @@ -408,18 +408,17 @@ Understanding the data structures is crucial for implementing stage transitions: **Input to your function:** - `stage_list[source_stage_id].engine_outputs`: List of `EngineCoreOutput` objects - - Each contains `outputs`: List of `RequestOutput` objects - - Each `RequestOutput` has: - - `token_ids`: Generated token IDs - - `multimodal_output`: Dict with keys like `"code_predictor_codes"`, etc. - - These are the hidden states or intermediate outputs from the model's forward pass - - `prompt_token_ids`: Original prompt token IDs +- - Each contains `outputs`: List of `RequestOutput` objects + - Each `RequestOutput` has: +- - - `token_ids`: Generated token IDs + - `multimodal_output`: Dict with keys like `"code_predictor_codes"`, etc.These are the hidden states or intermediate outputs from the model's forward pass + - `prompt_token_ids`: Original prompt token IDs **Output from your function:** - Must return `list[OmniTokensPrompt]` where each `OmniTokensPrompt` contains: - - `prompt_token_ids`: List[int] - Token IDs for the next stage - - `additional_information`: Dict[str, Any] - Optional metadata (e.g., embeddings, hidden states) - - `multi_modal_data`: Optional multimodal data if needed +- - `prompt_token_ids`: List[int] - Token IDs for the next stage + - `additional_information`: Dict[str, Any] - Optional metadata (e.g., embeddings, hidden states) + - `multi_modal_data`: Optional multimodal data if needed ### How Model Outputs Are Stored @@ -614,7 +613,7 @@ For a complete reference implementation, see: - **Thinker**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py` - **Talker**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py` - **Code2Wav**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py` -- **Stage config**: `vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml` +- **Stage config**: `vllm_omni/deploy/qwen3_omni_moe.yaml` - **Input processors**: `vllm_omni/model_executor/stage_input_processors/qwen3_omni.py` - **Registry**: `vllm_omni/model_executor/models/registry.py` - **Testing**: `vllm_omni/tests/e2e/offline_inference/test_qwen3_omni.py` diff --git a/docs/contributing/model/adding_tts_model.md b/docs/contributing/model/adding_tts_model.md index e48ae5049f..622064173c 100644 --- a/docs/contributing/model/adding_tts_model.md +++ b/docs/contributing/model/adding_tts_model.md @@ -28,7 +28,7 @@ and can be placed on different devices. Qwen3-TTS has two stages: Each stage is a separate model class configured independently via YAML. The two stages are connected by the `async_chunk` framework, which enables inter-stage streaming for -low first-packet latency (see [Async Chunk Design](../../design/feature/async_chunk_design.md)). +low first-packet latency (see [Async Chunk Design](../../design/feature/async_chunk.md)). ### Without async_chunk (batch mode) @@ -120,8 +120,18 @@ vllm_omni/model_executor/stage_configs/ | `models/qwen3_tts/qwen3_tts.py` | Unified model class | | `models/qwen3_tts/qwen3_tts_code_predictor_vllm.py` | Stage 0 - optimized AR | | `models/qwen3_tts/qwen3_tts_code2wav.py` | Stage 1 - decoder | -| `stage_configs/qwen3_tts.yaml` | Stage config (async_chunk enabled) | -| `stage_configs/qwen3_tts_batch.yaml` | Batch mode config | +| `deploy/qwen3_tts.yaml` (new schema) | Deploy config (async_chunk enabled) — paired with `models/qwen3_tts/pipeline.py` for the frozen topology | + +> **Chunked vs end-to-end modes**: `qwen3_tts` registers a single +> pipeline whose stage 1 declares alternate processor functions — an +> `async_chunk_process_next_stage_input_func` (per-chunk streaming, used +> when `deploy.async_chunk=True`) and a `sync_process_input_func` +> (batch-end, used when `deploy.async_chunk=False`). The loader selects +> one at merge time based on the bool, so `--no-async-chunk` alone +> switches modes — no variant yaml or variant pipeline registration is +> needed. Pipelines that only make sense in one mode (e.g. +> `qwen3_omni_moe` is always chunked) can keep using the unconditional +> `custom_process_*` fields. | `stage_input_processors/qwen3_tts.py` | Stage transition processors | ## Step-by-Step Implementation @@ -574,11 +584,12 @@ Adding a TTS model to vLLM-Omni involves: | `models/qwen3_tts/qwen3_tts.py` | Unified model class | | `models/qwen3_tts/qwen3_tts_code_predictor_vllm.py` | AR stage with vLLM fused ops | | `models/qwen3_tts/qwen3_tts_code2wav.py` | Decoder stage with `chunked_decode_streaming()` | -| `stage_configs/qwen3_tts.yaml` | Stage configuration | +| `models/qwen3_tts/pipeline.py` | Frozen pipeline topology (registered at import time) | +| `deploy/qwen3_tts.yaml` | Deploy config (user-editable, async_chunk + SharedMemoryConnector) | | `stage_input_processors/qwen3_tts.py` | Stage transition processors | For more information, see: - [Architecture Overview](../../design/architecture_overview.md) -- [Async Chunk Design](../../design/feature/async_chunk_design.md) +- [Async Chunk Design](../../design/feature/async_chunk.md) - [Stage Configuration Guide](../../configuration/stage_configs.md) diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md index 7a2e64f131..6c209e5659 100644 --- a/docs/contributing/profiling.md +++ b/docs/contributing/profiling.md @@ -1,216 +1,193 @@ # Profiling vLLM-Omni -> **Warning:** Profiling incurs significant overhead. Use only for development and debugging, never in production. +> **Warning:** Profiling is for development and debugging only. It adds significant overhead and should not be enabled in production. -vLLM-Omni uses the PyTorch Profiler to analyze performance across both **multi-stage omni-modality models** and **diffusion models**. +vLLM-Omni supports two profiler backends through `profiler_config`: -### 1. Configure Profiling in the Stage YAML +- `torch`: detailed CPU/CUDA traces written to `torch_profiler_dir` +- `cuda`: low-overhead CUDA range control for NVIDIA Nsight Systems (`nsys`) -Enable profiling by adding `profiler_config` under `engine_args` for the stage(s) you want to profile in your stage config YAML: +## 1. Configure Profiling + +Use the same `profiler_config` shape everywhere: + +```yaml +profiler_config: + profiler: torch + torch_profiler_dir: ./perf +``` + +Supported fields: + +| Field | Description | +|---|---| +| `profiler` | Profiler backend. Supported values: `torch`, `cuda`. | +| `torch_profiler_dir` | Output directory for torch traces. Required when `profiler: torch`. | +| `delay_iterations` | Number of worker iterations to skip before profiling starts. | +| `max_iterations` | Maximum number of worker iterations to capture before auto-stop. | +| `warmup_iterations` | Torch-profiler warmup iterations. | +| `active_iterations` | Torch-profiler active iterations. | +| `wait_iterations` | Torch-profiler wait iterations before warmup. | + +For multi-stage omni pipelines, put `profiler_config` under the target stage's `engine_args`. ```yaml stage_args: - stage_id: 0 stage_type: llm engine_args: - # ... other engine args ... profiler_config: profiler: torch torch_profiler_dir: ./perf ``` -| Field | Description | -|---|---| -| `profiler` | Profiler backend to use. Currently supports `torch`. | -| `torch_profiler_dir` | Directory where trace files are saved. Created automatically if it doesn't exist. | - -> **Tip:** Only enable `profiler_config` on stages you actually need to profile. Stages without it will not start a profiler, keeping overhead minimal. - -### 2. Profiling Omni-Modality Models +For single-stage diffusion usage, pass `profiler_config` directly to `Omni(...)` or `vllm serve`. -**Selective Stage Profiling** +## 2. Profiling Omni Pipelines -It is highly recommended to profile specific stages to prevent producing overly large trace files: +It is usually best to profile only the stages you need. ```python -# Profile all stages -omni_llm.start_profile() +# Profile all stages. +omni.start_profile() -# Only profile Stage 1 -omni_llm.start_profile(stages=[1]) - -# Stage 0 (Thinker) and Stage 2 (Audio Decoder) for qwen omni -omni_llm.start_profile(stages=[0, 2]) +# Profile selected stages only. +omni.start_profile(stages=[0, 2]) +... +omni.stop_profile(stages=[0, 2]) ``` -> **Important:** Always pass the same `stages` list to both `start_profile()` and `stop_profile()`. If you omit `stages` from `stop_profile()`, it defaults to stopping all stages — including ones that were never started — which will produce errors. +Always stop the same stage set that you started. If only some stages have `profiler_config`, pass an explicit `stages=[...]` list instead of relying on the default "all stages" behavior. -**Python Usage**: Wrap your generation logic with `start_profile()` and `stop_profile()`. +Examples: -```python -profiler_stages = [0] # Only profile the stages you need +1. [Qwen2.5-Omni end2end](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen2_5_omni/end2end.py) +2. [Qwen3-Omni end2end](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen3_omni/end2end.py) -# 1. Start profiling -omni.start_profile(stages=profiler_stages) +## 3. Profiling Single-Stage Diffusion -# Initialize generator -omni_generator = omni.generate(prompts, sampling_params_list, py_generator=args.py_generator) +Single-stage diffusion models use the same `start_profile()` / `stop_profile()` controls, but you must provide `profiler_config` explicitly. -total_requests = len(prompts) -processed_count = 0 +### PyTorch profiler -# Main Processing Loop -for stage_outputs in omni_generator: - - # ... [Output processing logic for text/audio would go here] ... +```python +from vllm_omni import Omni + +omni = Omni( + model="Wan-AI/Wan2.2-I2V-A14B-Diffusers", + profiler_config={ + "profiler": "torch", + "torch_profiler_dir": "./perf", + }, +) + +omni.start_profile() +... +omni.stop_profile() +``` - # Update count to track when to stop profiling - processed_count += len(stage_outputs.request_output) +### Nsight Systems (`nsys`) - # 2. Check if all requests are done to stop the profiler safely - if profiler_enabled and processed_count >= total_requests: - print(f"[Info] Processed {processed_count}/{total_requests}. Stopping profiler inside active loop...") +For Nsight Systems, use `profiler: cuda` and wrap the process with `nsys profile`. - # Stop the profiler while workers are still active - # Pass the same stages list used in start_profile() - omni_llm.stop_profile(stages=profiler_stages) +```bash +nsys profile \ + --trace-fork-before-exec=true \ + --cuda-graph-trace=node \ + --capture-range=cudaProfilerApi \ + --capture-range-end=repeat \ + -o diffusion_trace \ + python image_to_video.py ... +``` - # Wait for traces to flush to disk - print("[Info] Waiting 30s for workers to write trace files to disk...") - time.sleep(30) - print("[Info] Trace export wait time finished.") +The Python process being profiled must create the diffusion engine with: -omni_llm.close() +```python +profiler_config={"profiler": "cuda"} ``` +Then call `start_profile()` before the requests you want to capture and `stop_profile()` after them. The diffusion worker processes open and close the CUDA capture range themselves, so `nsys` sees the actual GPU work instead of only the parent process. -**CLI Usage** (using `end2end.py`): -```bash -# Profile only Stage 0 (Thinker) -python end2end.py --output-wav output_audio \ - --query-type text --enable-profiler --profiler-stages 0 +Examples: -# Profile Stage 0 and Stage 2 -python end2end.py --output-wav output_audio \ - --query-type text --enable-profiler --profiler-stages 0 2 - -# Profile all stages (omit --profiler-stages) -python end2end.py --output-wav output_audio \ - --query-type text --enable-profiler -``` +1. [Image edit example](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/image_to_image/image_edit.py) +2. [Image to video example](https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video) -**Examples**: +## 4. Profiling Online Serving -1. **Qwen2.5-Omni**: [https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen2_5_omni/end2end.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen2_5_omni/end2end.py) +When any stage has `profiler_config.profiler` set, the server exposes: -2. **Qwen3-Omni**: [https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen3_omni/end2end.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen3_omni/end2end.py) +- `POST /start_profile` +- `POST /stop_profile` -### 3. Profiling diffusion models +### Start the server -Diffusion profiling is End-to-End, capturing encoding, denoising loops, and decoding. Standalone diffusion scripts use `--profiler-dir` to enable profiling. +Multi-stage omni serving: -**CLI Usage:** ```bash -python image_to_video.py \ - --model Wan-AI/Wan2.2-I2V-A14B-Diffusers \ - --image qwen-bear.png \ - --prompt "A cat playing with yarn, smooth motion" \ - --profiler-dir \ - \ - # Minimize Spatial Dimensions (Optional but helpful): - # Drastically reduces memory usage so the profiler doesn't - # crash due to overhead, though for accurate performance - # tuning you often want target resolutions. - --height 48 \ - --width 64 \ - \ - # Minimize Temporal Dimension (Frames): - # Video models process 3D tensors (Time, Height, Width). - # Reducing frames to the absolute minimum (2) keeps the - # tensor size small, ensuring the trace file doesn't become - # multi-gigabytes in size. - --num-frames 2 \ - \ - # Minimize Iteration Loop (Steps): - # This is the most critical setting for profiling. - # Diffusion models run the same loop X times. - # Profiling 2 steps gives you the exact same performance - # data as 50 steps, but saves minutes of runtime and - # prevents the trace viewer from freezing. - --num-inference-steps 2 \ - \ - --guidance-scale 5.0 \ - --guidance-scale-high 6.0 \ - --boundary-ratio 0.875 \ - --flow-shift 12.0 \ - --fps 16 \ - --output i2v_output.mp4 +vllm serve Qwen/Qwen2.5-Omni-7B \ + --omni \ + --port 8091 ``` -> **Note:** For diffusion stages within a multi-stage omni pipeline, use `profiler_config` in the stage YAML instead (see Section 1). - -**Examples**: - -1. **Qwen image edit**: [https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/image_to_image/image_edit.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/image_to_image/image_edit.py) - -2. **Wan-AI/Wan2.2-I2V-A14B-Diffusers**: [https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video](https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video) - -### 4. Profiling Online Serving +(The default deploy config at `vllm_omni/deploy/qwen2_5_omni.yaml` is loaded automatically. Pass `--deploy-config /path/to/custom.yaml` to override.) -When `profiler_config` is set in the stage YAML, the server automatically exposes `/start_profile` and `/stop_profile` HTTP endpoints. +Single-stage diffusion serving with torch profiler: -**1. Start the server** with a stage YAML that has `profiler_config` enabled: ```bash -vllm serve Qwen/Qwen2.5-Omni-7B \ - --omni \ - --stage-configs-path qwen2_5_omni.yaml \ - --port 8091 +vllm serve Wan-AI/Wan2.2-I2V-A14B-Diffusers \ + --omni \ + --port 8091 \ + --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}' ``` -Or for one stage diffusion models: +Single-stage diffusion serving with Nsight Systems: ```bash -vllm serve Wan-AI/Wan2.2-I2V-A14B-Diffusers --omni --port 8091 --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}' +nsys profile \ + --trace-fork-before-exec=true \ + --cuda-graph-trace=node \ + --capture-range=cudaProfilerApi \ + --capture-range-end=repeat \ + -o serving_trace \ + vllm serve Wan-AI/Wan2.2-I2V-A14B-Diffusers \ + --omni \ + --port 8091 \ + --profiler-config '{"profiler": "cuda"}' ``` -**2. Start profiling** by sending a POST request: +### Control capture + ```bash -# Profile all stages that have profiler_config set +# Start profiling on all profiled stages. curl -X POST http://localhost:8091/start_profile -# Profile specific stages only +# Start profiling on selected stages. curl -X POST http://localhost:8091/start_profile \ - -H "Content-Type: application/json" \ - -d '{"stages": [0]}' -``` + -H "Content-Type: application/json" \ + -d '{"stages": [0]}' -**3. Send your inference requests** as normal while the profiler is running. - -**4. Stop profiling** and collect traces: -```bash -# Stop all stages +# Stop profiling. curl -X POST http://localhost:8091/stop_profile - -# Stop specific stages (must match the stages you started) -curl -X POST http://localhost:8091/stop_profile \ - -H "Content-Type: application/json" \ - -d '{"stages": [0]}' ``` -Trace files are written to the `torch_profiler_dir` specified in your stage YAML. +For mixed-stage pipelines, use explicit `stages` and pass the same stage list to both endpoints. + +## 5. Analyze Results -> **Important:** Always stop the same stages you started. Stopping a stage that was never started will produce errors. +Torch profiler output: -### 5. Analyzing Traces +- Chrome/Perfetto traces under `torch_profiler_dir` +- Optional aggregated CUDA-time tables under the same directory -Output files are saved to the `torch_profiler_dir` specified in your stage YAML config. +CUDA profiler / Nsight Systems output: -**Output** -**Chrome Trace** (`.json.gz`): Visual timeline of kernels and stages. Open in Perfetto UI. +- `.nsys-rep` report files written by `nsys -o ...` -**Viewing Tools:** +Recommended viewers: -- [Perfetto](https://ui.perfetto.dev/) (recommended) -- `chrome://tracing` (Chrome only) +- [Perfetto](https://ui.perfetto.dev/) for torch traces +- `nsys stats .nsys-rep` for CLI summaries +- Nsight Systems GUI for CUDA kernel timelines -**Note**: vLLM-Omni reuses the PyTorch Profiler infrastructure from vLLM. See the official vLLM profiler documentation: [vLLM Profiling Guide](https://docs.vllm.ai/en/stable/contributing/profiling/) +vLLM-Omni reuses the vLLM profiling infrastructure where possible. For the upstream reference, see the [vLLM profiling guide](https://docs.vllm.ai/en/stable/contributing/profiling/). diff --git a/docs/design/feature/async_chunk_design.md b/docs/design/feature/async_chunk.md similarity index 80% rename from docs/design/feature/async_chunk_design.md rename to docs/design/feature/async_chunk.md index 202ef0e18e..57b4209b8d 100644 --- a/docs/design/feature/async_chunk_design.md +++ b/docs/design/feature/async_chunk.md @@ -1,4 +1,4 @@ -# Async Chunk Design +# Async Chunk ## Table of Contents @@ -19,7 +19,7 @@ The `async_chunk` feature enables asynchronous, chunked processing of data acros For qwen3-omni: - **Thinker → Talker**: Per decode step (typically chunk_size=1) -- **Talker → Code2Wav**: Accumulated to `codec_chunk_frames` (default=25) before sending. During the initial phase, a dynamic initial chunk size (IC) is automatically selected based on server load to reduce TTFA. Use the per-request `initial_codec_chunk_frames` API field to override. +- **Talker → Code2Wav**: Accumulated to `codec_chunk_frames` (default=25) before sending. During the initial phase, a dynamic initial chunk size (IC) is automatically selected based on server load to reduce TTFP. Use the per-request `initial_codec_chunk_frames` API field to override. - **Code2Wav**: Streaming decode with code2wav chunk_size With `async_chunk`: @@ -75,26 +75,85 @@ Enabling **async_chunk** (False→True) sharply reduces time-to-first-audio (TTF

## Architecture -### Data Flow -#### Sequential Flow +### Async Chunk Pipeline Overview + +The following diagram illustrates the **Async Chunk Architecture** for multi-stage models (e.g., Qwen3-Omni with Thinker → Talker → Code2Wav), showing how data flows through the 4-stage pipeline with parallel processing and dual-stream output: +

- - Data Flow between stages + + Async Chunk Pipeline Architecture

-#### Async Chunk Flow +**Diagram Legend:** + +| Step | Stage Type | Description | +|------|-----------|------------| +| `prefill` | Initialization | Context processing, KV cache initialization | +| `decode` | Autoregressive | Token-by-token generation in AR stages | +| `codes` | Audio Encoding | RVQ codec codes from Talker stage | +| `output` | Final Output | Text chunks or audio waveforms | + +### Data Flow + +#### Stage 0: Thinker (Multimodal Understanding + Text Generation) +- **Prefill**: Processes multimodal input (text/image/audio/video), initializes KV cache +- **Decode Loop**: Generates text tokens autoregressively +- **Chunk Triggers**: Each decode step (typically `chunk_size=1`) can trigger downstream processing +- **Dual Output**: + - **Text Stream**: `text_0`, `text_1`, `text_2`... `text_n` streamed to output + - **Hidden States**: Passed to Talker stage for audio synthesis + +#### Stage 1: Talker (Text → RVQ Audio Codes) +- **Prefill**: Receives hidden states from Thinker as semantic condition +- **Decode Loop**: Generates RVQ codec codes autoregressively +- **Accumulation**: Codes accumulate to `codec_chunk_frames` (default=25) before forwarding +- **Dynamic IC**: Initial chunk size auto-selected based on server load to optimize TTFP +- **Output**: `codes` blocks (chunk 0, 1, ... n) sent to Code2Wav + +#### Stage 2: Code2Wav (Vocoder Decoder) +- **Non-Autoregressive**: Processes RVQ codes in parallel batches +- **Streaming Decode**: Converts codes to audio waveforms chunk-by-chunk +- **Batching**: Supports batched inference for multiple concurrent requests +- **Output**: Audio segments `audio_0`, `audio_1`, ... `audio_n` + +#### Stage 3: Output (Dual Stream) +- **Text Streaming**: `text_0` → `text_1` → `text_2` → ... (user sees response in real-time) +- **Audio Streaming**: `audio_0` → `audio_1` → ... (user hears audio progressively) + +### Execution Timeline +``` +Timeline: Parallel vs Sequential + +Sequential (async_chunk=false): +[Thinker: ████████████████████] (2.0s) + [Talker: ████████████████████] (3.0s) + [Code2Wav: ████] (1.0s) +Total: 6.0s, TTFP: 6.0s + +Async Chunk (async_chunk=true): +[Thinker: ████░░░░████░░░░████] (2.0s, streaming) + [Talker: ░░████░░░░████░░] (3.0s, parallel) + [Code2Wav: ░░░░████░░] (1.0s, batched) +Total: ~3.5s, TTFP: ~0.5s + +█ = Active computation ░ = Waiting/idle +``` + +#### Sequential Flow (for comparison)

- - Data Flow between stages + + Sequential Data Flow

-### Async Chunk architecture +In sequential mode, each stage must wait for the previous stage to complete entirely before starting. + +### Async Chunk System Architecture

diff --git a/docs/design/feature/cfg_parallel.md b/docs/design/feature/cfg_parallel.md index 64decbe956..c73a87749f 100644 --- a/docs/design/feature/cfg_parallel.md +++ b/docs/design/feature/cfg_parallel.md @@ -25,7 +25,9 @@ In standard Classifier-Free Guidance, each diffusion step requires two forward p 1. **Positive/Conditional**: Guided by the text prompt 2. **Negative/Unconditional**: Typically using empty or negative prompt -CFG-Parallel eliminates this bottleneck by distributing the two forward passes across different GPU ranks, allowing them to execute simultaneously rather than sequentially. +Some models require 3 or more CFG branches (see [N-Branch CFG](#n-branch-cfg-3-branches)). + +CFG-Parallel eliminates this bottleneck by distributing the forward passes across different GPU ranks, allowing them to execute simultaneously rather than sequentially. ### Architecture @@ -33,9 +35,11 @@ vLLM-omni provides `CFGParallelMixin` that encapsulates all CFG parallel logic. | Method | Purpose | Automatic Behavior | |--------|---------|-------------------| -| [`predict_noise_maybe_with_cfg()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Predict noise with CFG | Detects parallel mode, distributes computation, gathers results | +| [`predict_noise_maybe_with_cfg()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Predict noise with 2-branch CFG | Detects parallel mode, distributes computation, gathers results | +| [`predict_noise_with_multi_branch_cfg()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Predict noise with N-branch CFG | Round-robin dispatches N branches across M GPUs | | [`scheduler_step_maybe_with_cfg()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Step scheduler | All ranks step locally (no broadcast needed) | -| [`combine_cfg_noise()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Combine positive/negative | Applies CFG formula with optional normalization | +| [`combine_cfg_noise()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Combine 2-branch predictions | Applies CFG formula with optional normalization | +| [`combine_multi_branch_cfg_noise()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Combine N-branch predictions | Override for custom multi-branch combine logic | | [`predict_noise()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Forward pass wrapper | Override for custom transformer calls | | [`cfg_normalize_function()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Normalize CFG output | Override for custom normalization | @@ -57,6 +61,22 @@ vLLM-omni provides `CFGParallelMixin` that encapsulates all CFG parallel logic. - All ranks compute the scheduler step locally — no broadcast needed because `predict_noise_maybe_with_cfg` already ensures all ranks have identical noise predictions after `all_gather` + local combine. +### N-Branch CFG (3+ branches) + +Some models require more than 2 CFG branches. For example, Bagel and OmniGen2 use 3 branches, DreamID Omni uses 4 branches. + +`predict_noise_with_multi_branch_cfg()` handles these by automatically dispatching N branches across M GPUs using round-robin (rule: branch `i` → rank `i % M`): + +| Branches (N) | GPUs (M) | Dispatch | +|:---:|:---:|:---| +| 3 | 2 | `[[0, 2], [1]]` | +| 3 | 3 | `[[0], [1], [2]]` | +| 4 | 2 | `[[0, 2], [1, 3]]` | +| 4 | 3 | `[[0, 3], [1], [2]]` | +| 4 | 4 | `[[0], [1], [2], [3]]` | + +When a rank handles multiple branches, it runs them sequentially. After `all_gather`, all ranks execute `combine_multi_branch_cfg_noise()` locally, producing identical results. + --- ## Step-by-Step Implementation @@ -98,6 +118,7 @@ class YourModelPipeline(nn.Module, CFGParallelMixin): - `positive_kwargs`: transformer arguments for conditional (text-guided) prediction - `negative_kwargs`: transformer arguments for unconditional prediction (set to `None` if CFG disabled) - For image editing pipelines, add `output_slice=image_seq_len` to extract the generative image portion +- For models with 3+ CFG branches, see [Multi-Branch CFG](#multi-branch-cfg-3-branches) in the Customization section ### Step 2: Call `diffuse` @@ -171,20 +192,42 @@ class LongCatImagePipeline(nn.Module, CFGParallelMixin): ``` -### Override `combine_cfg_noise()` for Multi-Output Models +### Multi-Branch CFG (3+ branches) + +For models with 3 or more CFG branches, use `predict_noise_with_multi_branch_cfg()` instead of `predict_noise_maybe_with_cfg()`, and override `combine_multi_branch_cfg_noise()` for custom combine logic. This interface also works for standard 2-branch CFG — just pass 2 branches in `branches_kwargs`. -When `predict_noise()` returns a tuple (e.g., video + audio), the default `combine_cfg_noise()` applies CFG to every element. Override it to apply different logic per element — for example, CFG on video but positive-only on audio: +**Example (3-branch with dual guidance scale):** ```python -class MyVideoAudioPipeline(nn.Module, CFGParallelMixin): - def combine_cfg_noise(self, positive_noise_pred, negative_noise_pred, scale, normalize): - (video_pos, audio_pos) = positive_noise_pred - (video_neg, audio_neg) = negative_noise_pred - video_combined = super().combine_cfg_noise(video_pos, video_neg, scale, normalize) - return (video_combined, audio_pos) # audio: positive only, no CFG +class YourMultiBranchPipeline(nn.Module, CFGParallelMixin): + def combine_multi_branch_cfg_noise(self, predictions, true_cfg_scale, cfg_normalize=False): + text_scale = true_cfg_scale["text"] + image_scale = true_cfg_scale["image"] + pos, ref, uncond = predictions + return uncond + image_scale * (ref - uncond) + text_scale * (pos - ref) + + def diffuse(self, ...): + for i, t in enumerate(timesteps): + positive_kwargs = {...} # conditional prompt + ref_neg_kwargs = {...} # negative prompt + reference + uncond_kwargs = {...} # unconditional + + noise_pred = self.predict_noise_with_multi_branch_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale={"text": text_guidance_scale, "image": image_guidance_scale}, + branches_kwargs=[positive_kwargs, ref_neg_kwargs, uncond_kwargs], + ) + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + + return latents ``` -This also requires `predict_noise()` to return a tuple (see [Override predict_noise](#override-predict_noise-for-custom-transformer-calls) above). +### Override Combine Functions + +There are two combine functions for different scenarios: + +- **`combine_cfg_noise()`** — Used by `predict_noise_maybe_with_cfg()`. Override when `predict_noise()` returns a tuple (e.g., video + audio) and you need per-element CFG logic. +- **`combine_multi_branch_cfg_noise()`** — Used by `predict_noise_with_multi_branch_cfg()`. Override to implement custom multi-branch combine formulas (see [Multi-Branch CFG](#multi-branch-cfg-3-branches) above). ### Implement a Composite Scheduler for Multi-Output Models @@ -303,4 +346,5 @@ Adding CFG-Parallel support: 1. ✅ **Create mixin** - Inherit from `CFGParallelMixin` and implement `diffuse()` method 2. ✅ **(Optional) Customize** - Override `predict_noise()` or `cfg_normalize_function()` for custom behavior -3. ✅ **Test** - Verify with `--cfg-parallel-size 2` and compare performance +3. ✅ **(Optional) Multi-branch** - For 3+ branch models, use `predict_noise_with_multi_branch_cfg()` and override `combine_multi_branch_cfg_noise()` +4. ✅ **Test** - Verify with `--cfg-parallel-size 2` (or 3/4 for multi-branch) and compare performance diff --git a/docs/design/feature/expert_parallel.md b/docs/design/feature/expert_parallel.md index 9a7c4cdbac..e05eec3361 100644 --- a/docs/design/feature/expert_parallel.md +++ b/docs/design/feature/expert_parallel.md @@ -207,9 +207,9 @@ Complete examples in the codebase: | Model | Path | Pattern | Notes | |-------|------|---------|-------| -| **HunyuanImage3.0** | `vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py` | Standard EP | Full implementation with validation | +| **HunyuanImage3.0** | `vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py` | Standard EP | Full implementation with validation | | **EP Tests** | `vllm-omni/tests/e2e/offline_inference/test_expert_parallel.py` | E2E testing | EP correctness and performance | -| **Constraint Tests** | `vllm-omni/tests/diffusion/models/hunyuan_image_3/test_hunyuan_fused_moe.py` | Unit testing | Validation logic | +| **Constraint Tests** | `vllm-omni/tests/diffusion/models/hunyuan_image3/test_hunyuan_fused_moe.py` | Unit testing | Validation logic | --- ## Summary diff --git a/docs/design/feature/prefix_caching.md b/docs/design/feature/prefix_caching.md new file mode 100644 index 0000000000..ebad8b6910 --- /dev/null +++ b/docs/design/feature/prefix_caching.md @@ -0,0 +1,164 @@ +# Automatic Prefix Caching in Omni Models + + +--- + +## Table of Contents + +- [Overview](#overview) +- [High-Level Approach](#high-level-approach) +- [Example](#example) +- [What About Multimodal Inputs?](#what-about-multimodal-inputs) + +--- + +### Overview + +Prefix caching in the context of kv-cache management is a useful optimization for avoiding redundant computations. The main idea is that we store portions of the kv-cache from processed requests, so that we can reuse them if incoming requests have the same prefix as previous requests. + +vLLM manages the kv-cache as blocks, which represent a span of tokens of a fixed length. Blocks are hashable by the content that they contain, which typically means the tokens within the span, but also could be influenced by other factors, e.g., LoRA and multimodal data. + +vLLM implements automatic prefix caching for managing its kv-cache, which is best understood by reading the design document [here](https://docs.vllm.ai/en/latest/design/prefix_caching/). vLLM-Omni builds on top of the prefix caching mechanism in a noninvasive way to allow caching between stages in Omni pipelines. This typically means for a given stage we aim to support caching for the following: + +- The last hidden states produced by the stage +- Model / stage specific multimodal data + +!!! note "Note 1" + This document describes vLLM-Omni's mechanism for caching tensor outputs that are meant to be passed between stages, when requests have common prefixes, similar to the way in which vLLM has prefix caching for the kv-cache. This works in conjunction with vLLM's multimodal encoder caching, but is distinct. See the final section for a concrete example for how they tie together in practice. + +### High-Level Approach +!!! note "Note 2" + Prior to reading this section, it's recommended to take a look at the design documents in vLLM for [Automatic Prefix Caching](https://docs.vllm.ai/en/latest/features/automatic_prefix_caching/), which will make some of the concepts more clear. + +The main focus of vLLM-Omni's approach to prefix caching stage outputs is to build on vLLM's prefix caching in the least invasive way possible while minimizing impact for cache misses, and consuming a minimal amount of GPU memory. To understand the implementation, there are a few important things to note: + +- Between stages, device tensors are generally moved to CPU; this is important since we're just caching the outputs of stages, so it is okay to keep the entire cache on the CPU. + +- For a tensor to be considered cacheable, the first dimension (currently) needs to be the same as the token count, as it allows us to reuse block/slot mappings for our externally maintained tensor caches. This allows us to dynamically discover the tensors to be marked as cacheable outputs in each Omni model without having to explicitly specify cacheable output field names in every model. + +With this in mind, consider the set of blocks in a 2D layout, where the row represents the index of blocks being considered, and the columns represent the slots corresponding to tokens within each block. Since we know the `num_blocks` and `block_size` from our kv cache config, if we want to cache a tensor with feature size `D`, we can preallocate a CPU tensor of size `(num_blocks, block_size, D)`, and use the same block index and slot mapping to retrieve the corresponding feature vector. + + +### Example +!!! note "Note 3" + Prefix caching in vLLM-Omni currently is only supported on AutoRegressive stages with one kv-cache group. It can be enabled/disabled per-stage via the `enable_prefix_caching` parameter in the model's stage config. + +The way in which vLLM-Omni ties into vLLM's prefix caching is best understood by example. Say that we have the following: + +- `num_blocks=8` +- `block_size=4` +- `hidden_size=2` +- A stage specific multimodal output tensor named `mm_feature` with feature dimension `16` + +The prefix cache flow is then outlined below. + +1. When the model is initialized, we can determine the `hidden_size` from the `ModelConfig`, and allocate a cache of size `(num_blocks, block_size, hidden_size)`. + +2. Say we process the request `The quick brown fox was tired and slept beneath the shady tree`, which is 12 tokens and evenly divides into 3 blocks as shown below. + +``` + [ The quick brown fox ] [ was tired and slept ] [beneath the shady tree ] +Block 1: |<--- block tokens ---->| +Block 2: |<------- prefix ------>| |<--- block tokens --->| +Block 3: |<------------------ prefix -------------------->| |<--- block tokens ---->| +``` + +When the request processes, we inspect the multimodal outputs and identify the `mm_feature` tensor, which will be of shape `(seq_len, feature_dim)`, i.e., `(12, 16)` in this example. We note that the first axis is dependent on the `seq_len` and add a new cache_tensor of shape `(num_blocks, block_size, feature_dim)` to our multimodal cache for tensors. + + +3. If we lay out the cache as a 2D tensor of shape (`num_blocks`, `block_size`), we'll have something like the following: + +``` +0: [ The quick brown fox ] +1: [ was tired and slept ] +2: [beneath the shady tree ] +3: [EMPTY] +... +7: [EMPTY] +``` + +Or, if we flatten it down to 1D, +``` +0: The +1: quick +2: brown +3: fox +... +11: tree +12: [EMPTY] +... +``` + +which we can think of as row indices into the hidden states tensor if we view it as the 2D shape `(num_blocks x block_size, feature_dim)`. That is, the analogous flattened (from 3D -> 2D) mapping of the cache for hidden states becomes the following. +``` +0: +1: +2: +3: +... +11: +12: [EMPTY] +... +``` + +Similarly, for the multimodal outputs cache, the flattened coordinates are the same, but the `mm_feature` maps to vectors of length `16` instead of the hidden size of `2`. Note that in practice, we may have multiple multimodal output tensors per forward pass, which may have different names and different feature dimensions. + + +4. Now, say that we receive a new request `The quick brown fox jumped over the dog`. + +``` + [ The quick brown fox ] [ jumped over the dog ] +Block 1: |<--- block tokens ---->| +Block 2: |<------- prefix ------>| |<--- block tokens --->| +``` + +Here, we will have a cache hit for `Block 1` which will be detected by vLLM based on the hash of the first block when it's handling the prefix caching on the kv-cache. As a result, when we get the output from the scheduler, we will see that `num_computed_tokens=4` (corresponding to the cached first block), and we only need to process the remaining 4 new tokens in the new prefill. + +Since we have the block indices / slot mappings from the kv cache manager, we can simply mirror the mappings and leverage the same indices for the cached hidden states and multimodal outputs. This allows us to look up the correct tensors from our externally maintained 3D caches. + +``` +0: [ The quick brown fox ] < already in the cache +1: [ was tired and slept ] +2: [beneath the shady tree ] +3: [ jumped over the dog ] < added on the second request +4: [EMPTY] +... +7: [EMPTY] +... +``` + +Finally, to pass the full hidden states and multimodal outputs to the next stage, we simply concatenate the cached contents with the corresponding new tensors computed from the current forward call. + + +### What About Multimodal Inputs? +It's also useful to consider the case about how Omni prefix caching is handled when we have multimodal inputs that don't cleanly end on block boundaries, as well as how this works with multimodal encoder caching in vLLM. For example: + +``` + [ Im0 Im1 Im2 Im3 ] [ Im4 Im5 foo ] +Block 1: |<--- block tokens ---->| +Block 2: |<------- prefix ------>| |<--- block tokens --->| +``` + +In this case, only `Block 1` will have outputs stored in the prefix tensor cache, because vLLM does not store partial blocks. This may appear to be a problem at first glance, because the multimodal input is fragmented across a new block that wasn't cached. + +In reality, this isn't a big problem for correctness, because vLLM also maintains an encoder cache for multimodal inputs. In other words, after the first pass, we'll have the following: + +- The Block 1 hash, which is used for prefix caching +- The hash describing the image data starting at position 0 and with length 6 +- In vLLM's encoder cache, a mapping from the image hash above to the encoder output + + +To understand what happens, say we get the following input as a second request: +``` + [ Im0 Im1 Im2 Im3 ] [ Im4 Im5 bar baz ] +Block 1: |<--- block tokens ---->| +Block 2: |<------- prefix ------>| |<--- block tokens --->| +``` + +First, the scheduler will check for a prefix cache hit, which we will see on `Block 1`. As a result, we will have 4 tokens marked as precomputed, and only see the remaining 4 tokens in the following prefill. + +Because we have multimodal data in a scheduled span that isn't fully precomputed, we still need to call the visual encoder. However, since we have the image hash and encoder cache, we will retrieve the encoder outputs for `Im4` and `Im5` as we create the multimodal embeddings. + +When we pass our multimodal tensors to the language model component in the same stage, we'll then expect the same outputs, because the prefix caching behaviors in vLLM-Omni / vLLM match, so the LLM will use vLLM's KV cache manager's prefix caching to correctly handle the attention information for `Block 1` while calculating the outputs for `Block 2`, giving us the correct results for processing `Block 2` with the context of `Block 1`. + +Finally, we look up the output hidden states/multimodal tensors corresponding to the prefix cache hit `Block 1` and concatenate it with the forward pass result to get the final result, which is expected to be identical to the full hidden states when prefix caching is disabled. diff --git a/docs/design/feature/teacache.md b/docs/design/feature/teacache.md index 9fa315cee7..8577cff1f0 100644 --- a/docs/design/feature/teacache.md +++ b/docs/design/feature/teacache.md @@ -326,9 +326,41 @@ for prompt in tqdm(prompts, desc="Collecting data"): # Estimate coefficients coeffs = estimator.estimate(poly_order=4) -print(f"Estimated coefficients: {coeffs.tolist()}") +print(f"Estimated coefficients: {coeffs}") ``` +Note: some models may require the vLLM context and config to be initialized to initialize vLLM modules. To this end, you may need a workaround like the following to be able to run coefficient estimation. +```python +from vllm_omni.diffusion.forward_context import set_forward_context +from vllm_omni.diffusion.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) +from vllm.config import VllmConfig +... + +if __name__ == "__main__": + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "8192" + os.environ["LOCAL_RANK"] = "0" + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + + vllm_config = VllmConfig() + init_distributed_environment() + initialize_model_parallel() + + # NOTE: you may have to pass an initialized OmniDiffusionConfig as a kwarg + # here to make current sp checks happy; if this is the case, just create one + # .from_kwargs() with the model name to get around this check for now, + # since your estimator subclass should handle the actual model configuration. + # + # This will be cleaned up in the future + with set_forward_context(vllm_config): + +``` + + **Data Statistics Guide:** | Metric | Good Range | Warning Signs | diff --git a/docs/design/figures/omni/E2EL_s_vllm_omni_vs_transformers.png b/docs/design/figures/omni/E2EL_s_vllm_omni_vs_transformers.png new file mode 100644 index 0000000000..15112d5862 Binary files /dev/null and b/docs/design/figures/omni/E2EL_s_vllm_omni_vs_transformers.png differ diff --git a/docs/design/figures/omni/Mean_AUDIO_RTF_Baseline_vs_Batch.png b/docs/design/figures/omni/Mean_AUDIO_RTF_Baseline_vs_Batch.png new file mode 100644 index 0000000000..2f0615f77b Binary files /dev/null and b/docs/design/figures/omni/Mean_AUDIO_RTF_Baseline_vs_Batch.png differ diff --git a/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_CUDA_Graph_vs_Async_Chunk.png b/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_CUDA_Graph_vs_Async_Chunk.png new file mode 100644 index 0000000000..62d8bc79b6 Binary files /dev/null and b/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_CUDA_Graph_vs_Async_Chunk.png differ diff --git a/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_vs_Batch_CUDA_Graph.png b/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_vs_Batch_CUDA_Graph.png new file mode 100644 index 0000000000..5838b45319 Binary files /dev/null and b/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_vs_Batch_CUDA_Graph.png differ diff --git a/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Baseline_vs_Batch.png b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Baseline_vs_Batch.png new file mode 100644 index 0000000000..24be814b7e Binary files /dev/null and b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Baseline_vs_Batch.png differ diff --git a/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_CUDA_Graph_vs_Async_Chunk.png b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_CUDA_Graph_vs_Async_Chunk.png new file mode 100644 index 0000000000..c8df58ebcd Binary files /dev/null and b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_CUDA_Graph_vs_Async_Chunk.png differ diff --git a/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_vs_Batch_CUDA_Graph.png b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_vs_Batch_CUDA_Graph.png new file mode 100644 index 0000000000..2d1a04e9c2 Binary files /dev/null and b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_vs_Batch_CUDA_Graph.png differ diff --git a/docs/design/figures/omni/Mean_E2EL_ms_Baseline_vs_Batch.png b/docs/design/figures/omni/Mean_E2EL_ms_Baseline_vs_Batch.png new file mode 100644 index 0000000000..e598b54343 Binary files /dev/null and b/docs/design/figures/omni/Mean_E2EL_ms_Baseline_vs_Batch.png differ diff --git a/docs/design/figures/omni/Mean_E2EL_ms_Batch_CUDA_Graph_vs_Async_Chunk.png b/docs/design/figures/omni/Mean_E2EL_ms_Batch_CUDA_Graph_vs_Async_Chunk.png new file mode 100644 index 0000000000..54452013eb Binary files /dev/null and b/docs/design/figures/omni/Mean_E2EL_ms_Batch_CUDA_Graph_vs_Async_Chunk.png differ diff --git a/docs/design/figures/omni/Mean_E2EL_ms_Batch_vs_Batch_CUDA_Graph.png b/docs/design/figures/omni/Mean_E2EL_ms_Batch_vs_Batch_CUDA_Graph.png new file mode 100644 index 0000000000..04c5ad7396 Binary files /dev/null and b/docs/design/figures/omni/Mean_E2EL_ms_Batch_vs_Batch_CUDA_Graph.png differ diff --git a/docs/design/figures/omni/RTF_vllm_omni_vs_transformers.png b/docs/design/figures/omni/RTF_vllm_omni_vs_transformers.png new file mode 100644 index 0000000000..d93ba0b2af Binary files /dev/null and b/docs/design/figures/omni/RTF_vllm_omni_vs_transformers.png differ diff --git a/docs/design/figures/omni/Summary_E2EL_ms_vs_features.png b/docs/design/figures/omni/Summary_E2EL_ms_vs_features.png new file mode 100644 index 0000000000..04087b5910 Binary files /dev/null and b/docs/design/figures/omni/Summary_E2EL_ms_vs_features.png differ diff --git a/docs/design/figures/omni/Summary_RTF_vs_features.png b/docs/design/figures/omni/Summary_RTF_vs_features.png new file mode 100644 index 0000000000..c2c8ad4083 Binary files /dev/null and b/docs/design/figures/omni/Summary_RTF_vs_features.png differ diff --git a/docs/design/figures/omni/Summary_TTFP_ms_vs_features.png b/docs/design/figures/omni/Summary_TTFP_ms_vs_features.png new file mode 100644 index 0000000000..3dcc1c5537 Binary files /dev/null and b/docs/design/figures/omni/Summary_TTFP_ms_vs_features.png differ diff --git a/docs/design/figures/omni/TTFP_s_vllm_omni_vs_transformers.png b/docs/design/figures/omni/TTFP_s_vllm_omni_vs_transformers.png new file mode 100644 index 0000000000..9a5b6c9bda Binary files /dev/null and b/docs/design/figures/omni/TTFP_s_vllm_omni_vs_transformers.png differ diff --git a/docs/design/figures/tts/Mean_AUDIO_RTF_vllm_omni_vs_transformers.png b/docs/design/figures/tts/Mean_AUDIO_RTF_vllm_omni_vs_transformers.png new file mode 100644 index 0000000000..68f0ef17e8 Binary files /dev/null and b/docs/design/figures/tts/Mean_AUDIO_RTF_vllm_omni_vs_transformers.png differ diff --git a/docs/design/figures/tts/Mean_AUDIO_TTFP_(ms)_vllm_omni_vs_transformers.png b/docs/design/figures/tts/Mean_AUDIO_TTFP_(ms)_vllm_omni_vs_transformers.png new file mode 100644 index 0000000000..44be96e96d Binary files /dev/null and b/docs/design/figures/tts/Mean_AUDIO_TTFP_(ms)_vllm_omni_vs_transformers.png differ diff --git a/docs/design/figures/tts/Mean_E2EL_(ms)_vllm_omni_vs_transformers.png b/docs/design/figures/tts/Mean_E2EL_(ms)_vllm_omni_vs_transformers.png new file mode 100644 index 0000000000..2e5d1482bd Binary files /dev/null and b/docs/design/figures/tts/Mean_E2EL_(ms)_vllm_omni_vs_transformers.png differ diff --git a/docs/design/figures/tts/Mean_mean_e2e_ms_baseline_vs_batch.png b/docs/design/figures/tts/Mean_mean_e2e_ms_baseline_vs_batch.png new file mode 100644 index 0000000000..04d8f0bac5 Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_e2e_ms_baseline_vs_batch.png differ diff --git a/docs/design/figures/tts/Mean_mean_e2e_ms_batch_vs_cuda_graph.png b/docs/design/figures/tts/Mean_mean_e2e_ms_batch_vs_cuda_graph.png new file mode 100644 index 0000000000..eb85ec0dd4 Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_e2e_ms_batch_vs_cuda_graph.png differ diff --git a/docs/design/figures/tts/Mean_mean_e2e_ms_cuda_graph_vs_async_chunk.png b/docs/design/figures/tts/Mean_mean_e2e_ms_cuda_graph_vs_async_chunk.png new file mode 100644 index 0000000000..6f0e0e2529 Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_e2e_ms_cuda_graph_vs_async_chunk.png differ diff --git a/docs/design/figures/tts/Mean_mean_rtf_baseline_vs_batch.png b/docs/design/figures/tts/Mean_mean_rtf_baseline_vs_batch.png new file mode 100644 index 0000000000..89ea30a864 Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_rtf_baseline_vs_batch.png differ diff --git a/docs/design/figures/tts/Mean_mean_rtf_batch_vs_cuda_graph.png b/docs/design/figures/tts/Mean_mean_rtf_batch_vs_cuda_graph.png new file mode 100644 index 0000000000..2b207b8898 Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_rtf_batch_vs_cuda_graph.png differ diff --git a/docs/design/figures/tts/Mean_mean_rtf_cuda_graph_vs_async_chunk.png b/docs/design/figures/tts/Mean_mean_rtf_cuda_graph_vs_async_chunk.png new file mode 100644 index 0000000000..f5f7ad72c8 Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_rtf_cuda_graph_vs_async_chunk.png differ diff --git a/docs/design/figures/tts/Mean_mean_ttfp_ms_baseline_vs_batch.png b/docs/design/figures/tts/Mean_mean_ttfp_ms_baseline_vs_batch.png new file mode 100644 index 0000000000..6f8c1da4a5 Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_ttfp_ms_baseline_vs_batch.png differ diff --git a/docs/design/figures/tts/Mean_mean_ttfp_ms_batch_vs_cuda_graph.png b/docs/design/figures/tts/Mean_mean_ttfp_ms_batch_vs_cuda_graph.png new file mode 100644 index 0000000000..b0fe1d02a9 Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_ttfp_ms_batch_vs_cuda_graph.png differ diff --git a/docs/design/figures/tts/Mean_mean_ttfp_ms_cuda_graph_vs_async_chunk.png b/docs/design/figures/tts/Mean_mean_ttfp_ms_cuda_graph_vs_async_chunk.png new file mode 100644 index 0000000000..008ba9bf78 Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_ttfp_ms_cuda_graph_vs_async_chunk.png differ diff --git a/docs/design/figures/tts/Summary_mean_e2e_ms_vs_features.png b/docs/design/figures/tts/Summary_mean_e2e_ms_vs_features.png new file mode 100644 index 0000000000..7c65aa1177 Binary files /dev/null and b/docs/design/figures/tts/Summary_mean_e2e_ms_vs_features.png differ diff --git a/docs/design/figures/tts/Summary_mean_rtf_vs_features.png b/docs/design/figures/tts/Summary_mean_rtf_vs_features.png new file mode 100644 index 0000000000..71bb2c5468 Binary files /dev/null and b/docs/design/figures/tts/Summary_mean_rtf_vs_features.png differ diff --git a/docs/design/figures/tts/Summary_mean_ttfp_ms_vs_features.png b/docs/design/figures/tts/Summary_mean_ttfp_ms_vs_features.png new file mode 100644 index 0000000000..cef2546d6f Binary files /dev/null and b/docs/design/figures/tts/Summary_mean_ttfp_ms_vs_features.png differ diff --git a/docs/design/qwen3_omni_tts_performance_optimization.md b/docs/design/qwen3_omni_tts_performance_optimization.md new file mode 100644 index 0000000000..2f18a1b1bc --- /dev/null +++ b/docs/design/qwen3_omni_tts_performance_optimization.md @@ -0,0 +1,539 @@ +# Speech Generation on vLLM-Omni: Performance Optimizations for Qwen3-Omni and Qwen3-TTS + +## Summary + +vLLM-Omni supports end-to-end serving for speech-generating models, including both **Qwen3-Omni** (multimodal understanding + speech) and **Qwen3-TTS** (text-to-speech). Despite their different architectures, both models share the same multi-stage pipeline design and benefit from the same set of stacked optimizations: + +1. **Batching** improves GPU utilization stage by stage and increases overall throughput. +2. **CUDA Graph** reduces CPU launch overhead and decode-time jitter on stable shapes. +3. **Async Chunk and Streaming Output** overlap compute and communication across stages and emit audio incrementally, improving both TTFP and E2E. + +### Model architectures + +**Qwen3-Omni** is a native multimodal model that understands text, audio, image, and video inputs, and generates both text and speech outputs. Its pipeline has three stages: + +- **Thinker**: multimodal understanding and text generation +- **Talker (+ Talker-MTP / code predictor path)**: converts semantic/text representations into codec tokens +- **Code2Wav**: decodes codec tokens into waveform audio + +**Qwen3-TTS** is a lightweight, high-quality text-to-speech model. Its pipeline has two stages: + +- **Talker (AR decoder)**: auto-regressively generates codec tokens from text input +- **Code2Wav (vocoder)**: decodes codec tokens into waveform audio + +The optimizations described in this post apply to both models. We present results for each side by side. + +### vLLM-Omni vs HF Transformers + +Compared with **HF Transformers** (offline, single request), vLLM-Omni with the full optimization stack delivers dramatically lower latency and higher efficiency for both models. + +**Qwen3-Omni** (A100): + + + + + +
Qwen3-Omni E2EL: vLLM vs HFQwen3-Omni TTFP: vLLM vs HFQwen3-Omni RTF: vLLM vs HF
+ +| Metric | vLLM-Omni | HF Transformers | Improvement | +| --- | --- | --- | --- | +| E2E latency (s) | 23.78 | 336.10 | ~93% reduction | +| TTFP (s) | 0.934 | 336.10 | ~99.7% reduction | +| RTF | 0.32 | 3.776 | ~91% reduction (~12× faster) | + +- **E2E latency**: 23.78 s vs 336.10 s - **~93%** reduction +- **TTFP**: 0.934 s vs 336.10 s - **~99.7%** reduction +- **RTF**: 0.32 vs 3.776 - **~91%** reduction (~12x faster) + +**Qwen3-TTS** (H200, concurrency 1): + + + + + +
Qwen3-TTS E2EL: vLLM vs HFQwen3-TTS TTFP: vLLM vs HFQwen3-TTS RTF: vLLM vs HF
+ +| Metric | vLLM-Omni | HF Transformers | Improvement | +| --- | --- | --- | --- | +| E2E latency (ms) | 941 | 15,513 | ~94% reduction | +| TTFP (ms) | 64 | 15,513 | ~99.6% reduction (242× faster) | +| RTF | 0.16 | 2.64 | ~94% reduction (~16.5× faster) | + +- **E2E latency**: 941 ms vs 15,513 ms - **~94%** reduction +- **TTFP**: 64 ms vs 15,513 ms - **~99.6%** reduction (242x faster) +- **RTF**: 0.16 vs 2.64 - **~94%** reduction (~16.5x faster) + +### Stacked optimization summary + +Each optimization stacks on the previous one. The summary plots below show the cumulative effect at each step, with one line per concurrency level (1, 4, 10). + +**Qwen3-Omni** (A100): + + + + + +
Qwen3-Omni E2EL: stacked optimizationQwen3-Omni TTFP: stacked optimizationQwen3-Omni RTF: stacked optimization
+ +- **E2EL reduction**: ~74% at concurrency 10 (410,054 ms -> 104,901 ms); ~90% at concurrency 1 (426,529 ms -> 41,216 ms) +- **TTFP reduction**: ~96% at concurrency 10 (409,705 ms -> 16,482 ms); ~99.7% at concurrency 1 (426,078 ms -> 1,164 ms) +- **RTF reduction**: ~74% at concurrency 10 (2.83 -> 0.74); ~90% at concurrency 1 (2.08 -> 0.21) + +**Qwen3-TTS** (H200): + + + + + +
Qwen3-TTS E2EL: stacked optimizationQwen3-TTS TTFP: stacked optimizationQwen3-TTS RTF: stacked optimization
+ +- **E2EL reduction**: ~85% at concurrency 10 (12,141 ms -> 1,767 ms); ~29% at concurrency 1 (1,323 ms -> 941 ms) +- **TTFP reduction**: ~96.5% at concurrency 10 (12,141 ms -> 425 ms); ~95% at concurrency 1 (1,323 ms -> 64 ms) +- **RTF reduction**: ~86% at concurrency 10 (2.19 -> 0.31); ~30% at concurrency 1 (0.23 -> 0.16) + +**Benchmark environment:** + +| | Qwen3-Omni | Qwen3-TTS | +| --- |-----------------------------| --- | +| **GPU** | A100 | H200 | +| **Model** | Qwen3-Omni-30B-A3B-Instruct | Qwen3-TTS-12Hz-1.7B-CustomVoice | +| **vLLM** | v0.17.0 | v0.18.0 | +| **vllm-omni** | commit 199f7832 | v0.18.0rc2 | +| **CUDA** | 12.9 | 12.8 | + +This post walks through each optimization in the same order they are typically enabled in practice, then ends with deployment playbooks for both models. + +--- + +## Pipeline Batching + +### How stage-wise batching works + +For both Qwen3-Omni and Qwen3-TTS, batching is a pipeline-level optimization: + +- Requests are grouped per stage using `runtime.max_batch_size` +- Each stage executes batch inference with its own scheduler/worker +- Stage outputs are routed to downstream stages with per-request mapping preserved + +**Batching strategy by stage:** The understanding and decode stages (Thinker for Omni, Talker for both) use **continuous batching**: requests can join and leave the batch over time. Code2Wav uses **static batching**: once a batch is formed, the stage runs the whole batch before starting the next. This matches the decode pattern of Code2Wav and keeps implementation simple while still improving throughput. + +### Batching results (Baseline vs. Batch) + +Batching alone greatly reduces E2EL and RTF across all concurrencies. The biggest gains appear at high concurrency where requests share GPU resources. + +**Qwen3-Omni** (A100): + + + + + +
Qwen3-Omni E2EL: Baseline vs BatchQwen3-Omni TTFP: Baseline vs BatchQwen3-Omni RTF: Baseline vs Batch
+ +| Metric | Concurrency | Baseline | + Batch | Improvement | +| --- | --- | --- | --- | --- | +| E2EL (ms) | 1 | 426,529 | 307,719 | 1.4× | +| E2EL (ms) | 4 | 407,213 | 376,934 | 1.1× | +| E2EL (ms) | 10 | 410,054 | 234,844 | 1.7× | +| TTFP (ms) | 1 | 426,078 | 307,262 | 1.4× | +| TTFP (ms) | 4 | 406,843 | 376,466 | 1.1× | +| TTFP (ms) | 10 | 409,705 | 234,557 | 1.7× | +| RTF | 1 | 2.08 | 1.51 | 1.4× | +| RTF | 4 | 2.55 | 1.83 | 1.4× | +| RTF | 10 | 2.83 | 2.28 | 1.2× | + +At concurrency 10, E2EL drops from ~410 s to ~235 s; at concurrency 1, from ~427 s to ~308 s. + +**Qwen3-TTS** (H200): + + + + + +
Qwen3-TTS E2EL: Baseline vs BatchQwen3-TTS TTFP: Baseline vs BatchQwen3-TTS RTF: Baseline vs Batch
+ +| Metric | Concurrency | Baseline | + Batch | Improvement | +| --- | --- | --- | --- | --- | +| E2EL (ms) | 1 | 1,323 | 1,339 | 1.0× | +| E2EL (ms) | 4 | 5,171 | 1,471 | 3.5× | +| E2EL (ms) | 10 | 12,141 | 1,705 | 7.1× | +| RTF | 1 | 0.230 | 0.234 | 1.0× | +| RTF | 4 | 0.908 | 0.255 | 3.6× | +| RTF | 10 | 2.186 | 0.292 | 7.5× | +| Throughput (audio-s/wall-s) | 10 | 3.99 | 33.53 | 8.4× | + +At concurrency 10, batching alone brings Qwen3-TTS RTF from 2.19 (slower than realtime) down to 0.29 (faster than realtime), and throughput from 4.0 to 33.5 audio-sec/wall-sec. + +--- + +## CUDA Graph on the Critical Decode Path + +### Why CUDA Graph helps here + +In decode-heavy serving, repeatedly launching many small kernels from CPU can become a visible overhead. CUDA Graph reduces this overhead by capturing and replaying stable execution graphs. + +In stage configs, this is represented by `enforce_eager: false` for stages where graph capture is desired (Thinker/Talker), while Code2Wav keeps eager mode depending on stage behavior. + +### CUDA Graph results on top of batching + +**Qwen3-Omni** (A100): + + + + + +
Qwen3-Omni E2EL: Batch vs CUDA GraphQwen3-Omni TTFP: Batch vs CUDA GraphQwen3-Omni RTF: Batch vs CUDA Graph
+ +| Metric | Concurrency | Batch | + CUDA Graph | Improvement | +| --- | --- | --- | --- | --- | +| E2EL (ms) | 1 | 307,719 | 61,613 | 5.0× | +| E2EL (ms) | 4 | 376,934 | 79,019 | 4.8× | +| E2EL (ms) | 10 | 234,844 | 126,867 | 1.9× | +| TTFP (ms) | 1 | 307,262 | 61,257 | 5.0× | +| TTFP (ms) | 4 | 376,466 | 78,634 | 4.8× | +| TTFP (ms) | 10 | 234,557 | 126,534 | 1.9× | +| RTF | 1 | 1.51 | 0.32 | 4.7× | +| RTF | 4 | 1.83 | 0.43 | 4.3× | +| RTF | 10 | 2.28 | 0.90 | 2.5× | + +For the larger Qwen3-Omni model (30B-A3B), CUDA Graph provides a significant improvement. At concurrency 1, E2EL drops from ~308 s to ~62 s; at concurrency 10, from ~235 s to ~127 s. + +**Qwen3-TTS** (H200): + + + + + +
TTS E2EL: Batch vs +CGTTS TTFP: Batch vs +CGTTS RTF: Batch vs +CG
+ +| Metric | Concurrency | Batch | + CUDA Graph | Improvement | +| --- | --- | --- | --- | --- | +| E2EL (ms) | 1 | 1,339 | 733 | 1.8× | +| E2EL (ms) | 4 | 1,471 | 987 | 1.5× | +| E2EL (ms) | 10 | 1,705 | 1,197 | 1.4× | +| RTF | 1 | 0.234 | 0.124 | 1.9× | +| RTF | 10 | 0.292 | 0.203 | 1.4× | +| Throughput (audio-s/wall-s) | 10 | 33.53 | 47.15 | 1.4× | + +At concurrency 1, CUDA Graph reduces E2EL from 1,339 ms to 733 ms and RTF from 0.234 to 0.124 - nearly a 2x improvement. The benefit is consistent across all concurrency levels. + +--- + +## Async Chunk and Streaming Output: Earlier Audio and Cross-Stage Overlap + +### Why this step matters for first-packet latency + +Two mechanisms work together to improve user-visible latency: + +- **Streaming output**: audio streaming emits audio chunks as soon as they are decoded (lower **TTFP**). Without streaming, the client waits for larger buffers or end-of-sequence. +- **Async chunk** is the main enabler for *earlier* audio: instead of handing off whole-request results between stages, each stage forwards **chunks** so the next stage can start as soon as the first chunk is ready. For Omni: Thinker -> Talker forwards hidden-state chunks; for both: Talker -> Code2Wav forwards codec chunks; Code2Wav decodes and emits packets incrementally. This **overlaps compute and communication** across stages and directly reduces time-to-first-audio-packet (TTFP) and end-to-end latency (E2EL). + +So in practice: streaming output defines *how* bytes are sent to the client; async chunk defines *when* the pipeline can produce the first bytes. + +**Dependency between the two:** Async chunk and audio streaming output are mutually dependent. Without async chunk, **audio streaming output cannot truly take effect**. Without audio streaming output, async chunk's **TTFP advantage is not fully realized**: the client would still wait for larger buffers or end-of-sequence instead of hearing the first packet as soon as it is ready. We therefore recommend enabling **both** on top of batching + CUDA Graph; the benchmarks in this post use both. + +### Results: Batch + CUDA Graph vs. Batch + CUDA Graph + Async Chunk + Streaming Output + +**Qwen3-Omni** (A100): + + + + + +
Qwen3-Omni E2EL: CG vs Async ChunkQwen3-Omni TTFP: CG vs Async ChunkQwen3-Omni RTF: CG vs Async Chunk
+ +| Metric | Concurrency | Batch + CG | + Async Chunk | Improvement | +| --- | --- | --- | --- | --- | +| E2EL (ms) | 1 | 61,613 | 41,216 | 1.5× | +| E2EL (ms) | 4 | 79,019 | 67,584 | 1.2× | +| E2EL (ms) | 10 | 126,867 | 104,901 | 1.2× | +| TTFP (ms) | 1 | 61,257 | 1,164 | 53× | +| TTFP (ms) | 4 | 78,634 | 3,152 | 24.9× | +| TTFP (ms) | 10 | 126,534 | 16,482 | 7.7× | +| RTF | 1 | 0.32 | 0.21 | 1.5× | +| RTF | 4 | 0.43 | 0.34 | 1.3× | +| RTF | 10 | 0.90 | 0.74 | 1.2× | + +Enabling both brings TTFP down sharply (concurrency 1: 61,257 ms -> 1,164 ms, **~98% reduction**; concurrency 4: 78,634 ms -> 3,152 ms, **~96% reduction**). E2EL and RTF also improve at every concurrency. + +**Qwen3-TTS** (H200): + + + + + +
Qwen3-TTS E2EL: CG vs Async ChunkQwen3-TTS TTFP: CG vs Async ChunkQwen3-TTS RTF: CG vs Async Chunk
+ +| Metric | Concurrency | Batch + CG | + Async Chunk | Improvement | +| --- | --- | --- | --- | --- | +| TTFP (ms) | 1 | 733 | **64** | **11.5×** | +| TTFP (ms) | 4 | 987 | **119** | **8.3×** | +| TTFP (ms) | 10 | 1,197 | **425** | **2.8×** | +| E2EL (ms) | 1 | 733 | 941 | 0.8× | +| E2EL (ms) | 10 | 1,197 | 1,767 | 0.7× | +| RTF | 1 | 0.124 | 0.160 | 0.8× | +| RTF | 10 | 0.203 | 0.314 | 0.6× | + +The TTFP improvement is the headline result for both models. For Qwen3-TTS at concurrency 1, users hear the first audio in **64 ms** instead of 733 ms - an **11.5x reduction**. For Qwen3-Omni at concurrency 1, TTFP drops from 61 s to 1.2 s - a **53x reduction**. + +### Why E2EL and RTF are higher with async chunk (TTS) + +The table above shows that enabling async chunk + streaming *increases* E2EL and RTF for TTS compared to CUDA Graph alone. This is expected - the two configurations optimize for fundamentally different metrics: + +- **CUDA Graph (no async chunk)** generates the entire audio end-to-end before returning. No chunking overhead, so total compute is minimized. +- **Async Chunk + Streaming** splits the pipeline into incremental chunks, adding overhead from chunked transport, context overlap in Code2Wav (`codec_left_context_frames=25`), and smaller effective batch sizes per chunk. + +**The tradeoff is intentional.** Async chunk trades ~30% higher total compute for **11x faster time-to-first-audio**. For interactive applications (voice assistants, chatbots), TTFP determines perceived responsiveness. For offline batch processing, CUDA Graph without async chunk is the better choice. + +--- + +## TTS-Specific: Code Predictor Re-prefill + `torch.compile` + +Qwen3-TTS has a **code predictor** - a small 5-layer transformer that generates residual codebook tokens (groups 1 through Q-1) autoregressively. Each AR step operates on very short sequences (2 to ~16 tokens). + +The naive approach uses a KV cache for this small transformer, similar to the main Talker. But the KV cache machinery (block tables, slot mappings, paged attention) introduces significant overhead relative to the tiny model. Two optimizations replace that: + +### Re-prefill (stateless forward, no KV cache) + +Instead of maintaining a KV cache across steps, the code predictor **re-feeds the full growing sequence** at each AR step using `F.scaled_dot_product_attention`. With sequences of at most ~16 tokens through 5 layers, the O(T^2) attention cost is negligible - and removing the KV cache machinery (block table management, `set_forward_context`, slot mapping) saves far more time than it costs. + +### `torch.compile` on the code predictor forward + +The 5-layer transformer forward pass launches ~60 small CUDA kernels per step. `torch.compile(mode="default", dynamic=True)` fuses these into fewer kernels via Inductor: + +```python +self._compiled_model_fwd = torch.compile( + self.model.forward, + mode="default", # no Inductor CUDA graphs, avoids conflict with vLLM's CUDAGraphWrapper + dynamic=True, # sequence length grows each step (2, 3, ..., num_groups+1) +) +``` + +`mode="default"` is used instead of `mode="reduce-overhead"` to avoid conflicts with vLLM's own CUDA graph capture on the main Talker model. `dynamic=True` handles the growing sequence length without recompilation. + +These optimizations are always-on in the current codebase - all Qwen3-TTS benchmark results in this post include them. + +--- + +## TTS-Specific: Dynamic Initial Chunk for Faster First Audio + +In the async chunk pipeline, the standard `codec_chunk_frames` is 25 (each chunk = ~2 seconds of audio at 12 Hz). Waiting for 25 frames before forwarding the first chunk to Code2Wav adds unnecessary TTFP. The **initial codec chunk** optimization sends a smaller first chunk so Code2Wav can start decoding earlier. + +**Dynamic initial chunk sizing (default behavior):** + +Rather than using a fixed initial chunk size, vLLM-Omni dynamically selects it based on current server load. The initial chunk size is chosen from power-of-2 steps [2, 4, 8, 16] based on load factor (`active_requests / max_batch_size`): + +| Server load | Initial chunk frames | Rationale | +| --- | --- | --- | +| Low (e.g. 1/10 active) | **2** (~167 ms of audio) | Minimize TTFP when there's headroom | +| Medium (e.g. 5/10 active) | **4-8** | Balance TTFP vs decode efficiency | +| High (e.g. 10/10 active) | **16** | Larger first chunk to amortize decode cost | + +After the initial chunk, all subsequent chunks use the standard `codec_chunk_frames` (25) size. + +**How it works in the pipeline:** + +1. Talker generates codec tokens auto-regressively +2. The stage input processor checks current load and picks an initial chunk size (e.g. **2 frames** at low load) +3. After that many frames, the first chunk is forwarded to Code2Wav +4. Code2Wav decodes this small chunk and emits the first audio packet +5. Subsequent chunks use the standard 25-frame size for efficient batch decoding + +**Per-request override:** Clients can also set a fixed initial chunk size via the API: + +```json +{"initial_codec_chunk_frames": 2} +``` + +This overrides the dynamic calculation for that request. + +**Config (server-side):** + +```yaml +runtime: + connectors: + connector_of_shared_memory: + name: SharedMemoryConnector + extra: + codec_streaming: true + codec_chunk_frames: 25 # standard chunk size (~2s of audio) + codec_left_context_frames: 25 + # initial chunk is computed dynamically by default + # set initial_codec_chunk_frames: 2 to force a fixed value +``` + +The 64 ms TTFP result reported above for Qwen3-TTS at concurrency 1 uses the dynamic initial chunk, which picks `initial_codec_chunk_frames=2` at low load. At higher concurrency the dynamic sizing increases the initial chunk to maintain decode efficiency. + +--- + +## Live Demo: Streaming TTS over WebSocket + +vLLM-Omni supports real-time streaming audio output for Qwen3-TTS over WebSocket ([PR #1719](https://github.com/vllm-project/vllm-omni/pull/1719)). With `stream_audio: true`, the server sends chunked PCM audio frames as they are generated, so clients can start playback before full sentence synthesis completes. + +The WebSocket protocol uses `audio.start` / binary PCM chunks / `audio.done` framing per sentence: + +```json +// Client sends: +{"type":"session.config","voice":"Vivian","response_format":"pcm","stream_audio":true} +{"type":"input.text","text":"Hello world. This is a streaming demo."} +{"type":"input.done"} + +// Server streams back per sentence: +{"type":"audio.start","sentence_index":0,"sentence_text":"Hello world.","format":"pcm","sample_rate":24000} + + +... +{"type":"audio.done","sentence_index":0,"total_bytes":96000,"error":false} +{"type":"audio.start","sentence_index":1,"sentence_text":"This is a streaming demo.","format":"pcm","sample_rate":24000} + +... +{"type":"audio.done","sentence_index":1,"total_bytes":72000,"error":false} +{"type":"session.done","total_sentences":2} +``` + + + +--- + +## Deployment Playbook + +### Qwen3-Omni + +#### 1) Serve with the default 3-stage config + +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct \ + --omni \ + --port 8091 +``` + +Notes: + +- `runtime.max_batch_size` controls stage-level batching. +- Thinker/Talker commonly use `enforce_eager: false` for CUDA Graph paths. +- Code2Wav often remains eager (`enforce_eager: true`) depending on runtime behavior. + +#### 2) Enable async chunk + +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct \ + --omni \ + --port 8091 \ + --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml +``` + +#### 3) Key config knobs + +```yaml +async_chunk: true +stage_args: + - stage_id: 0 # thinker + runtime: + max_batch_size: 64 + engine_args: + enforce_eager: false + max_num_batched_tokens: 32768 + custom_process_next_stage_input_func: >- + vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk + + - stage_id: 1 # talker + runtime: + max_batch_size: 64 + engine_args: + enforce_eager: false + max_num_batched_tokens: 32768 + custom_process_next_stage_input_func: >- + vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk + + - stage_id: 2 # code2wav + runtime: + max_batch_size: 64 + engine_args: + enforce_eager: true + max_num_batched_tokens: 51200 +``` + +#### Reproduce Qwen3-Omni benchmarks + +```bash +vllm bench serve \ + --dataset-name random \ + --port ${PORT} \ + --model ${MODEL_PATH} \ + --endpoint /v1/chat/completions \ + --backend openai-chat-omni \ + --max-concurrency ${MAX_CONCURRENCY} \ + --num-prompts ${NUM_PROMPTS} \ + --random-input-len 2500 \ + --ignore-eos \ + --percentile-metrics ttft,tpot,itl,e2el,audio_ttfp,audio_rtf \ + --random-output-len 900 \ + --extra_body '{"modalities": ["text","audio"]}' +``` + +### Qwen3-TTS + +#### 1) Serve with async chunk (recommended) + +```bash +vllm-omni serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \ + --omni \ + --port 8000 +``` + +The default config (`qwen3_tts.yaml`) enables the full optimization stack: + +- Batching with `max_batch_size: 10` on the Talker stage +- CUDA Graph on the Talker (`enforce_eager: false`) +- Async chunk with streaming transport + +#### 2) Serve without async chunk (for comparison) + +```bash +vllm-omni serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \ + --omni \ + --port 8000 \ + --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml +``` + +#### 3) Key config knobs + +```yaml +async_chunk: true +stage_args: + - stage_id: 0 # Talker (AR decoder) + runtime: + max_batch_size: 10 + engine_args: + enforce_eager: false + max_num_batched_tokens: 512 + custom_process_next_stage_input_func: >- + vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk + + - stage_id: 1 # Code2Wav (vocoder) + runtime: + max_batch_size: 1 + engine_args: + enforce_eager: true + max_num_batched_tokens: 8192 + +runtime: + connectors: + connector_of_shared_memory: + name: SharedMemoryConnector + extra: + codec_streaming: true + codec_chunk_frames: 25 + codec_left_context_frames: 25 +``` + +#### Reproduce Qwen3-TTS benchmarks + +```bash +GPU_DEVICE=0 \ +MODEL=Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \ +NUM_PROMPTS=50 \ +CONCURRENCY="1 4 10" \ +bash benchmarks/qwen3-tts/vllm_omni/run_stacked_benchmark.sh +``` + +This cycles through four configs (Baseline -> + Batch -> + CUDA Graph -> + Async Chunk + Streaming), benchmarks each at the specified concurrency levels, and generates all comparison figures automatically. diff --git a/docs/getting_started/installation/gpu/rocm.inc.md b/docs/getting_started/installation/gpu/rocm.inc.md index 1a683d174f..5dfea8d2ff 100644 --- a/docs/getting_started/installation/gpu/rocm.inc.md +++ b/docs/getting_started/installation/gpu/rocm.inc.md @@ -26,7 +26,7 @@ uv pip install vllm-omni # Optional if want to run Qwen3 TTS uv pip uninstall onnxruntime # should be removed before we can install onnxruntime-rocm -uv pip install onnxruntime-rocm sox +uv pip install onnxruntime-rocm ``` # --8<-- [end:pre-built-wheels] diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 0f9c8fff60..b4976b09c4 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -47,6 +47,7 @@ th { | `Flux2KleinPipeline` | FLUX.2-klein | `black-forest-labs/FLUX.2-klein-4B`, `black-forest-labs/FLUX.2-klein-9B` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | `FluxKontextPipeline` | FLUX.1-Kontext-dev | `black-forest-labs/FLUX.1-Kontext-dev` | ✅︎ | ✅︎ | | | | `FluxPipeline` | FLUX.1-dev | `black-forest-labs/FLUX.1-dev` | ✅︎ | ✅︎ | | ✅︎ | +| `FluxPipeline` | FLUX.1-schnell | `black-forest-labs/FLUX.1-schnell` | ✅︎ | ✅︎ | | ✅︎ | | `OmniGen2Pipeline` | OmniGen2 | `OmniGen2/OmniGen2` | ✅︎ | ✅︎ | | ✅︎ | | `StableAudioPipeline` | Stable-Audio-Open | `stabilityai/stable-audio-open-1.0` | ✅︎ | ✅︎ | | ✅︎ | | `Qwen3TTSForConditionalGeneration` | Qwen3-TTS-12Hz-1.7B-CustomVoice | `Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice` | ✅︎ | ✅︎ | ✅︎ | ✅︎ | diff --git a/docs/serving/image_edit_api.md b/docs/serving/image_edit_api.md index d254ac06ad..79303e1a69 100644 --- a/docs/serving/image_edit_api.md +++ b/docs/serving/image_edit_api.md @@ -104,6 +104,8 @@ Content-Type: multipart/form-data | `guidance_scale` | float | model defaults | Classifier-free guidance scale (typically 0.0-20.0) | | `true_cfg_scale` | float | model defaults | True CFG scale (model-specific parameter, may be ignored if not supported) | | `seed` | integer | null | Random seed for reproducibility | +| `reference_image` | string or array | null | Reference image for inpainting | +| `mask_image` | string or array | null | Mask for inpainting (white areas will be inpainted) | ### Response Format diff --git a/docs/serving/speech_api.md b/docs/serving/speech_api.md index ecbe8d9ac9..733811081a 100644 --- a/docs/serving/speech_api.md +++ b/docs/serving/speech_api.md @@ -15,7 +15,7 @@ Each server instance runs a single model (specified at startup via `vllm serve < ```bash # Qwen3-TTS: CustomVoice model (predefined speakers) vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \ - --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \ + --deploy-config vllm_omni/deploy/qwen3_tts.yaml \ --omni \ --port 8091 \ --trust-remote-code \ @@ -300,7 +300,7 @@ curl -X POST http://localhost:8091/v1/audio/speech \ ```bash # Start server with VoiceDesign model first vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign \ - --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \ + --deploy-config vllm_omni/deploy/qwen3_tts.yaml \ --omni \ --port 8091 \ --trust-remote-code \ @@ -322,7 +322,7 @@ curl -X POST http://localhost:8091/v1/audio/speech \ ```bash # Start server with Base model first vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-Base \ - --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \ + --deploy-config vllm_omni/deploy/qwen3_tts.yaml \ --omni \ --port 8091 \ --trust-remote-code \ @@ -517,15 +517,16 @@ for result in response.json()["results"]: All items are fanned out to `generate()` concurrently. The engine's stage worker automatically batches them up to the configured `max_batch_size` and queues the rest — no client-side throttling needed. -For best throughput, use a batch-optimized stage config with `max_batch_size > 1`: +For best throughput, set both stages' `max_num_seqs` to ≥4 via `--stage-overrides`: ```bash vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \ - --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml \ - --omni --port 8091 --trust-remote-code --enforce-eager + --omni --port 8091 --trust-remote-code --enforce-eager \ + --stage-overrides '{"0":{"max_num_seqs":4,"gpu_memory_utilization":0.2}, + "1":{"max_num_seqs":4,"gpu_memory_utilization":0.2}}' ``` -The default `qwen3_tts.yaml` uses `max_batch_size: 1` (single request). The `qwen3_tts_batch.yaml` config sets `max_batch_size: 4` for ~4x throughput. +The bundled `qwen3_tts.yaml` uses `max_num_seqs: 1` (single request) on both stages. Bumping to 4 yields roughly 4× throughput on the talker and lets stage 1 batch chunks across in-flight requests. ## Supported Models @@ -617,7 +618,7 @@ Enable debug logging: ```bash vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \ - --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \ + --deploy-config vllm_omni/deploy/qwen3_tts.yaml \ --omni \ --port 8091 \ --trust-remote-code \ diff --git a/docs/source/architecture/async-chunk-architecture.png b/docs/source/architecture/async-chunk-architecture.png index 249de53bfe..7b3e95e4df 100644 Binary files a/docs/source/architecture/async-chunk-architecture.png and b/docs/source/architecture/async-chunk-architecture.png differ diff --git a/docs/source/architecture/qwen3-omni-async-chunk.png b/docs/source/architecture/qwen3-omni-async-chunk.png index b2d98b80f3..e73ca84b28 100644 Binary files a/docs/source/architecture/qwen3-omni-async-chunk.png and b/docs/source/architecture/qwen3-omni-async-chunk.png differ diff --git a/docs/source/architecture/qwen3-omni-non-async-chunk.png b/docs/source/architecture/qwen3-omni-non-async-chunk.png index da5610a11b..47a9ba66a5 100644 Binary files a/docs/source/architecture/qwen3-omni-non-async-chunk.png and b/docs/source/architecture/qwen3-omni-non-async-chunk.png differ diff --git a/docs/source/architecture/vllm-omni-dataflow-between-stages.png b/docs/source/architecture/vllm-omni-dataflow-between-stages.png index cdbc9a8b7b..74abc81ff0 100644 Binary files a/docs/source/architecture/vllm-omni-dataflow-between-stages.png and b/docs/source/architecture/vllm-omni-dataflow-between-stages.png differ diff --git a/docs/usage/faq.md b/docs/usage/faq.md index c080eae402..0539e158b0 100644 --- a/docs/usage/faq.md +++ b/docs/usage/faq.md @@ -4,14 +4,6 @@ A: Now, we support natively disaggregated deployment for different model stages within a model. There is a restriction that one chip can only have one AutoRegressive model stage. This is because the unified KV cache management of vLLM. Stages of other types can coexist within a chip. The restriction will be resolved in later version. -> Q: When trying to run examples, I encounter error about backend of librosa or soundfile. How to solve it? - -A: If you encounter error about backend of librosa, try to install ffmpeg with command below. -``` -sudo apt update -sudo apt install ffmpeg -``` - > Q: I see GPU OOM or "free memory is less than desired GPU memory utilization" errors. How can I fix it? A: Refer to [GPU memory calculation and configuration](../configuration/gpu_memory_utilization.md) for guidance on tuning `gpu_memory_utilization` and related settings. diff --git a/docs/user_guide/diffusion/frame_interpolation.md b/docs/user_guide/diffusion/frame_interpolation.md new file mode 100644 index 0000000000..349af50c51 --- /dev/null +++ b/docs/user_guide/diffusion/frame_interpolation.md @@ -0,0 +1,92 @@ +# Frame Interpolation + +## Overview + +vLLM-Omni supports post-generation frame interpolation for supported video +diffusion pipelines. This feature inserts synthesized intermediate frames +between adjacent generated frames to improve temporal smoothness without +rerunning the diffusion denoising loop. + +Frame interpolation runs in the diffusion worker post-processing path instead +of the API server encoding path. This allows the interpolation step to reuse +the worker's current accelerator device and keeps the FastAPI event loop free +from heavy synchronous PyTorch work. + +For an input video with `N` generated frames and interpolation exponent `exp`, +the output frame count is: + +```text +(N - 1) * 2**exp + 1 +``` + +The output FPS is multiplied by `2**exp` so the clip duration remains close to +the original generated video. + +## Supported Pipelines + +Frame interpolation is currently supported for: + +- `WanPipeline` (Wan2.2 text-to-video) +- `WanImageToVideoPipeline` +- `Wan22TI2VPipeline` + +## Request Parameters + +The video APIs `/v1/videos` and `/v1/videos/sync` accept: + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `enable_frame_interpolation` | bool | `false` | Enable post-generation frame interpolation | +| `frame_interpolation_exp` | int | `1` | Interpolation exponent. `1=2x`, `2=4x`, etc. | +| `frame_interpolation_scale` | float | `1.0` | RIFE inference scale | +| `frame_interpolation_model_path` | str | `None` | Local directory or Hugging Face repo ID containing `flownet.pkl` | + +## Execution Flow + +For supported Wan2.2 pipelines, the execution order is: + +1. Diffusion worker finishes denoising and decodes the raw video tensor. +2. Worker-side model-specific post-processing runs. +3. If frame interpolation is enabled, RIFE interpolates the decoded video + tensor on the worker side and records a FPS multiplier in `custom_output`. +4. The API server receives the already-interpolated video and only performs + MP4 export. + +This design keeps interpolation close to the generated tensor and avoids +introducing another heavyweight GPU context in the API server process. + +## Example + +Start the server: + +```bash +vllm serve Wan-AI/Wan2.2-T2V-A14B-Diffusers --omni --port 8091 +``` + +Run a sync request with interpolation enabled: + +```bash +curl -X POST http://localhost:8091/v1/videos/sync \ + -F "prompt=A dog running through a park" \ + -F "num_frames=81" \ + -F "width=832" \ + -F "height=480" \ + -F "fps=16" \ + -F "num_inference_steps=40" \ + -F "guidance_scale=1.0" \ + -F "guidance_scale_2=1.0" \ + -F "enable_frame_interpolation=true" \ + -F "frame_interpolation_exp=1" \ + -F "frame_interpolation_scale=1.0" \ + -F "seed=42" \ + -o sync_t2v_interpolated.mp4 +``` + +## Notes + +- This is a post-processing feature. It does not modify the diffusion denoising + schedule. +- Higher interpolation exponents increase post-processing time and memory usage. +- If the interpolation model weights are not available locally, + `frame_interpolation_model_path` may point to a Hugging Face repo containing + `flownet.pkl`. diff --git a/docs/user_guide/diffusion/lora.md b/docs/user_guide/diffusion/lora.md index e45c033b84..256698752a 100644 --- a/docs/user_guide/diffusion/lora.md +++ b/docs/user_guide/diffusion/lora.md @@ -56,6 +56,92 @@ outputs = omni.generate( !!! note "Server-side Path Requirement" The LoRA adapter path (`local_path`) must be readable on the **server** machine. If your client and server are on different machines, ensure the LoRA adapter is accessible via a shared mount or copied to the server. +## Wan2.2 LightX2V Offline Assembly + +This workflow is LoRA-adjacent: it uses external LightX2V conversion plus +`Wan2.2-Distill-Loras` to bake converted Wan2.2 I2V checkpoints into a local +Diffusers directory, instead of loading LoRA adapters at runtime. + +### Required assets + +- Base model: `Wan-AI/Wan2.2-I2V-A14B` +- Diffusers skeleton: `Wan-AI/Wan2.2-I2V-A14B-Diffusers` +- Optional external converter from the LightX2V project (not shipped in this repository) +- Optional LoRA weights: `lightx2v/Wan2.2-Distill-Loras` + +### Step 1: Optional - convert high/low-noise DiT weights with LightX2V + +Install or clone LightX2V from the upstream repository +(`https://github.com/ModelTC/LightX2V`). After cloning, the converter used +below is available at `/tools/convert/converter.py`. + +```bash +python /path/to/lightx2v/tools/convert/converter.py \ + --source /path/to/Wan2.2-I2V-A14B/high_noise_model \ + --output /tmp/wan22_lightx2v/high_noise_out \ + --output_ext .safetensors \ + --output_name diffusion_pytorch_model \ + --model_type wan_dit \ + --direction forward \ + --lora_path /path/to/wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors \ + --lora_key_convert auto \ + --single_file + +python /path/to/lightx2v/tools/convert/converter.py \ + --source /path/to/Wan2.2-I2V-A14B/low_noise_model \ + --output /tmp/wan22_lightx2v/low_noise_out \ + --output_ext .safetensors \ + --output_name diffusion_pytorch_model \ + --model_type wan_dit \ + --direction forward \ + --lora_path /path/to/wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors \ + --lora_key_convert auto \ + --single_file +``` + +If you are not using LightX2V, skip this step and either keep the original +Diffusers weights from the skeleton or point Step 2 at any other converted +`transformer/` and `transformer_2/` checkpoints. + +### Step 2: Assemble a final Diffusers-style directory + +```bash +python tools/wan22/assemble_wan22_i2v_diffusers.py \ + --diffusers-skeleton /path/to/Wan2.2-I2V-A14B-Diffusers \ + --transformer-weight /tmp/wan22_lightx2v/high_noise_out \ + --transformer-2-weight /tmp/wan22_lightx2v/low_noise_out \ + --output-dir /path/to/Wan2.2-I2V-A14B-Custom-Diffusers \ + --asset-mode symlink \ + --overwrite +``` + +`--transformer-weight` and `--transformer-2-weight` are optional. If you omit +them, the tool keeps the original weights from the Diffusers skeleton. + +### Step 3: Run offline inference + +```bash +python examples/offline_inference/image_to_video/image_to_video.py \ + --model /path/to/Wan2.2-I2V-A14B-Custom-Diffusers \ + --image /path/to/input.jpg \ + --prompt "A cat playing with yarn" \ + --num-frames 81 \ + --num-inference-steps 4 \ + --tensor-parallel-size 4 \ + --height 480 \ + --width 832 \ + --flow-shift 12 \ + --sample-solver euler \ + --guidance-scale 1.0 \ + --guidance-scale-high 1.0 \ + --boundary-ratio 0.875 +``` + +Notes: + +- This route avoids runtime LoRA loading changes in vLLM-Omni when you choose to bake converted weights into a local Diffusers directory. +- Output quality and speed depend on the replacement checkpoints and sampling params you choose. + ## See Also diff --git a/docs/user_guide/diffusion_features.md b/docs/user_guide/diffusion_features.md index 7e08851812..7bdeede446 100644 --- a/docs/user_guide/diffusion_features.md +++ b/docs/user_guide/diffusion_features.md @@ -14,7 +14,7 @@ vLLM-Omni supports various advanced features for diffusion models: - Acceleration: **cache methods**, **parallelism methods**, **startup optimizations** - Memory optimization: **cpu offloading**, **quantization** -- Extensions: **LoRA inference** +- Extensions: **LoRA inference**, **frame interpolation** - Execution modes: **step execution** ## Supported Features @@ -69,6 +69,7 @@ Extension methods add specialized capabilities to diffusion models beyond standa | Method | Description | Best For | |--------|-------------|----------| | **[LoRA Inference](diffusion/lora.md)** | Enables inference with Low-Rank Adaptation (LoRA) adapters weights | Reinforcement learning extensions | +| **[Frame Interpolation](diffusion/frame_interpolation.md)** | Inserts intermediate video frames after generation for smoother motion | Video generation pipelines that need higher temporal smoothness | ### Execution Modes @@ -108,17 +109,18 @@ The following tables show which models support each feature: |-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:| | **Bagel** | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | | **FLUX.1-dev** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | +| **FLUX.1-schnell** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | | **FLUX.2-klein** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | | **FLUX.1-Kontext-dev** | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | -| **FLUX.2-dev** | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | +| **FLUX.2-dev** | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | | **GLM-Image** | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | | **HunyuanImage3** | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | -| **LongCat-Image** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | -| **LongCat-Image-Edit** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | +| **LongCat-Image** | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | +| **LongCat-Image-Edit** | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | | **MagiHuman** | ❌ | ❌ | ❌ | ❓ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | | **MammothModa2(T2I)** | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | **Nextstep_1(T2I)** | ❓ | ❓ | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | -| **OmniGen2** | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| **OmniGen2** | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | **Ovis-Image** | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | **Qwen-Image** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ✅ | | **Qwen-Image-2512** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ✅ | @@ -138,16 +140,21 @@ The following tables show which models support each feature: |-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:| | **Wan2.2** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (encode/decode) | ❌ | ❌ | | **Wan2.1-VACE** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ❌ | ❌ | -| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | +| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | | **Helios** | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | | **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ | -| **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | + +**Frame Interpolation Support** + +- **Supported**: Wan2.2 text-to-video, image-to-video, and TI2V pipelines +- **Not supported**: Wan2.1-VACE, LTX-2, Helios, HunyuanVideo-1.5, DreamID-Omni ### AudioGen | Model | ⚡TeaCache | ⚡Cache-DiT | 🔀SP (Ulysses & Ring) | 🔀CFG-Parallel | 🔀Tensor-Parallel | 🔀HSDP | 💾CPU Offload (Layerwise) | 💾VAE-Patch-Parallel | 💾Quantization | 🔄Step Execution | |-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:| -| **Stable-Audio-Open** | ❌ | ❌ | ❓ | ❓ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | +| **Stable-Audio-Open** | ✅ | ❌ | ❓ | ❓ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ## Feature Compatibility @@ -258,6 +265,7 @@ Measured on NVIDIA H800: **Extensions:** - **[LoRA Inference Guide](diffusion/lora.md)** - Low-Rank Adaptation for style customization and fine-tuning +- **[Frame Interpolation Guide](diffusion/frame_interpolation.md)** - Worker-side post-generation video frame interpolation for smoother motion **Execution Modes:** diff --git a/docs/user_guide/examples/offline_inference/bagel.md b/docs/user_guide/examples/offline_inference/bagel.md index 5f458750b4..1fb4d40457 100644 --- a/docs/user_guide/examples/offline_inference/bagel.md +++ b/docs/user_guide/examples/offline_inference/bagel.md @@ -176,8 +176,6 @@ Example configuration for TP=2 on GPUs 0 and 1: | Parameter | Value | Description | | :-------------------- | :------ | :------------------------------- | -| `window_size` | `-1` | Window size (-1 means unlimited) | -| `max_inflight` | `1` | Maximum inflight requests | | `shm_threshold_bytes` | `65536` | Shared memory threshold (64KB) | ## Using Mooncake Connector @@ -250,13 +248,6 @@ For more details on the Mooncake connector and multi-node setup, see the [Moonca ## FAQ -- If you encounter an error about the backend of librosa, try to install ffmpeg with the command below. - -```bash -sudo apt update -sudo apt install ffmpeg -``` - - If you don’t know how much VRAM is needed for the model or encounter the OOM error, you can try to decrease the max_model_len. | Stage | VRAM | diff --git a/docs/user_guide/examples/offline_inference/cosyvoice3.md b/docs/user_guide/examples/offline_inference/cosyvoice3.md index d912f1c62e..ebb7c02efc 100644 --- a/docs/user_guide/examples/offline_inference/cosyvoice3.md +++ b/docs/user_guide/examples/offline_inference/cosyvoice3.md @@ -10,7 +10,7 @@ Install dependencies: uv pip install -e . ``` -> **Note:** This includes required libraries such as `librosa`, `soundfile`, +> **Note:** This includes required libraries such as `soundfile`, > `onnxruntime`, `x-transformers`, and `einops` via > `requirements/common.txt` and platform-specific requirements files. diff --git a/docs/user_guide/examples/offline_inference/image_to_video.md b/docs/user_guide/examples/offline_inference/image_to_video.md index 7a750aeff3..6e105741a7 100644 --- a/docs/user_guide/examples/offline_inference/image_to_video.md +++ b/docs/user_guide/examples/offline_inference/image_to_video.md @@ -62,12 +62,13 @@ Key arguments: - `--negative-prompt`: Optional list of artifacts to suppress. - `--boundary-ratio`: Boundary split ratio for two-stage MoE models. - `--flow-shift`: Scheduler flow shift (5.0 for 720p, 12.0 for 480p). +- `--sample-solver`: Wan2.2 sampling solver. Use `unipc` for the default multistep solver, or `euler` for Lightning/Distill checkpoints. - `--num-inference-steps`: Number of denoising steps (default 50). - `--fps`: Frames per second for the saved MP4 (requires `diffusers` export_to_video). - `--output`: Path to save the generated video. - `--vae-use-slicing`: Enable VAE slicing for memory optimization. - `--vae-use-tiling`: Enable VAE tiling for memory optimization. -- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](https://github.com/vllm-project/vllm-omni/tree/main/docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel). +- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](https://github.com/vllm-project/vllm-omni/tree/main/docs/user_guide/diffusion/parallelism/cfg_parallel.md). - `--tensor-parallel-size`: tensor parallel size (effective for models that support TP, e.g. LTX2). - `--enable-cpu-offload`: enable CPU offloading for diffusion models. - `--use-hsdp`: Enable Hybrid Sharded Data Parallel to shard model weights across GPUs. @@ -78,6 +79,9 @@ Key arguments: > ℹ️ If you encounter OOM errors, try using `--vae-use-slicing` and `--vae-use-tiling` to reduce memory usage. +For Wan2.2 LightX2V-converted local Diffusers directories and related LoRA +assets, see the [LoRA guide](../../diffusion/lora.md#wan22-lightx2v-offline-assembly). + ## Example materials ??? abstract "image_to_video.py" diff --git a/docs/user_guide/examples/offline_inference/mimo_audio.md b/docs/user_guide/examples/offline_inference/mimo_audio.md index 1a3be15d69..4e80526971 100644 --- a/docs/user_guide/examples/offline_inference/mimo_audio.md +++ b/docs/user_guide/examples/offline_inference/mimo_audio.md @@ -189,29 +189,6 @@ Note: This task uses hardcoded message lists in the script. ## Troubleshooting -### Audio dependencies (soundfile, librosa) - -This example depends on **soundfile** (read/write WAV) and **librosa** (load audio including MP3). Install the project requirements first: - -```bash -pip install -r requirements/common.txt -# or at least: pip install soundfile>=0.13.1 librosa>=0.11.0 -``` - -- **`soundfile` / libsndfile not found** - `soundfile` uses the C library **libsndfile**. On Linux, install the system package before pip: - - Debian/Ubuntu: `sudo apt-get install libsndfile1` - - For development builds: `sudo apt-get install libsndfile1-dev` - - Then: `pip install soundfile` - -- **`librosa` fails to load MP3 or reports "No backend available"** - Loading MP3 (e.g. in `spoken_dialogue_sft_multiturn` with `.mp3` files) uses **ffmpeg** as the backend. Install ffmpeg: - - Debian/Ubuntu: `sudo apt-get install ffmpeg` - - macOS: `brew install ffmpeg` - -- **`ImportError: No module named 'soundfile'` or `ModuleNotFoundError: ... librosa`** - Ensure you are in the same Python environment where vLLM Omni and the example dependencies are installed, and that `requirements/common.txt` (or the packages above) are installed. - ### Tokenizer path - **`MIMO_AUDIO_TOKENIZER_PATH` not set or model fails to find tokenizer** diff --git a/docs/user_guide/examples/offline_inference/qwen2_5_omni.md b/docs/user_guide/examples/offline_inference/qwen2_5_omni.md index 07a56cf9a0..c54976b540 100644 --- a/docs/user_guide/examples/offline_inference/qwen2_5_omni.md +++ b/docs/user_guide/examples/offline_inference/qwen2_5_omni.md @@ -64,14 +64,6 @@ If media file paths are not provided, the script will use default assets. Suppor - `use_audio_in_video`: Extract audio from video - `text`: Text-only query -### FAQ - -If you encounter error about backend of librosa, try to install ffmpeg with command below. -``` -sudo apt update -sudo apt install ffmpeg -``` - ## Example materials ??? abstract "end2end.py" diff --git a/docs/user_guide/examples/offline_inference/qwen3_omni.md b/docs/user_guide/examples/offline_inference/qwen3_omni.md index 6577092bbf..2d856f7380 100644 --- a/docs/user_guide/examples/offline_inference/qwen3_omni.md +++ b/docs/user_guide/examples/offline_inference/qwen3_omni.md @@ -112,14 +112,6 @@ python end2end_async_chunk.py \ > async_chunk example when you need the stage-level concurrency semantics > described in PR #962 / #1151. -### FAQ - -If you encounter error about backend of librosa, try to install ffmpeg with command below. -``` -sudo apt update -sudo apt install ffmpeg -``` - ## Example materials ??? abstract "end2end.py" diff --git a/docs/user_guide/examples/offline_inference/qwen3_tts.md b/docs/user_guide/examples/offline_inference/qwen3_tts.md index 19fea4132c..7226ac1fe4 100644 --- a/docs/user_guide/examples/offline_inference/qwen3_tts.md +++ b/docs/user_guide/examples/offline_inference/qwen3_tts.md @@ -18,11 +18,11 @@ Please refer to the [stage configuration documentation](https://docs.vllm.ai/pro ### ROCm Dependencies -You will need to install these two dependencies `onnxruntime-rocm` and `sox`. +You will need to install the dependency `onnxruntime-rocm`. ``` pip uninstall onnxruntime # should be removed before we can install onnxruntime-rocm -pip install onnxruntime-rocm sox +pip install onnxruntime-rocm ``` ## Quick Start @@ -144,13 +144,13 @@ completes. This demonstrates that audio data is available progressively rather t ## Batched Decoding -The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, provide a stage config with `max_num_seqs > 1` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`. +The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, set `max_num_seqs > 1` on both stages via `--stage-overrides` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`. ``` python end2end.py --query-type CustomVoice \ --txt-prompts benchmark_prompts.txt \ --batch-size 4 \ - --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml + --stage-overrides '{"0":{"max_num_seqs":4,"gpu_memory_utilization":0.2},"1":{"max_num_seqs":4,"gpu_memory_utilization":0.2}}' ``` **Important:** `--batch-size` must match a CUDA graph capture size (1, 2, 4, 8, 16...) because the Talker's code predictor KV cache is sized to `max_num_seqs`, and CUDA graphs pad the batch to the next capture size. Both stages need `max_num_seqs >= batch_size` in the stage config for batching to take effect. If only stage 1 has a higher `max_num_seqs`, it won't help — stage 1 can only batch chunks from requests that are in-flight simultaneously, which requires stage 0 to also process multiple requests concurrently. diff --git a/docs/user_guide/examples/offline_inference/x_to_video_audio.md b/docs/user_guide/examples/offline_inference/x_to_video_audio.md index 8ea39d8115..cec8d47c59 100644 --- a/docs/user_guide/examples/offline_inference/x_to_video_audio.md +++ b/docs/user_guide/examples/offline_inference/x_to_video_audio.md @@ -31,9 +31,9 @@ dreamid_omni/ ``` ### Run the Inference -``` +```python python x_to_video_audio.py \ - --model /xx/dreamid_omni \ + --model /path/to/dreamid_omni \ --prompt "Two people walking together and singing happily" \ --image-path ./example0.png ./example1.png \ --audio-path ./example0.wav ./example1.wav \ @@ -43,11 +43,33 @@ python x_to_video_audio.py \ --num-inference-steps 45 \ --height 704 \ --width 1280 \ - --output dreamid_omni.mp4 + --output out_dreamid_omni_twoip.mp4 ``` In the current test scenario (2 images + 2 audio inputs), the VRAM requirement is 72GB, regardless of whether cfg-parallel is enabled or disabled. The VRAM usage can be reduced by enabling CPU offload via --enable-cpu-offload. + +You could take reference images/audios from the test cases in the official repo: https://github.com/Guoxu1233/DreamID-Omni + +For example, single IP ref resources can be found under https://github.com/Guoxu1233/DreamID-Omni/tree/main/test_case/oneip, you could download them correspondingly to your local and use them for testing. + +```python +# Example usage for oneip, ref media from the official repo DreamID-Omni +python x_to_video_audio.py \ + --model /path/to/dreamid_omni \ + --prompt ": In the frame, a woman with black long hair is identified as .\n**Overall Environment/Scene**: A lively open-kitchen café at night; stove flames flare, steam rises, and warm pendant lights swing slightly as staff move behind her. The shot is an upper-body close-up.\n**Main Characters/Subjects Appearance**: is a young woman with thick dark wavy hair and a side part. She wears a fitted black top under a light apron, a thin gold chain necklace, and small stud earrings.\n**Main Characters/Subjects Actions**: tastes the sauce with a spoon, then turns her face toward the camera while still holding the spoon, her expression shifting from focused to conflicted.\n maintains eye contact, swallows as if choosing her words, and says, I keep telling myself I’m fine,but some nights it feels like I’m just performing calm." \ + --image-path 9.png \ + --audio-path 9.wav \ + --video-negative-prompt "jitter, bad hands, blur, distortion" \ + --audio-negative-prompt "robotic, muffled, echo, distorted" \ + --cfg-parallel-size 2 \ + --num-inference-steps 45 \ + --height 704 \ + --width 1280 \ + --output out_dreamid_omni_oneip.mp4 +``` + + Key arguments: - `--prompt`: text description (string). - `--model`: path to the model local directory. diff --git a/docs/user_guide/examples/online_serving/bagel.md b/docs/user_guide/examples/online_serving/bagel.md index 4a6094c089..9de31926aa 100644 --- a/docs/user_guide/examples/online_serving/bagel.md +++ b/docs/user_guide/examples/online_serving/bagel.md @@ -357,13 +357,6 @@ curl http://localhost:8091/v1/chat/completions \ ## FAQ -- If you encounter an error about the backend of librosa, try to install ffmpeg with the command below. - -```bash -sudo apt update -sudo apt install ffmpeg -``` - - If you don’t know how much VRAM is needed for the model or encounter the OOM error, you can try to decrease the max_model_len. | Stage | VRAM | diff --git a/docs/user_guide/examples/online_serving/image_to_video.md b/docs/user_guide/examples/online_serving/image_to_video.md index 00b67d74e2..781f0c2a5e 100644 --- a/docs/user_guide/examples/online_serving/image_to_video.md +++ b/docs/user_guide/examples/online_serving/image_to_video.md @@ -72,6 +72,9 @@ curl -X POST http://localhost:8091/v1/videos/sync \ -F "guidance_scale_2=1.0" \ -F "boundary_ratio=0.875" \ -F "flow_shift=12.0" \ + -F "enable_frame_interpolation=true" \ + -F "frame_interpolation_exp=1" \ + -F "frame_interpolation_scale=1.0" \ -F "seed=42" \ -o sync_i2v_output.mp4 ``` @@ -114,6 +117,9 @@ create_response=$(curl -s http://localhost:8091/v1/videos \ -F "guidance_scale_2=1.0" \ -F "boundary_ratio=0.875" \ -F "flow_shift=12.0" \ + -F "enable_frame_interpolation=true" \ + -F "frame_interpolation_exp=1" \ + -F "frame_interpolation_scale=1.0" \ -F "seed=42") video_id=$(echo "$create_response" | jq -r '.id') @@ -172,9 +178,35 @@ curl -X POST http://localhost:8091/v1/videos \ -F "guidance_scale_2=1.0" \ -F "boundary_ratio=0.875" \ -F "flow_shift=12.0" \ + -F "enable_frame_interpolation=true" \ + -F "frame_interpolation_exp=1" \ + -F "frame_interpolation_scale=1.0" \ -F "seed=42" ``` +Frame interpolation is also available for supported Wan2.2 I2V requests. See +[Frame Interpolation](../../diffusion/frame_interpolation.md) for worker-side +execution details and feature constraints. + +### Frame Interpolation Example + +```bash +curl -X POST http://localhost:8091/v1/videos/sync \ + -F "prompt=A bear playing with yarn, smooth motion" \ + -F "input_reference=@/path/to/qwen-bear.png" \ + -F "width=832" \ + -F "height=480" \ + -F "num_frames=33" \ + -F "fps=16" \ + -F "num_inference_steps=40" \ + -F "guidance_scale=1.0" \ + -F "guidance_scale_2=1.0" \ + -F "enable_frame_interpolation=true" \ + -F "frame_interpolation_exp=1" \ + -F "frame_interpolation_scale=1.0" \ + -o sync_i2v_interpolated.mp4 +``` + ## Create Response Format `POST /v1/videos` returns a job record, not inline base64 video data. diff --git a/docs/user_guide/examples/online_serving/qwen2_5_omni.md b/docs/user_guide/examples/online_serving/qwen2_5_omni.md index 4357646924..b3a2c9f2ac 100644 --- a/docs/user_guide/examples/online_serving/qwen2_5_omni.md +++ b/docs/user_guide/examples/online_serving/qwen2_5_omni.md @@ -218,14 +218,6 @@ The gradio script supports the following arguments: - `--port`: Port for Gradio server (default: 7861) - `--share`: Share the Gradio demo publicly (creates a public link) -### FAQ - -If you encounter error about backend of librosa, try to install ffmpeg with command below. -``` -sudo apt update -sudo apt install ffmpeg -``` - ## Example materials ??? abstract "gradio_demo.py" diff --git a/docs/user_guide/examples/online_serving/qwen3_omni.md b/docs/user_guide/examples/online_serving/qwen3_omni.md index 69de24852f..611eb6fd3f 100644 --- a/docs/user_guide/examples/online_serving/qwen3_omni.md +++ b/docs/user_guide/examples/online_serving/qwen3_omni.md @@ -18,12 +18,12 @@ vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 If you want to open async chunking for qwen3-omni, launch the server with command below ```bash -vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --deploy-config /vllm_omni/deploy/qwen3_omni_moe.yaml ``` If you have custom stage configs file, launch the server with command below ```bash -vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /path/to/stage_configs_file +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --deploy-config /path/to/deploy_config_file ``` ### Send Multi-modal Request @@ -64,15 +64,6 @@ python openai_chat_completion_client_for_multimodal_generation.py \ bash run_curl_multimodal_generation.sh use_image ``` - -### FAQ - -If you encounter error about backend of librosa, try to install ffmpeg with command below. -``` -sudo apt update -sudo apt install ffmpeg -``` - ## Modality control You can control output modalities to specify which types of output the model should generate. This is useful when you only need text output and want to skip audio generation stages for better performance. @@ -196,7 +187,7 @@ The script supports the following arguments: - `--model`: Model name/path (default: Qwen/Qwen3-Omni-30B-A3B-Instruct) - `--server-port`: Port for vLLM server (default: 8091) - `--gradio-port`: Port for Gradio demo (default: 7861) -- `--stage-configs-path`: Path to custom stage configs YAML file (optional) +- `--deploy-config`: Path to custom deploy config YAML file (optional) - `--server-host`: Host for vLLM server (default: 0.0.0.0) - `--gradio-ip`: IP for Gradio demo (default: 127.0.0.1) - `--share`: Share Gradio demo publicly (creates a public link) @@ -211,7 +202,7 @@ vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 If you have custom stage configs file: ```bash -vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /path/to/stage_configs_file +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --deploy-config /path/to/deploy_config_file ``` **Step 2: Run the Gradio demo** diff --git a/docs/user_guide/examples/online_serving/qwen3_tts.md b/docs/user_guide/examples/online_serving/qwen3_tts.md index 156c4942cd..95f234f02d 100644 --- a/docs/user_guide/examples/online_serving/qwen3_tts.md +++ b/docs/user_guide/examples/online_serving/qwen3_tts.md @@ -58,7 +58,7 @@ Then open http://localhost:7860 in your browser. ```bash # CustomVoice model (predefined speakers) vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \ - --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \ + --deploy-config vllm_omni/deploy/qwen3_tts.yaml \ --omni \ --port 8091 \ --trust-remote-code \ @@ -66,7 +66,7 @@ vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \ # VoiceDesign model vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign \ - --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \ + --deploy-config vllm_omni/deploy/qwen3_tts.yaml \ --omni \ --port 8091 \ --trust-remote-code \ @@ -74,7 +74,7 @@ vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign \ # Base model (voice cloning) vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-Base \ - --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \ + --deploy-config vllm_omni/deploy/qwen3_tts.yaml \ --omni \ --port 8091 \ --trust-remote-code \ @@ -211,14 +211,6 @@ with open("output.wav", "wb") as f: f.write(response.content) ``` -### FAQ - -If you encounter error about backend of librosa, try to install ffmpeg with command below. -``` -sudo apt update -sudo apt install ffmpeg -``` - ## API Reference ### Voices Endpoint diff --git a/docs/user_guide/examples/online_serving/text_to_video.md b/docs/user_guide/examples/online_serving/text_to_video.md index d58296fcc7..00a9c16723 100644 --- a/docs/user_guide/examples/online_serving/text_to_video.md +++ b/docs/user_guide/examples/online_serving/text_to_video.md @@ -3,17 +3,28 @@ Source . -This example demonstrates how to deploy the Wan2.2 text-to-video model for online video generation using vLLM-Omni. +This example demonstrates how to deploy text-to-video models for online video generation using vLLM-Omni. -## Start Server +## Supported Models -### Basic Start +| Model | Model ID | +|-------|----------| +| Wan2.1 T2V (1.3B) | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` | +| Wan2.1 T2V (14B) | `Wan-AI/Wan2.1-T2V-14B-Diffusers` | +| Wan2.2 T2V | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | +| LTX-2 | `Lightricks/LTX-2` | + +## Wan2.2 T2V + +### Start Server + +#### Basic Start ```bash vllm serve Wan-AI/Wan2.2-T2V-A14B-Diffusers --omni --port 8091 ``` -### Start with Parameters +#### Start with Parameters Or use the startup script: @@ -154,6 +165,9 @@ curl -X POST http://localhost:8091/v1/videos \ -F "guidance_scale_2=4.0" \ -F "boundary_ratio=0.875" \ -F "flow_shift=5.0" \ + -F "enable_frame_interpolation=true" \ + -F "frame_interpolation_exp=1" \ + -F "frame_interpolation_scale=1.0" \ -F "seed=42" ``` @@ -176,6 +190,35 @@ curl -X POST http://localhost:8091/v1/videos \ | `flow_shift` | float | None | Scheduler flow shift (Wan2.2) | | `seed` | int | None | Random seed (reproducible) | | `lora` | object | None | LoRA configuration | +| `enable_frame_interpolation` | bool | false | Enable RIFE frame interpolation before MP4 encoding | +| `frame_interpolation_exp` | int | 1 | Interpolation exponent; 1=2x temporal resolution, 2=4x | +| `frame_interpolation_scale` | float | 1.0 | RIFE inference scale; use 0.5 for high-resolution inputs | +| `frame_interpolation_model_path` | str | None | Local directory or Hugging Face repo ID with `flownet.pkl`; defaults to `elfgum/RIFE-4.22.lite` | + +## Frame Interpolation + +Frame interpolation is an optional post-processing step for `/v1/videos` and +`/v1/videos/sync`. It synthesizes intermediate frames between generated frames +without rerunning the diffusion model. If the generated video has `N` frames, +the interpolated output frame count is `(N - 1) * 2**exp + 1`. The encoder FPS +is multiplied by `2**exp` so the output duration remains close to the original. + +Frame interpolation runs in the diffusion worker post-processing path instead of +the API server encoding path, so it can reuse the worker's current accelerator +device without blocking the FastAPI event loop. + +Example: generate 5 frames and interpolate to 9 frames: + +```bash +curl -X POST http://localhost:8091/v1/videos/sync \ + -F "prompt=A dog running through a park" \ + -F "num_frames=5" \ + -F "fps=8" \ + -F "enable_frame_interpolation=true" \ + -F "frame_interpolation_exp=1" \ + -F "frame_interpolation_scale=1.0" \ + -o sync_t2v_interpolated.mp4 +``` ## Create Response Format @@ -234,8 +277,94 @@ while true; do done ``` +## LTX-2 + +### Start Server + +#### Basic Start + +```bash +vllm serve Lightricks/LTX-2 --omni --port 8098 \ + --enforce-eager --flow-shift 1.0 --boundary-ratio 1.0 +``` + +#### Start with Optimization Presets + +Use the LTX-2 startup script with built-in optimization presets: + +```bash +# Baseline (1 GPU, eager) +bash run_server_ltx2.sh baseline + +# 4-GPU Ulysses sequence parallelism (lossless) +bash run_server_ltx2.sh ulysses4 + +# Cache-DiT lossy acceleration (1 GPU, ~1.4× speedup) +bash run_server_ltx2.sh cache-dit + +# Best combo: 4-GPU Ulysses SP + Cache-DiT (~2.2× speedup) +bash run_server_ltx2.sh best-combo +``` + +#### Optimization Benchmarks + +Benchmarked on H800, online serving (480×768, 41 frames, 20 steps, `seed=42`). +"Inference" is the server-reported inference time; excludes HTTP/poll overhead. + +| Preset | Server Command | Inference (s) | Speedup | Type | +|--------|---------------|---------------|---------|------| +| `baseline` | `--enforce-eager` | 10.3 | 1.00× | — | +| `compile` | *(default, no --enforce-eager)* | ~10.3 (warm) | ~1.00× | Lossless | +| `ulysses4` | `--enforce-eager --usp 4` | ~10.3 | ~1.00× | Lossless | +| `cache-dit` | `--enforce-eager --cache-backend cache_dit` | 7.4 avg | ~1.4× | Lossy | +| `best-combo` | `--enforce-eager --usp 4 --cache-backend cache_dit` | 4.7 avg | **~2.2×** | Lossless + Lossy | + +**Observations**: +- **torch.compile**: On H800, warm-request inference time matches the eager baseline (~10.3s). + The first request pays ~6s compilation overhead. Benefit depends on model architecture and GPU. +- **Ulysses SP (4 GPU)**: No measurable speedup alone for 41-frame generation at this resolution. + Communication overhead outweighs gains at this sequence length. +- **Cache-DiT**: Inference varies per request (6–10s) due to dynamic caching decisions. + Average is ~7.4s (~1.4× speedup) with slight quality tradeoff. +- **Best combo**: 4-GPU Ulysses SP + Cache-DiT synergize well — Cache-DiT reduces per-step + computation, making the communication overhead of Ulysses SP worthwhile. Average ~4.7s + (~2.2× speedup). +- **FP8 quantization**: Reduces VRAM but does not speed up LTX-2 on H800 (compute-bound). + +**Deployment Recommendations**: +- For **production with quality priority**: use `baseline` with `--enforce-eager` +- For **maximum throughput** (4 GPUs, quality tradeoff): use `best-combo` (~2.2× speedup) +- For **single-GPU throughput**: use `cache-dit` (~1.4× speedup) +- `--enforce-eager` is recommended to avoid torch.compile warmup latency on first request + +### Send Requests (curl) + +```bash +# Using the provided script +bash run_curl_ltx2.sh + +# Or directly +curl -sS -X POST http://localhost:8098/v1/videos \ + -H "Accept: application/json" \ + -F "prompt=A serene lakeside sunrise with mist over the water." \ + -F "width=768" \ + -F "height=480" \ + -F "num_frames=41" \ + -F "fps=24" \ + -F "num_inference_steps=20" \ + -F "guidance_scale=3.0" \ + -F "seed=42" +``` + ## Example materials +??? abstract "response.json" + ``````json + --8<-- "examples/online_serving/text_to_video/response.json" + `````` +??? abstract "run_curl_ltx2.sh" + ``````sh + --8<-- "examples/online_serving/text_to_video/run_curl_ltx2.sh" ??? abstract "run_curl_hunyuan_video_15.sh" ``````sh --8<-- "examples/online_serving/text_to_video/run_curl_hunyuan_video_15.sh" @@ -248,6 +377,9 @@ done ``````sh --8<-- "examples/online_serving/text_to_video/run_server.sh" `````` +??? abstract "run_server_ltx2.sh" + ``````sh + --8<-- "examples/online_serving/text_to_video/run_server_ltx2.sh" ??? abstract "run_server_hunyuan_video_15.sh" ``````sh --8<-- "examples/online_serving/text_to_video/run_server_hunyuan_video_15.sh" diff --git a/examples/offline_inference/bagel/README.md b/examples/offline_inference/bagel/README.md index 226c009f79..3e653d0e3a 100644 --- a/examples/offline_inference/bagel/README.md +++ b/examples/offline_inference/bagel/README.md @@ -173,8 +173,6 @@ Example configuration for TP=2 on GPUs 0 and 1: | Parameter | Value | Description | | :-------------------- | :------ | :------------------------------- | -| `window_size` | `-1` | Window size (-1 means unlimited) | -| `max_inflight` | `1` | Maximum inflight requests | | `shm_threshold_bytes` | `65536` | Shared memory threshold (64KB) | ## Using Mooncake Connector @@ -247,13 +245,6 @@ For more details on the Mooncake connector and multi-node setup, see the [Moonca ## FAQ -- If you encounter an error about the backend of librosa, try to install ffmpeg with the command below. - -```bash -sudo apt update -sudo apt install ffmpeg -``` - - If you don’t know how much VRAM is needed for the model or encounter the OOM error, you can try to decrease the max_model_len. | Stage | VRAM | diff --git a/examples/offline_inference/bagel/end2end.py b/examples/offline_inference/bagel/end2end.py index 472d748d1e..ed5fa57e8d 100644 --- a/examples/offline_inference/bagel/end2end.py +++ b/examples/offline_inference/bagel/end2end.py @@ -97,6 +97,24 @@ def parse_args(): default=False, help="Enable thinking mode: AR stage decodes ... planning tokens before image generation.", ) + parser.add_argument( + "--max-think-tokens", + type=int, + default=1000, + help="Maximum number of tokens for thinking text generation (default: 1000).", + ) + parser.add_argument( + "--do-sample", + action="store_true", + default=False, + help="Enable sampling for text generation (default: greedy).", + ) + parser.add_argument( + "--text-temperature", + type=float, + default=0.3, + help="Temperature for text generation sampling (default: 0.3).", + ) args = parser.parse_args() return args @@ -108,7 +126,6 @@ def main(): model_name = args.model prompts: list[OmniPromptType] = [] try: - # Preferred: load from txt file (one prompt per line) if getattr(args, "txt_prompts", None) and args.prompt_type == "text": with open(args.txt_prompts, encoding="utf-8") as f: lines = [ln.strip() for ln in f.readlines()] @@ -121,10 +138,8 @@ def main(): raise if not prompts: - # Default prompt for text2img test if none provided prompts = ["A cute cat"] print(f"[Info] No prompts provided, using default: {prompts}") - omni_outputs = [] from PIL import Image @@ -132,11 +147,13 @@ def main(): omni_kwargs = {} stage_configs_path = args.stage_configs_path + is_single_stage = stage_configs_path and "single_stage" in stage_configs_path if args.think and stage_configs_path is None: stage_configs_path = "vllm_omni/model_executor/stage_configs/bagel_think.yaml" print(f"[Info] Think mode enabled, using stage config: {stage_configs_path}") if stage_configs_path: omni_kwargs["stage_configs_path"] = stage_configs_path + is_single_stage = "single_stage" in stage_configs_path omni_kwargs.update( { @@ -198,40 +215,61 @@ def main(): formatted_prompts.append(prompt_dict) params_list = omni.default_sampling_params_list + + # For single-stage DiT, think/text params go into the diffusion sampling params extra_args. + # For 2-stage, diffusion params are at index 1. + diffusion_params_idx = 0 if is_single_stage else (1 if len(params_list) > 1 else 0) + diffusion_params = params_list[diffusion_params_idx] + if args.modality in ("text2img", "img2img"): - if len(params_list) > 1: - diffusion_params = params_list[1] - diffusion_params.num_inference_steps = args.steps # type: ignore - diffusion_params.cfg_parallel_size = args.cfg_parallel_size # type: ignore - if args.seed is not None: - diffusion_params.seed = args.seed # type: ignore - extra = { - "cfg_text_scale": args.cfg_text_scale, - "cfg_img_scale": args.cfg_img_scale, - } - if args.cfg_interval is not None: - extra["cfg_interval"] = tuple(args.cfg_interval) - if args.cfg_renorm_type is not None: - extra["cfg_renorm_type"] = args.cfg_renorm_type - if args.cfg_renorm_min is not None: - extra["cfg_renorm_min"] = args.cfg_renorm_min - if args.negative_prompt is not None: - extra["negative_prompt"] = args.negative_prompt - diffusion_params.extra_args = extra # type: ignore + diffusion_params.num_inference_steps = args.steps # type: ignore + diffusion_params.cfg_parallel_size = args.cfg_parallel_size # type: ignore + if args.seed is not None: + diffusion_params.seed = args.seed # type: ignore + + extra = getattr(diffusion_params, "extra_args", {}) or {} + extra["cfg_text_scale"] = args.cfg_text_scale + extra["cfg_img_scale"] = args.cfg_img_scale + if args.cfg_interval is not None: + extra["cfg_interval"] = tuple(args.cfg_interval) + if args.cfg_renorm_type is not None: + extra["cfg_renorm_type"] = args.cfg_renorm_type + if args.cfg_renorm_min is not None: + extra["cfg_renorm_min"] = args.cfg_renorm_min + if args.negative_prompt is not None: + extra["negative_prompt"] = args.negative_prompt + + needs_text_gen = is_single_stage and (args.think or args.modality in ("text2text", "img2text")) + if needs_text_gen: + if args.think: + extra["think"] = True + extra["max_think_tokens"] = args.max_think_tokens + extra["do_sample"] = args.do_sample + extra["text_temperature"] = args.text_temperature + diffusion_params.extra_args = extra # type: ignore omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list)) img_idx = 0 for req_output in omni_outputs: - if args.think: - ro = getattr(req_output, "request_output", None) - if ro and getattr(ro, "outputs", None): - txt = "".join(getattr(o, "text", "") or "" for o in ro.outputs) - if txt: - print(txt) + # 2-stage think mode: text output from thinker stage + ro = getattr(req_output, "request_output", None) + if ro and getattr(ro, "outputs", None): + txt = "".join(getattr(o, "text", "") or "" for o in ro.outputs) + if txt: + if args.think: + print(f"[Think]\n{txt}") + else: + print(f"[Output] Text:\n{txt}") - images = getattr(req_output, "images", None) + # Single-stage DiT: text from custom_output + custom = getattr(req_output, "_custom_output", {}) or {} + if custom.get("think_text"): + print(f"[Think]\n{custom['think_text']}") + if custom.get("text_output"): + print(f"[Output] Text:\n{custom['text_output']}") + images = getattr(req_output, "images", None) if not images: continue @@ -241,8 +279,6 @@ def main(): print(f"[Output] Saved image to {save_path}") img_idx += 1 - print(omni_outputs) - if __name__ == "__main__": main() diff --git a/examples/offline_inference/cosyvoice3/README.md b/examples/offline_inference/cosyvoice3/README.md index 895d3f660f..e16134e6ef 100644 --- a/examples/offline_inference/cosyvoice3/README.md +++ b/examples/offline_inference/cosyvoice3/README.md @@ -7,7 +7,7 @@ Install dependencies: uv pip install -e . ``` -> **Note:** This includes required libraries such as `librosa`, `soundfile`, +> **Note:** This includes required libraries such as `soundfile`, > `onnxruntime`, `x-transformers`, and `einops` via > `requirements/common.txt` and platform-specific requirements files. diff --git a/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py b/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py index 68ab72b387..6311bbc901 100644 --- a/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py +++ b/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py @@ -2,13 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import os -from pathlib import Path -import librosa import numpy as np import soundfile as sf from vllm import SamplingParams from vllm.assets.audio import AudioAsset +from vllm.multimodal.media.audio import load_audio from vllm_omni.entrypoints.omni import Omni from vllm_omni.model_executor.models.cosyvoice3.config import CosyVoice3Config @@ -16,22 +15,6 @@ from vllm_omni.model_executor.models.cosyvoice3.utils import extract_text_token -def _ensure_mel_filters_asset() -> None: - repo_root = Path(__file__).resolve().parents[3] - filters_path = repo_root / "vllm_omni" / "model_executor" / "models" / "cosyvoice3" / "assets" / "mel_filters.npz" - if filters_path.exists(): - return - - source_url = "https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/mel_filters.npz" - raise FileNotFoundError( - "Missing CosyVoice3 mel filter asset:\n" - f" {filters_path}\n" - "Download it with:\n" - f" mkdir -p {filters_path.parent} && " - f"curl -L {source_url} -o {filters_path}" - ) - - def run_e2e(): parser = argparse.ArgumentParser() # ""FunAudioLLM/Fun-CosyVoice3-0.5B-2512 @@ -56,7 +39,6 @@ def run_e2e(): help="Path to tokenizer directory (e.g., /CosyVoice-BlankEN).", ) args = parser.parse_args() - _ensure_mel_filters_asset() # Ensure tokenizer directory exists if not os.path.exists(args.tokenizer): raise FileNotFoundError(f"{args.tokenizer} does not exist!") @@ -85,7 +67,7 @@ def run_e2e(): if not os.path.exists(args.audio_path): raise FileNotFoundError(f"Audio file not found: {args.audio_path}") # Load at native sample rate - audio_signal, sr = librosa.load(args.audio_path, sr=None) + audio_signal, sr = load_audio(args.audio_path, sr=None) # Validate sample rate before processing (similar to original CosyVoice) min_sr = 16000 diff --git a/examples/offline_inference/hunyuan_image3/prompt_utils.py b/examples/offline_inference/hunyuan_image3/prompt_utils.py new file mode 100644 index 0000000000..a5ef8e1536 --- /dev/null +++ b/examples/offline_inference/hunyuan_image3/prompt_utils.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Prompt construction utilities for HunyuanImage-3.0-Instruct examples. + +Wraps system_prompt.get_system_prompt() with task-aware presets so that +examples and tests don't need to manually concatenate system prompts, +, , and tags. + +Usage: + from prompt_utils import build_prompt + + # IT2I (image editing, think+recaption mode) + prompt = build_prompt("Make the petals neon pink", task="it2i_think") + + # I2T (image understanding) + prompt = build_prompt("Describe the content of the picture.", task="i2t") +""" + +from __future__ import annotations + +from vllm_omni.diffusion.models.hunyuan_image3.system_prompt import ( + get_system_prompt, +) + +# task → (sys_type, bot_task, trigger_tag) +# trigger_tag: "", "", or None +_TASK_PRESETS: dict[str, tuple[str, str | None, str | None]] = { + # Pure text generation (text → text, no image) + "t2t": ("en_unified", None, None), + # Image understanding (image → text) + "i2t": ("en_unified", None, None), + # Image editing (image+text → image), think+recaption mode + "it2i_think": ("en_unified", "think", ""), + # Image editing, recaption-only mode + "it2i_recaption": ("en_unified", "recaption", ""), + # Text-to-image, think mode + "t2i_think": ("en_unified", "think", ""), + # Text-to-image, recaption mode + "t2i_recaption": ("en_unified", "recaption", ""), + # Text-to-image, vanilla (no CoT) + "t2i_vanilla": ("en_vanilla", "image", None), +} + + +def build_prompt( + user_prompt: str, + task: str = "it2i_think", + sys_type: str | None = None, + custom_system_prompt: str | None = None, +) -> str: + """Build a complete HunyuanImage-3.0 prompt with auto-selected system + prompt and mode trigger tags. + + Args: + user_prompt: The user's raw instruction or question. + task: One of the preset task keys (see _TASK_PRESETS). + sys_type: Override the preset's sys_type for get_system_prompt(). + custom_system_prompt: Custom system prompt text (used when + sys_type="custom"). + + Returns: + Fully formatted prompt string ready for Omni.generate(). + """ + if task not in _TASK_PRESETS: + raise ValueError(f"Unknown task {task!r}. Choose from: {sorted(_TASK_PRESETS)}") + + preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task] + effective_sys_type = sys_type or preset_sys_type + + system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt) + sys_text = system_prompt.strip() if system_prompt else "" + + has_image_input = task.startswith("i2t") or task.startswith("it2i") + + parts = ["<|startoftext|>"] + if sys_text: + parts.append(sys_text) + # Instruct conversation template: \n\nUser: ... \n\nAssistant: + parts.append("\n\nUser: ") + if has_image_input: + parts.append("") + parts.append(user_prompt) + parts.append("\n\nAssistant: ") + if trigger_tag: + parts.append(trigger_tag) + + return "".join(parts) diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py index a8035a3fdc..1a7e86f13c 100644 --- a/examples/offline_inference/image_to_image/image_edit.py +++ b/examples/offline_inference/image_to_image/image_edit.py @@ -297,8 +297,8 @@ def parse_args() -> argparse.Namespace: "--cfg-parallel-size", type=int, default=1, - choices=[1, 2], - help="Number of GPUs used for classifier free guidance parallel size.", + choices=[1, 2, 3], + help="Number of GPUs used for classifier free guidance parallel size (max 3 branches).", ) parser.add_argument( "--enforce-eager", diff --git a/examples/offline_inference/image_to_video/README.md b/examples/offline_inference/image_to_video/README.md index 2692c76df2..a458850a02 100644 --- a/examples/offline_inference/image_to_video/README.md +++ b/examples/offline_inference/image_to_video/README.md @@ -59,12 +59,13 @@ Key arguments: - `--negative-prompt`: Optional list of artifacts to suppress. - `--boundary-ratio`: Boundary split ratio for two-stage MoE models. - `--flow-shift`: Scheduler flow shift (5.0 for 720p, 12.0 for 480p). +- `--sample-solver`: Wan2.2 sampling solver. Use `unipc` for the default multistep solver, or `euler` for Lightning/Distill checkpoints. - `--num-inference-steps`: Number of denoising steps (default 50). - `--fps`: Frames per second for the saved MP4 (requires `diffusers` export_to_video). - `--output`: Path to save the generated video. - `--vae-use-slicing`: Enable VAE slicing for memory optimization. - `--vae-use-tiling`: Enable VAE tiling for memory optimization. -- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel). +- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](https://github.com/vllm-project/vllm-omni/tree/main/docs/user_guide/diffusion/parallelism/cfg_parallel.md). - `--tensor-parallel-size`: tensor parallel size (effective for models that support TP, e.g. LTX2). - `--enable-cpu-offload`: enable CPU offloading for diffusion models. - `--use-hsdp`: Enable Hybrid Sharded Data Parallel to shard model weights across GPUs. @@ -74,3 +75,6 @@ Key arguments: > ℹ️ If you encounter OOM errors, try using `--vae-use-slicing` and `--vae-use-tiling` to reduce memory usage. + +For Wan2.2 LightX2V-converted local Diffusers directories and related LoRA +assets, see the [LoRA guide](../../../docs/user_guide/diffusion/lora.md#wan22-lightx2v-offline-assembly). diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py index 7e7cfbf84e..53319c8221 100644 --- a/examples/offline_inference/image_to_video/image_to_video.py +++ b/examples/offline_inference/image_to_video/image_to_video.py @@ -84,6 +84,13 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--flow-shift", type=float, default=5.0, help="Scheduler flow_shift (5.0 for 720p, 12.0 for 480p)." ) + parser.add_argument( + "--sample-solver", + type=str, + default="unipc", + choices=["unipc", "euler"], + help="Sampling solver for Wan2.2 pipelines. Use 'euler' for Lightning/Distill setups.", + ) parser.add_argument("--output", type=str, default="i2v_output.mp4", help="Path to save the video (mp4).") parser.add_argument("--fps", type=int, default=None, help="Frames per second for the output video.") parser.add_argument( @@ -305,6 +312,7 @@ def main(): print(f" Model: {args.model}") print(f" Inference steps: {args.num_inference_steps}") print(f" Frames: {args.num_frames}") + print(f" Solver: {args.sample_solver}") print( f" Parallel configuration: cfg_parallel_size={args.cfg_parallel_size}," f" tensor_parallel_size={args.tensor_parallel_size}, vae_patch_parallel_size={args.vae_patch_parallel_size}" @@ -326,9 +334,14 @@ def main(): generator=generator, guidance_scale=guidance_scale, guidance_scale_2=args.guidance_scale_high, + boundary_ratio=args.boundary_ratio, num_inference_steps=num_inference_steps, num_frames=num_frames, frame_rate=frame_rate, + extra_args={ + "sample_solver": args.sample_solver, + "flow_shift": args.flow_shift, + }, ), ) generation_end = time.perf_counter() diff --git a/examples/offline_inference/mimo_audio/README.md b/examples/offline_inference/mimo_audio/README.md index 747e734cc2..596afabeef 100644 --- a/examples/offline_inference/mimo_audio/README.md +++ b/examples/offline_inference/mimo_audio/README.md @@ -190,29 +190,6 @@ Note: This task uses hardcoded message lists in the script. ## Troubleshooting -### Audio dependencies (soundfile, librosa) - -This example depends on **soundfile** (read/write WAV) and **librosa** (load audio including MP3). Install the project requirements first: - -```bash -pip install -r requirements/common.txt -# or at least: pip install soundfile>=0.13.1 librosa>=0.11.0 -``` - -- **`soundfile` / libsndfile not found** - `soundfile` uses the C library **libsndfile**. On Linux, install the system package before pip: - - Debian/Ubuntu: `sudo apt-get install libsndfile1` - - For development builds: `sudo apt-get install libsndfile1-dev` - - Then: `pip install soundfile` - -- **`librosa` fails to load MP3 or reports "No backend available"** - Loading MP3 (e.g. in `spoken_dialogue_sft_multiturn` with `.mp3` files) uses **ffmpeg** as the backend. Install ffmpeg: - - Debian/Ubuntu: `sudo apt-get install ffmpeg` - - macOS: `brew install ffmpeg` - -- **`ImportError: No module named 'soundfile'` or `ModuleNotFoundError: ... librosa`** - Ensure you are in the same Python environment where vLLM Omni and the example dependencies are installed, and that `requirements/common.txt` (or the packages above) are installed. - ### Tokenizer path - **`MIMO_AUDIO_TOKENIZER_PATH` not set or model fails to find tokenizer** diff --git a/examples/offline_inference/mimo_audio/message_convert.py b/examples/offline_inference/mimo_audio/message_convert.py index ebcc59c6b4..416f21ccfa 100644 --- a/examples/offline_inference/mimo_audio/message_convert.py +++ b/examples/offline_inference/mimo_audio/message_convert.py @@ -5,12 +5,12 @@ import re from collections.abc import Callable -import librosa import numpy as np import torch import torchaudio from process_speechdata import InputSegment, StreamingInputSegment from torchaudio.transforms import MelSpectrogram +from vllm.multimodal.media.audio import load_audio speech_zeroemb_idx = 151667 empty_token = "<|empty|>" @@ -685,7 +685,7 @@ def get_audio_data(audio_url): # File path audio_file = audio_url - audio_signal, sr = librosa.load(audio_file, sr=24000) + audio_signal, sr = load_audio(audio_file, sr=24000) audio_data = (audio_signal.astype(np.float32), sr) return audio_data diff --git a/examples/offline_inference/ming_flash_omni/README.md b/examples/offline_inference/ming_flash_omni/README.md new file mode 100644 index 0000000000..7414163fc0 --- /dev/null +++ b/examples/offline_inference/ming_flash_omni/README.md @@ -0,0 +1,76 @@ +# Ming-flash-omni 2.0 + +[Ming-flash-omni-2.0](https://github.com/inclusionAI/Ming) is an omni-modal model supporting text, image, video, and audio understanding, with outputs in text, image, and audio. For now, Ming-flash-omni-2.0 in vLLM-Omni is supported with thinker stage (multi-modal understanding). + +## Setup + +Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup. + +## Run examples + +### Text-only +```bash +python examples/offline_inference/ming_flash_omni/end2end.py --query-type text +``` + +#### Reasoning (Thinking Mode) + +Reasoning (Thinking) mode is enabled via applying "detailed thinking on" when building the system prompt template (in `apply_chat_template`). + +In the end2end example, a default problem for thinking mode is provided, as referred to the example usage of Ming's cookbook; +To utilize it, you have to download the example figure from https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/figures/cases/3_0.png + +```bash +python examples/offline_inference/ming_flash_omni/end2end.py -q reasoning --image-path ./3_0.png +``` + +### Image understanding +```bash +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_image + +# With a local image +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_image --image-path /path/to/image.jpg +``` + +### Audio understanding +```bash +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_audio + +# With a local audio file +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_audio --audio-path /path/to/audio.wav +``` + +### Video understanding +```bash +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_video + +# With a local video and custom frame count +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_video --video-path /path/to/video.mp4 --num-frames 16 +``` + +### Mixed modalities (image + audio) +```bash +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_mixed_modalities \ + --image-path /path/to/image.jpg \ + --audio-path /path/to/audio.wav +``` + +If media file paths are not provided, the script uses built-in default assets. + +### Modality control +To control output modalities (e.g. text-only output): +```bash +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_audio --modalities text +``` + +*For now, only text output is supported* + +### Custom stage config +```bash +python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_image \ + --stage-configs-path /path/to/your_config.yaml +``` + +## Online serving + +For online serving via the OpenAI-compatible API, see [examples/online_serving/ming_flash_omni/README.md](../../online_serving/ming_flash_omni/README.md). diff --git a/examples/offline_inference/ming_flash_omni/end2end.py b/examples/offline_inference/ming_flash_omni/end2end.py new file mode 100644 index 0000000000..49cdbcc018 --- /dev/null +++ b/examples/offline_inference/ming_flash_omni/end2end.py @@ -0,0 +1,485 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Partial example cases are referred from +# https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/cookbook.ipynb +import os +import time +from typing import NamedTuple + +import librosa +import numpy as np +import vllm +from PIL import Image +from transformers import AutoProcessor +from vllm import SamplingParams +from vllm.assets.audio import AudioAsset +from vllm.assets.image import ImageAsset +from vllm.assets.video import VideoAsset, video_to_ndarrays +from vllm.multimodal.image import convert_image_mode +from vllm.utils.argparse_utils import FlexibleArgumentParser + +import vllm_omni +from vllm_omni.entrypoints.omni import Omni + +# Imports the processor also registers itself +from vllm_omni.transformers_utils.processors.ming import MingFlashOmniProcessor # noqa: F401 + +SEED = 42 +MODEL_NAME = "Jonathan1909/Ming-flash-omni-2.0" + + +class QueryResult(NamedTuple): + inputs: dict + limit_mm_per_prompt: dict[str, int] + + +def get_text_query(processor: MingFlashOmniProcessor, question: str | None = None) -> QueryResult: + if question is None: + question = "请详细介绍鹦鹉的生活习性。" + conversation = [{"role": "HUMAN", "content": question}] + prompt = processor.apply_chat_template(conversation, tokenize=False) + return QueryResult( + inputs={"prompt": prompt}, + limit_mm_per_prompt={}, + ) + + +def get_image_query( + processor: MingFlashOmniProcessor, + question: str | None = None, + image_path: str | None = None, +) -> QueryResult: + if question is None: + question = "Describe this image in detail." + + if image_path: + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image file not found: {image_path}") + image_data = convert_image_mode(Image.open(image_path), "RGB") + else: + image_data = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") + + conversation = [ + { + "role": "HUMAN", + "content": [ + {"type": "image", "image": image_data}, + {"type": "text", "text": question}, + ], + } + ] + prompt = processor.apply_chat_template(conversation, tokenize=False) + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": {"image": image_data}, + }, + limit_mm_per_prompt={"image": 1}, + ) + + +def get_audio_query( + processor: MingFlashOmniProcessor, + question: str | None = None, + audio_path: str | None = None, + sampling_rate: int = 16000, +) -> QueryResult: + if question is None: + question = "Please recognize the language of this speech and transcribe it. Format: oral." + + if audio_path: + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + audio_signal, sr = librosa.load(audio_path, sr=sampling_rate) + audio_data = (audio_signal.astype(np.float32), sr) + else: + audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate + + # Use a string for "audio" so the processor counts it as 1 audio input + conversation = [ + { + "role": "HUMAN", + "content": [ + {"type": "audio", "audio": "input"}, + {"type": "text", "text": question}, + ], + } + ] + prompt = processor.apply_chat_template(conversation, tokenize=False) + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": {"audio": audio_data}, + }, + limit_mm_per_prompt={"audio": 1}, + ) + + +def get_video_query( + processor: MingFlashOmniProcessor, + question: str | None = None, + video_path: str | None = None, + num_frames: int = 16, +) -> QueryResult: + if question is None: + question = "Describe what is happening in this video." + + if video_path: + if not os.path.exists(video_path): + raise FileNotFoundError(f"Video file not found: {video_path}") + video_frames = video_to_ndarrays(video_path, num_frames=num_frames) + else: + video_frames = VideoAsset(name="baby_reading", num_frames=num_frames).np_ndarrays + + conversation = [ + { + "role": "HUMAN", + "content": [ + {"type": "video"}, + {"type": "text", "text": question}, + ], + } + ] + prompt = processor.apply_chat_template(conversation, tokenize=False) + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": {"video": video_frames}, + }, + limit_mm_per_prompt={"video": 1}, + ) + + +def get_mixed_modalities_query( + processor: MingFlashOmniProcessor, + image_path: str | None = None, + audio_path: str | None = None, + sampling_rate: int = 16000, +) -> QueryResult: + """Mixed image + audio understanding.""" + question = "Describe the image, and recognize the language of this speech and transcribe it. Format: oral" + + if image_path: + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image file not found: {image_path}") + image_data = convert_image_mode(Image.open(image_path), "RGB") + else: + image_data = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") + + if audio_path: + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + sig, sr = librosa.load(audio_path, sr=sampling_rate) + audio_data = (sig.astype(np.float32), sr) + else: + audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate + + conversation = [ + { + "role": "HUMAN", + "content": [ + {"type": "image", "image": image_data}, + {"type": "audio", "audio": "input"}, + {"type": "text", "text": question}, + ], + } + ] + prompt = processor.apply_chat_template(conversation, tokenize=False) + + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": {"image": image_data, "audio": audio_data}, + }, + limit_mm_per_prompt={"image": 1, "audio": 1}, + ) + + +def get_reasoning_query( + processor: MingFlashOmniProcessor, + question: str | None = None, + image_path: str | None = None, +) -> QueryResult: + if question is None: + # NOTE: To use the following default question, input with example figure provided by Ming + # https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/figures/cases/3_0.png + # E.g., + # python examples/offline_inference/ming_flash_omni/end2end.py -q reasoning --image-path ./3_0.png + # Otherwise, the problem solving might be false. + question = ( + "Based on the following rules:\n•\tYou control the smiley face character\n" + "•\tYou can move up, down, left, and right, and only a single square at a time\n" + "•\tWalls are dark grey and cannot be moved into\n•\tThe brown square is a box\n•" + "\tThe box can be pushed by moving into it (i.e., if you are in the square " + "adjacent to the box to the left, and move onto the square with the box, " + "the box will move one square to the right).\n" + "•\tThe box cannot be pushed into walls\n" + "•\tThe blue door at the bottom is locked and cannot be passed through, " + "unless the box is placed on the blue square\n" + "•\tThe square beneath the blue door is the exit\n" + "•\tMoving from one square to another\n\n" + "Let's assume a coordinate system where the smiley face is " + "on the top left at (1,1) and the square below it is (1,2). " + "The smiley face performs the following moves: {down, right, right, right}, " + "such that the smiley face is at square (4,2) and the box is in square (5,2). " + "What are the next sequence of moves that must be done to move the box down to (5,3)? " + "Give your answer as a comma separated list." + ) + + if image_path: + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image file not found: {image_path}") + image_data = convert_image_mode(Image.open(image_path), "RGB") + conversation = [ + { + "role": "HUMAN", + "content": [ + {"type": "image", "image": image_data}, + {"type": "text", "text": question}, + ], + } + ] + prompt = processor.apply_chat_template(conversation, tokenize=False, use_cot_system_prompt=True) + return QueryResult( + inputs={ + "prompt": prompt, + "multi_modal_data": {"image": image_data}, + }, + limit_mm_per_prompt={"image": 1}, + ) + + conversation = [{"role": "HUMAN", "content": question}] + prompt = processor.apply_chat_template(conversation, tokenize=False, use_cot_system_prompt=True) + return QueryResult( + inputs={"prompt": prompt}, + limit_mm_per_prompt={}, + ) + + +query_map = { + "text": get_text_query, + "use_audio": get_audio_query, + "use_image": get_image_query, + "use_video": get_video_query, + "use_mixed_modalities": get_mixed_modalities_query, + "reasoning": get_reasoning_query, +} + + +def main(args): + print( + "=" * 20, + "\n", + f"vllm version: {vllm.__version__}\n", + f"vllm-omni version: {vllm_omni.__version__}\n", + "=" * 20, + sep="", + ) + + processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) + assert isinstance(processor, MingFlashOmniProcessor), f"Wrong processor type being used: {type(processor)}" + + query_func = query_map[args.query_type] + if args.query_type == "use_image": + query_result = query_func(processor, image_path=args.image_path) + elif args.query_type == "use_audio": + query_result = query_func(processor, audio_path=args.audio_path, sampling_rate=args.sampling_rate) + elif args.query_type == "use_video": + query_result = query_func(processor, video_path=args.video_path, num_frames=args.num_frames) + elif args.query_type == "use_mixed_modalities": + query_result = query_func( + processor, + image_path=args.image_path, + audio_path=args.audio_path, + sampling_rate=args.sampling_rate, + ) + elif args.query_type == "reasoning": + query_result = query_func(processor, image_path=args.image_path) + else: + query_result = query_func(processor) + + # Initialize Omni (with thinker-only stage config) + omni = Omni( + model=MODEL_NAME, + stage_configs_path=args.stage_configs_path, + log_stats=args.log_stats, + init_timeout=args.init_timeout, + stage_init_timeout=args.stage_init_timeout, + ) + + # Thinker sampling params + thinker_sampling_params = SamplingParams( + temperature=0.4, + top_p=0.9, + max_tokens=args.max_tokens, + repetition_penalty=1.05, + seed=SEED, + detokenize=True, + ) + sampling_params_list = [thinker_sampling_params] + + prompts = [query_result.inputs for _ in range(args.num_prompts)] + + if args.modalities is not None: + output_modalities = args.modalities.split(",") + for prompt in prompts: + prompt["modalities"] = output_modalities + + total_requests = len(prompts) + processed_count = 0 + print(f"Query type: {args.query_type}") + print(f"Number of prompts: {total_requests}") + + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + + profiler_enabled = args.enable_profiler + if profiler_enabled: + omni.start_profile(stages=args.profiler_stages) + + for stage_outputs in omni.generate(prompts, sampling_params_list): + output = stage_outputs.request_output + if stage_outputs.final_output_type == "text": + request_id = output.request_id + text_output = output.outputs[0].text + lines = [] + lines.append("Prompt:\n") + lines.append(str(output.prompt) + "\n") + lines.append("Text Output:\n") + lines.append(str(text_output).strip() + "\n") + print(*lines, sep="") + + # Save to file + out_txt = os.path.join(output_dir, f"{request_id}.txt") + try: + with open(out_txt, "w", encoding="utf-8") as f: + f.writelines(lines) + print(f"Request ID: {request_id}, text saved to {out_txt}") + except Exception as e: + print(f"Failed to write output file {out_txt}: {e}") + + elif stage_outputs.final_output_type == "audio": + raise NotImplementedError("Add audio example after talker supported.") + + processed_count += 1 + if profiler_enabled and processed_count >= total_requests: + print(f"[Info] Processed {processed_count}/{total_requests}. Stopping profiler inside active loop...") + # Stop the profiler while workers are still alive + omni.stop_profile(stages=args.profiler_stages) + + print("[Info] Waiting 30s for workers to write trace files to disk...") + time.sleep(30) + print("[Info] Trace export wait time finished.") + + omni.close() + + +def parse_args(): + parser = FlexibleArgumentParser(description="Ming-flash-omni 2.0 offline inference example") + parser.add_argument( + "--query-type", + "-q", + type=str, + default="text", + choices=query_map.keys(), + help="Query type.", + ) + parser.add_argument( + "--stage-configs-path", + type=str, + default=None, + help="Path to a stage configs YAML file.", + ) + parser.add_argument( + "--log-stats", + action="store_true", + default=False, + help="Enable detailed statistics logging.", + ) + parser.add_argument("--init-timeout", type=int, default=2000, help="Timeout for initializing in seconds.") + parser.add_argument( + "--stage-init-timeout", + type=int, + default=2000, + help="Timeout for initializing a single stage in seconds.", + ) + parser.add_argument( + "--enable-profiler", + action="store_true", + default=False, + help="Enables profiling when set.", + ) + parser.add_argument( + "--profiler-stages", + type=int, + nargs="*", + default=[0], + help="List of stage IDs to profile. If not set, profiles all stages.", + ) + parser.add_argument( + "--image-path", + "-i", + type=str, + default=None, + help="Path to local image file. Uses default asset if not provided.", + ) + parser.add_argument( + "--audio-path", + "-a", + type=str, + default=None, + help="Path to local audio file. Uses default asset if not provided.", + ) + parser.add_argument( + "--video-path", + "-v", + type=str, + default=None, + help="Path to local video file. Uses default asset if not provided.", + ) + parser.add_argument( + "--num-frames", + type=int, + default=16, + help="Number of frames to extract from video.", + ) + parser.add_argument( + "--sampling-rate", + type=int, + default=16000, + help="Sampling rate for audio loading.", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=16384, + help="Maximum tokens to generate.", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1, + help="Number of prompts to generate.", + ) + parser.add_argument( + "--modalities", + type=str, + default=None, + help="Output modalities (comma-separated).", + ) + parser.add_argument( + "--output-dir", + type=str, + default="output_ming", + help="Output directory for results.", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/offline_inference/omnivoice/end2end.py b/examples/offline_inference/omnivoice/end2end.py index b41379b011..9371c95142 100644 --- a/examples/offline_inference/omnivoice/end2end.py +++ b/examples/offline_inference/omnivoice/end2end.py @@ -103,9 +103,9 @@ def run_e2e(): if not os.path.exists(args.ref_audio): raise FileNotFoundError(f"Reference audio not found: {args.ref_audio}") - import librosa + from vllm.multimodal.media.audio import load_audio - audio_signal, sr = librosa.load(args.ref_audio, sr=None) + audio_signal, sr = load_audio(args.ref_audio, sr=None) multi_modal_data["audio"] = (audio_signal.astype(np.float32), sr) mm_processor_kwargs["ref_text"] = args.ref_text or "" mm_processor_kwargs["sample_rate"] = sr diff --git a/examples/offline_inference/qwen2_5_omni/README.md b/examples/offline_inference/qwen2_5_omni/README.md index 20740a0da0..e2eae8a96b 100644 --- a/examples/offline_inference/qwen2_5_omni/README.md +++ b/examples/offline_inference/qwen2_5_omni/README.md @@ -60,11 +60,3 @@ If media file paths are not provided, the script will use default assets. Suppor - `mixed_modalities`: Audio + image + video - `use_audio_in_video`: Extract audio from video - `text`: Text-only query - -### FAQ - -If you encounter error about backend of librosa, try to install ffmpeg with command below. -``` -sudo apt update -sudo apt install ffmpeg -``` diff --git a/examples/offline_inference/qwen2_5_omni/end2end.py b/examples/offline_inference/qwen2_5_omni/end2end.py index 7bba599830..dfe124700d 100644 --- a/examples/offline_inference/qwen2_5_omni/end2end.py +++ b/examples/offline_inference/qwen2_5_omni/end2end.py @@ -9,7 +9,6 @@ import time from typing import NamedTuple -import librosa import numpy as np import soundfile as sf from PIL import Image @@ -17,6 +16,7 @@ from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset, video_to_ndarrays from vllm.multimodal.image import convert_image_mode +from vllm.multimodal.media.audio import load_audio from vllm.sampling_params import SamplingParams from vllm.utils.argparse_utils import FlexibleArgumentParser @@ -96,7 +96,7 @@ def get_mixed_modalities_query( if audio_path: if not os.path.exists(audio_path): raise FileNotFoundError(f"Audio file not found: {audio_path}") - audio_signal, sr = librosa.load(audio_path, sr=sampling_rate) + audio_signal, sr = load_audio(audio_path, sr=sampling_rate) audio_data = (audio_signal.astype(np.float32), sr) else: audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate @@ -130,7 +130,7 @@ def get_use_audio_in_video_query( raise FileNotFoundError(f"Video file not found: {video_path}") video_frames = video_to_ndarrays(video_path, num_frames=num_frames) # Extract audio from video file - audio_signal, sr = librosa.load(video_path, sr=sampling_rate) + audio_signal, sr = load_audio(video_path, sr=sampling_rate) audio = (audio_signal.astype(np.float32), sr) else: asset = VideoAsset(name="baby_reading", num_frames=num_frames) @@ -165,7 +165,7 @@ def get_multi_audios_query(audio_path: str | None = None, sampling_rate: int = 1 if audio_path: if not os.path.exists(audio_path): raise FileNotFoundError(f"Audio file not found: {audio_path}") - audio_signal, sr = librosa.load(audio_path, sr=sampling_rate) + audio_signal, sr = load_audio(audio_path, sr=sampling_rate) audio_data = (audio_signal.astype(np.float32), sr) # Use the provided audio as the first audio, default as second audio_list = [ @@ -261,7 +261,7 @@ def get_audio_query(question: str = None, audio_path: str | None = None, samplin if audio_path: if not os.path.exists(audio_path): raise FileNotFoundError(f"Audio file not found: {audio_path}") - audio_signal, sr = librosa.load(audio_path, sr=sampling_rate) + audio_signal, sr = load_audio(audio_path, sr=sampling_rate) audio_data = (audio_signal.astype(np.float32), sr) else: audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate @@ -320,14 +320,7 @@ def main(args): query_result = query_func(audio_path=audio_path, sampling_rate=sampling_rate) else: query_result = query_func() - omni = Omni( - model=model_name, - log_stats=args.log_stats, - stage_init_timeout=args.stage_init_timeout, - batch_timeout=args.batch_timeout, - init_timeout=args.init_timeout, - shm_threshold_bytes=args.shm_threshold_bytes, - ) + omni = Omni.from_cli_args(args, model=model_name) thinker_sampling_params = SamplingParams( temperature=0.0, # Deterministic - no randomness top_p=1.0, # Disable nucleus sampling diff --git a/examples/offline_inference/qwen3_omni/README.md b/examples/offline_inference/qwen3_omni/README.md index b3e8592532..0710faa133 100644 --- a/examples/offline_inference/qwen3_omni/README.md +++ b/examples/offline_inference/qwen3_omni/README.md @@ -70,8 +70,8 @@ For true stage-level concurrency -- where downstream stages (Talker, Code2Wav) start **before** the upstream stage (Thinker) finishes -- use the async_chunk example. This requires: -1. A stage config YAML with ``async_chunk: true`` (e.g. - ``qwen3_omni_moe_async_chunk.yaml``). +1. A deploy config YAML with ``async_chunk: true`` (e.g. + ``qwen3_omni_moe.yaml``). 2. Hardware that matches the config (e.g. 2x H100 for the default 3-stage config). @@ -101,18 +101,10 @@ python end2end_async_chunk.py --query-type text --modalities text ```bash python end2end_async_chunk.py \ --query-type use_audio \ - --stage-configs-path /path/to/your_async_chunk.yaml + --deploy-config /path/to/your_deploy_config.yaml ``` > **Note**: The synchronous ``end2end.py`` (using ``Omni``) is still the > recommended entry point for non-async-chunk workflows. Only use the > async_chunk example when you need the stage-level concurrency semantics > described in PR #962 / #1151. - -### FAQ - -If you encounter error about backend of librosa, try to install ffmpeg with command below. -``` -sudo apt update -sudo apt install ffmpeg -``` diff --git a/examples/offline_inference/qwen3_omni/end2end.py b/examples/offline_inference/qwen3_omni/end2end.py index 155eca4ed9..f028c32aa1 100644 --- a/examples/offline_inference/qwen3_omni/end2end.py +++ b/examples/offline_inference/qwen3_omni/end2end.py @@ -9,7 +9,6 @@ import time from typing import NamedTuple -import librosa import numpy as np import soundfile as sf import vllm @@ -19,6 +18,7 @@ from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset, video_to_ndarrays from vllm.multimodal.image import convert_image_mode +from vllm.multimodal.media.audio import load_audio from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm_omni.entrypoints.omni import Omni @@ -129,7 +129,7 @@ def get_audio_query(question: str = None, audio_path: str | None = None, samplin if audio_path: if not os.path.exists(audio_path): raise FileNotFoundError(f"Audio file not found: {audio_path}") - audio_signal, sr = librosa.load(audio_path, sr=sampling_rate) + audio_signal, sr = load_audio(audio_path, sr=sampling_rate) audio_data = (audio_signal.astype(np.float32), sr) else: audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate @@ -183,7 +183,7 @@ def get_mixed_modalities_query( if audio_path: if not os.path.exists(audio_path): raise FileNotFoundError(f"Audio file not found: {audio_path}") - audio_signal, sr = librosa.load(audio_path, sr=sampling_rate) + audio_signal, sr = load_audio(audio_path, sr=sampling_rate) audio_data = (audio_signal.astype(np.float32), sr) else: audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate @@ -294,14 +294,7 @@ def main(args): else: query_result = query_func() - omni = Omni( - model=model_name, - dtype=args.dtype, - stage_configs_path=args.stage_configs_path, - log_stats=args.log_stats, - stage_init_timeout=args.stage_init_timeout, - init_timeout=args.init_timeout, - ) + omni = Omni.from_cli_args(args, model=model_name) thinker_sampling_params = SamplingParams( temperature=0.9, diff --git a/examples/offline_inference/qwen3_omni/end2end_async_chunk.py b/examples/offline_inference/qwen3_omni/end2end_async_chunk.py index 8adbae9eb6..f38922e943 100644 --- a/examples/offline_inference/qwen3_omni/end2end_async_chunk.py +++ b/examples/offline_inference/qwen3_omni/end2end_async_chunk.py @@ -14,7 +14,7 @@ Usage ----- python end2end_async_chunk.py --query-type use_audio \ - --stage-configs-path + --deploy-config See ``--help`` for all options. """ @@ -32,13 +32,13 @@ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" -import librosa from PIL import Image from vllm import SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset, video_to_ndarrays from vllm.multimodal.image import convert_image_mode +from vllm.multimodal.media.audio import load_audio from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm_omni.entrypoints.async_omni import AsyncOmni @@ -89,7 +89,7 @@ def get_audio_query( if audio_path: if not os.path.exists(audio_path): raise FileNotFoundError(f"Audio file not found: {audio_path}") - audio_signal, sr = librosa.load(audio_path, sr=sampling_rate) + audio_signal, sr = load_audio(audio_path, sr=sampling_rate) audio_data = (audio_signal.astype(np.float32), sr) else: audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate @@ -179,20 +179,26 @@ def clone_prompt_for_request(template: dict) -> dict: return cloned -def _default_async_chunk_stage_configs_path() -> str | None: - """Best-effort default stage config for running Qwen3-Omni with async_chunk. +def _default_deploy_config_path() -> str | None: + """Best-effort default deploy config for running Qwen3-Omni with async_chunk. - When this example is executed from within the repository, we resolve the - default YAML path relative to this file. When installed elsewhere, the - file may not exist and callers should pass --stage-configs-path explicitly. + The default ``vllm_omni/deploy/qwen3_omni_moe.yaml`` ships with + ``async_chunk: true`` at the top level, so loading it is enough to + enable async-chunk semantics. To disable it, copy the YAML and set + ``async_chunk: false`` (or pass ``--deploy-config`` to a YAML that + overrides the flag). + + When this example is executed from within the repository, we resolve + the default YAML path relative to this file. When installed elsewhere, + the file may not exist and callers should pass ``--deploy-config`` + explicitly. """ repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) candidate = os.path.join( repo_root, "vllm_omni", - "model_executor", - "stage_configs", - "qwen3_omni_moe_async_chunk.yaml", + "deploy", + "qwen3_omni_moe.yaml", ) return candidate if os.path.exists(candidate) else None @@ -374,15 +380,16 @@ async def run_all(args): prompt["modalities"] = output_modalities # Create AsyncOmni - print(f"[Info] Creating AsyncOmni with stage_configs_path={args.stage_configs_path}") + print(f"[Info] Creating AsyncOmni with deploy_config={args.deploy_config}") async_omni = None try: - async_omni = AsyncOmni( - model=args.model, - stage_configs_path=args.stage_configs_path, - log_stats=args.log_stats, - stage_init_timeout=args.stage_init_timeout, - ) + # ``from_cli_args`` expands vars(args) into kwargs and auto-captures + # ``_cli_explicit_keys`` from ``sys.argv[1:]`` so argparse defaults + # do not silently override deploy YAML values. Mirrors the + # ``EngineArgs.from_cli_args`` pattern used throughout vllm / + # vllm-omni. ``deploy_config=None`` (the default) falls through to + # the bundled ``vllm_omni/deploy/qwen3_omni_moe.yaml``. + async_omni = AsyncOmni.from_cli_args(args) # Use default sampling params from stage config (they are pre-configured # in the YAML for each stage). @@ -470,11 +477,11 @@ def parse_args(): help="Query type.", ) parser.add_argument( - "--stage-configs-path", + "--deploy-config", type=str, - default=_default_async_chunk_stage_configs_path(), + default=_default_deploy_config_path(), help=( - "Path to an async_chunk stage config YAML. " + "Path to a deploy config YAML. " "If not set, uses the model's default config " "(make sure it has async_chunk: true)." ), diff --git a/examples/offline_inference/qwen3_omni/run_multiple_prompts_async_chunk.sh b/examples/offline_inference/qwen3_omni/run_multiple_prompts_async_chunk.sh index 809054867c..2f2be20915 100755 --- a/examples/offline_inference/qwen3_omni/run_multiple_prompts_async_chunk.sh +++ b/examples/offline_inference/qwen3_omni/run_multiple_prompts_async_chunk.sh @@ -17,7 +17,7 @@ REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" python "${SCRIPT_DIR}/end2end_async_chunk.py" \ --query-type text \ --txt-prompts "${SCRIPT_DIR}/text_prompts_10.txt" \ - --stage-configs-path "${REPO_ROOT}/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml" \ + --deploy-config "${REPO_ROOT}/vllm_omni/deploy/qwen3_omni_moe.yaml" \ --output-dir output_audio_async_chunk \ --max-in-flight 2 \ "$@" diff --git a/examples/offline_inference/qwen3_omni/run_single_prompt_async_chunk.sh b/examples/offline_inference/qwen3_omni/run_single_prompt_async_chunk.sh index 918c7ee4fd..9ef69293cb 100755 --- a/examples/offline_inference/qwen3_omni/run_single_prompt_async_chunk.sh +++ b/examples/offline_inference/qwen3_omni/run_single_prompt_async_chunk.sh @@ -6,13 +6,13 @@ # achieving true stage-level concurrency via chunk-level streaming. # # Prerequisites: -# - An async_chunk stage config YAML (e.g. qwen3_omni_moe_async_chunk.yaml) +# - A deploy config YAML (e.g. qwen3_omni_moe.yaml) # - Hardware matching the config (e.g. 2x H100 for the default 3-stage config) # # Usage: # bash run_single_prompt_async_chunk.sh # bash run_single_prompt_async_chunk.sh --query-type text --modalities text -# bash run_single_prompt_async_chunk.sh --stage-configs-path /path/to/custom.yaml +# bash run_single_prompt_async_chunk.sh --deploy-config /path/to/custom.yaml set -euo pipefail @@ -21,6 +21,6 @@ REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" python "${SCRIPT_DIR}/end2end_async_chunk.py" \ --query-type use_audio \ - --stage-configs-path "${REPO_ROOT}/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml" \ + --deploy-config "${REPO_ROOT}/vllm_omni/deploy/qwen3_omni_moe.yaml" \ --output-dir output_audio_async_chunk \ "$@" diff --git a/examples/offline_inference/qwen3_tts/README.md b/examples/offline_inference/qwen3_tts/README.md index bf59dc9ba4..2971ad716a 100644 --- a/examples/offline_inference/qwen3_tts/README.md +++ b/examples/offline_inference/qwen3_tts/README.md @@ -15,11 +15,11 @@ Please refer to the [stage configuration documentation](https://docs.vllm.ai/pro ### ROCm Dependencies -You will need to install these two dependencies `onnxruntime-rocm` and `sox`. +You will need to install the dependency `onnxruntime-rocm`. ``` pip uninstall onnxruntime # should be removed before we can install onnxruntime-rocm -pip install onnxruntime-rocm sox +pip install onnxruntime-rocm ``` ## Quick Start @@ -104,13 +104,13 @@ completes. This demonstrates that audio data is available progressively rather t ## Batched Decoding -The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, provide a stage config with `max_num_seqs > 1` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`. +The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, set `max_num_seqs > 1` on both stages via `--stage-overrides` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`. ``` python end2end.py --query-type CustomVoice \ --txt-prompts benchmark_prompts.txt \ --batch-size 4 \ - --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml + --stage-overrides '{"0":{"max_num_seqs":4,"gpu_memory_utilization":0.2},"1":{"max_num_seqs":4,"gpu_memory_utilization":0.2}}' ``` **Important:** `--batch-size` must match a CUDA graph capture size (1, 2, 4, 8, 16...) because the Talker's code predictor KV cache is sized to `max_num_seqs`, and CUDA graphs pad the batch to the next capture size. Both stages need `max_num_seqs >= batch_size` in the stage config for batching to take effect. If only stage 1 has a higher `max_num_seqs`, it won't help — stage 1 can only batch chunks from requests that are in-flight simultaneously, which requires stage 0 to also process multiple requests concurrently. diff --git a/examples/offline_inference/qwen3_tts/end2end.py b/examples/offline_inference/qwen3_tts/end2end.py index 901418c39b..77da356b4f 100644 --- a/examples/offline_inference/qwen3_tts/end2end.py +++ b/examples/offline_inference/qwen3_tts/end2end.py @@ -366,12 +366,7 @@ def main(args): output_dir = args.output_dir os.makedirs(output_dir, exist_ok=True) - omni = Omni( - model=model_name, - stage_configs_path=args.stage_configs_path, - log_stats=args.log_stats, - stage_init_timeout=args.stage_init_timeout, - ) + omni = Omni.from_cli_args(args, model=model_name) batch_size = args.batch_size for batch_start in range(0, len(inputs), batch_size): @@ -387,12 +382,7 @@ async def main_streaming(args): output_dir = args.output_dir os.makedirs(output_dir, exist_ok=True) - omni = AsyncOmni( - model=model_name, - stage_configs_path=args.stage_configs_path, - log_stats=args.log_stats, - stage_init_timeout=args.stage_init_timeout, - ) + omni = AsyncOmni.from_cli_args(args, model=model_name) for i, prompt in enumerate(inputs): request_id = str(i) diff --git a/examples/offline_inference/text_to_audio/README.md b/examples/offline_inference/text_to_audio/README.md index 7edc38092a..50bab3e2f2 100644 --- a/examples/offline_inference/text_to_audio/README.md +++ b/examples/offline_inference/text_to_audio/README.md @@ -23,6 +23,7 @@ python text_to_audio.py \ --guidance-scale 7.0 \ --audio-length 10.0 \ --num-inference-steps 100 \ + --cache-backend tea_cache \ --output stable_audio_output.wav ``` @@ -34,4 +35,5 @@ Key arguments: - `--guidance-scale`: classifier-free guidance scale. - `--audio-length`: audio duration in seconds. - `--num-inference-steps`: diffusion sampling steps.(more steps = higher quality, slower). +- `--cache-backend`: cache acceleration backend. Stable Audio currently supports `tea_cache`. - `--output`: path to save the generated WAV file. diff --git a/examples/offline_inference/text_to_audio/text_to_audio.py b/examples/offline_inference/text_to_audio/text_to_audio.py index a6968c419f..3adb3ad53a 100644 --- a/examples/offline_inference/text_to_audio/text_to_audio.py +++ b/examples/offline_inference/text_to_audio/text_to_audio.py @@ -11,6 +11,7 @@ python text_to_audio.py --prompt "The sound of a dog barking" python text_to_audio.py --prompt "A piano playing a gentle melody" --audio-length 10.0 python text_to_audio.py --prompt "Thunder and rain sounds" --negative-prompt "Low quality" + python text_to_audio.py --prompt "A soft synth pad" --cache-backend tea_cache """ import argparse @@ -90,6 +91,23 @@ def parse_args() -> argparse.Namespace: default=44100, help="Sample rate for output audio (Stable Audio uses 44100 Hz).", ) + parser.add_argument( + "--cache-backend", + type=str, + default=None, + choices=["tea_cache"], + help=( + "Cache backend to use for acceleration. " + "Stable Audio currently supports 'tea_cache'. " + "Default: None (no cache acceleration)." + ), + ) + parser.add_argument( + "--tea-cache-rel-l1-thresh", + type=float, + default=0.2, + help="[tea_cache] Threshold for accumulated relative L1 distance.", + ) parser.add_argument( "--enable-diffusion-pipeline-profiler", action="store_true", @@ -124,6 +142,11 @@ def save_audio(audio_data: np.ndarray, output_path: str, sample_rate: int = 4410 def main(): args = parse_args() generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed) + cache_config = None + if args.cache_backend == "tea_cache": + cache_config = { + "rel_l1_thresh": args.tea_cache_rel_l1_thresh, + } print(f"\n{'=' * 60}") print("Stable Audio Open - Text-to-Audio Generation") @@ -134,12 +157,15 @@ def main(): print(f" Audio length: {args.audio_length}s") print(f" Inference steps: {args.num_inference_steps}") print(f" Guidance scale: {args.guidance_scale}") + print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}") print(f" Seed: {args.seed}") print(f"{'=' * 60}\n") # Initialize Omni with Stable Audio model omni = Omni( model=args.model, + cache_backend=args.cache_backend, + cache_config=cache_config, enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler, ) diff --git a/examples/offline_inference/text_to_image/README.md b/examples/offline_inference/text_to_image/README.md index cc295e8279..c71773972b 100644 --- a/examples/offline_inference/text_to_image/README.md +++ b/examples/offline_inference/text_to_image/README.md @@ -29,7 +29,8 @@ This folder provides several entrypoints for experimenting with text-to-image di | `AIDC-AI/Ovis-Image-7B` | 1024 x 1024 | 71.8 | 17.1 | | `OmniGen2/OmniGen2` | 1024 x 1024 | 20.1 | 14.7 | | `stabilityai/stable-diffusion-3.5-medium` | 1024 x 1024 | 20.1 | 15.6 | -| `black-forest-labs/FLUX.1-dev` | 1024 x 1024 | 77.6 | 31.4 | +| `black-forest-labs/FLUX.1-dev` | 1024 x 1024 | 33.9 | 31.4 | +| `black-forest-labs/FLUX.1-schnell` | 1024 x 1024 | 33.9 | 31.4 | | `black-forest-labs/FLUX.2-klein-4B` | 1024 x 1024 | 72.7 | 14.9 | | `black-forest-labs/FLUX.2-klein-9B` | 1024 x 1024 | 37.1 | 32.3 | | `black-forest-labs/FLUX.2-dev` | 1024 x 1024 | 65.7 | >80 (CPU offload required) | diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 615e4067ed..3b3f8e77cf 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -242,6 +242,18 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Enable logging of diffusion pipeline stats.", ) + parser.add_argument( + "--init-timeout", + type=int, + default=600, + help="Timeout for initializing a single stage in seconds (default: 600s)", + ) + parser.add_argument( + "--stage-init-timeout", + type=int, + default=600, + help="Timeout for initializing a single stage in seconds (default: 600s)", + ) parser.add_argument( "--use-system-prompt", type=str, @@ -346,6 +358,8 @@ def main(): "mode": "text-to-image", "log_stats": args.log_stats, "enable_diffusion_pipeline_profiler": args.enable_diffusion_pipeline_profiler, + "init_timeout": args.init_timeout, + "stage_init_timeout": args.stage_init_timeout, **lora_args, **quant_kwargs, } diff --git a/examples/offline_inference/voxcpm/README.md b/examples/offline_inference/voxcpm/README.md new file mode 100644 index 0000000000..1eaea9b0db --- /dev/null +++ b/examples/offline_inference/voxcpm/README.md @@ -0,0 +1,123 @@ +# VoxCPM Offline Example + +This directory contains the minimal offline VoxCPM example for vLLM Omni. + +`end2end.py` is intentionally small and only covers: + +- single text-to-speech +- single voice cloning with `ref_audio` + `ref_text` +- non-streaming with `vllm_omni/model_executor/stage_configs/voxcpm.yaml` +- streaming with `vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml` + +Advanced workflows were moved out of the getting-started example: + +- `benchmarks/voxcpm/vllm_omni/bench_tts_offline.py`: warmup, batch prompts, profiler, offline TTFP / RTF +- `benchmarks/voxcpm/vllm_omni/run_offline_matrix.py`: fixed offline smoke matrix +- `benchmarks/voxcpm/`: benchmark scripts and benchmark docs + +## Prerequisites + +Install VoxCPM in one of these ways: + +```bash +pip install voxcpm +``` + +or point vLLM Omni to the local VoxCPM source tree: + +```bash +export VLLM_OMNI_VOXCPM_CODE_PATH=/path/to/VoxCPM/src +``` + +The example writes WAV files with `soundfile`: + +```bash +pip install soundfile +``` + +## Model Path + +Pass the native VoxCPM model directory directly: + +```bash +export VOXCPM_MODEL=/path/to/voxcpm-model +``` + +If the native VoxCPM `config.json` does not contain HuggingFace metadata such as +`model_type`, prepare a persistent HF-compatible config directory and point the +stage configs to it with `VLLM_OMNI_VOXCPM_HF_CONFIG_PATH`: + +```bash +export VLLM_OMNI_VOXCPM_HF_CONFIG_PATH=/tmp/voxcpm_hf_config +mkdir -p "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH" +cp "$VOXCPM_MODEL/config.json" "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH/config.json" +cp "$VOXCPM_MODEL/generation_config.json" "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH/generation_config.json" 2>/dev/null || true +python3 -c 'import json, os; p=os.path.join(os.environ["VLLM_OMNI_VOXCPM_HF_CONFIG_PATH"], "config.json"); cfg=json.load(open(p, "r", encoding="utf-8")); cfg["model_type"]="voxcpm"; cfg.setdefault("architectures", ["VoxCPMForConditionalGeneration"]); json.dump(cfg, open(p, "w", encoding="utf-8"), indent=2, ensure_ascii=False)' +``` + +If the model directory itself already has `model_type`, this extra directory is +not required. + +## Quick Start + +Single text-to-speech, non-streaming: + +```bash +python examples/offline_inference/voxcpm/end2end.py \ + --model "$VOXCPM_MODEL" \ + --text "This is a split-stage VoxCPM synthesis example running on vLLM Omni." +``` + +Single voice cloning, non-streaming: + +```bash +python examples/offline_inference/voxcpm/end2end.py \ + --model "$VOXCPM_MODEL" \ + --text "This sentence is synthesized with a cloned voice." \ + --ref-audio /path/to/reference.wav \ + --ref-text "The exact transcript spoken in reference.wav." +``` + +Streaming: + +```bash +python examples/offline_inference/voxcpm/end2end.py \ + --model "$VOXCPM_MODEL" \ + --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml \ + --text "This is a split-stage VoxCPM streaming example running on vLLM Omni." +``` + +By default, `end2end.py` writes to `output_audio/` for non-streaming and +`output_audio_streaming/` for streaming. + +## Advanced Workflows + +Use `benchmarks/voxcpm/vllm_omni/bench_tts_offline.py` when you need: + +- warmup runs +- prompt files +- batch JSONL inputs +- profiler injection +- offline TTFP / RTF emission + +Use `benchmarks/voxcpm/vllm_omni/run_offline_matrix.py` when you need the fixed offline smoke matrix that previously lived in `test.py`. + +Full matrix benchmark example: + +```bash +python benchmarks/voxcpm/vllm_omni/run_offline_matrix.py \ + --model "$VOXCPM_MODEL" \ + --ref-audio /path/to/reference.wav \ + --ref-text "The exact transcript spoken in reference.wav." +``` + +For online serving examples, see [examples/online_serving/voxcpm](../../online_serving/voxcpm/README.md). + +For benchmark reporting, see [benchmarks/voxcpm](../../../benchmarks/voxcpm/README.md). + +## Notes + +- `voxcpm.yaml` is the default non-streaming stage config. +- `voxcpm_async_chunk.yaml` is the streaming stage config. +- Streaming is currently single-request oriented; the fixed smoke matrix now lives in `benchmarks/voxcpm/vllm_omni/run_offline_matrix.py`. +- `ref_text` must be the real transcript of the reference audio. Mismatched text usually causes obvious quality degradation. diff --git a/examples/offline_inference/voxcpm/end2end.py b/examples/offline_inference/voxcpm/end2end.py new file mode 100644 index 0000000000..980410feae --- /dev/null +++ b/examples/offline_inference/voxcpm/end2end.py @@ -0,0 +1,206 @@ +"""Minimal offline VoxCPM example for vLLM Omni.""" + +from __future__ import annotations + +import asyncio +import time +from pathlib import Path +from typing import Any + +import soundfile as sf +import torch +from vllm.utils.argparse_utils import FlexibleArgumentParser + +from vllm_omni import AsyncOmni, Omni + +REPO_ROOT = Path(__file__).resolve().parents[3] +DEFAULT_SYNC_STAGE_CONFIG = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm.yaml" + + +def _build_prompt(args) -> dict[str, Any]: + additional_information: dict[str, list[Any]] = { + "text": [args.text], + "cfg_value": [args.cfg_value], + "inference_timesteps": [args.inference_timesteps], + "min_len": [args.min_len], + "max_new_tokens": [args.max_new_tokens], + } + if args.streaming_prefix_len is not None: + additional_information["streaming_prefix_len"] = [args.streaming_prefix_len] + if args.ref_audio is not None: + additional_information["ref_audio"] = [args.ref_audio] + if args.ref_text is not None: + additional_information["ref_text"] = [args.ref_text] + return { + "prompt_token_ids": [1], + "additional_information": additional_information, + } + + +def _extract_audio_tensor(mm: dict[str, Any]) -> torch.Tensor: + audio = mm.get("audio", mm.get("model_outputs")) + if audio is None: + raise ValueError("No audio output found in multimodal output.") + if isinstance(audio, list): + parts = [torch.as_tensor(item).float().cpu().reshape(-1) for item in audio] + audio = torch.cat(parts, dim=-1) if parts else torch.zeros(0) + if not isinstance(audio, torch.Tensor): + audio = torch.as_tensor(audio) + return audio.float().cpu().reshape(-1) + + +def _extract_sample_rate(mm: dict[str, Any]) -> int: + sr_raw = mm.get("sr", 24000) + if isinstance(sr_raw, list) and sr_raw: + sr_raw = sr_raw[-1] + if hasattr(sr_raw, "item"): + return int(sr_raw.item()) + return int(sr_raw) + + +def _is_streaming_stage_config(stage_config_path: str) -> bool: + return "async_chunk" in Path(stage_config_path).stem + + +def _save_audio(audio: torch.Tensor, sample_rate: int, output_dir: Path, request_id: str) -> Path: + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / f"output_{request_id}.wav" + sf.write( + output_path, + audio.float().cpu().clamp(-1.0, 1.0).numpy(), + sample_rate, + format="WAV", + subtype="PCM_16", + ) + return output_path + + +async def _run_streaming(args) -> Path: + prompt = _build_prompt(args) + output_dir = Path(args.output_dir) if args.output_dir is not None else Path("output_audio_streaming") + request_id = "streaming_example" + sample_rate = 24000 + buffered_samples = 0 + chunks: list[torch.Tensor] = [] + started = time.perf_counter() + omni = AsyncOmni( + model=args.model, + stage_configs_path=args.stage_configs_path, + log_stats=args.log_stats, + stage_init_timeout=args.stage_init_timeout, + ) + try: + async for stage_output in omni.generate(prompt, request_id=request_id): + mm = getattr(stage_output, "multimodal_output", None) + if not isinstance(mm, dict): + request_output = getattr(stage_output, "request_output", None) + if request_output is None: + continue + mm = getattr(request_output, "multimodal_output", None) + if not isinstance(mm, dict) and getattr(request_output, "outputs", None): + mm = getattr(request_output.outputs[0], "multimodal_output", None) + if not isinstance(mm, dict): + continue + audio = _extract_audio_tensor(mm) + if audio.numel() == 0: + continue + sample_rate = _extract_sample_rate(mm) + if audio.numel() > buffered_samples: + delta = audio[buffered_samples:] + buffered_samples = int(audio.numel()) + else: + delta = audio + buffered_samples += int(delta.numel()) + if delta.numel() > 0: + chunks.append(delta) + if not chunks: + raise RuntimeError("No streaming audio chunks received from VoxCPM.") + output_audio = torch.cat(chunks, dim=0) + output_path = _save_audio(output_audio, sample_rate, output_dir, request_id) + print(f"Saved streaming audio to: {output_path} ({time.perf_counter() - started:.2f}s)") + return output_path + finally: + omni.shutdown() + + +def _run_sync(args) -> Path: + prompt = _build_prompt(args) + output_dir = Path(args.output_dir) if args.output_dir is not None else Path("output_audio") + request_id = "sync_example" + started = time.perf_counter() + last_mm: dict[str, Any] | None = None + omni = Omni( + model=args.model, + stage_configs_path=args.stage_configs_path, + log_stats=args.log_stats, + stage_init_timeout=args.stage_init_timeout, + ) + for stage_outputs in omni.generate(prompt): + request_output = getattr(stage_outputs, "request_output", None) + if request_output is None: + continue + outputs = getattr(request_output, "outputs", None) + if outputs: + for output in outputs: + mm = getattr(output, "multimodal_output", None) + if isinstance(mm, dict): + last_mm = mm + mm = getattr(request_output, "multimodal_output", None) + if isinstance(mm, dict): + last_mm = mm + if last_mm is None: + raise RuntimeError("No audio output received from VoxCPM.") + output_path = _save_audio( + _extract_audio_tensor(last_mm), + _extract_sample_rate(last_mm), + output_dir, + request_id, + ) + print(f"Saved audio to: {output_path} ({time.perf_counter() - started:.2f}s)") + return output_path + + +def parse_args(): + parser = FlexibleArgumentParser(description="Minimal offline VoxCPM example for vLLM Omni.") + parser.add_argument("--model", type=str, required=True, help="Local VoxCPM model directory.") + parser.add_argument( + "--stage-configs-path", + type=str, + default=str(DEFAULT_SYNC_STAGE_CONFIG), + help=("Stage config path. Use voxcpm.yaml for non-streaming or voxcpm_async_chunk.yaml for streaming."), + ) + parser.add_argument("--text", type=str, required=True, help="Input text for synthesis.") + parser.add_argument("--ref-audio", type=str, default=None, help="Reference audio path for voice cloning.") + parser.add_argument("--ref-text", type=str, default=None, help="Transcript of the reference audio.") + parser.add_argument("--output-dir", type=str, default=None, help="Output directory for generated wav files.") + parser.add_argument("--cfg-value", type=float, default=2.0, help="Guidance value passed to VoxCPM.") + parser.add_argument("--inference-timesteps", type=int, default=10, help="Number of diffusion timesteps.") + parser.add_argument("--min-len", type=int, default=2, help="Minimum latent length.") + parser.add_argument("--max-new-tokens", type=int, default=4096, help="Maximum latent length.") + parser.add_argument( + "--streaming-prefix-len", + type=int, + default=3, + help="Streaming prefix length used by voxcpm_async_chunk.yaml.", + ) + parser.add_argument("--stage-init-timeout", type=int, default=600, help="Stage initialization timeout in seconds.") + parser.add_argument("--log-stats", action="store_true", help="Enable vLLM Omni stats logging.") + args = parser.parse_args() + if (args.ref_audio is None) != (args.ref_text is None): + raise ValueError("Voice cloning requires --ref-audio and --ref-text together.") + return args + + +def main(args) -> None: + route = "streaming" if _is_streaming_stage_config(args.stage_configs_path) else "sync" + print(f"Model: {args.model}") + print(f"Stage config: {args.stage_configs_path}") + print(f"Route: {route}") + if route == "streaming": + asyncio.run(_run_streaming(args)) + else: + _run_sync(args) + + +if __name__ == "__main__": + main(parse_args()) diff --git a/examples/offline_inference/voxcpm2/README.md b/examples/offline_inference/voxcpm2/README.md new file mode 100644 index 0000000000..e982730799 --- /dev/null +++ b/examples/offline_inference/voxcpm2/README.md @@ -0,0 +1,83 @@ +# VoxCPM2 Offline Inference (Native AR) + +VoxCPM2 is a 2B-parameter tokenizer-free diffusion AR TTS model. It produces 48kHz audio and supports 30+ languages with a single-stage native AR pipeline backed by MiniCPM4. + +## Prerequisites + +Install the `voxcpm` package, or set the environment variable pointing to the source tree: + +```bash +# Option A: install package +pip install voxcpm + +# Option B: use source checkout +export VLLM_OMNI_VOXCPM_CODE_PATH=/path/to/voxcpm +``` + +## Quick Start + +Zero-shot synthesis: + +```bash +python examples/offline_inference/voxcpm2/end2end.py \ + --model openbmb/VoxCPM2 \ + --text "Hello, this is a VoxCPM2 demo." \ + --output-dir output_audio +``` + +Voice cloning with a reference audio: + +```bash +python examples/offline_inference/voxcpm2/end2end.py \ + --text "Hello, this is a voice clone demo." \ + --reference-audio /path/to/reference.wav \ + --output-dir output_clone +``` + +Prompt continuation (matched audio + text prefix): + +```bash +python examples/offline_inference/voxcpm2/end2end.py \ + --text "Continuation target sentence." \ + --prompt-audio /path/to/prompt.wav \ + --prompt-text "Transcript of the prompt audio." \ + --output-dir output_cont +``` + +The script accepts the following arguments: + +| Argument | Default | Description | +|---|---|---| +| `--model` | `openbmb/VoxCPM2` | HuggingFace repo ID or local path | +| `--text` | (example sentence) | Text to synthesize | +| `--output-dir` | `output_audio` | Directory for output WAV files | +| `--stage-configs-path` | `voxcpm2.yaml` | Stage config YAML path | +| `--reference-audio` | `None` | Reference audio for voice cloning (isolated) | +| `--prompt-audio` | `None` | Prompt audio for continuation mode | +| `--prompt-text` | `None` | Transcript matching `--prompt-audio` | + +## Performance + +Measured on a single H20 GPU (80 GB): + +| Input length | RTF | Sample rate | +|---|---|---| +| Short (~10 tokens) | ~0.28 | 48 kHz | +| Long (~100 tokens) | ~0.34 | 48 kHz | + +RTF < 1.0 means faster than real time. + +## Architecture + +VoxCPM2 uses a single-stage native AR pipeline: + +``` +feat_encoder +└─► MiniCPM4 (base LM) + └─► FSQ (finite scalar quantization) + └─► residual_lm (residual AR) + └─► LocDiT (local diffusion transformer) + └─► AudioVAE → 48 kHz waveform +``` + +All stages are fused into one vllm-native execution graph via `voxcpm2.yaml`, eliminating inter-stage coordination overhead and enabling true end-to-end batching. diff --git a/examples/offline_inference/voxcpm2/end2end.py b/examples/offline_inference/voxcpm2/end2end.py new file mode 100644 index 0000000000..6b6bf78ddf --- /dev/null +++ b/examples/offline_inference/voxcpm2/end2end.py @@ -0,0 +1,171 @@ +"""Offline VoxCPM2 inference example (native AR pipeline). + +Uses the single-stage native AR config (voxcpm2.yaml). +Requires the `voxcpm` package or VLLM_OMNI_VOXCPM_CODE_PATH env var. +""" + +from __future__ import annotations + +import os +import time +from pathlib import Path + +import soundfile as sf +import torch +from vllm.utils.argparse_utils import FlexibleArgumentParser + +from vllm_omni import Omni + +REPO_ROOT = Path(__file__).resolve().parents[3] +DEFAULT_STAGE_CONFIGS_PATH = str(REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm2.yaml") +SAMPLE_RATE = 48_000 + + +def parse_args(): + parser = FlexibleArgumentParser(description="Offline VoxCPM2 native AR inference") + parser.add_argument( + "--model", + type=str, + default="openbmb/VoxCPM2", + help="VoxCPM2 model path or HuggingFace repo ID.", + ) + parser.add_argument( + "--text", + type=str, + default="This is a VoxCPM2 native AR synthesis example running on vLLM Omni.", + help="Text to synthesize.", + ) + parser.add_argument( + "--output-dir", + type=str, + default="output_audio", + help="Directory for output WAV files.", + ) + parser.add_argument( + "--stage-configs-path", + type=str, + default=DEFAULT_STAGE_CONFIGS_PATH, + help="Path to the stage config YAML file.", + ) + parser.add_argument( + "--reference-audio", + type=str, + default=None, + help="Path to reference audio for voice cloning (isolated ref mode).", + ) + parser.add_argument( + "--prompt-audio", + type=str, + default=None, + help="Path to prompt audio for continuation mode (requires --prompt-text).", + ) + parser.add_argument( + "--prompt-text", + type=str, + default=None, + help="Text matching --prompt-audio for continuation mode.", + ) + parser.add_argument( + "--ref-text", + type=str, + default=None, + help="Optional transcript of --reference-audio (enables ref_continuation mode).", + ) + return parser.parse_args() + + +def extract_audio(multimodal_output: dict) -> torch.Tensor: + """Extract the final complete audio tensor from multimodal output. + + The output processor concatenates per-step delta tensors under + ``model_outputs``. Falls back to ``audio`` for backwards compat. + """ + audio = multimodal_output.get("model_outputs") + if audio is None: + audio = multimodal_output.get("audio") + if audio is None: + raise ValueError(f"No audio key in multimodal_output: {list(multimodal_output.keys())}") + + if isinstance(audio, list): + # Defensive: usually the output processor consolidates into a single + # tensor at request completion, but concatenate here too in case the + # caller consumes intermediate (pre-consolidation) outputs. + valid = [torch.as_tensor(a).float().cpu().reshape(-1) for a in audio if a is not None] + if not valid: + raise ValueError("Audio list is empty or all elements are None.") + return torch.cat(valid, dim=0) if len(valid) > 1 else valid[0] + + return torch.as_tensor(audio).float().cpu().reshape(-1) + + +def main(): + args = parse_args() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + engine = Omni( + model=args.model, + stage_configs_path=args.stage_configs_path, + ) + + from transformers import AutoTokenizer + + from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import ( + build_cjk_split_map, + build_voxcpm2_prompt, + ) + + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + split_map = build_cjk_split_map(tokenizer) + hf_config = engine.engine.stage_vllm_configs[0].model_config.hf_config + + ref_audio_arg = args.reference_audio or args.prompt_audio + ref_text_arg = args.ref_text or args.prompt_text + ref_wav, ref_sr = (None, None) + if ref_audio_arg: + ref_wav_arr, ref_sr = sf.read(ref_audio_arg) + ref_wav = ref_wav_arr.mean(axis=-1).tolist() if ref_wav_arr.ndim > 1 else ref_wav_arr.tolist() + + prompt = build_voxcpm2_prompt( + hf_config=hf_config, + tokenizer=tokenizer, + split_map=split_map, + text=args.text, + ref_audio=ref_wav, + ref_sr=ref_sr, + ref_text=ref_text_arg, + ) + + print(f"Model : {args.model}") + print(f"Text : {args.text}") + if ref_audio_arg: + print(f"Ref audio : {ref_audio_arg}") + if ref_text_arg: + print(f"Ref text : {ref_text_arg}") + print(f"Output dir : {output_dir}") + + t_start = time.perf_counter() + outputs = engine.generate([prompt]) + elapsed = time.perf_counter() - t_start + + # outputs[0].outputs[0].multimodal_output["audio"] is a list of tensors + request_output = outputs[0] + mm = request_output.outputs[0].multimodal_output + audio = extract_audio(mm) + + duration = audio.numel() / SAMPLE_RATE + rtf = elapsed / duration if duration > 0 else float("inf") + + output_path = output_dir / "output.wav" + sf.write(str(output_path), audio.numpy(), SAMPLE_RATE, format="WAV") + + print(f"Saved : {output_path}") + print(f"Duration : {duration:.2f}s") + print(f"Inference : {elapsed:.2f}s") + print(f"RTF : {rtf:.3f}") + + +if __name__ == "__main__": + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + main() diff --git a/examples/offline_inference/x_to_video_audio/x_to_video_audio.md b/examples/offline_inference/x_to_video_audio/x_to_video_audio.md index 4b5188f41b..13f2cfe7c0 100644 --- a/examples/offline_inference/x_to_video_audio/x_to_video_audio.md +++ b/examples/offline_inference/x_to_video_audio/x_to_video_audio.md @@ -30,9 +30,9 @@ dreamid_omni/ ``` ### Run the Inference -``` +```python python x_to_video_audio.py \ - --model /xx/dreamid_omni \ + --model /path/to/dreamid_omni \ --prompt "Two people walking together and singing happily" \ --image-path ./example0.png ./example1.png \ --audio-path ./example0.wav ./example1.wav \ @@ -42,11 +42,33 @@ python x_to_video_audio.py \ --num-inference-steps 45 \ --height 704 \ --width 1280 \ - --output dreamid_omni.mp4 + --output out_dreamid_omni_twoip.mp4 ``` In the current test scenario (2 images + 2 audio inputs), the VRAM requirement is 72GB, regardless of whether cfg-parallel is enabled or disabled. The VRAM usage can be reduced by enabling CPU offload via --enable-cpu-offload. + +You could take reference images/audios from the test cases in the official repo: https://github.com/Guoxu1233/DreamID-Omni + +For example, single IP ref resources can be found under https://github.com/Guoxu1233/DreamID-Omni/tree/main/test_case/oneip, you could download them correspondingly to your local and use them for testing. + +```python +# Example usage for oneip, ref media from the official repo DreamID-Omni +python x_to_video_audio.py \ + --model /path/to/dreamid_omni \ + --prompt ": In the frame, a woman with black long hair is identified as .\n**Overall Environment/Scene**: A lively open-kitchen café at night; stove flames flare, steam rises, and warm pendant lights swing slightly as staff move behind her. The shot is an upper-body close-up.\n**Main Characters/Subjects Appearance**: is a young woman with thick dark wavy hair and a side part. She wears a fitted black top under a light apron, a thin gold chain necklace, and small stud earrings.\n**Main Characters/Subjects Actions**: tastes the sauce with a spoon, then turns her face toward the camera while still holding the spoon, her expression shifting from focused to conflicted.\n maintains eye contact, swallows as if choosing her words, and says, I keep telling myself I’m fine,but some nights it feels like I’m just performing calm." \ + --image-path 9.png \ + --audio-path 9.wav \ + --video-negative-prompt "jitter, bad hands, blur, distortion" \ + --audio-negative-prompt "robotic, muffled, echo, distorted" \ + --cfg-parallel-size 2 \ + --num-inference-steps 45 \ + --height 704 \ + --width 1280 \ + --output out_dreamid_omni_oneip.mp4 +``` + + Key arguments: - `--prompt`: text description (string). - `--model`: path to the model local directory. diff --git a/examples/offline_inference/x_to_video_audio/x_to_video_audio.py b/examples/offline_inference/x_to_video_audio/x_to_video_audio.py index e0424add69..497284ceb9 100644 --- a/examples/offline_inference/x_to_video_audio/x_to_video_audio.py +++ b/examples/offline_inference/x_to_video_audio/x_to_video_audio.py @@ -5,10 +5,12 @@ import re import time -import librosa +import numpy as np from PIL import Image +from vllm.multimodal.media.audio import load_audio from vllm_omni.diffusion.data import DiffusionParallelConfig +from vllm_omni.diffusion.utils.media_utils import mux_video_audio_bytes from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniDiffusionSamplingParams @@ -36,8 +38,8 @@ def parse_args() -> argparse.Namespace: "--cfg-parallel-size", type=int, default=1, - choices=[1, 2], - help="Number of GPUs used for classifier free guidance parallel size.", + choices=[1, 2, 3, 4], + help="Number of GPUs used for classifier free guidance parallel size (max 4 branches).", ) parser.add_argument( "--video-negative-prompt", @@ -56,6 +58,11 @@ def parse_args() -> argparse.Namespace: default=False, help="Enable CPU offloading for diffusion models.", ) + parser.add_argument( + "--enable-layerwise-offload", + action="store_true", + help="Enable layerwise (blockwise) offloading on DiT modules.", + ) return parser.parse_args() @@ -69,7 +76,7 @@ def load_image_and_audio(image_paths, audio_paths): image.append(img) for path in audio_paths: - audio_array, sr = librosa.load(path, sr=16000) + audio_array, sr = load_audio(path, sr=16000) audio_array = audio_array[int(sr * 1) : int(sr * 3)] audio.append(audio_array) return image, audio @@ -124,6 +131,7 @@ def main() -> None: parallel_config=parallel_config, model_type=args.model_type, enable_cpu_offload=args.enable_cpu_offload, + enable_layerwise_offload=args.enable_layerwise_offload, ) start = time.perf_counter() outputs = omni.generate(prompt, sampling_params) @@ -131,15 +139,35 @@ def main() -> None: if not outputs: raise RuntimeError("No output returned from DreamID-Omni.") - output = outputs[0].request_output - generated_video = output.images[0][0] - generated_audio = output.images[0][1] - try: - from dreamid_omni.utils.io_utils import save_video - except Exception as e: - raise RuntimeError(f"Failed to extract video and audio from DreamID-Omni output. Error: {e}") + result = outputs[0] + if not result.images: + raise RuntimeError("No video frames found in DreamID-Omni output.") + generated_video = result.images[0] + mm = result.multimodal_output or {} + generated_audio = mm.get("audio") + fps = int(mm.get("fps", 24)) + sample_rate = int(mm.get("audio_sample_rate", 16000)) + + # DreamID-Omni returns video as (C, F, H, W) float32 in [-1, 1]. + # mux_video_audio_bytes expects (F, H, W, C) uint8. + if not isinstance(generated_video, np.ndarray) or generated_video.ndim != 4: + raise RuntimeError(f"Unexpected video shape: {getattr(generated_video, 'shape', None)}") + frames = generated_video.transpose(1, 2, 3, 0) + frames = (np.clip((frames + 1.0) / 2.0, 0.0, 1.0) * 255.0).round().astype(np.uint8) + + audio_np = None + if generated_audio is not None: + audio_np = np.squeeze(np.asarray(generated_audio)).astype(np.float32) + output_path = args.output - save_video(output_path, generated_video, generated_audio, fps=24, sample_rate=16000) + video_bytes = mux_video_audio_bytes( + frames, + audio_np, + fps=float(fps), + audio_sample_rate=sample_rate, + ) + with open(output_path, "wb") as f: + f.write(video_bytes) print(f"Saved generated video to {output_path}") print(f"Total time: {elapsed:.2f}s") diff --git a/examples/online_serving/bagel/README.md b/examples/online_serving/bagel/README.md index 9b74acae10..0939bc5f38 100644 --- a/examples/online_serving/bagel/README.md +++ b/examples/online_serving/bagel/README.md @@ -354,13 +354,6 @@ curl http://localhost:8091/v1/chat/completions \ ## FAQ -- If you encounter an error about the backend of librosa, try to install ffmpeg with the command below. - -```bash -sudo apt update -sudo apt install ffmpeg -``` - - If you don’t know how much VRAM is needed for the model or encounter the OOM error, you can try to decrease the max_model_len. | Stage | VRAM | diff --git a/examples/online_serving/image_to_video/README.md b/examples/online_serving/image_to_video/README.md index 49283bd9a0..285eeb2798 100644 --- a/examples/online_serving/image_to_video/README.md +++ b/examples/online_serving/image_to_video/README.md @@ -26,6 +26,23 @@ The script allows overriding: - `CACHE_BACKEND` (default: `none`) - `ENABLE_CACHE_DIT_SUMMARY` (default: `0`) +### Ascend / Local LightX2V Example + +For a local Wan2.2-LightX2V Diffusers directory on Ascend/NPU, you can start the server like this: + +```bash +vllm serve /path/to/Wan2.2-I2V-A14B-LightX2V-Diffusers-Lightning \ + --omni \ + --port 8091 \ + --flow-shift 12 \ + --cfg-parallel-size 1 \ + --ulysses-degree 4 \ + --use-hsdp \ + --trust-remote-code \ + --allowed-local-media-path / \ + --seed 42 +``` + ## Async Job Behavior `POST /v1/videos` is asynchronous. It creates a video job and immediately @@ -69,10 +86,35 @@ curl -X POST http://localhost:8091/v1/videos/sync \ -F "guidance_scale_2=1.0" \ -F "boundary_ratio=0.875" \ -F "flow_shift=12.0" \ + -F 'extra_params={"sample_solver":"euler"}' \ -F "seed=42" \ -o sync_i2v_output.mp4 ``` +For Wan Lightning/Distill checkpoints, pass `{"sample_solver":"euler"}` via `extra_params`. The default solver is `unipc`. + +Example matching the local LightX2V deployment above: + +```bash +curl -sS -X POST http://localhost:8091/v1/videos/sync \ + -H "Accept: video/mp4" \ + -F "prompt=A cat playing with yarn" \ + -F "input_reference=@/path/to/input.jpg" \ + -F "width=832" \ + -F "height=480" \ + -F "num_frames=81" \ + -F "fps=16" \ + -F "num_inference_steps=4" \ + -F "guidance_scale=1.0" \ + -F "guidance_scale_2=1.0" \ + -F "boundary_ratio=0.875" \ + -F "seed=42" \ + -F 'extra_params={"sample_solver":"euler"}' \ + -o ./output.mp4 +``` + +Use `/v1/videos/sync` if you want to write the MP4 directly to a file. `POST /v1/videos` is async and returns job metadata, not inline `b64_json`. + ## Storage Generated video files are stored on local disk by the async video API. @@ -96,6 +138,9 @@ export VLLM_OMNI_STORAGE_MAX_CONCURRENCY=8 # Basic image-to-video generation bash run_curl_image_to_video.sh +# Wan Lightning/Distill checkpoints +SAMPLE_SOLVER=euler bash run_curl_image_to_video.sh + # Or execute directly (OpenAI-style multipart) create_response=$(curl -s http://localhost:8091/v1/videos \ -H "Accept: application/json" \ @@ -111,6 +156,7 @@ create_response=$(curl -s http://localhost:8091/v1/videos \ -F "guidance_scale_2=1.0" \ -F "boundary_ratio=0.875" \ -F "flow_shift=12.0" \ + -F 'extra_params={"sample_solver":"euler"}' \ -F "seed=42") video_id=$(echo "$create_response" | jq -r '.id') @@ -169,9 +215,12 @@ curl -X POST http://localhost:8091/v1/videos \ -F "guidance_scale_2=1.0" \ -F "boundary_ratio=0.875" \ -F "flow_shift=12.0" \ + -F 'extra_params={"sample_solver":"euler"}' \ -F "seed=42" ``` +`sample_solver` is supported by Wan2.2 online serving through the existing `extra_params` field, which is merged into the pipeline `extra_args`. Use `unipc` for the default multistep solver, or `euler` for Lightning/Distill checkpoints. + ## Create Response Format `POST /v1/videos` returns a job record, not inline base64 video data. diff --git a/examples/online_serving/image_to_video/run_curl_image_to_video.sh b/examples/online_serving/image_to_video/run_curl_image_to_video.sh index f4c1496a69..6f6a6f96d5 100644 --- a/examples/online_serving/image_to_video/run_curl_image_to_video.sh +++ b/examples/online_serving/image_to_video/run_curl_image_to_video.sh @@ -7,6 +7,7 @@ INPUT_IMAGE="${INPUT_IMAGE:-../../offline_inference/image_to_video/qwen-bear.png BASE_URL="${BASE_URL:-http://localhost:8099}" OUTPUT_PATH="${OUTPUT_PATH:-wan22_i2v_output.mp4}" NEGATIVE_PROMPT="${NEGATIVE_PROMPT:-}" +SAMPLE_SOLVER="${SAMPLE_SOLVER:-}" POLL_INTERVAL="${POLL_INTERVAL:-2}" if [ ! -f "$INPUT_IMAGE" ]; then @@ -34,6 +35,10 @@ if [ -n "${NEGATIVE_PROMPT}" ]; then create_cmd+=(-F "negative_prompt=${NEGATIVE_PROMPT}") fi +if [ -n "${SAMPLE_SOLVER}" ]; then + create_cmd+=(-F "extra_params={\"sample_solver\":\"${SAMPLE_SOLVER}\"}") +fi + create_response="$("${create_cmd[@]}")" video_id="$(echo "${create_response}" | jq -r '.id')" if [ -z "${video_id}" ] || [ "${video_id}" = "null" ]; then diff --git a/examples/online_serving/ming_flash_omni/README.md b/examples/online_serving/ming_flash_omni/README.md new file mode 100644 index 0000000000..502232725c --- /dev/null +++ b/examples/online_serving/ming_flash_omni/README.md @@ -0,0 +1,204 @@ +# Ming-flash-omni 2.0 + +## Installation + +Please refer to [README.md](../../../README.md) + +## Run examples (Ming-flash-omni 2.0) + +### Launch the Server + +```bash +vllm serve Jonathan1909/Ming-flash-omni-2.0 --omni --port 8091 +``` + +If you have custom stage configs file, launch the server with command below +```bash +vllm serve Jonathan1909/Ming-flash-omni-2.0 --omni --port 8091 --stage-configs-path /path/to/stage_configs_file +``` + +### Send Multi-modal Request + +#### Send request via python + +```bash +python examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py --model Jonathan1909/Ming-flash-omni-2.0 --query-type use_mixed_modalities --port 8091 --host "localhost" --modalities text +``` + +The Python client supports the following command-line arguments: + +- `--query-type` (or `-q`): Query type. Options: `text`, `use_audio`, `use_image`, `use_video`, `use_mixed_modalities` +- `--video-path` (or `-v`): Path to local video file or URL. If not provided and query-type uses video, uses default video URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs. Example: `--video-path /path/to/video.mp4` or `--video-path https://example.com/video.mp4` +- `--image-path` (or `-i`): Path to local image file or URL. If not provided and query-type uses image, uses default image URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common image formats: JPEG, PNG, GIF, WebP. Example: `--image-path /path/to/image.jpg` or `--image-path https://example.com/image.png` +- `--audio-path` (or `-a`): Path to local audio file or URL. If not provided and query-type uses audio, uses default audio URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common audio formats: MP3, WAV, OGG, FLAC, M4A. Example: `--audio-path /path/to/audio.wav` or `--audio-path https://example.com/audio.mp3` +- `--prompt` (or `-p`): Custom text prompt/question. If not provided, uses default prompt for the selected query type. Example: `--prompt "What are the main activities shown in this video?"` +- `--modalities`: Output modalities. For now, only `text` is supported. Example: `--modalities text` + + +#### Send request via curl + +```bash +bash run_curl_multimodal_generation.sh text +bash run_curl_multimodal_generation.sh use_image +bash run_curl_multimodal_generation.sh use_audio +bash run_curl_multimodal_generation.sh use_video +bash run_curl_multimodal_generation.sh use_mixed_modalities +``` + +## Modality control + +Ming-flash-omni 2.0 currently supports text output only (thinker stage). + +| Modalities | Output | +|------------|--------| +| `["text"]` | Text only | +| Not specified | Text only (default) | + +### Using curl + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Jonathan1909/Ming-flash-omni-2.0", + "messages": [ + {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]}, + {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"} + ], + "modalities": ["text"] + }' +``` + +### Using OpenAI Python SDK + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") + +response = client.chat.completions.create( + model="Jonathan1909/Ming-flash-omni-2.0", + messages=[ + {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]}, + {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"}, + ], + modalities=["text"], +) +print(response.choices[0].message.content) +``` + +### Multi-modal input with OpenAI Python SDK + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") + +response = client.chat.completions.create( + model="Jonathan1909/Ming-flash-omni-2.0", + messages=[ + {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]}, + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg"}}, + {"type": "text", "text": "Describe this image in detail."}, + ], + }, + ], + modalities=["text"], +) +print(response.choices[0].message.content) +``` + +## Streaming Output + +To enable streaming output: + +```bash +python examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py \ + --query-type use_image \ + --model Jonathan1909/Ming-flash-omni-2.0 \ + --modalities text \ + --stream +``` + +Or with the OpenAI Python SDK: + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") + +response = client.chat.completions.create( + model="Jonathan1909/Ming-flash-omni-2.0", + messages=[ + {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]}, + {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"}, + ], + modalities=["text"], + stream=True, +) +for chunk in response: + for choice in chunk.choices: + if hasattr(choice, "delta") and choice.delta.content: + print(choice.delta.content, end="", flush=True) +print() +``` + +Or using curl: + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Jonathan1909/Ming-flash-omni-2.0", + "messages": [ + {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]}, + {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"} + ], + "modalities": ["text"], + "stream": true, + }' +``` + + +## Reasoning (Thinking Mode) + +To enable reasoning/thinking mode, change `detailed thinking off` to `detailed thinking on` in the system prompt: + +### Using curl + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Jonathan1909/Ming-flash-omni-2.0", + "messages": [ + {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking on"}]}, + {"role": "user", "content": [ + {"type": "image_url", "image_url": {"url": "https://example.com/math_problem.png"}}, + {"type": "text", "text": "Solve this math problem step by step."} + ]} + ], + "modalities": ["text"] + }' +``` + +### Using OpenAI Python SDK + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY") + +response = client.chat.completions.create( + model="Jonathan1909/Ming-flash-omni-2.0", + messages=[ + {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking on"}]}, + {"role": "user", "content": "If a train travels 120 km in 2 hours, what is its average speed?"}, + ], + modalities=["text"], +) +print(response.choices[0].message.content) +``` diff --git a/examples/online_serving/ming_flash_omni/run_curl_multimodal_generation.sh b/examples/online_serving/ming_flash_omni/run_curl_multimodal_generation.sh new file mode 100755 index 0000000000..768a424e45 --- /dev/null +++ b/examples/online_serving/ming_flash_omni/run_curl_multimodal_generation.sh @@ -0,0 +1,145 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Server port +PORT="${PORT:-8091}" +# Default query type +QUERY_TYPE="${1:-text}" + +# Validate query type +if [[ ! "$QUERY_TYPE" =~ ^(text|use_audio|use_image|use_video|use_mixed_modalities)$ ]]; then + echo "Error: Invalid query type '$QUERY_TYPE'" + echo "Usage: $0 [text|use_audio|use_image|use_video|use_mixed_modalities]" + echo " text: Text-only query" + echo " use_audio: Audio + Text query" + echo " use_image: Image + Text query" + echo " use_video: Video + Text query" + echo " use_mixed_modalities: Audio + Image + Video + Text query" + exit 1 +fi + +thinker_sampling_params='{ + "temperature": 0.4, + "top_p": 0.9, + "top_k": -1, + "max_tokens": 16384, + "seed": 42, + "detokenize": true, + "repetition_penalty": 1.05 +}' +# Above is optional, it has a default setting in stage_configs of the corresponding model. + +# Define URLs for assets +MARY_HAD_LAMB_AUDIO_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/mary_had_lamb.ogg" +CHERRY_BLOSSOM_IMAGE_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg" +SAMPLE_VIDEO_URL="https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4" + +# Build user content based on query type +case "$QUERY_TYPE" in + text) + user_content='[ + { + "type": "text", + "text": "请详细介绍鹦鹉的生活习性。" + } + ]' + ;; + use_image) + user_content='[ + { + "type": "image_url", + "image_url": { + "url": "'"$CHERRY_BLOSSOM_IMAGE_URL"'" + } + }, + { + "type": "text", + "text": "Describe this image in detail." + } + ]' + ;; + use_audio) + user_content='[ + { + "type": "audio_url", + "audio_url": { + "url": "'"$MARY_HAD_LAMB_AUDIO_URL"'" + } + }, + { + "type": "text", + "text": "Please recognize the language of this speech and transcribe it. Format: oral." + } + ]' + ;; + use_video) + user_content='[ + { + "type": "video_url", + "video_url": { + "url": "'"$SAMPLE_VIDEO_URL"'" + } + }, + { + "type": "text", + "text": "Describe what is happening in this video." + } + ]' + ;; + use_mixed_modalities) + user_content='[ + { + "type": "image_url", + "image_url": { + "url": "'"$CHERRY_BLOSSOM_IMAGE_URL"'" + } + }, + { + "type": "audio_url", + "audio_url": { + "url": "'"$MARY_HAD_LAMB_AUDIO_URL"'" + } + }, + { + "type": "text", + "text": "Describe the image, and recognize the language of this speech and transcribe it. Format: oral" + } + ]' + ;; +esac + +echo "Running query type: $QUERY_TYPE" +echo "" + +request_body=$(cat < **Note on `--no-async-chunk`**: Flips the deploy yaml's `async_chunk:` +> bool. Pipelines that implement alternate processor functions for +> chunked vs end-to-end modes (e.g. qwen3_tts code2wav) dispatch +> automatically based on that bool — no extra flag or variant yaml is +> needed. + +> ⚠️ **For multi-stage models that share GPUs (qwen3_omni_moe by default +> shares cuda:1 between stages 1 and 2), avoid using global memory flags.** +> A global `--gpu-memory-utilization 0.85` would apply to every stage and +> oversubscribe the shared device. Use per-stage overrides instead — see +> below. + +#### 2. Per-stage overrides via `--stage-overrides` (recommended for memory) + +```bash +# Lower stage 1's memory budget; leave others at the YAML default +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \ + --stage-overrides '{ + "1": {"gpu_memory_utilization": 0.5}, + "2": {"max_num_batched_tokens": 65536} + }' +``` + +Per-stage values are always treated as explicit and beat YAML defaults for +the named stage. Other stages keep their YAML values. + +#### 3. Custom deploy YAML + +When per-stage overrides get long, write a small overlay YAML that inherits +from the bundled default: + +```yaml +# my_qwen3_omni_overrides.yaml +base_config: /path/to/vllm_omni/deploy/qwen3_omni_moe.yaml + +stages: + - stage_id: 0 + max_num_batched_tokens: 65536 + enforce_eager: true + - stage_id: 1 + gpu_memory_utilization: 0.5 + - stage_id: 2 + max_model_len: 8192 ``` +Then start the server with `--deploy-config my_qwen3_omni_overrides.yaml`. +The `base_config:` line tells the loader to inherit everything else (stages, +connectors, edges, platforms section) from the bundled production YAML, so +you only need to spell out the deltas. + +#### 4. Multi-node deployment (cross-host transfer connector) + +The bundled `qwen3_omni_moe.yaml` uses `SharedMemoryConnector` between stages, +which only works when all stages run on the same physical host. For +**cross-node** deployments, write a small overlay YAML that swaps in a +network-capable connector (e.g. `MooncakeStoreConnector`) and re-points each +stage's connector wiring at it. The connector spec carries your own server +addresses — there is no checked-in default because every cluster is +different. + +```yaml +# my_qwen3_omni_multinode.yaml +base_config: /path/to/vllm_omni/deploy/qwen3_omni_moe.yaml + +connectors: + mooncake_connector: + name: MooncakeStoreConnector + extra: + host: "127.0.0.1" + metadata_server: "http://YOUR_METADATA_HOST:8080/metadata" + master: "YOUR_MASTER_HOST:50051" + segment: 512000000 # 512 MB transfer segment + localbuf: 64000000 # 64 MB local buffer + proto: "tcp" + +stages: + - stage_id: 0 + output_connectors: + to_stage_1: mooncake_connector + - stage_id: 1 + input_connectors: + from_stage_0: mooncake_connector + output_connectors: + to_stage_2: mooncake_connector + - stage_id: 2 + input_connectors: + from_stage_1: mooncake_connector +``` + +Then launch with `--deploy-config my_qwen3_omni_multinode.yaml`. Same +pattern works for Qwen2.5-Omni — replace `base_config:` with the path to +`vllm_omni/deploy/qwen2_5_omni.yaml`. + +> ⚠️ Replace `YOUR_METADATA_HOST` / `YOUR_MASTER_HOST` with the actual +> mooncake server addresses for your cluster. The `base_config:` overlay +> inherits all stage budgets, devices, and edges from the bundled prod +> YAML — you only need to spell out the connector swap. + ### Send Multi-modal Request Get into the example folder @@ -38,38 +180,43 @@ python examples/online_serving/openai_chat_completion_client_for_multimodal_gene #### Realtime WebSocket client (`openai_realtime_client.py`) -[`openai_realtime_client.py`](./openai_realtime_client.py) connects to **`ws://:/v1/realtime`**, uploads a local audio file as **PCM16 mono @ 16 kHz** chunks (OpenAI-style `input_audio_buffer.append` / `commit`), and prints **streaming transcription** (`transcription.delta` / `transcription.done`). +[`openai_realtime_client.py`](./openai_realtime_client.py) connects to **`ws://:/v1/realtime`**, streams a local WAV as **PCM16 mono @ 16 kHz** in fixed-size chunks (OpenAI-style `input_audio_buffer.append` / `commit`), and receives **`response.audio.delta`** (incremental PCM for the reply) plus **`transcription.*`** events. By default it concatenates audio deltas and writes **`--output-wav`** (model output is typically **24 kHz**). Optional **`--delta-dump-dir`** saves each delta as `delta_000001.wav`, … for debugging. + +Streaming input works well for translation-style use cases; if the Thinker runs while input is still incomplete, consider limiting **`max_tokens`** in your session / server defaults to avoid over-generation. **Dependencies:** ```bash -pip install websockets librosa numpy +pip install websockets ``` -(ffmpeg may be required by `librosa` for some formats; see the FAQ below.) - **From this directory** (`examples/online_serving/qwen3_omni`): ```bash python openai_realtime_client.py \ - --host localhost \ - --port 8091 \ + --url ws://localhost:8091/v1/realtime \ --model Qwen/Qwen3-Omni-30B-A3B-Instruct \ - --audio_path /path/to/your.wav + --input-wav /path/to/input_16k_mono.wav \ + --output-wav realtime_output.wav \ + --delta-dump-dir ./rt_delta_wavs ``` -If `--audio_path` is omitted, the script uses a bundled default clip (`mary_had_lamb` via vLLM assets). - **Arguments:** | Flag | Default | Description | |------|---------|-------------| -| `--host` | `localhost` | API server host | -| `--port` | `8000` | API server port (match your `vllm serve` port, e.g. `8091`) | -| `--model` | `Qwen/Qwen3-Omni-30B-A3B-Instruct` | Must match the served model (also sent in `session.update`) | -| `--audio_path` | *(optional)* | Path to input audio; resampled to 16 kHz mono inside the client | - -Ensure the vLLM-Omni server is running with realtime support for this endpoint, for example: +| `--url` | `ws://localhost:8091/v1/realtime` | Full WebSocket URL including path | +| `--model` | `Qwen/Qwen3-Omni-30B-A3B-Instruct` | Must match the served model (sent in `session.update`) | +| `--input-wav` | *(required)* | Input WAV: mono, 16-bit PCM, **16 kHz** | +| `--output-wav` | `realtime_output.wav` | Output path for concatenated reply audio | +| `--output-text` | *(optional)* | If set, write final transcription text to this path | +| `--chunk-ms` | `200` | Size of each uploaded audio chunk (milliseconds of audio) | +| `--send-delay-ms` | `0` | Delay between chunk sends (simulate realtime upload) | +| `--delta-dump-dir` | *(optional)* | Directory to write per-`response.audio.delta` WAV files | +| `--num-requests` | `1` | Number of sequential sessions (see `--concurrency`) | +| `--concurrency` | `1` | Max concurrent WebSocket sessions when `--num-requests` > 1 | + +Ensure the server is running **without** `async_chunk` if you use `/v1/realtime`, for example: ```bash vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 @@ -105,12 +252,6 @@ bash run_curl_multimodal_generation.sh use_image ### FAQ -If you encounter error about backend of librosa, try to install ffmpeg with command below. -``` -sudo apt update -sudo apt install ffmpeg -``` - ## Modality control You can control output modalities to specify which types of output the model should generate. This is useful when you only need text output and want to skip audio generation stages for better performance. @@ -284,7 +425,7 @@ The script supports the following arguments: - `--model`: Model name/path (default: Qwen/Qwen3-Omni-30B-A3B-Instruct) - `--server-port`: Port for vLLM server (default: 8091) - `--gradio-port`: Port for Gradio demo (default: 7861) -- `--stage-configs-path`: Path to custom stage configs YAML file (optional) +- `--deploy-config`: Path to custom deploy config YAML file (optional) - `--server-host`: Host for vLLM server (default: 0.0.0.0) - `--gradio-ip`: IP for Gradio demo (default: 127.0.0.1) - `--share`: Share Gradio demo publicly (creates a public link) @@ -299,7 +440,7 @@ vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 If you have custom stage configs file: ```bash -vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /path/to/stage_configs_file +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --deploy-config /path/to/deploy_config_file ``` **Step 2: Run the Gradio demo** diff --git a/examples/online_serving/qwen3_omni/openai_realtime_client.py b/examples/online_serving/qwen3_omni/openai_realtime_client.py index 4fa043c481..79e30a3f50 100644 --- a/examples/online_serving/qwen3_omni/openai_realtime_client.py +++ b/examples/online_serving/qwen3_omni/openai_realtime_client.py @@ -1,81 +1,118 @@ -""" -This script demonstrates how to use the vLLM-Omni Realtime WebSocket API to perform -audio transcription by uploading an audio file. +"""Realtime client for vLLM-Omni /v1/realtime (audio + text events). + +This client: +1) Reads a local WAV file (must be mono, 16-bit PCM, 16kHz), +2) Streams PCM16 chunks to /v1/realtime with OpenAI-style events, +3) Receives response.audio.* and transcription.* events, +4) Saves synthesized audio to an output WAV file and optional text file. -Before running this script, you must start the vLLM-Omni server with a realtime-capable -model, for example: +By default each ``response.audio.delta`` is treated as an **incremental PCM** +chunk and all chunks are concatenated into the final ``--output-wav``. - vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni +Optional debugging: pass ``--delta-dump-dir DIR`` to write every +``response.audio.delta`` payload as ``delta_000001.wav``, ``delta_000002.wav``, … -Requirements: -- vllm with audio support -- websockets -- librosa -- numpy +Usage: + python openai_realtime_client.py \ + --url ws://localhost:8091/v1/realtime \ + --model Qwen/Qwen3-Omni-30B-A3B-Instruct \ + --input-wav input_16k_mono.wav \ + --output-wav realtime_output.wav \ + --delta-dump-dir ./rt_delta_wavs -The script: -1. Connects to the Realtime WebSocket endpoint -2. Converts an audio file to PCM16 @ 16kHz -3. Sends audio chunks to the server -4. Receives and prints transcription as it streams +Dependencies: + pip install websockets """ +from __future__ import annotations + import argparse import asyncio import base64 import json +import wave +from pathlib import Path + +try: + import websockets +except ImportError: + print("Please install websockets: pip install websockets") + raise SystemExit(1) + + +def _read_wav_pcm16(path: Path) -> bytes: + with wave.open(str(path), "rb") as wf: + nchannels = wf.getnchannels() + sampwidth = wf.getsampwidth() + framerate = wf.getframerate() + comptype = wf.getcomptype() + nframes = wf.getnframes() + + if nchannels != 1: + raise ValueError(f"Input WAV must be mono (got {nchannels} channels).") + if sampwidth != 2: + raise ValueError(f"Input WAV must be 16-bit PCM (got sample width={sampwidth}).") + if framerate != 16000: + raise ValueError(f"Input WAV must be 16kHz (got {framerate} Hz).") + if comptype != "NONE": + raise ValueError(f"Input WAV must be uncompressed PCM (got comptype={comptype}).") + if nframes <= 0: + raise ValueError("Input WAV has no audio frames.") + + return wf.readframes(nframes) + + +def _write_wav_pcm16(path: Path, pcm16_bytes: bytes, sample_rate_hz: int) -> None: + with wave.open(str(path), "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate_hz) + wf.writeframes(pcm16_bytes) + + +async def run_client( + url: str, + model: str, + input_wav: Path, + output_wav: Path, + output_text: Path | None, + chunk_ms: int, + send_delay_ms: int, + delta_dump_dir: Path | None, + request_idx: int = 1, + total_requests: int = 1, +) -> None: + log_prefix = f"[req {request_idx:02d}/{total_requests:02d}] " if total_requests > 1 else "" + pcm16 = _read_wav_pcm16(input_wav) + bytes_per_ms = 16000 * 2 // 1000 # mono PCM16 at 16kHz + chunk_bytes = max(bytes_per_ms * chunk_ms, 2) -import librosa -import numpy as np -import websockets -from vllm.assets.audio import AudioAsset - - -def audio_to_pcm16_base64(audio_path: str) -> str: - """ - Load an audio file and convert it to base64-encoded PCM16 @ 16kHz. - """ - # Load audio and resample to 16kHz mono - audio, _ = librosa.load(audio_path, sr=16000, mono=True) - # Convert to PCM16 - pcm16 = (audio * 32767).astype(np.int16) - # Encode as base64 - return base64.b64encode(pcm16.tobytes()).decode("utf-8") - - -async def realtime_transcribe(audio_path: str, host: str, port: int, model: str): - """ - Connect to the Realtime API and transcribe an audio file. - """ - uri = f"ws://{host}:{port}/v1/realtime" - - async with websockets.connect(uri) as ws: - # Wait for session.created - response = json.loads(await ws.recv()) - if response["type"] == "session.created": - print(f"Session created: {response['id']}") - else: - print(f"Unexpected response: {response}") - return - - # Validate model - await ws.send(json.dumps({"type": "session.update", "model": model})) - - # Signal ready to start - await ws.send(json.dumps({"type": "input_audio_buffer.commit"})) - - # Convert audio file to base64 PCM16 - print(f"Loading audio from: {audio_path}") - audio_base64 = audio_to_pcm16_base64(audio_path) - - # Send audio in chunks (4KB of raw audio = ~8KB base64) - chunk_size = 4096 - audio_bytes = base64.b64decode(audio_base64) - total_chunks = (len(audio_bytes) + chunk_size - 1) // chunk_size - - print(f"Sending {total_chunks} audio chunks...") - for i in range(0, len(audio_bytes), chunk_size): - chunk = audio_bytes[i : i + chunk_size] + incremental_pcm_parts: list[bytes] = [] + output_sample_rate = 24000 + delta_index = 0 + text_chunks: list[str] = [] + final_text: str = "" + + if delta_dump_dir is not None: + delta_dump_dir.mkdir(parents=True, exist_ok=True) + + async with websockets.connect(url, max_size=64 * 1024 * 1024) as ws: + # 1) Validate model. + await ws.send( + json.dumps( + { + "type": "session.update", + "model": model, + } + ) + ) + + # 2) Start generation once (non-final commit). + await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": False})) + + # 3) Stream audio chunks. + for i in range(0, len(pcm16), chunk_bytes): + chunk = pcm16[i : i + chunk_bytes] await ws.send( json.dumps( { @@ -84,63 +121,212 @@ async def realtime_transcribe(audio_path: str, host: str, port: int, model: str) } ) ) + if send_delay_ms > 0: + await asyncio.sleep(send_delay_ms / 1000.0) - # Signal all audio is sent + # 4) Final commit closes input stream. await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": True})) - print("Audio sent. Waiting for transcription...\n") - # Receive transcription - print("Transcription: ", end="", flush=True) + # 5) Receive server events until audio done. while True: - response = json.loads(await ws.recv()) - if response["type"] == "transcription.delta": - print(response["delta"], end="", flush=True) - elif response["type"] == "transcription.done": - print(f"\n\nFinal transcription: {response['text']}") - if response.get("usage"): - print(f"Usage: {response['usage']}") - break - elif response["type"] == "error": - print(f"\nError: {response['error']}") + message = await ws.recv() + if isinstance(message, bytes): + # We only expect JSON text frames. + continue + + event = json.loads(message) + event_type = event.get("type") + + if event_type == "session.created": + continue + + if event_type == "response.audio.delta": + sr = event.get("sample_rate_hz") + if isinstance(sr, int) and sr > 0: + output_sample_rate = sr + audio_b64 = event.get("audio", "") + if audio_b64: + pcm_delta = base64.b64decode(audio_b64) + incremental_pcm_parts.append(pcm_delta) + if delta_dump_dir is not None and pcm_delta: + delta_index += 1 + dump_path = delta_dump_dir / f"delta_{delta_index:06d}.wav" + _write_wav_pcm16(dump_path, pcm_delta, output_sample_rate) + print( + f"{log_prefix}delta dump #{delta_index}: {dump_path} " + f"(pcm bytes={len(pcm_delta)}, sr={output_sample_rate})" + ) + continue + + if event_type == "transcription.delta": + delta = event.get("delta", "") + if delta: + text_chunks.append(delta) + print(delta, end="", flush=True) + continue + + if event_type == "transcription.done": + final_text = event.get("text", "") or "".join(text_chunks) + usage = event.get("usage") + final_text_with_tag = f"Final transcription: {final_text}" + if text_chunks: + print() + print(f"{log_prefix}{final_text_with_tag}") + if usage: + print(f"{log_prefix}text usage: {usage}") + continue + + if event_type == "response.audio.done": break + if event_type == "error": + raise RuntimeError(f"Server error: {event}") -def main(args): - if args.audio_path: - audio_path = args.audio_path - else: - # Use default audio asset - audio_path = str(AudioAsset("mary_had_lamb").get_local_path()) - print(f"No audio path provided, using default: {audio_path}") + all_pcm16 = b"".join(incremental_pcm_parts) + if not all_pcm16: + raise RuntimeError("No audio received from server.") - asyncio.run(realtime_transcribe(audio_path, args.host, args.port, args.model)) + output_wav.parent.mkdir(parents=True, exist_ok=True) + _write_wav_pcm16(output_wav, all_pcm16, output_sample_rate) + print(f"{log_prefix}Saved realtime audio to: {output_wav} (incremental chunks joined)") + if output_text is not None: + text_to_save = final_text if final_text else "".join(text_chunks) + output_text.parent.mkdir(parents=True, exist_ok=True) + output_text.write_text(text_to_save, encoding="utf-8") + print(f"{log_prefix}Saved realtime text to: {output_text}") -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Realtime WebSocket Transcription Client") + +def _indexed_output_path(path: Path | None, index: int, total: int) -> Path | None: + if path is None or total <= 1: + return path + return path.with_name(f"{path.stem}_{index:02d}{path.suffix}") + + +async def run_clients_concurrent( + *, + url: str, + model: str, + input_wav: Path, + output_wav: Path, + output_text: Path | None, + chunk_ms: int, + send_delay_ms: int, + delta_dump_dir: Path | None, + num_requests: int, + concurrency: int, +) -> None: + sem = asyncio.Semaphore(concurrency) + + async def _run_one(index: int) -> tuple[int, bool, str | None]: + per_output_wav = _indexed_output_path(output_wav, index, num_requests) + per_output_text = _indexed_output_path(output_text, index, num_requests) + per_delta_dir = None + if delta_dump_dir is not None: + per_delta_dir = delta_dump_dir / f"req_{index:02d}" + async with sem: + try: + await run_client( + url=url, + model=model, + input_wav=input_wav, + output_wav=per_output_wav, + output_text=per_output_text, + chunk_ms=chunk_ms, + send_delay_ms=send_delay_ms, + delta_dump_dir=per_delta_dir, + request_idx=index, + total_requests=num_requests, + ) + return index, True, None + except Exception as exc: + return index, False, str(exc) + + tasks = [asyncio.create_task(_run_one(i), name=f"rt-client-{i}") for i in range(1, num_requests + 1)] + results = await asyncio.gather(*tasks) + + failed = [(idx, err) for idx, ok, err in results if not ok] + succeeded = num_requests - len(failed) + print(f"[summary] succeeded={succeeded}, failed={len(failed)}, total={num_requests}") + if failed: + for idx, err in failed: + print(f"[summary] req {idx:02d} failed: {err}") + raise RuntimeError(f"{len(failed)} concurrent request(s) failed") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Realtime audio/text client for vLLM-Omni") + parser.add_argument("--url", default="ws://localhost:8091/v1/realtime", help="WebSocket URL") parser.add_argument( "--model", - type=str, default="Qwen/Qwen3-Omni-30B-A3B-Instruct", - help="Model that is served and should be pinged.", + help="Model name for session.update", ) + parser.add_argument("--input-wav", required=True, type=Path, help="Input WAV (mono, PCM16, 16kHz)") + parser.add_argument("--output-wav", default=Path("realtime_output.wav"), type=Path, help="Output WAV path") parser.add_argument( - "--audio_path", - type=str, + "--output-text", default=None, - help="Path to the audio file to transcribe.", + type=Path, + help="Optional output text path for final transcription", ) + parser.add_argument("--chunk-ms", type=int, default=200, help="Input chunk size in milliseconds") parser.add_argument( - "--host", - type=str, - default="localhost", - help="vLLM-Omni server host (default: localhost)", + "--send-delay-ms", + type=int, + default=0, + help="Delay between chunk sends; set >0 to simulate realtime upload", ) parser.add_argument( - "--port", + "--delta-dump-dir", + type=Path, + default=None, + help="If set, each response.audio.delta is saved as delta_NNNNNN.wav under this directory", + ) + parser.add_argument("--num-requests", type=int, default=1, help="Total number of requests to send") + parser.add_argument( + "--concurrency", type=int, - default=8000, - help="vLLM-Omni server port (default: 8000)", + default=1, + help="Maximum number of concurrent websocket requests", ) args = parser.parse_args() - main(args) + + if args.num_requests <= 0: + raise ValueError("--num-requests must be >= 1") + if args.concurrency <= 0: + raise ValueError("--concurrency must be >= 1") + concurrency = min(args.concurrency, args.num_requests) + + if args.num_requests == 1: + asyncio.run( + run_client( + url=args.url, + model=args.model, + input_wav=args.input_wav, + output_wav=args.output_wav, + output_text=args.output_text, + chunk_ms=args.chunk_ms, + send_delay_ms=args.send_delay_ms, + delta_dump_dir=args.delta_dump_dir, + ) + ) + else: + asyncio.run( + run_clients_concurrent( + url=args.url, + model=args.model, + input_wav=args.input_wav, + output_wav=args.output_wav, + output_text=args.output_text, + chunk_ms=args.chunk_ms, + send_delay_ms=args.send_delay_ms, + delta_dump_dir=args.delta_dump_dir, + num_requests=args.num_requests, + concurrency=concurrency, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/qwen3_tts/README.md b/examples/online_serving/qwen3_tts/README.md index 5504b5737a..350fcb71ca 100644 --- a/examples/online_serving/qwen3_tts/README.md +++ b/examples/online_serving/qwen3_tts/README.md @@ -43,7 +43,7 @@ Then open http://localhost:7860 in your browser. ### Launch the Server -The default stage config is located at `vllm_omni/model_executor/stage_configs/qwen3_tts.yaml`. For other platforms (e.g., NPU), refer to `vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml`. +The default deploy config is located at `vllm_omni/deploy/qwen3_tts.yaml` and is loaded automatically by the model registry — no `--deploy-config` flag needed for default use. Platform-specific deltas (NPU, ROCm, XPU) are merged in automatically from the `platforms:` block of the same YAML based on the detected runtime. ```bash # CustomVoice model (predefined speakers) @@ -70,6 +70,22 @@ vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \ --port 8091 ``` +#### Sync vs async-chunk mode + +Qwen3-TTS supports both **chunked streaming** (default, lower latency) and +**synchronous end-to-end** modes from the same deploy YAML. The bundled +`qwen3_tts.yaml` ships with `async_chunk: true`; flip with `--no-async-chunk` +and the pipeline automatically dispatches to the end-to-end codec processor: + +```bash +vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice --omni --port 8091 \ + --no-async-chunk +``` + +No variant YAML or extra flag is needed — the `StagePipelineConfig` on each +stage declares both processor functions and the runtime picks based on the +`async_chunk:` bool. + Alternatively, use the convenience script: ```bash ./run_server.sh # Default: CustomVoice model @@ -192,14 +208,6 @@ with open("output.wav", "wb") as f: f.write(response.content) ``` -### FAQ - -If you encounter error about backend of librosa, try to install ffmpeg with command below. -``` -sudo apt update -sudo apt install ffmpeg -``` - ## API Reference ### Voices Endpoint @@ -386,6 +394,54 @@ Server -> Client: {"type": "session.done", "total_sentences": 1} ``` +## Choosing an Execution Backend: Uniproc vs Multiprocessing + +Qwen3-TTS stage configs support two execution backends controlled by the +`distributed_executor_backend` engine arg. The performance tradeoff between +them is **both hardware- and task-dependent**, so there is no single best +default (see [#2603](https://github.com/vllm-project/vllm-omni/issues/2603), +[#2604](https://github.com/vllm-project/vllm-omni/pull/2604) for the full +investigation). + +| Backend | Stage config setting | Behaviour | +| ------- | -------------------- | --------- | +| **Uniproc** (default, world_size=1) | `distributed_executor_backend` omitted | Both stages run inside the orchestrator process. Avoids IPC serialisation, D2H copies, and msgpack overhead between stages. | +| **Multiprocessing** | `distributed_executor_backend: "mp"` | Each stage runs in its own subprocess. The Talker can continue decoding while Code2Wav runs the vocoder in parallel, improving pipeline utilisation under concurrency. | + +> **Note:** When `distributed_executor_backend` is omitted and `world_size=1`, +> vLLM [automatically uses the uniproc executor](https://github.com/vllm-project/vllm/blob/main/vllm/config/parallel.py#L825). +> When `world_size > 1`, it defaults to `mp`. + +### When uniproc wins + +The uniproc path eliminates inter-process data transfer (D2H copies, +msgpack serialisation/deserialisation, tensor detaching). This matters most +when per-request processing is heavy relative to autoregressive decode. + +The Base cloning task involves reference-audio encoding on every request, making IPC +overhead a larger fraction of total cost. Qwen3-Omni shows a similar pattern. + +### When multiprocessing (`mp`) wins + +For lighter per-request workloads, process-level parallelism between the +Talker and Code2Wav stages dominates. + +CustomVoice is lighter per-request (no reference audio encoding), so the +process-level parallelism of `mp` outweighs its serialisation cost at +concurrency ≥ 4. + +### How to switch + +To use the uniproc executor on a single-GPU setup, pass the +`qwen3_tts_uniproc.yaml` stage config: + +```bash +vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-Base \ + --omni \ + --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml \ + --port 8091 +``` + ## Limitations - **Single request**: Batch processing is not yet optimized for online serving. diff --git a/examples/online_serving/qwen3_tts/batch_speech_client.py b/examples/online_serving/qwen3_tts/batch_speech_client.py index 7d48e650f8..47fdc3691c 100644 --- a/examples/online_serving/qwen3_tts/batch_speech_client.py +++ b/examples/online_serving/qwen3_tts/batch_speech_client.py @@ -5,11 +5,13 @@ batch level and generate many utterances in the cloned voice without repeating the reference for each item. -Start the server (with batch-optimized config for best throughput): +Start the server (with batch-optimized stage settings for best throughput): vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \ - --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml \ - --trust-remote-code + --omni \ + --trust-remote-code \ + --stage-overrides '{"0":{"max_num_seqs":4,"gpu_memory_utilization":0.2}, + "1":{"max_num_seqs":4,"gpu_memory_utilization":0.2}}' Examples: # Batch with a predefined voice diff --git a/examples/online_serving/qwen3_tts/run_gradio_demo.sh b/examples/online_serving/qwen3_tts/run_gradio_demo.sh index bcc0ddb7cf..d79be3c2ab 100644 --- a/examples/online_serving/qwen3_tts/run_gradio_demo.sh +++ b/examples/online_serving/qwen3_tts/run_gradio_demo.sh @@ -127,7 +127,7 @@ echo "Starting vLLM server..." LOG_FILE="/tmp/vllm_tts_server_${SERVER_PORT}.log" vllm-omni serve "$MODEL" \ - --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \ + --deploy-config vllm_omni/deploy/qwen3_tts.yaml \ --host "$SERVER_HOST" \ --port "$SERVER_PORT" \ --gpu-memory-utilization 0.9 \ diff --git a/examples/online_serving/qwen3_tts/run_server.sh b/examples/online_serving/qwen3_tts/run_server.sh index 6f4aa83a0b..78dd2c305d 100755 --- a/examples/online_serving/qwen3_tts/run_server.sh +++ b/examples/online_serving/qwen3_tts/run_server.sh @@ -31,7 +31,7 @@ esac echo "Starting Qwen3-TTS server with model: $MODEL" vllm-omni serve "$MODEL" \ - --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \ + --deploy-config vllm_omni/deploy/qwen3_tts.yaml \ --host 0.0.0.0 \ --port 8091 \ --gpu-memory-utilization 0.9 \ diff --git a/examples/online_serving/qwen3_tts/speaker_embedding_interpolation.py b/examples/online_serving/qwen3_tts/speaker_embedding_interpolation.py index e6786f8869..7790fa5127 100644 --- a/examples/online_serving/qwen3_tts/speaker_embedding_interpolation.py +++ b/examples/online_serving/qwen3_tts/speaker_embedding_interpolation.py @@ -5,7 +5,7 @@ using SLERP and sends the result to the /v1/audio/speech API. Requirements: - pip install torch librosa soundfile numpy httpx + pip install torch soundfile numpy httpx Examples: # Extract and save an embedding @@ -143,17 +143,18 @@ def _load_speaker_encoder_weights(encoder: torch.nn.Module, model_path: str) -> def compute_mel_spectrogram(audio: np.ndarray, sr: int = 24000) -> torch.Tensor: """Compute 128-bin mel spectrogram matching Qwen3-TTS's extraction pipeline.""" - import librosa + from vllm.multimodal.audio import AudioResampler # Resample to 24kHz if needed if sr != 24000: - audio = librosa.resample(audio.astype(np.float32), orig_sr=sr, target_sr=24000) + resampler = AudioResampler(target_sr=24000) + audio = resampler.resample(audio.astype(np.float32), orig_sr=sr) y = torch.from_numpy(audio).unsqueeze(0).float() - from librosa.filters import mel as librosa_mel_fn + from vllm_omni.utils.audio import mel_filter_bank - mel_basis = torch.from_numpy(librosa_mel_fn(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=12000)).float() + mel_basis = mel_filter_bank(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=12000) n_fft = 1024 hop_size = 256 @@ -180,9 +181,9 @@ def compute_mel_spectrogram(audio: np.ndarray, sr: int = 24000) -> torch.Tensor: @torch.inference_mode() def extract_embedding(encoder: torch.nn.Module, audio_path: str, device: str = "cpu") -> np.ndarray: """Extract a 1024-dim speaker embedding from an audio file.""" - import librosa + from vllm.multimodal.media.audio import load_audio - audio, sr = librosa.load(audio_path, sr=None, mono=True) + audio, sr = load_audio(audio_path, sr=None, mono=True) mel = compute_mel_spectrogram(audio, sr).to(device) embedding = encoder(mel.to(next(encoder.parameters()).dtype))[0] return embedding.float().cpu().numpy() diff --git a/examples/online_serving/text_to_video/README.md b/examples/online_serving/text_to_video/README.md index 44e676671f..c01e0602ff 100644 --- a/examples/online_serving/text_to_video/README.md +++ b/examples/online_serving/text_to_video/README.md @@ -1,16 +1,27 @@ # Text-To-Video -This example demonstrates how to deploy the Wan2.2 text-to-video model for online video generation using vLLM-Omni. +This example demonstrates how to deploy text-to-video models for online video generation using vLLM-Omni. -## Start Server +## Supported Models -### Basic Start +| Model | Model ID | +|-------|----------| +| Wan2.1 T2V (1.3B) | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` | +| Wan2.1 T2V (14B) | `Wan-AI/Wan2.1-T2V-14B-Diffusers` | +| Wan2.2 T2V | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | +| LTX-2 | `Lightricks/LTX-2` | + +## Wan2.2 T2V + +### Start Server + +#### Basic Start ```bash vllm serve Wan-AI/Wan2.2-T2V-A14B-Diffusers --omni --port 8091 ``` -### Start with Parameters +#### Start with Parameters Or use the startup script: @@ -230,3 +241,82 @@ while true; do sleep 2 done ``` + +## LTX-2 + +### Start Server + +#### Basic Start + +```bash +vllm serve Lightricks/LTX-2 --omni --port 8098 \ + --enforce-eager --flow-shift 1.0 --boundary-ratio 1.0 +``` + +#### Start with Optimization Presets + +Use the LTX-2 startup script with built-in optimization presets: + +```bash +# Baseline (1 GPU, eager) +bash run_server_ltx2.sh baseline + +# 4-GPU Ulysses sequence parallelism (lossless) +bash run_server_ltx2.sh ulysses4 + +# Cache-DiT lossy acceleration (1 GPU, ~1.4× speedup) +bash run_server_ltx2.sh cache-dit + +# Best combo: 4-GPU Ulysses SP + Cache-DiT (~2.2× speedup) +bash run_server_ltx2.sh best-combo +``` + +#### Optimization Benchmarks + +Benchmarked on H800, online serving (480×768, 41 frames, 20 steps, `seed=42`). +"Inference" is the server-reported inference time; excludes HTTP/poll overhead. + +| Preset | Server Command | Inference (s) | Speedup | Type | +|--------|---------------|---------------|---------|------| +| `baseline` | `--enforce-eager` | 10.3 | 1.00× | — | +| `compile` | *(default, no --enforce-eager)* | ~10.3 (warm) | ~1.00× | Lossless | +| `ulysses4` | `--enforce-eager --usp 4` | ~10.3 | ~1.00× | Lossless | +| `cache-dit` | `--enforce-eager --cache-backend cache_dit` | 7.4 avg | ~1.4× | Lossy | +| `best-combo` | `--enforce-eager --usp 4 --cache-backend cache_dit` | 4.7 avg | **~2.2×** | Lossless + Lossy | + +**Observations**: +- **torch.compile**: On H800, warm-request inference time matches the eager baseline (~10.3s). + The first request pays ~6s compilation overhead. Benefit depends on model architecture and GPU. +- **Ulysses SP (4 GPU)**: No measurable speedup alone for 41-frame generation at this resolution. + Communication overhead outweighs gains at this sequence length. +- **Cache-DiT**: Inference varies per request (6–10s) due to dynamic caching decisions. + Average is ~7.4s (~1.4× speedup) with slight quality tradeoff. +- **Best combo**: 4-GPU Ulysses SP + Cache-DiT synergize well — Cache-DiT reduces per-step + computation, making the communication overhead of Ulysses SP worthwhile. Average ~4.7s + (~2.2× speedup). +- **FP8 quantization**: Reduces VRAM but does not speed up LTX-2 on H800 (compute-bound). + +**Deployment Recommendations**: +- For **production with quality priority**: use `baseline` with `--enforce-eager` +- For **maximum throughput** (4 GPUs, quality tradeoff): use `best-combo` (~2.2× speedup) +- For **single-GPU throughput**: use `cache-dit` (~1.4× speedup) +- `--enforce-eager` is recommended to avoid torch.compile warmup latency on first request + +### Send Requests (curl) + +```bash +# Using the provided script +bash run_curl_ltx2.sh + +# Or directly +curl -sS -X POST http://localhost:8098/v1/videos \ + -H "Accept: application/json" \ + -F "prompt=A serene lakeside sunrise with mist over the water." \ + -F "width=768" \ + -F "height=480" \ + -F "num_frames=41" \ + -F "fps=24" \ + -F "num_inference_steps=20" \ + -F "guidance_scale=3.0" \ + -F "seed=42" +``` diff --git a/examples/online_serving/text_to_video/run_curl_ltx2.sh b/examples/online_serving/text_to_video/run_curl_ltx2.sh new file mode 100644 index 0000000000..b82f672eaa --- /dev/null +++ b/examples/online_serving/text_to_video/run_curl_ltx2.sh @@ -0,0 +1,66 @@ +#!/bin/bash +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# LTX-2 text-to-video curl example using the async video job API. +# Start the server first: bash run_server_ltx2.sh best-combo + +set -euo pipefail + +BASE_URL="${BASE_URL:-http://localhost:8098}" +OUTPUT_PATH="${OUTPUT_PATH:-ltx2_output.mp4}" +POLL_INTERVAL="${POLL_INTERVAL:-2}" + +PROMPT="${PROMPT:-A serene lakeside sunrise with mist over the water.}" + +create_response=$( + curl -sS -X POST "${BASE_URL}/v1/videos" \ + -H "Accept: application/json" \ + -F "prompt=${PROMPT}" \ + -F "width=768" \ + -F "height=480" \ + -F "num_frames=41" \ + -F "fps=24" \ + -F "num_inference_steps=20" \ + -F "guidance_scale=3.0" \ + -F "seed=42" +) + +video_id="$(echo "${create_response}" | jq -r '.id')" +if [ -z "${video_id}" ] || [ "${video_id}" = "null" ]; then + echo "Failed to create video job:" + echo "${create_response}" | jq . + exit 1 +fi + +echo "Created video job ${video_id}" +echo "${create_response}" | jq . + +while true; do + status_response="$(curl -sS "${BASE_URL}/v1/videos/${video_id}")" + status="$(echo "${status_response}" | jq -r '.status')" + + case "${status}" in + queued|in_progress) + echo "Video job ${video_id} status: ${status}" + sleep "${POLL_INTERVAL}" + ;; + completed) + echo "${status_response}" | jq . + break + ;; + failed) + echo "Video generation failed:" + echo "${status_response}" | jq . + exit 1 + ;; + *) + echo "Unexpected status response:" + echo "${status_response}" | jq . + exit 1 + ;; + esac +done + +curl -sS -L "${BASE_URL}/v1/videos/${video_id}/content" -o "${OUTPUT_PATH}" +echo "Saved video to ${OUTPUT_PATH}" diff --git a/examples/online_serving/text_to_video/run_server_ltx2.sh b/examples/online_serving/text_to_video/run_server_ltx2.sh new file mode 100644 index 0000000000..f4597d3cd2 --- /dev/null +++ b/examples/online_serving/text_to_video/run_server_ltx2.sh @@ -0,0 +1,84 @@ +#!/bin/bash +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# LTX-2 online serving startup script with optimization presets. +# +# Usage: +# bash run_server_ltx2.sh # baseline (1 GPU, eager) +# bash run_server_ltx2.sh ulysses4 # 4-GPU Ulysses SP +# bash run_server_ltx2.sh cache-dit # 1 GPU + Cache-DiT +# bash run_server_ltx2.sh best-combo # 4-GPU Ulysses SP + Cache-DiT +# +# Online serving benchmarks on H800 (480×768, 41 frames, 20 steps): +# baseline : 10.3s inference (1.00×) +# compile : ~10.3s warm (~1.00×) first request +6s warmup +# ulysses4 : ~10.3s (~1.00×) no gain at 41 frames +# cache-dit : 7.4s avg (~1.4×) lossy, variable per request +# best-combo : 4.7s avg (~2.2×) 4-GPU ulysses + cache-dit + +set -euo pipefail + +MODEL="${MODEL:-Lightricks/LTX-2}" +PORT="${PORT:-8098}" +FLOW_SHIFT="${FLOW_SHIFT:-1.0}" +BOUNDARY_RATIO="${BOUNDARY_RATIO:-1.0}" + +PRESET="${1:-baseline}" + +EXTRA_ARGS=() +case "$PRESET" in + baseline) + echo "=== LTX-2 Preset: baseline (1 GPU, enforce-eager) ===" + EXTRA_ARGS+=(--enforce-eager) + ;; + ulysses2) + echo "=== LTX-2 Preset: 2-GPU Ulysses SP (lossless) ===" + EXTRA_ARGS+=(--enforce-eager --usp 2) + ;; + ulysses4) + echo "=== LTX-2 Preset: 4-GPU Ulysses SP (lossless) ===" + EXTRA_ARGS+=(--enforce-eager --usp 4) + ;; + cache-dit) + echo "=== LTX-2 Preset: Cache-DiT (1 GPU, lossy) ===" + EXTRA_ARGS+=(--enforce-eager --cache-backend cache_dit) + ;; + best-combo) + echo "=== LTX-2 Preset: 4-GPU Ulysses SP + Cache-DiT (best combo) ===" + EXTRA_ARGS+=(--enforce-eager --usp 4 --cache-backend cache_dit) + ;; + compile) + echo "=== LTX-2 Preset: torch.compile (1 GPU, lossless) ===" + # torch.compile is the default (no --enforce-eager) + ;; + *) + echo "Usage: $0 {baseline|ulysses2|ulysses4|cache-dit|best-combo|compile}" + echo "" + echo "Presets:" + echo " baseline - 1 GPU, eager execution (reference)" + echo " ulysses2 - 2-GPU Ulysses SP (lossless)" + echo " ulysses4 - 4-GPU Ulysses SP (lossless)" + echo " cache-dit - 1 GPU + Cache-DiT (lossy, ~1.4× speedup)" + echo " best-combo - 4-GPU Ulysses SP + Cache-DiT (~2.2× speedup)" + echo " compile - 1 GPU + torch.compile (slower first request)" + echo "" + echo "Environment variables:" + echo " MODEL - Model path (default: Lightricks/LTX-2)" + echo " PORT - Server port (default: 8098)" + echo " FLOW_SHIFT - Scheduler flow shift (default: 1.0)" + echo " BOUNDARY_RATIO - Boundary ratio (default: 1.0)" + exit 1 + ;; +esac + +echo "Model: $MODEL" +echo "Port: $PORT" +echo "Flow shift: $FLOW_SHIFT" +echo "Boundary ratio: $BOUNDARY_RATIO" + +vllm serve "$MODEL" --omni \ + --port "$PORT" \ + --flow-shift "$FLOW_SHIFT" \ + --boundary-ratio "$BOUNDARY_RATIO" \ + "${EXTRA_ARGS[@]}" diff --git a/examples/online_serving/voxcpm/README.md b/examples/online_serving/voxcpm/README.md new file mode 100644 index 0000000000..78e1bf4aaa --- /dev/null +++ b/examples/online_serving/voxcpm/README.md @@ -0,0 +1,166 @@ +# VoxCPM + +## Prerequisites + +Install VoxCPM in one of these ways: + +```bash +pip install voxcpm +``` + +or point vLLM-Omni to a local VoxCPM source tree: + +```bash +export VLLM_OMNI_VOXCPM_CODE_PATH=/path/to/VoxCPM/src +``` + +If the native VoxCPM `config.json` lacks HF metadata such as `model_type`, +prepare a persistent HF-compatible config directory and export: + +```bash +export VLLM_OMNI_VOXCPM_HF_CONFIG_PATH=/tmp/voxcpm_hf_config +mkdir -p "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH" +cp "$VOXCPM_MODEL/config.json" "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH/config.json" +cp "$VOXCPM_MODEL/generation_config.json" "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH/generation_config.json" 2>/dev/null || true +python3 -c 'import json, os; p=os.path.join(os.environ["VLLM_OMNI_VOXCPM_HF_CONFIG_PATH"], "config.json"); cfg=json.load(open(p, "r", encoding="utf-8")); cfg["model_type"]="voxcpm"; cfg.setdefault("architectures", ["VoxCPMForConditionalGeneration"]); json.dump(cfg, open(p, "w", encoding="utf-8"), indent=2, ensure_ascii=False)' +``` + +The VoxCPM stage configs read `VLLM_OMNI_VOXCPM_HF_CONFIG_PATH` directly. The `python3 -c` form above avoids heredoc/indentation issues in interactive shells. + +## Launch the Server + +Use the async-chunk stage config by default: + +```bash +export VOXCPM_MODEL=/path/to/voxcpm-model +cd examples/online_serving/voxcpm +./run_server.sh +``` + +Use the non-streaming stage config: + +```bash +./run_server.sh sync +``` + +You can also launch the server directly: + +```bash +vllm serve "$VOXCPM_MODEL" \ + --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml \ + --trust-remote-code \ + --enforce-eager \ + --omni \ + --port 8091 +``` + +## Send Requests + +### Basic text-to-speech + +```bash +python openai_speech_client.py \ + --model "$VOXCPM_MODEL" \ + --text "This is a VoxCPM online text-to-speech example." +``` + +### Voice cloning + +```bash +python openai_speech_client.py \ + --model "$VOXCPM_MODEL" \ + --text "This sentence is synthesized with a cloned voice." \ + --ref-audio /path/to/reference.wav \ + --ref-text "The exact transcript spoken in reference.wav." +``` + +`ref_text` must be the real transcript of the reference audio. Placeholder text or mismatched text will usually degrade quality badly. + +### Streaming PCM output + +```bash +python openai_speech_client.py \ + --model "$VOXCPM_MODEL" \ + --text "This is a streaming VoxCPM request." \ + --stream \ + --output voxcpm_stream.pcm +``` + +### Using curl + +```bash +curl -X POST http://localhost:8091/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "model": "OpenBMB/VoxCPM1.5", + "input": "Hello from VoxCPM online serving.", + "response_format": "wav" + }' --output output.wav +``` + +Voice cloning: + +```bash +curl -X POST http://localhost:8091/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "model": "OpenBMB/VoxCPM1.5", + "input": "This sentence uses a cloned voice.", + "ref_audio": "https://example.com/reference.wav", + "ref_text": "The exact transcript spoken in the reference audio.", + "response_format": "wav" + }' --output cloned.wav +``` + +Streaming PCM: + +```bash +curl -X POST http://localhost:8091/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "model": "OpenBMB/VoxCPM1.5", + "input": "This is a streaming VoxCPM request.", + "stream": true, + "response_format": "pcm" + }' --output output.pcm +``` + +## Supported Request Shape + +VoxCPM online serving currently supports: + +- plain text-to-speech +- voice cloning with `ref_audio` + `ref_text` +- `stream=true` with `response_format=pcm` or `wav` + +VoxCPM online serving does not use these generic TTS fields: + +- `voice` +- `instructions` +- `language` +- `speaker_embedding` +- `x_vector_only_mode` + +## Streaming vs Non-Streaming + +- `voxcpm_async_chunk.yaml` enables async-chunk streaming and is best for single-request streaming latency. +- `voxcpm.yaml` performs one-shot latent generation then VAE decode. + +Like native VoxCPM, the async streaming path should be treated as single-request. If you need stable throughput benchmarking, prefer `voxcpm.yaml`. + +Do not use `voxcpm_async_chunk.yaml` for concurrent online streaming or `/v1/audio/speech/batch`. For multiple requests, prefer `voxcpm.yaml`. + +## Benchmark + +The serving benchmark reports TTFP and RTF: + +```bash +python benchmarks/voxcpm/vllm_omni/bench_tts_serve.py \ + --host 127.0.0.1 \ + --port 8091 \ + --num-prompts 10 \ + --max-concurrency 1 \ + --result-dir /tmp/voxcpm_bench +``` + +For the async-chunk server, keep `--max-concurrency 1`. diff --git a/examples/online_serving/voxcpm/openai_speech_client.py b/examples/online_serving/voxcpm/openai_speech_client.py new file mode 100644 index 0000000000..c400114e8b --- /dev/null +++ b/examples/online_serving/voxcpm/openai_speech_client.py @@ -0,0 +1,155 @@ +"""OpenAI-compatible client for VoxCPM via /v1/audio/speech. + +Examples: + # Basic text-to-speech + python openai_speech_client.py --text "Hello from VoxCPM" + + # Voice cloning + python openai_speech_client.py \ + --text "This sentence uses the cloned voice." \ + --ref-audio /path/to/reference.wav \ + --ref-text "The exact transcript spoken in the reference audio." + + # Streaming PCM output + python openai_speech_client.py \ + --text "This is a streaming VoxCPM request." \ + --stream \ + --output output.pcm +""" + +import argparse +import base64 +import os + +import httpx + +DEFAULT_API_BASE = "http://localhost:8091" +DEFAULT_API_KEY = "EMPTY" +DEFAULT_MODEL = "OpenBMB/VoxCPM1.5" + + +def encode_audio_to_base64(audio_path: str) -> str: + """Encode a local audio file to base64 data URL.""" + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + + ext = audio_path.lower().rsplit(".", 1)[-1] + mime_map = { + "wav": "audio/wav", + "mp3": "audio/mpeg", + "flac": "audio/flac", + "ogg": "audio/ogg", + } + mime_type = mime_map.get(ext, "audio/wav") + + with open(audio_path, "rb") as f: + audio_b64 = base64.b64encode(f.read()).decode("utf-8") + return f"data:{mime_type};base64,{audio_b64}" + + +def build_payload(args) -> dict[str, object]: + payload: dict[str, object] = { + "model": args.model, + "input": args.text, + "response_format": "pcm" if args.stream else args.response_format, + } + + if args.ref_audio: + if args.ref_audio.startswith(("http://", "https://", "data:")): + payload["ref_audio"] = args.ref_audio + else: + payload["ref_audio"] = encode_audio_to_base64(args.ref_audio) + if args.ref_text: + payload["ref_text"] = args.ref_text + if args.max_new_tokens is not None: + payload["max_new_tokens"] = args.max_new_tokens + if args.stream: + payload["stream"] = True + + return payload + + +def run_tts(args) -> None: + payload = build_payload(args) + api_url = f"{args.api_base}/v1/audio/speech" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {args.api_key}", + } + + print(f"Model: {args.model}") + print(f"Text: {args.text}") + if args.ref_audio: + print("Mode: voice cloning") + print(f"Reference audio: {args.ref_audio}") + else: + print("Mode: text-to-speech") + + if args.stream: + output_path = args.output or "voxcpm_output.pcm" + with httpx.Client(timeout=300.0) as client: + with client.stream("POST", api_url, json=payload, headers=headers) as response: + if response.status_code != 200: + print(f"Error: {response.status_code}") + print(response.read().decode("utf-8", errors="ignore")) + return + + total_bytes = 0 + with open(output_path, "wb") as f: + for chunk in response.iter_bytes(): + if not chunk: + continue + f.write(chunk) + total_bytes += len(chunk) + print(f"Streamed {total_bytes} bytes to: {output_path}") + return + + with httpx.Client(timeout=300.0) as client: + response = client.post(api_url, json=payload, headers=headers) + + if response.status_code != 200: + print(f"Error: {response.status_code}") + print(response.text) + return + + try: + text = response.content.decode("utf-8") + if text.startswith('{"error"'): + print(f"Error: {text}") + return + except UnicodeDecodeError: + pass + + output_path = args.output or "voxcpm_output.wav" + with open(output_path, "wb") as f: + f.write(response.content) + print(f"Audio saved to: {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description="VoxCPM OpenAI-compatible speech client") + parser.add_argument("--api-base", default=DEFAULT_API_BASE, help="API base URL") + parser.add_argument("--api-key", default=DEFAULT_API_KEY, help="API key") + parser.add_argument("--model", "-m", default=DEFAULT_MODEL, help="Model name or path") + parser.add_argument("--text", required=True, help="Text to synthesize") + parser.add_argument("--ref-audio", default=None, help="Reference audio path, URL, or data URL") + parser.add_argument( + "--ref-text", + default=None, + help="The exact transcript spoken in the reference audio", + ) + parser.add_argument("--stream", action="store_true", help="Enable streaming PCM output") + parser.add_argument( + "--response-format", + default="wav", + choices=["wav", "pcm", "flac", "mp3", "aac", "opus"], + help="Audio format for non-streaming mode (default: wav)", + ) + parser.add_argument("--max-new-tokens", type=int, default=None, help="Maximum tokens to generate") + parser.add_argument("--output", "-o", default=None, help="Output file path") + args = parser.parse_args() + run_tts(args) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/voxcpm/run_server.sh b/examples/online_serving/voxcpm/run_server.sh new file mode 100755 index 0000000000..ab4b6fe854 --- /dev/null +++ b/examples/online_serving/voxcpm/run_server.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# Launch vLLM-Omni server for VoxCPM online speech serving. +# +# Usage: +# ./run_server.sh # default: async_chunk stage config +# ./run_server.sh async # async_chunk stage config +# ./run_server.sh sync # no-async-chunk stage config +# VOXCPM_MODEL=/path/to/model ./run_server.sh + +set -e + +MODE="${1:-async}" +MODEL="${VOXCPM_MODEL:-OpenBMB/VoxCPM1.5}" + +case "$MODE" in + async) + STAGE_CONFIG="vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml" + ;; + sync) + STAGE_CONFIG="vllm_omni/model_executor/stage_configs/voxcpm.yaml" + ;; + *) + echo "Unknown mode: $MODE" + echo "Supported: async, sync" + exit 1 + ;; +esac + +echo "Starting VoxCPM server with model: $MODEL" +echo "Stage config: $STAGE_CONFIG" + +vllm serve "$MODEL" \ + --stage-configs-path "$STAGE_CONFIG" \ + --host 0.0.0.0 \ + --port 8091 \ + --trust-remote-code \ + --enforce-eager \ + --omni diff --git a/examples/online_serving/voxcpm2/README.md b/examples/online_serving/voxcpm2/README.md new file mode 100644 index 0000000000..8735180f0a --- /dev/null +++ b/examples/online_serving/voxcpm2/README.md @@ -0,0 +1,42 @@ +# VoxCPM2 Online Serving + +Serve VoxCPM2 TTS via the OpenAI-compatible `/v1/audio/speech` endpoint. + +## Start the Server + +```bash +python -m vllm_omni.entrypoints.openai.api_server \ + --model openbmb/VoxCPM2 \ + --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm2.yaml \ + --host 0.0.0.0 --port 8000 +``` + +## Zero-shot Synthesis + +```bash +python openai_speech_client.py --text "Hello, this is VoxCPM2." +``` + +Or with curl: + +```bash +curl -X POST http://localhost:8000/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{"model": "voxcpm2", "input": "Hello, this is VoxCPM2.", "voice": "default"}' \ + --output output.wav +``` + +## Voice Cloning + +Clone a speaker's voice using a reference audio file: + +```bash +python openai_speech_client.py \ + --text "This should sound like the reference speaker." \ + --ref-audio /path/to/reference.wav +``` + +The `--ref-audio` parameter accepts: +- Local file path (auto-encoded to base64) +- URL (`https://...`) +- Base64 data URI (`data:audio/wav;base64,...`) diff --git a/examples/online_serving/voxcpm2/gradio_demo.py b/examples/online_serving/voxcpm2/gradio_demo.py new file mode 100644 index 0000000000..a33a2d9245 --- /dev/null +++ b/examples/online_serving/voxcpm2/gradio_demo.py @@ -0,0 +1,602 @@ +"""Gradio demo for VoxCPM2 TTS with gapless streaming audio playback. + +Uses a custom AudioWorklet-based player for gap-free streaming +(adapted from the Qwen3-TTS demo). Audio is streamed from the vLLM +server through a same-origin proxy and played via the Web Audio API's +AudioWorklet, which maintains a FIFO buffer queue and plays samples at +the audio clock rate. + +Usage: + # Start the vLLM server first: + python -m vllm_omni.entrypoints.openai.api_server \ + --model openbmb/VoxCPM2 \ + --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm2.yaml \ + --host 0.0.0.0 --port 8000 + + # Then launch the demo: + python gradio_demo.py --api-base http://localhost:8000 +""" + +from __future__ import annotations + +import argparse +import base64 +import io +import json +import logging + +import gradio as gr +import httpx +import numpy as np +import soundfile as sf +from fastapi import FastAPI, Request +from fastapi.responses import Response, StreamingResponse + +logger = logging.getLogger(__name__) + +SAMPLE_RATE = 48000 + +# ── AudioWorklet processor (loaded in browser via Blob URL) ────────── +WORKLET_JS = r""" +class TTSPlaybackProcessor extends AudioWorkletProcessor { + constructor() { + super(); + this.queue = []; + this.buf = null; + this.pos = 0; + this.playing = false; + this.played = 0; + this.port.onmessage = (e) => { + if (e.data && e.data.type === 'clear') { + this.queue = []; this.buf = null; this.pos = 0; this.played = 0; + if (this.playing) { this.playing = false; this.port.postMessage({type:'stopped'}); } + return; + } + this.queue.push(e.data); + }; + } + process(inputs, outputs) { + const out = outputs[0][0]; + for (let i = 0; i < out.length; i++) { + if (!this.buf || this.pos >= this.buf.length) { + if (this.queue.length > 0) { + this.buf = this.queue.shift(); this.pos = 0; + } else { + for (let j = i; j < out.length; j++) out[j] = 0; + if (this.playing) { this.playing = false; this.port.postMessage({type:'stopped', played:this.played}); } + return true; + } + } + out[i] = this.buf[this.pos++] / 32768; + this.played++; + } + if (!this.playing) { this.playing = true; this.port.postMessage({type:'started'}); } + return true; + } +} +registerProcessor('tts-playback-processor', TTSPlaybackProcessor); +""" + +PLAYER_HTML = """ +

+
+
+ Ready + +
+ + + +
+""" + + +def _build_player_js() -> str: + return f""" + +""" + + +def _encode_audio(audio_data: tuple) -> str: + sr, audio_np = audio_data + if audio_np.dtype in (np.float32, np.float64): + audio_np = np.clip(audio_np, -1.0, 1.0) + audio_np = (audio_np * 32767).astype(np.int16) + elif audio_np.dtype != np.int16: + audio_np = audio_np.astype(np.int16) + buf = io.BytesIO() + sf.write(buf, audio_np, sr, format="WAV") + return f"data:audio/wav;base64,{base64.b64encode(buf.getvalue()).decode()}" + + +def create_app(api_base: str): + app = FastAPI() + _pending: dict[str, dict] = {} + + @app.post("/proxy/v1/audio/speech") + async def proxy_speech(request: Request): + body = await request.json() + req_id = body.get("_req_id") + if req_id and req_id in _pending: + body = _pending.pop(req_id) + logger.info("Proxy: %s", {k: (f"<{len(str(v))} chars>" if k == "ref_audio" else v) for k, v in body.items()}) + try: + client = httpx.AsyncClient(timeout=300) + resp = await client.send( + client.build_request( + "POST", + f"{api_base}/v1/audio/speech", + json=body, + headers={"Authorization": "Bearer EMPTY", "Content-Type": "application/json"}, + ), + stream=True, + ) + except Exception as exc: + logger.exception("Proxy connection error") + await client.aclose() + return Response(content=str(exc), status_code=502) + if resp.status_code != 200: + content = await resp.aread() + await resp.aclose() + await client.aclose() + return Response(content=content, status_code=resp.status_code) + + async def relay(): + try: + async for chunk in resp.aiter_bytes(): + yield chunk + finally: + await resp.aclose() + await client.aclose() + + return StreamingResponse(relay(), media_type="application/octet-stream") + + css = """ + #generate-btn button { width: 100%; } + #streaming-player { border: 1px solid var(--border-color-primary) !important; border-radius: var(--block-radius) !important; padding: var(--block-padding) !important; } + """ + theme = gr.themes.Default( + primary_hue=gr.themes.Color( + c50="#f0f5ff", + c100="#dce6f9", + c200="#b8cef3", + c300="#8eb2eb", + c400="#6496e0", + c500="#4A90D9", + c600="#3a7bc8", + c700="#2d66b0", + c800="#1f4f8f", + c900="#163a6e", + c950="#0e2650", + ), + ) + + with gr.Blocks(title="VoxCPM2 TTS Demo") as demo: + gr.HTML(f""" +
+ vLLM-Omni +
+

VoxCPM2 Streaming Demo

+ + Served by vLLM-Omni + · {api_base} + · 48 kHz + +
+
+ """) + + gr.Markdown( + "**Three modes:** " + "**Voice Design** (control instruction only) · " + "**Controllable Cloning** (ref audio + optional style control) · " + "**Ultimate Cloning** (ref audio + transcript for audio continuation)" + ) + + with gr.Row(): + with gr.Column(scale=3): + text_input = gr.Textbox( + label="Target Text", + placeholder="Enter text to synthesize...", + lines=4, + ) + control_instruction = gr.Textbox( + label="Control Instruction (optional)", + placeholder="e.g. A warm young woman / Excited and fast-paced", + lines=2, + info="Describe voice style, emotion, pace. Works for both Voice Design and Controllable Cloning.", + ) + + with gr.Accordion("Voice Cloning", open=False): + ref_audio = gr.Audio( + label="Reference Audio (upload for cloning)", + type="numpy", + sources=["upload", "microphone"], + ) + ref_audio_url = gr.Textbox( + label="or Reference Audio URL", + placeholder="https://example.com/reference.wav", + ) + ultimate_clone = gr.Checkbox( + label="Ultimate Cloning Mode", + value=False, + info="Provide transcript of ref audio for audio continuation (disables control instruction)", + ) + prompt_text = gr.Textbox( + label="Reference Audio Transcript", + placeholder="Transcript of your reference audio (for ultimate cloning)", + lines=2, + visible=False, + ) + + with gr.Row(): + stream_checkbox = gr.Checkbox( + label="Stream (gapless)", + value=True, + info="AudioWorklet streaming", + ) + with gr.Row(): + generate_btn = gr.Button( + "Generate Speech", + variant="primary", + size="lg", + elem_id="generate-btn", + scale=3, + ) + reset_btn = gr.Button("Reset", variant="secondary", size="lg", scale=1) + + with gr.Column(scale=2): + player_html = gr.HTML( + value=PLAYER_HTML, + visible=True, + label="streaming player", + elem_id="streaming-player", + ) + audio_output = gr.Audio( + label="generated audio", + interactive=False, + autoplay=True, + visible=False, + ) + gr.Examples( + examples=[ + ["Hello, this is a VoxCPM2 demo running on vLLM-Omni.", ""], + [ + "I have a dream that my four little children will one day live in a nation " + "where they will not be judged by the color of their skin but by the content " + "of their character.", + "", + ], + [ + "I never asked you to stay. It's not like I care or anything. " + "But why does it still hurt so much now that you're gone?", + "A young girl with a soft, sweet voice. Speaks slowly with a melancholic tone.", + ], + ], + inputs=[text_input, control_instruction], + label="examples", + ) + gr.HTML(""" +
+ + vLLM-Omni + +
+ """) + + hidden_payload = gr.Textbox(visible=False, elem_id="tts-payload") + + def on_ultimate_toggle(checked): + return ( + gr.update(visible=checked), # prompt_text + gr.update(interactive=not checked), # control_instruction + ) + + ultimate_clone.change( + fn=on_ultimate_toggle, + inputs=[ultimate_clone], + outputs=[prompt_text, control_instruction], + ) + + def on_stream_change(stream: bool): + if stream: + return gr.update(visible=True), gr.update(visible=False) + return gr.update(visible=False), gr.update(visible=True) + + stream_checkbox.change( + fn=on_stream_change, + inputs=[stream_checkbox], + outputs=[player_html, audio_output], + ) + + def on_reset(): + return "", "", None, "", False, "", PLAYER_HTML + + reset_btn.click( + fn=on_reset, + outputs=[ + text_input, + control_instruction, + audio_output, + hidden_payload, + ultimate_clone, + prompt_text, + player_html, + ], + js="() => { if (window.ttsStop) window.ttsStop(); }", + ) + + def on_generate(stream_enabled, text, ctrl_instr, ref_a, ref_url, ult_clone, p_text): + import time as _time + + if not text or not text.strip(): + raise gr.Error("Please enter text to synthesize.") + + # VoxCPM2 uses "(instruction)text" format for control + ctrl = ctrl_instr.strip() if ctrl_instr and not ult_clone else "" + final_text = f"({ctrl}){text.strip()}" if ctrl else text.strip() + + payload: dict = { + "input": final_text, + "voice": "default", + "response_format": "pcm" if stream_enabled else "wav", + "stream": stream_enabled, + } + + # Reference audio for cloning + ref_url_s = ref_url.strip() if ref_url else "" + if ref_url_s: + payload["ref_audio"] = ref_url_s + elif ref_a is not None: + payload["ref_audio"] = _encode_audio(ref_a) + + # Ultimate cloning: prompt_audio + prompt_text for continuation + if ult_clone and p_text and p_text.strip(): + if ref_url_s: + payload["prompt_audio"] = ref_url_s + elif ref_a is not None: + payload["prompt_audio"] = payload.get("ref_audio", "") + payload["prompt_text"] = p_text.strip() + + if stream_enabled: + if ref_a is not None and not ref_url_s: + req_id = f"req-{int(_time.time() * 1000)}" + _pending[req_id] = payload + browser_payload = {"_req_id": req_id, "_nonce": int(_time.time() * 1000)} + return json.dumps(browser_payload), gr.update() + payload["_nonce"] = int(_time.time() * 1000) + return json.dumps(payload), gr.update() + else: + try: + with httpx.Client(timeout=300.0) as client: + resp = client.post( + f"{api_base}/v1/audio/speech", + json=payload, + headers={"Content-Type": "application/json", "Authorization": "Bearer EMPTY"}, + ) + except httpx.ConnectError: + raise gr.Error(f"Cannot connect to server at {api_base}.") + if resp.status_code != 200: + raise gr.Error(f"Server error ({resp.status_code}): {resp.text[:200]}") + audio_np, sr = sf.read(io.BytesIO(resp.content)) + if audio_np.ndim > 1: + audio_np = audio_np[:, 0] + return "", (sr, audio_np.astype(np.float32)) + + generate_btn.click( + fn=on_generate, + inputs=[ + stream_checkbox, + text_input, + control_instruction, + ref_audio, + ref_audio_url, + ultimate_clone, + prompt_text, + ], + outputs=[hidden_payload, audio_output], + ).then( + fn=lambda p: p, + inputs=[hidden_payload], + outputs=[hidden_payload], + js="(p) => { if (p && p.trim()) { const d = JSON.parse(p); delete d._nonce; window.ttsGenerate(d); } return p; }", + ) + + demo.queue() + + return gr.mount_gradio_app(app, demo, path="/", css=css, theme=theme, head=_build_player_js()) + + +def main(): + parser = argparse.ArgumentParser(description="VoxCPM2 streaming Gradio demo") + parser.add_argument("--api-base", default="http://localhost:8000", help="vLLM API server URL") + parser.add_argument("--host", default="0.0.0.0", help="Gradio server host") + parser.add_argument("--port", type=int, default=7860, help="Gradio server port") + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + print(f"Connecting to vLLM server at: {args.api_base}") + + import uvicorn + + uvicorn.run(create_app(args.api_base), host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/voxcpm2/openai_speech_client.py b/examples/online_serving/voxcpm2/openai_speech_client.py new file mode 100644 index 0000000000..a117d24fd1 --- /dev/null +++ b/examples/online_serving/voxcpm2/openai_speech_client.py @@ -0,0 +1,108 @@ +"""OpenAI-compatible client for VoxCPM2 TTS via /v1/audio/speech endpoint. + +Examples: + # Zero-shot synthesis + python openai_speech_client.py --text "Hello, this is VoxCPM2." + + # Voice cloning with a local reference audio file + python openai_speech_client.py --text "Hello world" \ + --ref-audio /path/to/reference.wav + + # Voice cloning with a URL + python openai_speech_client.py --text "Hello world" \ + --ref-audio "https://example.com/reference.wav" + +Server setup: + python -m vllm_omni.entrypoints.openai.api_server \ + --model openbmb/VoxCPM2 \ + --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm2.yaml \ + --host 0.0.0.0 --port 8000 +""" + +from __future__ import annotations + +import argparse +import base64 +import os + +import httpx + +DEFAULT_API_BASE = "http://localhost:8000" +DEFAULT_API_KEY = "sk-empty" + + +def encode_audio_to_base64(audio_path: str) -> str: + """Encode a local audio file to a base64 data URL.""" + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found: {audio_path}") + + ext = audio_path.lower().rsplit(".", 1)[-1] + mime = { + "wav": "audio/wav", + "mp3": "audio/mpeg", + "flac": "audio/flac", + "ogg": "audio/ogg", + }.get(ext, "audio/wav") + + with open(audio_path, "rb") as f: + b64 = base64.b64encode(f.read()).decode("utf-8") + return f"data:{mime};base64,{b64}" + + +def main() -> None: + parser = argparse.ArgumentParser(description="VoxCPM2 OpenAI speech client") + parser.add_argument("--text", type=str, required=True, help="Text to synthesize") + parser.add_argument( + "--ref-audio", + type=str, + default=None, + help="Reference audio for voice cloning (local path, URL, or data: URI)", + ) + parser.add_argument("--model", type=str, default="voxcpm2") + parser.add_argument("--output", type=str, default="output.wav") + parser.add_argument("--api-base", type=str, default=DEFAULT_API_BASE) + parser.add_argument("--api-key", type=str, default=DEFAULT_API_KEY) + parser.add_argument("--response-format", type=str, default="wav") + args = parser.parse_args() + + # VoxCPM2 has no predefined voices. The "voice" field is required by + # the OpenAI API schema but ignored by VoxCPM2 — use any placeholder. + # For voice cloning, pass --ref-audio instead. + payload: dict = { + "model": args.model, + "input": args.text, + "voice": "default", + "response_format": args.response_format, + } + + if args.ref_audio: + ref = args.ref_audio + if ref.startswith(("http://", "https://", "data:")): + payload["ref_audio"] = ref + else: + payload["ref_audio"] = encode_audio_to_base64(ref) + + url = f"{args.api_base}/v1/audio/speech" + print(f"POST {url}") + print(f" text: {args.text}") + if args.ref_audio: + print(f" ref_audio: {args.ref_audio[:80]}...") + + with httpx.Client(timeout=300) as client: + resp = client.post( + url, + json=payload, + headers={"Authorization": f"Bearer {args.api_key}"}, + ) + + if resp.status_code != 200: + print(f"Error {resp.status_code}: {resp.text[:500]}") + return + + with open(args.output, "wb") as f: + f.write(resp.content) + print(f"Saved: {args.output} ({len(resp.content):,} bytes)") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index e49aa6e325..9b034a7c8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,12 +55,24 @@ dev = [ "pyttsx3>=2.99", "opencc>=1.2.0", "mistune>=3.2.0", # for example tests + "torchmetrics>=1.4.0", # for accuracy similarity metrics ] demo = [ "gradio>=6.7.0", ] +# Seed-TTS serve benchmark WER (BytedanceSpeech/seed-tts-eval run_wer.py protocol). +seed-tts-eval = [ + "jiwer>=3.0.0", + "zhon>=2.0.0", + "zhconv>=1.4.2", + "scipy>=1.10.0", + "soundfile>=0.12.0", + "transformers>=4.36.0", + "funasr>=1.0.0", +] + docs = [ "mkdocs>=1.5.0", "mkdocs-api-autonav", @@ -182,6 +194,7 @@ markers = [ "H100: Tests that require H100 GPU", "L4: Tests that require L4 GPU", "MI325: Tests that require MI325 GPU (AMD/ROCm)", + "B60: Tests that require Intel Arc Pro B60 XPU", "S5000: Tests that require S5000 GPU (Moore Threads/MUSA)", "A2: Tests that require A2 NPU", "A3: Tests that require A3 NPU", diff --git a/recipes/Qwen/Qwen3-Omni.md b/recipes/Qwen/Qwen3-Omni.md new file mode 100644 index 0000000000..081e1453d3 --- /dev/null +++ b/recipes/Qwen/Qwen3-Omni.md @@ -0,0 +1,90 @@ +# Qwen3-Omni for multimodal chat on 1x A100 80GB + +## Summary + +- Vendor: Qwen +- Model: `Qwen/Qwen3-Omni-30B-A3B-Instruct` +- Task: Multimodal chat with text, image, audio, or video input +- Mode: Online serving with the OpenAI-compatible API +- Maintainer: Community + +## When to use this recipe + +Use this recipe when you want a known-good starting point for serving +`Qwen/Qwen3-Omni-30B-A3B-Instruct` with vLLM-Omni on a single 80 GB A100 and +validate the deployment with the existing multimodal client examples in this +repository. + +## References + +- Upstream or canonical docs: + [`docs/user_guide/examples/online_serving/qwen3_omni.md`](../../docs/user_guide/examples/online_serving/qwen3_omni.md) +- Related example under `examples/`: + [`examples/online_serving/qwen3_omni/README.md`](../../examples/online_serving/qwen3_omni/README.md) +- Related issue or discussion: + [RFC: add recipes folder](https://github.com/vllm-project/vllm-omni/issues/2645) + +## Hardware Support + +This recipe currently documents one tested-style reference configuration for +CUDA GPU serving. Add more sections for other hardware as community validation +lands. + +## GPU + +### 1x A100 80GB + +#### Environment + +- OS: Linux +- Python: 3.10+ +- Driver / runtime: NVIDIA CUDA environment with an A100 80 GB GPU +- vLLM version: Match the repository requirements for your checkout +- vLLM-Omni version or commit: Use the commit you are deploying from + +#### Command + +Start the server from the repository root: + +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 +``` + +To enable async chunking, use the bundled stage config: + +```bash +vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct \ + --omni \ + --port 8091 \ + --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml +``` + +#### Verification + +Run one of the existing example clients after the server is ready: + +```bash +python examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py \ + --model Qwen/Qwen3-Omni-30B-A3B-Instruct \ + --query-type use_image \ + --port 8091 \ + --host localhost +``` + +For a quick API smoke test, request text-only output: + +```bash +curl http://localhost:8091/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "messages": [{"role": "user", "content": "Describe vLLM in brief."}], + "modalities": ["text"] + }' +``` + +#### Notes + +- Memory usage: Size depends on runtime options and output modalities; leave headroom for multimodal workloads. +- Key flags: `--omni` is required; `--stage-configs-path` is optional for custom or async-chunk stage configs. +- Known limitations: This starter recipe is intentionally narrow and focuses on the single-GPU online-serving path already documented in the repo examples. diff --git a/recipes/README.md b/recipes/README.md new file mode 100644 index 0000000000..5b3dfb5430 --- /dev/null +++ b/recipes/README.md @@ -0,0 +1,35 @@ +# Community Recipes + +This directory contains community-maintained recipes for answering a +practical user question: + +> How do I run model X on hardware Y for task Z? + +Add recipes for this repository under this in-repo `recipes/` directory. To +keep naming and layout consistent, organize recipes by model vendor in a way +that is aligned with +[`vllm-project/recipes`](https://github.com/vllm-project/recipes), but treat +that external repository as a reference for structure rather than the place to +add files for this repo. Use one Markdown file per model family by default. + +Example layout: + +```text +recipes/ + Qwen/ + Qwen3-Omni.md + Qwen3-TTS.md + Tencent-Hunyuan/ + HunyuanVideo.md +``` + +## Available Recipes + +- [`Qwen/Qwen3-Omni.md`](./Qwen/Qwen3-Omni.md): online serving recipe for + multimodal chat on `1x A100 80GB` + +Within a single recipe file, include different hardware support sections such +as `GPU`, `ROCm`, and `NPU`, and add concrete tested configurations like +`1x A100 80GB` or `2x L40S` inside those sections when applicable. + +See [TEMPLATE.md](./TEMPLATE.md) for the recommended format. diff --git a/recipes/TEMPLATE.md b/recipes/TEMPLATE.md new file mode 100644 index 0000000000..9bf8cb9c75 --- /dev/null +++ b/recipes/TEMPLATE.md @@ -0,0 +1,82 @@ +# Recipe Title + +> Example: Qwen3-Omni for speech chat on 1x A100 80GB + +## Summary + +- Vendor: +- Model: +- Task: +- Mode: +- Maintainer: + +## When to use this recipe + +Briefly describe the concrete scenario this recipe covers. + +## References + +- Upstream or canonical docs: +- Related example under `examples/`: +- Related issue or discussion: + +## Hardware Support + +Add one section per platform, such as `GPU`, `ROCm`, or `NPU`. Under each +platform section, document one or more tested hardware configurations. + +## GPU + +### 1x A100 80GB + +#### Environment + +- OS: +- Python: +- Driver / runtime: +- vLLM version: +- vLLM-Omni version or commit: + +#### Command + +```bash +# Add the exact command(s) here +``` + +#### Verification + +```bash +# Add a quick validation command or expected output here +``` + +#### Notes + +- Memory usage: +- Key flags: +- Known limitations: + +### 2x L40S + +Repeat the same structure for other hardware setups as needed. + +## ROCm + +### Example hardware configuration + +Repeat the same nested structure for ROCm setups as needed: + +- `#### Environment` +- `#### Command` +- `#### Verification` +- `#### Notes` + +## NPU + +### Example hardware configuration + +Repeat the same nested structure for NPU setups as needed: + +- `#### Environment` +- `#### Command` +- `#### Verification` +- `#### Notes` diff --git a/requirements/common.txt b/requirements/common.txt index 89eaac32bc..63e16d580f 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -1,8 +1,6 @@ # Common dependencies for all platforms av>=14.0.0 omegaconf>=2.3.0 -librosa>=0.11.0 -resampy>=0.4.3 diffusers>=0.36.0 accelerate==1.12.0 soundfile>=0.13.1 @@ -11,7 +9,6 @@ tqdm>=4.66.0 torchsde>=0.2.6 openai-whisper>=20250625 imageio[ffmpeg]>=2.37.2 -sox>=1.5.0 x-transformers>=2.12.2 einops>=0.8.1 prettytable>=3.8.0 diff --git a/tests/comfyui/conftest.py b/tests/comfyui/conftest.py index 0b4565e946..4280d3506f 100644 --- a/tests/comfyui/conftest.py +++ b/tests/comfyui/conftest.py @@ -9,8 +9,8 @@ import os import sys +from types import ModuleType, SimpleNamespace from typing import BinaryIO, TypedDict -from unittest.mock import MagicMock def pytest_configure(config): @@ -58,15 +58,15 @@ def save_to(self, file: str | BinaryIO): else: file.write(self._data) - mock_comfy_api = MagicMock() - mock_comfy_api_input = MagicMock() + mock_comfy_api = ModuleType("comfy_api") + mock_comfy_api_input = ModuleType("comfy_api.input") mock_comfy_api_input.AudioInput = AudioInput mock_comfy_api_input.VideoInput = VideoInput mock_comfy_api.input = mock_comfy_api_input - mock_comfy_api_latest = MagicMock() - mock_comfy_api_latest.Types.VideoComponents = MagicMock(side_effect=lambda **kwargs: kwargs) - mock_comfy_api_latest.InputImpl.VideoFromComponents = MagicMock( - side_effect=lambda _: VideoInput(b"mock_video_from_components") + mock_comfy_api_latest = ModuleType("comfy_api.latest") + mock_comfy_api_latest.Types = SimpleNamespace(VideoComponents=lambda **kwargs: kwargs) + mock_comfy_api_latest.InputImpl = SimpleNamespace( + VideoFromComponents=lambda _: VideoInput(b"mock_video_from_components") ) mock_comfy_api.latest = mock_comfy_api_latest @@ -76,8 +76,8 @@ def mock_load(_: str | BinaryIO): sample_rate = 24000 return waveform, sample_rate - mock_comfy_extras = MagicMock() - mock_nodes_audio = MagicMock() + mock_comfy_extras = ModuleType("comfy_extras") + mock_nodes_audio = ModuleType("comfy_extras.nodes_audio") mock_nodes_audio.load = mock_load mock_comfy_extras.nodes_audio = mock_nodes_audio diff --git a/tests/comfyui/test_comfyui_integration.py b/tests/comfyui/test_comfyui_integration.py index f6ce82f9b2..5164f3b9ac 100644 --- a/tests/comfyui/test_comfyui_integration.py +++ b/tests/comfyui/test_comfyui_integration.py @@ -13,7 +13,6 @@ from enum import StrEnum, auto from types import SimpleNamespace from typing import Any, NamedTuple -from unittest.mock import AsyncMock, MagicMock, patch import pytest import requests @@ -28,6 +27,7 @@ ) from comfyui_vllm_omni.utils.types import AutoregressionSamplingParams, DiffusionSamplingParams, WanModelSpecificParams from PIL import Image +from pytest_mock import MockerFixture from vllm import SamplingParams from vllm.outputs import CompletionOutput, RequestOutput from vllm.utils.argparse_utils import FlexibleArgumentParser @@ -217,9 +217,10 @@ def _build_diffusion_video_output() -> OmniRequestOutput: def _build_diffusion_image_output_for_chat_endpoint() -> OmniRequestOutput: - request_output = MagicMock() - request_output.images = [_build_image_output(color="blue")] - request_output.finished = True + request_output = SimpleNamespace( + images=[_build_image_output(color="blue")], + finished=True, + ) return OmniRequestOutput( request_id="test_req_img_chat", finished=True, @@ -389,51 +390,55 @@ def sampling_case(request) -> SamplingCase: @pytest.fixture -def mock_async_omni(server_case: ServerCase, sampling_case: SamplingCase): +def mock_async_omni( + server_case: ServerCase, + sampling_case: SamplingCase, + monkeypatch: pytest.MonkeyPatch, + mocker: MockerFixture, +): async def _mock_preprocess_chat(self, *args, **kwargs): return ([{"role": "user", "content": "test"}], [{"prompt": "test prompt"}]) # Need to mock AsyncOmni itself (not only its generate method) because # 1. The API layer uses its stage_list and stage_configs attributes # 2. Its __init__ method has slow side effects (model & config loading). - with ( - patch("vllm_omni.entrypoints.openai.api_server.AsyncOmni") as MockAsyncOmni, - patch( - "vllm_omni.entrypoints.openai.serving_chat.OmniOpenAIServingChat._preprocess_chat", - new=_mock_preprocess_chat, - ), - ): - mock_instance = AsyncMock(spec=RealAsyncOmni) - mock_instance.generate = _build_mock_outputs(server_case.outputs, sampling_case, server_case) - - mock_instance.stage_list = server_case.stage_list - mock_instance.stage_configs = server_case.stage_configs - mock_instance.output_modalities = _build_output_modalities(server_case.stage_configs) - mock_instance.default_sampling_params_list = [ - SamplingParams() if _stage_type(stage) != "diffusion" else MagicMock() - for stage in server_case.stage_configs - ] - mock_instance.errored = False - mock_instance.dead_error = RuntimeError("Mock engine error") - mock_instance.model_config = MagicMock( - max_model_len=4096, - io_processor_plugin=None, - allowed_local_media_path=None, - allowed_media_domains=None, - ) - # Mimic Qwen3-TTS talker speaker config so CustomVoice validation passes. - mock_instance.model_config.hf_config = MagicMock() - mock_instance.model_config.hf_config.talker_config = MagicMock() - mock_instance.model_config.hf_config.talker_config.speaker_id = {"Vivian": 0} - mock_instance.io_processor = MagicMock() - mock_instance.input_processor = MagicMock() - mock_instance.shutdown = MagicMock() - mock_instance.get_vllm_config = AsyncMock(return_value=None) - mock_instance.get_supported_tasks = AsyncMock(return_value=["generate"]) - mock_instance.get_tokenizer = AsyncMock(return_value=None) + mock_async_omni_cls = mocker.patch("vllm_omni.entrypoints.openai.api_server.AsyncOmni") + monkeypatch.setattr( + "vllm_omni.entrypoints.openai.serving_chat.OmniOpenAIServingChat._preprocess_chat", + _mock_preprocess_chat, + ) + + mock_instance = mocker.AsyncMock(spec=RealAsyncOmni) + mock_instance.generate = _build_mock_outputs(server_case.outputs, sampling_case, server_case) + + mock_instance.stage_list = server_case.stage_list + mock_instance.stage_configs = server_case.stage_configs + mock_instance.output_modalities = _build_output_modalities(server_case.stage_configs) + mock_instance.default_sampling_params_list = [ + SamplingParams() if _stage_type(stage) != "diffusion" else mocker.MagicMock() + for stage in server_case.stage_configs + ] + mock_instance.errored = False + mock_instance.dead_error = RuntimeError("Mock engine error") + mock_instance.model_config = mocker.MagicMock( + max_model_len=4096, + io_processor_plugin=None, + allowed_local_media_path=None, + allowed_media_domains=None, + ) + # Mimic Qwen3-TTS talker speaker config so CustomVoice validation passes. + mock_instance.model_config.hf_config = mocker.MagicMock() + mock_instance.model_config.hf_config.talker_config = mocker.MagicMock() + mock_instance.model_config.hf_config.talker_config.speaker_id = {"Vivian": 0} + mock_instance.io_processor = mocker.MagicMock() + mock_instance.input_processor = mocker.MagicMock() + mock_instance.shutdown = mocker.MagicMock() + mock_instance.get_vllm_config = mocker.AsyncMock(return_value=None) + mock_instance.get_supported_tasks = mocker.AsyncMock(return_value=["generate"]) + mock_instance.get_tokenizer = mocker.AsyncMock(return_value=None) - MockAsyncOmni.return_value = mock_instance - yield MockAsyncOmni + mock_async_omni_cls.return_value = mock_instance + yield mock_async_omni_cls @pytest.fixture @@ -518,6 +523,7 @@ def run_server(): "Qwen/Qwen-Image-Edit", True, id="image-to-image-dalle-endpoint", + marks=pytest.mark.skip(reason="Temporarily disabled due to failure."), ), pytest.param( ServerCase( @@ -583,9 +589,9 @@ async def test_image_generation_node(api_server: str, model: str, image_input: b ServerCase( served_model="Qwen/Qwen2.5-Omni-7B", stage_list=[ - MagicMock(is_comprehension=True, model_stage="llm"), - MagicMock(is_comprehension=False, model_stage="llm"), - MagicMock(is_comprehension=False, model_stage="llm"), + SimpleNamespace(is_comprehension=True, model_stage="llm"), + SimpleNamespace(is_comprehension=False, model_stage="llm"), + SimpleNamespace(is_comprehension=False, model_stage="llm"), ], stage_configs=[ _make_stage_config("llm", is_comprehension=True, model_stage="thinker"), diff --git a/tests/config/__init__.py b/tests/config/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/config/test_pipeline_registry.py b/tests/config/test_pipeline_registry.py new file mode 100644 index 0000000000..3483d530c6 --- /dev/null +++ b/tests/config/test_pipeline_registry.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the central pipeline registry (2.5/N).""" + +from __future__ import annotations + +import pytest + +from vllm_omni.config.pipeline_registry import ( + _DIFFUSION_PIPELINES, + _OMNI_PIPELINES, + _VLLM_OMNI_PIPELINES, +) +from vllm_omni.config.stage_config import ( + _PIPELINE_REGISTRY, + PipelineConfig, + StageExecutionType, + StagePipelineConfig, + register_pipeline, +) + + +class TestCentralRegistryDeclarations: + """Every in-tree pipeline must be declared exactly once in the central registry.""" + + def test_union_contains_all_omni(self): + for key in _OMNI_PIPELINES: + assert key in _VLLM_OMNI_PIPELINES + + def test_union_contains_all_diffusion(self): + for key in _DIFFUSION_PIPELINES: + assert key in _VLLM_OMNI_PIPELINES + + def test_no_duplicate_model_type_between_omni_and_diffusion(self): + overlap = set(_OMNI_PIPELINES) & set(_DIFFUSION_PIPELINES) + assert not overlap, f"Duplicate model_types across omni/diffusion: {overlap}" + + def test_expected_omni_pipelines_present(self): + # Guard against accidental removal during future refactors. + assert "qwen2_5_omni" in _OMNI_PIPELINES + assert "qwen2_5_omni_thinker_only" in _OMNI_PIPELINES + assert "qwen3_omni_moe" in _OMNI_PIPELINES + assert "qwen3_tts" in _OMNI_PIPELINES + + +class TestLazyLoading: + """Pipelines are imported only on first access.""" + + def test_contains_without_import(self): + # ``in`` hits the lazy map, not the loaded cache. + assert "qwen3_omni_moe" in _PIPELINE_REGISTRY + + def test_getitem_loads_correct_pipeline(self): + pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"] + assert pipeline.model_type == "qwen3_omni_moe" + assert pipeline.model_arch == "Qwen3OmniMoeForConditionalGeneration" + + def test_unknown_model_type_returns_none_via_get(self): + assert _PIPELINE_REGISTRY.get("not_a_real_pipeline") is None + + def test_unknown_model_type_raises_keyerror_via_getitem(self): + with pytest.raises(KeyError): + _PIPELINE_REGISTRY["not_a_real_pipeline"] + + def test_iteration_yields_registered_pipelines(self): + keys = set(_PIPELINE_REGISTRY) + assert "qwen2_5_omni" in keys + assert "qwen3_omni_moe" in keys + + +class TestDynamicRegistration: + """``register_pipeline()`` still works for plugins and tests.""" + + def test_register_adds_to_registry(self): + custom = PipelineConfig( + model_type="_test_dynamic_registration", + model_arch="DynamicTestModel", + stages=( + StagePipelineConfig( + stage_id=0, + model_stage="test", + execution_type=StageExecutionType.LLM_AR, + input_sources=(), + final_output=True, + ), + ), + ) + register_pipeline(custom) + try: + assert "_test_dynamic_registration" in _PIPELINE_REGISTRY + assert _PIPELINE_REGISTRY["_test_dynamic_registration"] is custom + finally: + # Don't leak the test registration into other tests. + if "_test_dynamic_registration" in _PIPELINE_REGISTRY: + del _PIPELINE_REGISTRY["_test_dynamic_registration"] + + def test_dynamic_registration_overrides_lazy_entry(self): + # Build a substitute for qwen3_omni_moe that we can distinguish. + original = _PIPELINE_REGISTRY["qwen3_omni_moe"] + override = PipelineConfig( + model_type="qwen3_omni_moe", + model_arch="OverriddenArch", + stages=original.stages, + ) + register_pipeline(override) + try: + assert _PIPELINE_REGISTRY["qwen3_omni_moe"].model_arch == "OverriddenArch" + finally: + # Remove the dynamic override so later tests see the original. + if "qwen3_omni_moe" in _PIPELINE_REGISTRY._loaded: + del _PIPELINE_REGISTRY["qwen3_omni_moe"] diff --git a/tests/conftest.py b/tests/conftest.py index 27833fe282..83752521f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import atexit import base64 import datetime import io @@ -46,6 +47,7 @@ from vllm.logger import init_logger from vllm.utils.network_utils import get_open_port +from vllm_omni.config.stage_config import resolve_deploy_yaml from vllm_omni.entrypoints.omni import Omni from vllm_omni.inputs.data import OmniSamplingParams from vllm_omni.outputs import OmniRequestOutput @@ -75,6 +77,7 @@ class OmniServerParams(NamedTuple): use_omni: bool = True use_stage_cli: bool = False init_timeout: int | None = None + stage_init_timeout: int | None = None # None defers to the server's own default (300 s) def assert_image_diffusion_response( @@ -166,7 +169,6 @@ def assert_audio_diffusion_response( Validate audio diffusion response. """ raise NotImplementedError("Audio validation is not implemented yet") - # consider using assert_audio_valid defined above def _maybe_int(value: Any) -> int | None: @@ -276,15 +278,32 @@ def assert_video_valid( pass -def assert_audio_valid(path: Path, *, sample_rate: int, channels: int, duration_s: float) -> None: - """Assert the WAV has the expected sample rate, channel count, and duration.""" +def assert_audio_valid( + audio_or_path: Path | np.ndarray, + *, + sample_rate: int, + channels: int, + duration_s: float, +) -> None: + """Assert WAV file or (batch, channels, samples) ndarray matches expected audio format.""" + expected_samples = int(duration_s * sample_rate) + if isinstance(audio_or_path, np.ndarray): + audio = audio_or_path + assert audio.ndim == 3, f"Expected audio ndim=3 (batch, channels, samples), got shape {audio.shape}" + assert audio.shape[0] == 1, f"Expected batch size 1, got {audio.shape[0]}" + assert audio.shape[1] == channels, f"Expected {channels} channels, got {audio.shape[1]}" + assert audio.shape[2] == expected_samples, ( + f"Expected {expected_samples} samples ({duration_s}s @ {sample_rate} Hz), got {audio.shape[2]}" + ) + return + + path = audio_or_path assert path.exists(), f"Audio not found: {path}" info = sf.info(str(path)) assert info.samplerate == sample_rate, f"Expected sample_rate={sample_rate}, got {info.samplerate}" assert info.channels == channels, f"Expected {channels} channel(s), got {info.channels}" - expected_frames = int(duration_s * sample_rate) - assert info.frames == expected_frames, ( - f"Expected {expected_frames} frames ({duration_s}s @ {sample_rate} Hz), got {info.frames}" + assert info.frames == expected_samples, ( + f"Expected {expected_samples} frames ({duration_s}s @ {sample_rate} Hz), got {info.frames}" ) @@ -1321,12 +1340,14 @@ def delete_by_path(config_dict: dict, path: str) -> None: else: print(f"Path {path} does not exist") + _stage_key = "stages" if "stages" in config else "stage_args" + # Apply deletions first if deletes: for key, value in deletes.items(): - if key == "stage_args": + if key in ("stage_args", "stages"): if value and isinstance(value, dict): - stage_args = config.get("stage_args", []) + stage_args = config.get(_stage_key, []) if not stage_args: raise ValueError("stage_args does not exist in config") @@ -1345,9 +1366,10 @@ def delete_by_path(config_dict: dict, path: str) -> None: continue # Delete specified paths in this stage - for path in delete_paths: - if path: # Skip empty paths - delete_by_path(target_stage, path) + # Avoid shadowing the original YAML Path used for the output filename below. + for delete_path in delete_paths: + if delete_path: # Skip empty paths + delete_by_path(target_stage, delete_path) elif "." in key: # Delete using dot-separated path delete_by_path(config, key) @@ -1358,9 +1380,9 @@ def delete_by_path(config_dict: dict, path: str) -> None: # Apply updates if updates: for key, value in updates.items(): - if key == "stage_args": + if key in ("stage_args", "stages"): if value and isinstance(value, dict): - stage_args = config.get("stage_args", []) + stage_args = config.get(_stage_key, []) if not stage_args: raise ValueError("stage_args does not exist in config") @@ -1377,15 +1399,15 @@ def delete_by_path(config_dict: dict, path: str) -> None: raise KeyError(f"Stage ID {stage_id} not found, available: {available_ids}") # Apply updates to this stage - for path, val in stage_updates.items(): + for update_path, val in stage_updates.items(): # Check if this is a simple key (not dot-separated) # Example: 'engine_input_source' vs 'engine_args.max_model_len' - if "." not in path: + if "." not in update_path: # Direct key assignment (e.g., updating a list value) - target_stage[path] = val + target_stage[update_path] = val else: # Dot-separated path (e.g., nested dict access) - apply_update(target_stage, path, val) + apply_update(target_stage, update_path, val) elif "." in key: # Apply using dot-separated path apply_update(config, key, value) @@ -1397,13 +1419,14 @@ def delete_by_path(config_dict: dict, path: str) -> None: # within the same second (e.g. test_qwen3_omni_expansion imports both # get_chunk_config and get_batch_token_config). int(time.time()) would collide # and the later write would overwrite the earlier YAML on disk. - base_name = yaml_path.rsplit(".", 1)[0] if "." in yaml_path else yaml_path - output_path = f"{base_name}_{time.time_ns()}.yaml" + # Keep generated configs outside the repo and delete them when pytest exits. + output_fd, output_path = tempfile.mkstemp(prefix=f"{path.stem}_", suffix=".yaml") + atexit.register(Path(output_path).unlink, missing_ok=True) - with open(output_path, "w", encoding="utf-8") as f: + with os.fdopen(output_fd, "w", encoding="utf-8") as f: yaml.dump(config, f, default_flow_style=None, sort_keys=False, allow_unicode=True, indent=2) - return output_path + return str(output_path) class OmniServer: @@ -1565,32 +1588,46 @@ def __init__( self.stage_config_path = stage_config_path self.master_port = get_open_port() self.visible_device_list = self._load_visible_device_list(env_dict) - self.stage_runtime_devices = self._load_stage_runtime_devices(stage_config_path) - self.stage_ids = stage_ids or self._load_stage_ids(stage_config_path) + resolved_cfg = resolve_deploy_yaml(stage_config_path) + # Dump the resolved deploy config so CI logs show each stage's + # gpu_memory_utilization / max_model_len / max_num_seqs after + # base_config inheritance and overlay merge — essential when + # diagnosing OOMs that depend on the merged values. + print( + f"[OmniServerStageCli] Resolved deploy config from {stage_config_path}:\n" + f"{yaml.safe_dump(resolved_cfg, sort_keys=False, default_flow_style=False)}", + flush=True, + ) + self.stage_runtime_devices = self._load_stage_runtime_devices(resolved_cfg) + self.stage_ids = stage_ids or self._load_stage_ids(resolved_cfg) if 0 not in self.stage_ids: raise ValueError(f"Stage CLI test requires stage_id=0 in config: {stage_config_path}") self.stage_procs: dict[int, subprocess.Popen] = {} self.proc = None @staticmethod - def _load_stage_ids(stage_config_path: str) -> list[int]: - with open(stage_config_path, encoding="utf-8") as f: - cfg = yaml.safe_load(f) or {} + def _stage_entries(cfg: dict) -> list[dict]: + """Return the list of stage entries from either legacy (``stage_args``) + or new-schema (``stages``) deploy YAMLs.""" + return cfg.get("stage_args") or cfg.get("stages") or [] - stage_ids = [stage["stage_id"] for stage in cfg.get("stage_args", []) if "stage_id" in stage] + @staticmethod + def _load_stage_ids(resolved_config: dict) -> list[int]: + stage_ids = [ + stage["stage_id"] for stage in OmniServerStageCli._stage_entries(resolved_config) if "stage_id" in stage + ] if not stage_ids: - raise ValueError(f"No stage IDs found in config: {stage_config_path}") + raise ValueError("No stage IDs found in resolved config") return stage_ids @staticmethod - def _load_stage_runtime_devices(stage_config_path: str) -> dict[int, str]: - with open(stage_config_path, encoding="utf-8") as f: - cfg = yaml.safe_load(f) or {} - + def _load_stage_runtime_devices(resolved_config: dict) -> dict[int, str]: runtime_devices: dict[int, str] = {} - for stage in cfg.get("stage_args", []): + for stage in OmniServerStageCli._stage_entries(resolved_config): stage_id = stage.get("stage_id") - devices = stage.get("runtime", {}).get("devices") + # New schema: stage.devices is flat at stage level. + # Legacy schema: stage.runtime.devices is nested. + devices = stage.get("devices") or stage.get("runtime", {}).get("devices") if stage_id is not None and devices: runtime_devices[int(stage_id)] = str(devices) return runtime_devices @@ -1676,10 +1713,21 @@ def _launch_stage(self, stage_id: int, *, headless: bool) -> None: cmd = self._build_stage_cmd(stage_id, headless=headless) print(f"Launching OmniServerStageCli stage {stage_id}: {' '.join(cmd)}") + # Capture each subprocess's stdout+stderr to a per-stage log file so + # debugging "Stage N exited before API server ready" doesn't rely on + # guessing; the file is surfaced in the RuntimeError message. + log_path = Path(tempfile.gettempdir()) / f"omni_stage_{stage_id}_{self.master_port}.log" + self._stage_log_paths = getattr(self, "_stage_log_paths", {}) + self._stage_log_paths[stage_id] = log_path + log_fh = open(log_path, "w", buffering=1) # noqa: SIM115 - closed in __exit__ + self._stage_log_files = getattr(self, "_stage_log_files", {}) + self._stage_log_files[stage_id] = log_fh proc = subprocess.Popen( cmd, env=env, cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + stdout=log_fh, + stderr=subprocess.STDOUT, ) self.stage_procs[stage_id] = proc if stage_id == 0: @@ -1689,7 +1737,18 @@ def _ensure_stage_processes_alive(self) -> None: for stage_id, proc in self.stage_procs.items(): ret = proc.poll() if ret is not None: - raise RuntimeError(f"Stage {stage_id} exited with code {ret} before API server became ready.") + log_path = getattr(self, "_stage_log_paths", {}).get(stage_id) + tail = "" + if log_path and log_path.exists(): + try: + with open(log_path, encoding="utf-8", errors="replace") as f: + lines = f.readlines() + tail = "\n=== Last 60 lines of stage {} log ({}) ===\n{}".format( + stage_id, log_path, "".join(lines[-60:]) or "" + ) + except Exception as exc: # pragma: no cover - diagnostic only + tail = f"\n" + raise RuntimeError(f"Stage {stage_id} exited with code {ret} before API server became ready.{tail}") def _start_server(self) -> None: ordered_stage_ids = [0, *[stage_id for stage_id in self.stage_ids if stage_id != 0]] @@ -1715,7 +1774,46 @@ def _start_server(self) -> None: raise RuntimeError(f"OmniServerStageCli failed to start within {max_wait} seconds") + def _dump_stage_logs_for_debug(self, head_lines: int = 300, tail_lines: int = 500) -> None: + """Tail each stage's subprocess log back to stdout on teardown. + + Stage subprocesses redirect stdout/stderr to ``/tmp/omni_stage_*.log`` + so we don't spam the main CI stream while tests run; but that also + hides engine init (KV cache size, Available KV cache memory, vLLM + engine config) when things go wrong. Dump them here so buildkite + captures them post-run. Head covers engine init; tail covers + whatever state the stage was in when it was torn down. + """ + log_paths = getattr(self, "_stage_log_paths", {}) or {} + for stage_id in sorted(log_paths): + log_path = log_paths[stage_id] + if not log_path or not log_path.exists(): + continue + try: + with open(log_path, encoding="utf-8", errors="replace") as f: + lines = f.readlines() + except Exception as exc: # pragma: no cover - diagnostic only + print(f"[OmniServerStageCli] stage {stage_id} log read failed: {exc}", flush=True) + continue + total = len(lines) + if total <= head_lines + tail_lines: + head_chunk = lines + tail_chunk = [] + elided = 0 + else: + head_chunk = lines[:head_lines] + tail_chunk = lines[-tail_lines:] + elided = total - head_lines - tail_lines + print(f"\n=== stage {stage_id} log HEAD ({log_path}) ===", flush=True) + print("".join(head_chunk).rstrip("\n"), flush=True) + if tail_chunk: + print(f"\n... [{elided} lines elided] ...", flush=True) + print(f"\n=== stage {stage_id} log TAIL ({log_path}) ===", flush=True) + print("".join(tail_chunk).rstrip("\n"), flush=True) + print(f"=== end stage {stage_id} log ===\n", flush=True) + def __exit__(self, exc_type, exc_val, exc_tb): + self._dump_stage_logs_for_debug() for stage_id in sorted(self.stage_procs, reverse=True): proc = self.stage_procs[stage_id] if proc.poll() is None: @@ -1761,22 +1859,35 @@ def omni_server(request: pytest.FixtureRequest, run_level: str, model_prefix: st if run_level == "advanced_model" and stage_config_path is not None: with open(stage_config_path, encoding="utf-8") as f: cfg = yaml.safe_load(f) or {} - stage_ids = [stage["stage_id"] for stage in cfg.get("stage_args", []) if "stage_id" in stage] + # Strip ``load_format: dummy`` (CI overlay default) so advanced_model + # tests use real weights. New schema (``stages:``) writes the field + # flat at stage level; legacy schema (``stage_args:``) nests it as + # ``engine_args.load_format``. Handle both. + new_schema_stages = cfg.get("stages") + stage_key = "stages" if new_schema_stages is not None else "stage_args" + delete_path = "load_format" if new_schema_stages is not None else "engine_args.load_format" + stage_entries = cfg.get(stage_key, []) + stage_ids = [stage["stage_id"] for stage in stage_entries if "stage_id" in stage] stage_config_path = modify_stage_config( stage_config_path, - deletes={"stage_args": {stage_id: ["engine_args.load_format"] for stage_id in stage_ids}}, + deletes={stage_key: {stage_id: [delete_path] for stage_id in stage_ids}}, ) server_args = params.server_args or [] - if params.use_omni: - server_args = ["--stage-init-timeout", "120", *server_args] + if params.use_omni and params.stage_init_timeout is not None: + server_args = [*server_args, "--stage-init-timeout", str(params.stage_init_timeout)] + else: + server_args = [*server_args, "--stage-init-timeout", "600"] if params.init_timeout is not None: server_args = [*server_args, "--init-timeout", str(params.init_timeout)] + else: + server_args = [*server_args, "--init-timeout", "900"] if params.use_stage_cli: if not params.use_omni: raise ValueError("omni_server with use_stage_cli=True requires use_omni=True") if stage_config_path is None: raise ValueError("omni_server with use_stage_cli=True requires a stage_config_path") + server_args += ["--stage-configs-path", stage_config_path] with OmniServerStageCli( model, @@ -1826,6 +1937,7 @@ class OmniResponse: e2e_latency: float | None = None success: bool = False error_message: str | None = None + cached_tokens: int | None = None @dataclass @@ -2321,6 +2433,11 @@ def _process_non_stream_omni_response(self, chat_completion) -> OmniResponse: if hasattr(choice.message, "content") and choice.message.content is not None: text_content = choice.message.content + # Extract cached_tokens for prefix caching tests + usage = getattr(chat_completion, "usage", None) + if usage and (details := getattr(usage, "prompt_tokens_details", None)): + result.cached_tokens = details.cached_tokens + # Calculate end-to-end latency result.e2e_latency = time.perf_counter() - start_time @@ -2373,7 +2490,7 @@ def _process_diffusion_response(self, chat_completion) -> DiffusionResponse: image_url = item.get("image_url", {}).get("url") else: image_url_obj = getattr(item, "image_url", None) - image_url = hasattr(image_url_obj, "url", None) if image_url_obj else None + image_url = getattr(image_url_obj, "url", None) if image_url_obj else None if image_url and image_url.startswith("data:image"): b64_data = image_url.split(",", 1)[1] img = decode_b64_image(b64_data) @@ -2679,7 +2796,7 @@ def _stream_task(): return responses - def send_diffusion_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]: + def send_diffusion_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[DiffusionResponse]: """ Send OpenAI requests for diffusion models. @@ -2687,9 +2804,9 @@ def send_diffusion_request(self, request_config: dict[str, Any], request_num: in request_config: Request configuration dictionary containing parameters like model, messages request_num: Number of requests to send concurrently, defaults to 1 (single request) Returns: - List[OmniResponse]: List of response objects + List[DiffusionResponse]: List of response objects """ - responses = [] + responses: list[DiffusionResponse] = [] stream = request_config.get("stream", False) modalities = request_config.get("modalities", omit) # Most diffusion models don't require modalities param extra_body = request_config.get("extra_body", None) @@ -2869,9 +2986,9 @@ def __init__( self, model_name: str, seed: int = 42, - stage_init_timeout: int = 300, + stage_init_timeout: int = 600, batch_timeout: int = 10, - init_timeout: int = 300, + init_timeout: int = 900, shm_threshold_bytes: int = 65536, log_stats: bool = False, stage_configs_path: str | None = None, @@ -2996,6 +3113,10 @@ def get_omni_inputs( video_padding_token = "<|video_pad|>" image_padding_token = "<|image_pad|>" audio_padding_token = "<|audio_pad|>" + elif "Ming-flash-omni" in self.model_name: + video_padding_token = "
", "STOP"]) + result = omni._prepare_prefill_sampling_params("req-1", sp) + assert result.stop == [] + + def test_clears_stop_token_ids(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + sp = SamplingParams(max_tokens=2048, stop_token_ids=[151643, 151644]) + result = omni._prepare_prefill_sampling_params("req-1", sp) + assert result.stop_token_ids == [] + + def test_clears_include_stop_str_in_output(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + sp = SamplingParams(max_tokens=2048, include_stop_str_in_output=True) + result = omni._prepare_prefill_sampling_params("req-1", sp) + assert result.include_stop_str_in_output is False + + def test_original_sp_unchanged(self, monkeypatch): + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + sp = SamplingParams(max_tokens=2048, stop=["
"], stop_token_ids=[151643]) + _ = omni._prepare_prefill_sampling_params("req-1", sp) + assert sp.stop == [""] + assert sp.stop_token_ids == [151643] + + +# =================================================================== +# Tests: Failure mode & memory leak prevention +# =================================================================== +# NOTE: Full generate()-level failure mode tests are removed for now. +# The _run_generation error handler (line 1344-1350 in omni.py) calls +# _drop_pd_kv_params but does not increment completed_requests, causing +# the while-loop to hang. These tests need to be revisited once the +# production error-handling path is fixed to properly terminate on +# stage errors. + + +# =================================================================== +# Tests: TP size validation +# =================================================================== + + +class TestTPSizeValidation: + """Tests that _validate_pd_separation_config checks tensor_parallel_size.""" + + def test_matching_tp_passes(self, monkeypatch): + """Same TP size should not raise.""" + prefill_cfg = _prefill_stage_cfg() + prefill_cfg["engine_args"]["tensor_parallel_size"] = 2 + decode_cfg = _decode_stage_cfg(engine_input_source=[0]) + decode_cfg["engine_args"]["tensor_parallel_size"] = 2 + omni = _make_pd_omni(monkeypatch, [prefill_cfg, decode_cfg]) + assert omni._pd_separation_pair == (0, 1) + + def test_mismatched_tp_raises(self, monkeypatch): + """Different TP sizes should raise ValueError.""" + prefill_cfg = _prefill_stage_cfg() + prefill_cfg["engine_args"]["tensor_parallel_size"] = 2 + decode_cfg = _decode_stage_cfg(engine_input_source=[0]) + decode_cfg["engine_args"]["tensor_parallel_size"] = 4 + with pytest.raises(ValueError, match="tensor_parallel_size"): + _make_pd_omni(monkeypatch, [prefill_cfg, decode_cfg]) + + def test_default_tp_no_error(self, monkeypatch): + """Stages without explicit TP (defaults to 1) should pass.""" + omni = _make_pd_omni( + monkeypatch, + [ + _prefill_stage_cfg(), + _decode_stage_cfg(engine_input_source=[0]), + ], + ) + assert omni._pd_separation_pair == (0, 1) diff --git a/tests/entrypoints/test_realtime_connection_helpers.py b/tests/entrypoints/test_realtime_connection_helpers.py new file mode 100644 index 0000000000..e795aa92d0 --- /dev/null +++ b/tests/entrypoints/test_realtime_connection_helpers.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for realtime streaming helpers (PR #2581 /v1/realtime path).""" + +from __future__ import annotations + +import base64 + +import numpy as np +import pytest +import torch +from vllm.sampling_params import RequestOutputKind, SamplingParams + +from vllm_omni.entrypoints.async_omni import AsyncOmni +from vllm_omni.entrypoints.openai.realtime_connection import RealtimeConnection + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +@pytest.fixture +def realtime_conn() -> RealtimeConnection: + return RealtimeConnection.__new__(RealtimeConnection) + + +class TestRealtimeConnectionTensorAndPcm: + def test_tensor_to_numpy_none(self) -> None: + assert RealtimeConnection._tensor_to_numpy(None) is None + + def test_tensor_to_numpy_1d_numpy(self) -> None: + arr = np.array([1.0, 2.0], dtype=np.float64) + out = RealtimeConnection._tensor_to_numpy(arr) + assert out is not None + assert out.dtype == np.float32 + assert out.shape == (2,) + + def test_tensor_to_numpy_2d_numpy_flattened(self) -> None: + arr = np.array([[0.5], [-0.5]], dtype=np.float32) + out = RealtimeConnection._tensor_to_numpy(arr) + assert out is not None + assert out.shape == (2,) + + def test_tensor_to_numpy_torch(self) -> None: + t = torch.tensor([[0.25, -0.25]], dtype=torch.float32) + out = RealtimeConnection._tensor_to_numpy(t) + assert out is not None + assert out.shape == (2,) + np.testing.assert_allclose(out, [0.25, -0.25], rtol=1e-5) + + def test_pcm16_b64_roundtrip(self) -> None: + audio = np.array([0.0, 1.0, -1.0], dtype=np.float32) + b64 = RealtimeConnection._pcm16_b64(audio) + raw = base64.b64decode(b64) + assert len(raw) == 6 + pcm = np.frombuffer(raw, dtype=np.int16) + assert pcm[0] == 0 + assert pcm[1] == 32767 + assert pcm[2] == -32767 + + +class TestAsyncOmniStreamingParamsValidation: + def test_accepts_streaming_friendly_params(self) -> None: + p = SamplingParams( + n=1, + stop=[], + output_kind=RequestOutputKind.DELTA, + ) + AsyncOmni._validate_streaming_input_sampling_params(p) + + def test_rejects_non_sampling_params(self) -> None: + with pytest.raises(ValueError, match="Input streaming"): + AsyncOmni._validate_streaming_input_sampling_params(object()) # type: ignore[arg-type] + + def test_rejects_n_greater_than_one(self) -> None: + p = SamplingParams(n=2, stop=[], output_kind=RequestOutputKind.DELTA) + with pytest.raises(ValueError, match="Input streaming"): + AsyncOmni._validate_streaming_input_sampling_params(p) + + def test_rejects_final_only(self) -> None: + p = SamplingParams(n=1, stop=[], output_kind=RequestOutputKind.FINAL_ONLY) + with pytest.raises(ValueError, match="Input streaming"): + AsyncOmni._validate_streaming_input_sampling_params(p) + + def test_rejects_stop_strings(self) -> None: + p = SamplingParams(n=1, stop=["\n"], output_kind=RequestOutputKind.DELTA) + with pytest.raises(ValueError, match="Input streaming"): + AsyncOmni._validate_streaming_input_sampling_params(p) diff --git a/tests/entrypoints/test_serve.py b/tests/entrypoints/test_serve.py index 916db3cc22..e60afc9cd7 100644 --- a/tests/entrypoints/test_serve.py +++ b/tests/entrypoints/test_serve.py @@ -3,15 +3,37 @@ from __future__ import annotations import argparse -from unittest.mock import Mock, patch import pytest +from pytest_mock import MockerFixture -from vllm_omni.entrypoints.cli.serve import run_headless +from vllm_omni.entrypoints.cli.serve import OmniServeCommand, run_headless +from vllm_omni.entrypoints.utils import detect_explicit_cli_keys pytestmark = [pytest.mark.core_model, pytest.mark.cpu] +def test_serve_parser_accepts_no_async_chunk_and_marks_it_explicit() -> None: + """``--no-async-chunk`` should parse to ``async_chunk=False`` and mark the + shared deploy-level dest as explicitly provided by the user.""" + try: + from vllm.utils.argparse_utils import FlexibleArgumentParser + except Exception as exc: + pytest.skip(f"Cannot build parser in this environment: {exc}") + + root = FlexibleArgumentParser() + subparsers = root.add_subparsers(dest="subcommand") + cmd = OmniServeCommand() + serve_parser = cmd.subparser_init(subparsers) + + argv = ["serve", "fake-model", "--omni", "--no-async-chunk"] + args = root.parse_args(argv) + + assert args.async_chunk is False + explicit = detect_explicit_cli_keys(argv, serve_parser) + assert "async_chunk" in explicit + + def _make_headless_args() -> argparse.Namespace: return argparse.Namespace( model="fake-model", @@ -26,45 +48,43 @@ def _make_headless_args() -> argparse.Namespace: ) -def test_run_headless_registers_stage_once_and_launches_all_local_engines() -> None: +def test_run_headless_registers_stage_once_and_launches_all_local_engines(mocker: MockerFixture) -> None: args = _make_headless_args() - stage_cfg = Mock(stage_id=3) + stage_cfg = mocker.Mock(stage_id=3) stage_cfgs = [stage_cfg] - parallel_config = Mock( + parallel_config = mocker.Mock( data_parallel_size_local=2, data_parallel_rank=4, data_parallel_rank_local=1, node_rank_within_dp=0, ) - vllm_config = Mock(parallel_config=parallel_config) - executor_class = Mock() - engine_manager = Mock() - - with ( - patch( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - return_value=("/fake/stages.yaml", stage_cfgs), - ), - patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment"), - patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=Mock()), - patch("vllm_omni.engine.stage_init_utils.get_stage_connector_spec", return_value={}), - patch("vllm_omni.engine.stage_init_utils.build_engine_args_dict", return_value={}), - patch( - "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage", - return_value=(None, None, None), - ), - patch( - "vllm_omni.engine.stage_init_utils.build_vllm_config", - return_value=(vllm_config, executor_class), - ) as mock_build_vllm_config, - patch( - "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master", - return_value="tcp://127.0.0.1:26001", - ) as mock_register, - patch("vllm.v1.engine.utils.CoreEngineProcManager", return_value=engine_manager) as mock_manager_cls, - patch("signal.signal"), - ): - run_headless(args) + vllm_config = mocker.Mock(parallel_config=parallel_config) + executor_class = mocker.Mock() + engine_manager = mocker.Mock() + + mocker.patch( + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", + return_value=("/fake/stages.yaml", stage_cfgs), + ) + mocker.patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment") + mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=mocker.Mock()) + mocker.patch("vllm_omni.engine.stage_init_utils.get_stage_connector_spec", return_value={}) + mocker.patch("vllm_omni.engine.stage_init_utils.build_engine_args_dict", return_value={}) + mocker.patch( + "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage", + return_value=(None, None, None), + ) + mock_build_vllm_config = mocker.patch( + "vllm_omni.engine.stage_init_utils.build_vllm_config", + return_value=(vllm_config, executor_class), + ) + mock_register = mocker.patch( + "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master", + return_value="tcp://127.0.0.1:26001", + ) + mock_manager_cls = mocker.patch("vllm.v1.engine.utils.CoreEngineProcManager", return_value=engine_manager) + mocker.patch("signal.signal") + run_headless(args) mock_build_vllm_config.assert_called_once_with( stage_cfg, @@ -92,89 +112,85 @@ def test_run_headless_registers_stage_once_and_launches_all_local_engines() -> N engine_manager.shutdown.assert_called_once_with() -def test_run_headless_honors_explicit_log_stats_flag() -> None: +def test_run_headless_honors_explicit_log_stats_flag(mocker: MockerFixture) -> None: args = _make_headless_args() args.log_stats = True - stage_cfg = Mock(stage_id=3) + stage_cfg = mocker.Mock(stage_id=3) stage_cfgs = [stage_cfg] - parallel_config = Mock( + parallel_config = mocker.Mock( data_parallel_size_local=2, data_parallel_rank=4, data_parallel_rank_local=1, node_rank_within_dp=0, ) - vllm_config = Mock(parallel_config=parallel_config) - executor_class = Mock() - engine_manager = Mock() - - with ( - patch( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - return_value=("/fake/stages.yaml", stage_cfgs), - ), - patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment"), - patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=Mock()), - patch("vllm_omni.engine.stage_init_utils.get_stage_connector_spec", return_value={}), - patch("vllm_omni.engine.stage_init_utils.build_engine_args_dict", return_value={}), - patch( - "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage", - return_value=(None, None, None), - ), - patch( - "vllm_omni.engine.stage_init_utils.build_vllm_config", - return_value=(vllm_config, executor_class), - ), - patch( - "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master", - return_value="tcp://127.0.0.1:26001", - ), - patch("vllm.v1.engine.utils.CoreEngineProcManager", return_value=engine_manager) as mock_manager_cls, - patch("signal.signal"), - ): - run_headless(args) + vllm_config = mocker.Mock(parallel_config=parallel_config) + executor_class = mocker.Mock() + engine_manager = mocker.Mock() + + mocker.patch( + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", + return_value=("/fake/stages.yaml", stage_cfgs), + ) + mocker.patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment") + mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=mocker.Mock()) + mocker.patch("vllm_omni.engine.stage_init_utils.get_stage_connector_spec", return_value={}) + mocker.patch("vllm_omni.engine.stage_init_utils.build_engine_args_dict", return_value={}) + mocker.patch( + "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage", + return_value=(None, None, None), + ) + mocker.patch( + "vllm_omni.engine.stage_init_utils.build_vllm_config", + return_value=(vllm_config, executor_class), + ) + mocker.patch( + "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master", + return_value="tcp://127.0.0.1:26001", + ) + mock_manager_cls = mocker.patch("vllm.v1.engine.utils.CoreEngineProcManager", return_value=engine_manager) + mocker.patch("signal.signal") + run_headless(args) manager_kwargs = mock_manager_cls.call_args.kwargs assert manager_kwargs["log_stats"] is True -def test_run_headless_launches_diffusion_stage_via_omni_master() -> None: +def test_run_headless_launches_diffusion_stage_via_omni_master(mocker: MockerFixture) -> None: args = _make_headless_args() - stage_cfg = Mock(stage_id=3, stage_type="diffusion") - stage_cfg.engine_args = Mock() + stage_cfg = mocker.Mock(stage_id=3, stage_type="diffusion") + stage_cfg.engine_args = mocker.Mock() stage_cfg.engine_input_source = [] stage_cfgs = [stage_cfg] - metadata = Mock(stage_id=3) - od_config = Mock() - proc = Mock() + metadata = mocker.Mock(stage_id=3) + od_config = mocker.Mock() + proc = mocker.Mock() proc.exitcode = 0 proc.is_alive.return_value = False - with ( - patch( - "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", - return_value=("/fake/stages.yaml", stage_cfgs), - ), - patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment"), - patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=Mock()), - patch( - "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage", - return_value=(None, None, None), - ), - patch("vllm_omni.engine.stage_init_utils.extract_stage_metadata", return_value=metadata), - patch("vllm_omni.engine.stage_init_utils.inject_kv_stage_info") as mock_inject_stage_info, - patch("vllm_omni.engine.stage_init_utils.build_diffusion_config", return_value=od_config), - patch( - "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master", - return_value=("tcp://127.0.0.1:26001", "tcp://127.0.0.1:26002", "tcp://127.0.0.1:26003"), - ) as mock_register, - patch( - "vllm_omni.diffusion.stage_diffusion_proc.spawn_diffusion_proc", - return_value=(proc, "tcp://127.0.0.1:26001", "tcp://127.0.0.1:26002", "tcp://127.0.0.1:26003"), - ) as mock_spawn, - patch("vllm_omni.diffusion.stage_diffusion_proc.complete_diffusion_handshake") as mock_handshake, - patch("signal.signal"), - ): - run_headless(args) + mocker.patch( + "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs", + return_value=("/fake/stages.yaml", stage_cfgs), + ) + mocker.patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment") + mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=mocker.Mock()) + mocker.patch( + "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage", + return_value=(None, None, None), + ) + mocker.patch("vllm_omni.engine.stage_init_utils.extract_stage_metadata", return_value=metadata) + mock_inject_stage_info = mocker.patch("vllm_omni.engine.stage_init_utils.inject_kv_stage_info") + mocker.patch("vllm_omni.engine.stage_init_utils.build_diffusion_config", return_value=od_config) + mock_register = mocker.patch( + "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master", + return_value=("tcp://127.0.0.1:26001", "tcp://127.0.0.1:26002", "tcp://127.0.0.1:26003"), + ) + mock_spawn = mocker.patch( + "vllm_omni.diffusion.stage_diffusion_proc.spawn_diffusion_proc", + return_value=(proc, "tcp://127.0.0.1:26001", "tcp://127.0.0.1:26002", "tcp://127.0.0.1:26003"), + ) + mock_handshake = mocker.patch("vllm_omni.diffusion.stage_diffusion_proc.complete_diffusion_handshake") + mocker.patch("signal.signal") + run_headless(args) mock_inject_stage_info.assert_called_once_with(stage_cfg, 3) mock_register.assert_called_once_with( diff --git a/tests/entrypoints/test_utils.py b/tests/entrypoints/test_utils.py index 94e254c250..248629d51d 100644 --- a/tests/entrypoints/test_utils.py +++ b/tests/entrypoints/test_utils.py @@ -310,6 +310,39 @@ def mock_exists(path): assert result is not None assert "glm_image.yaml" in result + def test_voxcpm_transformers_format_resolution(self, mocker: MockerFixture): + """Test VoxCPM transformers config resolves to the voxcpm stage config.""" + mocker.patch( + "vllm_omni.entrypoints.utils.get_config", + side_effect=ValueError("missing transformers config"), + ) + mocker.patch( + "vllm_omni.entrypoints.utils.file_or_path_exists", + side_effect=lambda _model, filename, revision=None: filename == "config.json", + ) + mocker.patch( + "vllm_omni.entrypoints.utils.get_hf_file_to_dict", + return_value={"model_type": "voxcpm"}, + ) + mocker.patch( + "vllm_omni.entrypoints.utils.current_omni_platform.get_default_stage_config_path", + return_value="vllm_omni/model_executor/stage_configs", + ) + + original_exists = os.path.exists + + def mock_exists(path): + if "voxcpm.yaml" in str(path): + return True + return original_exists(path) + + mocker.patch("os.path.exists", side_effect=mock_exists) + + result = resolve_model_config_path("OpenBMB/VoxCPM1.5") + + assert result is not None + assert "voxcpm.yaml" in result + class TestLoadAndResolveStageConfigs: def test_load_and_resolve_with_kwargs(self): diff --git a/tests/examples/online_serving/test_qwen2_5_omni.py b/tests/examples/online_serving/test_qwen2_5_omni.py index a78ccf5924..2813b2fda8 100644 --- a/tests/examples/online_serving/test_qwen2_5_omni.py +++ b/tests/examples/online_serving/test_qwen2_5_omni.py @@ -5,8 +5,6 @@ import os -from vllm_omni.platforms import current_omni_platform - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" from pathlib import Path @@ -19,19 +17,15 @@ run_cmd, strip_trailing_audio_saved_line, ) -from tests.utils import hardware_test +from tests.utils import get_deploy_config_path, hardware_test pytestmark = [pytest.mark.advanced_model, pytest.mark.example] models = ["Qwen/Qwen2.5-Omni-7B"] - -stage_configs = [str(Path(__file__).parent.parent.parent / "e2e" / "stage_configs" / "qwen2_5_omni_ci.yaml")] - -if current_omni_platform.is_xpu(): - stage_configs = [ - str(Path(__file__).parent.parent.parent / "e2e" / "stage_configs" / "xpu" / "qwen2_5_omni_ci.yaml") - ] +# Single CI deploy YAML; rocm/xpu deltas are picked automatically via the +# platforms: section in vllm_omni/deploy/ci/qwen2_5_omni.yaml. +stage_configs = [get_deploy_config_path("ci/qwen2_5_omni.yaml")] example_dir = str(Path(__file__).parent.parent.parent.parent / "examples" / "online_serving") # Create parameter combinations for model and stage config diff --git a/tests/examples/online_serving/test_qwen3_omni.py b/tests/examples/online_serving/test_qwen3_omni.py index 65f99d7bf2..e9ee2763bb 100644 --- a/tests/examples/online_serving/test_qwen3_omni.py +++ b/tests/examples/online_serving/test_qwen3_omni.py @@ -5,8 +5,6 @@ import os -from vllm_omni.platforms import current_omni_platform - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" from pathlib import Path @@ -19,17 +17,14 @@ run_cmd, strip_trailing_audio_saved_line, ) -from tests.utils import hardware_test +from tests.utils import get_deploy_config_path, hardware_test pytestmark = [pytest.mark.advanced_model, pytest.mark.example] models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"] -stage_configs = [str(Path(__file__).parent.parent.parent / "e2e" / "stage_configs" / "qwen3_omni_ci.yaml")] - -if current_omni_platform.is_xpu(): - stage_configs = [str(Path(__file__).parent.parent.parent / "e2e" / "stage_configs" / "xpu" / "qwen3_omni_ci.yaml")] +stage_configs = [get_deploy_config_path("ci/qwen3_omni_moe.yaml")] example_dir = str(Path(__file__).parent.parent.parent.parent / "examples" / "online_serving") diff --git a/tests/model_executor/models/mimo_audio/test_mimo_audio_code2wav_batch_decode.py b/tests/model_executor/models/mimo_audio/test_mimo_audio_code2wav_batch_decode.py index 85c0e8b56e..8858d1f8f1 100644 --- a/tests/model_executor/models/mimo_audio/test_mimo_audio_code2wav_batch_decode.py +++ b/tests/model_executor/models/mimo_audio/test_mimo_audio_code2wav_batch_decode.py @@ -2,10 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from types import SimpleNamespace -from unittest.mock import Mock import pytest import torch +from pytest_mock import MockerFixture from vllm_omni.model_executor.models.mimo_audio.config_mimo_audio import TALKER_CODEC_PAD_TOKEN_ID from vllm_omni.model_executor.models.mimo_audio.mimo_audio_code2wav import ( @@ -51,7 +51,7 @@ def _make_invalid_flat_immediate_eostm(eostm_id: int = 666) -> torch.Tensor: return g.reshape(-1) -def _minimal_model(): +def _minimal_model(mocker: MockerFixture): """Avoid __init__ (HF tokenizer paths); only fields used by _batch_decode_waveforms.""" model = object.__new__(MiMoAudioToken2WavForConditionalGenerationVLLM) model.device = torch.device("cpu") @@ -59,7 +59,7 @@ def _minimal_model(): model.streamer_config = AudioStreamerConfig(group_size=_GROUP, audio_channels=_AC) model.codes = _codes_ns() - decode_vq = Mock( + decode_vq = mocker.Mock( side_effect=lambda audio_codes: torch.ones( audio_codes.shape[1], 7, @@ -67,7 +67,7 @@ def _minimal_model(): device=audio_codes.device, ) ) - decoder = Mock() + decoder = mocker.Mock() audio_tok = SimpleNamespace( encoder=SimpleNamespace(decode_vq=decode_vq), @@ -78,9 +78,9 @@ def _minimal_model(): return model, audio_tok -def test_batch_decode_waveforms_empty_input_list(): +def test_batch_decode_waveforms_empty_input_list(mocker: MockerFixture): """Empty input list returns a single zero-length float32 tensor on model device.""" - model, _ = _minimal_model() + model, _ = _minimal_model(mocker) out = MiMoAudioToken2WavForConditionalGenerationVLLM._batch_decode_waveforms(model, []) assert len(out) == 1 assert out[0].dtype == torch.float32 @@ -88,9 +88,9 @@ def test_batch_decode_waveforms_empty_input_list(): assert out[0].device == model.device -def test_batch_decode_waveforms_single_vs_multiple_decoder_shapes(): +def test_batch_decode_waveforms_single_vs_multiple_decoder_shapes(mocker: MockerFixture): """Single and multi-request batches produce correctly shaped packed hidden states and trimmed waveforms.""" - model, audio_tok = _minimal_model() + model, audio_tok = _minimal_model(mocker) decoder = audio_tok.decoder # Single valid request: decoder output rank-3 for double squeeze path @@ -118,9 +118,9 @@ def test_batch_decode_waveforms_single_vs_multiple_decoder_shapes(): assert out2[1].shape == (8 * _FTP,) -def test_batch_decode_waveforms_mixed_valid_invalid_requests(): +def test_batch_decode_waveforms_mixed_valid_invalid_requests(mocker: MockerFixture): """Mixed valid and invalid requests: invalid slots get empty tensors, valid slots get decoded waveforms.""" - model, audio_tok = _minimal_model() + model, audio_tok = _minimal_model(mocker) valid_a = _make_valid_flat_codes(1) valid_b = _make_valid_flat_codes(1) dummy = _make_dummy_code_tensor() @@ -151,9 +151,9 @@ def test_batch_decode_waveforms_mixed_valid_invalid_requests(): assert input_lengths.tolist() == [4, 4] -def test_batch_decode_waveforms_all_invalid_returns_per_request_empty(): +def test_batch_decode_waveforms_all_invalid_returns_per_request_empty(mocker: MockerFixture): """All-invalid batch skips decoder entirely and returns empty tensors for every slot.""" - model, audio_tok = _minimal_model() + model, audio_tok = _minimal_model(mocker) out = MiMoAudioToken2WavForConditionalGenerationVLLM._batch_decode_waveforms( model, [None, _make_dummy_code_tensor(), torch.tensor([], dtype=torch.long)], @@ -163,9 +163,9 @@ def test_batch_decode_waveforms_all_invalid_returns_per_request_empty(): audio_tok.decoder.assert_not_called() -def test_batch_decode_waveforms_output_shape_trim_when_decoder_returns_extra_samples(): +def test_batch_decode_waveforms_output_shape_trim_when_decoder_returns_extra_samples(mocker: MockerFixture): """Decoder output longer than valid_len is trimmed to the exact expected waveform length.""" - model, audio_tok = _minimal_model() + model, audio_tok = _minimal_model(mocker) flat = _make_valid_flat_codes(1) # Longer than valid_len so branch wav = wav[:valid_len] runs audio_tok.decoder.return_value = torch.ones(1, 1, 10_000, dtype=torch.float32) @@ -175,9 +175,9 @@ def test_batch_decode_waveforms_output_shape_trim_when_decoder_returns_extra_sam assert out[0].dtype == torch.float32 -def test_batch_decode_waveforms_multi_request_trims_each_row_when_decoder_returns_extra(): +def test_batch_decode_waveforms_multi_request_trims_each_row_when_decoder_returns_extra(mocker: MockerFixture): """Else-branch split: per-request wav[:valid_len] when decoder pads each batch row.""" - model, audio_tok = _minimal_model() + model, audio_tok = _minimal_model(mocker) a = _make_valid_flat_codes(1) b = _make_valid_flat_codes(2) audio_tok.decoder.return_value = torch.ones(2, 1, 10_000, dtype=torch.float32) @@ -189,9 +189,9 @@ def test_batch_decode_waveforms_multi_request_trims_each_row_when_decoder_return assert out[1].dtype == torch.float32 -def test_batch_decode_waveforms_valid_only_at_edges_maps_to_correct_indices(): +def test_batch_decode_waveforms_valid_only_at_edges_maps_to_correct_indices(mocker: MockerFixture): """Tensor packing order must match valid_indices when invalid requests are in the middle.""" - model, audio_tok = _minimal_model() + model, audio_tok = _minimal_model(mocker) first = _make_valid_flat_codes(1) last = _make_valid_flat_codes(2) inputs = [ @@ -212,9 +212,9 @@ def test_batch_decode_waveforms_valid_only_at_edges_maps_to_correct_indices(): assert input_lengths.tolist() == [4, 8] -def test_batch_decode_waveforms_output_shapes_1d_float32_for_all_slots(): +def test_batch_decode_waveforms_output_shapes_1d_float32_for_all_slots(mocker: MockerFixture): """Every slot is a 1-D float32 vector (empty or waveform), matching downstream expectations.""" - model, audio_tok = _minimal_model() + model, audio_tok = _minimal_model(mocker) inputs = [_make_valid_flat_codes(1), None, _make_valid_flat_codes(1)] audio_tok.decoder.return_value = torch.ones(2, 1, 5000, dtype=torch.float32) out = MiMoAudioToken2WavForConditionalGenerationVLLM._batch_decode_waveforms(model, inputs) diff --git a/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py b/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py index 8e04b04966..587e7f7f8b 100644 --- a/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py +++ b/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py @@ -10,10 +10,9 @@ - Interleaved (use_audio_in_video) should also work correctly. """ -from unittest.mock import Mock - import pytest import torch +from pytest_mock import MockerFixture from vllm.model_executor.models.qwen2_5_omni_thinker import ( check_interleaved_audio_video, merge_interleaved_embeddings, @@ -107,7 +106,7 @@ def test_interleaved(self): # --------------------------------------------------------------------------- -def make_mock_model(hidden: int = 8): +def make_mock_model(mocker: MockerFixture, hidden: int = 8): """ Return a minimal mock of Qwen2_5OmniThinkerForConditionalGeneration that has enough structure to run embed_input_ids. @@ -116,10 +115,10 @@ def make_mock_model(hidden: int = 8): Qwen2_5OmniThinkerForConditionalGeneration, ) - model = Mock(spec=Qwen2_5OmniThinkerForConditionalGeneration) + model = mocker.Mock(spec=Qwen2_5OmniThinkerForConditionalGeneration) # Config with token IDs - cfg = Mock() + cfg = mocker.Mock() cfg.video_token_index = VIDEO_TOKEN_ID cfg.audio_token_index = AUDIO_TOKEN_ID model.config = cfg @@ -130,9 +129,9 @@ def fake_lm_embed(ids: torch.Tensor) -> torch.Tensor: # view with shared memory, which masked_scatter_ cannot handle). return ids.float().unsqueeze(-1).expand(-1, hidden).clone() - lang_model = Mock() + lang_model = mocker.Mock() lang_model.embed_input_ids = fake_lm_embed - model.get_language_model = Mock(return_value=lang_model) + model.get_language_model = mocker.Mock(return_value=lang_model) from vllm.model_executor.models.interfaces import SupportsMultiModal @@ -169,7 +168,7 @@ def build_mm_embeds(audio_n, image_n, video_n, hidden, audio_val=10.0, image_val class TestEmbedInputIds: - def _run(self, audio_n, image_n, video_n, hidden=8): + def _run(self, mocker: MockerFixture, audio_n, image_n, video_n, hidden=8): """ Run embed_input_ids for a non-interleaved mixed-modality sequence. Returns (result_embeds, input_ids, is_multimodal). @@ -177,33 +176,33 @@ def _run(self, audio_n, image_n, video_n, hidden=8): input_ids, is_multimodal = make_token_seq(audio_n, image_n, video_n) mm_embeds = build_mm_embeds(audio_n, image_n, video_n, hidden) - model, _ = make_mock_model(hidden) + model, _ = make_mock_model(mocker, hidden) result = model.embed_input_ids(input_ids, mm_embeds, is_multimodal=is_multimodal) return result, input_ids, is_multimodal - def test_audio_only(self): + def test_audio_only(self, mocker: MockerFixture): """Audio-only: audio positions get audio embeddings.""" audio_n, hidden = 5, 8 audio_val = 10.0 - result, input_ids, is_multimodal = self._run(audio_n, 0, 0, hidden) + result, input_ids, is_multimodal = self._run(mocker, audio_n, 0, 0, hidden) audio_pos = (input_ids == AUDIO_TOKEN_ID).nonzero(as_tuple=True)[0] assert result[audio_pos].allclose(torch.full((audio_n, hidden), audio_val)), ( "Audio positions should get audio embeddings" ) - def test_video_only(self): + def test_video_only(self, mocker: MockerFixture): """Video-only: video positions get video embeddings.""" video_n, hidden = 6, 8 video_val = 30.0 - result, input_ids, is_multimodal = self._run(0, 0, video_n, hidden) + result, input_ids, is_multimodal = self._run(mocker, 0, 0, video_n, hidden) video_pos = (input_ids == VIDEO_TOKEN_ID).nonzero(as_tuple=True)[0] assert result[video_pos].allclose(torch.full((video_n, hidden), video_val)), ( "Video positions should get video embeddings" ) - def test_mixed_modalities_audio_goes_to_audio_pos(self): + def test_mixed_modalities_audio_goes_to_audio_pos(self, mocker: MockerFixture): """ Regression test for GitHub issue #34506: With audio + image + video (non-interleaved), audio positions must @@ -212,7 +211,7 @@ def test_mixed_modalities_audio_goes_to_audio_pos(self): audio_n, image_n, video_n, hidden = 5, 4, 6, 8 audio_val, image_val, video_val = 10.0, 20.0, 30.0 - result, input_ids, is_multimodal = self._run(audio_n, image_n, video_n, hidden) + result, input_ids, is_multimodal = self._run(mocker, audio_n, image_n, video_n, hidden) audio_pos = (input_ids == AUDIO_TOKEN_ID).nonzero(as_tuple=True)[0] image_pos = (input_ids == IMAGE_TOKEN_ID).nonzero(as_tuple=True)[0] @@ -233,10 +232,10 @@ def test_mixed_modalities_audio_goes_to_audio_pos(self): f"Video emb wrong: expected {video_val}, got mean={mean_v:.1f}" ) - def test_text_positions_unchanged(self): + def test_text_positions_unchanged(self, mocker: MockerFixture): """Text positions should keep their text embeddings.""" audio_n, image_n, video_n, hidden = 3, 2, 4, 8 - result, input_ids, is_multimodal = self._run(audio_n, image_n, video_n, hidden) + result, input_ids, is_multimodal = self._run(mocker, audio_n, image_n, video_n, hidden) text_pos = (~is_multimodal).nonzero(as_tuple=True)[0] # Text tokens have value TEXT_TOKEN_ID=0, so embed -> 0.0 @@ -244,7 +243,7 @@ def test_text_positions_unchanged(self): "Text positions should keep text embeddings" ) - def test_interleaved_use_audio_in_video(self): + def test_interleaved_use_audio_in_video(self, mocker: MockerFixture): """ Interleaved (use_audio_in_video): video chunks interleaved with audio. Video embeddings must go to video positions, audio to audio positions. @@ -263,7 +262,7 @@ def test_interleaved_use_audio_in_video(self): torch.full((audio_n, hidden), audio_val), ] - model, _ = make_mock_model(hidden) + model, _ = make_mock_model(mocker, hidden) result = model.embed_input_ids(input_ids, mm_embeds, is_multimodal=is_multimodal) video_pos = (input_ids == VIDEO_TOKEN_ID).nonzero(as_tuple=True)[0] diff --git a/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py b/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py index e2970dcb2d..8798cb3ca9 100644 --- a/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py +++ b/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py @@ -15,12 +15,13 @@ import os import sys import types -from unittest.mock import MagicMock, patch +import pytest import torch +from pytest_mock import MockerFixture # Direct file import to avoid vllm_omni.__init__ patch dependencies. -_BASE = os.path.join( +_MODELS = os.path.join( os.path.dirname(__file__), os.pardir, os.pardir, @@ -29,80 +30,116 @@ "vllm_omni", "model_executor", "models", - "qwen3_tts", ) +_BASE = os.path.join(_MODELS, "qwen3_tts") +_COMMON = os.path.join(_MODELS, "common") def _load_module(name: str, filename: str): path = os.path.abspath(os.path.join(_BASE, filename)) spec = importlib.util.spec_from_file_location(name, path) mod = importlib.util.module_from_spec(spec) + sys.modules[name] = mod # register before exec (needed for dataclasses etc.) spec.loader.exec_module(mod) return mod -def _build_mock_modules() -> dict[str, object]: +def _build_mock_modules(mocker: MockerFixture) -> dict[str, object]: """Build the dict of modules to inject into sys.modules.""" - platforms_mock = MagicMock() + platforms_mock = mocker.MagicMock() platforms_mock.current_omni_platform.supports_torch_inductor.return_value = False - logger_mock = MagicMock() - logger_mock.init_logger = lambda name: MagicMock() + logger_mock = mocker.MagicMock() + logger_mock.init_logger = lambda name: mocker.MagicMock() - vllm_config_mod = MagicMock() - vllm_config_mod.set_current_vllm_config = lambda cfg: MagicMock(__enter__=MagicMock(), __exit__=MagicMock()) + vllm_config_mod = mocker.MagicMock() + vllm_config_mod.set_current_vllm_config = lambda cfg: mocker.MagicMock( + __enter__=mocker.MagicMock(), + __exit__=mocker.MagicMock(), + ) - weight_utils_mock = MagicMock() + weight_utils_mock = mocker.MagicMock() weight_utils_mock.default_weight_loader = lambda p, w: None - pkg = types.ModuleType("vllm_omni.model_executor.models.qwen3_tts") - pkg.__path__ = [os.path.abspath(_BASE)] + tts_pkg = types.ModuleType("vllm_omni.model_executor.models.qwen3_tts") + tts_pkg.__path__ = [os.path.abspath(_BASE)] + + common_pkg = types.ModuleType("vllm_omni.model_executor.models.common") + common_pkg.__path__ = [os.path.abspath(_COMMON)] + + models_pkg = types.ModuleType("vllm_omni.model_executor.models") + models_pkg.__path__ = [os.path.abspath(_MODELS)] + + vllm_parallel_mock = mocker.MagicMock() + vllm_parallel_mock.VocabParallelEmbedding = torch.nn.Embedding return { - "vllm_omni": MagicMock(), + "vllm_omni": mocker.MagicMock(), "vllm_omni.platforms": platforms_mock, "vllm.logger": logger_mock, - "vllm.config": MagicMock(), + "vllm.config": mocker.MagicMock(), "vllm.config.vllm": vllm_config_mod, "vllm.model_executor.model_loader.weight_utils": weight_utils_mock, + "vllm.model_executor.layers.vocab_parallel_embedding": vllm_parallel_mock, "vllm_omni.model_executor": types.ModuleType("vllm_omni.model_executor"), - "vllm_omni.model_executor.models": types.ModuleType("vllm_omni.model_executor.models"), - "vllm_omni.model_executor.models.qwen3_tts": pkg, + "vllm_omni.model_executor.models": models_pkg, + "vllm_omni.model_executor.models.common": common_pkg, + "vllm_omni.model_executor.models.qwen3_tts": tts_pkg, } -def _load_target_classes(): +def _load_target_classes(mocker: MockerFixture): """Load config and code predictor modules with mocked dependencies. - Uses patch.dict to ensure sys.modules is always restored, even on failure. + Uses mocker.patch.dict to ensure sys.modules is always restored, even on failure. """ - mocks = _build_mock_modules() - with patch.dict(sys.modules, mocks): - config_mod = _load_module( - "vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts", - "configuration_qwen3_tts.py", - ) - sys.modules["vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts"] = config_mod + mocks = _build_mock_modules(mocker) + mocker.patch.dict(sys.modules, mocks) + config_mod = _load_module( + "vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts", + "configuration_qwen3_tts.py", + ) + sys.modules["vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts"] = config_mod - cp_mod = _load_module( - "vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code_predictor_vllm", - "qwen3_tts_code_predictor_vllm.py", - ) + # Load the shared common module (thin wrappers import from it) + common_cp_path = os.path.abspath(os.path.join(_COMMON, "qwen3_code_predictor.py")) + common_spec = importlib.util.spec_from_file_location( + "vllm_omni.model_executor.models.common.qwen3_code_predictor", common_cp_path + ) + common_cp_mod = importlib.util.module_from_spec(common_spec) + sys.modules["vllm_omni.model_executor.models.common.qwen3_code_predictor"] = common_cp_mod + common_spec.loader.exec_module(common_cp_mod) - return config_mod, cp_mod + cp_mod = _load_module( + "vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code_predictor_vllm", + "qwen3_tts_code_predictor_vllm.py", + ) + return config_mod, cp_mod -_config_mod, _cp_mod = _load_target_classes() -Qwen3TTSTalkerCodePredictorConfig = _config_mod.Qwen3TTSTalkerCodePredictorConfig -Qwen3TTSTalkerConfig = _config_mod.Qwen3TTSTalkerConfig -CodePredictorWrapper = _cp_mod.Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM -CodePredictorModel = _cp_mod.Qwen3TTSTalkerCodePredictorModelVLLM +@pytest.fixture +def loaded_target_classes(mocker: MockerFixture): + config_mod, cp_mod = _load_target_classes(mocker) + return ( + config_mod.Qwen3TTSTalkerCodePredictorConfig, + config_mod.Qwen3TTSTalkerConfig, + cp_mod.Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM, + cp_mod.Qwen3TTSTalkerCodePredictorModelVLLM, + cp_mod.CodePredictorWrapperConfig, + ) -def _make_tiny_config() -> tuple: +def _make_tiny_config(loaded_target_classes) -> tuple: """Create minimal configs for a tiny code predictor model.""" - cp_config = Qwen3TTSTalkerCodePredictorConfig( + ( + qwen3_tts_talker_code_predictor_config, + qwen3_tts_talker_config, + _, + _, + _, + ) = loaded_target_classes + cp_config = qwen3_tts_talker_code_predictor_config( vocab_size=64, hidden_size=32, intermediate_size=64, @@ -113,16 +150,16 @@ def _make_tiny_config() -> tuple: num_code_groups=4, rms_norm_eps=1e-6, ) - talker_config = Qwen3TTSTalkerConfig( + talker_config = qwen3_tts_talker_config( hidden_size=32, num_code_groups=4, ) return cp_config, talker_config -def _make_vllm_config(max_num_seqs: int = 4) -> MagicMock: +def _make_vllm_config(mocker: MockerFixture, max_num_seqs: int = 4): """Create a mock VllmConfig with scheduler_config.""" - vllm_config = MagicMock() + vllm_config = mocker.MagicMock() vllm_config.scheduler_config.max_num_seqs = max_num_seqs return vllm_config @@ -130,32 +167,34 @@ def _make_vllm_config(max_num_seqs: int = 4) -> MagicMock: class TestCodePredictorDtypeAlignment: """Test that code predictor buffers match model parameter dtype.""" - def test_ensure_buffers_uses_given_dtype(self) -> None: + def test_ensure_buffers_uses_given_dtype(self, mocker: MockerFixture, loaded_target_classes) -> None: """_ensure_buffers should create proj_buf with the given dtype.""" - cp_config, talker_config = _make_tiny_config() - vllm_config = _make_vllm_config() + _, _, code_predictor_wrapper, _, _ = loaded_target_classes + cp_config, talker_config = _make_tiny_config(loaded_target_classes) + vllm_config = _make_vllm_config(mocker) - predictor = CodePredictorWrapper( + predictor = code_predictor_wrapper( vllm_config=vllm_config, config=cp_config, talker_config=talker_config, ) # Create buffer in float16 - predictor._ensure_buffers(torch.device("cpu"), torch.float16) + predictor._ensure_buffers(torch.device("cpu"), torch.float16, 4) assert predictor._proj_buf is not None assert predictor._proj_buf.dtype == torch.float16 # Re-create buffer in float32 (different dtype triggers re-allocation) - predictor._ensure_buffers(torch.device("cpu"), torch.float32) + predictor._ensure_buffers(torch.device("cpu"), torch.float32, 4) assert predictor._proj_buf.dtype == torch.float32 - def test_warmup_aligns_buffer_to_model_params(self) -> None: + def test_warmup_aligns_buffer_to_model_params(self, mocker: MockerFixture, loaded_target_classes) -> None: """_warmup_buckets should align proj_buf dtype to model parameters.""" - cp_config, talker_config = _make_tiny_config() - vllm_config = _make_vllm_config(max_num_seqs=2) + _, _, code_predictor_wrapper, _, _ = loaded_target_classes + cp_config, talker_config = _make_tiny_config(loaded_target_classes) + vllm_config = _make_vllm_config(mocker, max_num_seqs=2) - predictor = CodePredictorWrapper( + predictor = code_predictor_wrapper( vllm_config=vllm_config, config=cp_config, talker_config=talker_config, @@ -165,7 +204,7 @@ def test_warmup_aligns_buffer_to_model_params(self) -> None: predictor = predictor.to(torch.float16) # Pre-create proj_buf with WRONG dtype (float32) — simulating the bug - predictor._ensure_buffers(torch.device("cpu"), torch.float32) + predictor._ensure_buffers(torch.device("cpu"), torch.float32, 2) assert predictor._proj_buf.dtype == torch.float32 # Simulate _setup_compile having cached model dtype and compiled forward @@ -177,12 +216,13 @@ def test_warmup_aligns_buffer_to_model_params(self) -> None: assert predictor._proj_buf.dtype == torch.float16 - def test_setup_compile_caches_model_dtype(self) -> None: + def test_setup_compile_caches_model_dtype(self, mocker: MockerFixture, loaded_target_classes) -> None: """_setup_compile should cache model parameter dtype.""" - cp_config, talker_config = _make_tiny_config() - vllm_config = _make_vllm_config(max_num_seqs=2) + _, _, code_predictor_wrapper, _, _ = loaded_target_classes + cp_config, talker_config = _make_tiny_config(loaded_target_classes) + vllm_config = _make_vllm_config(mocker, max_num_seqs=2) - predictor = CodePredictorWrapper( + predictor = code_predictor_wrapper( vllm_config=vllm_config, config=cp_config, talker_config=talker_config, @@ -193,12 +233,13 @@ def test_setup_compile_caches_model_dtype(self) -> None: predictor._setup_compile() assert predictor._model_dtype == torch.float16 - def test_forward_with_mismatched_input_dtype(self) -> None: + def test_forward_with_mismatched_input_dtype(self, mocker: MockerFixture, loaded_target_classes) -> None: """forward() should not crash when inputs are float32 but model is float16.""" - cp_config, talker_config = _make_tiny_config() - vllm_config = _make_vllm_config(max_num_seqs=2) + _, _, code_predictor_wrapper, _, _ = loaded_target_classes + cp_config, talker_config = _make_tiny_config(loaded_target_classes) + vllm_config = _make_vllm_config(mocker, max_num_seqs=2) - predictor = CodePredictorWrapper( + predictor = code_predictor_wrapper( vllm_config=vllm_config, config=cp_config, talker_config=talker_config, @@ -231,10 +272,11 @@ def test_forward_with_mismatched_input_dtype(self) -> None: class TestCodePredictorModelDtype: """Test the inner model forward with different dtypes.""" - def test_model_forward_float16(self) -> None: + def test_model_forward_float16(self, loaded_target_classes) -> None: """Inner model forward should work in float16.""" - cp_config, _ = _make_tiny_config() - model = CodePredictorModel(cp_config, talker_hidden_size=32).to(torch.float16) + _, _, _, code_predictor_model, _ = loaded_target_classes + cp_config, _ = _make_tiny_config(loaded_target_classes) + model = code_predictor_model(cp_config, embedding_dim=32).to(torch.float16) bsz, seq_len = 1, 4 inputs = torch.randn(bsz, seq_len, 32, dtype=torch.float16) @@ -244,10 +286,11 @@ def test_model_forward_float16(self) -> None: assert output.dtype == torch.float16 assert output.shape == (bsz, seq_len, 32) - def test_model_forward_float32(self) -> None: + def test_model_forward_float32(self, loaded_target_classes) -> None: """Inner model forward should work in float32.""" - cp_config, _ = _make_tiny_config() - model = CodePredictorModel(cp_config, talker_hidden_size=32).to(torch.float32) + _, _, _, code_predictor_model, _ = loaded_target_classes + cp_config, _ = _make_tiny_config(loaded_target_classes) + model = code_predictor_model(cp_config, embedding_dim=32).to(torch.float32) bsz, seq_len = 1, 4 inputs = torch.randn(bsz, seq_len, 32, dtype=torch.float32) @@ -256,3 +299,37 @@ def test_model_forward_float32(self) -> None: output = model(inputs, pos_ids) assert output.dtype == torch.float32 assert output.shape == (bsz, seq_len, 32) + + +class TestCodePredictorWrapperConfig: + """Test wrapper configuration for different models.""" + + def test_omni_config(self, loaded_target_classes) -> None: + """Qwen3-Omni uses correct wrapper config.""" + _, _, _, _, code_predictor_wrapper_config = loaded_target_classes + config = code_predictor_wrapper_config( + use_cuda_graphs=False, + use_parallel_embedding=True, + use_projection=False, + return_proj_buf=True, + sampling_mode="stored", + ) + assert config.use_cuda_graphs is False + assert config.use_parallel_embedding is True + assert config.return_proj_buf is True + assert config.sampling_mode == "stored" + + def test_tts_config(self, loaded_target_classes) -> None: + """Qwen3-TTS uses correct wrapper config.""" + _, _, _, _, code_predictor_wrapper_config = loaded_target_classes + config = code_predictor_wrapper_config( + use_cuda_graphs=True, + use_parallel_embedding=False, + use_projection=True, + return_proj_buf=False, + sampling_mode="per_call", + ) + assert config.use_cuda_graphs is True + assert config.use_parallel_embedding is False + assert config.return_proj_buf is False + assert config.sampling_mode == "per_call" diff --git a/tests/model_executor/models/test_encoder_quant_config.py b/tests/model_executor/models/test_encoder_quant_config.py new file mode 100644 index 0000000000..8020184986 --- /dev/null +++ b/tests/model_executor/models/test_encoder_quant_config.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Regression test for #2686: pre-quantized methods must not apply +quant config to vision / audio encoders. + +For modelopt FP8/FP4/MXFP8 checkpoints the Thinker LM is the only +quantized component. Vision and audio encoder weights are BF16 with no +FP8 scale tensors — passing quant_config to them causes FP8 kernels to +run on BF16 weights, producing garbage embeddings. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from vllm_omni.quantization.component_config import ( + PRE_QUANTIZED_METHODS, + ComponentQuantizationConfig, + resolve_encoder_quant_config, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + +# --------------------------------------------------------------------------- +# resolve_encoder_quant_config — the core routing logic for encoder quant +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("method", sorted(PRE_QUANTIZED_METHODS)) +def test_pre_quantized_returns_none(method: str) -> None: + """visual_quant_config and audio_quant_config must be None for + pre-quantized methods (modelopt, modelopt_fp4, modelopt_mxfp8).""" + mock_config = MagicMock() + mock_config.get_name.return_value = method + + assert resolve_encoder_quant_config(mock_config) is None + + +@pytest.mark.parametrize("method", ["fp8", "awq", "gptq", "bitsandbytes"]) +def test_non_pre_quantized_preserves_config(method: str) -> None: + """Non-pre-quantized methods should pass through the original config.""" + mock_config = MagicMock() + mock_config.get_name.return_value = method + + assert resolve_encoder_quant_config(mock_config) is mock_config + + +def test_none_input_returns_none() -> None: + """No quantization → None for encoders.""" + assert resolve_encoder_quant_config(None) is None + + +def test_component_config_passed_through() -> None: + """ComponentQuantizationConfig should be returned as-is so the caller + can call .resolve() with the appropriate prefix.""" + inner = MagicMock() + inner.get_name.return_value = "modelopt" # would be None if not Component + component = ComponentQuantizationConfig( + component_configs={"language_model": inner}, + default_config=None, + ) + + result = resolve_encoder_quant_config(component) + assert result is component + + +# --------------------------------------------------------------------------- +# PRE_QUANTIZED_METHODS constant — exhaustiveness check +# --------------------------------------------------------------------------- + + +def test_pre_quantized_methods_contains_expected() -> None: + """Guard against accidental removal of a known pre-quantized method.""" + expected = {"modelopt", "modelopt_fp4", "modelopt_mxfp8"} + assert PRE_QUANTIZED_METHODS == expected diff --git a/tests/model_executor/models/test_fish_speech_voice_cache.py b/tests/model_executor/models/test_fish_speech_voice_cache.py index 8fe7a4a4d1..fef4b551ab 100644 --- a/tests/model_executor/models/test_fish_speech_voice_cache.py +++ b/tests/model_executor/models/test_fish_speech_voice_cache.py @@ -10,11 +10,11 @@ import os import tempfile -from unittest.mock import MagicMock, patch import numpy as np import pytest import torch +from pytest_mock import MockerFixture pytestmark = [pytest.mark.core_model, pytest.mark.cpu] @@ -61,18 +61,18 @@ class TestFishSpeechVoiceCacheIntegration: """Test the cache-hit / cache-miss / no-cache paths in the model.""" @pytest.fixture - def mock_model(self): + def mock_model(self, mocker: MockerFixture): """Create a mock FishSpeechSlowARForConditionalGeneration with cache.""" from vllm_omni.utils.voice_cache import VoiceEmbeddingCache - model = MagicMock() + model = mocker.MagicMock() model._voice_cache = VoiceEmbeddingCache(max_entries=4) model._semantic_begin_id = 151678 model._num_codebooks = 10 model._codebook_size = 4096 model.model_path = "/fake/model" - model.codebook_embeddings = MagicMock() - model.codebook_embeddings.weight = MagicMock() + model.codebook_embeddings = mocker.MagicMock() + model.codebook_embeddings.weight = mocker.MagicMock() model.codebook_embeddings.weight.device = torch.device("cpu") return model @@ -166,9 +166,9 @@ def test_created_at_zero_disables_cache(self, mock_model): class TestFishSpeechValidatorUploadedVoice: """Test _validate_fish_tts_request uploaded voice resolution.""" - def test_uploaded_voice_resolves_ref_audio(self): + def test_uploaded_voice_resolves_ref_audio(self, mocker: MockerFixture): """When voice matches an uploaded speaker, ref_audio should be auto-set.""" - request = MagicMock() + request = mocker.MagicMock() request.input = "Hello" request.voice = "alice" request.ref_audio = None @@ -185,17 +185,17 @@ def test_uploaded_voice_resolves_ref_audio(self): } # Simulate: voice in uploaded_speakers, file exists, get_audio returns data URL. - with patch("pathlib.Path.exists", return_value=True): - voice_lower = request.voice.lower() - assert voice_lower in uploaded_speakers + mocker.patch("pathlib.Path.exists", return_value=True) + voice_lower = request.voice.lower() + assert voice_lower in uploaded_speakers - speaker_info = uploaded_speakers[voice_lower] - ref_text_from_upload = speaker_info.get("ref_text") - assert ref_text_from_upload == "Hi this is Alice" + speaker_info = uploaded_speakers[voice_lower] + ref_text_from_upload = speaker_info.get("ref_text") + assert ref_text_from_upload == "Hi this is Alice" - def test_uploaded_voice_without_ref_text_uses_request_ref_text(self): + def test_uploaded_voice_without_ref_text_uses_request_ref_text(self, mocker: MockerFixture): """If upload has no ref_text but request provides it, use request's.""" - request = MagicMock() + request = mocker.MagicMock() request.input = "Hello" request.voice = "bob" request.ref_audio = None diff --git a/tests/model_executor/models/voxcpm2/__init__.py b/tests/model_executor/models/voxcpm2/__init__.py new file mode 100644 index 0000000000..208f01a7cb --- /dev/null +++ b/tests/model_executor/models/voxcpm2/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/tests/model_executor/models/voxcpm2/test_talker_state_eviction.py b/tests/model_executor/models/voxcpm2/test_talker_state_eviction.py new file mode 100644 index 0000000000..5d8a35636b --- /dev/null +++ b/tests/model_executor/models/voxcpm2/test_talker_state_eviction.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Regression tests for VoxCPM2 talker per-request state lifecycle.""" + +from __future__ import annotations + +import pytest + +torch = pytest.importorskip("torch") +pytest.importorskip("librosa") + +from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import ( # noqa: E402 + VoxCPM2TalkerForConditionalGeneration, + _RequestState, +) + + +def _make_bare_talker() -> VoxCPM2TalkerForConditionalGeneration: + talker = VoxCPM2TalkerForConditionalGeneration.__new__(VoxCPM2TalkerForConditionalGeneration) + talker._active_states = {} + talker._current_request_id = None + talker._pending_requests = [] + talker._results_queue = [] + talker._audio_queue = [] + talker._deferred_cleanup_ids = set() + talker._max_batch_size = 4 + talker._active_state_warn_threshold = 512 + talker._active_state_warned = False + return talker + + +def _seed_cached_decode(talker, req_id: str) -> _RequestState: + state = _RequestState(request_id=req_id) + state.prefill_completed = True + state.decode_step_count = 5 + talker._active_states[req_id] = state + return state + + +class TestStateEvictionContract: + def test_pending_requests_is_not_used_for_eviction(self) -> None: + talker = _make_bare_talker() + + cached_ids = [f"req-{i}" for i in range(4)] + for rid in cached_ids: + _seed_cached_decode(talker, rid) + + walked_so_far = ["req-new", cached_ids[0], cached_ids[1]] + talker._pending_requests = [(rid, False, None, 0) for rid in walked_so_far] + + for rid in cached_ids: + assert rid in talker._active_states + assert talker._active_states[rid].prefill_completed is True + + def test_on_requests_finished_defers_cleanup(self) -> None: + talker = _make_bare_talker() + _seed_cached_decode(talker, "req-A") + _seed_cached_decode(talker, "req-B") + + talker.on_requests_finished({"req-A"}) + + assert "req-A" in talker._active_states + assert "req-A" in talker._deferred_cleanup_ids + + def test_flush_deferred_cleanup_removes_only_finished(self) -> None: + talker = _make_bare_talker() + _seed_cached_decode(talker, "req-A") + _seed_cached_decode(talker, "req-B") + talker.on_requests_finished(["req-A"]) + + talker._flush_deferred_cleanup() + + assert "req-A" not in talker._active_states + assert "req-B" in talker._active_states + assert talker._deferred_cleanup_ids == set() + + def test_current_request_id_cleared_when_matching(self) -> None: + talker = _make_bare_talker() + _seed_cached_decode(talker, "req-A") + talker._current_request_id = "req-A" + + talker.on_requests_finished({"req-A"}) + talker._flush_deferred_cleanup() + + assert talker._current_request_id is None + + def test_current_request_id_preserved_when_not_finished(self) -> None: + talker = _make_bare_talker() + _seed_cached_decode(talker, "req-A") + _seed_cached_decode(talker, "req-B") + talker._current_request_id = "req-B" + + talker.on_requests_finished({"req-A"}) + talker._flush_deferred_cleanup() + + assert talker._current_request_id == "req-B" + + +class TestLeakWarnGuard: + def test_warn_fires_once_over_threshold(self, monkeypatch) -> None: + from vllm_omni.model_executor.models.voxcpm2 import voxcpm2_talker as tk + + calls: list[str] = [] + + def _capture(msg, *args, **kwargs): + calls.append(msg % args if args else msg) + + monkeypatch.setattr(tk.logger, "warning", _capture) + + talker = _make_bare_talker() + talker._active_state_warn_threshold = 3 + + for i in range(4): + talker._active_states[f"seed-{i}"] = _RequestState(request_id=f"seed-{i}") + + talker._get_or_create_state("new-1") + talker._get_or_create_state("new-2") + + leak_warnings = [m for m in calls if "cleanup path leak" in m] + assert len(leak_warnings) == 1 + assert talker._active_state_warned is True diff --git a/tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py b/tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py index 6f072944d9..847adae06f 100644 --- a/tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py +++ b/tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py @@ -78,6 +78,13 @@ AudioSpecialTokens = _mod2.AudioSpecialTokens +class SyntheticAcousticTransformerArgs: + """Mimics AcousticTransformerArgs interface.""" + + def __init__(self): + self.n_decoding_steps = 7 + + class SyntheticModelArgs: """Mimics MultimodalAudioModelArgs interface.""" @@ -96,6 +103,7 @@ class SyntheticAcousticTransformer(nn.Module): def __init__(self): super().__init__() self.model_args = SyntheticModelArgs() + self.acoustic_transformer_args = SyntheticAcousticTransformerArgs() self.acoustic_embeddings_levels = ACOUSTIC_EMBEDDINGS_LEVELS # semantic_codebook_output: hidden_dim -> padded_codebook_size diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py new file mode 100644 index 0000000000..18972c91d5 --- /dev/null +++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for Qwen3-Omni streaming thinker→talker / talker→codec helpers (PR #2581).""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +import vllm_omni.model_executor.stage_input_processors.qwen3_omni as q3 + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +@pytest.fixture(autouse=True) +def _streaming_context() -> SimpleNamespace: + return SimpleNamespace(bridge_states={}) + + +def test_get_streaming_talker_tokens_first_segment(_streaming_context: SimpleNamespace) -> None: + inc_p, inc_o, merged, thinker_in = q3._get_streaming_talker_tokens( + "r1", + [1, 2], + [10, 11], + streaming_context=_streaming_context, + ) + assert inc_p == [1, 2] + assert inc_o == [10, 11] + assert merged == [1, 2, 10, 11] + assert thinker_in == [1, 2] + + +def test_get_streaming_talker_tokens_second_segment_accumulates(_streaming_context: SimpleNamespace) -> None: + q3._get_streaming_talker_tokens("r2", [1, 2], [10, 11], streaming_context=_streaming_context) + inc_p, inc_o, merged, thinker_in = q3._get_streaming_talker_tokens( + "r2", + [1, 2, 3, 4], + [10, 11, 12, 13], + streaming_context=_streaming_context, + ) + assert inc_p == [3, 4] + assert inc_o == [12, 13] + assert merged == [1, 2, 10, 3, 4, 12, 13] + assert thinker_in == [1, 2, 10, 3, 4] + + +def test_get_streaming_talker_tokens_new_prompt_len_snapshot_truncates( + _streaming_context: SimpleNamespace, +) -> None: + inc_p, inc_o, merged, thinker_in = q3._get_streaming_talker_tokens( + "r3", + [1, 2, 3, 4, 5, 6], + [10], + new_prompt_len_snapshot=2, + streaming_context=_streaming_context, + ) + assert inc_p == [1, 2, 3, 4] + assert inc_o == [10] + assert merged == [1, 2, 3, 4, 10] + assert thinker_in == [1, 2, 3, 4] + + +def test_get_streaming_talker_tokens_clear_state(_streaming_context: SimpleNamespace) -> None: + q3._get_streaming_talker_tokens("r4", [1], [2], streaming_context=_streaming_context, clear_state=True) + state = q3._get_qwen3_streaming_state("r4", _streaming_context).thinker2talker + assert state.last_prompt_len == 0 + assert state.last_output_len == 0 + assert state.merged_sequences == [] + + +def test_get_streaming_codec_delta_len_increments_and_finishes(_streaming_context: SimpleNamespace) -> None: + d1 = q3._get_streaming_codec_delta_len(5, "c1", SimpleNamespace(finished=False), _streaming_context) + assert d1 == 5 + d2 = q3._get_streaming_codec_delta_len(8, "c1", SimpleNamespace(finished=False), _streaming_context) + assert d2 == 2 + # After d2, stored cursor is cur_seq_len + 1 == 9; next delta uses new cur_seq_len - 9. + d3 = q3._get_streaming_codec_delta_len(10, "c1", SimpleNamespace(finished=True), _streaming_context) + assert d3 == 1 + state = q3._get_qwen3_streaming_state("c1", _streaming_context) + assert state.talker2code2wav_last_seq_len == 0 diff --git a/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py b/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py new file mode 100644 index 0000000000..7d6fc6e74c --- /dev/null +++ b/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""UTs for VoxCPM async-chunk stage input processing.""" + +from types import SimpleNamespace + +import pytest +import torch + +from vllm_omni.model_executor.stage_input_processors.voxcpm import ( + _VOXCPM_LATENT_MAGIC, + _coerce_finished_flag, + latent2vae_async_chunk, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def _request(*, finished): + return SimpleNamespace(is_finished=lambda: finished) + + +def _decode_serialized_latent(codes: list[int]) -> torch.Tensor: + assert codes[0] == _VOXCPM_LATENT_MAGIC + latent_dim = codes[1] + time_dim = codes[2] + payload = torch.tensor(codes[3:], dtype=torch.int32).to(torch.uint16) + return payload.view(torch.bfloat16).to(torch.float32).reshape(1, latent_dim, time_dim) + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + (None, False), + (False, False), + (True, True), + (torch.tensor(False), False), + (torch.tensor(True), True), + ([torch.tensor(True)], True), + (([True],), True), + ([], False), + ], +) +def test_coerce_finished_flag(value, expected): + assert _coerce_finished_flag(value) is expected + + +def test_latent2vae_async_chunk_serializes_latent_payload(): + latent = torch.arange(6, dtype=torch.float32).reshape(2, 3) + + payload = latent2vae_async_chunk( + transfer_manager=None, + pooling_output={"latent_audio_feat": latent}, + request=_request(finished=False), + is_finished=torch.tensor(False), + ) + + assert payload is not None + assert torch.equal(payload["finished"], torch.tensor(False, dtype=torch.bool)) + recovered = _decode_serialized_latent(payload["code_predictor_codes"]) + torch.testing.assert_close(recovered, latent.to(torch.bfloat16).to(torch.float32).unsqueeze(0)) + + +def test_latent2vae_async_chunk_returns_terminal_marker_without_latent(): + payload = latent2vae_async_chunk( + transfer_manager=None, + pooling_output=None, + request=_request(finished=[torch.tensor(True)]), + is_finished=False, + ) + + assert payload == { + "code_predictor_codes": [], + "finished": torch.tensor(True, dtype=torch.bool), + } + + +def test_latent2vae_async_chunk_returns_none_for_nonterminal_empty_chunk(): + payload = latent2vae_async_chunk( + transfer_manager=None, + pooling_output={"latent_audio_feat": torch.zeros((0,), dtype=torch.float32)}, + request=_request(finished=False), + is_finished=False, + ) + + assert payload is None diff --git a/tests/test_arg_utils.py b/tests/test_arg_utils.py new file mode 100644 index 0000000000..dab5ed6878 --- /dev/null +++ b/tests/test_arg_utils.py @@ -0,0 +1,353 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for vllm_omni.engine.arg_utils — invariants that must +hold for the orchestrator/engine/server CLI flag partition.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, fields + +import pytest + +from vllm_omni.engine.arg_utils import ( + SHARED_FIELDS, + derive_server_dests_from_vllm_parser, + internal_blacklist_keys, + orchestrator_args_from_argparse, + orchestrator_field_names, + split_kwargs, +) + +# --------------------------------------------------------------------------- +# Fake engine class for unit testing — avoids pulling in the full vllm +# EngineArgs and its heavy __post_init__ at test time. +# --------------------------------------------------------------------------- + + +@dataclass +class _FakeEngineArgs: + """Stand-in for OmniEngineArgs with a representative subset of fields.""" + + model: str = "" + stage_id: int = 0 + max_num_seqs: int = 64 + gpu_memory_utilization: float = 0.9 + async_chunk: bool = False # also in OrchestratorArgs → shared + log_stats: bool = False # also in OrchestratorArgs → shared + stage_configs_path: str | None = None + + +# ============================================================================ +# Invariant 1 — OrchestratorArgs and engine must not ambiguously overlap. +# ============================================================================ + + +def test_no_ambiguous_overlap_with_fake_engine(): + """OrchestratorArgs ∩ engine fields must be ⊆ SHARED_FIELDS.""" + orch = orchestrator_field_names() + engine = {f.name for f in fields(_FakeEngineArgs)} + overlap = orch & engine + unexpected = overlap - SHARED_FIELDS + assert not unexpected, ( + f"Fields declared in both OrchestratorArgs and the engine class " + f"but not in SHARED_FIELDS: {sorted(unexpected)}. These cause " + f"double-routing — either remove the duplicate declaration or add " + f"to SHARED_FIELDS if sharing is intentional." + ) + + +def test_no_ambiguous_overlap_with_real_engine(): + """Same check, but against the real OmniEngineArgs.""" + try: + from vllm_omni.engine.arg_utils import OmniEngineArgs + except Exception as exc: + pytest.skip(f"OmniEngineArgs not importable: {exc}") + + orch = orchestrator_field_names() + engine = {f.name for f in fields(OmniEngineArgs)} + overlap = orch & engine + unexpected = overlap - SHARED_FIELDS + assert not unexpected, ( + f"Real OmniEngineArgs has ambiguous overlap with OrchestratorArgs: " + f"{sorted(unexpected)}. Update SHARED_FIELDS or remove duplication." + ) + + +# ============================================================================ +# Invariant 2 — split_kwargs partitions correctly. +# ============================================================================ + + +def test_split_orchestrator_only(): + """Pure orchestrator fields go to OrchestratorArgs, not engine_kwargs.""" + raw = {"stage_init_timeout": 500, "worker_backend": "ray"} + orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs) + assert orch.stage_init_timeout == 500 + assert orch.worker_backend == "ray" + assert "stage_init_timeout" not in engine + assert "worker_backend" not in engine + + +def test_split_engine_only(): + """Pure engine fields go to engine_kwargs, not OrchestratorArgs.""" + raw = {"max_num_seqs": 128, "gpu_memory_utilization": 0.85} + orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs) + assert engine["max_num_seqs"] == 128 + assert engine["gpu_memory_utilization"] == 0.85 + # These fields don't exist on OrchestratorArgs at all. + + +def test_split_shared_fields_go_to_both(): + """Fields in SHARED_FIELDS are copied to both buckets.""" + raw = {"model": "Qwen/Qwen2.5-Omni-7B", "log_stats": True} + orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs) + assert orch.log_stats is True + assert engine["model"] == "Qwen/Qwen2.5-Omni-7B" + assert engine["log_stats"] is True + + +def test_split_drops_unclassified(): + """Unclassified fields (uvicorn/server) are dropped silently.""" + raw = { + "max_num_seqs": 64, # engine + "host": "0.0.0.0", # unclassified (server) + "port": 8091, # unclassified (server) + "ssl_keyfile": "key.pem", # unclassified (server) + } + orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs) + assert engine == {"max_num_seqs": 64} + assert "host" not in engine + assert "port" not in engine + assert "ssl_keyfile" not in engine + + +def test_split_mixed_real_world(): + """End-to-end: raw CLI kwargs with all three classes present.""" + raw = { + # orchestrator + "stage_init_timeout": 400, + "deploy_config": "/tmp/deploy.yaml", + "worker_backend": "multi_process", + "async_chunk": True, + # engine + "max_num_seqs": 32, + "gpu_memory_utilization": 0.8, + # shared + "model": "Qwen/Qwen3-Omni", + "log_stats": False, + # server / unclassified + "host": "0.0.0.0", + "port": 8091, + "api_key": "secret", + # None values + "ray_address": None, + } + orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs) + + # Orchestrator side + assert orch.stage_init_timeout == 400 + assert orch.deploy_config == "/tmp/deploy.yaml" + assert orch.worker_backend == "multi_process" + assert orch.async_chunk is True + assert orch.log_stats is False # shared, read from raw + assert orch.ray_address is None # default preserved + + # Engine side + assert engine["max_num_seqs"] == 32 + assert engine["gpu_memory_utilization"] == 0.8 + assert engine["model"] == "Qwen/Qwen3-Omni" + assert engine["log_stats"] is False + assert "host" not in engine + assert "port" not in engine + assert "api_key" not in engine + # orchestrator-only keys never reach engine + assert "stage_init_timeout" not in engine + assert "deploy_config" not in engine + assert "async_chunk" not in engine + + +# ============================================================================ +# Invariant 3 — user-typed unclassifiable flags warn (don't fail silently). +# ============================================================================ + + +def test_user_typed_unclassified_warns(caplog): + """If the user types a flag we can't route, warn — don't silently drop.""" + raw = {"bogus_flag": "value", "max_num_seqs": 64} + with caplog.at_level(logging.WARNING, logger="vllm_omni.engine.arg_utils"): + split_kwargs(raw, engine_cls=_FakeEngineArgs, user_typed={"bogus_flag"}) + assert any("bogus_flag" in rec.message for rec in caplog.records), ( + f"Expected warning mentioning 'bogus_flag', got: {[rec.message for rec in caplog.records]}" + ) + + +def test_unclassified_without_user_typed_silent(caplog): + """Without user_typed, unclassified keys drop silently (argparse defaults + for server flags shouldn't spam logs on every launch).""" + raw = {"host": "0.0.0.0", "port": 8091, "max_num_seqs": 64} + with caplog.at_level(logging.WARNING, logger="vllm_omni.engine.arg_utils"): + split_kwargs(raw, engine_cls=_FakeEngineArgs, user_typed=None) + # No warnings because we don't know these were user-typed. + assert not any("host" in rec.message or "port" in rec.message for rec in caplog.records) + + +# ============================================================================ +# Invariant 4 — CLI flag classification completeness. +# Catches new flags added without updating OrchestratorArgs or OmniEngineArgs. +# ============================================================================ + + +def test_all_omni_cli_flags_classified(): + """Every vllm-omni-added CLI flag must be classifiable. + + Runs ``OmniServeCommand.subparser_init`` and checks that every new + argument (compared to vllm's base parser) is either: + - a field on OrchestratorArgs, OR + - a field on OmniEngineArgs, OR + - in SHARED_FIELDS + """ + try: + from vllm.utils.argparse_utils import FlexibleArgumentParser + + from vllm_omni.engine.arg_utils import OmniEngineArgs + from vllm_omni.entrypoints.cli.serve import OmniServeCommand + except Exception as exc: + pytest.skip(f"Cannot build parser in this environment: {exc}") + + # Build the serve parser + root = FlexibleArgumentParser() + subparsers = root.add_subparsers() + cmd = OmniServeCommand() + try: + parser = cmd.subparser_init(subparsers) + except Exception as exc: + pytest.skip(f"subparser_init failed (dev env issue): {exc}") + + all_dests = {a.dest for a in parser._actions if a.dest and a.dest not in {"help", "model_tag"}} + + orch = orchestrator_field_names() + engine = {f.name for f in fields(OmniEngineArgs)} + server_derived = derive_server_dests_from_vllm_parser() + + unclassified = all_dests - orch - engine - SHARED_FIELDS - server_derived + # Some argparse-internal dests (suppressed, private) may not match — + # filter those out. + unclassified = {d for d in unclassified if not d.startswith("_")} + + assert not unclassified, ( + f"These CLI flags are not classified as " + f"orchestrator/engine/shared/server: {sorted(unclassified)}. " + f"Add them to OrchestratorArgs (if consumed by orchestrator), " + f"OmniEngineArgs (if consumed by per-stage engine), or the known-server " + f"allowlist (if they're vllm frontend flags). " + f"If intentional (e.g. a new CLI-only flag that doesn't map to either " + f"dataclass), add it to a KNOWN_UNROUTED allowlist." + ) + + +# ============================================================================ +# argparse interop (Phase 3). +# ============================================================================ + + +def test_orchestrator_args_from_argparse(): + """Can build OrchestratorArgs from an argparse.Namespace.""" + import argparse + + ns = argparse.Namespace( + stage_init_timeout=500, + deploy_config="/tmp/x.yaml", + max_num_seqs=64, # engine field — ignored + host="0.0.0.0", # server field — ignored + ) + orch = orchestrator_args_from_argparse(ns) + assert orch.stage_init_timeout == 500 + assert orch.deploy_config == "/tmp/x.yaml" + assert orch.worker_backend == "multi_process" # default + + +def test_derive_server_dests_returns_frozenset(): + """Server-dest derivation returns a frozenset (possibly empty).""" + result = derive_server_dests_from_vllm_parser() + assert isinstance(result, frozenset) + + +# ============================================================================ +# internal_blacklist_keys — single source of truth for per-stage forwarding. +# ============================================================================ + + +def test_internal_blacklist_keys_derived_from_orchestrator(): + """Blacklist is exactly OrchestratorArgs fields minus SHARED_FIELDS. + + This function replaces the old hardcoded INTERNAL_STAGE_OVERRIDE_KEYS + frozenset. Asserts the contract so future changes to OrchestratorArgs + automatically propagate to the blacklist. + """ + blacklist = internal_blacklist_keys() + assert blacklist == orchestrator_field_names() - SHARED_FIELDS + # Spot-check expected entries + assert "stage_init_timeout" in blacklist + assert "deploy_config" in blacklist + assert "async_chunk" in blacklist + # Shared fields must NOT appear — they flow to both orchestrator and engine + assert "model" not in blacklist + assert "log_stats" not in blacklist + + +# ============================================================================ +# Boundary value analysis — edge cases around split_kwargs. +# ============================================================================ + + +def test_split_empty_kwargs(): + """Empty kwargs yields default OrchestratorArgs and empty engine dict.""" + orch, engine = split_kwargs({}, engine_cls=_FakeEngineArgs) + assert orch.stage_init_timeout == 300 # dataclass default + assert orch.worker_backend == "multi_process" # dataclass default + assert engine == {} + + +def test_split_all_none_values_preserved_on_orchestrator(): + """None values for orchestrator fields are kept (represents 'not set').""" + raw = {"ray_address": None, "deploy_config": None, "max_num_seqs": None} + orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs) + assert orch.ray_address is None + assert orch.deploy_config is None + # Engine-side None still passes through; caller decides semantics downstream. + assert engine.get("max_num_seqs") is None + + +def test_split_user_typed_with_empty_kwargs_no_warn(caplog): + """user_typed non-empty but kwargs empty — no warnings emitted.""" + with caplog.at_level(logging.WARNING, logger="vllm_omni.engine.arg_utils"): + split_kwargs({}, engine_cls=_FakeEngineArgs, user_typed={"nothing"}) + assert not caplog.records + + +def test_ambiguous_field_strict_raises(): + """strict=True raises ValueError on overlap outside SHARED_FIELDS.""" + + # deploy_config is on OrchestratorArgs; declaring it on the engine class + # too (without adding to SHARED_FIELDS) creates an ambiguous route. + @dataclass + class _AmbiguousEngine: + deploy_config: str | None = None + + with pytest.raises(ValueError, match="both OrchestratorArgs and"): + split_kwargs({"deploy_config": "x"}, engine_cls=_AmbiguousEngine, strict=True) + + +def test_ambiguous_field_non_strict_routes_to_orchestrator(caplog): + """strict=False logs ERROR but routes the ambiguous field to orchestrator.""" + + @dataclass + class _AmbiguousEngine: + deploy_config: str | None = None + + with caplog.at_level(logging.ERROR, logger="vllm_omni.engine.arg_utils"): + orch, engine = split_kwargs({"deploy_config": "x"}, engine_cls=_AmbiguousEngine, strict=False) + assert orch.deploy_config == "x" + assert "deploy_config" not in engine + assert any("both OrchestratorArgs" in r.message for r in caplog.records) diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py index e284de48d0..1d65d3acd2 100644 --- a/tests/test_config_factory.py +++ b/tests/test_config_factory.py @@ -4,12 +4,26 @@ Unit tests for StageConfigFactory and related classes. """ +from dataclasses import dataclass +from pathlib import Path + +import pytest + from vllm_omni.config.stage_config import ( + _EXECUTION_TYPE_TO_SCHEDULER, + _PIPELINE_REGISTRY, ModelPipeline, + PipelineConfig, StageConfig, StageConfigFactory, + StageExecutionType, + StagePipelineConfig, StageType, + build_stage_runtime_overrides, + register_pipeline, + strip_parent_engine_args, ) +from vllm_omni.engine.arg_utils import SHARED_FIELDS, internal_blacklist_keys class TestStageType: @@ -241,8 +255,9 @@ def test_default_diffusion_no_yaml(self): def test_default_diffusion_with_parallel_config(self): """Test diffusion config calculates devices from parallel_config.""" + @dataclass class MockParallelConfig: - world_size = 4 + world_size: int = 4 kwargs = { "parallel_config": MockParallelConfig(), @@ -270,7 +285,7 @@ def test_cli_override_forwards_engine_registered_args(self): stage = StageConfig(stage_id=0, model_stage="thinker", input_sources=[]) cli_overrides = { "gpu_memory_utilization": 0.9, # Well-known param - "custom_engine_flag": True, # Not in _INTERNAL_KEYS, so forwarded + "custom_engine_flag": True, # Not orchestrator-owned, so forwarded } overrides = StageConfigFactory._merge_cli_overrides(stage, cli_overrides) @@ -311,6 +326,56 @@ def test_per_stage_override_excludes_internal_keys(self): assert "batch_timeout" not in overrides +class TestStageResolutionHelpers: + """Tests for shared stage override / filtering helpers.""" + + def test_build_stage_runtime_overrides_ignores_other_stage_and_internal_keys(self): + # Pass the same filter set the function uses by default + # (orchestrator-only fields plus SHARED_FIELDS so ``model`` is + # treated as not-per-stage-overridable). + overrides = build_stage_runtime_overrides( + 0, + { + "gpu_memory_utilization": 0.5, + "stage_0_gpu_memory_utilization": 0.9, + "stage_1_gpu_memory_utilization": 0.1, + "stage_0_model": "should_be_ignored", + "parallel_config": {"world_size": 2}, + }, + internal_keys=internal_blacklist_keys() | SHARED_FIELDS, + ) + + assert overrides["gpu_memory_utilization"] == 0.9 + assert "model" not in overrides + assert "parallel_config" not in overrides + + def test_strip_parent_engine_args_reports_only_surprising_parent_overrides(self): + from dataclasses import fields as dc_fields + + from vllm.engine.arg_utils import EngineArgs + + parent_fields = {f.name: f for f in dc_fields(EngineArgs)} + filtered, overridden = strip_parent_engine_args( + { + "model": "some/model", + "stage_configs_path": "/tmp/stages.yaml", + "tensor_parallel_size": 4, + "worker_extension_cls": "some.Extension", + "custom_pipeline_args": {"pipeline_class": "demo.Pipeline"}, + }, + parent_fields=parent_fields, + keep_keys={"worker_extension_cls"}, + strip_keys={"stage_configs_path"}, + no_warn_keys={"model"}, + ) + + assert filtered == { + "worker_extension_cls": "some.Extension", + "custom_pipeline_args": {"pipeline_class": "demo.Pipeline"}, + } + assert overridden == ["tensor_parallel_size"] + + class TestPipelineYamlParsing: """Tests for pipeline YAML file parsing (@ZJY0516).""" @@ -609,16 +674,617 @@ def test_parse_missing_async_chunk_defaults_false(self, tmp_path): assert pipeline.async_chunk is False -class TestArchitectureFallback: - """Tests for architecture-based model detection fallback.""" +class TestPipelineDiscovery: + """Tests for the central pipeline registry (``pipeline_registry._VLLM_OMNI_PIPELINES``).""" + + def test_registry_has_known_models(self): + """Built-in pipelines are lazy-loaded from the central declaration + on first access; no eager import or discovery walk needed.""" + # ``in`` triggers the lazy-map lookup without forcing a load. + assert "qwen2_5_omni" in _PIPELINE_REGISTRY + assert "qwen3_omni_moe" in _PIPELINE_REGISTRY + assert "qwen3_tts" in _PIPELINE_REGISTRY + + def test_registry_loads_pipeline_on_getitem(self): + """Looking up a registered model_type returns the matching PipelineConfig.""" + pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"] + assert pipeline.model_type == "qwen3_omni_moe" + assert len(pipeline.stages) == 3 # thinker + talker + code2wav + + def test_registry_returns_none_for_unknown(self): + """Unknown model_types aren't found; ``get()`` returns None.""" + assert "definitely_not_a_real_model" not in _PIPELINE_REGISTRY + assert _PIPELINE_REGISTRY.get("definitely_not_a_real_model") is None + + def test_pipeline_config_supports_hf_architectures(self): + """PipelineConfig accepts hf_architectures for HF-arch fallback + (replaces the old _ARCHITECTURE_MODELS dict).""" + p = PipelineConfig( + model_type="custom_collide", + hf_architectures=("SomeCollidingArch",), + ) + assert p.hf_architectures == ("SomeCollidingArch",) + + +class TestStagePipelineConfig: + def test_frozen(self): + s = StagePipelineConfig(stage_id=0, model_stage="a") + with pytest.raises(AttributeError): + s.model_stage = "changed" + + def test_defaults(self): + s = StagePipelineConfig(stage_id=0, model_stage="a") + assert s.execution_type == StageExecutionType.LLM_AR + assert s.input_sources == () + assert s.final_output is False + assert s.sampling_constraints == {} + assert s.engine_output_type is None + + +class TestPipelineConfigNew: + def test_frozen(self): + p = PipelineConfig(model_type="t", model_arch="A") + with pytest.raises(AttributeError): + p.model_type = "changed" + + def test_validate_valid(self): + p = PipelineConfig( + model_type="t", + model_arch="A", + stages=( + StagePipelineConfig(stage_id=0, model_stage="a"), + StagePipelineConfig(stage_id=1, model_stage="b", input_sources=(0,)), + ), + ) + assert p.validate() == [] + + def test_validate_no_stages(self): + p = PipelineConfig(model_type="t", model_arch="A") + assert any("no stages" in e.lower() for e in p.validate()) + + def test_get_scheduler_cls(self): + p = PipelineConfig( + model_type="t", + model_arch="A", + stages=( + StagePipelineConfig(stage_id=0, model_stage="a", execution_type=StageExecutionType.LLM_AR), + StagePipelineConfig( + stage_id=1, model_stage="b", execution_type=StageExecutionType.LLM_GENERATION, input_sources=(0,) + ), + ), + ) + assert "OmniARScheduler" in p.get_scheduler_cls(0) + assert "OmniGenerationScheduler" in p.get_scheduler_cls(1) + + +class TestExecutionTypeToScheduler: + def test_all_types_mapped(self): + for et in StageExecutionType: + assert et in _EXECUTION_TYPE_TO_SCHEDULER + + +class TestPipelineRegistry: + def test_register_and_lookup(self): + p = PipelineConfig( + model_type="__test_only__", + model_arch="A", + stages=(StagePipelineConfig(stage_id=0, model_stage="a"),), + ) + register_pipeline(p) + assert _PIPELINE_REGISTRY["__test_only__"] is p + del _PIPELINE_REGISTRY["__test_only__"] + + +class TestDeployConfigLoading: + def test_load_deploy_config(self): + from pathlib import Path + + from vllm_omni.config.stage_config import load_deploy_config + + deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml" + if not deploy_path.exists(): + pytest.skip("Deploy config not found") + + deploy = load_deploy_config(deploy_path) + assert len(deploy.stages) == 3 + assert deploy.async_chunk is True + assert deploy.connectors is not None + assert deploy.platforms is not None + + def test_merge_pipeline_deploy(self): + from pathlib import Path + + import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401 + from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy + + pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"] + deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml" + if not deploy_path.exists(): + pytest.skip("Deploy config not found") + + deploy = load_deploy_config(deploy_path) + stages = merge_pipeline_deploy(pipeline, deploy) + + assert len(stages) == 3 + s0 = stages[0] + assert s0.model_stage == "thinker" + assert s0.yaml_engine_args["model_arch"] == "Qwen3OmniMoeForConditionalGeneration" + assert s0.yaml_engine_args["engine_output_type"] == "latent" + assert s0.yaml_extras["default_sampling_params"]["detokenize"] is True + + +class TestQwen3OmniPipeline: + def test_registered(self): + import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401 + + p = _PIPELINE_REGISTRY.get("qwen3_omni_moe") + assert p is not None + assert p.model_arch == "Qwen3OmniMoeForConditionalGeneration" + assert len(p.stages) == 3 + assert p.validate() == [] + + def test_thinker(self): + import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401 + + s = _PIPELINE_REGISTRY["qwen3_omni_moe"].get_stage(0) + assert s.model_stage == "thinker" + assert s.execution_type == StageExecutionType.LLM_AR + assert s.owns_tokenizer is True + assert s.engine_output_type == "latent" + assert s.sampling_constraints["detokenize"] is True + + def test_talker(self): + import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401 + + s = _PIPELINE_REGISTRY["qwen3_omni_moe"].get_stage(1) + assert s.input_sources == (0,) + assert s.sampling_constraints["stop_token_ids"] == [2150] + assert s.custom_process_input_func is not None + assert s.custom_process_next_stage_input_func is not None + + def test_code2wav(self): + import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401 + + s = _PIPELINE_REGISTRY["qwen3_omni_moe"].get_stage(2) + assert s.execution_type == StageExecutionType.LLM_GENERATION + assert s.final_output_type == "audio" + assert s.custom_process_input_func is not None + + +class TestQwen2_5OmniPipeline: + def test_registered(self): + import vllm_omni.model_executor.models.qwen2_5_omni.pipeline # noqa: F401 + + p = _PIPELINE_REGISTRY.get("qwen2_5_omni") + assert p is not None + assert p.model_arch == "Qwen2_5OmniForConditionalGeneration" + assert len(p.stages) == 3 + assert p.validate() == [] + + def test_thinker(self): + import vllm_omni.model_executor.models.qwen2_5_omni.pipeline # noqa: F401 + + s = _PIPELINE_REGISTRY["qwen2_5_omni"].get_stage(0) + assert s.model_stage == "thinker" + assert s.execution_type == StageExecutionType.LLM_AR + assert s.owns_tokenizer is True + assert s.engine_output_type == "latent" + assert s.requires_multimodal_data is True + + def test_talker(self): + import vllm_omni.model_executor.models.qwen2_5_omni.pipeline # noqa: F401 + + s = _PIPELINE_REGISTRY["qwen2_5_omni"].get_stage(1) + assert s.input_sources == (0,) + assert s.sampling_constraints["stop_token_ids"] == [8294] + assert s.custom_process_input_func is not None + + def test_code2wav(self): + import vllm_omni.model_executor.models.qwen2_5_omni.pipeline # noqa: F401 + + s = _PIPELINE_REGISTRY["qwen2_5_omni"].get_stage(2) + assert s.execution_type == StageExecutionType.LLM_GENERATION + assert s.final_output_type == "audio" + assert s.engine_output_type == "audio" + + +class TestQwen3TTSPipeline: + def test_registered(self): + import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401 + + p = _PIPELINE_REGISTRY.get("qwen3_tts") + assert p is not None + assert p.model_arch == "Qwen3TTSTalkerForConditionalGeneration" + assert len(p.stages) == 2 + assert p.validate() == [] + + def test_talker_stage(self): + import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401 + + s = _PIPELINE_REGISTRY["qwen3_tts"].get_stage(0) + assert s.model_stage == "qwen3_tts" + assert s.execution_type == StageExecutionType.LLM_AR + assert s.owns_tokenizer is True + assert s.engine_output_type == "latent" + assert s.sampling_constraints["stop_token_ids"] == [2150] + # Stage 0 inherits the pipeline-level model_arch + assert s.model_arch is None + + def test_code2wav_stage_has_per_stage_model_arch(self): + import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401 + + s = _PIPELINE_REGISTRY["qwen3_tts"].get_stage(1) + assert s.execution_type == StageExecutionType.LLM_GENERATION + assert s.final_output_type == "audio" + assert s.engine_output_type == "audio" + # Per-stage model_arch override (different from pipeline-level talker) + assert s.model_arch == "Qwen3TTSCode2Wav" + # tts_args is passed through via extras + assert s.extras["tts_args"]["max_instructions_length"] == 500 + + def test_per_stage_model_arch_flows_through_merge(self, tmp_path): + """Verify the new ps.model_arch override survives merge_pipeline_deploy.""" + import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401 + from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy + + deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_tts.yaml" + if not deploy_path.exists(): + pytest.skip("qwen3_tts deploy yaml not found") + + deploy = load_deploy_config(deploy_path) + pipeline = _PIPELINE_REGISTRY["qwen3_tts"] + stages = merge_pipeline_deploy(pipeline, deploy) + + # Stage 0 inherits pipeline-level model_arch + assert stages[0].yaml_engine_args["model_arch"] == "Qwen3TTSTalkerForConditionalGeneration" + # Stage 1 uses its per-stage override + assert stages[1].yaml_engine_args["model_arch"] == "Qwen3TTSCode2Wav" + + +class TestBaseConfigInheritance: + """Test deploy YAML base_config inheritance.""" + + def test_ci_inherits_from_main(self): + from tests.utils import get_deploy_config_path + from vllm_omni.config.stage_config import load_deploy_config + + ci_path = Path(get_deploy_config_path("ci/qwen3_omni_moe.yaml")) + if not ci_path.exists(): + pytest.skip("CI deploy config not found") + + deploy = load_deploy_config(ci_path) + assert len(deploy.stages) == 3 + # CI overrides + assert deploy.stages[0].engine_extras.get("load_format") == "dummy" + assert deploy.stages[0].max_num_seqs == 5 + # Inherited from base + assert deploy.stages[0].gpu_memory_utilization == 0.9 + assert deploy.connectors is not None + assert "connector_of_shared_memory" in deploy.connectors + # CI overlay explicitly sets async_chunk: False (see + # tests/utils.py::_CI_OVERLAYS and PR #2383 discussion). Overlay + # bool overrides base even when the base yaml has async_chunk: true. + assert deploy.async_chunk is False + + def test_ci_sampling_merge(self): + from tests.utils import get_deploy_config_path + from vllm_omni.config.stage_config import load_deploy_config + + ci_path = Path(get_deploy_config_path("ci/qwen3_omni_moe.yaml")) + if not ci_path.exists(): + pytest.skip("CI deploy config not found") + + deploy = load_deploy_config(ci_path) + s0 = deploy.stages[0].default_sampling_params + # CI overrides max_tokens + assert s0["max_tokens"] == 150 + # Inherited from base + assert s0["temperature"] == 0.4 + assert s0["seed"] == 42 + + def test_pure_inheritance_overlay(self, tmp_path): + """An overlay with only ``base_config`` inherits everything.""" + from vllm_omni.config.stage_config import load_deploy_config + + base = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml" + if not base.exists(): + pytest.skip("Base deploy config not found") + + overlay = tmp_path / "overlay.yaml" + overlay.write_text(f"base_config: {base}\n") + + deploy = load_deploy_config(overlay) + assert len(deploy.stages) == 3 + assert deploy.stages[0].gpu_memory_utilization == 0.9 + + def test_single_field_overlay(self, tmp_path): + """An overlay overriding one stage field merges with the base.""" + from vllm_omni.config.stage_config import load_deploy_config + + base = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml" + if not base.exists(): + pytest.skip("Base deploy config not found") + + overlay = tmp_path / "overlay.yaml" + overlay.write_text(f"base_config: {base}\nstages:\n - stage_id: 2\n max_num_batched_tokens: 1000000\n") + + deploy = load_deploy_config(overlay) + assert deploy.stages[2].max_num_batched_tokens == 1000000 + # Rest inherited + assert deploy.stages[0].gpu_memory_utilization == 0.9 + + +class TestPlatformOverrides: + """Test platform-specific deploy config overrides.""" + + def test_npu_overrides(self): + from pathlib import Path + + from vllm_omni.config.stage_config import _apply_platform_overrides, load_deploy_config + + deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml" + if not deploy_path.exists(): + pytest.skip("Deploy config not found") + + deploy = load_deploy_config(deploy_path) + deploy = _apply_platform_overrides(deploy, platform="npu") + + assert deploy.stages[0].gpu_memory_utilization == 0.6 + assert deploy.stages[0].tensor_parallel_size == 2 + assert deploy.stages[0].devices == "0,1" + # Stage 2 unaffected fields stay at base + assert deploy.stages[2].enforce_eager is True + + def test_xpu_overrides(self): + from pathlib import Path + + from vllm_omni.config.stage_config import _apply_platform_overrides, load_deploy_config + + deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml" + if not deploy_path.exists(): + pytest.skip("Deploy config not found") + + deploy = load_deploy_config(deploy_path) + deploy = _apply_platform_overrides(deploy, platform="xpu") + + assert deploy.stages[0].tensor_parallel_size == 4 + assert deploy.stages[0].devices == "0,1,2,3" + assert deploy.stages[0].engine_extras.get("max_cudagraph_capture_size") == 0 + + def test_unknown_platform_noop(self): + from pathlib import Path + + from vllm_omni.config.stage_config import _apply_platform_overrides, load_deploy_config + + deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml" + if not deploy_path.exists(): + pytest.skip("Deploy config not found") + + deploy = load_deploy_config(deploy_path) + original_mem = deploy.stages[0].gpu_memory_utilization + deploy = _apply_platform_overrides(deploy, platform="unknown_hw") + assert deploy.stages[0].gpu_memory_utilization == original_mem + + def test_platforms_deep_merge_inheritance(self, tmp_path): + """Overlay's platforms: block layers onto base's, per-stage.""" + from vllm_omni.config.stage_config import _apply_platform_overrides, load_deploy_config + + base = tmp_path / "base.yaml" + base.write_text( + "stages:\n" + " - stage_id: 0\n" + " gpu_memory_utilization: 0.9\n" + "platforms:\n" + " rocm:\n" + " stages:\n" + " - stage_id: 0\n" + " enforce_eager: true\n" + ) + overlay = tmp_path / "overlay.yaml" + overlay.write_text( + f"base_config: {base.name}\n" + "platforms:\n" + " rocm:\n" + " stages:\n" + " - stage_id: 0\n" + " max_num_seqs: 1\n" + ) + + deploy = load_deploy_config(overlay) + deploy = _apply_platform_overrides(deploy, platform="rocm") + # Both base's enforce_eager and overlay's max_num_seqs should apply. + assert deploy.stages[0].enforce_eager is True + assert deploy.stages[0].max_num_seqs == 1 + # Inherited stage default not touched by overlay platforms section. + assert deploy.stages[0].gpu_memory_utilization == 0.9 + + +class TestCLIOverrideFlow: + """Test --stage-overrides JSON merge into StageConfig.""" + + def test_stage_overrides_merge(self): + from pathlib import Path + + import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401 + from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy + + pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"] + deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml" + if not deploy_path.exists(): + pytest.skip("Deploy config not found") + + deploy = load_deploy_config(deploy_path) + stages = merge_pipeline_deploy(pipeline, deploy) + + # Simulate --stage-overrides '{"0": {"gpu_memory_utilization": 0.5}}' + overrides = {"stage_0_gpu_memory_utilization": 0.5} + stages[0].runtime_overrides = StageConfigFactory._merge_cli_overrides(stages[0], overrides) + assert stages[0].runtime_overrides["gpu_memory_utilization"] == 0.5 + + def test_global_override_applies_to_all(self): + from pathlib import Path + + import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401 + from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy + + pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"] + deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml" + if not deploy_path.exists(): + pytest.skip("Deploy config not found") + + deploy = load_deploy_config(deploy_path) + stages = merge_pipeline_deploy(pipeline, deploy) + + overrides = {"enforce_eager": True} + for s in stages: + s.runtime_overrides = StageConfigFactory._merge_cli_overrides(s, overrides) + assert s.runtime_overrides["enforce_eager"] is True + + +class TestCLIExplicitPrecedence: + """Verify YAML > argparse defaults; explicit CLI args > YAML.""" + + def _stages(self, cli_overrides, cli_explicit_keys): + import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401 + + return StageConfigFactory._create_from_registry( + "qwen3_omni_moe", + cli_overrides=cli_overrides, + cli_explicit_keys=cli_explicit_keys, + ) + + def test_explicit_cli_overrides_yaml(self): + """User-typed --max-num-seqs wins over the deploy YAML value.""" + stages = self._stages( + cli_overrides={"max_num_seqs": 999}, + cli_explicit_keys={"max_num_seqs"}, + ) + # Stage 2 yaml has max_num_seqs=1; explicit CLI must beat it. + assert stages[2].runtime_overrides.get("max_num_seqs") == 999 + + def test_default_cli_does_not_override_yaml(self): + """Argparse defaults must NOT clobber values that are present in YAML.""" + stages = self._stages( + cli_overrides={"max_num_seqs": 256}, + cli_explicit_keys=set(), # user typed nothing + ) + # Stage 2's YAML value (1) should win because the user didn't type --max-num-seqs. + assert stages[2].runtime_overrides.get("max_num_seqs") != 256 + + def test_default_cli_fills_missing_yaml_field(self): + """Argparse defaults still fill fields the YAML doesn't set.""" + stages = self._stages( + cli_overrides={"some_unrelated_knob": "fallback"}, + cli_explicit_keys=set(), + ) + # Field absent from YAML → CLI default flows through as a fallback. + assert stages[0].runtime_overrides.get("some_unrelated_knob") == "fallback" + + def test_per_stage_overrides_always_explicit(self): + """``stage__*`` keys are always treated as explicit.""" + stages = self._stages( + cli_overrides={"stage_0_gpu_memory_utilization": 0.42}, + cli_explicit_keys=set(), # not in the explicit set, but per-stage + ) + assert stages[0].runtime_overrides.get("gpu_memory_utilization") == 0.42 + + def test_none_explicit_set_treats_all_as_explicit(self): + """Programmatic Omni() callers (cli_explicit_keys=None) keep current behavior.""" + stages = self._stages( + cli_overrides={"max_num_seqs": 999}, + cli_explicit_keys=None, + ) + assert stages[2].runtime_overrides.get("max_num_seqs") == 999 + + def test_explicit_async_chunk_false_overrides_yaml(self): + """``--no-async-chunk`` flips the deploy-level async_chunk to False even + when the YAML sets it to True. Verifies that the per-stage + ``async_chunk: True`` injection in ``merge_pipeline_deploy`` is skipped + and that ``async_chunk`` does not leak through ``_merge_cli_overrides``. + """ + stages = self._stages( + cli_overrides={"async_chunk": False}, + cli_explicit_keys={"async_chunk"}, + ) + # qwen3_omni_moe.yaml has `async_chunk: true`, so by default every + # stage's engine_args would carry it. With the explicit override, it + # must NOT show up. + for stage in stages: + assert stage.yaml_engine_args.get("async_chunk") is not True + assert stage.runtime_overrides.get("async_chunk") is None + + def test_default_async_chunk_leaves_yaml_alone(self): + """An unset ``--async-chunk`` (default None) must leave the YAML's True + in force on every stage.""" + stages = self._stages( + cli_overrides={"async_chunk": None}, + cli_explicit_keys=set(), + ) + # qwen3_omni_moe.yaml: `async_chunk: true` → injected on every stage. + for stage in stages: + assert stage.yaml_engine_args.get("async_chunk") is True + + def test_explicit_enable_prefix_caching_overrides_yaml(self): + """``--enable-prefix-caching`` (global) flips every stage's + ``enable_prefix_caching`` to True regardless of the YAML default.""" + stages = self._stages( + cli_overrides={"enable_prefix_caching": True}, + cli_explicit_keys={"enable_prefix_caching"}, + ) + for stage in stages: + assert stage.runtime_overrides.get("enable_prefix_caching") is True + + def test_async_chunk_dispatches_processors(self): + """A single ``qwen3_tts`` pipeline picks per-chunk vs end-to-end + processors based on ``deploy.async_chunk``, without needing a + separate variant pipeline registration.""" + import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401 + from vllm_omni.config.stage_config import ( + _PIPELINE_REGISTRY, + DeployConfig, + merge_pipeline_deploy, + ) + + pipeline = _PIPELINE_REGISTRY["qwen3_tts"] + + # async_chunk=True → stage 0's per-chunk processor wires up, stage 1 + # has no sync input processor. + async_stages = merge_pipeline_deploy(pipeline, DeployConfig(async_chunk=True)) + assert ( + async_stages[0] + .yaml_engine_args.get("custom_process_next_stage_input_func", "") + .endswith("talker2code2wav_async_chunk") + ) + assert async_stages[1].custom_process_input_func is None + + # async_chunk=False → stage 0 has no streaming processor, stage 1's + # batch-end processor wires up. + sync_stages = merge_pipeline_deploy(pipeline, DeployConfig(async_chunk=False)) + assert "custom_process_next_stage_input_func" not in sync_stages[0].yaml_engine_args + assert sync_stages[1].custom_process_input_func is not None + assert sync_stages[1].custom_process_input_func.endswith("talker2code2wav") + + +class TestSamplingConstraintsPrecedence: + """Test that pipeline sampling_constraints override deploy defaults.""" + + def test_constraints_win(self): + from pathlib import Path + + import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401 + from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy + + pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"] + deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml" + if not deploy_path.exists(): + pytest.skip("Deploy config not found") - def test_architecture_models_mapping_exists(self): - """Test that _ARCHITECTURE_MODELS contains expected entries.""" - assert "MiMoAudioForConditionalGeneration" in StageConfigFactory._ARCHITECTURE_MODELS - assert StageConfigFactory._ARCHITECTURE_MODELS["MiMoAudioForConditionalGeneration"] == "mimo_audio" - assert "HunyuanImage3ForCausalMM" in StageConfigFactory._ARCHITECTURE_MODELS - assert StageConfigFactory._ARCHITECTURE_MODELS["HunyuanImage3ForCausalMM"] == "hunyuan_image3" + deploy = load_deploy_config(deploy_path) + stages = merge_pipeline_deploy(pipeline, deploy) - def test_mimo_audio_in_pipeline_models(self): - """Test that mimo_audio is registered in PIPELINE_MODELS.""" - assert "mimo_audio" in StageConfigFactory.PIPELINE_MODELS + # Pipeline says detokenize=True for thinker, deploy can't override + assert stages[0].yaml_extras["default_sampling_params"]["detokenize"] is True + # Pipeline says stop_token_ids=[2150] for talker + assert stages[1].yaml_extras["default_sampling_params"]["stop_token_ids"] == [2150] + # Deploy temperature still flows through + assert stages[0].yaml_extras["default_sampling_params"]["temperature"] == 0.4 diff --git a/tests/test_diffusion_config_fields.py b/tests/test_diffusion_config_fields.py new file mode 100644 index 0000000000..b87ceec1df --- /dev/null +++ b/tests/test_diffusion_config_fields.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Ensure diffusion stage YAML configs only use valid OmniDiffusionConfig fields. + +Regression test for https://github.com/vllm-project/vllm-omni/issues/2563 +""" + +from dataclasses import fields +from pathlib import Path + +import pytest +import yaml + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + +try: + from vllm_omni.diffusion.data import OmniDiffusionConfig +except Exception: + OmniDiffusionConfig = None + + +@pytest.mark.skipif( + OmniDiffusionConfig is None, + reason="OmniDiffusionConfig could not be imported (missing torch?)", +) +def test_diffusion_stage_configs_only_contain_valid_fields(): + """Diffusion stage engine_args must only contain OmniDiffusionConfig fields. + + Regression test for https://github.com/vllm-project/vllm-omni/issues/2563 + """ + # Scan both main configs and test configs + repo_root = Path(__file__).parent.parent + config_dirs = [ + repo_root / "vllm_omni" / "model_executor" / "stage_configs", + ] + # Also scan test directories recursively + test_dir = repo_root / "tests" + + yaml_paths: list[Path] = [] + for config_dir in config_dirs: + yaml_paths.extend(sorted(config_dir.glob("*.yaml"))) + yaml_paths.extend(sorted(test_dir.rglob("*.yaml"))) + + valid_fields = {f.name for f in fields(OmniDiffusionConfig)} + # model_stage is consumed by the stage init layer, not OmniDiffusionConfig + valid_fields.add("model_stage") + # model_arch is consumed by the stage init layer for diffusion model class resolution + valid_fields.add("model_arch") + # "quantization" is mapped to "quantization_config" by from_kwargs() backwards-compat + valid_fields.add("quantization") + + invalid_entries: list[tuple[str, set[str]]] = [] + for yaml_path in yaml_paths: + with open(yaml_path) as fh: + config = yaml.safe_load(fh) + + stages = config.get("stage_args", config.get("stages", [])) + for stage in stages: + if stage.get("stage_type") != "diffusion": + continue + engine_args = stage.get("engine_args", {}) + invalid = set(engine_args.keys()) - valid_fields + if invalid: + invalid_entries.append((yaml_path.relative_to(repo_root), invalid)) + + assert not invalid_entries, "Diffusion stage configs contain fields not in OmniDiffusionConfig:\n" + "\n".join( + f" {name}: {sorted(bad)}" for name, bad in invalid_entries + ) diff --git a/tests/test_diffusion_config_propagation.py b/tests/test_diffusion_config_propagation.py index 7d6d9c43f0..eeb3505efe 100644 --- a/tests/test_diffusion_config_propagation.py +++ b/tests/test_diffusion_config_propagation.py @@ -15,6 +15,7 @@ DiffusionParallelConfig, OmniDiffusionConfig, ) +from vllm_omni.diffusion.model_metadata import QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES pytestmark = [pytest.mark.core_model, pytest.mark.cpu] @@ -109,3 +110,12 @@ def test_extra_kwargs_forwarded(self): ea = stages[0]["engine_args"] assert ea["enforce_eager"] is True assert ea["lora_path"] == "/tmp/lora" + + +def test_qwen_image_edit_plus_sets_generic_multimodal_limit(): + od_config = OmniDiffusionConfig(model="Qwen/Qwen-Image-Edit-2511", model_class_name="QwenImageEditPlusPipeline") + + od_config.update_multimodal_support() + + assert od_config.supports_multimodal_inputs is True + assert od_config.max_multimodal_image_inputs == QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES diff --git a/tests/test_fish_speech_voice_cache.py b/tests/test_fish_speech_voice_cache.py index 8fe7a4a4d1..1c299d8014 100644 --- a/tests/test_fish_speech_voice_cache.py +++ b/tests/test_fish_speech_voice_cache.py @@ -10,11 +10,12 @@ import os import tempfile -from unittest.mock import MagicMock, patch +from pathlib import Path import numpy as np import pytest import torch +from pytest_mock import MockerFixture pytestmark = [pytest.mark.core_model, pytest.mark.cpu] @@ -61,18 +62,18 @@ class TestFishSpeechVoiceCacheIntegration: """Test the cache-hit / cache-miss / no-cache paths in the model.""" @pytest.fixture - def mock_model(self): + def mock_model(self, mocker: MockerFixture): """Create a mock FishSpeechSlowARForConditionalGeneration with cache.""" from vllm_omni.utils.voice_cache import VoiceEmbeddingCache - model = MagicMock() + model = mocker.MagicMock() model._voice_cache = VoiceEmbeddingCache(max_entries=4) model._semantic_begin_id = 151678 model._num_codebooks = 10 model._codebook_size = 4096 model.model_path = "/fake/model" - model.codebook_embeddings = MagicMock() - model.codebook_embeddings.weight = MagicMock() + model.codebook_embeddings = mocker.MagicMock() + model.codebook_embeddings.weight = mocker.MagicMock() model.codebook_embeddings.weight.device = torch.device("cpu") return model @@ -166,9 +167,13 @@ def test_created_at_zero_disables_cache(self, mock_model): class TestFishSpeechValidatorUploadedVoice: """Test _validate_fish_tts_request uploaded voice resolution.""" - def test_uploaded_voice_resolves_ref_audio(self): + def test_uploaded_voice_resolves_ref_audio( + self, + monkeypatch: pytest.MonkeyPatch, + mocker: MockerFixture, + ): """When voice matches an uploaded speaker, ref_audio should be auto-set.""" - request = MagicMock() + request = mocker.MagicMock() request.input = "Hello" request.voice = "alice" request.ref_audio = None @@ -185,17 +190,21 @@ def test_uploaded_voice_resolves_ref_audio(self): } # Simulate: voice in uploaded_speakers, file exists, get_audio returns data URL. - with patch("pathlib.Path.exists", return_value=True): - voice_lower = request.voice.lower() - assert voice_lower in uploaded_speakers + monkeypatch.setattr(Path, "exists", lambda self: True) - speaker_info = uploaded_speakers[voice_lower] - ref_text_from_upload = speaker_info.get("ref_text") - assert ref_text_from_upload == "Hi this is Alice" + voice_lower = request.voice.lower() + assert voice_lower in uploaded_speakers + + speaker_info = uploaded_speakers[voice_lower] + ref_text_from_upload = speaker_info.get("ref_text") + assert ref_text_from_upload == "Hi this is Alice" - def test_uploaded_voice_without_ref_text_uses_request_ref_text(self): + def test_uploaded_voice_without_ref_text_uses_request_ref_text( + self, + mocker: MockerFixture, + ): """If upload has no ref_text but request provides it, use request's.""" - request = MagicMock() + request = mocker.MagicMock() request.input = "Hello" request.voice = "bob" request.ref_audio = None diff --git a/tests/test_generate_nightly_perf_excel.py b/tests/test_generate_nightly_perf_excel.py new file mode 100644 index 0000000000..9b05d6de0f --- /dev/null +++ b/tests/test_generate_nightly_perf_excel.py @@ -0,0 +1,71 @@ +import importlib.util +import json +import sys +from pathlib import Path + +import pytest +from openpyxl import load_workbook + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def _load_excel_module(): + repo_root = Path(__file__).resolve().parents[1] + module_path = repo_root / "tools" / "nightly" / "generate_nightly_perf_excel.py" + spec = importlib.util.spec_from_file_location("generate_nightly_perf_excel", module_path) + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def _cell_value_by_header(ws, header_name: str, row_idx: int = 2): + headers = [c.value for c in ws[1]] + col_idx = headers.index(header_name) + 1 + return ws.cell(row=row_idx, column=col_idx).value + + +def test_generate_excel_report_with_perf_templates(tmp_path: Path): + module = _load_excel_module() + repo_root = Path(__file__).resolve().parents[1] + perf_scripts_dir = repo_root / "tests" / "dfx" / "perf" / "scripts" + + omni_template_path = perf_scripts_dir / "result_omni_template.json" + diffusion_template_path = perf_scripts_dir / "diffusion_result_template.json" + + omni_record = json.loads(omni_template_path.read_text(encoding="utf-8")) + diffusion_records = json.loads(diffusion_template_path.read_text(encoding="utf-8")) + + input_dir = tmp_path / "input" + diffusion_input_dir = tmp_path / "diffusion_input" + input_dir.mkdir() + diffusion_input_dir.mkdir() + + # Keep file names compatible with parser conventions in generate_nightly_perf_excel.py + omni_result_file = input_dir / "result_test_perf_random_1_4_in2500_out900_20260415-185642.json" + diffusion_result_file = diffusion_input_dir / "diffusion_result_qwen_image_edit_20260415-193200.json" + omni_result_file.write_text(json.dumps(omni_record, ensure_ascii=False, indent=2), encoding="utf-8") + diffusion_result_file.write_text(json.dumps(diffusion_records, ensure_ascii=False, indent=2), encoding="utf-8") + + output_file = tmp_path / "nightly_perf.xlsx" + module.generate_excel_report( + input_dir=str(input_dir), + diffusion_input_dir=str(diffusion_input_dir), + output_file=str(output_file), + commit_sha="test_commit_sha", + build_id="test_build_id", + build_url="https://example.com/build/123", + ) + + assert output_file.exists() + + wb = load_workbook(output_file) + assert set(wb.sheetnames) >= {"omni_summary", "diffusion_summary", "omni_raw", "diffusion_raw"} + + ws_omni_raw = wb["omni_raw"] + baseline_cell = _cell_value_by_header(ws_omni_raw, "baseline") + assert baseline_cell == json.dumps(omni_record["baseline"], ensure_ascii=False, sort_keys=True) + + ws_omni_summary = wb["omni_summary"] + assert _cell_value_by_header(ws_omni_summary, "commit_sha") == "test_commit_sha" + assert _cell_value_by_header(ws_omni_summary, "build_id") == "test_build_id" diff --git a/tests/test_generate_nightly_perf_html.py b/tests/test_generate_nightly_perf_html.py new file mode 100644 index 0000000000..4e77eb3adf --- /dev/null +++ b/tests/test_generate_nightly_perf_html.py @@ -0,0 +1,54 @@ +import importlib.util +import json +import sys +from pathlib import Path + +import pytest + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def _load_html_module(): + repo_root = Path(__file__).resolve().parents[1] + module_path = repo_root / "tools" / "nightly" / "generate_nightly_perf_html.py" + spec = importlib.util.spec_from_file_location("generate_nightly_perf_html", module_path) + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def test_generate_html_report_with_perf_templates(tmp_path: Path): + module = _load_html_module() + repo_root = Path(__file__).resolve().parents[1] + perf_scripts_dir = repo_root / "tests" / "dfx" / "perf" / "scripts" + + omni_template_path = perf_scripts_dir / "result_omni_template.json" + diffusion_template_path = perf_scripts_dir / "diffusion_result_template.json" + + omni_record = json.loads(omni_template_path.read_text(encoding="utf-8")) + diffusion_records = json.loads(diffusion_template_path.read_text(encoding="utf-8")) + + input_dir = tmp_path / "input" + diffusion_input_dir = tmp_path / "diffusion_input" + input_dir.mkdir() + diffusion_input_dir.mkdir() + + omni_result_file = input_dir / "result_test_perf_random_1_4_in2500_out900_20260415-185642.json" + diffusion_result_file = diffusion_input_dir / "diffusion_result_qwen_image_edit_20260415-193200.json" + omni_result_file.write_text(json.dumps(omni_record, ensure_ascii=False, indent=2), encoding="utf-8") + diffusion_result_file.write_text(json.dumps(diffusion_records, ensure_ascii=False, indent=2), encoding="utf-8") + + output_file = tmp_path / "nightly_perf_v2.html" + module.generate_html_report( + input_dir=str(input_dir), + diffusion_input_dir=str(diffusion_input_dir), + output_file=str(output_file), + ) + + assert output_file.exists() + html = output_file.read_text(encoding="utf-8") + assert "Nightly Performance Report" in html + assert "Omni records 1" in html + assert f"Diffusion records {len(diffusion_records)}" in html + assert "const DIFF_DATA =" in html diff --git a/tests/test_version.py b/tests/test_version.py new file mode 100644 index 0000000000..07e622a7d1 --- /dev/null +++ b/tests/test_version.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for version compatibility warnings.""" + +import warnings +from unittest import mock + +import pytest + +from vllm_omni.version import warn_if_misaligned_vllm_version + + +@mock.patch("vllm_omni.version.__version_tuple__", (0, 19, 0)) +@mock.patch("vllm_omni.version.__version__", "0.19.0") +@mock.patch("vllm.__version_tuple__", (0, 18, 0)) +@mock.patch("vllm.__version__", "0.18.0") +def test_version_mismatch_warning(): + """Ensure that we warn when vLLM and vLLM-Omni major/minor versions differ.""" + with pytest.warns(RuntimeWarning, match="mismatched major/minor versions"): + warn_if_misaligned_vllm_version() + + +@pytest.mark.parametrize( + "vllm_ver,vllm_tuple,omni_ver,omni_tuple", + [ + ("0.19.0", (0, 19, 0), "0.19.5", (0, 19, 5)), # Patch differs + ("0.18.0", (0, 18, 0), "dev", (0, 0, "dev")), # Omni dev + ("dev", (0, 0, "dev"), "0.19.0", (0, 19, 0)), # vLLM dev + # Ensure local identifies don't matter for the warning + ("0.19.0+foo", (0, 19, 0, "foo"), "0.19.5", (0, 19, 0)), + ("0.19.0", (0, 19, 0), "0.19.5+bar", (0, 19, 0, "bar")), + ("0.19.0+foo", (0, 19, 0, "foo"), "0.19.5+bar", (0, 19, 0, "bar")), + ], +) +def test_no_warning_cases(vllm_ver, vllm_tuple, omni_ver, omni_tuple): + """Ensure we don't warn when minor versions match or either is a dev build.""" + with ( + mock.patch.multiple("vllm", __version__=vllm_ver, __version_tuple__=vllm_tuple), + mock.patch.multiple("vllm_omni.version", __version__=omni_ver, __version_tuple__=omni_tuple), + ): + with warnings.catch_warnings(): + warnings.simplefilter("error") + warn_if_misaligned_vllm_version() + + +@mock.patch("vllm_omni.version.__version_tuple__", (0, 19, 0)) +@mock.patch("vllm_omni.version.__version__", "0.19.0rc2.dev21") +@mock.patch("vllm.__version_tuple__", (0, 18, 0)) +@mock.patch("vllm.__version__", "0.18.0") +def test_warning_contains_version_strings(): + """Ensure that the warning contains the full version strings.""" + with pytest.warns(RuntimeWarning) as record: + warn_if_misaligned_vllm_version() + + assert len(record) == 1 + msg = str(record[0].message) + assert "0.19.0rc2.dev21" in msg + assert "0.18.0" in msg diff --git a/tests/utils.py b/tests/utils.py index 84edbbf3d1..d8137cf963 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,6 +11,7 @@ import time from collections.abc import Callable from contextlib import ExitStack, contextmanager, suppress +from pathlib import Path from typing import Any, Literal import cloudpickle @@ -24,6 +25,221 @@ _P = ParamSpec("_P") +_REPO_ROOT = Path(__file__).resolve().parent.parent +_DEPLOY_DIR = _REPO_ROOT / "vllm_omni" / "deploy" +_CI_GENERATED_DIR = _REPO_ROOT / "tests" / ".ci_generated" + + +# CI overlays as Python dicts (LSP-friendly). Materialized on demand to +# tests/.ci_generated/.yaml via get_deploy_config_path("ci/.yaml"). +_CI_OVERLAYS: dict[str, dict[str, Any]] = { + "qwen2_5_omni": { + "base_config": "qwen2_5_omni.yaml", + "async_chunk": False, + "stages": [ + { + "stage_id": 0, + "max_model_len": 16384, + "max_num_batched_tokens": 16384, + "max_num_seqs": 1, + "gpu_memory_utilization": 0.9, + "skip_mm_profiling": True, + "load_format": "dummy", + "default_sampling_params": {"max_tokens": 128}, + }, + { + "stage_id": 1, + "max_model_len": 16384, + "max_num_batched_tokens": 16384, + "max_num_seqs": 1, + "gpu_memory_utilization": 0.4, + "skip_mm_profiling": True, + "load_format": "dummy", + "default_sampling_params": {"max_tokens": 4096}, + }, + { + "stage_id": 2, + "max_num_seqs": 1, + "gpu_memory_utilization": 0.5, + "max_num_batched_tokens": 8192, + "max_model_len": 8192, + "load_format": "dummy", + "devices": "2", + "default_sampling_params": {"max_tokens": 8192}, + }, + ], + "platforms": { + "rocm": { + "stages": [ + {"stage_id": 0, "gpu_memory_utilization": 0.9}, + {"stage_id": 1, "gpu_memory_utilization": 0.4}, + {"stage_id": 2, "gpu_memory_utilization": 0.5, "devices": "2"}, + ], + }, + "xpu": { + "stages": [ + { + "stage_id": 0, + "gpu_memory_utilization": 0.9, + "max_num_batched_tokens": 16384, + "max_model_len": 16384, + }, + {"stage_id": 1, "gpu_memory_utilization": 0.5}, + { + "stage_id": 2, + "gpu_memory_utilization": 0.3, + "max_num_batched_tokens": 4096, + "max_model_len": 4096, + "devices": "2", + }, + ], + }, + }, + }, + "qwen3_omni_moe": { + "base_config": "qwen3_omni_moe.yaml", + "async_chunk": False, + "stages": [ + { + "stage_id": 0, + "max_num_seqs": 5, + "max_model_len": 32768, + "mm_processor_cache_gb": 0, + "load_format": "dummy", + "default_sampling_params": {"max_tokens": 150, "ignore_eos": False}, + }, + { + "stage_id": 1, + "gpu_memory_utilization": 0.5, + "max_num_seqs": 5, + "max_model_len": 32768, + "load_format": "dummy", + "default_sampling_params": {"max_tokens": 1000}, + }, + { + "stage_id": 2, + "max_num_seqs": 5, + "max_num_batched_tokens": 100000, + "load_format": "dummy", + "default_sampling_params": {"max_tokens": 2000}, + }, + ], + "platforms": { + "rocm": { + "stages": [ + {"stage_id": 0, "max_num_seqs": 1, "default_sampling_params": {"max_tokens": 100}}, + { + "stage_id": 1, + "max_num_seqs": 1, + "enforce_eager": True, + "default_sampling_params": {"max_tokens": 100}, + }, + { + "stage_id": 2, + "max_num_seqs": 1, + "max_num_batched_tokens": 1000000, + "default_sampling_params": {"max_tokens": 200}, + }, + ], + }, + "xpu": { + "stages": [ + { + "stage_id": 0, + "gpu_memory_utilization": 0.85, + "max_num_seqs": 1, + "tensor_parallel_size": 4, + "enforce_eager": True, + "max_num_batched_tokens": 4096, + "max_model_len": 4096, + "max_cudagraph_capture_size": 0, + "skip_mm_profiling": True, + "devices": "0,1,2,3", + "default_sampling_params": {"max_tokens": 100, "ignore_eos": False}, + }, + { + "stage_id": 1, + "gpu_memory_utilization": 0.6, + "max_num_seqs": 1, + "enforce_eager": True, + "max_num_batched_tokens": 4096, + "max_model_len": 4096, + "max_cudagraph_capture_size": 0, + "skip_mm_profiling": True, + "devices": "4", + }, + { + "stage_id": 2, + "gpu_memory_utilization": 0.3, + "max_num_seqs": 1, + "max_num_batched_tokens": 100000, + "max_cudagraph_capture_size": 0, + "skip_mm_profiling": True, + "devices": "5", + "default_sampling_params": {"max_tokens": 2000}, + }, + ], + }, + }, + }, + # Single-stage thinker-only topology for the abort test. + "qwen2_5_omni_thinker_only": { + "async_chunk": False, + "pipeline": "qwen2_5_omni_thinker_only", + "stages": [ + { + "stage_id": 0, + "max_num_seqs": 1, + "gpu_memory_utilization": 0.9, + "enforce_eager": True, + "max_num_batched_tokens": 16384, + "max_model_len": 16384, + "skip_mm_profiling": True, + "mm_processor_cache_gb": 0, + "load_format": "dummy", + "devices": "0", + "default_sampling_params": { + "temperature": 0.0, + "top_p": 1.0, + "top_k": -1, + "max_tokens": 128, + "seed": 42, + "repetition_penalty": 1.1, + }, + }, + ], + }, +} + + +def _materialize_ci_overlay(model_type: str) -> Path: + import yaml + + if model_type not in _CI_OVERLAYS: + raise KeyError(f"No CI overlay registered for {model_type!r}. Available: {sorted(_CI_OVERLAYS)}") + + _CI_GENERATED_DIR.mkdir(parents=True, exist_ok=True) + out = _CI_GENERATED_DIR / f"{model_type}.yaml" + + overlay = {**_CI_OVERLAYS[model_type]} + base = overlay.get("base_config") + if base: + overlay["base_config"] = str(_DEPLOY_DIR / base) + + with open(out, "w", encoding="utf-8") as f: + yaml.safe_dump(overlay, f, sort_keys=False) + return out + + +def get_deploy_config_path(rel_path: str) -> str: + """Resolve a deploy yaml; ``ci/.yaml`` materializes from ``_CI_OVERLAYS``.""" + if rel_path.startswith("ci/") and rel_path.endswith(".yaml"): + model_type = rel_path[len("ci/") : -len(".yaml")] + if model_type in _CI_OVERLAYS: + return str(_materialize_ci_overlay(model_type)) + return str(_DEPLOY_DIR / rel_path) + + if current_platform.is_rocm(): from amdsmi import ( amdsmi_get_gpu_vram_usage, diff --git a/tests/utils/test_audio.py b/tests/utils/test_audio.py new file mode 100644 index 0000000000..0e483e6468 --- /dev/null +++ b/tests/utils/test_audio.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Unit tests for vllm_omni.utils.audio.""" + +import numpy as np +import pytest +import torch + +from vllm_omni.utils.audio import mel_filter_bank, peak_normalize + +# Parameter combinations used across the codebase. +_PARAM_SETS = [ + # Qwen3-TTS talker / speaker encoder (sr=24000) + dict(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=12000), + # CosyVoice3 whisper encoder, Qwen3-TTS 25Hz tokenizer (sr=16000, 80 mels) + dict(sr=16000, n_fft=400, n_mels=80), + # CosyVoice3 whisper encoder (sr=16000, 128 mels) + dict(sr=16000, n_fft=400, n_mels=128), +] + +_parametrize_params = pytest.mark.parametrize( + "params", _PARAM_SETS, ids=lambda p: f"{p['sr']}_{p['n_fft']}_{p['n_mels']}" +) + + +class TestMelFilterBank: + @_parametrize_params + def test_output_shape(self, params): + fb = mel_filter_bank(**params) + n_freqs = params["n_fft"] // 2 + 1 + assert fb.shape == (params["n_mels"], n_freqs) + + @_parametrize_params + def test_non_negative(self, params): + fb = mel_filter_bank(**params) + assert (fb >= 0).all() + + def test_dtype_is_float(self): + fb = mel_filter_bank(sr=16000, n_fft=400, n_mels=80) + assert fb.dtype == torch.float32 + + def test_fmax_defaults_to_nyquist(self): + """When fmax is omitted it should equal sr / 2.""" + fb_default = mel_filter_bank(sr=16000, n_fft=400, n_mels=80) + fb_explicit = mel_filter_bank(sr=16000, n_fft=400, n_mels=80, fmax=8000.0) + torch.testing.assert_close(fb_default, fb_explicit) + + def test_each_mel_band_has_nonzero_energy(self): + """Every mel band should have at least one nonzero frequency bin.""" + fb = mel_filter_bank(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=12000) + for i in range(fb.shape[0]): + assert fb[i].sum() > 0, f"mel band {i} is all zeros" + + def test_higher_fmax_extends_coverage(self): + """A higher fmax should produce nonzero weights at higher frequency bins.""" + fb_low = mel_filter_bank(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=6000) + fb_high = mel_filter_bank(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=12000) + # The highest nonzero column should be larger for fb_high. + last_nonzero_low = (fb_low.sum(dim=0) > 0).nonzero()[-1].item() + last_nonzero_high = (fb_high.sum(dim=0) > 0).nonzero()[-1].item() + assert last_nonzero_high > last_nonzero_low + + +class TestPeakNormalize: + def test_silence_unchanged(self): + """All-zero input should remain all-zero.""" + audio = np.zeros(1600, dtype=np.float32) + result = peak_normalize(audio, db_level=-6.0) + np.testing.assert_array_equal(result, audio) + + def test_peak_reaches_target(self): + """After normalization, peak amplitude should be at target dB.""" + rng = np.random.default_rng(7) + audio = rng.uniform(-0.4, 0.4, size=16000).astype(np.float32) + + result = peak_normalize(audio, db_level=-6.0) + peak_db = 20 * np.log10(np.abs(result).max()) + np.testing.assert_allclose(peak_db, -6.0, atol=1e-4) diff --git a/tests/worker/test_omni_connector_mixin.py b/tests/worker/test_omni_connector_mixin.py new file mode 100644 index 0000000000..0e162a37e5 --- /dev/null +++ b/tests/worker/test_omni_connector_mixin.py @@ -0,0 +1,1419 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for OmniConnectorModelRunnerMixin. + +These tests use a mock connector (in-memory dict store) and do not require +GPU or vLLM runtime. +""" + +from __future__ import annotations + +import time +import unittest +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from vllm_omni.outputs import OmniConnectorOutput +from vllm_omni.worker.omni_connector_model_runner_mixin import ( + OmniConnectorModelRunnerMixin, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + +# ------------------------------------------------------------------ # +# Mock helpers +# ------------------------------------------------------------------ # + + +class MockConnector: + """In-memory connector for testing (mimics OmniConnectorBase).""" + + def __init__(self, stage_id: int = 0): + self.stage_id = stage_id + self._store: dict[str, Any] = {} + + def put(self, from_stage, to_stage, put_key, data): + key = f"{from_stage}_{to_stage}_{put_key}" + self._store[key] = data + return True, len(str(data)), None + + def get(self, from_stage, to_stage, get_key, metadata=None): + key = f"{from_stage}_{to_stage}_{get_key}" + data = self._store.pop(key, None) + if data is None: + return None + return data, len(str(data)) + + def close(self): + pass + + +def _make_model_config( + stage_id: int = 0, + async_chunk: bool = False, + worker_type: str = "ar", + custom_func: str | None = None, +) -> SimpleNamespace: + return SimpleNamespace( + stage_connector_config=None, + async_chunk=async_chunk, + worker_type=worker_type, + custom_process_next_stage_input_func=custom_func, + ) + + +def _make_request(req_id: str, external_req_id: str | None = None): + r = SimpleNamespace( + request_id=req_id, + external_req_id=external_req_id or req_id, + additional_information=None, + prompt_token_ids=[], + num_computed_tokens=0, + ) + return r + + +class MixinHost(OmniConnectorModelRunnerMixin): + """Minimal class that mixes in the mixin for testing.""" + + pass + + +class _FakeTPGroup: + def __init__(self, *, world_size: int, rank_in_group: int, follower_result: Any = None): + self.world_size = world_size + self.rank_in_group = rank_in_group + self.follower_result = follower_result + self.broadcast_inputs: list[Any] = [] + + def broadcast_object(self, obj: Any | None = None, src: int = 0): + self.broadcast_inputs.append(obj) + if self.rank_in_group == src: + return obj + return self.follower_result + + +# ------------------------------------------------------------------ # +# Test cases +# ------------------------------------------------------------------ # + + +class TestMixinAsyncChunkSendRecv(unittest.TestCase): + """Test 2: Async chunk send/recv + bg threads.""" + + def test_send_chunk_passes_is_finished_and_connector(self): + connector = MockConnector(stage_id=0) + + sender = MixinHost() + sender.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=0, async_chunk=True), + ) + sender._omni_connector = connector + sender._stage_id = 0 + sender._async_chunk = True + + seen = {} + + def mock_process(transfer_manager, pooling_output, request, is_finished=False): + seen["connector"] = transfer_manager.connector + seen["is_finished"] = is_finished + return {"data": pooling_output, "finished": is_finished} + + sender._custom_process_func = mock_process + + request = _make_request("req-1", "ext-req-1") + request.is_finished = lambda: True + sender._send_single_request( + { + "stage_id": 0, + "next_stage_id": 1, + "request_id": "ext-req-1", + "request": request, + "pooling_output": {"value": 42}, + } + ) + self.assertIs(seen["connector"], connector) + self.assertTrue(seen["is_finished"]) + + sender.shutdown_omni_connectors() + + def test_send_chunk_does_not_retry_real_type_error(self): + connector = MockConnector(stage_id=0) + + sender = MixinHost() + sender.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=0, async_chunk=True), + ) + sender._omni_connector = connector + sender._stage_id = 0 + sender._async_chunk = True + + seen = {"calls": 0} + + def broken_process(transfer_manager, pooling_output, request, is_finished=""): + seen["calls"] += 1 + return {"data": is_finished + "tail"} + + sender._custom_process_func = broken_process + + request = _make_request("req-1", "ext-req-1") + request.is_finished = lambda: True + ok = sender.send_chunk(request, pooling_output={"value": 42}) + self.assertFalse(ok) + self.assertEqual(seen["calls"], 1) + + sender.shutdown_omni_connectors() + + +class TestMixinKVCacheTransfer(unittest.TestCase): + """Test 3: KV cache delegation to OmniKVTransferManager.""" + + def test_send_kv_delegates(self): + mock_kvm = MagicMock() + mock_kvm.handle_finished_requests_kv_transfer.return_value = ["req-1"] + + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + kv_transfer_manager=mock_kvm, + ) + + result = host.send_kv_cache( + finished_reqs={"req-1": {"seq_len": 10, "block_ids": [0]}}, + kv_caches=[], + block_size=16, + cache_dtype="float16", + ) + self.assertEqual(result, ["req-1"]) + mock_kvm.handle_finished_requests_kv_transfer.assert_called_once() + + host.shutdown_omni_connectors() + + def test_recv_kv_delegates(self): + mock_kvm = MagicMock() + mock_kvm.receive_kv_cache_for_request.return_value = ({"layer_blocks": {}}, 100) + + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + kv_transfer_manager=mock_kvm, + ) + + data, size = host.recv_kv_cache("req-1") + self.assertIsNotNone(data) + self.assertEqual(size, 100) + mock_kvm.receive_kv_cache_for_request.assert_called_once() + + host.shutdown_omni_connectors() + + def test_receive_multi_kv_fetches_companions_via_mixin(self): + mock_kvm = MagicMock() + + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + kv_transfer_manager=mock_kvm, + ) + + host.recv_kv_cache = MagicMock( + side_effect=[({"layer_blocks": {"k": [1]}}, 64), ({"layer_blocks": {"k": [2]}}, 32)] + ) + seen = {} + + def collect_cfg(request_id, cfg_role_payloads): + seen["request_id"] = request_id + seen["cfg_role_payloads"] = cfg_role_payloads + return {"cfg_text_kv_metadata": {"seq_len": 3}} + + req = SimpleNamespace( + request_id="req-1", + sampling_params=SimpleNamespace(cfg_kv_request_ids={"cfg_text": "req-1__cfg_text"}), + ) + ok = host.receive_multi_kv_cache(req, cfg_kv_collect_func=collect_cfg) + self.assertTrue(ok) + host.recv_kv_cache.assert_any_call("req-1", target_device=None) + host.recv_kv_cache.assert_any_call("req-1__cfg_text", target_device=None) + mock_kvm.apply_kv_cache_to_request.assert_called_once_with(req, {"layer_blocks": {"k": [1]}}) + self.assertEqual(seen["request_id"], "req-1") + self.assertEqual( + seen["cfg_role_payloads"], + {"cfg_text": ({"layer_blocks": {"k": [2]}}, 32)}, + ) + self.assertEqual(req.sampling_params.cfg_text_kv_metadata, {"seq_len": 3}) + + host.shutdown_omni_connectors() + + def test_receive_multi_kv_skips_inactive_request(self): + mock_kvm = MagicMock() + + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + kv_transfer_manager=mock_kvm, + ) + + host.requests = {} + host.recv_kv_cache = MagicMock(return_value=({"layer_blocks": {"k": [1]}}, 64)) + req = SimpleNamespace(request_id="req-1", sampling_params=None) + + ok = host.receive_multi_kv_cache(req) + + self.assertFalse(ok) + host.recv_kv_cache.assert_not_called() + mock_kvm.apply_kv_cache_to_request.assert_not_called() + + host.shutdown_omni_connectors() + + +class TestOmniConnectorOutput(unittest.TestCase): + """Test 4: Output aggregation across transfer modes.""" + + def test_output_aggregation(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + ) + + host._chunk_ready_req_ids.add("req-1") + host._chunk_finished_req_ids.add("req-2") + host._local_request_metadata["req-1"] = {"next_stage_prompt_len": 10} + host._stage_recv_req_ids.add("req-3") + + output = host.get_omni_connector_output() + self.assertIsInstance(output, OmniConnectorOutput) + self.assertEqual(output.chunk_ready_req_ids, {"req-1"}) + self.assertEqual(output.chunk_finished_req_ids, {"req-2"}) + self.assertEqual(output.request_metadata, {"req-1": {"next_stage_prompt_len": 10}}) + self.assertEqual(output.stage_recv_req_ids, {"req-3"}) + + output2 = host.get_omni_connector_output() + self.assertEqual(output2.chunk_ready_req_ids, set()) + self.assertEqual(output2.request_metadata, {}) + + host.shutdown_omni_connectors() + + +class TestMixinNoConnector(unittest.TestCase): + """Edge case: mixin works gracefully without a connector.""" + + def test_no_connector(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + ) + self.assertIsNone(host._omni_connector) + + results = host.recv_full_payload_inputs(scheduler_output=None) + self.assertIsNone(results) + + sent = host.send_full_payload_outputs(None, {"req-1": {}}) + self.assertEqual(sent, []) + + ok = host.send_chunk(_make_request("req-1"), pooling_output={}) + self.assertFalse(ok) + + output = host.get_omni_connector_output() + self.assertIsInstance(output, OmniConnectorOutput) + + host.shutdown_omni_connectors() + + +class TestFinishedLoadReqsDrain(unittest.TestCase): + """Test A1 fix: get_omni_connector_output drains _finished_load_reqs.""" + + def test_finished_load_reqs_flow_to_chunk_ready(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + ) + + host._finished_load_reqs.add("req-1") + host._finished_load_reqs.add("req-2") + + output = host.get_omni_connector_output() + self.assertIn("req-1", output.chunk_ready_req_ids) + self.assertIn("req-2", output.chunk_ready_req_ids) + + self.assertEqual(len(host._finished_load_reqs), 0) + self.assertEqual(len(host._chunk_ready_req_ids), 0) + + host.shutdown_omni_connectors() + + +class TestLoadCustomFuncSelection(unittest.TestCase): + def test_skips_legacy_stage_list_processors_for_full_payload_mode(self): + legacy_paths = [ + "vllm_omni.model_executor.stage_input_processors.mimo_audio.llm2code2wav", + "vllm_omni.model_executor.stage_input_processors.mammoth_moda2.ar2dit", + "vllm_omni.model_executor.stage_input_processors.cosyvoice3.text2flow", + "vllm_omni.model_executor.stage_input_processors.glm_image.ar2diffusion", + ] + + for func_path in legacy_paths: + selected_path, func = MixinHost._load_custom_func( + SimpleNamespace( + async_chunk=False, + custom_process_input_func=func_path, + custom_process_next_stage_input_func=None, + ) + ) + assert selected_path != func_path + assert func is None or MixinHost._is_connector_payload_builder(func) + + +class TestFullPayloadSendWithCustomFunc(unittest.TestCase): + """Test B4: send_full_payload_outputs with full_payload_mode custom process func.""" + + def test_full_payload_send_passes_is_finished_and_connector(self): + seen = {} + + def full_payload_func(transfer_manager, pooling_output, request, is_finished=False): + seen["connector"] = transfer_manager.connector + seen["is_finished"] = is_finished + seen["data"] = pooling_output + seen["rid"] = request.request_id if request else None + return {"processed": True, "finished": is_finished} + + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + ) + host._omni_connector = MockConnector(stage_id=0) + host._stage_id = 0 + host._custom_process_func = full_payload_func + + req = _make_request("req-1") + req.is_finished = lambda: True + sent = host.send_full_payload_outputs( + scheduler_output=None, + outputs={"req-1": ({"raw": 100}, req)}, + ) + self.assertEqual(sent, ["req-1"]) + self.assertEqual( + seen, + { + "connector": host._omni_connector, + "is_finished": True, + "data": {"raw": 100}, + "rid": "req-1", + }, + ) + + host.shutdown_omni_connectors() + + def test_accumulate_and_flush(self): + call_log = [] + + def full_payload_func(transfer_manager, pooling_output, request): + call_log.append(request.request_id if request else None) + return {"processed": True} + + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + ) + host._omni_connector = MockConnector(stage_id=0) + host._stage_id = 0 + host._custom_process_func = full_payload_func + + req = _make_request("req-1") + host.accumulate_full_payload_output("req-1", {"raw": 42}, req) + self.assertEqual(len(host._pending_full_payload_send), 1) + + host.flush_full_payload_outputs({"req-1"}) + self.assertEqual(len(host._pending_full_payload_send), 0) + self.assertEqual(len(call_log), 1) + self.assertEqual(call_log[0], "req-1") + + time.sleep(0.1) + host.shutdown_omni_connectors() + + +class TestKVSentReqIdsAccumulation(unittest.TestCase): + """Test that kv_sent_req_ids accumulates results from send_kv_cache.""" + + def test_kv_sent_accumulation(self): + mock_kvm = MagicMock() + mock_kvm.handle_finished_requests_kv_transfer.return_value = ["req-1", "req-2"] + + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(), + kv_transfer_manager=mock_kvm, + ) + + host.send_kv_cache( + finished_reqs={"req-1": {}, "req-2": {}}, + kv_caches=[], + block_size=16, + cache_dtype="float16", + ) + + output = host.get_omni_connector_output() + self.assertIn("req-1", output.kv_sent_req_ids) + self.assertIn("req-2", output.kv_sent_req_ids) + + output2 = host.get_omni_connector_output() + self.assertEqual(output2.kv_sent_req_ids, []) + + host.shutdown_omni_connectors() + + +class TestChunkStreamCompletedGuard(unittest.TestCase): + """Test that register_chunk_recv is skipped after finish sentinel. + + This validates the fix for the race condition where the scheduling + coordinator re-registers a request for chunk polling after its + upstream chunk stream has already finished (is_finished sentinel + received), causing the bg recv thread to poll for a non-existent + shared-memory segment (e.g. ``_0_7`` when only 7 chunks 0–6 exist). + """ + + def _make_host(self, stage_id: int = 1) -> MixinHost: + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=stage_id, async_chunk=True), + ) + host._omni_connector = MockConnector(stage_id=stage_id) + host._stage_id = stage_id + host._async_chunk = True + return host + + def test_register_blocked_after_finish_sentinel(self): + """register_chunk_recv must be a no-op after the finish sentinel.""" + host = self._make_host(stage_id=1) + + req = _make_request("req-1", "ext-req-1") + + # Simulate the bg thread having received the finish sentinel: + with host._lock: + host._chunk_stream_completed.add("req-1") + + # Now try to re-register — this mimics the coordinator asking + # the model runner to poll for the next (non-existent) chunk. + host.register_chunk_recv(req) + + # The request must NOT appear in _pending_load_reqs + self.assertNotIn( + "req-1", + host._pending_load_reqs, + "register_chunk_recv should skip requests whose chunk stream is already complete", + ) + + host.shutdown_omni_connectors() + + def test_register_allowed_before_finish(self): + """register_chunk_recv works normally before finish sentinel.""" + host = self._make_host(stage_id=1) + req = _make_request("req-1", "ext-req-1") + + host.register_chunk_recv(req) + self.assertIn( + "req-1", + host._pending_load_reqs, + "register_chunk_recv should add request to pending when stream is not yet complete", + ) + + host.shutdown_omni_connectors() + + def test_finish_sentinel_populates_completed_set(self): + """Receiving is_finished=True adds to _chunk_stream_completed.""" + host = self._make_host(stage_id=1) + + # Simulate _poll_single_request receiving is_finished=True + req_id = "req-1" + with host._lock: + host._chunk_finished_req_ids.add(req_id) + host._chunk_stream_completed.add(req_id) + host._local_stage_payload_cache[req_id] = {"finished": True} + host._local_request_metadata[req_id] = {} + host._finished_load_reqs.add(req_id) + host._pending_load_reqs.pop(req_id, None) + + self.assertIn(req_id, host._chunk_stream_completed) + + # Subsequent register_chunk_recv should be blocked + req = _make_request(req_id, f"ext-{req_id}") + host.register_chunk_recv(req) + self.assertNotIn(req_id, host._pending_load_reqs) + + host.shutdown_omni_connectors() + + def test_stage_0_always_skipped(self): + """Stage-0 has no upstream, register_chunk_recv is always no-op.""" + host = self._make_host(stage_id=0) + host._stage_id = 0 + + req = _make_request("req-1") + host.register_chunk_recv(req) + self.assertNotIn("req-1", host._pending_load_reqs) + + host.shutdown_omni_connectors() + + def test_full_payload_recv_guard_still_works(self): + """Pre-existing guard: staged full-payload results prevent registration.""" + host = self._make_host(stage_id=1) + + with host._lock: + host._stage_recv_req_ids.add("req-1") + + req = _make_request("req-1", "ext-req-1") + host.register_chunk_recv(req) + self.assertNotIn("req-1", host._pending_load_reqs) + + host.shutdown_omni_connectors() + + +class TestCleanupFinishedRequest(unittest.TestCase): + """Test cleanup_finished_request frees per-request mixin state.""" + + def _make_host(self, stage_id: int = 1) -> MixinHost: + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=stage_id, async_chunk=True), + ) + host._omni_connector = MockConnector(stage_id=stage_id) + host._stage_id = stage_id + host._async_chunk = True + return host + + def test_cleanup_removes_all_state(self): + """cleanup_finished_request removes all tracking dicts/sets.""" + host = self._make_host(stage_id=1) + req_id = "req-1" + ext_id = "ext-req-1" + + # Simulate state accumulated during a request's lifetime + host._request_ids_mapping[req_id] = ext_id + host._put_req_chunk[ext_id] = 5 + host._get_req_chunk[req_id] = 3 + host._send_side_request_payload[ext_id] = {"some": "data"} + host._code_prompt_token_ids[ext_id] = [[1, 2, 3]] + host._chunk_stream_completed.add(req_id) + host._stage_recv_req_ids.add(req_id) + host._local_stage_payload_cache[req_id] = {"engine_inputs": {}} + host._local_request_metadata[req_id] = {"prompt_len": 10} + + # Cleanup + host.cleanup_finished_request(req_id) + + # All state should be gone + self.assertNotIn(req_id, host._request_ids_mapping) + self.assertNotIn(ext_id, host._put_req_chunk) + self.assertNotIn(req_id, host._get_req_chunk) + self.assertNotIn(ext_id, host._send_side_request_payload) + self.assertNotIn(ext_id, host._code_prompt_token_ids) + self.assertNotIn(req_id, host._chunk_stream_completed) + self.assertNotIn(req_id, host._stage_recv_req_ids) + self.assertNotIn(req_id, host._local_stage_payload_cache) + self.assertNotIn(req_id, host._local_request_metadata) + + host.shutdown_omni_connectors() + + def test_cleanup_removes_per_cycle_ready_state(self): + """cleanup_finished_request clears ready/finished carry-over for req-id reuse.""" + host = self._make_host(stage_id=1) + req_id = "req-1" + + host._pending_load_reqs[req_id] = _make_request(req_id, "ext-req-1") + host._finished_load_reqs.add(req_id) + host._chunk_ready_req_ids.add(req_id) + host._chunk_finished_req_ids.add(req_id) + + host.cleanup_finished_request(req_id) + + self.assertNotIn(req_id, host._pending_load_reqs) + self.assertNotIn(req_id, host._finished_load_reqs) + self.assertNotIn(req_id, host._chunk_ready_req_ids) + self.assertNotIn(req_id, host._chunk_finished_req_ids) + + host.shutdown_omni_connectors() + + def test_cleanup_without_mapping(self): + """cleanup works for Stage-0 where _request_ids_mapping isn't set.""" + host = self._make_host(stage_id=0) + host._stage_id = 0 + req_id = "req-1" + + # Stage-0 uses req_id directly (no ext_id mapping) + host._put_req_chunk[req_id] = 3 + host._get_req_chunk[req_id] = 0 + + host.cleanup_finished_request(req_id) + + self.assertNotIn(req_id, host._put_req_chunk) + self.assertNotIn(req_id, host._get_req_chunk) + + host.shutdown_omni_connectors() + + def test_prune_inactive_requests_cleans_stale_state_but_keeps_active(self): + """Inactive request IDs should be pruned without touching active ones.""" + host = self._make_host(stage_id=1) + active_req_id = "req-active" + stale_req_id = "req-stale" + stale_ext_id = "ext-stale" + + host._request_ids_mapping[active_req_id] = "ext-active" + host._request_ids_mapping[stale_req_id] = stale_ext_id + host._put_req_chunk[stale_ext_id] = 2 + host._get_req_chunk[stale_req_id] = 1 + host._finished_load_reqs.add(stale_req_id) + host._chunk_ready_req_ids.update({active_req_id, stale_req_id}) + host._chunk_finished_req_ids.add(stale_req_id) + host._chunk_stream_completed.add(stale_req_id) + host._stage_recv_req_ids.add(active_req_id) + host._send_side_request_payload[stale_ext_id] = {"stale": True} + host._code_prompt_token_ids[stale_ext_id] = [[1, 2, 3]] + + pruned = host.prune_inactive_requests({active_req_id}) + + self.assertEqual(pruned, {stale_req_id}) + self.assertIn(active_req_id, host._request_ids_mapping) + self.assertIn(active_req_id, host._chunk_ready_req_ids) + self.assertIn(active_req_id, host._stage_recv_req_ids) + self.assertNotIn(stale_req_id, host._request_ids_mapping) + self.assertNotIn(stale_ext_id, host._put_req_chunk) + self.assertNotIn(stale_req_id, host._get_req_chunk) + self.assertNotIn(stale_req_id, host._pending_load_reqs) + self.assertNotIn(stale_req_id, host._finished_load_reqs) + self.assertNotIn(stale_req_id, host._chunk_ready_req_ids) + self.assertNotIn(stale_req_id, host._chunk_finished_req_ids) + self.assertNotIn(stale_req_id, host._chunk_stream_completed) + self.assertNotIn(stale_req_id, host._stage_recv_req_ids) + self.assertNotIn(stale_ext_id, host._send_side_request_payload) + self.assertNotIn(stale_ext_id, host._code_prompt_token_ids) + + host.shutdown_omni_connectors() + + def test_prune_inactive_requests_keeps_recently_received_full_payload_state(self): + """Late bg-thread receives must survive until the scheduler catches up.""" + host = self._make_host(stage_id=1) + req_id = "req-recv-race" + ext_id = "ext-recv-race" + + host._request_ids_mapping[req_id] = ext_id + host._put_req_chunk[ext_id] = 1 + host._local_stage_payload_cache[req_id] = {"engine_inputs": {"ids": [1, 2, 3]}} + host._local_request_metadata[req_id] = {"next_stage_prompt_len": 3} + host._stage_recv_req_ids.add(req_id) + + pruned = host.prune_inactive_requests(set()) + + self.assertEqual(pruned, set()) + self.assertIn(req_id, host._request_ids_mapping) + self.assertIn(req_id, host._local_stage_payload_cache) + self.assertIn(req_id, host._local_request_metadata) + self.assertIn(req_id, host._stage_recv_req_ids) + self.assertIn(ext_id, host._put_req_chunk) + + # Once the scheduler has consumed the wake-up and the request really + # disappears from all protected sets, prune should clean it up. + host._stage_recv_req_ids.clear() + host._local_stage_payload_cache.clear() + host._local_request_metadata.clear() + + pruned = host.prune_inactive_requests(set()) + + self.assertEqual(pruned, {req_id}) + self.assertNotIn(req_id, host._request_ids_mapping) + self.assertNotIn(ext_id, host._put_req_chunk) + + host.shutdown_omni_connectors() + + +class TestSendChunkCachesMapping(unittest.TestCase): + """Test that send_chunk caches internal→external req ID mapping.""" + + def test_send_chunk_populates_request_ids_mapping(self): + """send_chunk should cache the internal→external mapping.""" + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=0, async_chunk=True), + ) + host._omni_connector = MockConnector(stage_id=0) + host._stage_id = 0 + host._async_chunk = True + + def mock_process(transfer_manager, pooling_output, request): + return {"data": "test", "finished": False} + + host._custom_process_func = mock_process + + request = _make_request("internal-1", "external-1") + host.send_chunk(request, pooling_output={"v": 1}) + + # The mapping should be cached + self.assertEqual( + host._request_ids_mapping.get("internal-1"), + "external-1", + ) + + time.sleep(0.1) + host.shutdown_omni_connectors() + + +class TestLocalPayloadCacheLifecycle(unittest.TestCase): + """Unit tests for the local payload cache API (RFC §2.4).""" + + def _make_host(self) -> MixinHost: + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=0), + ) + host._omni_connector = MockConnector(stage_id=0) + host._stage_id = 0 + return host + + def test_put_get_pop(self): + host = self._make_host() + payload = {"engine_inputs": {"ids": [1, 2, 3]}} + host.put_local_stage_payload("r1", payload) + + self.assertEqual(host.get_local_stage_payload("r1"), payload) + popped = host.pop_local_stage_payload("r1") + self.assertEqual(popped, payload) + self.assertIsNone(host.get_local_stage_payload("r1")) + host.shutdown_omni_connectors() + + def test_recv_full_payload_inputs_populates_local_cache(self): + host = self._make_host() + host._omni_connector = MockConnector(stage_id=0) + host._stage_id = 0 + + # Simulate a full payload already staged by the bg recv path + with host._lock: + host._local_stage_payload_cache["r1"] = {"tok": [10]} + host._stage_recv_req_ids.add("r1") + + host.recv_full_payload_inputs(scheduler_output=None) + self.assertEqual(host.get_local_stage_payload("r1"), {"tok": [10]}) + host.shutdown_omni_connectors() + + def test_rank0_only_polls_connector_for_tp_full_payload(self): + host = self._make_host() + host._omni_connector = MagicMock() + host._stage_id = 2 + host._local_rank = 0 + host._request_ids_mapping["r1"] = "ext-r1" + host._get_req_chunk["r1"] = 0 + payload = {"tok": [10], "finished": torch.tensor(True)} + connector_result = (payload, 123) + host._omni_connector.get.return_value = connector_result + tp_group = _FakeTPGroup(world_size=2, rank_in_group=0) + + with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group): + made_progress = host._poll_single_request("r1") + + self.assertTrue(made_progress) + host._omni_connector.get.assert_called_once_with("1", "2", "ext-r1_1_0") + self.assertEqual(tp_group.broadcast_inputs, []) + self.assertEqual(host.get_local_stage_payload("r1"), payload) + self.assertIn("r1", host._full_payload_pending_broadcast_req_ids) + self.assertNotIn("r1", host._stage_recv_req_ids) + self.assertIsNone(host.get_local_request_metadata("r1")) + host.shutdown_omni_connectors() + + def test_tp_follower_skips_connector_poll_for_full_payload(self): + host = self._make_host() + host._omni_connector = MagicMock() + host._stage_id = 2 + host._local_rank = 1 + host._request_ids_mapping["r1"] = "ext-r1" + host._get_req_chunk["r1"] = 0 + tp_group = _FakeTPGroup(world_size=2, rank_in_group=1) + + with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group): + made_progress = host._poll_single_request("r1") + + self.assertFalse(made_progress) + host._omni_connector.get.assert_not_called() + self.assertEqual(tp_group.broadcast_inputs, []) + self.assertNotIn("r1", host._local_stage_payload_cache) + host.shutdown_omni_connectors() + + def test_recv_full_payload_inputs_broadcasts_tp_leader_results_to_followers(self): + host = self._make_host() + host._omni_connector = MagicMock() + host._stage_id = 2 + host._local_rank = 1 + host._pending_load_reqs["r1"] = object() + payload = {"tok": [10], "finished": torch.tensor(True)} + tp_group = _FakeTPGroup(world_size=2, rank_in_group=1, follower_result={"r1": payload}) + + with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group): + results = host.recv_full_payload_inputs(scheduler_output=None) + + self.assertEqual(results, {"r1": payload}) + self.assertEqual(host.get_local_stage_payload("r1"), payload) + self.assertEqual(host.get_local_request_metadata("r1"), {}) + self.assertEqual(host._stage_recv_req_ids, {"r1"}) + self.assertNotIn("r1", host._pending_load_reqs) + self.assertEqual(tp_group.broadcast_inputs, [None]) + host.shutdown_omni_connectors() + + +class TestTPAsyncChunkFanout(unittest.TestCase): + def _make_host(self, rank: int) -> MixinHost: + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=2, async_chunk=True, worker_type="gen"), + ) + host._omni_connector = MagicMock() + host._stage_id = 2 + host._async_chunk = True + host._model_mode = "gen" + host._local_rank = rank + host._request_ids_mapping["r1"] = "ext-r1" + host._get_req_chunk["r1"] = 0 + return host + + def test_rank0_only_polls_connector_for_tp_async_chunk(self): + host = self._make_host(rank=0) + payload = { + "code_predictor_codes": [10, 11], + "left_context_size": 0, + "finished": torch.tensor(False), + } + host._omni_connector.get.return_value = (payload, 123) + tp_group = _FakeTPGroup(world_size=2, rank_in_group=0) + + with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group): + made_progress = host._poll_single_request("r1") + + self.assertTrue(made_progress) + host._omni_connector.get.assert_called_once_with("1", "2", "ext-r1_1_0") + self.assertEqual(host.get_local_stage_payload("r1"), payload) + self.assertIn("r1", host._finished_load_reqs) + self.assertIn("r1", host._async_chunk_updated_req_ids) + self.assertEqual(tp_group.broadcast_inputs, []) + host.shutdown_omni_connectors() + + def test_tp_follower_skips_connector_poll_for_async_chunk(self): + host = self._make_host(rank=1) + tp_group = _FakeTPGroup(world_size=2, rank_in_group=1) + + with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group): + made_progress = host._poll_single_request("r1") + + self.assertFalse(made_progress) + host._omni_connector.get.assert_not_called() + self.assertIsNone(host.get_local_stage_payload("r1")) + self.assertEqual(tp_group.broadcast_inputs, []) + host.shutdown_omni_connectors() + + def test_get_output_broadcasts_tp_async_chunk_payloads_to_followers(self): + host = self._make_host(rank=1) + host._pending_load_reqs["r1"] = object() + payload = { + "code_predictor_codes": [10, 11], + "left_context_size": 0, + "finished": torch.tensor(True), + } + packet = { + "staged_payloads": {"r1": payload}, + "request_metadata": {"r1": {"code_predictor_codes": [10, 11], "left_context_size": 0}}, + "newly_finished": {"r1"}, + "chunk_finished": {"r1"}, + } + tp_group = _FakeTPGroup(world_size=2, rank_in_group=1, follower_result=packet) + + with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group): + output = host.get_omni_connector_output() + + self.assertEqual(output.chunk_ready_req_ids, {"r1"}) + self.assertEqual(output.chunk_finished_req_ids, {"r1"}) + self.assertEqual( + output.request_metadata, + {"r1": {"code_predictor_codes": [10, 11], "left_context_size": 0}}, + ) + self.assertEqual(host.get_local_stage_payload("r1"), payload) + self.assertNotIn("r1", host._pending_load_reqs) + self.assertIn("r1", host._chunk_stream_completed) + self.assertEqual(tp_group.broadcast_inputs, [None]) + host.shutdown_omni_connectors() + + +class TestKVTransferLifecycle(unittest.TestCase): + """Unit tests for KV transfer lifecycle methods.""" + + def _make_host(self) -> MixinHost: + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=0), + ) + return host + + def test_mark_drain_ack_complete(self): + host = self._make_host() + self.assertFalse(host.has_pending_kv_work()) + + host.mark_kv_transfer("r1", seq_len=100, block_ids=[0, 1, 2]) + self.assertTrue(host.has_pending_kv_work()) + self.assertTrue(host.is_kv_transfer_triggered("r1")) + + # Drain moves pending → active + pending = host.drain_pending_kv_transfers() + self.assertEqual(pending, {"r1": {"seq_len": 100, "block_ids": [0, 1, 2]}}) + self.assertIn("r1", host._kv_active_transfers) + self.assertTrue(host.has_pending_kv_work()) + + # Ack moves active → completed + host.ack_kv_transfers(["r1"]) + self.assertNotIn("r1", host._kv_active_transfers) + self.assertIn("r1", host._kv_completed_transfers) + + # Drain completed + completed = host.drain_completed_kv_transfers() + self.assertEqual(completed, {"r1"}) + self.assertFalse(host.has_pending_kv_work()) + host.shutdown_omni_connectors() + + def test_mark_dedup(self): + host = self._make_host() + host.mark_kv_transfer("r1", seq_len=100, block_ids=[0]) + host.mark_kv_transfer("r1", seq_len=200, block_ids=[0, 1]) + # Second mark is a no-op + self.assertEqual(host._kv_pending_transfers["r1"]["seq_len"], 100) + host.shutdown_omni_connectors() + + def test_cleanup_removes_kv_state(self): + host = self._make_host() + host.mark_kv_transfer("r1", seq_len=50, block_ids=[0]) + host.drain_pending_kv_transfers() + host.cleanup_finished_request("r1") + self.assertFalse(host.is_kv_transfer_triggered("r1")) + self.assertNotIn("r1", host._kv_active_transfers) + self.assertFalse(host.has_pending_kv_work()) + host.shutdown_omni_connectors() + + +class TestAsyncPayloadLifecycle(unittest.TestCase): + """Regression tests for async payload delivery lifecycle.""" + + def test_send_side_request_payload_not_cleared_before_payload_is_consumable(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"), + ) + host._request_ids_mapping["r1"] = "r1" + payload = { + "thinker_decode_embeddings": torch.ones(1, 2), + "thinker_output_token_ids": [1], + "override_keys": ["thinker_decode_embeddings", "thinker_output_token_ids"], + "finished": torch.tensor(False), + } + + host._accumulate_payload("r1", dict(payload)) + with host._lock: + host._finished_load_reqs.add("r1") + + host.get_omni_connector_output() + self.assertIn("r1", host._send_side_request_payload) + host.shutdown_omni_connectors() + + def test_payload_consumable_ignores_token_horizon_only_updates(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"), + ) + payload = { + "thinker_output_token_ids": [1, 2, 3], + "finished": torch.tensor(False), + "override_keys": [ + "thinker_output_token_ids", + "thinker_decode_embeddings_token_start", + "thinker_decode_embeddings_token_end", + ], + "thinker_decode_embeddings_token_start": 2, + "thinker_decode_embeddings_token_end": 3, + } + self.assertFalse(host._payload_is_consumable(payload)) + host.shutdown_omni_connectors() + + def test_payload_consumable_accepts_decode_embeddings(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"), + ) + payload = { + "thinker_output_token_ids": [1, 2, 3], + "thinker_decode_embeddings": torch.ones(1, 2), + "finished": torch.tensor(False), + } + self.assertTrue(host._payload_is_consumable(payload)) + host.shutdown_omni_connectors() + + def test_ar_metadata_only_followup_chunk_does_not_rewake_request(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"), + ) + host._omni_connector = MagicMock() + host._stage_id = 1 + host._async_chunk = True + host._model_mode = "ar" + host._request_ids_mapping["r1"] = "ext-r1" + host._get_req_chunk["r1"] = 0 + + host._omni_connector.get.side_effect = [ + ( + { + "thinker_decode_embeddings": torch.ones(1, 2), + "finished": torch.tensor(False), + }, + 1, + ), + ( + { + "next_stage_prompt_len": 7, + "finished": torch.tensor(False), + }, + 1, + ), + ] + + host._poll_single_request("r1") + output1 = host.get_omni_connector_output() + self.assertEqual(output1.chunk_ready_req_ids, {"r1"}) + + host._poll_single_request("r1") + output2 = host.get_omni_connector_output() + self.assertEqual(output2.chunk_ready_req_ids, set()) + self.assertEqual(output2.request_metadata, {"r1": {"next_stage_prompt_len": 7}}) + + host.shutdown_omni_connectors() + + def test_non_ar_recv_does_not_overwrite_unconsumed_staged_chunk(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=2, async_chunk=True, worker_type="gen"), + ) + host._omni_connector = MagicMock() + host._stage_id = 2 + host._async_chunk = True + host._model_mode = "gen" + host._request_ids_mapping["r1"] = "ext-r1" + host._get_req_chunk["r1"] = 1 + host._local_stage_payload_cache["r1"] = { + "code_predictor_codes": [1, 2, 3], + "left_context_size": 0, + "finished": torch.tensor(False), + } + + made_progress = host._poll_single_request("r1") + + self.assertFalse(made_progress) + host._omni_connector.get.assert_not_called() + self.assertEqual(host._get_req_chunk["r1"], 1) + + host.shutdown_omni_connectors() + + def test_non_ar_recv_waits_for_scheduler_handoff_before_fetching_next_chunk(self): + host = MixinHost() + host.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=2, async_chunk=True, worker_type="gen"), + ) + host._omni_connector = MagicMock() + host._stage_id = 2 + host._async_chunk = True + host._model_mode = "gen" + host._request_ids_mapping["r1"] = "ext-r1" + host._get_req_chunk["r1"] = 1 + host._local_request_metadata["r1"] = { + "code_predictor_codes": [10, 11, 12], + "left_context_size": 0, + } + host._finished_load_reqs.add("r1") + + made_progress = host._poll_single_request("r1") + + self.assertFalse(made_progress) + host._omni_connector.get.assert_not_called() + self.assertEqual(host._get_req_chunk["r1"], 1) + + output = host.get_omni_connector_output() + self.assertEqual(output.request_metadata["r1"]["code_predictor_codes"], [10, 11, 12]) + self.assertEqual(output.chunk_ready_req_ids, {"r1"}) + + host._omni_connector.get.return_value = ( + { + "code_predictor_codes": [20, 21, 22], + "left_context_size": 0, + "finished": torch.tensor(False), + }, + 1, + ) + made_progress = host._poll_single_request("r1") + + self.assertTrue(made_progress) + host._omni_connector.get.assert_called_once() + self.assertEqual(host._get_req_chunk["r1"], 2) + + host.shutdown_omni_connectors() + + +class TestRankAwareKVRouting(unittest.TestCase): + def _make_host(self, *, from_tp: int, to_tp: int, local_rank: int) -> MixinHost: + host = MixinHost() + host.init_omni_connectors(vllm_config=None, model_config=_make_model_config(stage_id=1)) + host._from_tp = from_tp + host._to_tp = to_tp + host._local_rank = local_rank + return host + + def test_recv_keys_use_remote_rank_as_from_rank(self): + host = self._make_host(from_tp=4, to_tp=2, local_rank=1) + self.assertEqual( + host.get_rank_aware_kv_keys("req", from_stage=0), + ["req_0_0_2_1", "req_0_0_3_1"], + ) + host.shutdown_omni_connectors() + + def test_send_keys_route_from_rank_gt_to_rank(self): + host = self._make_host(from_tp=4, to_tp=2, local_rank=3) + self.assertEqual(host.get_rank_aware_kv_send_keys("req", from_stage=0), ["req_0_0_3_1"]) + host.shutdown_omni_connectors() + + def test_invalid_recv_rank_mapping_raises(self): + host = self._make_host(from_tp=3, to_tp=2, local_rank=1) + with self.assertRaises(ValueError): + host.get_rank_aware_kv_keys("req", from_stage=0) + host.shutdown_omni_connectors() + + def test_invalid_send_rank_mapping_raises(self): + host = self._make_host(from_tp=3, to_tp=2, local_rank=1) + with self.assertRaises(ValueError): + host.get_rank_aware_kv_send_keys("req", from_stage=0) + host.shutdown_omni_connectors() + + def test_merge_rank_sharded_payloads_concatenates_head_dimension(self): + host = self._make_host(from_tp=4, to_tp=2, local_rank=0) + payloads = [ + {"layer_blocks": {"key_cache": [torch.ones(2, 1, 3)], "value_cache": [torch.ones(2, 1, 3)]}}, + {"layer_blocks": {"key_cache": [torch.full((2, 1, 3), 2.0)], "value_cache": [torch.full((2, 1, 3), 2.0)]}}, + ] + merged = host._merge_rank_sharded_kv_payloads(payloads) + self.assertEqual(tuple(merged["layer_blocks"]["key_cache"][0].shape), (2, 2, 3)) + self.assertTrue(torch.equal(merged["layer_blocks"]["key_cache"][0][:, 0], torch.ones(2, 3))) + self.assertTrue(torch.equal(merged["layer_blocks"]["key_cache"][0][:, 1], torch.full((2, 3), 2.0))) + host.shutdown_omni_connectors() + + def test_slice_rank_sharded_payload_splits_head_dimension(self): + host = self._make_host(from_tp=2, to_tp=4, local_rank=1) + payload = { + "layer_blocks": { + "key_cache": [torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)], + "value_cache": [torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)], + }, + "metadata": {}, + } + sliced = host._slice_rank_sharded_kv_payload(payload) + self.assertEqual(tuple(sliced["layer_blocks"]["key_cache"][0].shape), (2, 2, 3)) + expected = torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)[:, 2:4, :] + self.assertTrue(torch.equal(sliced["layer_blocks"]["key_cache"][0], expected)) + host.shutdown_omni_connectors() + + +class TestAttachOmniConnectorOutput(unittest.TestCase): + def test_wraps_empty_model_runner_output_when_signals_exist(self): + from vllm.v1.worker.gpu_model_runner import EMPTY_MODEL_RUNNER_OUTPUT + + host = MixinHost() + host.get_omni_connector_output = lambda: OmniConnectorOutput(chunk_ready_req_ids={"req-1"}) + + wrapped = host.attach_omni_connector_output(EMPTY_MODEL_RUNNER_OUTPUT) + + self.assertIsNot(wrapped, EMPTY_MODEL_RUNNER_OUTPUT) + self.assertEqual(wrapped.omni_connector_output.chunk_ready_req_ids, {"req-1"}) + + +class TestConnectorConfigValidation(unittest.TestCase): + def test_invalid_connector_name_raises(self): + host = MixinHost() + model_config = _make_model_config(stage_id=1) + model_config.stage_connector_config = {"name": " "} + + with self.assertRaisesRegex(RuntimeError, "missing connector name"): + host.init_omni_connectors(vllm_config=None, model_config=model_config) + + +class _FailingConnector: + """Connector whose put() fails a configurable number of times.""" + + def __init__(self, fail_count: int = 1, raise_on_fail: bool = False): + self._fail_count = fail_count + self._raise_on_fail = raise_on_fail + self.attempt = 0 + + def put(self, from_stage, to_stage, put_key, data): + self.attempt += 1 + if self.attempt <= self._fail_count: + if self._raise_on_fail: + raise ConnectionError("transient connector error") + return False, 0, None + return True, len(str(data)), None + + def get(self, *a, **kw): + return None + + def close(self): + pass + + +class TestSendRetry(unittest.TestCase): + """Tests for P1-2: failed connector sends must be retried.""" + + def _make_sender(self, connector): + sender = MixinHost() + sender.init_omni_connectors( + vllm_config=None, + model_config=_make_model_config(stage_id=0, async_chunk=True), + ) + sender._omni_connector = connector + sender._stage_id = 0 + sender._async_chunk = True + return sender + + def _make_task(self, req_id="r1"): + return { + "stage_id": 0, + "next_stage_id": 1, + "request_id": req_id, + "data": {"payload": "test"}, + } + + def test_send_single_request_returns_false_on_put_failure(self): + connector = _FailingConnector(fail_count=999) + sender = self._make_sender(connector) + + result = sender._send_single_request(self._make_task()) + self.assertFalse(result) + sender.shutdown_omni_connectors() + + def test_send_single_request_does_not_decrement_on_failure(self): + connector = _FailingConnector(fail_count=999) + sender = self._make_sender(connector) + sender._pending_save_counts["r1"] = 1 + + sender._send_single_request(self._make_task()) + self.assertEqual(sender._pending_save_counts.get("r1"), 1, "pending count must NOT be decremented on failure") + sender.shutdown_omni_connectors() + + def test_send_single_request_decrements_on_success(self): + connector = MockConnector(stage_id=0) + sender = self._make_sender(connector) + sender._pending_save_counts["r1"] = 1 + + result = sender._send_single_request(self._make_task()) + self.assertTrue(result) + self.assertNotIn("r1", sender._pending_save_counts, "pending count should be zero/removed on success") + sender.shutdown_omni_connectors() + + def test_requeue_or_drop_requeues_on_first_failure(self): + sender = self._make_sender(MockConnector(stage_id=0)) + task = self._make_task() + + sender._requeue_or_drop_failed_send(task) + + self.assertEqual(task.get("_retry_count"), 1) + with sender._lock: + dq = sender._pending_save_reqs.get("r1") + self.assertIsNotNone(dq) + self.assertEqual(len(dq), 1) + sender.shutdown_omni_connectors() + + def test_requeue_or_drop_drops_after_max_retries(self): + sender = self._make_sender(MockConnector(stage_id=0)) + sender._pending_save_counts["r1"] = 1 + task = self._make_task() + task["_retry_count"] = sender._MAX_SEND_RETRIES # already at max + + sender._requeue_or_drop_failed_send(task) + + with sender._lock: + dq = sender._pending_save_reqs.get("r1") + self.assertTrue(dq is None or len(dq) == 0, "task should NOT be re-enqueued after max retries") + self.assertNotIn("r1", sender._pending_save_counts, "pending count should be cleaned up on final drop") + sender.shutdown_omni_connectors() + + def test_save_loop_retries_on_exception(self): + """Integration: _save_loop retries a task when put() raises.""" + from collections import deque + + connector = _FailingConnector(fail_count=1, raise_on_fail=True) + sender = self._make_sender(connector) + task = self._make_task() + + with sender._lock: + sender._pending_save_reqs["r1"] = deque([task]) + sender._pending_save_counts["r1"] = 1 + + sender._stop_event.clear() + + def run_one_loop(): + sender._save_loop() + + sender._stop_event.set() # will exit after one iteration + # Run manually instead of threading + # Simulate: pop task, send fails, requeue + popped_task = None + with sender._lock: + dq = sender._pending_save_reqs.get("r1") + if dq: + popped_task = dq.popleft() + if not dq: + del sender._pending_save_reqs["r1"] + + if popped_task is not None: + success = False + try: + success = sender._send_single_request(popped_task) + except Exception: + pass + if not success: + sender._requeue_or_drop_failed_send(popped_task) + + # After first failure, task should be re-enqueued + with sender._lock: + dq = sender._pending_save_reqs.get("r1") + self.assertIsNotNone(dq) + self.assertEqual(len(dq), 1) + requeued = dq[0] + self.assertEqual(requeued.get("_retry_count"), 1) + + # Second attempt should succeed (connector now returns True) + success = sender._send_single_request(requeued) + self.assertTrue(success) + sender.shutdown_omni_connectors() + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/nightly/generate_nightly_perf_excel.py b/tools/nightly/generate_nightly_perf_excel.py index 817f37f664..6ba1d1eef0 100644 --- a/tools/nightly/generate_nightly_perf_excel.py +++ b/tools/nightly/generate_nightly_perf_excel.py @@ -23,6 +23,22 @@ GREY_BLOCK_FILL = PatternFill(start_color="D3D3D3", fill_type="solid") # Diffusion sheet columns (Qwen-Image diffusion benchmark). +# Per-stage latency metrics. Unpack from stage_durations_mean/p50/p99 dicts +DIFFUSION_STAGE_LATENCY_COLUMNS: tuple[str, ...] = ( + # "vae.encode_mean", + # "vae.encode_p50", + # "vae.encode_p99", + "vae.decode_mean", + "vae.decode_p50", + "vae.decode_p99", + "diffuse_mean", + "diffuse_p50", + "diffuse_p99", + "text_encoder.forward_mean", + "text_encoder.forward_p50", + "text_encoder.forward_p99", +) + DIFFUSION_BENCHMARK_COLUMNS: tuple[str, ...] = ( "duration", "completed_requests", @@ -36,7 +52,7 @@ "peak_memory_mb_mean", "peak_memory_mb_median", "slo_attainment_rate", -) +) + DIFFUSION_STAGE_LATENCY_COLUMNS DIFFUSION_NUMERIC_FORMAT_COLUMNS: tuple[str, ...] = DIFFUSION_BENCHMARK_COLUMNS @@ -63,7 +79,7 @@ "build_id", "build_url", "source_file", -) +) + DIFFUSION_STAGE_LATENCY_COLUMNS # Benchmark metric columns: grey the latest row's cell when value changed vs previous date. BENCHMARK_COLUMNS: tuple[str, ...] = ( @@ -106,7 +122,7 @@ _COLUMNS_FILENAME = "nightly_perf_summary_columns.txt" _RESULT_JSON_PREFIX = "result_test_" -_DIFFUSION_JSON_PREFIX = "diffusion_perf_" +_DIFFUSION_RESULT_PREFIX = "diffusion_result_" DEFAULT_INPUT_DIR = os.getenv("DEFAULT_INPUT_DIR") if os.getenv("DEFAULT_INPUT_DIR") else "tests" DEFAULT_OUTPUT_DIR = os.getenv("DEFAULT_OUTPUT_DIR") if os.getenv("DEFAULT_OUTPUT_DIR") else "tests" DEFAULT_DIFFUSION_INPUT_DIR = os.getenv("DIFFUSION_BENCHMARK_DIR") @@ -252,7 +268,7 @@ def parse_args() -> argparse.Namespace: type=str, default=None, help=( - "Directory containing diffusion_perf_*.json files; default is " + "Directory containing diffusion_result_*.json files; default is " "DIFFUSION_BENCHMARK_DIR, fallback to --input-dir." ), ) @@ -286,7 +302,7 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() -def _load_json_file(path: str) -> dict[str, Any] | None: +def _load_json_file(path: str) -> dict[str, Any] | list[Any] | None: """Safely load a single JSON file; return None and log a warning on failure.""" try: with open(path, encoding="utf-8") as f: @@ -295,18 +311,18 @@ def _load_json_file(path: str) -> dict[str, Any] | None: LOGGER.warning("failed to load json '%s': %s", path, exc) return None - if not isinstance(data, dict): - LOGGER.warning("json root in '%s' is not an object, skip", path) + if not isinstance(data, (dict, list)): + LOGGER.warning("json root in '%s' is not a dict or list, skip", path) return None return data def _parse_from_filename(filename: str) -> dict[str, Any]: - """Parse test-related metadata from a result JSON filename. + """Parse test-related metadata from a ``result_test_*.json`` filename. - Expected pattern (after prefix/suffix stripped): - ____ + Matches ``tests/dfx/perf/scripts/run_benchmark.py`` naming, including optional + ``_in{X}_out{Y}_`` before the timestamp (``na`` when unset). """ name, ext = os.path.splitext(filename) if ext != ".json" or not name.startswith(_RESULT_JSON_PREFIX): @@ -315,22 +331,42 @@ def _parse_from_filename(filename: str) -> dict[str, Any]: core = name[len(_RESULT_JSON_PREFIX) :] parts = core.split("_") if len(parts) < 5: - LOGGER.warning("filename '%s' does not match expected pattern, skip parsing test metadata", filename) + LOGGER.warning( + "filename '%s' does not match expected pattern (need >= 5 segments), skip parsing", + filename, + ) return {} - timestamp = parts[-1] - num_prompts_str = parts[-2] - max_concurrency_str = parts[-3] - dataset_name = parts[-4] - test_name = "_".join(parts[:-4]) if parts[:-4] else "" + idx = len(parts) - 1 + timestamp = parts[idx] + idx -= 1 parsed: dict[str, Any] = {} - if len(timestamp) >= 15: parsed["date"] = timestamp - if dataset_name in DATASET_NAME_ALLOWED: - parsed["dataset_name"] = dataset_name + if idx >= 0 and parts[idx].startswith("out"): + parsed["random_output_len"] = parts[idx][3:] + idx -= 1 + if idx >= 0 and parts[idx].startswith("in"): + parsed["random_input_len"] = parts[idx][2:] + idx -= 1 + + if idx < 3: + LOGGER.warning( + "filename '%s' has too few segments after timestamp / optional in-out (idx=%s)", + filename, + idx, + ) + return parsed + + num_prompts_str = parts[idx] + idx -= 1 + flow_str = parts[idx] + idx -= 1 + dataset_name = parts[idx] + idx -= 1 + test_name = "_".join(parts[: idx + 1]) if idx >= 0 else "" try: parsed["num_prompts"] = int(num_prompts_str) @@ -338,13 +374,16 @@ def _parse_from_filename(filename: str) -> dict[str, Any]: pass try: - parsed["max_concurrency"] = int(max_concurrency_str) + parsed["max_concurrency"] = int(flow_str) except (TypeError, ValueError): pass if test_name: parsed["test_name"] = test_name + if dataset_name in DATASET_NAME_ALLOWED: + parsed["dataset_name"] = dataset_name + return parsed @@ -396,27 +435,29 @@ def _iter_omni_json_records(input_dir: str) -> Iterable[dict[str, Any]]: yield record -def _parse_diffusion_from_filename(filename: str) -> dict[str, Any]: - """Parse diffusion test_name/date from filename: diffusion_perf__.json""" +def _parse_diffusion_result_from_filename(filename: str) -> dict[str, Any]: + """Parse test_name/date from filename: diffusion_result__.json""" name, ext = os.path.splitext(filename) - if ext != ".json" or not name.startswith(_DIFFUSION_JSON_PREFIX): + if ext != ".json" or not name.startswith(_DIFFUSION_RESULT_PREFIX): return {} - core = name[len(_DIFFUSION_JSON_PREFIX) :] + core = name[len(_DIFFUSION_RESULT_PREFIX) :] parts = core.split("_") if len(parts) < 2: return {} timestamp = parts[-1] - test_name = "_".join(parts[:-1]) if parts[:-1] else "" parsed: dict[str, Any] = {} if len(timestamp) >= 15: parsed["date"] = timestamp - if test_name: - parsed["test_name"] = test_name return parsed -def _iter_diffusion_json_records(input_dir: str) -> Iterable[dict[str, Any]]: - """Iterate over diffusion_perf_*.json files and yield normalized diffusion records.""" +def _iter_diffusion_records(input_dir: str) -> Iterable[dict[str, Any]]: + """Iterate over diffusion_result_*.json files and yield normalized records. + + Unlike omni format where each JSON file contains one test case, diffusion format + produces a single JSON file containing a list of all test case records. + Test params (feature toggles) are NOT embedded in the filename. + """ if not os.path.isdir(input_dir): LOGGER.warning("diffusion input dir '%s' does not exist or is not a directory", input_dir) return @@ -424,7 +465,7 @@ def _iter_diffusion_json_records(input_dir: str) -> Iterable[dict[str, Any]]: for entry in sorted(os.listdir(input_dir)): if not entry.endswith(".json"): continue - if not entry.startswith(_DIFFUSION_JSON_PREFIX): + if not entry.startswith(_DIFFUSION_RESULT_PREFIX): continue full_path = os.path.join(input_dir, entry) if not os.path.isfile(full_path): @@ -434,23 +475,63 @@ def _iter_diffusion_json_records(input_dir: str) -> Iterable[dict[str, Any]]: if data is None: continue - record: dict[str, Any] = dict(data) - filename_meta = _parse_diffusion_from_filename(os.path.basename(full_path)) - if "date" not in record or not record.get("date"): - record["date"] = filename_meta.get("date") or datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S") - if "test_name" not in record or not record.get("test_name"): - if "test_name" in filename_meta: - record["test_name"] = filename_meta["test_name"] - record["source_file"] = os.path.basename(full_path) - yield record + filename_meta = _parse_diffusion_result_from_filename(os.path.basename(full_path)) + if not isinstance(data, list): + LOGGER.warning("diffusion result file '%s' root is not a list, skip", full_path) + continue + + for record in data: + if not isinstance(record, dict): + continue + record = dict(record) + if "date" not in record or not record.get("date"): + record["date"] = filename_meta.get("date") or datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S") + record["source_file"] = os.path.basename(full_path) + yield record -def _collect_records(input_dir: str) -> list[dict[str, Any]]: + +def _collect_omni_records(input_dir: str) -> list[dict[str, Any]]: return list(_iter_omni_json_records(input_dir)) def _collect_diffusion_records(diffusion_input_dir: str) -> list[dict[str, Any]]: - return list(_iter_diffusion_json_records(diffusion_input_dir)) + """Collect diffusion records from diffusion_result_*.json files. + Their format is different from omni JSON files. + """ + return [_process_diffusion_record(r) for r in _iter_diffusion_records(diffusion_input_dir)] + + +def _flatten_stage_durations(record: dict[str, Any]) -> dict[str, Any]: + """Flatten stage_durations dict into individual columns matching DIFFUSION_STAGE_LATENCY_COLUMNS.""" + result = dict(record) + + for prefix in ("stage_durations_mean", "stage_durations_p50", "stage_durations_p99"): + durations = result.pop(prefix, None) + if not isinstance(durations, dict): + continue + + suffix = prefix.replace("stage_durations_", "") # "mean", "p50", "p99" + + for stage_key, value in durations.items(): # e.g., "SomePipeline.vae.decode_mean": 100.0 + stage_key = stage_key.split(".", 1)[-1] # "decode_mean" + col_name = f"{stage_key}_{suffix}" + if col_name not in DIFFUSION_STAGE_LATENCY_COLUMNS: + print(f"skipping stage_key: {col_name}") + continue + result[col_name] = value + + return result + + +def _process_diffusion_record(record: dict[str, Any]) -> dict[str, Any]: + """Normalize a diffusion record by merging `result` and flattening stage metrics.""" + flat = record.copy() + flat.update(flat.pop("result", {})) + flat = _flatten_stage_durations(flat) + flat.pop("benchmark_params", None) + flat.pop("server_params", None) + return flat def _apply_build_metadata_to_latest_only( @@ -493,7 +574,7 @@ def _apply_build_metadata_to_latest_only( def _sort_records_for_summary(records: list[dict[str, Any]]) -> list[dict[str, Any]]: """Sort so that same test configuration is grouped, newest date first within each group.""" - by_date_desc = sorted(records, key=lambda r: (r.get("date") or ""), reverse=True) + by_date_desc = sorted(records, key=lambda r: r.get("date") or "", reverse=True) return sorted( by_date_desc, key=_omni_group_key, @@ -501,7 +582,7 @@ def _sort_records_for_summary(records: list[dict[str, Any]]) -> list[dict[str, A def _sort_diffusion_records_for_summary(records: list[dict[str, Any]]) -> list[dict[str, Any]]: - by_date_desc = sorted(records, key=lambda r: (r.get("date") or ""), reverse=True) + by_date_desc = sorted(records, key=lambda r: r.get("date") or "", reverse=True) return sorted(by_date_desc, key=_diffusion_group_key) @@ -584,6 +665,21 @@ def _to_float_if_numeric(value: Any) -> Any: return value +def _to_excel_compatible(value: Any) -> Any: + """Convert non-scalar objects to JSON string for openpyxl cell values.""" + if isinstance(value, (dict, list, tuple)): + try: + return json.dumps(value, ensure_ascii=False, sort_keys=True) + except (TypeError, ValueError): + return str(value) + if isinstance(value, set): + try: + return json.dumps(sorted(value), ensure_ascii=False) + except (TypeError, ValueError): + return str(value) + return value + + def _write_sheet( ws, columns: Sequence[str], @@ -599,6 +695,7 @@ def _write_sheet( v = record.get(col) if col in numeric_set: v = _to_float_if_numeric(v) + v = _to_excel_compatible(v) row_values.append(v) ws.append(row_values) @@ -678,7 +775,7 @@ def generate_excel_report( script_dir = os.path.dirname(os.path.abspath(__file__)) omni_summary_columns = _ensure_omni_summary_columns(_load_summary_columns(script_dir)) - omni_records = _collect_records(input_dir) + omni_records = _collect_omni_records(input_dir) diffusion_records = _collect_diffusion_records(diffusion_input_dir) if not omni_records: diff --git a/tools/nightly/generate_nightly_perf_html.py b/tools/nightly/generate_nightly_perf_html.py index 05dc48d717..a50a462550 100644 --- a/tools/nightly/generate_nightly_perf_html.py +++ b/tools/nightly/generate_nightly_perf_html.py @@ -17,7 +17,7 @@ LOGGER = logging.getLogger(__name__) _RESULT_JSON_PREFIX = "result_test_" -_DIFFUSION_JSON_PREFIX = "diffusion_perf_" +_DIFFUSION_JSON_PREFIXES = ("diffusion_perf_", "diffusion_result_") DEFAULT_INPUT_DIR = os.getenv("DEFAULT_INPUT_DIR") or "tests" DEFAULT_OUTPUT_DIR = os.getenv("DEFAULT_OUTPUT_DIR") or "tests" DEFAULT_DIFFUSION_INPUT_DIR = os.getenv("DIFFUSION_BENCHMARK_DIR") @@ -51,7 +51,7 @@ def _default_diffusion_input_dir(input_dir: str) -> str: return DEFAULT_DIFFUSION_INPUT_DIR if DEFAULT_DIFFUSION_INPUT_DIR else input_dir -def _load_json_file(path: str) -> dict[str, Any] | None: +def _load_json_file(path: str) -> dict[str, Any] | list[Any] | None: try: with open(path, encoding="utf-8") as f: data = json.load(f) @@ -59,14 +59,15 @@ def _load_json_file(path: str) -> dict[str, Any] | None: LOGGER.warning("failed to load json '%s': %s", path, exc) return None - if not isinstance(data, dict): - LOGGER.warning("json root in '%s' is not an object, skip", path) + if not isinstance(data, (dict, list)): + LOGGER.warning("json root in '%s' is not an object or list, skip", path) return None return data def _parse_from_filename(filename: str) -> dict[str, Any]: + """Parse ``result_test_*.json`` filenames; same rules as ``generate_nightly_perf_excel``.""" name, ext = os.path.splitext(filename) if ext != ".json" or not name.startswith(_RESULT_JSON_PREFIX): return {} @@ -75,32 +76,58 @@ def _parse_from_filename(filename: str) -> dict[str, Any]: parts = core.split("_") if len(parts) < 5: LOGGER.warning( - "filename '%s' does not match expected pattern, skip parsing test metadata", + "filename '%s' does not match expected pattern (need >= 5 segments), skip parsing", filename, ) return {} - timestamp = parts[-1] - num_prompts_str = parts[-2] - max_concurrency_str = parts[-3] - dataset_name = parts[-4] - test_name = "_".join(parts[:-4]) if parts[:-4] else "" + idx = len(parts) - 1 + timestamp = parts[idx] + idx -= 1 parsed: dict[str, Any] = {} if len(timestamp) >= 15: parsed["date"] = timestamp - if dataset_name in ("random", "random-mm"): - parsed["dataset_name"] = dataset_name + + if idx >= 0 and parts[idx].startswith("out"): + parsed["random_output_len"] = parts[idx][3:] + idx -= 1 + if idx >= 0 and parts[idx].startswith("in"): + parsed["random_input_len"] = parts[idx][2:] + idx -= 1 + + if idx < 3: + LOGGER.warning( + "filename '%s' has too few segments after timestamp / optional in-out (idx=%s)", + filename, + idx, + ) + return parsed + + num_prompts_str = parts[idx] + idx -= 1 + flow_str = parts[idx] + idx -= 1 + dataset_name = parts[idx] + idx -= 1 + test_name = "_".join(parts[: idx + 1]) if idx >= 0 else "" + try: parsed["num_prompts"] = int(num_prompts_str) except (TypeError, ValueError): pass + try: - parsed["max_concurrency"] = int(max_concurrency_str) + parsed["max_concurrency"] = int(flow_str) except (TypeError, ValueError): pass + if test_name: parsed["test_name"] = test_name + + if dataset_name in ("random", "random-mm"): + parsed["dataset_name"] = dataset_name + return parsed @@ -143,9 +170,10 @@ def _iter_omni_json_records(input_dir: str) -> Iterable[dict[str, Any]]: def _parse_diffusion_from_filename(filename: str) -> dict[str, Any]: name, ext = os.path.splitext(filename) - if ext != ".json" or not name.startswith(_DIFFUSION_JSON_PREFIX): + if ext != ".json" or not any(name.startswith(prefix) for prefix in _DIFFUSION_JSON_PREFIXES): return {} - core = name[len(_DIFFUSION_JSON_PREFIX) :] + matched_prefix = next(prefix for prefix in _DIFFUSION_JSON_PREFIXES if name.startswith(prefix)) + core = name[len(matched_prefix) :] parts = core.split("_") if len(parts) < 2: return {} @@ -168,7 +196,7 @@ def _iter_diffusion_json_records(input_dir: str) -> Iterable[dict[str, Any]]: return for entry in sorted(os.listdir(input_dir)): - if not entry.endswith(".json") or not entry.startswith(_DIFFUSION_JSON_PREFIX): + if not entry.endswith(".json") or not any(entry.startswith(prefix) for prefix in _DIFFUSION_JSON_PREFIXES): continue full_path = os.path.join(input_dir, entry) if not os.path.isfile(full_path): @@ -177,17 +205,32 @@ def _iter_diffusion_json_records(input_dir: str) -> Iterable[dict[str, Any]]: if data is None: continue - record: dict[str, Any] = dict(data) filename_meta = _parse_diffusion_from_filename(os.path.basename(full_path)) - if "date" not in record or not record.get("date"): - record["date"] = filename_meta.get("date") or datetime.now( - timezone.utc, - ).strftime("%Y%m%d-%H%M%S") - if "test_name" not in record or not record.get("test_name"): - if "test_name" in filename_meta: - record["test_name"] = filename_meta["test_name"] - record["source_file"] = os.path.basename(full_path) - yield record + if isinstance(data, dict): + records = [data] + elif isinstance(data, list): + records = [r for r in data if isinstance(r, dict)] + else: + records = [] + + if not records: + LOGGER.warning("diffusion json '%s' has no valid records, skip", full_path) + continue + + for record in records: + flat: dict[str, Any] = dict(record) + result = flat.pop("result", None) + if isinstance(result, dict): + flat.update(result) + if "date" not in flat or not flat.get("date"): + flat["date"] = filename_meta.get("date") or datetime.now( + timezone.utc, + ).strftime("%Y%m%d-%H%M%S") + if "test_name" not in flat or not flat.get("test_name"): + if "test_name" in filename_meta: + flat["test_name"] = filename_meta["test_name"] + flat["source_file"] = os.path.basename(full_path) + yield flat def _collect_omni_records(input_dir: str) -> list[dict[str, Any]]: diff --git a/tools/wan22/assemble_wan22_i2v_diffusers.py b/tools/wan22/assemble_wan22_i2v_diffusers.py new file mode 100644 index 0000000000..8e14ca3c26 --- /dev/null +++ b/tools/wan22/assemble_wan22_i2v_diffusers.py @@ -0,0 +1,385 @@ +#!/usr/bin/env python3 +""" +Assemble a Wan2.2-I2V-A14B-Diffusers-style model directory using a Diffusers +skeleton and optional replacement transformer checkpoints. + +This tool does NOT run any external conversion step. You can use it in two +ways: +- keep the original weights from the Diffusers skeleton +- replace transformer/transformer_2 with converted checkpoints such as + LightX2V outputs +- use legacy LightX2V arg names (--high-noise-weight/--low-noise-weight), + which are accepted as aliases + +Typical use: + python tools/wan22/assemble_wan22_i2v_diffusers.py \ + --diffusers-skeleton /path/to/Wan2.2-I2V-A14B-Diffusers \ + --transformer-weight /path/to/high_noise_out/diffusion_pytorch_model.safetensors \ + --transformer-2-weight /path/to/low_noise_out/diffusion_pytorch_model.safetensors \ + --output-dir /path/to/Wan2.2-I2V-A14B-Custom-Diffusers +""" + +from __future__ import annotations + +import argparse +import json +import shutil +import sys +from dataclasses import dataclass +from pathlib import Path + +WEIGHT_CANDIDATES = ( + "diffusion_pytorch_model.safetensors", + "diffusion_pytorch_model.bin", + "diffusion_pytorch_model.pt", + "model.safetensors", + "pytorch_model.bin", + "model.pt", +) +WEIGHT_INDEX_CANDIDATES = ( + "diffusion_pytorch_model.safetensors.index.json", + "model.safetensors.index.json", + "pytorch_model.bin.index.json", +) + +ROOT_REQUIRED_FILES = ("model_index.json",) +ROOT_REQUIRED_DIRS = ("tokenizer", "text_encoder", "vae", "transformer", "transformer_2") +OPTIONAL_DIRS = ("image_encoder", "image_processor", "scheduler", "feature_extractor") + + +class AssembleError(RuntimeError): + pass + + +@dataclass(frozen=True) +class WeightSpec: + kind: str # "single" | "sharded" + single_file: Path | None = None + index_file: Path | None = None + shard_files: tuple[Path, ...] = () + + +def _load_shard_files_from_index(index_file: Path, role: str) -> tuple[Path, ...]: + try: + with index_file.open(encoding="utf-8") as f: + payload = json.load(f) + except Exception as exc: + raise AssembleError(f"Failed to parse {role} index file: {index_file}. error={exc}") from exc + + weight_map = payload.get("weight_map") + if not isinstance(weight_map, dict) or not weight_map: + raise AssembleError(f"Invalid {role} index file (missing/empty weight_map): {index_file}") + + shard_names = sorted({str(v) for v in weight_map.values()}) + shard_paths: list[Path] = [] + missing: list[str] = [] + for shard_name in shard_names: + shard_path = index_file.parent / shard_name + if not shard_path.is_file(): + missing.append(str(shard_path)) + else: + shard_paths.append(shard_path) + + if missing: + raise AssembleError(f"{role} index references missing shard file(s): " + ", ".join(missing)) + + if not shard_paths: + raise AssembleError(f"No shard files referenced by {role} index: {index_file}") + + return tuple(shard_paths) + + +def _resolve_weight_spec(path: Path, role: str) -> WeightSpec: + if path.is_file(): + return WeightSpec(kind="single", single_file=path) + + if path.is_dir(): + for name in WEIGHT_CANDIDATES: + candidate = path / name + if candidate.is_file(): + return WeightSpec(kind="single", single_file=candidate) + + for index_name in WEIGHT_INDEX_CANDIDATES: + index_file = path / index_name + if not index_file.is_file(): + continue + shard_files = _load_shard_files_from_index(index_file, role=role) + return WeightSpec( + kind="sharded", + index_file=index_file, + shard_files=shard_files, + ) + + shard_candidates = sorted(path.glob("diffusion_pytorch_model-*.safetensors")) + if shard_candidates: + raise AssembleError( + f"Detected sharded {role} files under {path}, but index json is missing. " + f"Expected one of: {', '.join(WEIGHT_INDEX_CANDIDATES)}" + ) + + raise AssembleError( + f"Cannot find {role} weight under directory: {path}. " + f"Expected one of single files [{', '.join(WEIGHT_CANDIDATES)}] " + f"or sharded index files [{', '.join(WEIGHT_INDEX_CANDIDATES)}]." + ) + + raise AssembleError(f"{role} path does not exist: {path}") + + +def _canonical_weight_name(weight_file: Path) -> str: + suffix = weight_file.suffix.lower() + if suffix == ".safetensors": + return "diffusion_pytorch_model.safetensors" + if suffix == ".bin": + return "diffusion_pytorch_model.bin" + if suffix == ".pt": + return "diffusion_pytorch_model.pt" + return weight_file.name + + +def _validate_skeleton(skeleton: Path) -> None: + if not skeleton.is_dir(): + raise AssembleError(f"--diffusers-skeleton is not a directory: {skeleton}") + + for file_name in ROOT_REQUIRED_FILES: + if not (skeleton / file_name).is_file(): + raise AssembleError(f"Missing required file in skeleton: {skeleton / file_name}") + + for dir_name in ROOT_REQUIRED_DIRS: + if not (skeleton / dir_name).is_dir(): + raise AssembleError(f"Missing required directory in skeleton: {skeleton / dir_name}") + + if not (skeleton / "transformer" / "config.json").is_file(): + raise AssembleError(f"Missing transformer config: {skeleton / 'transformer/config.json'}") + + if not (skeleton / "transformer_2" / "config.json").is_file(): + raise AssembleError(f"Missing transformer_2 config: {skeleton / 'transformer_2/config.json'}") + + +def _ensure_clean_output(output_dir: Path, overwrite: bool) -> None: + if output_dir.exists(): + if not overwrite: + raise AssembleError( + f"Output directory already exists: {output_dir}. Use --overwrite to remove and recreate it." + ) + shutil.rmtree(output_dir) + output_dir.mkdir(parents=True, exist_ok=False) + + +def _copy_or_link_dir(src: Path, dst: Path, asset_mode: str) -> None: + if asset_mode == "copy": + shutil.copytree(src, dst) + elif asset_mode == "symlink": + dst.symlink_to(src, target_is_directory=True) + else: + raise AssembleError(f"Unknown asset mode: {asset_mode}") + + +def _materialize_weight(weight: WeightSpec, dst_dir: Path, role: str) -> tuple[Path, ...]: + if weight.kind == "single": + assert weight.single_file is not None + dst = dst_dir / _canonical_weight_name(weight.single_file) + shutil.copy2(weight.single_file, dst) + return (dst,) + + if weight.kind == "sharded": + assert weight.index_file is not None + copied: list[Path] = [] + index_dst = dst_dir / weight.index_file.name + shutil.copy2(weight.index_file, index_dst) + copied.append(index_dst) + for shard_file in weight.shard_files: + shard_dst = dst_dir / shard_file.name + shutil.copy2(shard_file, shard_dst) + copied.append(shard_dst) + return tuple(copied) + + raise AssembleError(f"Unknown {role} weight kind: {weight.kind}") + + +def _assemble( + skeleton: Path, + output_dir: Path, + transformer_weight: WeightSpec, + transformer_2_weight: WeightSpec, + asset_mode: str, +) -> tuple[tuple[Path, ...], tuple[Path, ...]]: + shutil.copy2(skeleton / "model_index.json", output_dir / "model_index.json") + + for dir_name in ROOT_REQUIRED_DIRS: + if dir_name in ("transformer", "transformer_2"): + continue + _copy_or_link_dir(skeleton / dir_name, output_dir / dir_name, asset_mode) + + for dir_name in OPTIONAL_DIRS: + src_dir = skeleton / dir_name + if src_dir.is_dir(): + _copy_or_link_dir(src_dir, output_dir / dir_name, asset_mode) + + (output_dir / "transformer").mkdir(parents=True, exist_ok=True) + (output_dir / "transformer_2").mkdir(parents=True, exist_ok=True) + + shutil.copy2(skeleton / "transformer" / "config.json", output_dir / "transformer" / "config.json") + shutil.copy2(skeleton / "transformer_2" / "config.json", output_dir / "transformer_2" / "config.json") + + transformer_copied = _materialize_weight(transformer_weight, output_dir / "transformer", role="transformer") + transformer_2_copied = _materialize_weight( + transformer_2_weight, + output_dir / "transformer_2", + role="transformer_2", + ) + + return transformer_copied, transformer_2_copied + + +def _validate_output( + output_dir: Path, + transformer_copied: tuple[Path, ...], + transformer_2_copied: tuple[Path, ...], +) -> None: + if not (output_dir / "model_index.json").is_file(): + raise AssembleError("Output validation failed: model_index.json missing") + + required_paths = ( + output_dir / "tokenizer", + output_dir / "text_encoder", + output_dir / "vae", + output_dir / "transformer" / "config.json", + output_dir / "transformer_2" / "config.json", + *transformer_copied, + *transformer_2_copied, + ) + missing = [str(p) for p in required_paths if not p.exists()] + if missing: + raise AssembleError("Output validation failed, missing: " + ", ".join(missing)) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Assemble a Wan2.2-I2V-A14B-Diffusers directory while optionally " + "replacing transformer and transformer_2 weights." + ) + ) + parser.add_argument( + "--diffusers-skeleton", + type=Path, + required=True, + help="Path to a local Wan-AI/Wan2.2-I2V-A14B-Diffusers directory.", + ) + parser.add_argument( + "--transformer-weight", + type=Path, + help=( + "Optional checkpoint file, or directory containing either a single-file " + "weight or sharded index+shards for transformer/. If omitted, keep the " + "skeleton's original transformer weights." + ), + ) + parser.add_argument( + "--transformer-2-weight", + type=Path, + help=( + "Optional checkpoint file, or directory containing either a single-file " + "weight or sharded index+shards for transformer_2/. If omitted, keep the " + "skeleton's original transformer_2 weights." + ), + ) + parser.add_argument( + "--high-noise-weight", + type=Path, + help=argparse.SUPPRESS, + ) + parser.add_argument( + "--low-noise-weight", + type=Path, + help=argparse.SUPPRESS, + ) + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="Output directory for the assembled model.", + ) + parser.add_argument( + "--asset-mode", + choices=("symlink", "copy"), + default="symlink", + help=( + "How to materialize non-transformer assets (tokenizer/text_encoder/vae/optional dirs). " + "symlink saves disk and is default." + ), + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="Overwrite output-dir if it exists.", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + + skeleton = args.diffusers_skeleton.resolve() + output_dir = args.output_dir.resolve() + + if args.transformer_weight is not None and args.high_noise_weight is not None: + print( + "[ERROR] --transformer-weight and --high-noise-weight are aliases; please provide only one.", + file=sys.stderr, + ) + return 2 + if args.transformer_2_weight is not None and args.low_noise_weight is not None: + print( + "[ERROR] --transformer-2-weight and --low-noise-weight are aliases; please provide only one.", + file=sys.stderr, + ) + return 2 + + transformer_weight_arg = args.transformer_weight if args.transformer_weight is not None else args.high_noise_weight + transformer_2_weight_arg = ( + args.transformer_2_weight if args.transformer_2_weight is not None else args.low_noise_weight + ) + + transformer_input = ( + transformer_weight_arg.resolve() if transformer_weight_arg is not None else skeleton / "transformer" + ) + transformer_2_input = ( + transformer_2_weight_arg.resolve() if transformer_2_weight_arg is not None else skeleton / "transformer_2" + ) + + try: + _validate_skeleton(skeleton) + transformer_weight = _resolve_weight_spec(transformer_input, role="transformer") + transformer_2_weight = _resolve_weight_spec(transformer_2_input, role="transformer_2") + + _ensure_clean_output(output_dir, overwrite=args.overwrite) + transformer_copied, transformer_2_copied = _assemble( + skeleton=skeleton, + output_dir=output_dir, + transformer_weight=transformer_weight, + transformer_2_weight=transformer_2_weight, + asset_mode=args.asset_mode, + ) + _validate_output(output_dir, transformer_copied, transformer_2_copied) + except AssembleError as exc: + print(f"[ERROR] {exc}", file=sys.stderr) + return 2 + + def _weight_summary(copied: tuple[Path, ...]) -> str: + if len(copied) == 1: + return copied[0].name + return f"{copied[0].name} + {len(copied) - 1} shard files" + + print("[OK] Assembled Wan2.2 I2V Diffusers directory:") + print(f" output_dir: {output_dir}") + print(f" transformer weight: {_weight_summary(transformer_copied)}") + print(f" transformer_2 weight: {_weight_summary(transformer_2_copied)}") + print("\nUse it with vLLM-Omni, for example:") + print(f" vllm serve {output_dir} --omni --port 8091") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/vllm_omni/__init__.py b/vllm_omni/__init__.py index cec8b0af7e..65ad79c725 100644 --- a/vllm_omni/__init__.py +++ b/vllm_omni/__init__.py @@ -12,6 +12,12 @@ processing """ +# We import version early, because it will warn if vLLM / vLLM Omni +# are not using the same major + minor version (if vLLM is installed). +# We should do this before applying patch, because vLLM imports might +# throw in patch if the versions differ. +from .version import __version__, __version_tuple__ # isort:skip # noqa: F401 + try: from . import patch # noqa: F401 except ModuleNotFoundError as exc: # pragma: no cover - optional dependency @@ -25,8 +31,6 @@ from .config import OmniModelConfig -from .version import __version__, __version_tuple__ # isort:skip - def __getattr__(name: str): # Lazy import for AsyncOmni and Omni to avoid pulling in heavy diff --git a/vllm_omni/assets/video.py b/vllm_omni/assets/video.py index 98b1f7e4e2..6a5f3204a9 100644 --- a/vllm_omni/assets/video.py +++ b/vllm_omni/assets/video.py @@ -1,6 +1,6 @@ -import librosa import numpy as np from vllm.assets.video import VideoAsset +from vllm.multimodal.media.audio import load_audio def extract_video_audio(path: str = None, sampling_rate: int = 16000) -> np.ndarray: @@ -12,5 +12,5 @@ def extract_video_audio(path: str = None, sampling_rate: int = 16000) -> np.ndar """ if not path: path = VideoAsset(name="baby_reading").video_path - audio_signal, sr = librosa.load(path, sr=sampling_rate) + audio_signal, sr = load_audio(path, sr=sampling_rate) return audio_signal diff --git a/vllm_omni/benchmarks/data_modules/daily_omni_dataset.py b/vllm_omni/benchmarks/data_modules/daily_omni_dataset.py new file mode 100644 index 0000000000..01b86d0fd1 --- /dev/null +++ b/vllm_omni/benchmarks/data_modules/daily_omni_dataset.py @@ -0,0 +1,887 @@ +"""Daily-Omni Dataset loader for benchmark. + +Daily-Omni is an audio-visual reasoning benchmark with 684 videos +and 1,197 multiple-choice QA pairs across 6 major task types. + +Dataset source: https://huggingface.co/datasets/liarliar/Daily-Omni + +Supports loading QA metadata from: +- Local JSON file (``qa_json_path``): recommended for offline/air-gapped environments +- HuggingFace datasets (``dataset_path``): legacy online mode + +The videos must be separately downloaded and extracted from Videos.tar. + +Why ``BenchmarkDataset`` instead of ``HuggingFaceDataset``? + vLLM's ``HuggingFaceDataset`` is a thin wrapper whose ``__init__`` always ends by calling + ``load_data()`` → ``datasets.load_dataset(...)`` with a required Hub id and split. That + contract fits "Hub-only" benches, but Daily-Omni also needs **offline QA metadata** from a + local ``qa.json`` without touching the network. Subclassing ``HuggingFaceDataset`` would + mean fighting the parent constructor (fake ``dataset_path``, reordering ``load_data``, or + duplicating half the parent) and would still imply ``datasets`` is always relevant. + + This class therefore inherits only ``BenchmarkDataset`` (minimal: ``dataset_path``, + ``random_seed``, ``self.data``) and implements **two explicit loaders**: + ``_load_from_local_json`` (default path for air-gapped runs) and ``_load_from_huggingface`` + (optional legacy path for users who prefer ``datasets`` + Hub cache). The latter is **not** + inheritance; it is the same Hub rows as before, factored into a helper so one class can + serve both deployment modes without mandatory ``datasets`` when using ``qa_json_path``. + +Usage: + from vllm_omni.benchmarks.data_modules.daily_omni_dataset import DailyOmniDataset + + # Local JSON mode (recommended) + dataset = DailyOmniDataset( + qa_json_path="/path/to/qa.json", + video_dir="/path/to/Videos", + random_seed=42, + ) + + # HuggingFace mode (legacy, requires network) + dataset = DailyOmniDataset( + dataset_path="liarliar/Daily-Omni", + dataset_split="train", + random_seed=42, + ) + requests = dataset.sample( + tokenizer=tokenizer, + num_requests=100, + output_len=256, + ) +""" + +import base64 +import json +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal + +try: + from vllm.benchmarks.datasets import BenchmarkDataset, SampleRequest +except ImportError: + # Fallback: if BenchmarkDataset not available, use base class from same module + from vllm.benchmarks.datasets import HuggingFaceDataset as BenchmarkDataset + from vllm.benchmarks.datasets import SampleRequest +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.hf import get_cached_tokenizer + +try: + from datasets import load_dataset +except ImportError: + load_dataset = None + +logger = logging.getLogger(__name__) + + +class _ListDatasetIterator: + """Simple iterator wrapper around a list to mimic HuggingFace streaming dataset behavior.""" + + def __init__(self, data: list[dict[str, Any]]) -> None: + self._data = data + self._index = 0 + + def __iter__(self): + self._index = 0 + return self + + def __next__(self) -> dict[str, Any]: + if self._index >= len(self._data): + raise StopIteration + item = self._data[self._index] + self._index += 1 + return item + + def __len__(self) -> int: + return len(self._data) + + def __getitem__(self, idx: int | slice) -> dict[str, Any] | list[dict[str, Any]]: + return self._data[idx] + + +# Aligns with Lliar-liar/Daily-Omni CLI ``--input_mode`` (test_model/*/testmodel.py). +DailyOmniInputMode = Literal["all", "visual", "audio"] + +# ``build_conversation()`` in Daily-Omni ``test_model/Qwen2.5-Omni/testmodel.py`` (verbatim). +DAILY_OMNI_SYSTEM_TEXT = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, " + "capable of perceiving auditory and visual inputs, as well as generating text and speech." +) + + +@dataclass +class DailyOmniSampleRequest(SampleRequest): + """``SampleRequest`` with Daily-Omni gold labels for post-run accuracy scoring.""" + + daily_omni_gold_answer: str = "" + daily_omni_video_id: str = "" + daily_omni_task_type: str = "" + #: Official qa.json ``video_duration`` (e.g. ``30s``, ``60s``) for leaderboard-style breakdown. + daily_omni_video_duration: str = "" + #: Official ``video_category`` (YouTube-style category string) for per-category accuracy. + daily_omni_video_category: str = "" + #: Extra JSON fields merged into chat-completions ``extra_body`` (e.g. ``mm_processor_kwargs``). + omni_extra_body: dict[str, Any] | None = None + #: Full OpenAI ``messages`` (system + user) mirroring upstream Daily-Omni conversation. + omni_chat_messages: list[dict[str, Any]] | None = None + #: Used only when ``omni_chat_messages`` is None (non-Daily-Omni-style requests). + omni_chat_mm_position: Literal["first", "last"] = "last" + + +class DailyOmniDataset(BenchmarkDataset): + """Daily-Omni audio-visual QA dataset for benchmarking. + + Inherits ``BenchmarkDataset`` only (not ``HuggingFaceDataset``): see module docstring for why + Hub loading lives in ``_load_from_huggingface`` instead of subclassing the HF base class. + + The dataset includes: + - 684 videos from daily life scenarios (available in Videos.tar) + - 1,197 multiple-choice QA pairs in qa.json + - 6 major task categories + + QA metadata can be loaded from: + - Local JSON file (``qa_json_path``): recommended for offline/air-gapped environments + - HuggingFace datasets (``dataset_path``): legacy online mode + + The videos must be separately downloaded and extracted from Videos.tar. + + Args: + qa_json_path: Path to local qa.json file (offline mode, preferred). When provided, + ``dataset_path`` and ``dataset_split`` are ignored. + dataset_path: HuggingFace dataset path (e.g., "liarliar/Daily-Omni"). Used only if + ``qa_json_path`` is not provided (legacy online mode). + dataset_split: Dataset split to use (default: "train"). Used only in online mode. + random_seed: Random seed for shuffling + video_dir: Directory containing extracted video files (default: None) + input_mode: Which modalities to send, matching upstream Daily-Omni ``--input_mode``: + ``all`` — video + WAV (default; official audio-visual protocol); + ``visual`` — video only; + ``audio`` — extracted WAV only (requires ``{video_id}/{video_id}_audio.wav`` under ``video_dir``). + max_duration_seconds: Reserved for future ffprobe-based filtering; currently **not applied** + when building requests (metadata ``video_duration`` is still passed through for eval). + dataset_subset: Optional HuggingFace subset name (``load_dataset(..., name=...)``); used by bench + ``--hf-subset`` / patch. + no_stream: If True, load the Hub split non-streaming (matches bench ``--no-stream``). + inline_local_video: If True, embed local MP4 as ``data:video/mp4;base64,...`` in requests so + the API server does not need ``--allowed-local-media-path`` (large JSON; use for small runs). + When ``input_mode`` is ``audio`` or ``all``, local WAV is embedded the same way + (``data:audio/wav;base64,...``). + trust_remote_code: Whether to trust remote code when loading HuggingFace dataset + (online mode only). + """ + + SUPPORTED_DATASET_PATHS: set[str] = { + "liarliar/Daily-Omni", + } + #: Default Hub id for synthetic video URLs when ``qa_json_path`` is used (``dataset_path`` None). + DEFAULT_HF_DATASET_ID = "liarliar/Daily-Omni" + IS_MULTIMODAL = True + DEFAULT_OUTPUT_LEN = 256 + + def __init__( + self, + qa_json_path: str | None = None, + dataset_path: str | None = None, + dataset_split: str = "train", + random_seed: int = 0, + video_dir: str | None = None, + input_mode: DailyOmniInputMode = "all", + inline_local_video: bool = False, + trust_remote_code: bool = False, + max_duration_seconds: float | None = None, + dataset_subset: str | None = None, + no_stream: bool = False, + **kwargs, + ) -> None: + if input_mode not in ("all", "visual", "audio"): + raise ValueError(f"input_mode must be 'all', 'visual', or 'audio', got {input_mode!r}") + + # Validate arguments: need either local JSON or HF path + if qa_json_path is None and dataset_path is None: + raise ValueError( + "Either 'qa_json_path' (local JSON) or 'dataset_path' (HuggingFace) must be provided. " + "For offline/air-gapped environments, download qa.json and use qa_json_path." + ) + + # Store configuration + self.qa_json_path = Path(qa_json_path) if qa_json_path else None + self.dataset_path = dataset_path + self.dataset_split = dataset_split + self.dataset_subset = dataset_subset + #: Match vLLM ``HuggingFaceDataset`` / bench CLI ``--no-stream``. + self._hf_streaming = not no_stream + self.video_dir = Path(video_dir) if video_dir else None + self.inline_local_video = inline_local_video + self.input_mode: DailyOmniInputMode = input_mode + self.max_duration_seconds = max_duration_seconds + self.trust_remote_code = trust_remote_code + + #: In-process cache of ffprobe durations only (no disk persistence). + self._video_durations: dict[str, float] = {} + + # Initialize parent BenchmarkDataset + super().__init__( + dataset_path=dataset_path if qa_json_path is None else None, + random_seed=random_seed, + **kwargs, + ) + + # Load data based on mode + self.load_data() + + # Verify dataset info + logger.info( + "Loaded Daily-Omni dataset: mode=%s, source=%s, random_seed=%d, input_mode=%s, max_duration=%s", + "local_json" if self.qa_json_path else "huggingface", + str(self.qa_json_path) if self.qa_json_path else f"{dataset_path}/{dataset_split}", + random_seed, + input_mode, + f"{max_duration_seconds}s" if max_duration_seconds else "unlimited", + ) + + def load_data(self) -> None: + """Populate ``self.data`` from either local JSON or the Hub. + + See module docstring: we do not subclass ``HuggingFaceDataset`` because Daily-Omni needs + a first-class offline path; Hub loading is an optional branch implemented below. + """ + if self.qa_json_path is not None: + self._load_from_local_json() + else: + self._load_from_huggingface() + + def _load_from_local_json(self) -> None: + """Load QA data from local JSON file.""" + if not self.qa_json_path.exists(): + raise FileNotFoundError(f"QA JSON file not found: {self.qa_json_path}") + + with open(self.qa_json_path, encoding="utf-8") as f: + data = json.load(f) + + # Support both list format and dict with "train"/"test" splits + if isinstance(data, dict): + # Try to get the requested split, fallback to first available + split_data = data.get(self.dataset_split) + if split_data is None: + available = list(data.keys()) + if available: + logger.warning( + "Split '%s' not found in %s, using '%s' instead", + self.dataset_split, + self.qa_json_path, + available[0], + ) + split_data = data[available[0]] + else: + split_data = [] + data = split_data + + if not isinstance(data, list): + raise ValueError(f"Expected list of QA items in JSON, got {type(data).__name__}") + + # Shuffle if requested + if not getattr(self, "disable_shuffle", False) and self.random_seed is not None: + import random + + rng = random.Random(self.random_seed) + shuffled = data[:] + rng.shuffle(shuffled) + data = shuffled + + # Create an iterator-like wrapper for compatibility + self.data = _ListDatasetIterator(data) + + def _load_from_huggingface(self) -> None: + """Load QA rows via ``datasets.load_dataset`` (legacy / convenience path). + + Kept for backward compatibility: callers can still pass ``dataset_path=liarliar/Daily-Omni`` + and get the same parquet-backed rows as the Hub dataset card, with streaming (or + non-streaming if ``no_stream=True``) and shuffle. + + This is intentionally **not** implemented by subclassing ``HuggingFaceDataset``: that base + always runs Hub ``load_dataset`` from its constructor and expects a Hub id as the primary + API; Daily-Omni instead chooses the source in ``load_data()`` (JSON vs Hub) while sharing + one ``sample()`` / request-building implementation for both. + """ + if load_dataset is None: + raise ImportError( + "datasets library is required for HuggingFace mode. " + "Install with: pip install datasets, or use local JSON mode instead." + ) + + ds = load_dataset( + self.dataset_path, + name=self.dataset_subset, + split=self.dataset_split, + streaming=self._hf_streaming, + trust_remote_code=self.trust_remote_code, + ) + if not getattr(self, "disable_shuffle", False): + ds = ds.shuffle(seed=self.random_seed) + self.data = ds + + def get_task_statistics(self) -> dict[str, int]: + """Get distribution of task types in the dataset. + + Returns: + Dict mapping task type to count + """ + stats: dict[str, int] = {} + for item in self.data: + row = self._coerce_row(item) + fields = self._normalize_qa_fields(row) + task_type = fields["task_type"] or "unknown" + stats[task_type] = stats.get(task_type, 0) + 1 + return stats + + @staticmethod + def _coerce_row(item: Any) -> dict[str, Any]: + """Turn a dataset row into a plain dict (Arrow / Mapping).""" + if isinstance(item, dict): + return item + if hasattr(item, "as_py"): + return dict(item.as_py()) # pyarrow Row + try: + return dict(item) + except (TypeError, ValueError): + return {k: item[k] for k in item} # type: ignore[misc] + + @staticmethod + def _normalize_qa_fields(row: dict[str, Any]) -> dict[str, Any]: + """Map official Daily-Omni qa.json / Hub schema to internal fields. + + Official fields (see liarliar/Daily-Omni ``qa.json``): ``Question``, ``Choice`` (list), + ``Answer``, ``video_id``, ``Type``, ``video_duration`` (``30s`` / ``60s``), ``video_category``, + plus other category columns. Legacy aliases (lowercase / older loaders) are still accepted. + """ + out: dict[str, Any] = {} + + out["question"] = str(row.get("Question") or row.get("question") or "").strip() + vid = row.get("video_id") if row.get("video_id") is not None else row.get("video") + out["video_id"] = str(vid).strip() if vid is not None else "" + out["task_type"] = str(row.get("Type") or row.get("task_type") or row.get("type") or "").strip() + vc = row.get("video_category") if row.get("video_category") is not None else row.get("videoCategory") + out["video_category"] = str(vc).strip() if vc is not None else "" + vd = row.get("video_duration") if row.get("video_duration") is not None else row.get("videoDuration") + out["video_duration"] = str(vd).strip() if vd is not None else "" + out["answer"] = str(row.get("Answer") or row.get("answer") or "").strip() + vu = row.get("video_url") if row.get("video_url") is not None else row.get("Video_URL") + out["video_url"] = str(vu).strip() if vu is not None and str(vu).strip() else None + + choice = row.get("Choice") + if choice is None: + choice = row.get("options") or row.get("choice") + out["choice"] = choice + + return out + + def sample( + self, + tokenizer: TokenizerLike, + num_requests: int, + output_len: int | None = None, + request_id_prefix: str = "", + no_oversample: bool = False, + **kwargs, + ) -> list[SampleRequest]: + """Sample requests from Daily-Omni dataset. + + Args: + tokenizer: Tokenizer for computing prompt length + num_requests: Number of requests to sample + output_len: Target output length in tokens (default: 256) + request_id_prefix: Prefix for request IDs + no_oversample: If True, do not oversample if fewer examples available + **kwargs: Additional arguments (ignored) + + Returns: + List of SampleRequest objects with video URLs and prompts + """ + if output_len is None: + output_len = self.DEFAULT_OUTPUT_LEN + + sampled_requests: list[SampleRequest] = [] + ind = 0 + cached_tokenizer = get_cached_tokenizer(tokenizer) + + # Iterate over shuffled dataset + for item in self.data: + if len(sampled_requests) >= num_requests: + break + + request = self._create_sample_request( + self._coerce_row(item), cached_tokenizer, output_len, request_id_prefix, ind + ) + if request: + sampled_requests.append(request) + ind += 1 + + logger.info("Created %d sample requests from Daily-Omni dataset", len(sampled_requests)) + + # Handle oversampling if needed + self.maybe_oversample_requests(sampled_requests, num_requests, request_id_prefix, no_oversample) + + return sampled_requests + + def _create_sample_request( + self, + qa_item: dict[str, Any], + tokenizer: TokenizerLike, + output_len: int, + request_id_prefix: str, + index: int, + ) -> SampleRequest | None: + """Create a SampleRequest from a QA item. + + Args: + qa_item: QA pair from the dataset + tokenizer: Tokenizer + output_len: Target output length + request_id_prefix: Prefix for request ID + index: Request index + + Returns: + SampleRequest or None if invalid + """ + fields = self._normalize_qa_fields(qa_item) + video_id = fields["video_id"] + question = fields["question"] + choice = fields["choice"] + task_type = fields["task_type"] + video_url = fields["video_url"] + video_duration = fields.get("video_duration") or "" + video_category = fields.get("video_category") or "" + + if not video_id and not video_url: + logger.warning("Skipping item: no video_id / video_url") + return None + + if not question: + logger.warning("Skipping item: no question found") + return None + + # Official layout after extracting Videos.tar (see Lliar-liar/Daily-Omni test_model): + # {video_base_dir}/{video_id}/{video_id}_video.mp4 + mm_payload, omni_extra, mm_pos = self._compose_daily_omni_multimodal(video_id, video_url) + if not mm_payload: + return None + + messages = self._build_daily_omni_openai_messages(mm_payload, question, choice) + user_text = self._official_daily_omni_user_prompt(question, choice) + # Text-only length estimate (same as before: no MM token count in bench). + prompt_len = len(tokenizer.encode(f"{DAILY_OMNI_SYSTEM_TEXT}\n{user_text}")) + + return DailyOmniSampleRequest( + prompt=user_text, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=None, + request_id=f"{request_id_prefix}{index}", + daily_omni_gold_answer=fields["answer"], + daily_omni_video_id=video_id, + daily_omni_task_type=task_type, + daily_omni_video_duration=video_duration, + daily_omni_video_category=video_category, + omni_extra_body=omni_extra, + omni_chat_messages=messages, + omni_chat_mm_position=mm_pos, + ) + + @staticmethod + def _official_video_relpath(video_id: str) -> str: + """Relative path inside extracted ``Videos/`` per upstream Daily-Omni scripts.""" + return f"{video_id}/{video_id}_video.mp4" + + @staticmethod + def _official_audio_relpath(video_id: str) -> str: + """Relative path for extracted WAV per upstream ``get_audio_path``.""" + return f"{video_id}/{video_id}_audio.wav" + + def _resolve_local_video_path(self, video_id: str) -> Path | None: + """Pick an existing file under ``video_dir`` (official layout + flat fallback).""" + if not self.video_dir or not video_id: + return None + + candidates = [ + self.video_dir / self._official_video_relpath(video_id), + self.video_dir / f"{video_id}.mp4", # flat layout (custom mirrors / outdated docs) + ] + seen: set[Path] = set() + for p in candidates: + rp = p.resolve() + if rp in seen: + continue + seen.add(rp) + if p.exists(): + return p + return None + + def _resolve_local_audio_path(self, video_id: str) -> Path | None: + """Pick an existing WAV under ``video_dir`` (official layout + flat fallback).""" + if not self.video_dir or not video_id: + return None + candidates = [ + self.video_dir / self._official_audio_relpath(video_id), + self.video_dir / f"{video_id}.wav", + ] + seen: set[Path] = set() + for p in candidates: + rp = p.resolve() + if rp in seen: + continue + seen.add(rp) + if p.exists(): + return p + return None + + def _local_file_to_video_url_payload(self, video_path: Path) -> dict[str, Any]: + """Build OpenAI-style video_url part for a resolved local file. + + vLLM rejects ``file://`` unless the server was started with + ``--allowed-local-media-path`` set to a directory that **contains** the file + (typically the extracted ``Videos`` root). Use ``inline_local_video=True`` to + send base64 data URLs instead (no server path allowlist; larger requests). + """ + path = video_path.expanduser().resolve() + if self.inline_local_video: + raw = path.read_bytes() + b64 = base64.b64encode(raw).decode("ascii") + return { + "type": "video_url", + "video_url": {"url": f"data:video/mp4;base64,{b64}"}, + } + return { + "type": "video_url", + "video_url": {"url": path.as_uri()}, + } + + def _local_file_to_audio_url_payload(self, audio_path: Path) -> dict[str, Any]: + """Build OpenAI-style ``audio_url`` part for a resolved local WAV file.""" + path = audio_path.expanduser().resolve() + if self.inline_local_video: + raw = path.read_bytes() + b64 = base64.b64encode(raw).decode("ascii") + return { + "type": "audio_url", + "audio_url": {"url": f"data:audio/wav;base64,{b64}"}, + } + return { + "type": "audio_url", + "audio_url": {"url": path.as_uri()}, + } + + def _get_video_content( + self, + video_id: str, + video_url: str | None, + ) -> dict[str, Any] | None: + """Resolve video for OpenAI-style ``video_url`` content. + + Upstream uses ``get_video_path(video_id, base) -> base/video_id/video_id_video.mp4``. + The Hub repo only publishes ``Videos.tar``; use ``--daily-omni-video-dir`` pointing + at the extracted ``Videos`` folder (parent of per-``video_id`` subdirs). + + For ``file://`` URLs, start ``vllm serve`` with e.g. + ``--allowed-local-media-path /same/path/as/daily-omni-video-dir``. + """ + if video_url: + url = video_url + if not url.startswith(("http://", "https://", "file://")): + url = f"https://{url.lstrip('/')}" + return {"type": "video_url", "video_url": {"url": url}} + + if self.video_dir and video_id: + video_path = self._resolve_local_video_path(video_id) + if video_path is not None: + return self._local_file_to_video_url_payload(video_path) + logger.warning( + "Video not found under video_dir=%s for video_id=%r (expected %s or %s)", + self.video_dir, + video_id, + self._official_video_relpath(video_id), + f"{video_id}.mp4", + ) + + if video_id: + repo = self.dataset_path or self.DEFAULT_HF_DATASET_ID + rel = self._official_video_relpath(video_id) + hf_video_url = f"https://huggingface.co/datasets/{repo}/resolve/main/Videos/{rel}" + logger.debug( + "Using HF video URL (likely 404 — Hub ships Videos.tar only): %s", + hf_video_url, + ) + return {"type": "video_url", "video_url": {"url": hf_video_url}} + + logger.error("Could not determine video source for video_id=%r", video_id) + return None + + def _get_audio_content(self, video_id: str) -> dict[str, Any] | None: + """Resolve extracted WAV for OpenAI-style ``audio_url`` (local files only).""" + if not self.video_dir or not video_id: + logger.warning( + "Daily-Omni input_mode %r requires --daily-omni-video-dir with %s", + self.input_mode, + self._official_audio_relpath(video_id), + ) + return None + audio_path = self._resolve_local_audio_path(video_id) + if audio_path is not None: + return self._local_file_to_audio_url_payload(audio_path) + logger.warning( + "Audio not found under video_dir=%s for video_id=%r (expected %s or %s)", + self.video_dir, + video_id, + self._official_audio_relpath(video_id), + f"{video_id}.wav", + ) + return None + + def _compose_daily_omni_multimodal( + self, + video_id: str, + video_url: str | None, + ) -> tuple[dict[str, Any] | list[dict[str, Any]] | None, dict[str, Any] | None, Literal["first", "last"]]: + """Build ``multi_modal_data`` and request extras for the active ``input_mode``. + + Mirrors upstream Daily-Omni: separate video + WAV with ``use_audio_in_video=False``. + """ + extra: dict[str, Any] = {"mm_processor_kwargs": {"use_audio_in_video": False}} + mode = self.input_mode + + if mode == "visual": + v = self._get_video_content(video_id, video_url) + return v, extra, "last" + + if mode == "audio": + a = self._get_audio_content(video_id) + return a, extra, "first" + + v = self._get_video_content(video_id, video_url) + a = self._get_audio_content(video_id) + if not v or not a: + return None, None, "first" + return [v, a], extra, "first" + + @staticmethod + def _media_desc_for_official_prompt(mode: DailyOmniInputMode) -> str: + """``media_desc`` in upstream ``build_conversation``.""" + if mode == "audio": + return "given audio" + if mode == "all": + return "given video and audio together" + return "given video" + + @staticmethod + def _choices_repr_for_official_prompt(choice: Any) -> str: + """Format ``Choice`` from qa.json for the model (one option per line when possible). + + Using ``str(list)`` embeds Python list brackets and quotes, which is poor for MCQ + reading; lists/tuples are joined with newlines instead. Other shapes fall back to + ``str(choice)`` for parity with exotic upstream payloads. + """ + if choice is None: + return "" + if isinstance(choice, (list, tuple)): + lines = [str(x).strip() for x in choice if str(x).strip()] + return "\n".join(lines) + if isinstance(choice, dict): + return "\n".join(f"{k}. {v}" for k, v in choice.items()) + return str(choice) + + def _official_daily_omni_user_prompt(self, question: str, choice: Any) -> str: + """User text block from Daily-Omni ``build_conversation`` (after media parts).""" + task_prompt = self._media_desc_for_official_prompt(self.input_mode) + choices = self._choices_repr_for_official_prompt(choice) + # Single f-string with explicit newlines avoids accidental implicit concatenation + # gluing sentences (e.g. ``...media_desc.Select...``) when editing. + return ( + "Your task is to accurately answer multiple-choice questions " + f"based on the {task_prompt}.\n" + "Select the single most accurate answer from the given choices.\n" + f"Question: {question}\n" + f"Choices: {choices}\n" + "Your answer should be a capital letter representing your choice: " + "A, B, C, or D. Don't generate any other text.\n" + ) + + def _build_daily_omni_openai_messages( + self, + mm_payload: dict[str, Any] | list[dict[str, Any]], + question: str, + choice: Any, + ) -> list[dict[str, Any]]: + """Map upstream conversation to OpenAI Chat Completions ``messages`` (video_url / audio_url parts).""" + user_text = self._official_daily_omni_user_prompt(question, choice) + mm_list: list[dict[str, Any]] = mm_payload if isinstance(mm_payload, list) else [mm_payload] + user_content: list[dict[str, Any]] = [*mm_list, {"type": "text", "text": user_text}] + return [ + {"role": "system", "content": [{"type": "text", "text": DAILY_OMNI_SYSTEM_TEXT}]}, + {"role": "user", "content": user_content}, + ] + + def sample_by_task_type( + self, + tokenizer: TokenizerLike, + task_type: str, + num_samples: int, + output_len: int | None = None, + request_id_prefix: str = "", + **kwargs, + ) -> list[SampleRequest]: + """Sample requests filtered by task type. + + Args: + tokenizer: Tokenizer + task_type: Task type to filter by + num_samples: Number of samples + output_len: Target output length + request_id_prefix: Prefix for request IDs + **kwargs: Additional sampling arguments + + Returns: + List of SampleRequest objects matching the task type + """ + if output_len is None: + output_len = self.DEFAULT_OUTPUT_LEN + + filtered = [ + item for item in self.data if self._normalize_qa_fields(self._coerce_row(item))["task_type"] == task_type + ] + + available = len(filtered) + if available < num_samples: + logger.warning( + "Only %d samples available for task type '%s', requested %d", + available, + task_type, + num_samples, + ) + num_samples = available + + sampled_requests: list[SampleRequest] = [] + cached_tokenizer = get_cached_tokenizer(tokenizer) + + for i, item in enumerate(filtered[:num_samples]): + request = self._create_sample_request(item, cached_tokenizer, output_len, request_id_prefix, i) + if request: + sampled_requests.append(request) + + return sampled_requests + + def __repr__(self) -> str: + return ( + f"DailyOmniDataset(" + f"dataset_path={self.dataset_path!r}, " + f"dataset_split={self.dataset_split!r}, " + f"video_dir={self.video_dir!r}, " + f"input_mode={self.input_mode!r}, " + f"inline_local_video={self.inline_local_video!r}, " + f"max_duration_seconds={self.max_duration_seconds}, " + f"random_seed={self.random_seed}" + f")" + ) + + +def load_daily_omni_dataset( + qa_json_path: str | None = None, + dataset_path: str | None = None, + dataset_split: str = "train", + random_seed: int = 0, + video_dir: str | None = None, + input_mode: DailyOmniInputMode = "all", + max_duration_seconds: float | None = None, + dataset_subset: str | None = None, + no_stream: bool = False, + **kwargs, +) -> DailyOmniDataset: + """Convenience function to load Daily-Omni dataset. + + Args: + qa_json_path: Path to local qa.json file (recommended for offline/air-gapped environments). + When provided, ``dataset_path`` is ignored. + dataset_path: HuggingFace dataset path (default: liarliar/Daily-Omni). Used only if + ``qa_json_path`` is not provided (legacy online mode). + dataset_split: Dataset split to use (default: "train") + random_seed: Random seed for shuffling + video_dir: Directory containing extracted ``Videos/`` tree (MP4 and, for ``all``/``audio``, WAV) + input_mode: ``visual`` | ``audio`` | ``all`` (same semantics as upstream Daily-Omni) + max_duration_seconds: Maximum video duration in seconds (e.g., 30 for 30s subset, 60 for 60s subset); + uses ffprobe on local files under ``video_dir`` (in-memory cache only for this process). + **kwargs: Additional arguments passed to DailyOmniDataset + + Returns: + DailyOmniDataset instance + + Example: + >>> from vllm_omni.benchmarks.data_modules.daily_omni_dataset import load_daily_omni_dataset + + # Local JSON mode (recommended for offline) + >>> dataset = load_daily_omni_dataset( + ... qa_json_path="/path/to/qa.json", + ... video_dir="/path/to/Daily-Omni/Videos", + ... random_seed=42, + ... max_duration_seconds=30, + ... ) + + # HuggingFace mode (legacy online) + >>> dataset = load_daily_omni_dataset( + ... dataset_path="liarliar/Daily-Omni", + ... video_dir="/path/to/Daily-Omni/Videos", + ... random_seed=42, + ... ) + >>> requests = dataset.sample(tokenizer, num_requests=100) + """ + return DailyOmniDataset( + qa_json_path=qa_json_path, + dataset_path=dataset_path, + dataset_split=dataset_split, + random_seed=random_seed, + video_dir=video_dir, + input_mode=input_mode, + max_duration_seconds=max_duration_seconds, + dataset_subset=dataset_subset, + no_stream=no_stream, + **kwargs, + ) + + +def get_daily_omni_statistics( + qa_json_path: str | None = None, + dataset_path: str | None = DailyOmniDataset.DEFAULT_HF_DATASET_ID, + dataset_split: str = "train", +) -> dict[str, Any]: + """Get statistics about the Daily-Omni dataset. + + Args: + qa_json_path: Path to local qa.json file (recommended for offline/air-gapped environments). + When provided, ``dataset_path`` is ignored. + dataset_path: HuggingFace dataset path. Defaults to ``DailyOmniDataset.DEFAULT_HF_DATASET_ID`` + when ``qa_json_path`` is omitted. Pass ``None`` only together with ``qa_json_path``. + dataset_split: Dataset split to use (default: "train") + + Returns: + Statistics dict with task type distribution and other info + + Example: + >>> from vllm_omni.benchmarks.data_modules.daily_omni_dataset import get_daily_omni_statistics + + # Local JSON mode + >>> stats = get_daily_omni_statistics(qa_json_path="/path/to/qa.json") + + # HuggingFace mode + >>> stats = get_daily_omni_statistics(dataset_path="liarliar/Daily-Omni") + >>> print(f"Total QA pairs: {stats['total_qa_pairs']}") + >>> print(f"Task distribution: {stats['task_distribution']}") + """ + dataset = DailyOmniDataset( + qa_json_path=qa_json_path, + dataset_path=dataset_path, + dataset_split=dataset_split, + ) + task_stats = dataset.get_task_statistics() + + source = str(qa_json_path) if qa_json_path else f"{dataset_path}/{dataset_split}" + return { + "source": source, + "total_qa_pairs": len(list(dataset.data)), + "task_distribution": task_stats, + } diff --git a/vllm_omni/benchmarks/data_modules/daily_omni_eval.py b/vllm_omni/benchmarks/data_modules/daily_omni_eval.py new file mode 100644 index 0000000000..ecc9edc844 --- /dev/null +++ b/vllm_omni/benchmarks/data_modules/daily_omni_eval.py @@ -0,0 +1,406 @@ +"""Daily-Omni multiple-choice accuracy scoring for vLLM-Omni bench serve. + +Compares model ``generated_text`` to dataset ``Answer`` (A/B/C/D). + +**Alignment with open-source** (`Lliar-liar/Daily-Omni` ``test_model/.../testmodel.py``): + +- Answer extraction defaults to the same rules as ``extract_choice_letter`` (strip after an + ``assistant`` marker, then leading ``A``–``D``, else first ``\\b[A-D]\\b``). Set env + ``DAILY_OMNI_EXTRACT_MODE=relaxed`` to use the older vLLM-Omni heuristics (last ``answer:``, + tail scan, etc.). +- Overall accuracy comparable to the official script uses **successful HTTP responses only** as + the denominator (their ``valid_questions = total - failed`` excludes inference / I/O skips). + We also report ``daily_omni_accuracy_incl_http_fail`` where each failed request counts as a + wrong answer in the denominator (stricter throughput-bench view). +- **By video length:** mirrors upstream ``--- Accuracy by Video Duration ---`` for ``30s`` / + ``60s`` (``qa.json`` ``video_duration``): ``daily_omni_per_duration*`` metrics and a printed block. +- **By video category:** mirrors ``--- Accuracy by Video Category ---`` using ``video_category`` + from ``qa.json`` (``daily_omni_per_category*``; empty category is bucketed as ``unknown``). +- **Correctness:** uses the same ``evaluate_answer`` rule as upstream (truthy extracted letter vs + raw ``Answer`` string, both ``strip().upper()``). Rows with empty ``Answer`` are skipped + (``no_gold``), matching missing-field skips in the official loop. +""" + +from __future__ import annotations + +import os +import re +from typing import Any + +from vllm.benchmarks.lib.endpoint_request_func import RequestFuncOutput + +from vllm_omni.benchmarks.data_modules.daily_omni_dataset import DailyOmniSampleRequest + +_VALID = frozenset("ABCD") + +# Official ``testmodel.py`` buckets (``qa.json`` ``video_duration``). +DAILY_OMNI_DURATION_KEYS: tuple[str, ...] = ("30s", "60s") + + +def extract_choice_letter_official(text: str | None) -> str | None: + """Port of Daily-Omni ``extract_choice_letter`` (first A–D, assistant-tail semantics).""" + if not text: + return None + raw = str(text).strip() + if not raw: + return None + match = re.search(r"assistant\s*([\s\S]*)$", raw, flags=re.IGNORECASE) + candidate = match.group(1).strip() if match else raw + direct = re.match(r"(?i)^\s*([A-D])(?:[\s\.\)::]|$)", candidate) + if direct: + return direct.group(1).upper() + fallback = re.search(r"\b([A-D])\b", candidate.upper()) + if fallback: + return fallback.group(1) + return None + + +def evaluate_answer_official(model_answer: str | None, correct_answer: str) -> bool: + """Port of Daily-Omni ``evaluate_answer`` (strict string match after strip/upper).""" + if not model_answer: + return False + return model_answer.strip().upper() == (correct_answer or "").strip().upper() + + +def normalize_gold_answer(gold: str) -> str | None: + """Best-effort single letter from ``Answer`` (for ``gold_normalized`` in saved items only).""" + g = (gold or "").strip().upper() + if len(g) == 1 and g in _VALID: + return g + m = re.search(r"([ABCD])\b", g) + if m: + return m.group(1).upper() + return None + + +def _extract_predicted_choice_relaxed(text: str) -> str | None: + """Legacy vLLM-Omni heuristics (last ``answer:`` patterns, tail scan).""" + if not text or not str(text).strip(): + return None + t = str(text).strip() + + strong_patterns = [ + r"(?i)\*\*answer\*\*\s*[::]?\s*\(?([ABCD])\)?", + r"(?i)\banswer\s*[::]?\s*\(?([ABCD])\)?", + r"(?i)\bfinal\s+answer\s*[::]?\s*\(?([ABCD])\)?", + r"(?i)\bcorrect\s+(?:answer|option)\s*[::]?\s*\(?([ABCD])\)?", + r"(?i)\bthe\s+(?:correct\s+)?option\s+(?:is|would\s+be)\s*\(?([ABCD])\)?", + r"(?i)\bI\s+(?:would\s+)?(?:choose|select|pick)\s*\(?([ABCD])\)?", + ] + last_letter: str | None = None + for pat in strong_patterns: + for m in re.finditer(pat, t): + last_letter = m.group(1).upper() + if last_letter: + return last_letter + + # Weaker phrases: first match can be spurious; still prefer last occurrence. + weak_patterns = [ + r"(?i)\boption\s*[::]?\s*\(?([ABCD])\)?", + r"(?i)\bchoice\s*[::]?\s*\(?([ABCD])\)?", + ] + for pat in weak_patterns: + for m in re.finditer(pat, t): + last_letter = m.group(1).upper() + if last_letter: + return last_letter + + paren = list(re.finditer(r"\(([ABCD])\)", t)) + if paren: + return paren[-1].group(1).upper() + + # First line sometimes is just "B" or "B." — allow if whole output is short + one_line = t.split("\n", 1)[0].strip() + if len(t) < 120 and len(one_line) <= 6: + m0 = re.match(r"^([ABCD])\s*[.:\)]?\s*$", one_line, re.I) + if m0: + return m0.group(1).upper() + + # Tail-only: avoids matching echoed "A. ..." option blocks at the start + tail_len = min(500, len(t)) + tail = t[-tail_len:] + # ``\b`` after the letter avoids "Because"/"Definitely" false positives + m = re.search(r"(?:^|[^\w])([ABCD])\b", tail, re.I) + if m: + return m.group(1).upper() + + return None + + +def extract_predicted_choice(text: str | None) -> str | None: + """Parse model output to A–D (official Daily-Omni rules by default).""" + if not text or not str(text).strip(): + return None + mode = os.environ.get("DAILY_OMNI_EXTRACT_MODE", "official").strip().lower() + if mode in ("relaxed", "heuristic", "legacy"): + return _extract_predicted_choice_relaxed(str(text)) + return extract_choice_letter_official(text) + + +def compute_daily_omni_accuracy_metrics( + input_requests: list[Any], + outputs: list[RequestFuncOutput], + *, + include_per_item: bool = False, +) -> dict[str, Any] | None: + """If all requests are :class:`DailyOmniSampleRequest`, compute accuracy stats. + + Rows with empty ``Answer`` (after strip) are skipped as ``no_gold``, like upstream missing + ``correct_answer``. + + **Denominators:** The open-source script excludes items that hit inference / I/O failures + from ``valid_questions``; we mirror that with ``daily_omni_accuracy`` (= correct / + successful responses). Failed HTTP requests are also tracked and used in + ``daily_omni_accuracy_incl_http_fail`` (each failure counts as incorrect in the + denominator). + """ + if not input_requests or len(input_requests) != len(outputs): + return None + if not all(isinstance(r, DailyOmniSampleRequest) for r in input_requests): + return None + + # total / correct: all rows with gold (incl. HTTP fail in total) + # total_ok / correct_ok: successful HTTP only (GitHub-style per-type denominator) + per_task: dict[str, dict[str, int]] = {} + per_category: dict[str, dict[str, int]] = {} + per_duration: dict[str, dict[str, int]] = { + k: {"correct": 0, "total": 0, "correct_ok": 0, "total_ok": 0} for k in DAILY_OMNI_DURATION_KEYS + } + items: list[dict[str, Any]] = [] + correct = 0 + evaluated = 0 + no_gold = 0 + request_failed = 0 + parse_failed = 0 # success but could not extract A–D + + for req, out in zip(input_requests, outputs, strict=True): + assert isinstance(req, DailyOmniSampleRequest) + gold_raw = (req.daily_omni_gold_answer or "").strip() + gold_norm = normalize_gold_answer(req.daily_omni_gold_answer) + tt = (req.daily_omni_task_type or "unknown").strip() or "unknown" + dur_key = (req.daily_omni_video_duration or "").strip() + dur_active = dur_key in per_duration + cat_key = (req.daily_omni_video_category or "").strip() or "unknown" + if tt not in per_task: + per_task[tt] = {"correct": 0, "total": 0, "correct_ok": 0, "total_ok": 0} + if cat_key not in per_category: + per_category[cat_key] = {"correct": 0, "total": 0, "correct_ok": 0, "total_ok": 0} + + if not gold_raw: + no_gold += 1 + items.append( + { + "request_id": req.request_id, + "skipped": True, + "reason": "no_gold", + "task_type": tt, + "video_id": req.daily_omni_video_id, + "video_duration": dur_key or None, + "video_category": cat_key if cat_key != "unknown" else None, + } + ) + continue + + if not out.success: + request_failed += 1 + evaluated += 1 + per_task[tt]["total"] += 1 + per_category[cat_key]["total"] += 1 + if dur_active: + per_duration[dur_key]["total"] += 1 + # GitHub: failed inference not in valid_questions — do not increment total_ok + items.append( + { + "request_id": req.request_id, + "gold": gold_raw, + "gold_normalized": gold_norm, + "predicted": None, + "correct": False, + "task_type": tt, + "video_id": req.daily_omni_video_id, + "video_duration": dur_key or None, + "video_category": cat_key if cat_key != "unknown" else None, + "error": (out.error or "")[:500], + } + ) + continue + + pred = extract_predicted_choice(out.generated_text) + evaluated += 1 + per_task[tt]["total"] += 1 + per_task[tt]["total_ok"] += 1 + per_category[cat_key]["total"] += 1 + per_category[cat_key]["total_ok"] += 1 + if dur_active: + per_duration[dur_key]["total"] += 1 + per_duration[dur_key]["total_ok"] += 1 + if pred is None: + parse_failed += 1 + is_correct = evaluate_answer_official(pred, req.daily_omni_gold_answer) + if is_correct: + correct += 1 + per_task[tt]["correct"] += 1 + per_task[tt]["correct_ok"] += 1 + per_category[cat_key]["correct"] += 1 + per_category[cat_key]["correct_ok"] += 1 + if dur_active: + per_duration[dur_key]["correct"] += 1 + per_duration[dur_key]["correct_ok"] += 1 + + items.append( + { + "request_id": req.request_id, + "gold": gold_raw, + "gold_normalized": gold_norm, + "predicted": pred, + "correct": is_correct, + "parse_failed": pred is None, + "task_type": tt, + "video_id": req.daily_omni_video_id, + "video_duration": dur_key or None, + "video_category": cat_key if cat_key != "unknown" else None, + } + ) + + evaluated_ok = evaluated - request_failed + accuracy_github = (correct / evaluated_ok) if evaluated_ok else None + accuracy_incl_fail = (correct / evaluated) if evaluated else None + + per_task_accuracy: dict[str, float | None] = {} + per_task_accuracy_github: dict[str, float | None] = {} + for name, st in per_task.items(): + tot = st["total"] + per_task_accuracy[name] = (st["correct"] / tot) if tot else None + tok = st["total_ok"] + per_task_accuracy_github[name] = (st["correct_ok"] / tok) if tok else None + + per_category_accuracy: dict[str, float | None] = {} + per_category_accuracy_github: dict[str, float | None] = {} + for name, st in per_category.items(): + tot = st["total"] + per_category_accuracy[name] = (st["correct"] / tot) if tot else None + tok = st["total_ok"] + per_category_accuracy_github[name] = (st["correct_ok"] / tok) if tok else None + + per_duration_accuracy: dict[str, float | None] = {} + per_duration_accuracy_github: dict[str, float | None] = {} + for name, st in per_duration.items(): + tot = st["total"] + per_duration_accuracy[name] = (st["correct"] / tot) if tot else None + tok = st["total_ok"] + per_duration_accuracy_github[name] = (st["correct_ok"] / tok) if tok else None + + out: dict[str, Any] = { + # Comparable to GitHub testmodel.py: correct / successful inferences + "daily_omni_accuracy": accuracy_github, + "daily_omni_accuracy_incl_http_fail": accuracy_incl_fail, + "daily_omni_correct": correct, + "daily_omni_evaluated": evaluated, + "daily_omni_evaluated_ok": evaluated_ok, + "daily_omni_no_gold": no_gold, + "daily_omni_request_failed": request_failed, + "daily_omni_parse_failed": parse_failed, + "daily_omni_per_task": {k: dict(v) for k, v in per_task.items()}, + "daily_omni_per_task_accuracy": per_task_accuracy, + "daily_omni_per_task_accuracy_github_style": per_task_accuracy_github, + "daily_omni_per_category": {k: dict(v) for k, v in per_category.items()}, + "daily_omni_per_category_accuracy": per_category_accuracy, + "daily_omni_per_category_accuracy_github_style": per_category_accuracy_github, + "daily_omni_per_duration": {k: dict(v) for k, v in per_duration.items()}, + "daily_omni_per_duration_accuracy": per_duration_accuracy, + "daily_omni_per_duration_accuracy_github_style": per_duration_accuracy_github, + } + if include_per_item: + out["daily_omni_eval_items"] = items + return out + + +def print_daily_omni_accuracy_summary(metrics: dict[str, Any]) -> None: + """Pretty-print accuracy block (stdout).""" + acc = metrics.get("daily_omni_accuracy") + acc_fail = metrics.get("daily_omni_accuracy_incl_http_fail") + if acc is None and acc_fail is None and metrics.get("daily_omni_evaluated", 0) == 0: + return + print("{s:{c}^{n}}".format(s=" Daily-Omni accuracy (MCQ) ", n=50, c="=")) + ok = int(metrics.get("daily_omni_evaluated_ok", 0) or 0) + cor = int(metrics.get("daily_omni_correct", 0) or 0) + if ok > 0 and acc is not None: + print(f"Overall Accuracy: {cor}/{ok} = {acc:.2%}") + elif int(metrics.get("daily_omni_evaluated", 0) or 0) > 0: + print("Overall Accuracy: 0/0 = N/A (no successful HTTP responses)") + print( + "{:<40} {:<10}".format( + "Submitted (gold present):", + metrics.get("daily_omni_evaluated", 0), + ) + ) + print( + "{:<40} {:<10}".format( + "Successful HTTP (GitHub denom.):", + metrics.get("daily_omni_evaluated_ok", 0), + ) + ) + print("{:<40} {:<10}".format("Correct:", metrics.get("daily_omni_correct", 0))) + if acc is not None: + print("{:<40} {:<10.4f}".format("Accuracy (ratio, same as above):", acc)) + if acc_fail is not None and metrics.get("daily_omni_request_failed", 0): + print( + "{:<40} {:<10.4f}".format( + "Accuracy (incl. HTTP as wrong):", + acc_fail, + ) + ) + print("{:<40} {:<10}".format("Skipped (no gold):", metrics.get("daily_omni_no_gold", 0))) + print( + "{:<40} {:<10}".format( + "HTTP failed (excl. from GitHub acc.):", + metrics.get("daily_omni_request_failed", 0), + ) + ) + print( + "{:<40} {:<10}".format( + "Parsed OK but no A–D found:", + metrics.get("daily_omni_parse_failed", 0), + ) + ) + pt = metrics.get("daily_omni_per_task") or {} + pta = metrics.get("daily_omni_per_task_accuracy_github_style") or {} + if pta: + print("\n--- Accuracy by QA Type ---") + for name in sorted(pta.keys()): + a = pta[name] + st = pt.get(name) or {} + tok = int(st.get("total_ok", 0) or 0) + cok = int(st.get("correct_ok", 0) or 0) + if tok and a is not None: + print(f"{name}: {cok}/{tok} = {a:.2%}") + else: + print(f"{name}: 0/0 = N/A") + + pc = metrics.get("daily_omni_per_category") or {} + ptc = metrics.get("daily_omni_per_category_accuracy_github_style") or {} + if ptc: + print("\n--- Accuracy by Video Category ---") + for name in sorted(ptc.keys()): + a = ptc[name] + st = pc.get(name) or {} + tok = int(st.get("total_ok", 0) or 0) + cok = int(st.get("correct_ok", 0) or 0) + if tok and a is not None: + print(f"{name}: {cok}/{tok} = {a:.2%}") + else: + print(f"{name}: 0/0 = N/A") + + pdf = metrics.get("daily_omni_per_duration_accuracy_github_style") or {} + if pdf: + print("\n--- Accuracy by Video Duration ---") + for name in DAILY_OMNI_DURATION_KEYS: + a = pdf.get(name) + st = (metrics.get("daily_omni_per_duration") or {}).get(name) or {} + tok = int(st.get("total_ok", 0) or 0) + cor = int(st.get("correct_ok", 0) or 0) + if tok and a is not None: + print(f"{name} Duration: {cor}/{tok} = {a:.2%}") + else: + print(f"{name} Duration: 0/0 = N/A") + print("=" * 50) diff --git a/vllm_omni/benchmarks/data_modules/daily_omni_text_audio.py b/vllm_omni/benchmarks/data_modules/daily_omni_text_audio.py new file mode 100644 index 0000000000..69fbe026bd --- /dev/null +++ b/vllm_omni/benchmarks/data_modules/daily_omni_text_audio.py @@ -0,0 +1,255 @@ +"""Daily-Omni: optional consistency check between text stream and generated speech. + +The benchmark MCQ accuracy uses ``generated_text`` only. When the omni server also +streams ``modality=audio`` (TTS), this module can transcribe the concatenated WAV +with Whisper and compare the inferred option letter to the one parsed from text. + +Requires ``openai-whisper`` (``pip install openai-whisper``). Enable via env +``DAILY_OMNI_TEXT_AUDIO_CONSISTENCY=1`` or CLI ``--daily-omni-text-audio-consistency``. + +Whisper model name defaults to ``tiny`` (override with ``DAILY_OMNI_WHISPER_MODEL``). +""" + +from __future__ import annotations + +import logging +import os +import re +import threading +from typing import Any + +from vllm_omni.benchmarks.data_modules.daily_omni_dataset import DailyOmniSampleRequest +from vllm_omni.benchmarks.data_modules.daily_omni_eval import extract_predicted_choice + +logger = logging.getLogger(__name__) + +_whisper_model = None +_whisper_model_name: str | None = None +_whisper_lock = threading.Lock() + + +def env_text_audio_check_enabled() -> bool: + return os.environ.get("DAILY_OMNI_TEXT_AUDIO_CONSISTENCY", "").lower() in ( + "1", + "true", + "yes", + ) + + +def extract_choice_from_asr_transcript(transcript: str) -> str | None: + """Parse A–D from ASR text; extends :func:`extract_predicted_choice` with spoken Chinese phrases.""" + c = extract_predicted_choice(transcript) + if c: + return c + t = transcript or "" + for pat in ( + r"(?i)选项\s*([ABCD])\b", + r"(?i)选\s*([ABCD])\b", + r"(?i)答案\s*是\s*([ABCD])\b", + r"(?i)答案\s*([ABCD])\b", + ): + m = re.search(pat, t) + if m: + return m.group(1).upper() + return None + + +def _get_whisper_model(model_name: str): + global _whisper_model, _whisper_model_name + with _whisper_lock: + if _whisper_model is None or _whisper_model_name != model_name: + import whisper + + logger.warning( + "Loading Whisper model %r for Daily-Omni text/audio consistency (one-time)...", + model_name, + ) + _whisper_model = whisper.load_model(model_name) + _whisper_model_name = model_name + return _whisper_model + + +def transcribe_wav_bytes( + wav_bytes: bytes, + *, + language: str | None = None, + model_name: str | None = None, +) -> tuple[str | None, str | None]: + """Transcribe WAV bytes. Returns ``(transcript, error)`` — one of them is set. + + Args: + wav_bytes: RIFF WAV file bytes. + language: Optional Whisper language code (e.g. ``en``, ``zh``); improves accuracy/latency. + model_name: Override model id; else ``DAILY_OMNI_WHISPER_MODEL`` or ``tiny``. + """ + if not wav_bytes: + return None, "empty_wav" + if model_name is None or not str(model_name).strip(): + model_name = os.environ.get("DAILY_OMNI_WHISPER_MODEL") or "tiny" + model_name = str(model_name).strip() or "tiny" + path: str | None = None + try: + import tempfile + + model = _get_whisper_model(model_name) + fd, path = tempfile.mkstemp(suffix=".wav") + with os.fdopen(fd, "wb") as fp: + fp.write(wav_bytes) + kwargs: dict = {} + if language: + kwargs["language"] = language + result = model.transcribe(path, **kwargs) + text = (result.get("text") or "").strip() + return (text if text else None), None + except ImportError: + return None, "openai-whisper is not installed (pip install openai-whisper)" + except Exception as e: + return None, str(e)[:500] + finally: + if path: + try: + os.unlink(path) + except OSError: + pass + + +def compute_daily_omni_text_audio_consistency_metrics( + input_requests: list[Any], + outputs: list[Any], + *, + include_per_item: bool = False, +) -> dict[str, Any] | None: + """Compare option letter from ``generated_text`` vs Whisper transcript of output audio. + + Only considers requests where ``outputs[i]`` has ``generated_audio_wav_bytes`` set + (populated by the omni benchmark when TA check is enabled). + """ + if not input_requests or len(input_requests) != len(outputs): + return None + if not all(isinstance(r, DailyOmniSampleRequest) for r in input_requests): + return None + + ta_no_wav = 0 + ta_asr_failed = 0 + ta_text_unparsed = 0 + ta_audio_unparsed = 0 + ta_consistent = 0 + ta_mismatch = 0 + ta_both_parsed = 0 + items: list[dict[str, Any]] = [] + + for req, out in zip(input_requests, outputs, strict=True): + assert isinstance(req, DailyOmniSampleRequest) + rid = req.request_id + if not getattr(out, "success", False): + if include_per_item: + items.append( + { + "request_id": rid, + "skipped": True, + "reason": "request_not_success", + } + ) + continue + + wav = getattr(out, "generated_audio_wav_bytes", None) + if not wav: + ta_no_wav += 1 + if include_per_item: + items.append( + { + "request_id": rid, + "skipped": False, + "reason": "no_output_audio", + "text_choice": extract_predicted_choice(getattr(out, "generated_text", "") or ""), + } + ) + continue + + transcript, asr_err = transcribe_wav_bytes(wav) + if asr_err: + ta_asr_failed += 1 + if include_per_item: + items.append( + { + "request_id": rid, + "asr_error": asr_err, + "text_choice": extract_predicted_choice(getattr(out, "generated_text", "") or ""), + } + ) + continue + + text_choice = extract_predicted_choice(getattr(out, "generated_text", "") or "") + audio_choice = extract_choice_from_asr_transcript(transcript or "") + + if text_choice is None: + ta_text_unparsed += 1 + if audio_choice is None: + ta_audio_unparsed += 1 + + if text_choice is not None and audio_choice is not None: + ta_both_parsed += 1 + if text_choice == audio_choice: + ta_consistent += 1 + else: + ta_mismatch += 1 + + if include_per_item: + consistent: bool | None + if text_choice is None or audio_choice is None: + consistent = None + else: + consistent = text_choice == audio_choice + items.append( + { + "request_id": rid, + "text_choice": text_choice, + "audio_choice": audio_choice, + "asr_transcript": (transcript or "")[:500], + "text_audio_consistent": consistent, + } + ) + + comparable = ta_consistent + ta_mismatch + rate = (ta_consistent / comparable) if comparable else None + + out: dict[str, Any] = { + "daily_omni_ta_enabled": True, + "daily_omni_ta_no_output_audio": ta_no_wav, + "daily_omni_ta_asr_failed": ta_asr_failed, + "daily_omni_ta_text_unparsed": ta_text_unparsed, + "daily_omni_ta_audio_unparsed": ta_audio_unparsed, + "daily_omni_ta_both_parsed": ta_both_parsed, + "daily_omni_ta_consistent": ta_consistent, + "daily_omni_ta_mismatch": ta_mismatch, + "daily_omni_ta_consistency_rate": rate, + } + if include_per_item: + out["daily_omni_ta_items"] = items + return out + + +def print_daily_omni_text_audio_summary(metrics: dict[str, Any]) -> None: + if not metrics.get("daily_omni_ta_enabled"): + return + print("{s:{c}^{n}}".format(s=" Daily-Omni text vs audio (ASR) ", n=50, c="=")) + print("{:<40} {:<10}".format("No output audio captured:", metrics.get("daily_omni_ta_no_output_audio", 0))) + print("{:<40} {:<10}".format("ASR failed:", metrics.get("daily_omni_ta_asr_failed", 0))) + print("{:<40} {:<10}".format("Both text+audio letter parsed:", metrics.get("daily_omni_ta_both_parsed", 0))) + print("{:<40} {:<10}".format("Consistent (same letter):", metrics.get("daily_omni_ta_consistent", 0))) + print("{:<40} {:<10}".format("Mismatch:", metrics.get("daily_omni_ta_mismatch", 0))) + r = metrics.get("daily_omni_ta_consistency_rate") + if r is not None: + print("{:<40} {:<10.4f}".format("Consistency rate (of both parsed):", r)) + print( + "{:<40} {:<10}".format( + "Text unparsed (among w/ audio):", + metrics.get("daily_omni_ta_text_unparsed", 0), + ) + ) + print( + "{:<40} {:<10}".format( + "Audio unparsed (among w/ audio):", + metrics.get("daily_omni_ta_audio_unparsed", 0), + ) + ) diff --git a/vllm_omni/benchmarks/data_modules/seed_tts_dataset.py b/vllm_omni/benchmarks/data_modules/seed_tts_dataset.py new file mode 100644 index 0000000000..ca6de4cb20 --- /dev/null +++ b/vllm_omni/benchmarks/data_modules/seed_tts_dataset.py @@ -0,0 +1,272 @@ +"""Seed-TTS zero-shot evaluation-style prompts for ``vllm bench serve``. + +Loads rows from the `meta.lst` format used in `BytedanceSpeech/seed-tts-eval`_ (or any +HuggingFace dataset repo with the same layout):: + + utt_id|prompt_transcript|prompt_wav_relative_path|text_to_synthesize + +Each benchmark request supplies target text plus ``ref_text`` / ``ref_audio`` (Qwen3-TTS ``Base`` / +voice clone), merged into the JSON body. By default ``ref_audio`` is an inline ``data:`` URL so +the server does not need ``--allowed-local-media-path``. Use ``--seed-tts-file-ref-audio`` for +``file://`` (smaller bodies; requires that flag). Use ``--backend openai-audio-speech`` +(``/v1/audio/speech``) or ``--backend openai-chat-omni`` (``/v1/chat/completions`` with the same +fields on the body plus a Qwen3-Omni-style ``system`` message and the target text as ``user`` content). + +.. _BytedanceSpeech/seed-tts-eval: https://github.com/BytedanceSpeech/seed-tts-eval +""" + +from __future__ import annotations + +import base64 +import logging +import random +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from vllm.benchmarks.datasets import BenchmarkDataset, SampleRequest +from vllm.tokenizers import TokenizerLike +from vllm.tokenizers.hf import get_cached_tokenizer + +logger = logging.getLogger(__name__) + +# Matches Qwen3-Omni serving examples (``openai_chat_completion_client_for_multimodal_generation`` / +# ``qwen3_omni/gradio_demo``) plus explicit TTS / voice-clone instructions for chat completions. +SEED_TTS_DEFAULT_OMNI_SYSTEM_PROMPT = ( + "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, " + "capable of perceiving auditory and visual inputs, as well as generating text and speech.\n" + "For this request you act as a text-to-speech engine with zero-shot voice cloning: " + "the API provides reference audio and its transcript (ref_audio, ref_text) and task_type Base. " + "The user message is the exact text you must speak. " + "Synthesize natural speech in the same language as that user text, " + "matching the timbre, prosody, and speaking style of the reference audio while reading the new content clearly." +) + + +@dataclass +class SeedTTSSampleRequest(SampleRequest): + """``SampleRequest`` with per-row fields merged into ``/v1/audio/speech`` JSON.""" + + #: Shallow-merged into ``RequestFuncInput.extra_body`` (ref_audio, ref_text, task_type, …). + seed_tts_speech_extra: dict[str, Any] | None = None + seed_tts_utterance_id: str = "" + seed_tts_locale: str = "" + #: For ``openai-chat-omni``: becomes the chat ``system`` message (Qwen3-Omni + TTS behavior). + seed_tts_system_prompt: str = "" + #: Local path to reference prompt WAV (for SIM vs. synthesized PCM in ``seed_tts_eval``). + seed_tts_ref_wav_path: str = "" + + +@dataclass +class _SeedTTSRow: + utterance_id: str + ref_text: str + prompt_wav_rel: str + target_text: str + + +def _parse_meta_line(line: str) -> _SeedTTSRow | None: + line = line.strip() + if not line or line.startswith("#"): + return None + parts = line.split("|") + if len(parts) < 4: + logger.warning("Skipping malformed meta.lst line (need 4 '|'-fields): %r", line[:120]) + return None + utt_id, ref_text, wav_rel, target = parts[0], parts[1], parts[2], parts[3] + if not target.strip(): + return None + return _SeedTTSRow( + utterance_id=utt_id.strip(), + ref_text=ref_text.strip(), + prompt_wav_rel=wav_rel.strip(), + target_text=target.strip(), + ) + + +def _load_meta_rows(meta_file: Path) -> list[_SeedTTSRow]: + text = meta_file.read_text(encoding="utf-8") + rows: list[_SeedTTSRow] = [] + for line in text.splitlines(): + r = _parse_meta_line(line) + if r is not None: + rows.append(r) + return rows + + +def resolve_seed_tts_root(dataset_path: str | None, *, explicit_root: str | None) -> Path: + """Return directory containing ``{locale}/meta.lst`` and ``{locale}/prompt-wavs/``.""" + if explicit_root: + root = Path(explicit_root).expanduser().resolve() + if not root.is_dir(): + raise FileNotFoundError(f"--seed-tts-root is not a directory: {root}") + return root + + if not dataset_path: + raise ValueError("Seed-TTS requires --dataset-path (HF repo id or local root) or --seed-tts-root.") + + p = Path(dataset_path).expanduser() + if p.exists() and p.is_dir(): + return p.resolve() + + repo_id = dataset_path.strip() + try: + from huggingface_hub import snapshot_download + except ImportError as e: + raise ImportError( + "Install huggingface_hub to download Seed-TTS from the Hub, or clone the dataset " + "locally and pass --dataset-path / --seed-tts-root to that directory." + ) from e + cache = snapshot_download(repo_id=repo_id, repo_type="dataset") + return Path(cache).resolve() + + +def _ref_audio_payload(wav_path: Path, *, inline: bool) -> str: + if inline: + raw = wav_path.read_bytes() + b64 = base64.b64encode(raw).decode("ascii") + return f"data:audio/wav;base64,{b64}" + return wav_path.expanduser().resolve().as_uri() + + +class SeedTTSDataset(BenchmarkDataset): + """Seed-TTS-style zero-shot TTS rows for throughput/latency benchmarking. + + Args: + dataset_path: HuggingFace dataset repo id (``org/dataset``) or local directory with + ``en/meta.lst`` (and ``zh/meta.lst`` if using zh). + locale: ``en`` or ``zh`` — which subfolder under the root to read. + inline_ref_audio: If True (default), embed prompt WAV as ``data:audio/wav;base64,...`` + so Qwen3-TTS / ``/v1/audio/speech`` works without server + ``--allowed-local-media-path``. If False, use ``file://`` (smaller + requests; server must set ``--allowed-local-media-path`` to the dataset root). + seed_tts_root: Optional override for the root directory (same layout as HF dataset). + system_prompt: Optional override for the chat system message when using + ``--backend openai-chat-omni``; defaults to :data:`SEED_TTS_DEFAULT_OMNI_SYSTEM_PROMPT`. + """ + + IS_MULTIMODAL = False + DEFAULT_OUTPUT_LEN = 2048 + + def __init__( + self, + dataset_path: str, + random_seed: int = 0, + locale: str = "en", + inline_ref_audio: bool = True, + seed_tts_root: str | None = None, + system_prompt: str | None = None, + disable_shuffle: bool = False, + **kwargs: Any, + ) -> None: + if locale not in ("en", "zh"): + raise ValueError("locale must be 'en' or 'zh'") + self.locale = locale + self.inline_ref_audio = inline_ref_audio + self._explicit_root = seed_tts_root + sp = (system_prompt or "").strip() + self._system_prompt = sp if sp else SEED_TTS_DEFAULT_OMNI_SYSTEM_PROMPT + super().__init__( + dataset_path=dataset_path, + random_seed=random_seed, + disable_shuffle=disable_shuffle, + **kwargs, + ) + self._root = resolve_seed_tts_root(self.dataset_path, explicit_root=self._explicit_root) + self._rows: list[_SeedTTSRow] = [] + self.load_data() + + def load_data(self) -> None: + meta = self._root / self.locale / "meta.lst" + if not meta.is_file(): + raise FileNotFoundError( + f"Seed-TTS meta not found: {meta}. " + f"Expected layout from seed-tts-eval (e.g. {self._root}/{self.locale}/meta.lst)." + ) + self._rows = _load_meta_rows(meta) + if not self._rows: + raise ValueError(f"No valid rows in {meta}") + if not self.disable_shuffle: + rng = random.Random(self.random_seed) + rng.shuffle(self._rows) + self.data = self._rows + logger.info( + "Loaded Seed-TTS: root=%s locale=%s rows=%d inline_ref_audio=%s", + self._root, + self.locale, + len(self._rows), + self.inline_ref_audio, + ) + + def sample( + self, + tokenizer: TokenizerLike, + num_requests: int, + output_len: int | None = None, + request_id_prefix: str = "", + no_oversample: bool = False, + **kwargs: Any, + ) -> list[SampleRequest]: + if output_len is None: + output_len = self.DEFAULT_OUTPUT_LEN + + tok = get_cached_tokenizer(tokenizer) + out: list[SampleRequest] = [] + for i, row in enumerate(self._rows): + if len(out) >= num_requests: + break + wav_path = (self._root / self.locale / row.prompt_wav_rel).resolve() + if not wav_path.is_file(): + logger.warning("Missing prompt wav for %s: %s", row.utterance_id, wav_path) + continue + + target = row.target_text + prompt_len = len(tok.encode(f"{self._system_prompt}\n{target}")) + lang = "English" if self.locale == "en" else "Chinese" + ref_uri = _ref_audio_payload(wav_path, inline=self.inline_ref_audio) + speech_extra: dict[str, Any] = { + "ref_audio": ref_uri, + "ref_text": row.ref_text, + "task_type": "Base", + "language": lang, + "max_new_tokens": output_len, + } + + out.append( + SeedTTSSampleRequest( + prompt=target, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=None, + request_id=f"{request_id_prefix}{i}", + seed_tts_speech_extra=speech_extra, + seed_tts_utterance_id=row.utterance_id, + seed_tts_locale=self.locale, + seed_tts_system_prompt=self._system_prompt, + seed_tts_ref_wav_path=str(wav_path), + ) + ) + + logger.info("Seed-TTS: built %d requests (asked %d)", len(out), num_requests) + self.maybe_oversample_requests(out, num_requests, request_id_prefix, no_oversample) + return out + + +def load_seed_tts_dataset( + dataset_path: str, + random_seed: int = 0, + locale: str = "en", + inline_ref_audio: bool = True, + seed_tts_root: str | None = None, + system_prompt: str | None = None, + **kwargs: Any, +) -> SeedTTSDataset: + return SeedTTSDataset( + dataset_path=dataset_path, + random_seed=random_seed, + locale=locale, + inline_ref_audio=inline_ref_audio, + seed_tts_root=seed_tts_root, + system_prompt=system_prompt, + **kwargs, + ) diff --git a/vllm_omni/benchmarks/data_modules/seed_tts_eval.py b/vllm_omni/benchmarks/data_modules/seed_tts_eval.py new file mode 100644 index 0000000000..d5f1b64709 --- /dev/null +++ b/vllm_omni/benchmarks/data_modules/seed_tts_eval.py @@ -0,0 +1,729 @@ +"""Seed-TTS WER aligned with Bytedance ``seed-tts-eval`` / ``run_wer.py``. + +Matches the published protocol (see Hugging Face dataset card and +https://github.com/BytedanceSpeech/seed-tts-eval): + +- **EN**: ``openai/whisper-large-v3`` via ``transformers``, audio resampled to **16 kHz** + (same as ``run_wer.py``). +- **ZH**: ``funasr`` **paraformer-zh**, hypothesis converted with **zhconv** to zh-cn. +- **WER**: ``jiwer`` after punctuation stripping (``zhon.hanzi.punctuation`` + ``string.punctuation``, + preserving ``'``) and EN lowercasing / ZH per-character spacing. Supports jiwer 3.x + (``compute_measures``) and 4.x (``process_words``). + +- **SIM** (speaker similarity proxy): cosine similarity of L2-normalized mean-pooled **WavLM** + embeddings (reference prompt WAV vs. synthesized PCM), 16 kHz. Official ``cal_sim.sh`` uses + UniSpeech ``verification_pair_list_v2.py`` with a **fine-tuned** WavLM SV checkpoint — set + ``SEED_TTS_WAVLM_MODEL`` to another HF id if you need closer parity. Disable with + ``SEED_TTS_SIM_EVAL=0``. Optional: ``SEED_TTS_SIM_DEVICE`` (e.g. ``cpu``) to avoid GPU + issues when Whisper already uses CUDA; ``SEED_TTS_WAVLM_MIN_SAMPLES`` pads very short + waveforms so the WavLM CNN front-end does not fail. + +- **UTMOS** (predicted MOS from TorchScript): default ``balacoon/utmos`` → ``utmos.jit`` + (Sarulab-style demo export). Uses ``torch`` + ``huggingface_hub`` only. Aggregate metrics + are over **all requests with captured PCM** (independent of ASR/WER). Non-finite scores are + dropped and counted as failures. Override repo/file via ``SEED_TTS_UTMOS_HF_REPO`` / + ``SEED_TTS_UTMOS_JIT_FILE``. **Device**: defaults to **CPU** when ``SEED_TTS_UTMOS_DEVICE`` + is unset; set ``SEED_TTS_UTMOS_DEVICE=cuda:0`` (or ``cuda:1`` etc.) to run on GPU. The JIT + model is loaded directly onto the target device via ``map_location`` to avoid cross-device + issues (some PyTorch builds/Windows have problems moving TorchScript modules after load). + Forward uses **float32** waveform in ``[-1, 1]`` (same as the WER resampled array) so + tensor dtypes match JIT weights; using int16 triggers + ``RuntimeError: input type and weight type should be same`` on common exports. Disable + with ``SEED_TTS_UTMOS_EVAL=0``. + +Enable with ``SEED_TTS_WER_EVAL=1`` or ``--seed-tts-wer-eval``. Install optional deps:: + + pip install 'vllm-omni[seed-tts-eval]' + +Env: ``SEED_TTS_EVAL_DEVICE`` (e.g. ``cuda:0``, ``cpu``); ``SEED_TTS_HF_WHISPER_MODEL`` +defaults to ``openai/whisper-large-v3`` (override for debugging only). +""" + +from __future__ import annotations + +import io +import logging +import math +import os +import statistics +import string +import tempfile +import threading +import wave +from typing import Any + +import numpy as np +from vllm.benchmarks.datasets import SampleRequest + +from vllm_omni.benchmarks.data_modules.seed_tts_dataset import SeedTTSSampleRequest + +logger = logging.getLogger(__name__) + +# Mirrors seed-tts-eval/run_wer.py +OFFICIAL_WHISPER_HF_ID = "openai/whisper-large-v3" +PARAFORMER_MODEL_ID = "paraformer-zh" + +_lock = threading.Lock() +_device: str | None = None +_en_processor = None +_en_model = None +_zh_paraformer = None +_wavlm_model = None +_wavlm_processor = None +_wavlm_device: str | None = None +_utmos_jit_model = None +_utmos_jit_device: str | None = None +_utmos_jit_load_failed = False +_utmos_forward_warned = False + + +def pcm_s16le_mono_to_wav_bytes(pcm: bytes, *, sample_rate: int = 24000) -> bytes: + buf = io.BytesIO() + with wave.open(buf, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(pcm) + return buf.getvalue() + + +def _get_eval_device() -> str: + explicit = os.environ.get("SEED_TTS_EVAL_DEVICE", "").strip() + if explicit: + return explicit + try: + import torch + + return "cuda:0" if torch.cuda.is_available() else "cpu" + except ImportError: + return "cpu" + + +def _punctuation_all() -> str: + from zhon.hanzi import punctuation + + return punctuation + string.punctuation + + +def _jiwer_wer(reference: str, hypothesis: str) -> float: + """Word-level WER; strings are normalized like ``run_wer.process_one``. + + jiwer 4.x removed ``compute_measures`` (``ImportError``); fall back to ``process_words``. + """ + try: + from jiwer import compute_measures + + return float(compute_measures(reference, hypothesis)["wer"]) + except ImportError: + import jiwer + + out = jiwer.process_words(reference, hypothesis) + return float(out.wer) + + +def process_one_official(hypo: str, truth: str, lang: str) -> tuple[float, str, str]: + """Same normalization + ``jiwer`` call as ``run_wer.process_one`` (hypo=ASR, truth=reference).""" + raw_truth = truth + raw_hypo = hypo + truth_n = truth + hypo_n = hypo + for x in _punctuation_all(): + if x == "'": + continue + truth_n = truth_n.replace(x, "") + hypo_n = hypo_n.replace(x, "") + truth_n = truth_n.replace(" ", " ") + hypo_n = hypo_n.replace(" ", " ") + if lang == "zh": + truth_n = " ".join([x for x in truth_n]) + hypo_n = " ".join([x for x in hypo_n]) + elif lang == "en": + truth_n = truth_n.lower() + hypo_n = hypo_n.lower() + else: + raise ValueError(f"unsupported lang {lang!r}") + wer = _jiwer_wer(truth_n, hypo_n) + return wer, raw_truth, raw_hypo + + +def _pcm_s16le_to_f32_16k(pcm: bytes, pcm_sample_rate: int = 24000) -> np.ndarray: + import scipy.signal + + if not pcm: + return np.zeros(0, dtype=np.float32) + raw = np.frombuffer(pcm, dtype=np.int16).astype(np.float32) / 32768.0 + target_len = int(len(raw) * 16000 / pcm_sample_rate) + if target_len <= 0: + return np.zeros(0, dtype=np.float32) + return scipy.signal.resample(raw, target_len).astype(np.float32) + + +def _eval_submetric_enabled(env_name: str, *, default: bool = True) -> bool: + raw = os.environ.get(env_name, "").strip().lower() + if raw in ("0", "false", "no", "off"): + return False + if raw in ("1", "true", "yes", "on"): + return True + return default + + +def _audio_path_to_f32_16k(path: str) -> np.ndarray: + import scipy.signal + import soundfile as sf + + data, sr = sf.read(path, dtype="float32", always_2d=True) + mono = np.mean(data, axis=1).astype(np.float32) + if int(sr) == 16000: + return mono + target_len = max(1, int(len(mono) * 16000 / int(sr))) + return scipy.signal.resample(mono, target_len).astype(np.float32) + + +def _ensure_wavlm_sim() -> None: + global _wavlm_model, _wavlm_processor, _wavlm_device + with _lock: + if _wavlm_model is not None: + return + from transformers import AutoFeatureExtractor, AutoModel + + mid = os.environ.get("SEED_TTS_WAVLM_MODEL", "microsoft/wavlm-base-plus").strip() or "microsoft/wavlm-base-plus" + _wavlm_device = os.environ.get("SEED_TTS_SIM_DEVICE", "").strip() or _get_eval_device() + logger.warning( + "Loading WavLM %r on %s for Seed-TTS SIM (embedding cosine; not identical to " + "seed-tts-eval UniSpeech SV checkpoint).", + mid, + _wavlm_device, + ) + _wavlm_processor = AutoFeatureExtractor.from_pretrained(mid) + _wavlm_model = AutoModel.from_pretrained(mid).to(_wavlm_device) + _wavlm_model.eval() + + +def _wavlm_prepare_waveform(wav: np.ndarray) -> np.ndarray: + """Trim, pad to a minimum length WavLM/Wav2Vec2 CNN stack accepts, float32 mono.""" + max_sec = float(os.environ.get("SEED_TTS_WAVLM_MAX_SECONDS", "30")) + cap = int(max_sec * 16000) + w = np.asarray(wav, dtype=np.float32).reshape(-1) + if len(w) == 0: + return w + if len(w) > cap: + w = w[:cap].copy() + # Very short clips make the strided conv front-end fail (shape / empty time dim). + min_samples = int(os.environ.get("SEED_TTS_WAVLM_MIN_SAMPLES", "4000")) + if len(w) < min_samples: + w = np.pad(w, (0, min_samples - len(w)), mode="constant") + return w + + +def _wavlm_mean_embedding_f32_16k(wav: np.ndarray) -> np.ndarray | None: + import torch + + _ensure_wavlm_sim() + w = _wavlm_prepare_waveform(wav) + if len(w) == 0: + return None + assert _wavlm_processor is not None and _wavlm_model is not None and _wavlm_device is not None + # Single utterance: avoid padding=True (adds zeros that distort mean pooling). Still pass + # attention_mask when the extractor provides it (sample-level; do not mix with hidden length). + try: + inputs = _wavlm_processor( + w, + sampling_rate=16000, + return_tensors="pt", + padding=False, + return_attention_mask=True, + ) + except TypeError: + inputs = _wavlm_processor( + w, + sampling_rate=16000, + return_tensors="pt", + padding=False, + ) + iv = inputs["input_values"].to(_wavlm_device) + am = inputs.get("attention_mask") + if am is not None: + am = am.to(_wavlm_device) + with torch.inference_mode(): + out = _wavlm_model(iv, attention_mask=am) + h = out.last_hidden_state + v = h.mean(dim=1).squeeze(0).float().cpu().numpy() + n = float(np.linalg.norm(v)) + if not np.isfinite(n) or n < 1e-8: + return None + return (v / n).astype(np.float32) + + +def _cosine_similarity_unit_vectors(a: np.ndarray, b: np.ndarray) -> float: + return float(np.dot(a, b)) + + +def _ensure_utmos_jit_model() -> Any | None: + """Load UTMOS as TorchScript (``balacoon/utmos`` style): no ``import utmos`` / fairseq.""" + global _utmos_jit_model, _utmos_jit_device, _utmos_jit_load_failed + with _lock: + if _utmos_jit_load_failed: + return None + if _utmos_jit_model is not None: + return _utmos_jit_model + try: + import torch + from huggingface_hub import hf_hub_download + + repo = os.environ.get("SEED_TTS_UTMOS_HF_REPO", "balacoon/utmos").strip() or "balacoon/utmos" + fname = os.environ.get("SEED_TTS_UTMOS_JIT_FILE", "utmos.jit").strip() or "utmos.jit" + logger.warning( + "Loading UTMOS TorchScript from Hugging Face %r file %r (one-time download/cache)...", + repo, + fname, + ) + path = hf_hub_download(repo_id=repo, filename=fname, repo_type="model") + + # TODO The model weights in UTMOS must be loaded in cuda:0; otherwise, the model execution will fail. + want = "cuda:0" + if want.startswith("cuda") and torch.cuda.is_available(): + idx = want.split(":")[-1] if ":" in want else "0" + target_dev = f"cuda:{idx}" + else: + target_dev = "cpu" + + try: + m = torch.jit.load(path, map_location=target_dev) + m.eval() + _utmos_jit_device = target_dev + except Exception as load_e: + if target_dev.startswith("cuda"): + logger.warning( + "UTMOS JIT load on %s failed (%s), retrying on CPU...", + target_dev, + load_e, + ) + m = torch.jit.load(path, map_location="cpu") + m.eval() + _utmos_jit_device = "cpu" + else: + raise + _utmos_jit_model = m + except Exception as e: + logger.warning( + "UTMOS JIT unavailable (install torch + huggingface_hub; check HF access): %s", + e, + ) + _utmos_jit_load_failed = True + return None + return _utmos_jit_model + + +def _utmos_predict_f32_16k(wav_f32: np.ndarray) -> float | None: + """MOS from JIT model; input is float32 mono @ 16 kHz in ``[-1, 1]`` (WER pipeline). + + ``balacoon/utmos`` demos sometimes use int16 numpy, but the exported ``.jit`` weights are + float32; passing int16 tensors causes: "RuntimeError: ... input type and weight type + should be same". + """ + import torch + + if len(wav_f32) == 0: + return None + model = _ensure_utmos_jit_model() + if model is None: + return None + # Infer model's device from its first parameter/buffer to guarantee input sits with weights. + try: + model_dev = next(model.parameters()).device + except StopIteration: + try: + model_dev = next(model.buffers()).device + except StopIteration: + model_dev = torch.device("cpu") + w = np.ascontiguousarray(wav_f32, dtype=np.float32) + x = torch.from_numpy(w).unsqueeze(0).to(device=model_dev, dtype=torch.float32) + with torch.no_grad(): + out = model(x) + val = float(out.reshape(-1)[0].item()) + if not math.isfinite(val): + return None + return val + + +def _ensure_en_asr() -> None: + global _en_processor, _en_model, _device + with _lock: + if _en_processor is not None: + return + from transformers import WhisperForConditionalGeneration, WhisperProcessor + + _device = _get_eval_device() + mid = os.environ.get("SEED_TTS_HF_WHISPER_MODEL", OFFICIAL_WHISPER_HF_ID).strip() or OFFICIAL_WHISPER_HF_ID + logger.warning( + "Loading Seed-TTS eval Whisper HF model %r on %s (one-time, seed-tts-eval protocol)...", + mid, + _device, + ) + _en_processor = WhisperProcessor.from_pretrained(mid) + _en_model = WhisperForConditionalGeneration.from_pretrained(mid).to(_device) + _en_model.eval() + + +def _ensure_zh_asr() -> None: + global _zh_paraformer, _device + with _lock: + if _zh_paraformer is not None: + return + from funasr import AutoModel + + _device = _get_eval_device() + logger.warning( + "Loading Seed-TTS eval Paraformer %r on %s (one-time, seed-tts-eval protocol)...", + PARAFORMER_MODEL_ID, + _device, + ) + try: + _zh_paraformer = AutoModel(model=PARAFORMER_MODEL_ID, device=_device) + except TypeError: + _zh_paraformer = AutoModel(model=PARAFORMER_MODEL_ID) + + +def _transcribe_en_f32_16k(wav_f32: np.ndarray) -> str: + import torch + + _ensure_en_asr() + if len(wav_f32) == 0: + return "" + with _lock: + assert _en_processor is not None and _en_model is not None and _device is not None + inputs = _en_processor(wav_f32, sampling_rate=16000, return_tensors="pt") + input_features = inputs.input_features.to(_device) + with torch.no_grad(): + try: + forced = _en_processor.get_decoder_prompt_ids(language="english", task="transcribe") + predicted_ids = _en_model.generate(input_features, forced_decoder_ids=forced) + except Exception: + predicted_ids = _en_model.generate( + input_features, + language="english", + task="transcribe", + ) + text = _en_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] + return (text or "").strip() + + +def _transcribe_zh_wav_path(wav_path: str) -> str: + import zhconv + + _ensure_zh_asr() + with _lock: + assert _zh_paraformer is not None + res = _zh_paraformer.generate(input=wav_path, batch_size_s=300) + transcription = res[0]["text"] if res else "" + return zhconv.convert(transcription, "zh-cn").strip() + + +def _missing_deps_message(lang: str) -> str | None: + try: + import jiwer # noqa: F401 + from zhon.hanzi import punctuation # noqa: F401 + except ImportError as e: + return f"Seed-TTS WER eval needs jiwer and zhon ({e!s}). Install: pip install 'vllm-omni[seed-tts-eval]'" + try: + import scipy.signal # noqa: F401 + import soundfile # noqa: F401 + except ImportError as e: + return f"Seed-TTS WER eval needs scipy and soundfile ({e!s})." + if lang == "en": + try: + import torch # noqa: F401 + from transformers import WhisperForConditionalGeneration # noqa: F401 + except ImportError as e: + return f"English WER needs torch and transformers ({e!s}). Install: pip install 'vllm-omni[seed-tts-eval]'" + else: + try: + import zhconv # noqa: F401 + from funasr import AutoModel # noqa: F401 + except ImportError as e: + return f"Chinese WER needs funasr and zhconv ({e!s}). Install: pip install 'vllm-omni[seed-tts-eval]'" + return None + + +def compute_seed_tts_wer_metrics( + input_requests: list[SampleRequest], + outputs: list[Any], + *, + include_per_item: bool = False, +) -> dict[str, Any] | None: + """If all requests are :class:`SeedTTSSampleRequest`, run seed-tts-eval-style WER.""" + global _utmos_forward_warned + if not input_requests or len(input_requests) != len(outputs): + return None + if not all(isinstance(r, SeedTTSSampleRequest) for r in input_requests): + return None + + first = input_requests[0] + assert isinstance(first, SeedTTSSampleRequest) + lang = "zh" if (first.seed_tts_locale or "en").lower().startswith("zh") else "en" + + setup_err = _missing_deps_message(lang) + if setup_err: + logger.error("%s", setup_err) + return { + "seed_tts_eval_setup_error": setup_err, + "seed_tts_eval_protocol": "seed-tts-eval", + "seed_tts_content_evaluated": 0, + "seed_tts_content_error_mean": None, + "seed_tts_content_error_median": None, + "seed_tts_request_failed": 0, + "seed_tts_no_pcm": 0, + "seed_tts_asr_failed": 0, + "seed_tts_content_metric": "wer", + } + + import soundfile as sf + + errs: list[float] = [] + items: list[dict[str, Any]] = [] + asr_failed = 0 + no_pcm = 0 + request_failed = 0 + sim_values: list[float] = [] + utmos_values: list[float] = [] + sim_failed = 0 + sim_skipped_no_ref = 0 + utmos_failed = 0 + utmos_on = _eval_submetric_enabled("SEED_TTS_UTMOS_EVAL", default=True) + + for req, out in zip(input_requests, outputs, strict=True): + assert isinstance(req, SeedTTSSampleRequest) + ref = req.prompt + locale = req.seed_tts_locale or "en" + row_lang = "zh" if locale.lower().startswith("zh") else "en" + utmos_v: float | None = None + + if not out.success: + request_failed += 1 + if include_per_item: + items.append( + { + "utterance_id": req.seed_tts_utterance_id, + "locale": locale, + "error": "request_failed", + "detail": (out.error or "")[:500], + } + ) + continue + + pcm = getattr(out, "tts_output_pcm_bytes", None) + if not pcm: + no_pcm += 1 + if include_per_item: + items.append( + { + "utterance_id": req.seed_tts_utterance_id, + "locale": locale, + "error": "no_pcm", + } + ) + continue + + wav_16k = _pcm_s16le_to_f32_16k(pcm) + if len(wav_16k) == 0: + asr_failed += 1 + if include_per_item: + items.append( + { + "utterance_id": req.seed_tts_utterance_id, + "locale": locale, + "error": "empty_audio", + } + ) + continue + + # UTMOS scores synthesized audio only; do not gate on ASR/WER (those can fail independently). + if utmos_on: + try: + utmos_v = _utmos_predict_f32_16k(wav_16k) + if utmos_v is not None: + utmos_values.append(utmos_v) + elif not _utmos_jit_load_failed: + utmos_failed += 1 + except Exception: + if not _utmos_forward_warned: + _utmos_forward_warned = True + logger.warning( + "UTMOS JIT forward failed (first utterance=%s; set logging DEBUG for " + "full trace). Check sample rate (16 kHz), input shape, or " + "SEED_TTS_UTMOS_DEVICE.", + req.seed_tts_utterance_id, + exc_info=True, + ) + else: + logger.debug( + "UTMOS forward failed for %s", + req.seed_tts_utterance_id, + exc_info=True, + ) + utmos_failed += 1 + + try: + if row_lang == "en": + hyp = _transcribe_en_f32_16k(wav_16k) + else: + fd, tmp_wav = tempfile.mkstemp(suffix=".wav") + os.close(fd) + try: + sf.write(tmp_wav, wav_16k, 16000, subtype="PCM_16") + hyp = _transcribe_zh_wav_path(tmp_wav) + finally: + try: + os.unlink(tmp_wav) + except OSError: + pass + except Exception as e: + logger.exception("Seed-TTS ASR failed for %s", req.seed_tts_utterance_id) + asr_failed += 1 + if include_per_item: + items.append( + { + "utterance_id": req.seed_tts_utterance_id, + "locale": locale, + "error": "asr_exception", + "detail": str(e)[:500], + } + ) + continue + + if not hyp: + asr_failed += 1 + if include_per_item: + items.append( + { + "utterance_id": req.seed_tts_utterance_id, + "locale": locale, + "error": "empty_asr", + } + ) + continue + + try: + wer, raw_truth, raw_hypo = process_one_official(hyp, ref, row_lang) + except Exception as e: + logger.warning("jiwer/normalize failed for %s: %s", req.seed_tts_utterance_id, e) + asr_failed += 1 + if include_per_item: + items.append( + { + "utterance_id": req.seed_tts_utterance_id, + "locale": locale, + "error": "wer_compute_failed", + "detail": str(e)[:500], + } + ) + continue + + errs.append(wer) + sim_v: float | None = None + + if _eval_submetric_enabled("SEED_TTS_SIM_EVAL", default=True): + ref_path = getattr(req, "seed_tts_ref_wav_path", "") or "" + if ref_path and os.path.isfile(ref_path): + try: + ref_wav = _audio_path_to_f32_16k(ref_path) + e_ref = _wavlm_mean_embedding_f32_16k(ref_wav) + e_hyp = _wavlm_mean_embedding_f32_16k(wav_16k) + if e_ref is not None and e_hyp is not None: + sim_v = _cosine_similarity_unit_vectors(e_ref, e_hyp) + sim_values.append(sim_v) + except Exception as e: + logger.warning( + "SIM embedding failed for utterance=%s: %s: %s", + req.seed_tts_utterance_id, + type(e).__name__, + e, + ) + sim_failed += 1 + else: + sim_skipped_no_ref += 1 + + if include_per_item: + row: dict[str, Any] = { + "utterance_id": req.seed_tts_utterance_id, + "locale": locale, + "wer": wer, + "reference_raw": raw_truth, + "asr_raw": raw_hypo, + } + if sim_v is not None: + row["sim"] = sim_v + if utmos_v is not None: + row["utmos"] = utmos_v + items.append(row) + + result: dict[str, Any] = { + "seed_tts_eval_protocol": "seed-tts-eval", + "seed_tts_content_evaluated": len(errs), + "seed_tts_content_error_mean": statistics.fmean(errs) if errs else None, + "seed_tts_content_error_median": statistics.median(errs) if errs else None, + "seed_tts_request_failed": request_failed, + "seed_tts_no_pcm": no_pcm, + "seed_tts_asr_failed": asr_failed, + "seed_tts_content_metric": "wer", + "seed_tts_sim_evaluated": len(sim_values), + "seed_tts_sim_mean": statistics.fmean(sim_values) if sim_values else None, + "seed_tts_sim_median": statistics.median(sim_values) if sim_values else None, + "seed_tts_sim_failed": sim_failed, + "seed_tts_sim_skipped_no_ref": sim_skipped_no_ref, + "seed_tts_utmos_evaluated": len(utmos_values), + "seed_tts_utmos_mean": statistics.fmean(utmos_values) if utmos_values else None, + "seed_tts_utmos_median": statistics.median(utmos_values) if utmos_values else None, + "seed_tts_utmos_failed": utmos_failed, + } + if include_per_item: + result["seed_tts_wer_eval_items"] = items + return result + + +def print_seed_tts_wer_summary(metrics: dict[str, Any]) -> None: + setup = metrics.get("seed_tts_eval_setup_error") + if setup: + print("{s:{c}^{n}}".format(s=" Seed-TTS eval (seed-tts-eval protocol) ", n=50, c="=")) + print(setup) + return + + ev = int(metrics.get("seed_tts_content_evaluated", 0) or 0) + rf = int(metrics.get("seed_tts_request_failed", 0) or 0) + npc = int(metrics.get("seed_tts_no_pcm", 0) or 0) + af = int(metrics.get("seed_tts_asr_failed", 0) or 0) + sim_ev = int(metrics.get("seed_tts_sim_evaluated", 0) or 0) + ut_ev = int(metrics.get("seed_tts_utmos_evaluated", 0) or 0) + if ev == 0 and rf == 0 and npc == 0 and af == 0 and sim_ev == 0 and ut_ev == 0: + return + print("{s:{c}^{n}}".format(s=" Seed-TTS eval (seed-tts-eval protocol) ", n=50, c="=")) + print("{:<40} {:<10}".format("Evaluated (WER, lower is better):", ev)) + mean = metrics.get("seed_tts_content_error_mean") + if mean is not None: + print("{:<40} {:<10.4f}".format("Mean WER:", float(mean))) + med = metrics.get("seed_tts_content_error_median") + if med is not None: + print("{:<40} {:<10.4f}".format("Median WER:", float(med))) + print("{:<40} {:<10}".format("Request failed:", metrics.get("seed_tts_request_failed", 0))) + print("{:<40} {:<10}".format("No PCM captured:", metrics.get("seed_tts_no_pcm", 0))) + print("{:<40} {:<10}".format("ASR / WER failed:", metrics.get("seed_tts_asr_failed", 0))) + if sim_ev or metrics.get("seed_tts_sim_skipped_no_ref") or metrics.get("seed_tts_sim_failed"): + print("{:<40} {:<10}".format("SIM evaluated (higher ~ closer):", sim_ev)) + sm = metrics.get("seed_tts_sim_mean") + if sm is not None: + print("{:<40} {:<10.4f}".format("Mean SIM:", float(sm))) + s_med = metrics.get("seed_tts_sim_median") + if s_med is not None: + print("{:<40} {:<10.4f}".format("Median SIM:", float(s_med))) + print("{:<40} {:<10}".format("SIM skipped (no ref path):", metrics.get("seed_tts_sim_skipped_no_ref", 0))) + print("{:<40} {:<10}".format("SIM embedding errors:", metrics.get("seed_tts_sim_failed", 0))) + if ut_ev or metrics.get("seed_tts_utmos_failed"): + print("{:<40} {:<10}".format("UTMOS evaluated (JIT MOS, higher better):", ut_ev)) + um = metrics.get("seed_tts_utmos_mean") + if um is not None: + print("{:<40} {:<10.4f}".format("Mean UTMOS:", float(um))) + u_med = metrics.get("seed_tts_utmos_median") + if u_med is not None: + print("{:<40} {:<10.4f}".format("Median UTMOS:", float(u_med))) + print("{:<40} {:<10}".format("UTMOS errors:", metrics.get("seed_tts_utmos_failed", 0))) + print("=" * 50) diff --git a/vllm_omni/benchmarks/patch/__init__.py b/vllm_omni/benchmarks/patch/__init__.py index e69de29bb2..ca6b41ba8f 100644 --- a/vllm_omni/benchmarks/patch/__init__.py +++ b/vllm_omni/benchmarks/patch/__init__.py @@ -0,0 +1,3 @@ +"""Omni benchmark monkey-patches (side effects in ``patch.patch``).""" + +from . import patch as _patch_module # noqa: F401 diff --git a/vllm_omni/benchmarks/patch/patch.py b/vllm_omni/benchmarks/patch/patch.py index 343655df20..41aed09423 100644 --- a/vllm_omni/benchmarks/patch/patch.py +++ b/vllm_omni/benchmarks/patch/patch.py @@ -6,6 +6,7 @@ import os import random import ssl +import sys import time import traceback from collections.abc import Iterable @@ -33,15 +34,245 @@ from vllm.tokenizers import TokenizerLike logger = init_logger(__name__) + +from vllm_omni.benchmarks.data_modules.daily_omni_dataset import DailyOmniDataset, DailyOmniSampleRequest from vllm_omni.benchmarks.data_modules.random_multi_modal_dataset import OmniRandomMultiModalDataset +from vllm_omni.benchmarks.data_modules.seed_tts_dataset import ( + SEED_TTS_DEFAULT_OMNI_SYSTEM_PROMPT, + SeedTTSDataset, + SeedTTSSampleRequest, +) get_samples_old = datasets.get_samples +_DEFAULT_DAILY_OMNI_REPO = "liarliar/Daily-Omni" + + +def _seed_tts_capture_pcm_for_wer() -> bool: + return os.environ.get("SEED_TTS_WER_EVAL", "").lower() in ( + "1", + "true", + "yes", + ) + + +def _merge_extra_body_mm_kwargs(base: dict | None, overlay: dict | None) -> dict | None: + """Shallow-merge ``extra_body`` dicts; deep-merge ``mm_processor_kwargs`` if both set.""" + if not base and not overlay: + return None + out = dict(base or {}) + if not overlay: + return out + for k, v in overlay.items(): + if k == "mm_processor_kwargs" and isinstance(v, dict): + prev = out.get("mm_processor_kwargs") + merged_kw = {**(prev if isinstance(prev, dict) else {}), **v} + out["mm_processor_kwargs"] = merged_kw + else: + out[k] = v + return out + + +def _attach_daily_omni_to_request_func_input(sample: SampleRequest, rfi: RequestFuncInput) -> None: + """Apply per-request OpenAI fields (``mm_processor_kwargs``, messages) for Daily-Omni.""" + if not isinstance(sample, DailyOmniSampleRequest): + return + rfi.extra_body = _merge_extra_body_mm_kwargs(rfi.extra_body, sample.omni_extra_body) + if sample.omni_chat_messages is not None: + setattr(rfi, "omni_chat_messages", sample.omni_chat_messages) + else: + setattr(rfi, "mm_position", sample.omni_chat_mm_position) + + +def _attach_seed_tts_to_request_func_input(sample: SampleRequest, rfi: RequestFuncInput) -> None: + """Merge Seed-TTS per-row TTS fields (ref_audio, ref_text, task_type, …) into ``extra_body``. + + Used by both ``/v1/audio/speech`` and ``/v1/chat/completions`` (flattened into JSON body). + For ``openai-chat-omni``, also sets ``omni_chat_messages`` (system + user) so Qwen3-Omni + follows the same role layout as official TTS / multimodal demos. ``/v1/audio/speech`` ignores + ``messages`` and only uses ``input`` + body fields. + Flags ``openai-chat-omni`` to request audio output and optionally export PCM for WER. + """ + if not isinstance(sample, SeedTTSSampleRequest): + return + ex = sample.seed_tts_speech_extra + if not ex: + return + base = dict(rfi.extra_body) if rfi.extra_body else {} + base.update(ex) + rfi.extra_body = base + # Used by request funcs to force streaming TTS behavior and to export PCM when WER is on. + setattr(rfi, "seed_tts_row", True) + sys_prompt = (sample.seed_tts_system_prompt or "").strip() or SEED_TTS_DEFAULT_OMNI_SYSTEM_PROMPT + setattr( + rfi, + "omni_chat_messages", + [ + {"role": "system", "content": [{"type": "text", "text": sys_prompt}]}, + {"role": "user", "content": [{"type": "text", "text": sample.prompt}]}, + ], + ) + + +def _daily_omni_repo_from_args(args) -> str | None: + """Resolve HuggingFace repo id for Daily-Omni from CLI args. + + vLLM allows ``--dataset-path`` to be a local path while the real HF id is + passed via ``--hf-name``. Upstream ``get_samples`` for ``hf`` only matches + a fixed elif-chain and never discovers Omni's loader, so we must detect + Daily-Omni here using either field. + """ + dp = getattr(args, "dataset_path", None) + hn = getattr(args, "hf_name", None) + if dp in DailyOmniDataset.SUPPORTED_DATASET_PATHS: + return dp + if hn in DailyOmniDataset.SUPPORTED_DATASET_PATHS: + return hn + return None + def get_samples(args, tokenizer): - if args.backend not in ["openai-chat-omni", "openai-audio-speech"]: + # Daily-Omni: explicit dataset name, or hf + matching path/hf-name + is_daily_omni = args.dataset_name == "daily-omni" or ( + args.dataset_name == "hf" and _daily_omni_repo_from_args(args) is not None + ) + is_seed_tts = args.dataset_name == "seed-tts" + + # Check if we need to handle omni-related backends/datasets + is_omni_backend = args.backend in ["openai-chat-omni", "openai-audio-speech", "daily-omni"] + is_omni_dataset = is_daily_omni or is_seed_tts or args.dataset_name == "random-mm" + + if not is_omni_backend and not is_omni_dataset: + # Not an omni-related request, delegate to original implementation return get_samples_old(args, tokenizer) - elif args.dataset_name == "random-mm": + + # Handle Daily-Omni dataset + if is_daily_omni: + # Support: + # --dataset-name daily-omni [--dataset-path liarliar/Daily-Omni] + # --dataset-name daily-omni --daily-omni-qa-json /path/to/qa.json (offline QA) + # --dataset-name hf --dataset-path liarliar/Daily-Omni + # --dataset-name hf --hf-name liarliar/Daily-Omni (dataset-path may be local) + + # Validate backend supports multimodal (video) + if args.backend not in ["openai-chat-omni", "daily-omni"]: + raise ValueError( + f"Daily-Omni dataset requires a multimodal backend that supports video. " + f"Got backend='{args.backend}'. Please use '--backend openai-chat-omni'" + ) + + # Determine video directory if specified (for local video files) + video_dir = getattr(args, "daily_omni_video_dir", None) + + # Get HF split (default to "train"; unused when loading from local qa.json) + dataset_split = getattr(args, "hf_split", None) or "train" + + qa_json = getattr(args, "daily_omni_qa_json", None) + if isinstance(qa_json, str): + qa_json = qa_json.strip() or None + + if qa_json is not None: + logger.info( + "Loading Daily-Omni dataset: qa_json=%s, video_dir=%s (Hub not used for QA)", + qa_json, + video_dir, + ) + dataset = DailyOmniDataset( + qa_json_path=qa_json, + dataset_path=None, + dataset_split=dataset_split, + random_seed=args.seed, + video_dir=video_dir, + input_mode=getattr(args, "daily_omni_input_mode", "all"), + inline_local_video=getattr(args, "daily_omni_inline_local_video", False), + trust_remote_code=getattr(args, "trust_remote_code", False), + disable_shuffle=getattr(args, "disable_shuffle", False), + ) + else: + repo_id = _daily_omni_repo_from_args(args) + if args.dataset_name == "daily-omni": + if repo_id is None: + repo_id = _DEFAULT_DAILY_OMNI_REPO + elif repo_id is None: + raise ValueError( + "Daily-Omni with --dataset-name hf requires " + f"--dataset-path {_DEFAULT_DAILY_OMNI_REPO} or " + f"--hf-name {_DEFAULT_DAILY_OMNI_REPO}." + ) + + logger.info( + "Loading Daily-Omni dataset: hf_repo=%s, split=%s, video_dir=%s", + repo_id, + dataset_split, + video_dir, + ) + + dataset = DailyOmniDataset( + dataset_path=repo_id, + dataset_split=dataset_split, + dataset_subset=getattr(args, "hf_subset", None), + random_seed=args.seed, + video_dir=video_dir, + input_mode=getattr(args, "daily_omni_input_mode", "all"), + inline_local_video=getattr(args, "daily_omni_inline_local_video", False), + trust_remote_code=getattr(args, "trust_remote_code", False), + no_stream=getattr(args, "no_stream", False), + disable_shuffle=getattr(args, "disable_shuffle", False), + ) + + out_len = getattr(args, "output_len", None) + if out_len is None: + out_len = getattr(args, "hf_output_len", None) + if out_len is None: + out_len = DailyOmniDataset.DEFAULT_OUTPUT_LEN + + input_requests = dataset.sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + output_len=out_len, + request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + ) + return input_requests + + if is_seed_tts: + if args.backend not in ("openai-audio-speech", "openai-chat-omni"): + raise ValueError( + "Seed-TTS requires --backend openai-audio-speech (POST /v1/audio/speech) or " + "--backend openai-chat-omni (POST /v1/chat/completions with ref_audio/ref_text). " + f"Got backend={args.backend!r}." + ) + repo_id = getattr(args, "dataset_path", None) or getattr(args, "hf_name", None) + if not repo_id: + raise ValueError( + "Seed-TTS requires --dataset-path (HF dataset repo id or local directory) or " + "--hf-name for the Hub dataset id." + ) + + dataset = SeedTTSDataset( + dataset_path=repo_id, + random_seed=args.seed, + locale=getattr(args, "seed_tts_locale", "en"), + inline_ref_audio=not getattr(args, "seed_tts_file_ref_audio", False), + seed_tts_root=getattr(args, "seed_tts_root", None), + system_prompt=getattr(args, "seed_tts_system_prompt", None), + disable_shuffle=getattr(args, "disable_shuffle", False), + ) + out_len = getattr(args, "output_len", None) + if out_len is None: + out_len = getattr(args, "hf_output_len", None) + if out_len is None: + out_len = SeedTTSDataset.DEFAULT_OUTPUT_LEN + return dataset.sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + output_len=out_len, + request_id_prefix=args.request_id_prefix, + no_oversample=args.no_oversample, + ) + + # Handle random-mm dataset (Omni's synthetic multimodal dataset) + if args.dataset_name == "random-mm": dataset = OmniRandomMultiModalDataset(random_seed=args.seed, dataset_path=args.dataset_path) input_requests = dataset.sample( tokenizer=tokenizer, @@ -64,6 +295,10 @@ def get_samples(args, tokenizer): datasets.get_samples = get_samples +_serve_mod = sys.modules.get("vllm.benchmarks.serve") +if _serve_mod is not None: + _serve_mod.get_samples = get_samples + @dataclass class MixRequestFuncOutput(RequestFuncOutput): @@ -72,6 +307,9 @@ class MixRequestFuncOutput(RequestFuncOutput): audio_frames: int = 0 audio_rtf: float = 0.0 text_latency: float = 0.0 + #: Raw PCM s16le mono at 24 kHz for Seed-TTS WER: from ``/v1/audio/speech`` stream or + #: resampled export after ``openai-chat-omni`` audio deltas. + tts_output_pcm_bytes: bytes | None = None async def async_request_openai_chat_omni_completions( @@ -83,13 +321,17 @@ async def async_request_openai_chat_omni_completions( api_url = request_func_input.api_url _validate_api_url(api_url, "OpenAI Chat Completions API", "chat/completions") - content = _get_chat_content(request_func_input, mm_position=mm_position) + omni_messages = getattr(request_func_input, "omni_chat_messages", None) + if omni_messages is not None: + messages_payload = omni_messages + else: + effective_mm_position = getattr(request_func_input, "mm_position", mm_position) + content = _get_chat_content(request_func_input, mm_position=effective_mm_position) + messages_payload = [{"role": "user", "content": content}] payload = { "model": request_func_input.model_name if request_func_input.model_name else request_func_input.model, - "messages": [ - {"role": "user", "content": content}, - ], + "messages": messages_payload, "temperature": 0.0, "max_tokens": request_func_input.output_len, "stream": True, @@ -98,6 +340,10 @@ async def async_request_openai_chat_omni_completions( }, } _update_payload_common(payload, request_func_input) + # Seed-TTS via chat: voice-clone fields live on the body; ensure audio is streamed. + if getattr(request_func_input, "seed_tts_row", False): + if payload.get("modalities") is None: + payload["modalities"] = ["text", "audio"] response_format = payload.get("response_format", "wav") if response_format == "pcm": @@ -143,7 +389,11 @@ async def async_request_openai_chat_omni_completions( if response.status == 200: handler = StreamedResponseHandler() async for chunk_bytes in response.content.iter_any(): - chunk_bytes = chunk_bytes.strip() + # NOTE: Do NOT strip() here; TCP may fragment the SSE messages, + # so stripping here can cause problems depending on how it is split. + # + # Simple example: [b'data: ', b'{json}\n\n'] <- stripping the first + # chunk will break SSE parsing because the space after 'data:' is required. if not chunk_bytes: continue @@ -163,7 +413,10 @@ async def async_request_openai_chat_omni_completions( data = json.loads(chunk) if choices := data.get("choices"): modality = data.get("modality") - content = choices[0]["delta"].get("content") + delta = choices[0].get("delta") or {} + content = delta.get("content") + if not content and isinstance(delta.get("audio"), dict): + content = delta["audio"].get("data") if modality == "text": # First token if ttft == 0.0: @@ -178,7 +431,7 @@ async def async_request_openai_chat_omni_completions( if output.audio_ttfp == 0.0: output.audio_ttfp = timestamp - st audio_generate_time = timestamp - st - if content != "": + if content: audio_bytes = base64.b64decode(content) seg = AudioSegment.from_file(io.BytesIO(audio_bytes)) if seg is not None: @@ -210,6 +463,12 @@ async def async_request_openai_chat_omni_completions( else: output.audio_rtf = 0 logger.warning("Audio duration is zero") + if _seed_tts_capture_pcm_for_wer() and getattr(request_func_input, "seed_tts_row", False): + try: + seg = generated_audio.set_frame_rate(24000).set_channels(1).set_sample_width(2) + output.tts_output_pcm_bytes = bytes(seg.raw_data) + except Exception as ex: + logger.warning("seed_tts WER PCM export failed: %s", ex) output.success = True else: output.error = response.reason or "" @@ -264,6 +523,10 @@ async def async_request_openai_audio_speech( "response_format": "pcm", } _update_payload_common(payload, request_func_input) + # Seed-TTS + WER: ``--extra-body`` may set stream=false / other formats; speech must stream PCM. + if getattr(request_func_input, "seed_tts_row", False) and _seed_tts_capture_pcm_for_wer(): + payload["stream"] = True + payload["response_format"] = "pcm" headers = { "Content-Type": "application/json", @@ -282,6 +545,8 @@ async def async_request_openai_audio_speech( st = time.perf_counter() output.start_time = st total_pcm_bytes = 0 + capture_wer_pcm = _seed_tts_capture_pcm_for_wer() and getattr(request_func_input, "seed_tts_row", False) + pcm_capture = bytearray() if capture_wer_pcm else None try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: @@ -293,6 +558,8 @@ async def async_request_openai_audio_speech( output.audio_ttfp = timestamp - st output.ttft = output.audio_ttfp total_pcm_bytes += len(chunk) + if pcm_capture is not None: + pcm_capture.extend(chunk) end_time = time.perf_counter() output.latency = end_time - st @@ -305,6 +572,16 @@ async def async_request_openai_audio_speech( else: output.audio_rtf = 0 logger.warning("Audio duration is zero") + if pcm_capture is not None and pcm_capture: + output.tts_output_pcm_bytes = bytes(pcm_capture) + elif capture_wer_pcm: + ct = response.headers.get("Content-Type", "") + logger.warning( + "Seed-TTS WER: HTTP 200 but no PCM bytes (Content-Type=%r, url=%s). " + "Check stream=true and response_format=pcm on the server.", + ct, + api_url, + ) output.success = True else: output.error = response.reason or "" @@ -327,6 +604,12 @@ async def async_request_openai_audio_speech( if "openai-audio-speech" not in OPENAI_COMPATIBLE_BACKENDS: OPENAI_COMPATIBLE_BACKENDS.append("openai-audio-speech") +# Daily-Omni backend for audio-visual reasoning benchmark +# Reuses openai-chat-omni completions for video+text understanding +ASYNC_REQUEST_FUNCS["daily-omni"] = async_request_openai_chat_omni_completions +if "daily-omni" not in OPENAI_COMPATIBLE_BACKENDS: + OPENAI_COMPATIBLE_BACKENDS.append("daily-omni") + # ruff: noqa: E402 # Prevent import order from causing patch failures from vllm.benchmarks import serve @@ -418,6 +701,8 @@ async def benchmark( extra_headers=extra_headers, extra_body=extra_body, ) + _attach_daily_omni_to_request_func_input(input_requests[0], test_input) + _attach_seed_tts_to_request_func_input(input_requests[0], test_input) if ready_check_timeout_sec > 0: test_output = await wait_for_endpoint( @@ -480,6 +765,8 @@ async def warmup_limited_request_func(): extra_headers=extra_headers, extra_body=extra_body, ) + _attach_daily_omni_to_request_func_input(input_requests[0], profile_input) + _attach_seed_tts_to_request_func_input(input_requests[0], profile_input) profile_output = await request_func(request_func_input=profile_input, session=session) if profile_output.success: print("Profiler started") @@ -560,6 +847,8 @@ async def limited_request_func(request_func_input, session, pbar): extra_body=extra_body, request_id=request_id, ) + _attach_daily_omni_to_request_func_input(request, request_func_input) + _attach_seed_tts_to_request_func_input(request, request_func_input) tasks.append( asyncio.create_task(limited_request_func(request_func_input=request_func_input, session=session, pbar=pbar)) ) @@ -627,6 +916,37 @@ async def limited_request_func(request_func_input, session, pbar): "errors": [output.error for output in outputs], } + from vllm_omni.benchmarks.data_modules.daily_omni_eval import ( + compute_daily_omni_accuracy_metrics, + print_daily_omni_accuracy_summary, + ) + + _save_items = os.environ.get("DAILY_OMNI_SAVE_EVAL_ITEMS", "").lower() in ( + "1", + "true", + "yes", + ) + _daily_acc = compute_daily_omni_accuracy_metrics(input_requests, outputs, include_per_item=_save_items) + if _daily_acc is not None: + result.update(_daily_acc) + print_daily_omni_accuracy_summary(_daily_acc) + + if _seed_tts_capture_pcm_for_wer(): + from vllm_omni.benchmarks.data_modules.seed_tts_eval import ( + compute_seed_tts_wer_metrics, + print_seed_tts_wer_summary, + ) + + _save_wer = os.environ.get("SEED_TTS_WER_SAVE_ITEMS", "").lower() in ( + "1", + "true", + "yes", + ) + _wer_m = compute_seed_tts_wer_metrics(input_requests, outputs, include_per_item=_save_wer) + if _wer_m is not None: + result.update(_wer_m) + print_seed_tts_wer_summary(_wer_m) + if rps_change_events: result["rps_change_events"] = rps_change_events diff --git a/vllm_omni/benchmarks/serve.py b/vllm_omni/benchmarks/serve.py index fe94603693..d3f3510c56 100644 --- a/vllm_omni/benchmarks/serve.py +++ b/vllm_omni/benchmarks/serve.py @@ -1,9 +1,21 @@ import argparse import asyncio +import os from typing import Any from vllm.benchmarks.serve import main_async +# Import patch to register daily-omni dataset and omni backends +# This monkey-patches vllm.benchmarks.datasets.get_samples before it's used +# Must be imported before any vllm.benchmarks module usage +import vllm_omni.benchmarks.patch.patch # noqa: F401 + def main(args: argparse.Namespace) -> dict[str, Any]: + if getattr(args, "seed_tts_wer_eval", False): + os.environ["SEED_TTS_WER_EVAL"] = "1" + if getattr(args, "seed_tts_wer_save_items", False): + os.environ["SEED_TTS_WER_SAVE_ITEMS"] = "1" + if getattr(args, "daily_omni_save_eval_items", False): + os.environ["DAILY_OMNI_SAVE_EVAL_ITEMS"] = "1" return asyncio.run(main_async(args)) diff --git a/vllm_omni/config/__init__.py b/vllm_omni/config/__init__.py index 2aa236e69f..f02c075880 100644 --- a/vllm_omni/config/__init__.py +++ b/vllm_omni/config/__init__.py @@ -5,10 +5,18 @@ from vllm_omni.config.lora import LoRAConfig from vllm_omni.config.model import OmniModelConfig from vllm_omni.config.stage_config import ( + DeployConfig, ModelPipeline, + PipelineConfig, StageConfig, StageConfigFactory, + StageDeployConfig, + StageExecutionType, + StagePipelineConfig, StageType, + load_deploy_config, + merge_pipeline_deploy, + register_pipeline, ) from vllm_omni.config.yaml_util import ( create_config, @@ -24,6 +32,14 @@ "StageConfigFactory", "ModelPipeline", "StageType", + "StageExecutionType", + "StagePipelineConfig", + "PipelineConfig", + "StageDeployConfig", + "DeployConfig", + "load_deploy_config", + "merge_pipeline_deploy", + "register_pipeline", "create_config", "load_yaml_config", "merge_configs", diff --git a/vllm_omni/config/pipeline_registry.py b/vllm_omni/config/pipeline_registry.py new file mode 100644 index 0000000000..c07bc2610c --- /dev/null +++ b/vllm_omni/config/pipeline_registry.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Central declarative registry of all vllm-omni pipelines. + +Mirrors the pattern in ``vllm/model_executor/models/registry.py``: each entry +is ``model_type -> (module_path, variable_name)``, and the module is imported +lazily on first lookup (see ``_LazyPipelineRegistry`` in +``vllm_omni/config/stage_config.py``). Keeping every pipeline declared in one +file makes it easy to spot a missing registration, which was the original +motivation in https://github.com/vllm-project/vllm-omni/issues/2887 (item 4). + +Per-model ``pipeline.py`` modules still define the ``PipelineConfig`` instance; +they just no longer need to self-register via ``register_pipeline(...)``. + +Adding a new pipeline: + 1. Define the ``PipelineConfig`` instance as a module-level variable in + ``vllm_omni/.../pipeline.py``. + 2. Add one line to ``_OMNI_PIPELINES`` or ``_DIFFUSION_PIPELINES`` below. + +``register_pipeline(config)`` in ``stage_config`` is still supported for +out-of-tree plugins and tests that create pipelines at runtime; those override +the entries declared here. +""" + +from __future__ import annotations + +# --- Multi-stage omni pipelines (LLM-centric; audio / video I/O) --- +_OMNI_PIPELINES: dict[str, tuple[str, str]] = { + # model_type -> (module_path, variable_name) + "qwen2_5_omni": ( + "vllm_omni.model_executor.models.qwen2_5_omni.pipeline", + "QWEN2_5_OMNI_PIPELINE", + ), + "qwen2_5_omni_thinker_only": ( + "vllm_omni.model_executor.models.qwen2_5_omni.pipeline", + "QWEN2_5_OMNI_THINKER_ONLY_PIPELINE", + ), + "qwen3_omni_moe": ( + "vllm_omni.model_executor.models.qwen3_omni.pipeline", + "QWEN3_OMNI_PIPELINE", + ), + "qwen3_tts": ( + "vllm_omni.model_executor.models.qwen3_tts.pipeline", + "QWEN3_TTS_PIPELINE", + ), +} + +# --- Single-stage diffusion pipelines (populated in PR 3/N) --- +_DIFFUSION_PIPELINES: dict[str, tuple[str, str]] = {} + +# Union view used by ``_LazyPipelineRegistry``; don't mutate at runtime. +_VLLM_OMNI_PIPELINES: dict[str, tuple[str, str]] = { + **_OMNI_PIPELINES, + **_DIFFUSION_PIPELINES, +} diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py index a4e186c3bd..392a550be6 100644 --- a/vllm_omni/config/stage_config.py +++ b/vllm_omni/config/stage_config.py @@ -1,18 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Stage Configuration System for vLLM-Omni. - -Pipeline structure (stages, types, data-flow) is defined in per-model YAML -files and is set by model developers at integration time. -Runtime parameters (gpu_memory_utilization, tp_size, etc.) come from CLI. -""" +"""Stage configuration system for vLLM-Omni.""" from __future__ import annotations +import dataclasses import re import warnings -from dataclasses import asdict, dataclass, field +from dataclasses import asdict, dataclass, field, fields from enum import Enum from pathlib import Path from typing import Any @@ -20,76 +15,818 @@ from vllm.logger import init_logger from vllm_omni.config.yaml_util import create_config, load_yaml_config, to_dict +from vllm_omni.core.sched.omni_ar_scheduler import OmniARScheduler +from vllm_omni.core.sched.omni_generation_scheduler import OmniGenerationScheduler -# Pipeline YAMLs live alongside model code in model_executor/models// _MODELS_DIR = Path(__file__).resolve().parent.parent / "model_executor" / "models" def get_pipeline_path(model_dir: str, filename: str) -> Path: - """Return the full path to a pipeline YAML file. + return _MODELS_DIR / model_dir / filename + + +logger = init_logger(__name__) + + +_STAGE_OVERRIDE_PATTERN = re.compile(r"^stage_(\d+)_(.+)$") - Args: - model_dir: Model subdirectory name (e.g., "qwen3_omni"). - filename: Name of the YAML file (e.g., "pipeline.yaml"). - Returns: - Absolute path to the file. +def build_stage_runtime_overrides( + stage_id: int, + cli_overrides: dict[str, Any], + *, + internal_keys: set[str] | frozenset[str] | None = None, +) -> dict[str, Any]: + """Build per-stage runtime overrides from global and ``stage__*`` kwargs. + + ``internal_keys`` defaults to the union of + ``arg_utils.internal_blacklist_keys()`` and ``arg_utils.SHARED_FIELDS`` + so that neither orchestrator-only fields nor shared-pipeline fields + (``model`` / ``stage_configs_path`` / ``log_stats`` / ``stage_id``) leak + into a stage's per-stage runtime overrides — the orchestrator sets those + uniformly for every stage, they are not per-stage knobs. Callers can + pass an explicit set for tests or specialized flows. """ - return _MODELS_DIR / model_dir / filename + if internal_keys is None: + from vllm_omni.engine.arg_utils import SHARED_FIELDS, internal_blacklist_keys + internal_keys = internal_blacklist_keys() | SHARED_FIELDS -logger = init_logger(__name__) + result: dict[str, Any] = {} + + for key, value in cli_overrides.items(): + if value is None or key in internal_keys: + continue + + match = _STAGE_OVERRIDE_PATTERN.match(key) + if match is not None: + override_stage_id = int(match.group(1)) + param_name = match.group(2) + if override_stage_id == stage_id and param_name not in internal_keys: + result[param_name] = value + continue + + result[key] = value + + return result + + +def strip_parent_engine_args( + kwargs: dict[str, Any], + *, + parent_fields: dict[str, dataclasses.Field], + keep_keys: set[str] | frozenset[str] = frozenset(), + strip_keys: set[str] | frozenset[str] = frozenset(), + no_warn_keys: set[str] | frozenset[str] = frozenset(), +) -> tuple[dict[str, Any], list[str]]: + """Strip parent ``EngineArgs`` fields before merging into stage YAML.""" + overridden: list[str] = [] + result: dict[str, Any] = {} + + for key, value in kwargs.items(): + if key in strip_keys: + continue + + if key not in parent_fields or key in keep_keys: + result[key] = value + continue + + field_def = parent_fields[key] + if field_def.default is not dataclasses.MISSING: + default = field_def.default + elif field_def.default_factory is not dataclasses.MISSING: + default = field_def.default_factory() + else: + default = dataclasses.MISSING + + if default is dataclasses.MISSING or value is None: + continue + + if dataclasses.is_dataclass(default) and not isinstance(default, type): + default = asdict(default) + + if value != default and key not in no_warn_keys: + overridden.append(key) + + return result, sorted(overridden) class StageType(str, Enum): """Type of processing stage in the Omni pipeline.""" + # TODO(@lishunyang12): remove once all models migrate to StageExecutionType LLM = "llm" DIFFUSION = "diffusion" +class StageExecutionType(str, Enum): + """Merged StageType + WorkerType — 3 combinations today.""" + + LLM_AR = "llm_ar" + LLM_GENERATION = "llm_generation" + DIFFUSION = "diffusion" + + +# Mapping class refs (not dotted-path strings) so module/class renames fail +# at import time instead of lazily at scheduler resolution. YAML overrides +# and downstream serialization still use the dotted-path string form; the +# conversion happens at the map lookup site via _scheduler_path(). +_EXECUTION_TYPE_TO_SCHEDULER: dict[StageExecutionType, type | None] = { + StageExecutionType.LLM_AR: OmniARScheduler, + StageExecutionType.LLM_GENERATION: OmniGenerationScheduler, + StageExecutionType.DIFFUSION: None, +} + + +def _scheduler_path(cls: type | None) -> str | None: + """Return the dotted import path for a scheduler class (``None`` passes through).""" + if cls is None: + return None + return f"{cls.__module__}.{cls.__qualname__}" + + +@dataclass(frozen=True) +class StagePipelineConfig: + """Fixed topology for one stage (frozen, not user-configurable).""" + + stage_id: int + model_stage: str + execution_type: StageExecutionType = StageExecutionType.LLM_AR + input_sources: tuple[int, ...] = () + final_output: bool = False + final_output_type: str | None = None + owns_tokenizer: bool = False + requires_multimodal_data: bool = False + hf_config_name: str | None = None + engine_output_type: str | None = None + model_arch: str | None = None + sampling_constraints: dict[str, Any] = field(default_factory=dict) + custom_process_input_func: str | None = None + custom_process_next_stage_input_func: str | None = None + # Alternates picked by ``merge_pipeline_deploy`` based on ``deploy.async_chunk``. + async_chunk_process_next_stage_input_func: str | None = None + sync_process_input_func: str | None = None + prompt_expand_func: str | None = None + cfg_kv_collect_func: str | None = None + omni_kv_config: dict[str, Any] | None = None + extras: dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class PipelineConfig: + """Complete pipeline topology for a model (frozen).""" + + model_type: str + model_arch: str = "" + stages: tuple[StagePipelineConfig, ...] = () + # HF architecture aliases: used by StageConfigFactory when the model's + # HF config reports a generic model_type that collides with a different + # model (e.g. MiMo Audio reports model_type="qwen2"). The factory + # matches ``hf_config.architectures[*]`` against this tuple to route + # to the correct pipeline. Leave empty for models with unique model_type. + hf_architectures: tuple[str, ...] = () + + def get_stage(self, stage_id: int) -> StagePipelineConfig | None: + """Look up a stage by its ID.""" + for stage in self.stages: + if stage.stage_id == stage_id: + return stage + return None + + def get_scheduler_cls(self, stage_id: int) -> str | None: + """Return the inferred scheduler class path for a stage. + + Returns ``None`` for DIFFUSION stages (no vLLM scheduler). Raises + ``ValueError`` if ``stage_id`` doesn't exist in this pipeline, and + ``KeyError`` if ``execution_type`` isn't in the scheduler map. + """ + stage = self.get_stage(stage_id) + if stage is None: + raise ValueError(f"Pipeline {self.model_type!r} has no stage with id {stage_id}") + return _scheduler_path(_EXECUTION_TYPE_TO_SCHEDULER[stage.execution_type]) + + def validate(self) -> list[str]: + """Return list of topology errors (empty if valid).""" + errors: list[str] = [] + if not self.stages: + errors.append("Pipeline has no stages defined") + return errors + stage_ids = [s.stage_id for s in self.stages] + if len(stage_ids) != len(set(stage_ids)): + errors.append("Duplicate stage IDs found") + stage_id_set = set(stage_ids) + for stage in self.stages: + for src in stage.input_sources: + if src not in stage_id_set: + errors.append(f"Stage {stage.stage_id} references non-existent input source {src}") + if src == stage.stage_id: + errors.append(f"Stage {stage.stage_id} references itself") + if not any(not s.input_sources for s in self.stages): + errors.append("No entry point (stage with empty input_sources)") + return errors + + +class _LazyPipelineRegistry: + """Dict-like registry that lazy-loads pipelines from the central declaration. + + In-tree pipelines are declared once in + ``vllm_omni/config/pipeline_registry.py`` as + ``model_type -> (module_path, variable_name)`` entries; the module is + imported only when the pipeline is first looked up. This mirrors the + pattern in ``vllm/model_executor/models/registry.py`` and addresses + https://github.com/vllm-project/vllm-omni/issues/2887 (item 4): having + every registration in one file makes a missing entry easy to spot. + + Out-of-tree / dynamic registrations via ``register_pipeline()`` are stored + directly in ``_loaded`` and take precedence over the lazy-map entry with + the same ``model_type``. + + The class exposes the subset of ``dict`` operations the rest of this + module relies on (``__contains__``, ``__getitem__``, ``__setitem__``, + ``get``, ``keys``, ``values``, ``items``, ``__iter__``), so existing call + sites don't need to change. + """ + + def __init__(self) -> None: + self._loaded: dict[str, PipelineConfig] = {} + # Populated lazily to avoid a circular import at module init time. + self._lazy_map: dict[str, tuple[str, str]] | None = None + + def _get_lazy_map(self) -> dict[str, tuple[str, str]]: + if self._lazy_map is None: + from vllm_omni.config.pipeline_registry import _VLLM_OMNI_PIPELINES + + self._lazy_map = _VLLM_OMNI_PIPELINES + return self._lazy_map + + def _load_lazy(self, model_type: str) -> PipelineConfig | None: + entry = self._get_lazy_map().get(model_type) + if entry is None: + return None + module_path, var_name = entry + import importlib + + try: + module = importlib.import_module(module_path) + except ImportError as exc: + logger.error( + "Failed to import pipeline module %r for %r: %s", + module_path, + model_type, + exc, + ) + return None + pipeline = getattr(module, var_name, None) + if pipeline is None: + logger.error( + "Pipeline variable %r not found in module %r (registered for %r)", + var_name, + module_path, + model_type, + ) + return None + errors = pipeline.validate() + if errors: + logger.warning("Pipeline %s has issues: %s", pipeline.model_type, errors) + self._loaded[model_type] = pipeline + return pipeline + + def __contains__(self, model_type: str) -> bool: + if model_type in self._loaded: + return True + return model_type in self._get_lazy_map() + + def __getitem__(self, model_type: str) -> PipelineConfig: + if model_type in self._loaded: + return self._loaded[model_type] + pipeline = self._load_lazy(model_type) + if pipeline is None: + raise KeyError(model_type) + return pipeline + + def get(self, model_type: str, default: PipelineConfig | None = None) -> PipelineConfig | None: + if model_type in self._loaded: + return self._loaded[model_type] + pipeline = self._load_lazy(model_type) + return pipeline if pipeline is not None else default + + def __setitem__(self, model_type: str, pipeline: PipelineConfig) -> None: + self._loaded[model_type] = pipeline + + def __delitem__(self, model_type: str) -> None: + """Remove a dynamically-registered pipeline. + + Only the dynamic-cache side of the registry can be mutated; the + central declarative registry is immutable at runtime. Calling ``del`` + on a model_type that only exists in the central registry raises + ``KeyError``. + """ + if model_type in self._loaded: + del self._loaded[model_type] + return + if model_type in self._get_lazy_map(): + raise KeyError( + f"{model_type!r} is declared in the central pipeline_registry and " + "cannot be removed at runtime. Edit " + "vllm_omni/config/pipeline_registry.py to delete a built-in entry." + ) + raise KeyError(model_type) + + def keys(self) -> set[str]: + return set(self._get_lazy_map().keys()) | set(self._loaded.keys()) + + def values(self): + # Iterating values forces load of every lazy pipeline. + for key in self.keys(): + yield self[key] + + def items(self): + for key in self.keys(): + yield key, self[key] + + def __iter__(self): + return iter(self.keys()) + + +_PIPELINE_REGISTRY = _LazyPipelineRegistry() + + +def register_pipeline(pipeline: PipelineConfig) -> None: + """Register a pipeline config dynamically. + + In-tree pipelines are declared in ``pipeline_registry._VLLM_OMNI_PIPELINES`` + and loaded lazily; calling ``register_pipeline`` is only needed for + out-of-tree plugins or tests that build a ``PipelineConfig`` at runtime. + A dynamic registration overrides the central-registry entry with the same + ``model_type``. + """ + errors = pipeline.validate() + if errors: + logger.warning("Pipeline %s has issues: %s", pipeline.model_type, errors) + _PIPELINE_REGISTRY[pipeline.model_type] = pipeline + + +_DEPLOY_DIR = Path(__file__).resolve().parent.parent / "deploy" + + +@dataclass +class StageDeployConfig: + """Per-stage deployment knobs. + + Only fields whose value legitimately varies across stages of the same + pipeline live here (e.g. ``max_num_seqs`` on thinker vs talker, + ``devices`` for GPU placement). Pipeline-wide settings + (``trust_remote_code``, ``distributed_executor_backend``, ``dtype``, + ``quantization``, prefix/chunked prefill, DP/PP sizes) are declared at + the top level of ``DeployConfig`` and propagated to every stage. + """ + + stage_id: int + max_num_seqs: int = 64 + gpu_memory_utilization: float = 0.9 + tensor_parallel_size: int = 1 + enforce_eager: bool = False + max_num_batched_tokens: int = 32768 + max_model_len: int | None = None + async_scheduling: bool | None = None + devices: str = "0" + output_connectors: dict[str, str] | None = None + input_connectors: dict[str, str] | None = None + default_sampling_params: dict[str, Any] | None = None + engine_extras: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class DeployConfig: + """Loaded from deploy/.yaml — the only config file users edit. + + Top-level fields (``trust_remote_code``, ``distributed_executor_backend``, + ``dtype``, ``quantization``, ``enable_prefix_caching``, + ``enable_chunked_prefill``, ``data_parallel_size``, + ``pipeline_parallel_size``) are pipeline-wide: they apply uniformly to + every stage. Fields that legitimately vary per stage live in the + individual ``StageDeployConfig`` entries under ``stages:``. + """ + + async_chunk: bool = True + connectors: dict[str, Any] | None = None + edges: list[dict[str, Any]] | None = None + stages: list[StageDeployConfig] = field(default_factory=list) + platforms: dict[str, Any] | None = None + # Overrides the auto-detected pipeline registry key for structural variants. + pipeline: str | None = None + + # === Pipeline-wide engine settings (applied uniformly to every stage) === + trust_remote_code: bool = True + distributed_executor_backend: str = "mp" + dtype: str | None = None + quantization: str | None = None + enable_prefix_caching: bool = False + enable_chunked_prefill: bool | None = None + data_parallel_size: int = 1 + pipeline_parallel_size: int = 1 + + +_STAGE_NON_ENGINE_KEYS = frozenset( + { + "stage_id", + "devices", + "output_connectors", + "input_connectors", + "default_sampling_params", + "engine_extras", + } +) + +# Fields on StageDeployConfig that are populated from engine_args dict +_STAGE_DEPLOY_FIELDS = {f.name: f for f in fields(StageDeployConfig) if f.name not in _STAGE_NON_ENGINE_KEYS} + + +def _parse_stage_deploy(stage_data: dict[str, Any]) -> StageDeployConfig: + """Parse a single stage entry from deploy YAML into StageDeployConfig.""" + if "engine_args" in stage_data: + engine_args = dict(stage_data["engine_args"]) + devices = stage_data.get("runtime", {}).get("devices", stage_data.get("devices", "0")) + else: + engine_args = {k: v for k, v in stage_data.items() if k not in _STAGE_NON_ENGINE_KEYS and k != "stage_id"} + devices = stage_data.get("devices", "0") + + kwargs: dict[str, Any] = {"stage_id": stage_data["stage_id"], "devices": devices} + for name, f in _STAGE_DEPLOY_FIELDS.items(): + if name in engine_args: + kwargs[name] = engine_args.pop(name) + + kwargs["output_connectors"] = stage_data.get("output_connectors") + kwargs["input_connectors"] = stage_data.get("input_connectors") + kwargs["default_sampling_params"] = stage_data.get("default_sampling_params") + kwargs["engine_extras"] = engine_args + return StageDeployConfig(**kwargs) + + +_DEEP_MERGE_KEYS = frozenset({"default_sampling_params", "engine_extras", "engine_args"}) + + +def _deep_merge_stage(base: dict, overlay: dict) -> dict: + """Deep-merge ``_DEEP_MERGE_KEYS`` so thin overlays don't drop base keys.""" + merged = dict(base) + for k, v in overlay.items(): + if k in _DEEP_MERGE_KEYS: + base_val = merged.get(k) + if isinstance(v, dict) and isinstance(base_val, dict): + merged[k] = {**base_val, **v} + continue + # Deep-merge key but at least one side isn't a dict: surface the + # silent clobber so mismatched YAML types don't get past review. + if base_val is not None: + logger.warning( + "Deep-merge key %r has non-dict value (base=%s, overlay=%s); " + "overlay will fully replace base instead of merging.", + k, + type(base_val).__name__, + type(v).__name__, + ) + merged[k] = v + return merged + + +def _merge_stage_lists( + base_stages: list[dict[str, Any]] | None, + overlay_stages: list[dict[str, Any]] | None, +) -> list[dict[str, Any]]: + """Merge two ``stages:`` lists by ``stage_id`` (overlay wins per field).""" + by_id: dict[int, dict[str, Any]] = {s["stage_id"]: s for s in (base_stages or [])} + for overlay_stage in overlay_stages or []: + sid = overlay_stage["stage_id"] + if sid in by_id: + by_id[sid] = _deep_merge_stage(by_id[sid], overlay_stage) + else: + by_id[sid] = overlay_stage + return list(by_id.values()) + + +def _merge_platforms( + base: dict[str, Any] | None, + overlay: dict[str, Any] | None, +) -> dict[str, Any] | None: + """Deep-merge two ``platforms:`` blocks per-platform, per-stage_id.""" + if not base and not overlay: + return None + base = base or {} + overlay = overlay or {} + merged: dict[str, Any] = {} + for plat in set(base) | set(overlay): + bp = base.get(plat) or {} + op = overlay.get(plat) or {} + merged_plat = {**bp, **{k: v for k, v in op.items() if k != "stages"}} + merged_plat["stages"] = _merge_stage_lists(bp.get("stages"), op.get("stages")) + merged[plat] = merged_plat + return merged + + +def resolve_deploy_yaml(path: str | Path) -> dict[str, Any]: + """Load a deploy YAML with optional ``base_config`` inheritance.""" + raw_dict = to_dict(load_yaml_config(path)) + + base_path = raw_dict.pop("base_config", None) + if base_path is None: + return raw_dict + + # Resolve relative to the overlay file's directory + base_path = Path(path).parent / base_path + base_dict = resolve_deploy_yaml(base_path) + + # Merge top-level scalars: overlay wins. ``stages:`` and ``platforms:`` + # are deep-merged below so an overlay can layer on top of the base. + merged = { + **base_dict, + **{k: v for k, v in raw_dict.items() if k not in ("stages", "platforms")}, + } + merged["stages"] = _merge_stage_lists(base_dict.get("stages"), raw_dict.get("stages")) + merged_platforms = _merge_platforms(base_dict.get("platforms"), raw_dict.get("platforms")) + if merged_platforms is not None: + merged["platforms"] = merged_platforms + + return merged + + +def load_deploy_config(path: str | Path) -> DeployConfig: + """Load a deploy YAML (with optional base_config inheritance).""" + raw_dict = resolve_deploy_yaml(path) + + stages = [_parse_stage_deploy(s) for s in raw_dict.get("stages", [])] + + kwargs: dict[str, Any] = { + "async_chunk": raw_dict.get("async_chunk", True), + "connectors": raw_dict.get("connectors", None), + "edges": raw_dict.get("edges", None), + "stages": stages, + "platforms": raw_dict.get("platforms", None), + "pipeline": raw_dict.get("pipeline", None), + } + # Pipeline-wide engine settings: only set if explicitly present in YAML + # so the DeployConfig dataclass defaults take effect otherwise. + for name in ( + "trust_remote_code", + "distributed_executor_backend", + "dtype", + "quantization", + "enable_prefix_caching", + "enable_chunked_prefill", + "data_parallel_size", + "pipeline_parallel_size", + ): + if name in raw_dict: + kwargs[name] = raw_dict[name] + return DeployConfig(**kwargs) + + +def _detect_platform() -> str | None: + """Return "npu", "rocm", "xpu", or None (CUDA default).""" + try: + from vllm.platforms import current_platform + + name = current_platform.device_name.lower() + if "npu" in name: + return "npu" + if "rocm" in name or "amd" in name: + return "rocm" + if "xpu" in name: + return "xpu" + except Exception as e: + logger.debug("Platform auto-detect failed, falling back to CUDA: %s", e) + return None + + +def _extract_platform_overrides(ps: dict[str, Any]) -> tuple[dict[str, Any], str | None]: + """Return ``(overrides, devices)`` from a platform stage entry. + + Handles both the nested layout (``engine_args:`` / ``runtime.devices``) and + the flat layout. ``devices`` is ``None`` when no override is set. + """ + if "engine_args" in ps: + return dict(ps["engine_args"]), ps.get("runtime", {}).get("devices") + overrides = {k: v for k, v in ps.items() if k not in ("stage_id", "devices")} + return overrides, ps.get("devices") + + +def _apply_platform_overrides( + deploy: DeployConfig, + platform: str | None = None, +) -> DeployConfig: + """Merge platform-specific stage overrides into deploy config.""" + if platform is None: + platform = _detect_platform() + if platform is None or deploy.platforms is None: + return deploy + platform_section = deploy.platforms.get(platform) + if platform_section is None: + return deploy + + platform_stages = platform_section.get("stages", []) + base_by_id = {s.stage_id: s for s in deploy.stages} + + for ps in platform_stages: + base = base_by_id.get(ps["stage_id"]) + if base is None: + continue + overrides, devices = _extract_platform_overrides(ps) + if devices is not None: + base.devices = devices + for key, val in overrides.items(): + if hasattr(base, key): + setattr(base, key, val) + else: + base.engine_extras[key] = val + + return deploy + + +_EXECUTION_TYPE_TO_STAGE_WORKER: dict[StageExecutionType, tuple[StageType, str | None]] = { + StageExecutionType.LLM_AR: (StageType.LLM, "ar"), + StageExecutionType.LLM_GENERATION: (StageType.LLM, "generation"), + StageExecutionType.DIFFUSION: (StageType.DIFFUSION, None), +} + + +def _resolve_execution_mode( + execution_type: StageExecutionType, +) -> tuple[StageType, str | None]: + """Map ``execution_type`` → ``(stage_type, worker_type)`` legacy tuple.""" + return _EXECUTION_TYPE_TO_STAGE_WORKER.get(execution_type, (StageType.LLM, None)) + + +def _select_processor_funcs( + ps: StagePipelineConfig, + async_chunk: bool, +) -> tuple[str | None, str | None]: + """Pick ``(input_proc, next_stage_proc)`` based on the async_chunk mode.""" + next_stage_proc = ps.custom_process_next_stage_input_func + input_proc = ps.custom_process_input_func + if async_chunk and ps.async_chunk_process_next_stage_input_func: + next_stage_proc = ps.async_chunk_process_next_stage_input_func + elif not async_chunk and ps.sync_process_input_func: + input_proc = ps.sync_process_input_func + return input_proc, next_stage_proc + + +# Pipeline-wide DeployConfig fields that are propagated to every stage's +# engine args during merge. These live at top level of the deploy YAML. +_PIPELINE_WIDE_ENGINE_FIELDS: tuple[str, ...] = ( + "trust_remote_code", + "distributed_executor_backend", + "dtype", + "quantization", + "enable_prefix_caching", + "enable_chunked_prefill", + "data_parallel_size", + "pipeline_parallel_size", +) + + +def _build_engine_args( + ps: StagePipelineConfig, + ds: StageDeployConfig | None, + pipeline: PipelineConfig, + deploy: DeployConfig, + next_stage_proc: str | None, +) -> dict[str, Any]: + """Assemble the flat ``yaml_engine_args`` dict for one stage. + + Pipeline-wide DeployConfig fields are applied uniformly to every stage; + per-stage StageDeployConfig overrides take precedence when present (e.g. + ``engine_extras`` can still carry a stage-specific ``dtype``). + """ + engine_args: dict[str, Any] = {"model_arch": ps.model_arch or pipeline.model_arch} + if ps.engine_output_type: + engine_args["engine_output_type"] = ps.engine_output_type + if next_stage_proc: + engine_args["custom_process_next_stage_input_func"] = next_stage_proc + + # Pipeline-wide top-level DeployConfig settings, applied to every stage. + for name in _PIPELINE_WIDE_ENGINE_FIELDS: + value = getattr(deploy, name) + if value is not None: + engine_args[name] = value + + # Per-stage StageDeployConfig values override pipeline-wide settings. + if ds is not None: + for k, v in asdict(ds).items(): + if k in _STAGE_NON_ENGINE_KEYS or v is None: + continue + engine_args[k] = v + engine_args.update(ds.engine_extras) + if deploy.async_chunk: + engine_args["async_chunk"] = True + return engine_args + + +def _build_extras( + ps: StagePipelineConfig, + ds: StageDeployConfig | None, +) -> dict[str, Any]: + """Assemble ``yaml_extras`` (sampling + connectors + pipeline extras).""" + extras: dict[str, Any] = {} + sampling: dict[str, Any] = {} + if ds is not None and ds.default_sampling_params: + sampling.update(ds.default_sampling_params) + sampling.update(ps.sampling_constraints) + if sampling: + extras["default_sampling_params"] = sampling + if ds is not None and ds.output_connectors: + extras["output_connectors"] = dict(ds.output_connectors) + if ds is not None and ds.input_connectors: + extras["input_connectors"] = dict(ds.input_connectors) + if ps.extras: + extras.update(ps.extras) + return extras + + +def merge_pipeline_deploy( + pipeline: PipelineConfig, + deploy: DeployConfig, + cli_overrides: dict[str, Any] | None = None, +) -> list[StageConfig]: + """Merge pipeline + deploy + platform overrides → list[StageConfig].""" + if cli_overrides is None: + cli_overrides = {} + + deploy = _apply_platform_overrides(deploy) + deploy_by_id = {s.stage_id: s for s in deploy.stages} + + # A pipeline supports async_chunk if any stage has either an explicit + # async-chunk-only processor slot OR a custom next-stage processor (some + # pipelines like qwen3_omni wire async-chunk processing directly through + # ``custom_process_next_stage_input_func``). Only raise when neither is + # present — that's the "user enabled async_chunk but pipeline has no + # inter-stage processing at all" case. + if deploy.async_chunk and not any( + ps.async_chunk_process_next_stage_input_func or ps.custom_process_next_stage_input_func + for ps in pipeline.stages + ): + raise ValueError( + f"Pipeline {pipeline.model_type!r} has async_chunk=True in deploy but no stage " + "declares a next-stage input processor " + "(``async_chunk_process_next_stage_input_func`` or ``custom_process_next_stage_input_func``). " + "Either set async_chunk=False or implement an async-chunk processor on the pipeline." + ) + + result: list[StageConfig] = [] + for ps in pipeline.stages: + ds = deploy_by_id.get(ps.stage_id) + stage_type, worker_type = _resolve_execution_mode(ps.execution_type) + input_proc, next_stage_proc = _select_processor_funcs(ps, deploy.async_chunk) + engine_args = _build_engine_args(ps, ds, pipeline, deploy, next_stage_proc) + extras = _build_extras(ps, ds) + runtime: dict[str, Any] = {"process": True} + if ds is not None: + runtime["devices"] = ds.devices + + result.append( + StageConfig( + stage_id=ps.stage_id, + model_stage=ps.model_stage, + stage_type=stage_type, + input_sources=list(ps.input_sources), + custom_process_input_func=input_proc, + final_output=ps.final_output, + final_output_type=ps.final_output_type, + worker_type=worker_type, + scheduler_cls=_scheduler_path(_EXECUTION_TYPE_TO_SCHEDULER.get(ps.execution_type)), + hf_config_name=ps.hf_config_name, + is_comprehension=ps.owns_tokenizer, + yaml_engine_args=engine_args, + yaml_runtime=runtime, + yaml_extras=extras, + ) + ) + return result + + @dataclass class StageConfig: - """Per-stage configuration from pipeline YAML. + """Per-stage config (legacy path). Used by both new and legacy loaders. - Topology fields (stage_id, input_sources, etc.) define the DAG. - Engine and runtime defaults come from the YAML; CLI overrides take - precedence via ``runtime_overrides``. + TODO(@lishunyang12): replace with ResolvedStageConfig once all models are migrated. """ - # Identity stage_id: int model_stage: str - - # Stage type stage_type: StageType = StageType.LLM - input_sources: list[int] = field(default_factory=list) custom_process_input_func: str | None = None final_output: bool = False - final_output_type: str | None = None # "text", "audio", "image" - worker_type: str | None = None # "ar" or "generation" + final_output_type: str | None = None + worker_type: str | None = None scheduler_cls: str | None = None hf_config_name: str | None = None is_comprehension: bool = False - - # Per-stage engine args from pipeline YAML (defaults) yaml_engine_args: dict[str, Any] = field(default_factory=dict) - # Per-stage runtime config from pipeline YAML (devices, etc.) yaml_runtime: dict[str, Any] = field(default_factory=dict) - # Pass-through fields from pipeline YAML (default_sampling_params, - # output_connectors, input_connectors, tts_args, etc.) yaml_extras: dict[str, Any] = field(default_factory=dict) - - # Runtime overrides (populated from CLI, not from pipeline YAML) runtime_overrides: dict[str, Any] = field(default_factory=dict) def to_omegaconf(self) -> Any: - """Convert to OmegaConf for backward compatibility with OmniStage. - - Returns: - OmegaConf DictConfig with stage configuration in legacy format. - """ + """TODO(@lishunyang12): remove once engine consumes ResolvedStageConfig directly.""" # Start with YAML engine_args defaults engine_args: dict[str, Any] = dict(self.yaml_engine_args) @@ -152,9 +889,9 @@ def to_omegaconf(self) -> Any: @dataclass class ModelPipeline: - """Complete pipeline definition for a multi-stage model. + """Complete pipeline definition for a multi-stage model (legacy). - Defined by model developers, bundled with the model, not user-editable. + TODO(@lishunyang12): remove once all models migrate to PipelineConfig. """ model_type: str @@ -225,49 +962,55 @@ class StageConfigFactory: """Factory that loads pipeline YAML and merges CLI overrides. Handles both single-stage and multi-stage models. - """ - # Mapping of model types to directories under model_executor/models/. - PIPELINE_MODELS: dict[str, str] = { - "qwen3_omni_moe": "qwen3_omni", - "qwen2_5_omni": "qwen2_5_omni", - "bagel": "bagel", - "qwen3_tts": "qwen3_tts", - "voxtral_tts": "voxtral_tts", - "mimo_audio": "mimo_audio", - "glm-image": "glm_image", - "cosyvoice3": "cosyvoice3", - "mammothmoda2": "mammoth_moda2", - } - - # Fallback: map HF architecture class names to pipeline dirs. - # Used when model_type collides with another model (e.g. MiMo Audio - # reports model_type="qwen2" which matches plain Qwen2, not our pipeline). - _ARCHITECTURE_MODELS: dict[str, str] = { - "MiMoAudioForConditionalGeneration": "mimo_audio", - "HunyuanImage3ForCausalMM": "hunyuan_image3", - } + Pipelines are declared in ``vllm_omni/config/pipeline_registry.py`` and + loaded lazily via ``_PIPELINE_REGISTRY``; no hardcoded model-type → + directory mapping is maintained here. Models with generic HF + ``model_type`` collisions (e.g. MiMo Audio reports ``qwen2``) should + declare ``hf_architectures=(...)`` on their ``PipelineConfig`` so the + factory can disambiguate via ``hf_config.architectures``. + """ @classmethod def create_from_model( cls, model: str, cli_overrides: dict[str, Any] | None = None, + deploy_config_path: str | None = None, + cli_explicit_keys: set[str] | None = None, ) -> list[StageConfig] | None: - """Load pipeline YAML, merge with CLI overrides. + """Load pipeline + deploy config, merge with CLI overrides. - Args: - model: Model name or path. - cli_overrides: CLI overrides from VllmConfig/OmniDiffusionConfig. + Checks _PIPELINE_REGISTRY first (new path), falls back to legacy YAML. - Returns: - List of StageConfig objects with CLI overrides applied, - or None if no pipeline definition was found for this model. + ``cli_explicit_keys`` is the set of CLI keys the user actually typed + (captured at the parser layer in ``vllm serve``). When ``None`` — + which is the case for programmatic ``Omni()`` callers — every kwarg + in ``cli_overrides`` is treated as explicit. """ if cli_overrides is None: cli_overrides = {} trust_remote_code = cli_overrides.get("trust_remote_code", True) + + # --- New path: check pipeline registry by model_type first --- + model_type, hf_config = cls._auto_detect_model_type(model, trust_remote_code=trust_remote_code) + if model_type and model_type in _PIPELINE_REGISTRY: + return cls._create_from_registry(model_type, cli_overrides, deploy_config_path, cli_explicit_keys) + + # --- HF architecture fallback: some models report a generic + # model_type that collides with another model. Match by the + # hf_architectures declared on each registered PipelineConfig. + if hf_config is not None: + hf_archs = set(getattr(hf_config, "architectures", []) or []) + if hf_archs: + for registered in _PIPELINE_REGISTRY.values(): + if hf_archs.intersection(registered.hf_architectures): + return cls._create_from_registry( + registered.model_type, cli_overrides, deploy_config_path, cli_explicit_keys + ) + + # --- Legacy path: load from pipeline YAML --- pipeline = cls._load_pipeline(model, trust_remote_code=trust_remote_code) if pipeline is None: @@ -295,6 +1038,78 @@ def create_from_model( return result + @classmethod + def _create_from_registry( + cls, + model_type: str, + cli_overrides: dict[str, Any], + deploy_config_path: str | None = None, + cli_explicit_keys: set[str] | None = None, + ) -> list[StageConfig]: + """Create StageConfigs from pipeline registry + deploy YAML. + + Precedence (high → low): + explicit CLI args > deploy YAML > parser default CLI values + + ``cli_explicit_keys`` carries the set of long-option attribute names + the user actually typed (captured in ``OmniServeCommand.cmd``). Any + kwarg whose key is not in that set is treated as a parser default + and is only used to fill fields YAML doesn't already cover. When the + set is ``None`` (programmatic ``Omni()`` callers, which have no + argparse layer), every kwarg is treated as explicit. + """ + # Resolve deploy config path + if deploy_config_path is None: + deploy_path = _DEPLOY_DIR / f"{model_type}.yaml" + else: + deploy_path = Path(deploy_config_path) + + if not deploy_path.exists(): + logger.warning( + "Deploy config not found: %s — using pipeline defaults only", + deploy_path, + ) + deploy_cfg = DeployConfig() + else: + deploy_cfg = load_deploy_config(deploy_path) + + cli_async_chunk = cli_overrides.get("async_chunk") + if cli_async_chunk is not None and (cli_explicit_keys is None or "async_chunk" in cli_explicit_keys): + deploy_cfg.async_chunk = bool(cli_async_chunk) + + pipeline_key = deploy_cfg.pipeline or model_type + if pipeline_key not in _PIPELINE_REGISTRY: + raise KeyError( + f"Pipeline {pipeline_key!r} not in registry " + f"(resolved from {deploy_path.name!r}). Available: " + f"{sorted(_PIPELINE_REGISTRY.keys())}" + ) + pipeline_cfg = _PIPELINE_REGISTRY[pipeline_key] + + stages = merge_pipeline_deploy(pipeline_cfg, deploy_cfg, cli_overrides) + + # Precedence: explicit CLI > yaml > parser-default CLI. + # Per-stage (``stage_N_*``) keys are always treated as explicit. + explicit_overrides: dict[str, Any] = {} + default_overrides: dict[str, Any] = {} + for key, value in cli_overrides.items(): + if value is None: + continue + is_per_stage = bool(re.match(r"stage_\d+_", key)) + is_explicit = cli_explicit_keys is None or key in cli_explicit_keys or is_per_stage + if is_explicit: + explicit_overrides[key] = value + else: + default_overrides[key] = value + + for stage in stages: + yaml_keys = set(stage.yaml_engine_args) + fallback = {k: v for k, v in default_overrides.items() if k not in yaml_keys} + merged = {**fallback, **explicit_overrides} + stage.runtime_overrides = cls._merge_cli_overrides(stage, merged) + + return stages + @classmethod def create_default_diffusion(cls, kwargs: dict[str, Any]) -> list[dict[str, Any]]: """Single-stage diffusion - no YAML needed. @@ -322,9 +1137,16 @@ def create_default_diffusion(cls, kwargs: dict[str, Any]) -> list[dict[str, Any] continue engine_args[key] = value - # Serialize parallel_config as dict for OmegaConf compatibility + # Serialize parallel_config as dict for OmegaConf. Test helpers + # sometimes pass SimpleNamespace rather than a dataclass instance. if "parallel_config" in kwargs: - engine_args["parallel_config"] = asdict(kwargs["parallel_config"]) + parallel_config = kwargs["parallel_config"] + if dataclasses.is_dataclass(parallel_config) and not isinstance(parallel_config, type): + engine_args["parallel_config"] = asdict(parallel_config) + elif hasattr(parallel_config, "__dict__"): + engine_args["parallel_config"] = dict(vars(parallel_config)) + else: + engine_args["parallel_config"] = parallel_config engine_args.setdefault("cache_backend", "none") engine_args["model_stage"] = "diffusion" @@ -351,40 +1173,49 @@ def create_default_diffusion(cls, kwargs: dict[str, Any]) -> list[dict[str, Any] @classmethod def _load_pipeline(cls, model: str, trust_remote_code: bool = True) -> ModelPipeline | None: - """Load pipeline YAML for the model. + """Load a legacy ``pipeline.yaml`` for the model. - Args: - model: Model name or path. - trust_remote_code: Whether to trust remote code for HF config loading. + Searches ``model_executor/models//pipeline.yaml`` by trying + (a) the raw ``model_type`` as the directory name, then + (b) ``model_type`` with hyphens replaced by underscores, + and finally (c) scanning every ``pipeline.yaml`` for one that + declares a matching ``model_type`` or ``hf_architectures``. - Returns: - ModelPipeline if found, None otherwise. + Returns None if no pipeline.yaml is found — caller handles the + ``resolve_model_config_path`` fallback via stage_configs/ YAMLs. """ model_type, hf_config = cls._auto_detect_model_type(model, trust_remote_code=trust_remote_code) if model_type is None: return None - pipeline_dir = cls.PIPELINE_MODELS.get(model_type) - - # Fallback: check HF architectures when model_type doesn't match - if pipeline_dir is None and hf_config is not None: - for arch in getattr(hf_config, "architectures", []) or []: - pipeline_dir = cls._ARCHITECTURE_MODELS.get(arch) - if pipeline_dir is not None: - model_type = pipeline_dir - break - - if pipeline_dir is None: - logger.debug(f"No pipeline mapping for model_type: {model_type}") - return None - - pipeline_path = get_pipeline_path(pipeline_dir, "pipeline.yaml") - - if not pipeline_path.exists(): - logger.debug(f"Pipeline file not found: {pipeline_path}") - return None + # Direct lookups by convention + candidates = [model_type, model_type.replace("-", "_")] + for dir_name in candidates: + pipeline_path = get_pipeline_path(dir_name, "pipeline.yaml") + if pipeline_path.exists(): + return cls._parse_pipeline_yaml(pipeline_path, model_type) + + # Scan fallback: read every pipeline.yaml and match on declared fields + hf_archs = set(getattr(hf_config, "architectures", []) or []) if hf_config else set() + if _MODELS_DIR.exists(): + for subdir in sorted(_MODELS_DIR.iterdir()): + if not subdir.is_dir(): + continue + pipeline_path = subdir / "pipeline.yaml" + if not pipeline_path.exists(): + continue + try: + cfg = load_yaml_config(pipeline_path) + except Exception as exc: + logger.debug("Skip %s: %s", pipeline_path, exc) + continue + declared_type = getattr(cfg, "model_type", None) + declared_archs = set(getattr(cfg, "hf_architectures", None) or []) + if declared_type == model_type or (hf_archs and hf_archs.intersection(declared_archs)): + return cls._parse_pipeline_yaml(pipeline_path, declared_type or model_type) - return cls._parse_pipeline_yaml(pipeline_path, model_type) + logger.debug("No pipeline.yaml found for model_type %s (archs=%s)", model_type, sorted(hf_archs)) + return None # Keys consumed as explicit StageConfig fields — everything else is # passed through via yaml_extras. @@ -542,66 +1373,17 @@ def _auto_detect_model_type(cls, model: str, trust_remote_code: bool = True) -> return None, None - # Keys that should never be forwarded as engine overrides (internal / - # orchestrator-only knobs, complex objects, etc.). - _INTERNAL_KEYS: set[str] = { - "model", - "stage_configs_path", - "stage_id", - "stage_init_timeout", - "init_timeout", - "shm_threshold_bytes", - "worker_backend", - "ray_address", - "batch_timeout", - "log_stats", - "tokenizer", - "parallel_config", - } - @classmethod def _merge_cli_overrides( cls, stage: StageConfig, cli_overrides: dict[str, Any], ) -> dict[str, Any]: - """Merge CLI overrides into stage runtime config. + """Merge global and per-stage (``stage_N_*``) CLI overrides. - All CLI arguments registered by engine config classes (e.g. - EngineArgs / OmniDiffusionConfig) are accepted as overrides - unless they appear in ``_INTERNAL_KEYS``. - - Handles: - - Global overrides (apply to all stages) - - Per-stage overrides (--stage-N-* format, take precedence) - - Args: - stage: The stage to merge overrides into. - cli_overrides: CLI arguments from VllmConfig/OmniDiffusionConfig. - - Returns: - Dict of runtime overrides for this stage. + Orchestrator-owned keys are filtered by ``build_stage_runtime_overrides`` + using ``OrchestratorArgs`` as the single source of truth; unknown + server/uvicorn keys are dropped downstream by + ``filter_dataclass_kwargs(OmniEngineArgs, ...)``. """ - result: dict[str, Any] = {} - - # Apply global overrides – any key not in the internal blocklist - # is forwarded so that engine-registered params work out of the box. - for key, value in cli_overrides.items(): - if key in cls._INTERNAL_KEYS: - continue - if re.match(r"stage_\d+_", key): - # Per-stage keys handled below - continue - if value is not None: - result[key] = value - - # Apply per-stage overrides (--stage-N-* format, take precedence) - stage_prefix = f"stage_{stage.stage_id}_" - for key, value in cli_overrides.items(): - if key.startswith(stage_prefix) and value is not None: - param_name = key[len(stage_prefix) :] - if param_name in cls._INTERNAL_KEYS: - continue - result[param_name] = value - - return result + return build_stage_runtime_overrides(stage.stage_id, cli_overrides) diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py new file mode 100644 index 0000000000..69e7346c4c --- /dev/null +++ b/vllm_omni/core/prefix_cache.py @@ -0,0 +1,264 @@ +""" +Utilities for Prefix Caching in Omni models. +""" + +import torch +from vllm.logger import init_logger +from vllm.v1.worker.gpu_input_batch import InputBatch + +from vllm_omni.utils.mm_outputs import build_mm_cpu, to_payload_element + +logger = init_logger(__name__) + + +class OmniTensorPrefixCache: + """Prefix cache for hidden states (model outputs) and model specific + multimodal outputs. + + This class implements prefix caching in a non-invasive way on top of + vLLM by leveraging the same slot mappings that the vLLM scheduler uses + for the KV Cache. + + Conceptually, this means we are mapping vLLM's cache mapping: + (num_blocks, block_size) + + to 3D tensors of shape: + (num_blocks, block_size, feature_size) + + Note that feature_size may vary across multimodal_outputs. + """ + + def __init__( + self, + num_blocks: int, + block_size: int, + hidden_size: int, + hs_dtype: torch.dtype, + ): + self.num_blocks = num_blocks + self.block_size = block_size + self.default_hidden_size = hidden_size + + # Initialize the hidden states cache immediately + self.hidden_states_cache = self._get_cache_tensor(dtype=hs_dtype) + + # Defer initialization of the mm_outputs_cache until we + # actually see mm output tensors dependent on num tokens. + self.mm_outputs_cache = {} + self.mm_cache_keys = set() + self._new_req_cache_hit_ids: set[str] = set() + + def maybe_init_missing_mm_cache_keys(self, multimodal_outputs: dict, seq_len: int): + """Given multimodal outputs from executing the model, dynamically + determine which multimodal outputs are tensors depending on sequence + length and should be cached, and initialize the cache tensors + accordingly. + + NOTE: This is done to avoid the need for explicit specification of + cache keys for every model/stage and aligns with the current way + that we slice the multimodal outputs based on the first dimension. + + This will usually be called by the first forward pass, i.e., + determined by the warmup. + """ + for key, val in multimodal_outputs.items(): + if isinstance(val, torch.Tensor) and val.shape[0] == seq_len and key not in self.mm_cache_keys: + feat_dim = val.shape[-1] + self.mm_outputs_cache[key] = self._get_cache_tensor( + dtype=val.dtype, + hidden_size=feat_dim, + ) + self.mm_cache_keys.add(key) + new_tensor_shape = self.mm_outputs_cache[key].shape + logger.info("Initializing multimodal output cache of size %s for key: %s", list(new_tensor_shape), key) + + def _get_cache_tensor(self, dtype: torch.dtype, hidden_size: int | None = None) -> torch.Tensor: + """Allocate a CPU cache tensor for a specific key.""" + actual_hidden_size = hidden_size if hidden_size is not None else self.default_hidden_size + return torch.zeros( + (self.num_blocks, self.block_size, actual_hidden_size), + dtype=dtype, + device="cpu", + ) + + def add_prefix_cached_new_req_id(self, req_id: str): + """Adds a new request ID to the set of prefix cache hits on the batch.""" + self._new_req_cache_hit_ids.add(req_id) + + def reset_prefix_cached_new_req_ids(self): + """Clears the cache hit IDs to prepare for a new engine step.""" + self._new_req_cache_hit_ids.clear() + + @staticmethod + def _coerce_to_cpu_tensor(maybe_gpu_tensor: torch.Tensor) -> torch.Tensor: + """Convert GPU tensors -> contiguous CPU tensors if needed.""" + return maybe_gpu_tensor.detach().cpu().contiguous() + + def update_omni_tensor_prefix_cache( + self, + hidden_states: torch.Tensor | None, + multimodal_outputs: dict[str, torch.Tensor] | None, + num_tokens_unpadded: int, + slot_mapping: torch.Tensor, + num_tokens_padded: int | None = None, + ): + """Updates the hidden cache state for the provided hidden states and multimodal outputs. + + Args: + hidden_states: Hidden states tensor to cache (if any) + multimodal_outputs: Multimodal dict whose tensors may be cached + num_tokens_unpadded: Number of tokens without padding + slot_mapping: Slot mapping for the input sequence + num_tokens_padded: Total number of tokens including padding + """ + unpadded_slot_mapping = slot_mapping[:num_tokens_unpadded] + if num_tokens_padded is None: + num_tokens_padded = num_tokens_unpadded + + if hidden_states is not None: + # Slice to unpadded portion before caching + hidden_states = hidden_states[:num_tokens_unpadded] + # Ensure that hidden states are on the CPU + hidden_states = OmniTensorPrefixCache._coerce_to_cpu_tensor(hidden_states) + # View the cache as 2D so that we can treat our slots as row indices + flat_cache = self.hidden_states_cache.view(-1, self.hidden_states_cache.shape[-1]) + flat_cache[unpadded_slot_mapping] = hidden_states + logger.debug("Writing to hidden states for %s tokens", num_tokens_unpadded) + + # Do the same for the stage's cached multimodal outputs + if multimodal_outputs is not None: + # If we haven't initialized the keys already, do it now + # We check against the padded token count since we haven't sliced yet + self.maybe_init_missing_mm_cache_keys( + multimodal_outputs, + seq_len=num_tokens_padded, + ) + + for mm_out_key, mm_cache in self.mm_outputs_cache.items(): + if mm_out_key in multimodal_outputs: + # Slice to unpadded portion before caching + mm_state = multimodal_outputs[mm_out_key][:num_tokens_unpadded] + mm_state = OmniTensorPrefixCache._coerce_to_cpu_tensor(mm_state) + flat_cache = mm_cache.view(-1, mm_cache.shape[-1]) + flat_cache[unpadded_slot_mapping] = mm_state + logger.debug("Writing to mm output cache for %s tokens", num_tokens_unpadded) + + def _coerce_to_payload_dict( + self, + element: object, + query_start_loc: torch.Tensor, + input_batch: InputBatch, + num_scheduled_tokens: dict[str, int], + ) -> dict[str, object]: + """Build the multimodal passthrough data per request for + the object under consideration. This is identical to the case + for no prefix cache when we tensor does have a first dimension + matching the seq len. + """ + elem_dict = {} + for req_id in input_batch.req_ids: + req_idx = input_batch.req_id_to_index[req_id] + start = query_start_loc[req_idx] + end = start + num_scheduled_tokens[req_id] + elem_dict[req_id] = to_payload_element( + element, req_idx, start=start, end=end, pass_lists_through=True, seq_len=None + ) + return elem_dict + + def get_merged_multimodal_states( + self, + query_start_loc: torch.Tensor, + input_batch: InputBatch, + multimodal_outputs: dict, + num_scheduled_tokens: dict[str, int], + ): + """Get the merged multimodal states if hidden state prefix caching is enabled.""" + combined_multimodal_outputs = {} + # First get the prefix cached tensors that are present in the mm data + for mm_key in self.mm_cache_keys: + if mm_key in multimodal_outputs: + combined_multimodal_outputs[mm_key] = self._get_merged_tensors( + query_start_loc=query_start_loc, + input_batch=input_batch, + cache=self.mm_outputs_cache[mm_key], + hidden_states=multimodal_outputs[mm_key], + num_scheduled_tokens=num_scheduled_tokens, + ) + + # Then, get everything else (passthrough data); first, convert to CPU + # tensors similarly to the non prefix cached path, and then populate + # the subdicts mapping request IDs -> payload objects + passthrough_keys = set(multimodal_outputs.keys()) - self.mm_cache_keys + passthrough_mm_data = {k: v for k, v in multimodal_outputs.items() if k in passthrough_keys} + mm_cpu = build_mm_cpu(multimodal_outputs=passthrough_mm_data) + + for mm_key, mm_val in mm_cpu.items(): + combined_multimodal_outputs[mm_key] = self._coerce_to_payload_dict( + element=mm_val, + query_start_loc=query_start_loc, + input_batch=input_batch, + num_scheduled_tokens=num_scheduled_tokens, + ) + return combined_multimodal_outputs + + def get_merged_hidden_states(self, *args, **kwargs) -> dict[str, torch.Tensor]: + """Get the merged hidden states.""" + return self._get_merged_tensors( + *args, + **kwargs, + cache=self.hidden_states_cache, + ) + + def _get_merged_tensors( + self, + query_start_loc: torch.Tensor, + input_batch: InputBatch, + cache: torch.Tensor, + hidden_states: torch.Tensor, + num_scheduled_tokens: dict[str, int], + ) -> dict[str, torch.Tensor]: + """When hidden state caching is enabled, takes the input hidden_states, + which only correspond to the scheduled tokens, and returns a mapping + from request IDs to their full hidden states. This is accomplished by + looking up the block IDs & scheduled token counts to split the + hidden_states. + """ + # We do not support hybrid caches at the moment. + if len(input_batch.block_table.block_tables) > 1: + logger.warning_once( + "Omni prefix caching is enabled, but the batch block table appears to" + " have multiple kv groups; only the first group will be used!" + ) + + combined_hidden_states = {} + hidden_states = OmniTensorPrefixCache._coerce_to_cpu_tensor(hidden_states) + for req_id in input_batch.req_ids: + req_idx = input_batch.req_id_to_index[req_id] + + if req_id in self._new_req_cache_hit_ids: + block_ids = self._get_cached_block_ids(req_idx, input_batch) + cached_hs = cache[block_ids].reshape(-1, cache.shape[-1]) + + # Slice the hidden states corresponding to this request; + # we do this by using the query start + start = query_start_loc[req_idx] + new_hs = hidden_states[start : start + num_scheduled_tokens[req_id]] + combined_hidden_states[req_id] = torch.cat([cached_hs, new_hs], dim=0) + else: + # cache miss for this request, pass through normally + start = query_start_loc[req_idx] + new_hs = hidden_states[start : start + num_scheduled_tokens[req_id]] + combined_hidden_states[req_id] = new_hs + + return combined_hidden_states + + def _get_cached_block_ids(self, req_idx: int, input_batch: InputBatch) -> torch.Tensor: + """Given an input batch and request index in the batch (not ID), get the + block IDs corresponding to the cache hit. + """ + num_computed = input_batch.num_computed_tokens_cpu[req_idx] + # NOTE: vLLM only caches full blocks + num_cached_blocks = num_computed // self.block_size + # Get the block IDs attached to this cache hit and reindex into + # the flattened cached hidden states (i.e., 1 row per token). + return input_batch.block_table[0].block_table.cpu[req_idx, :num_cached_blocks] diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index eac737b6e6..a5579dd464 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -15,9 +15,10 @@ from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs from vllm.v1.metrics.perf import PerfStats from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.request import Request, RequestStatus +from vllm.v1.request import Request, RequestStatus, StreamingUpdate from vllm.v1.spec_decode.metrics import SpecDecodingStats +from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin from vllm_omni.core.sched.output import OmniSchedulerOutput from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import ( OmniChunkTransferAdapter, @@ -38,7 +39,7 @@ def to_dict(self) -> dict[str, Any]: return asdict(self) -class OmniARScheduler(VLLMScheduler): +class OmniARScheduler(OmniSchedulerMixin, VLLMScheduler): """ OmniARScheduler: Scheduler for vLLM-Omni multimodal processing. @@ -59,6 +60,11 @@ def __init__(self, *args, **kwargs): # Track ACTIVE transfers (submitted to runner but not yet acked via kv_extracted_req_ids) self.active_kv_transfers: set[str] = set() + # Requests marked for deferred stop: keep running until KV extraction + # completes so that kv_ready can be emitted while the request is still + # alive. Stopped on the first scheduler step after extraction ack. + self.pending_stop_after_extraction: set[str] = set() + # [Omni] Pre-parse KV transfer criteria self.kv_transfer_criteria = self._get_kv_transfer_criteria() @@ -71,6 +77,8 @@ def __init__(self, *args, **kwargs): self.chunk_transfer_adapter = None if getattr(model_config, "async_chunk", False): self.chunk_transfer_adapter = OmniChunkTransferAdapter(self.vllm_config) + # Snapshot prompt length for each streaming input update + self._new_prompt_len_snapshot: dict[str, int] = {} def _get_kv_transfer_criteria(self) -> dict | None: # Note: vllm_config is available in Scheduler after super().__init__ @@ -126,11 +134,16 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int stop_decode_on_trigger = self.kv_transfer_criteria.get("stop_after_transfer", True) if request.request_id in self.transfer_triggered_requests: - # Already triggered. When stop_decode_on_trigger is True AND - # transfer was actually queued, the request was already stopped - # at trigger time (see below). Any request that reaches this - # point either has stop_decode_on_trigger=False (continue - # decoding) or was not actually queued (should not be stopped). + # Deferred stop: once KV extraction is complete (no longer in + # active_kv_transfers), stop the request. This guarantees the + # kv_ready signal was emitted while the request was still alive. + if ( + request.request_id in self.pending_stop_after_extraction + and request.request_id not in self.active_kv_transfers + ): + self.pending_stop_after_extraction.discard(request.request_id) + request.status = RequestStatus.FINISHED_STOPPED + return True return False if criteria_type == "prefill_finished": @@ -140,14 +153,11 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int actually_queued = request.request_id in self.requests_needing_kv_transfer if stop_decode_on_trigger and actually_queued: - # Stop immediately so the request is NOT scheduled in - # the next step, freeing scheduling budget for companion - # requests whose chunked-prefill boundaries must be - # deterministic. waiting_for_transfer_free keeps blocks - # alive until the model runner finishes KV extraction. - self.waiting_for_transfer_free.add(request.request_id) - request.status = RequestStatus.FINISHED_STOPPED - return True + # Defer the stop until KV extraction completes so that + # the kv_ready signal can be emitted while the request + # is still alive. The request will be stopped on the + # next scheduler step after extraction ack arrives. + self.pending_stop_after_extraction.add(request.request_id) return False @@ -167,9 +177,7 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int actually_queued = request.request_id in self.requests_needing_kv_transfer if stop_decode_on_trigger and actually_queued: - self.waiting_for_transfer_free.add(request.request_id) - request.status = RequestStatus.FINISHED_STOPPED - return True + self.pending_stop_after_extraction.add(request.request_id) return False @@ -268,6 +276,26 @@ def update_from_output( num_scheduled_tokens, ) + # Pre-process KV extraction acks so that the per-request loop below + # can see up-to-date active_kv_transfers state and emit kv_ready + # signals while requests are still alive (before any deferred stop). + kv_extracted_ids = getattr(model_runner_output, "kv_extracted_req_ids", None) + if kv_extracted_ids: + for req_id in kv_extracted_ids: + try: + self.active_kv_transfers.discard(req_id) + req = self.requests.get(req_id) + if req is not None and not req.is_finished(): + outputs[req.client_index].append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=[], + kv_transfer_params={"kv_ready": True}, + ) + ) + except Exception: + init_logger(__name__).exception("Failed to pre-process KV extraction for %s", req_id) + # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, # the below loop can be a performance bottleneck. We should do our best # to avoid expensive operations inside the loop. @@ -313,6 +341,7 @@ def update_from_output( ) stopped = False + is_segment_finished = False new_logprobs = None new_token_ids = generated_token_ids pooler_output = pooler_outputs[req_index] if pooler_outputs else None @@ -341,6 +370,7 @@ def update_from_output( # Capture finish_reason BEFORE _handle_stopped_request, which may # reset the status to WAITING for streaming requests that continue. finish_reason = request.get_finished_reason() + is_segment_finished = request.is_finished() and request.resumable finished = self._handle_stopped_request(request) if finished: kv_transfer_params = self._free_request(request) @@ -393,6 +423,8 @@ def update_from_output( num_external_computed_tokens=request.num_external_computed_tokens, routed_experts=routed_experts, num_nans_in_logits=request.num_nans_in_logits, + is_segment_finished=is_segment_finished, + new_prompt_len_snapshot=self._new_prompt_len_snapshot.get(req_id, None), ) ) if self.chunk_transfer_adapter is not None: @@ -436,6 +468,7 @@ def update_from_output( self.transfer_triggered_requests.remove(req.request_id) if req.request_id in self.active_kv_transfers: self.active_kv_transfers.remove(req.request_id) + self.pending_stop_after_extraction.discard(req.request_id) # Same for preempted for req in stopped_preempted_reqs: @@ -444,6 +477,8 @@ def update_from_output( self.transfer_triggered_requests.remove(req.request_id) if req.request_id in self.active_kv_transfers: self.active_kv_transfers.remove(req.request_id) + self.pending_stop_after_extraction.discard(req.request_id) + # KV Connector: update state for finished KV Transfers. if kv_connector_output: self._update_from_kv_xfer_finished(kv_connector_output) @@ -489,35 +524,12 @@ def update_from_output( engine_core_outputs[0] = eco = EngineCoreOutputs() eco.scheduler_stats = stats - # This is where we free blocks that were held for transfer - try: - kv_extracted_ids = getattr(model_runner_output, "kv_extracted_req_ids", None) - if kv_extracted_ids: - for req_id in kv_extracted_ids: - # Emit a kv_ready signal so the orchestrator can forward - # the request to the DiT stage immediately after KV - # extraction, without waiting for AR decode to finish. - req = self.requests.get(req_id) - if req is not None and not req.is_finished(): - eco = engine_core_outputs.get(req.client_index) - if eco is None: - eco = EngineCoreOutputs() - engine_core_outputs[req.client_index] = eco - eco.outputs.append( - EngineCoreOutput( - request_id=req_id, - new_token_ids=[], - kv_transfer_params={"kv_ready": True}, - ) - ) - - # Mark transfer as finished - if req_id in self.active_kv_transfers: - self.active_kv_transfers.remove(req_id) - logger.debug(f"[Omni] KV Transfer finished for {req_id}") - + # Free blocks that were held for transfer (kv_ready and + # active_kv_transfers updates already done before the per-request loop). + if kv_extracted_ids: + for req_id in kv_extracted_ids: + try: if req_id in self.waiting_for_transfer_free: - # Now it's safe to free blocks req = self.requests.get(req_id) if req: self.kv_cache_manager.free(req) @@ -525,16 +537,48 @@ def update_from_output( del self.requests[req_id] if req_id in self.transfer_triggered_requests: self.transfer_triggered_requests.remove(req_id) - if req_id in self.active_kv_transfers: - self.active_kv_transfers.remove(req_id) - + self.active_kv_transfers.discard(req_id) + self.pending_stop_after_extraction.discard(req_id) logger.debug(f"Freed blocks for {req_id} after transfer extraction") self.waiting_for_transfer_free.remove(req_id) - except Exception: - init_logger(__name__).exception("Failed to process finished transfer requests") + except Exception: + init_logger(__name__).exception("Failed to free blocks for %s after transfer", req_id) return engine_core_outputs + def finish_requests(self, request_ids: Any, finished_status: RequestStatus) -> list[tuple[str, int]]: + """Handles the finish signal from outside the scheduler. + + For example, the API server can abort a request when the client + disconnects. + + If request_ids is None, all requests will be finished. + + Returns: + Tuple of (req_id, client_index) for requests that were aborted. Will not + include any that were already finished. + """ + + if self.chunk_transfer_adapter: + self.chunk_transfer_adapter.finish_requests(request_ids, finished_status, self.requests) + + return super().finish_requests(request_ids, finished_status) + + def _update_request_as_session(self, session: Request, update: StreamingUpdate) -> None: + """ + Override: Only extend prompt at stage 0, and replace + the existing session with the next streaming update at other stages. + + Discards the last sampled output token from the prior input chunk at stage 0. + """ + req_id = session.request_id + self._new_prompt_len_snapshot[req_id] = len(update.prompt_token_ids) + if self.vllm_config.model_config.stage_id != 0: + self._replace_session_with_streaming_update(session, update) + + else: + super()._update_request_as_session(session, update) + def _free_request(self, request: Request, delay_free_blocks: bool = False) -> dict[str, Any] | None: # TODO(wzliu)! for offline mode, we should not end process until all data is transferred """Mark a request as finished and free its resources.""" @@ -548,6 +592,7 @@ def _free_request(self, request: Request, delay_free_blocks: bool = False) -> di self.encoder_cache_manager.free(request) request_id = request.request_id self.finished_req_ids.add(request_id) + self._new_prompt_len_snapshot.pop(request_id, None) if self.finished_req_ids_dict is not None: self.finished_req_ids_dict[request.client_index].add(request_id) @@ -564,8 +609,7 @@ def _free_request(self, request: Request, delay_free_blocks: bool = False) -> di kv_xfer_params = None return kv_xfer_params elif request_id in self.waiting_for_transfer_free: - # Stopped immediately by stop_decode_on_trigger; blocks are - # held until KV extraction completes in a future step. + # Blocks held until KV extraction completes in a future step. return None else: logger.debug( diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py index 1c4356d4f5..81f0b7fc2b 100644 --- a/vllm_omni/core/sched/omni_generation_scheduler.py +++ b/vllm_omni/core/sched/omni_generation_scheduler.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import time from collections import defaultdict @@ -11,11 +13,16 @@ from vllm.v1.core.sched.request_queue import create_request_queue from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler from vllm.v1.core.sched.utils import remove_all -from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs +from vllm.v1.engine import ( + EngineCoreEventType, + EngineCoreOutput, + EngineCoreOutputs, +) from vllm.v1.metrics.perf import PerfStats -from vllm.v1.request import Request, RequestStatus +from vllm.v1.request import Request, RequestStatus, StreamingUpdate from vllm.v1.spec_decode.metrics import SpecDecodingStats +from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin from vllm_omni.core.sched.output import OmniCachedRequestData, OmniNewRequestData from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import ( OmniChunkTransferAdapter, @@ -25,7 +32,7 @@ logger = init_logger(__name__) -class OmniGenerationScheduler(VLLMScheduler): +class OmniGenerationScheduler(OmniSchedulerMixin, VLLMScheduler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) model_config = self.vllm_config.model_config @@ -324,6 +331,24 @@ def schedule(self) -> SchedulerOutput: return scheduler_output + def finish_requests(self, request_ids, finished_status: RequestStatus) -> list[tuple[str, int]]: + """Handles the finish signal from outside the scheduler. + + For example, the API server can abort a request when the client + disconnects. + + If request_ids is None, all requests will be finished. + + Returns: + Tuple of (req_id, client_index) for requests that were aborted. Will not + include any that were already finished. + """ + + if self.chunk_transfer_adapter: + self.chunk_transfer_adapter.finish_requests(request_ids, finished_status, self.requests) + + return super().finish_requests(request_ids, finished_status) + """ Scheduler for the diffusion model. This scheduler is modified to stop the request immediately for the diffusion model. @@ -581,3 +606,11 @@ def update_from_output( eco.scheduler_stats = stats return engine_core_outputs + + def _update_request_as_session(self, session: Request, update: StreamingUpdate) -> None: + """ + Override: Just replace the existing session with the next streaming update. + + Do not expend prompt id using update. + """ + self._replace_session_with_streaming_update(session, update) diff --git a/vllm_omni/core/sched/omni_scheduler_mixin.py b/vllm_omni/core/sched/omni_scheduler_mixin.py new file mode 100644 index 0000000000..36080e63ac --- /dev/null +++ b/vllm_omni/core/sched/omni_scheduler_mixin.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from vllm.v1.engine import EngineCoreEventType +from vllm.v1.request import Request, RequestStatus, StreamingUpdate + + +class OmniSchedulerMixin: + """Shared scheduler helpers for omni-specific request handling.""" + + def _replace_session_with_streaming_update( + self, + session: Request, + update: StreamingUpdate, + ) -> None: + """For streaming input: Replace an existing streaming session payload with the latest update.""" + session._output_token_ids.clear() + session._all_token_ids.clear() + new_prompt = update.prompt_token_ids or () + session._all_token_ids.extend(new_prompt) + session.num_computed_tokens = 0 + session.prompt_token_ids = update.prompt_token_ids or () + session.additional_information = update.additional_information or None + # Update block hashes for the new tokens. + session.update_block_hashes() + session.num_prompt_tokens = len(session.prompt_token_ids) + session.arrival_time = update.arrival_time + session.sampling_params = update.sampling_params + if session.status == RequestStatus.WAITING_FOR_STREAMING_REQ: + self.num_waiting_for_streaming_input -= 1 + session.status = RequestStatus.WAITING + + if self.log_stats: + session.record_event(EngineCoreEventType.QUEUED) diff --git a/vllm_omni/core/sched/omni_scheduling_coordinator.py b/vllm_omni/core/sched/omni_scheduling_coordinator.py new file mode 100644 index 0000000000..c9d891afb4 --- /dev/null +++ b/vllm_omni/core/sched/omni_scheduling_coordinator.py @@ -0,0 +1,380 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Scheduling-side coordination for chunk and full_payload input waiting. + +Manages WAITING_FOR_CHUNK and WAITING_FOR_INPUT state transitions +based on readiness signals from OmniConnectorOutput, without ever +calling connector.put()/get(). + +This replaces the scheduling half of OmniChunkTransferAdapter; the +transport half lives in OmniConnectorModelRunnerMixin. +""" + +from __future__ import annotations + +import time +from collections import deque +from typing import Any + +from vllm.logger import init_logger +from vllm.v1.request import Request, RequestStatus + +logger = init_logger(__name__) + + +class OmniSchedulingCoordinator: + """Pure-scheduling coordinator for chunk and full_payload input waiting. + + The Scheduler owns an instance of this class. It consumes readiness + signals produced by the Model Runner's ``OmniConnectorModelRunnerMixin`` + (via ``OmniConnectorOutput``) and manages ``WAITING_FOR_CHUNK`` and + ``WAITING_FOR_INPUT`` state transitions accordingly. + """ + + def __init__(self, scheduler_max_num_seqs: int, stage_id: int = 0, async_chunk: bool = False): + self._stage_id = stage_id + self._scheduler_max_num_seqs = scheduler_max_num_seqs + self._async_chunk = async_chunk + + self.finished_requests: set[str] = set() + self.requests_with_ready_chunks: set[str] = set() + self._full_payload_input_received: set[str] = set() + + self._waiting_for_chunk_waiting: deque[Any] = deque() + self._waiting_for_chunk_running: deque[Any] = deque() + + # Request IDs that were newly registered for chunk recv this cycle. + # The engine/Model Runner should call register_chunk_recv() for these + # so the bg thread starts polling. + self.pending_chunk_registrations: list[Any] = [] + + # Requests waiting for full_payload stage input (WAITING_FOR_INPUT). + self._waiting_for_input: deque[Any] = deque() + self.pending_input_registrations: list[Any] = [] + + # Monotonic timestamp recording when each request first entered + # WAITING_FOR_CHUNK or WAITING_FOR_INPUT. Used by + # collect_timed_out_request_ids() to detect orphaned waits. + self._waiting_since: dict[str, float] = {} + + # ------------------------------------------------------------------ # + # Core scheduling methods + # ------------------------------------------------------------------ # + + def process_pending_chunks( + self, + waiting_queue: Any, + running_queue: list[Request], + chunk_ready_req_ids: set[str], + chunk_finished_req_ids: set[str], + ) -> None: + """Transition requests whose chunks have arrived. + + Args: + waiting_queue: Scheduler's waiting request queue. + running_queue: Scheduler's running request list. + chunk_ready_req_ids: IDs with a newly arrived chunk this cycle. + chunk_finished_req_ids: IDs whose final chunk has arrived. + """ + if self._stage_id == 0 or not self._async_chunk: + return + + terminal_ready_req_ids = chunk_ready_req_ids.intersection(chunk_finished_req_ids) + self.finished_requests.update(chunk_finished_req_ids - terminal_ready_req_ids) + self.pending_chunk_registrations = [] + + self._process_chunk_queue( + waiting_queue, + self._waiting_for_chunk_waiting, + RequestStatus.WAITING, + chunk_ready_req_ids, + ) + self._process_chunk_queue( + running_queue, + self._waiting_for_chunk_running, + RequestStatus.RUNNING, + chunk_ready_req_ids, + ) + self.finished_requests.update(terminal_ready_req_ids) + + while len(running_queue) > self._scheduler_max_num_seqs: + request = running_queue.pop() + # Must reset status to WAITING so the scheduler treats it as + # schedulable work. KV blocks are NOT freed here (unlike a + # real preemption), so PREEMPTED would be incorrect. + request.status = RequestStatus.WAITING + waiting_queue.prepend_requests([request]) + + def process_pending_full_payload_inputs( + self, + waiting_queue: Any, + running_queue: list[Request], + stage_recv_req_ids: set[str], + ) -> None: + """Manage WAITING_FOR_INPUT lifecycle for full_payload_mode. + + For non-Stage-0 stages in full_payload_mode (``async_chunk=False``): + 1. Fresh WAITING requests are transitioned to WAITING_FOR_INPUT + and registered for bg-thread polling. + 2. WAITING_FOR_INPUT requests whose data has arrived (in + ``stage_recv_req_ids``) are transitioned back to WAITING. + """ + if self._stage_id == 0: + return + + self._full_payload_input_received.update(stage_recv_req_ids) + if not self._async_chunk and stage_recv_req_ids: + self.finished_requests.update(stage_recv_req_ids) + logger.debug( + "[Coordinator stage-%s] full_payload recv -> finished_requests: %s", + self._stage_id, + stage_recv_req_ids, + ) + self.pending_input_registrations = [] + + remaining: deque[Any] = deque() + for request in self._waiting_for_input: + if request.request_id in stage_recv_req_ids: + request.status = RequestStatus.WAITING + self._waiting_since.pop(request.request_id, None) + waiting_queue.add_request(request) + else: + remaining.append(request) + self._waiting_for_input = remaining + + if not self._async_chunk: + to_remove: list[Any] = [] + queue_snapshot = list(waiting_queue) + for request in queue_snapshot: + if request.status == RequestStatus.WAITING: + if request.request_id in self._full_payload_input_received: + continue + if request.request_id in self.requests_with_ready_chunks: + continue + if request.request_id in self.finished_requests: + continue + request.status = RequestStatus.WAITING_FOR_INPUT + self._waiting_since.setdefault(request.request_id, time.monotonic()) + to_remove.append(request) + self._waiting_for_input.append(request) + self.pending_input_registrations.append(request) + elif request.status == RequestStatus.WAITING_FOR_INPUT: + if request.request_id in stage_recv_req_ids: + request.status = RequestStatus.WAITING + self._waiting_since.pop(request.request_id, None) + else: + to_remove.append(request) + self._waiting_for_input.append(request) + self.pending_input_registrations.append(request) + for request in to_remove: + waiting_queue.remove(request) + + def process_pending_full_payload_inputs_legacy( + self, + waiting_queue: Any, + running_queue: list[Request], + stage_recv_req_ids: set[str], + ) -> None: + """Compatibility wrapper for ``process_pending_full_payload_inputs``.""" + self.process_pending_full_payload_inputs(waiting_queue, running_queue, stage_recv_req_ids) + + def free_finished_request(self, request_id: str) -> None: + """Prune internal tracking sets for a freed request to prevent unbounded growth.""" + self._full_payload_input_received.discard(request_id) + self.finished_requests.discard(request_id) + self.requests_with_ready_chunks.discard(request_id) + self._waiting_since.pop(request_id, None) + + def collect_timed_out_request_ids( + self, + timeout_s: float, + ) -> set[str]: + """Return IDs of requests that have been waiting longer than *timeout_s*. + + Uses ``_waiting_since`` timestamps (always up-to-date) to detect + timed-out requests. This method is safe to call at any point in + the scheduling cycle — it does **not** rely on coordinator internal + queues (which are empty after ``restore_queues()``). + + Clears ``_waiting_since`` for timed-out IDs and defensively removes + them from coordinator internal queues if present. The caller + (scheduler) should then remove the requests from its queues, + set ``FINISHED_ERROR``, and call ``_free_request()`` so that + ``cleanup_finished_request()`` fires in the model runner mixin. + """ + if timeout_s <= 0: + return set() + now = time.monotonic() + timed_out_ids: set[str] = set() + for req_id, start_time in self._waiting_since.items(): + if now - start_time > timeout_s: + timed_out_ids.add(req_id) + if not timed_out_ids: + return set() + + # Defensively remove from coordinator internal queues (may already + # be empty if restore_queues() has run). + for queue_attr in ( + "_waiting_for_chunk_waiting", + "_waiting_for_chunk_running", + "_waiting_for_input", + ): + queue = getattr(self, queue_attr) + remaining: deque[Any] = deque() + for request in queue: + if request.request_id not in timed_out_ids: + remaining.append(request) + setattr(self, queue_attr, remaining) + + for req_id in timed_out_ids: + self._waiting_since.pop(req_id, None) + logger.warning( + "[Coordinator stage-%s] Request %s timed out waiting for chunk/input (waited > %.0fs)", + self._stage_id, + req_id, + timeout_s, + ) + + return timed_out_ids + + def restore_queues( + self, + waiting_queue: Any, + running_queue: list[Request], + ) -> None: + """Return waiting-for-chunk/input requests to scheduling queues.""" + for request in self._waiting_for_chunk_waiting: + waiting_queue.add_request(request) + self._waiting_for_chunk_waiting = deque() + + if self._waiting_for_chunk_running: + running_queue.extend(self._waiting_for_chunk_running) + self._waiting_for_chunk_running = deque() + + for request in self._waiting_for_input: + waiting_queue.add_request(request) + self._waiting_for_input = deque() + + def update_request_metadata( + self, + requests: dict[str, Request], + request_metadata: dict[str, dict[str, Any]], + model_mode: str = "ar", + ) -> None: + """Apply received scheduling metadata to request objects. + + For AR mode: only scheduler-visible metadata is applied locally. + For Generation mode: updates ``request.prompt_token_ids``. + + Additionally, if the payload contains ``next_stage_prompt_len``, + updates the request's ``prompt_token_ids`` to the correct length. + """ + for req_id, metadata in request_metadata.items(): + request = requests.get(req_id) + if request is None: + continue + + # Handle next_stage_prompt_len if present (for models like Qwen3-Omni). + # Only apply when the request has not started decoding yet + # (no output tokens). Resetting a mid-decode request would + # destroy generated tokens and desync KV cache state. + if "next_stage_prompt_len" in metadata: + next_len = metadata["next_stage_prompt_len"] + if isinstance(next_len, int) and next_len > 0: + output_token_ids = getattr(request, "_output_token_ids", None) + has_decode_output = output_token_ids is not None and len(output_token_ids) > 0 + if has_decode_output: + logger.debug( + "[Coordinator stage-%s] Skipping prompt resize for req %s: " + "request already has %s output tokens", + self._stage_id, + req_id, + len(output_token_ids), + ) + else: + current_prompt_ids = getattr(request, "prompt_token_ids", []) or [] + current_prompt_len = len(current_prompt_ids) + if current_prompt_len != next_len or getattr(request, "num_prompt_tokens", None) != next_len: + new_prompt = [0] * next_len + request.prompt_token_ids = new_prompt + request.num_prompt_tokens = next_len + request._all_token_ids.clear() + request._all_token_ids.extend(new_prompt) + request._output_token_ids.clear() + request.num_computed_tokens = 0 + logger.debug( + "[Coordinator stage-%s] Updated prompt_token_ids length to %s for req %s", + self._stage_id, + next_len, + req_id, + ) + + if model_mode != "ar": + new_ids = metadata.get("code_predictor_codes", []) + runtime_seed = None + if "left_context_size" in metadata: + runtime_seed = { + "left_context_size": metadata["left_context_size"], + } + request._omni_initial_model_buffer = runtime_seed + if new_ids: + request.prompt_token_ids = new_ids + request.num_computed_tokens = 0 + + def postprocess_scheduler_output( + self, + scheduler_output: Any, + requests: dict[str, Request] | None = None, + ) -> None: + """Clear per-cycle ready state after scheduler output is materialized.""" + self._clear_chunk_ready(scheduler_output) + + # ------------------------------------------------------------------ # + # Internal helpers + # ------------------------------------------------------------------ # + + def _process_chunk_queue( + self, + queue: Any, + waiting_for_chunk_list: deque[Any], + target_status: RequestStatus, + chunk_ready_req_ids: set[str], + ) -> None: + queue_snapshot = list(queue) + for request in queue_snapshot: + if request.status != RequestStatus.WAITING_FOR_CHUNK: + if request.request_id in self.requests_with_ready_chunks: + continue + if request.request_id in self.finished_requests: + continue + if request.status == RequestStatus.WAITING_FOR_INPUT: + continue + if request.request_id in chunk_ready_req_ids: + self.requests_with_ready_chunks.add(request.request_id) + continue + self.pending_chunk_registrations.append(request) + request.status = RequestStatus.WAITING_FOR_CHUNK + self._waiting_since.setdefault(request.request_id, time.monotonic()) + else: + if request.request_id in chunk_ready_req_ids: + request.status = target_status + self.requests_with_ready_chunks.add(request.request_id) + self._waiting_since.pop(request.request_id, None) + continue + queue.remove(request) + waiting_for_chunk_list.append(request) + + def _clear_chunk_ready(self, scheduler_output: Any) -> None: + if scheduler_output.scheduled_new_reqs: + for req_data in scheduler_output.scheduled_new_reqs: + self.requests_with_ready_chunks.discard( + getattr(req_data, "req_id", None), + ) + + if scheduler_output.scheduled_cached_reqs: + for req_id in scheduler_output.scheduled_cached_reqs.req_ids: + self.requests_with_ready_chunks.discard(req_id) + + +# Backward-compatible alias +ChunkSchedulingCoordinator = OmniSchedulingCoordinator diff --git a/vllm_omni/deploy/qwen2_5_omni.yaml b/vllm_omni/deploy/qwen2_5_omni.yaml new file mode 100644 index 0000000000..41aef0df6f --- /dev/null +++ b/vllm_omni/deploy/qwen2_5_omni.yaml @@ -0,0 +1,92 @@ +# Qwen2.5-Omni deploy: CUDA defaults + platform overrides, verified on 2x H100. +# Stage 2 disables flashinfer autotune because its DiT block never invokes +# flashinfer; the autotune dummy run OOMs the shared cuda:0 device otherwise. +# +# Fields omitted from a stage fall back to StageDeployConfig dataclass +# defaults (see vllm_omni/config/stage_config.py). For instance, every +# stage here uses vLLM's default max_num_batched_tokens=32768 because +# chat-sized prefill comfortably fits; only models with codec prefill +# (Qwen3-Omni, Qwen3-TTS) need to bump it above 32k. +# +# enforce_eager policy across the three deploy YAMLs: +# * code2wav / generation stages: always true (cudagraph incompatible with +# the custom generation loop — set explicitly everywhere). +# * AR stages (thinker, talker): model-dependent. Qwen2.5-Omni runs eager +# on CUDA (thinker uses custom ops that don't trace cleanly); NPU / XPU +# platform overrides flip back to false where cudagraph is verified. +# Qwen3-Omni / Qwen3-TTS AR stages use the default (false = cudagraph on). +async_chunk: false + +stages: + - stage_id: 0 + max_num_seqs: 1 + gpu_memory_utilization: 0.8 + enforce_eager: true + mm_processor_cache_gb: 0 + devices: "0" + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + repetition_penalty: 1.1 + + - stage_id: 1 + max_num_seqs: 1 + gpu_memory_utilization: 0.8 + enforce_eager: true + devices: "1" + default_sampling_params: + temperature: 0.9 + top_p: 0.8 + top_k: 40 + max_tokens: 2048 + seed: 42 + repetition_penalty: 1.05 + + - stage_id: 2 + max_num_seqs: 1 + gpu_memory_utilization: 0.15 + enforce_eager: true + enable_flashinfer_autotune: false + async_scheduling: false + devices: "0" + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + repetition_penalty: 1.1 + +platforms: + npu: + stages: + # NPU has cudagraph support for the thinker, unlike GPU which still + # only runs eager. + - stage_id: 0 + enforce_eager: false + - stage_id: 2 + # 3-NPU layout: stage 2 lives on its own card. + devices: "2" + + rocm: + stages: + - stage_id: 2 + # 3-GPU MI325 layout: stage 2 on a separate card. + devices: "2" + + xpu: + stages: + # Verified on 2x Intel Arc Pro B60. Both AR stages use cudagraphs. + - stage_id: 0 + gpu_memory_utilization: 0.9 + enforce_eager: false + - stage_id: 1 + gpu_memory_utilization: 0.5 + enforce_eager: false + - stage_id: 2 + gpu_memory_utilization: 0.3 + # Stage 2 colocates with stage 1's device on XPU. + devices: "1" diff --git a/vllm_omni/deploy/qwen3_omni_moe.yaml b/vllm_omni/deploy/qwen3_omni_moe.yaml new file mode 100644 index 0000000000..fb8b616213 --- /dev/null +++ b/vllm_omni/deploy/qwen3_omni_moe.yaml @@ -0,0 +1,98 @@ +# Qwen3-Omni-MoE production deploy, verified on 2x H100 (stage 0 on cuda:0, +# stages 1+2 on cuda:1). +# +# Fields omitted from a stage fall back to StageDeployConfig defaults (see +# vllm_omni/config/stage_config.py). Notable implicit defaults for this +# model: +# * Stages 0/1 (thinker, talker) do not set max_num_batched_tokens — +# chat-sized prefill fits in the 32768 default. +# * Stages 0/1 do not set enforce_eager — cudagraph runs by default +# (false). Stage 2 (code2wav) sets true because its generation loop +# is cudagraph-incompatible. +# * Platform sections flip enforce_eager per-stage where platform +# cudagraph support differs. +async_chunk: true + +connectors: + connector_of_shared_memory: + name: SharedMemoryConnector + extra: + codec_chunk_frames: 25 + codec_left_context_frames: 25 + +stages: + - stage_id: 0 + gpu_memory_utilization: 0.9 + devices: "0" + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 42 + repetition_penalty: 1.05 + + - stage_id: 1 + gpu_memory_utilization: 0.6 + devices: "1" + input_connectors: + from_stage_0: connector_of_shared_memory + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + repetition_penalty: 1.05 + + - stage_id: 2 + gpu_memory_utilization: 0.1 + max_num_seqs: 1 + enforce_eager: true + async_scheduling: false + max_num_batched_tokens: 51200 + devices: "1" + input_connectors: + from_stage_1: connector_of_shared_memory + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + repetition_penalty: 1.1 + +platforms: + npu: + stages: + - stage_id: 0 + gpu_memory_utilization: 0.6 + tensor_parallel_size: 2 + devices: "0,1" + - stage_id: 1 + gpu_memory_utilization: 0.6 + enforce_eager: true + devices: "2" + - stage_id: 2 + gpu_memory_utilization: 0.3 + devices: "2" + + rocm: + stages: + - stage_id: 0 + enforce_eager: true + + xpu: + stages: + - stage_id: 0 + tensor_parallel_size: 4 + enforce_eager: true + max_cudagraph_capture_size: 0 + devices: "0,1,2,3" + - stage_id: 1 + enforce_eager: true + max_cudagraph_capture_size: 0 + devices: "4" + - stage_id: 2 + gpu_memory_utilization: 0.3 + max_cudagraph_capture_size: 0 + devices: "4" diff --git a/vllm_omni/deploy/qwen3_tts.yaml b/vllm_omni/deploy/qwen3_tts.yaml new file mode 100644 index 0000000000..32dceebd80 --- /dev/null +++ b/vllm_omni/deploy/qwen3_tts.yaml @@ -0,0 +1,81 @@ +# Qwen3-TTS deploy: talker → code2wav via shared-memory chunk streaming. +# Verified on 1x H100. +# +# Fields omitted from a stage fall back to StageDeployConfig defaults (see +# vllm_omni/config/stage_config.py). Notable choices for this model: +# * Stage 0 (talker) sets max_num_batched_tokens=512 for async-chunk +# latency tuning (not correctness) — small per-step batches keep +# first-chunk latency low. +# * Stage 1 (code2wav) sets max_num_batched_tokens=65536 for correctness: +# codec prefill length (Q * num_frames) exceeds the 32k default. +# * Stage 0 does not set enforce_eager — talker runs cudagraph by default. +# Stage 1 sets true because its codec generation loop is not +# cudagraph-compatible. NPU platform flips stage 0 to true where +# cudagraph is not yet verified. +async_chunk: true + +connectors: + connector_of_shared_memory: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 + codec_streaming: true + connector_get_sleep_s: 0.01 + connector_get_max_wait_first_chunk: 3000 + connector_get_max_wait: 300 + # Must match the decoder sliding attention window. + codec_chunk_frames: 25 + codec_left_context_frames: 72 + +stages: + - stage_id: 0 + max_num_seqs: 10 + gpu_memory_utilization: 0.3 + async_scheduling: true + max_num_batched_tokens: 512 + max_model_len: 4096 + devices: "0" + output_connectors: + to_stage_1: connector_of_shared_memory + default_sampling_params: + temperature: 0.9 + top_k: 50 + max_tokens: 4096 + seed: 42 + repetition_penalty: 1.05 + + - stage_id: 1 + max_num_seqs: 1 + gpu_memory_utilization: 0.3 + enforce_eager: true + async_scheduling: true + # Must be divisible by num_code_groups and cover (left_context + chunk). + # Prefill length is Q * num_frames (e.g. 16 * 2148 = 34368); keep + # headroom past 32k. + max_num_batched_tokens: 65536 + # async_chunk appends windows per step; max_model_len must cover the + # accumulated flat codec stream. + max_model_len: 65536 + devices: "0" + input_connectors: + from_stage_0: connector_of_shared_memory + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 65536 + seed: 42 + repetition_penalty: 1.0 + +platforms: + npu: + stages: + # NPU does not yet support async-scheduling for TTS, and the + # talker fits at max_num_seqs=1 only. + - stage_id: 0 + max_num_seqs: 1 + enforce_eager: true + async_scheduling: false + - stage_id: 1 + gpu_memory_utilization: 0.2 + async_scheduling: false diff --git a/vllm_omni/diffusion/attention/parallel/ulysses.py b/vllm_omni/diffusion/attention/parallel/ulysses.py index 5d860b3350..326b5d4567 100644 --- a/vllm_omni/diffusion/attention/parallel/ulysses.py +++ b/vllm_omni/diffusion/attention/parallel/ulysses.py @@ -414,10 +414,6 @@ def pre_attention( def post_attention(self, attn_output: torch.Tensor, ctx: ParallelAttentionContext | None) -> torch.Tensor: assert isinstance(ctx, _UlyssesCtx), f"Unexpected ctx type: {type(ctx)!r}" - # If we have joint tensors (Text), they were Head-Sliced. - # The main sequence (Image) was Sequence-Sliced. - # attn_output contains [Joint_Sliced | Image_Sliced] (if strategy='front'). - if ctx.joint_len > 0: joint_len = ctx.joint_len diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index a5055a0688..d5397dd166 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -281,6 +281,7 @@ def enable_cache_for_longcat_image(pipeline: Any, cache_config: Any) -> Callable ], forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_1], params_modifiers=[modifier], + has_separate_cfg=True, ) ), cache_config=db_cache_config, @@ -464,6 +465,77 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool return refresh_cache_context +def enable_cache_for_stable_audio_open(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for Stable Audio Open pipeline. + + Args: + pipeline: The StableAudioPipeline instance. + cache_config: DiffusionCacheConfig instance with cache configuration. + + Returns: + A refresh function that can be called to update cache context with new num_inference_steps. + """ + db_cache_config = _build_db_cache_config(cache_config) + + calibrator_config = None + if cache_config.enable_taylorseer: + taylorseer_order = cache_config.taylorseer_order + calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) + logger.info(f"TaylorSeer enabled with order={taylorseer_order}") + + # StableAudio is officially registered in CacheDiT as Pattern_3: + # https://github.com/vipshop/cache-dit/blob/69e82bd1/src/cache_dit/caching/block_adapters/__init__.py#L562 + # + # Pattern_3 is required because StableAudioDiT uses cross-attention + # with static encoder_hidden_states that do not change inside the + # transformer block loop. + cache_dit.enable_cache( + BlockAdapter( + transformer=pipeline.transformer, + blocks=pipeline.transformer.transformer_blocks, + forward_pattern=ForwardPattern.Pattern_3, + params_modifiers=[ + ParamsModifier( + cache_config=db_cache_config, + calibrator_config=calibrator_config, + ) + ], + ), + cache_config=db_cache_config, + ) + + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + """Refresh cache context for the transformer with new num_inference_steps. + + Args: + pipeline: The StableAudioPipeline instance. + num_inference_steps: New number of inference steps. + verbose: Whether to log refresh operations. + """ + # Bypass SCM for step counts that don't support predefined masks (e.g., vLLM's 1-step dummy run) + scm_supported_steps = num_inference_steps >= 8 or num_inference_steps in (4, 6) + + if cache_config.scm_steps_mask_policy is None or not scm_supported_steps: + cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose) + else: + updated_scm_config = DBCacheConfig().reset( + num_inference_steps=num_inference_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, + total_steps=num_inference_steps, + ), + steps_computation_policy=cache_config.scm_steps_policy, + ) + + cache_dit.refresh_context( + pipeline.transformer, + cache_config=updated_scm_config, + verbose=verbose, + ) + + return refresh_cache_context + + def enable_cache_for_sd3(pipeline: Any, cache_config: Any) -> Callable[[int], None]: """Enable cache-dit for StableDiffusion3Pipeline. @@ -561,6 +633,7 @@ def enable_cache_for_ltx2(pipeline: Any, cache_config: Any) -> Callable[[int], N forward_pattern=ForwardPattern.Pattern_0, # Treat audio_hidden_states as encoder_hidden_states in Pattern_0 check_forward_pattern=False, + has_separate_cfg=True, ), cache_config=db_cache_config, calibrator_config=calibrator_config, @@ -1097,41 +1170,85 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool return refresh_cache_context -def enable_cache_for_glm_image(pipeline: Any, cache_config: Any) -> Callable[[int], None]: - """Enable cache-dit for GLM-Image pipeline. +def enable_cache_for_flux2(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for Flux.2-dev pipeline. - GLM-Image processes prompt and image by calling the transformer before the - denoising loop. When an input image is provided (editing mode), the cache must - be force-refreshed after the preprocessing step so stale hidden states are - discarded. Set force_refresh_step_hint = 1 for editing, None for text-to-image. + Args: + pipeline: The Flux2 pipeline instance. + cache_config: DiffusionCacheConfig instance with cache configuration. + Returns: + A refresh function that can be called with a new ``num_inference_steps`` + to update the cache context for the pipeline. """ + # Build DBCacheConfig for transformer db_cache_config = _build_db_cache_config(cache_config) - calibrator_config = None + calibrator = None if cache_config.enable_taylorseer: - calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=cache_config.taylorseer_order) - logger.info(f"TaylorSeer enabled with order={cache_config.taylorseer_order}") + taylorseer_order = cache_config.taylorseer_order + calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order) + logger.info(f"TaylorSeer enabled with order={taylorseer_order}") + + # Build ParamsModifier for transformer + modifier = ParamsModifier( + cache_config=db_cache_config, + calibrator_config=calibrator, + ) logger.info( - f"Enabling cache-dit on GLM-Image transformer: " + f"Enabling cache-dit on Flux transformer with BlockAdapter: " f"Fn={db_cache_config.Fn_compute_blocks}, " f"Bn={db_cache_config.Bn_compute_blocks}, " f"W={db_cache_config.max_warmup_steps}, " - f"force_refresh_step_hint={db_cache_config.force_refresh_step_hint}, " ) + # Enable cache-dit using BlockAdapter for transformer cache_dit.enable_cache( - pipeline.transformer, + ( + BlockAdapter( + transformer=pipeline.transformer, + blocks=[ + pipeline.transformer.transformer_blocks, + pipeline.transformer.single_transformer_blocks, + ], + forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_2], + params_modifiers=[modifier], + ) + ), cache_config=db_cache_config, - calibrator_config=calibrator_config, ) + def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None: + """Refresh cache context for the transformer with new num_inference_steps. -def enable_cache_for_flux2(pipeline: Any, cache_config: Any) -> Callable[[int], None]: - """Enable cache-dit for Flux.2-dev pipeline. + Args: + pipeline: The Flux2 pipeline instance. + num_inference_steps: New number of inference steps. + """ + if cache_config.scm_steps_mask_policy is None: + cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose) + else: + cache_dit.refresh_context( + pipeline.transformer, + cache_config=DBCacheConfig().reset( + num_inference_steps=num_inference_steps, + steps_computation_mask=cache_dit.steps_mask( + mask_policy=cache_config.scm_steps_mask_policy, + total_steps=num_inference_steps, + ), + steps_computation_policy=cache_config.scm_steps_policy, + ), + verbose=verbose, + ) + + return refresh_cache_context + + +def enable_cache_for_glm_image(pipeline: Any, cache_config: Any) -> Callable[[int], None]: + """Enable cache-dit for GlmImage pipeline. Args: - pipeline: The Flux2 pipeline instance. + pipeline: The GlmImage pipeline instance. cache_config: DiffusionCacheConfig instance with cache configuration. Returns: A refresh function that can be called with a new ``num_inference_steps`` @@ -1153,23 +1270,25 @@ def enable_cache_for_flux2(pipeline: Any, cache_config: Any) -> Callable[[int], ) logger.info( - f"Enabling cache-dit on Flux transformer with BlockAdapter: " + f"Enabling cache-dit on GlmImage transformer with BlockAdapter: " f"Fn={db_cache_config.Fn_compute_blocks}, " f"Bn={db_cache_config.Bn_compute_blocks}, " f"W={db_cache_config.max_warmup_steps}, " ) # Enable cache-dit using BlockAdapter for transformer + # Note: We don't use patch_functor here because it's designed for diffusers' GlmImage, + # and our vllm-omni implementation has a different forward signature. + # We use ForwardPattern.Pattern_0 because our block returns (hidden_states, encoder_hidden_states) cache_dit.enable_cache( ( BlockAdapter( transformer=pipeline.transformer, - blocks=[ - pipeline.transformer.transformer_blocks, - pipeline.transformer.single_transformer_blocks, - ], - forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_2], + blocks=pipeline.transformer.transformer_blocks, + forward_pattern=ForwardPattern.Pattern_0, params_modifiers=[modifier], + patch_functor=None, + has_separate_cfg=True, ) ), cache_config=db_cache_config, @@ -1179,7 +1298,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool """Refresh cache context for the transformer with new num_inference_steps. Args: - pipeline: The Flux2 pipeline instance. + pipeline: The GlmImage pipeline instance. num_inference_steps: New number of inference steps. """ if cache_config.scm_steps_mask_policy is None: @@ -1212,6 +1331,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool "Flux2KleinPipeline": enable_cache_for_flux2_klein, "LongCatImagePipeline": enable_cache_for_longcat_image, "LongCatImageEditPipeline": enable_cache_for_longcat_image, + "StableAudioPipeline": enable_cache_for_stable_audio_open, "StableDiffusion3Pipeline": enable_cache_for_sd3, "LTX2Pipeline": enable_cache_for_ltx2, "LTX2ImageToVideoPipeline": enable_cache_for_ltx2, diff --git a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py index 5dd80718d1..38c805c28d 100644 --- a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py +++ b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py @@ -1,19 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os from typing import Any import numpy as np import torch from vllm.config import LoadConfig -from vllm.utils.torch_utils import set_default_torch_dtype +from vllm.transformers_utils.config import get_hf_file_to_dict from vllm_omni.diffusion.cache.teacache.extractors import get_extractor -from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig from vllm_omni.diffusion.hooks import HookRegistry, ModelHook from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader -from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline -from vllm_omni.diffusion.models.stable_audio.pipeline_stable_audio import StableAudioPipeline from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.inputs.data import OmniDiffusionSamplingParams @@ -35,6 +34,7 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any: ctx = self.extractor_fn(module, *args, **kwargs) + # NOTE: We upcast to float32 to also handle bfloat16. modulated_input_cpu = ctx.modulated_input.detach().float().cpu().numpy() outputs = ctx.run_transformer_blocks() @@ -53,23 +53,39 @@ def stop_collection(self) -> list[tuple[np.ndarray, np.ndarray]]: return list(self.current_trajectory) -class BagelAdapter: - """Adapter for Bagel model.""" +class DefaultAdapter: + """Default adapter for standard diffusers pipelines.""" - @staticmethod - def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16) -> BagelPipeline: - od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype) - od_config.model_class_name = "BagelPipeline" + model_class_name = None + uses_tf_config = True + + @classmethod + def load_pipeline(cls, model_path: str, device: str, dtype: torch.dtype) -> Any: + if cls.model_class_name is None: + raise ValueError("Adapter doesn't have a set class name.") + + od_config = OmniDiffusionConfig.from_kwargs( + model_class_name=cls.model_class_name, + model=model_path, + dtype=dtype, + ) + + if cls.uses_tf_config: + # TODO (Alex): Refactor to handle tf_model_config in OmniDiffusionConfig + # instead of OmniDiffusion and remove the manual population here + tf_config_dict = get_hf_file_to_dict( + os.path.join("transformer", "config.json"), + od_config.model, + ) + od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) - pipeline = BagelPipeline(od_config=od_config) - loader = DiffusersPipelineLoader(LoadConfig()) - loader.load_weights(pipeline) - pipeline.to(device) - return pipeline + loader = DiffusersPipelineLoader(LoadConfig(), od_config=od_config) + # load_model will handle dtypes / device placement, put in .eval() mode + return loader.load_model(od_config=od_config, load_device=device) @staticmethod def get_transformer(pipeline: Any) -> tuple[Any, str]: - return pipeline.bagel, "Bagel" + return pipeline.transformer, pipeline.transformer.__class__.__name__ @staticmethod def install_hook(transformer: Any, hook: DataCollectionHook) -> None: @@ -77,25 +93,17 @@ def install_hook(transformer: Any, hook: DataCollectionHook) -> None: registry.register_hook(hook._HOOK_NAME, hook) -class StableAudioAdapter: - """Adapter for Stable Audio Open 1.0 coefficient estimation.""" - - @staticmethod - def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.float16) -> Any: - od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype) - - # Strictly necessary because we bypass loader.load_model() - with set_default_torch_dtype(dtype): - pipeline = StableAudioPipeline(od_config=od_config) +class BagelAdapter(DefaultAdapter): + """Adapter for Bagel model.""" - loader = DiffusersPipelineLoader(LoadConfig()) - loader.load_weights(pipeline) - pipeline.to(device) - return pipeline + model_class_name = "BagelPipeline" + # Skip the hack for loading the tf model config, + # because bagel doesn't use it. + uses_tf_config = False @staticmethod def get_transformer(pipeline: Any) -> tuple[Any, str]: - return pipeline.transformer, "StableAudioDiTModel" + return pipeline.bagel, "Bagel" @staticmethod def install_hook(transformer: Any, hook: DataCollectionHook) -> None: @@ -103,26 +111,32 @@ def install_hook(transformer: Any, hook: DataCollectionHook) -> None: registry.register_hook(hook._HOOK_NAME, hook) -class DefaultAdapter: - """Default adapter for standard diffusers pipelines.""" +class Flux2Adapter(DefaultAdapter): + """Adapter for Flux2 model coefficient estimation.""" - @staticmethod - def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> Any: - raise NotImplementedError("DefaultAdapter.load_pipeline not implemented") + model_class_name = "Flux2Pipeline" - @staticmethod - def get_transformer(pipeline: Any) -> tuple[Any, str]: - return pipeline.transformer, pipeline.transformer.__class__.__name__ - @staticmethod - def install_hook(transformer: Any, hook: DataCollectionHook) -> None: - registry = HookRegistry.get_or_create(transformer) - registry.register_hook(hook._HOOK_NAME, hook) +class LongCatAdapter(DefaultAdapter): + """Adapter for LongCat Image - NOTE: currently this model needs the vLLM + context to be correctly configured to actually run the estimation, since it + uses vLLM norm layers etc. + """ + + model_class_name = "LongCatImagePipeline" + + +class StableAudioAdapter(DefaultAdapter): + """Adapter for Stable Audio Open 1.0 coefficient estimation.""" + + model_class_name = "StableAudioPipeline" _MODEL_ADAPTERS: dict[str, type] = { "Bagel": BagelAdapter, "StableAudio": StableAudioAdapter, + "Flux2": Flux2Adapter, + "LongCat": LongCatAdapter, } _EPSILON = 1e-6 @@ -169,7 +183,6 @@ def __init__( device: str = "cuda", dtype: torch.dtype = torch.bfloat16, ): - # Add validation here ⬇️ if model_type not in _MODEL_ADAPTERS: available_types = list(_MODEL_ADAPTERS.keys()) raise ValueError( @@ -178,7 +191,7 @@ def __init__( f"To add support for a new model, add an entry to _MODEL_ADAPTERS." ) - adapter = _MODEL_ADAPTERS.get(model_type, DefaultAdapter) + adapter = _MODEL_ADAPTERS[model_type] self.pipeline = adapter.load_pipeline(model_path, device, dtype) self.transformer, self.transformer_type = adapter.get_transformer(self.pipeline) self.hook = DataCollectionHook(self.transformer_type) diff --git a/vllm_omni/diffusion/cache/teacache/config.py b/vllm_omni/diffusion/cache/teacache/config.py index 96cf3f03ee..7efdd418e1 100644 --- a/vllm_omni/diffusion/cache/teacache/config.py +++ b/vllm_omni/diffusion/cache/teacache/config.py @@ -64,6 +64,17 @@ -1.04182570e01, 6.78098549e-01, ], + # Flux2 transformer coefficients + # Copied from Qwen-Image, need to be tuned specifically for Flux2 in future + "Flux2Transformer2DModel": [ + -4.50000000e02, + 2.80000000e02, + -4.50000000e01, + 3.20000000e00, + -2.00000000e-02, + ], + # LongCat Image transformer coefficients + "LongCatImageTransformer2DModel": [652.5980, -424.1615, 84.5526, -4.5923, 0.1694], } diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py index bdb3f6a786..d0da0d9df3 100644 --- a/vllm_omni/diffusion/cache/teacache/extractors.py +++ b/vllm_omni/diffusion/cache/teacache/extractors.py @@ -19,8 +19,12 @@ import torch import torch.nn as nn +from vllm.logger import init_logger from vllm_omni.diffusion.forward_context import get_forward_context +from vllm_omni.platforms import current_omni_platform + +logger = init_logger(__name__) @dataclass @@ -221,7 +225,8 @@ def extract_qwen_context( block = module.transformer_blocks[0] img_mod_params = block.img_mod(temb) img_mod1, _ = img_mod_params.chunk(2, dim=-1) - img_modulated, _ = block.img_norm1(hidden_states, img_mod1) + img_scale1, img_shift1, _ = block._modulate(img_mod1) + img_modulated = block.img_norm1(hidden_states, img_scale1, img_shift1) # ============================================================================ # DEFINE TRANSFORMER EXECUTION (Qwen-specific) @@ -721,6 +726,105 @@ def postprocess(h): ) +def extract_longcat_context( + module: nn.Module, # LongCatImageTransformer2DModel + hidden_states, + timestep, + guidance, + encoder_hidden_states, + txt_ids, + img_ids, + **kwargs, +) -> CacheContext: + """Extract the cache context for LongCat Image. + + Similar to other extractors, this is currently the only code needed + for TeaCache support for LongCat image, and encapsulates preprocessing, + modulated input extraction, transformer execution, and postprocessing + logic. + + Args & kawrgs are identical to the inputs to LongCat Image's forward. + + Returns: + CacheContext with all information needed for generic caching + """ + # TODO (Alex) - Refactor TeaCache extractors to more tightly integrate with .forward + from diffusers.models.modeling_outputs import Transformer2DModelOutput + + # 1. Model specific preprocessing + fwd_context = get_forward_context() + sp_size = module.parallel_config.sequence_parallel_size + if sp_size is not None and sp_size > 1: + # NOTE: For now, we set this to False on the forward context + # to be consistent with LongCat Image's current behavior when + # TeaCache is enabled. We do not need to reset it in post process + # since we should never split text embed in sp for this model. + fwd_context.split_text_embed_in_sp = False + + hidden_states = module.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + + temb = module.time_embed(timestep, hidden_states.dtype) + encoder_hidden_states = module.context_embedder(encoder_hidden_states) + + # Compute RoPE embeddings via rope_preparer module + # _sp_plan will automatically shard img_cos/img_sin (outputs 2, 3) + # txt_cos/txt_sin (outputs 0, 1) remain replicated for dual-stream attention + txt_cos, txt_sin, img_cos, img_sin = module.rope_preparer(txt_ids, img_ids) + + # Reconstruct image_rotary_emb with chunked values + # Final shape: (txt_seq_len + img_seq_len // SP, head_dim) + image_rotary_emb = ( + torch.cat([txt_cos, img_cos], dim=0), + torch.cat([txt_sin, img_sin], dim=0), + ) + + # 2. Extract the modulated output from the first mm-DiT block + first_block = module.transformer_blocks[0] + img_modulated = first_block.norm1(hidden_states, emb=temb)[0] + + # 3. Define the transformer execution + def run_transformer_blocks(): + """Execute all Longcat transformer blocks.""" + h = hidden_states + e = encoder_hidden_states + for block in module.transformer_blocks: + e, h = block( + hidden_states=h, + encoder_hidden_states=e, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + for block in module.single_transformer_blocks: + e, h = block( + hidden_states=h, + encoder_hidden_states=e, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + # Hook expects hidden states to be first + return (h, e) + + # 4. Postprocessing + def postprocess(h): + """Apply Longcat-specific output postprocessing.""" + h = module.norm_out(h, temb) + output = module.proj_out(h) + return Transformer2DModelOutput(sample=output) + + # 5. Return the CacheContext + return CacheContext( + modulated_input=img_modulated, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + run_transformer_blocks=run_transformer_blocks, + postprocess=postprocess, + ) + + def extract_stable_audio_context( module: nn.Module, hidden_states: torch.Tensor, @@ -827,6 +931,144 @@ def postprocess(h: torch.Tensor) -> Any: ) +def extract_flux2_context( + module: nn.Module, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + return_dict: bool = True, + **kwargs: Any, +) -> CacheContext: + """ + Extract cache context for Flux2Transformer2DModel. + + This is the ONLY Flux2-specific code needed for TeaCache support. + It encapsulates preprocessing, modulated input extraction, transformer execution, + and postprocessing logic. + + Args: + module: Flux2Transformer2DModel instance + hidden_states: Input hidden states tensor + encoder_hidden_states: Text encoder outputs + timestep: Current diffusion timestep + img_ids: Image inputs for position embedding + txt_ids: Text inputs for position embedding + guidance: Optional guidance scale for CFG + joint_attention_kwargs: Additional attention arguments + return_dict: Whether to return a Transformer2DModelOutput instead of a plain tensor + **kwargs: Additional keyword arguments ignored by this extractor + + Returns: + CacheContext with all information needed for generic caching + """ + + from diffusers.models.modeling_outputs import Transformer2DModelOutput + + if not hasattr(module, "transformer_blocks") or len(module.transformer_blocks) == 0: + raise ValueError("Module must have transformer_blocks") + + # ============================================================================ + # PREPROCESSING (Flux2-specific) + # ============================================================================ + num_txt_tokens = encoder_hidden_states.shape[1] + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + + temb = module.time_guidance_embed(timestep, guidance) + + double_stream_mod_img = module.double_stream_modulation_img(temb) + double_stream_mod_txt = module.double_stream_modulation_txt(temb) + single_stream_mod = module.single_stream_modulation(temb)[0] + + hidden_states = module.x_embedder(hidden_states) + encoder_hidden_states = module.context_embedder(encoder_hidden_states) + + if img_ids.ndim == 3: + img_ids = img_ids[0] + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + + if current_omni_platform.is_npu(): + freqs_cos_image, freqs_sin_image = module.pos_embed(img_ids.cpu()) + image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu()) + freqs_cos_text, freqs_sin_text = module.pos_embed(txt_ids.cpu()) + text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu()) + else: + image_rotary_emb = module.pos_embed(img_ids) + text_rotary_emb = module.pos_embed(txt_ids) + concat_rotary_emb = ( + torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), + torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), + ) + + # ============================================================================ + # EXTRACT MODULATED INPUT (for cache decision) + # ============================================================================ + block = module.transformer_blocks[0] + (shift_msa, scale_msa, gate_msa), _ = double_stream_mod_img + modulated_input = block.norm1(hidden_states) + modulated_input = (1 + scale_msa) * modulated_input + shift_msa + + # ============================================================================ + # DEFINE TRANSFORMER EXECUTION (Flux2-specific) + # ============================================================================ + def run_transformer_blocks(): + """Execute all Flux2 transformer blocks.""" + h = hidden_states + e = encoder_hidden_states + + for transformer_block in module.transformer_blocks: + e, h = transformer_block( + hidden_states=h, + encoder_hidden_states=e, + temb_mod_params_img=double_stream_mod_img, + temb_mod_params_txt=double_stream_mod_txt, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + h = torch.cat([e, h], dim=1) + + for single_transformer_block in module.single_transformer_blocks: + h = single_transformer_block( + hidden_states=h, + encoder_hidden_states=None, + temb_mod_params=single_stream_mod, + image_rotary_emb=concat_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + h = h[:, num_txt_tokens:, ...] + return (h,) + + # ============================================================================ + # DEFINE POSTPROCESSING + # ============================================================================ + def postprocess(h): + h = module.norm_out(h, temb) + output = module.proj_out(h) + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + # ============================================================================ + # RETURN CONTEXT + # ============================================================================ + return CacheContext( + modulated_input=modulated_input, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + run_transformer_blocks=run_transformer_blocks, + postprocess=postprocess, + ) + + # Registry for model-specific extractors # Key: Transformer class name # Value: extractor function with signature (module, *args, **kwargs) -> CacheContext @@ -839,6 +1081,8 @@ def postprocess(h: torch.Tensor) -> Any: "ZImageTransformer2DModel": extract_zimage_context, "Flux2Klein": extract_flux2_klein_context, "StableAudioDiTModel": extract_stable_audio_context, + "Flux2Transformer2DModel": extract_flux2_context, + "LongCatImageTransformer2DModel": extract_longcat_context, # Future models: # "FluxTransformer2DModel": extract_flux_context, # "CogVideoXTransformer3DModel": extract_cogvideox_context, diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 56a891aa5c..0a19eb1197 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -18,6 +18,7 @@ QuantizationConfig, ) +from vllm_omni.diffusion.model_metadata import get_diffusion_model_metadata from vllm_omni.diffusion.utils.network_utils import is_port_available from vllm_omni.quantization import build_quant_config @@ -481,8 +482,10 @@ class OmniDiffusionConfig: # Scheduler flow_shift for Wan2.2 (12.0 for 480p, 5.0 for 720p) flow_shift: float | None = None - # support multi images input + # Support multi-image inputs and expose any model-specific request limit + # through a generic config field so serving code stays model-agnostic. supports_multimodal_inputs: bool = False + max_multimodal_image_inputs: int | None = None log_level: str = "info" @@ -664,7 +667,54 @@ def set_tf_model_config(self, tf_config: "TransformerConfig") -> None: ) def update_multimodal_support(self) -> None: - self.supports_multimodal_inputs = self.model_class_name in {"QwenImageEditPlusPipeline"} + # Resolve serving-visible multimodal behavior from shared metadata + # instead of importing concrete pipeline modules into the config layer. + metadata = get_diffusion_model_metadata(self.model_class_name) + self.supports_multimodal_inputs = metadata.supports_multimodal_inputs + self.max_multimodal_image_inputs = metadata.max_multimodal_image_inputs + + def enrich_config(self) -> None: + """Load model metadata from HuggingFace and populate config fields. + + Diffusers-style models expose ``model_index.json`` with ``_class_name``. + Non-diffusers models (e.g. Bagel, NextStep) only have ``config.json``, + so we fall back to reading that and mapping model_type manually. + """ + from vllm.transformers_utils.config import get_hf_file_to_dict + + try: + config_dict = get_hf_file_to_dict("model_index.json", self.model) + if config_dict is not None: + if self.model_class_name is None: + self.model_class_name = config_dict.get("_class_name", None) + self.update_multimodal_support() + + tf_config_dict = get_hf_file_to_dict("transformer/config.json", self.model) + self.tf_model_config = TransformerConfig.from_dict(tf_config_dict) + else: + raise FileNotFoundError("model_index.json not found") + except (AttributeError, OSError, ValueError, FileNotFoundError): + cfg = get_hf_file_to_dict("config.json", self.model) + if cfg is None: + raise ValueError(f"Could not find config.json or model_index.json for model {self.model}") + + self.tf_model_config = TransformerConfig.from_dict(cfg) + model_type = cfg.get("model_type") + architectures = cfg.get("architectures") or [] + + if model_type == "bagel" or "BagelForConditionalGeneration" in architectures: + self.model_class_name = "BagelPipeline" + self.tf_model_config = TransformerConfig() + self.update_multimodal_support() + elif model_type == "nextstep": + if self.model_class_name is None: + self.model_class_name = "NextStep11Pipeline" + self.tf_model_config = TransformerConfig() + self.update_multimodal_support() + elif architectures and len(architectures) == 1: + self.model_class_name = architectures[0] + else: + raise @classmethod def from_kwargs(cls, **kwargs: Any) -> "OmniDiffusionConfig": diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 422ef479b0..fe940d623e 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -3,6 +3,7 @@ from __future__ import annotations +import inspect import queue import threading import time @@ -78,6 +79,12 @@ def __init__( self.post_process_func = get_diffusion_post_process_func(od_config) self.pre_process_func = get_diffusion_pre_process_func(od_config) + # Cache whether the model-specific postprocess accepts request-level + # sampling params so step() can support both legacy and extended hooks. + self._post_process_accepts_sampling_params = bool( + self.post_process_func is not None + and "sampling_params" in inspect.signature(self.post_process_func).parameters + ) executor_class = DiffusionExecutor.get_class(od_config) self.executor = executor_class(od_config) @@ -143,12 +150,22 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: output_data = output_data.cpu() postprocess_start_time = time.perf_counter() - outputs = self.post_process_func(output_data) if self.post_process_func is not None else output_data + if self.post_process_func is not None: + # Some video pipelines need request-level controls during + # postprocess (for example worker-side frame interpolation). + if self._post_process_accepts_sampling_params: + outputs = self.post_process_func(output_data, sampling_params=request.sampling_params) + else: + outputs = self.post_process_func(output_data) + else: + outputs = output_data audio_payload = None + custom_output = output.custom_output or {} model_audio_sample_rate = None model_fps = None if isinstance(outputs, dict): audio_payload = outputs.get("audio") + custom_output.update(outputs.get("custom_output") or {}) model_audio_sample_rate = outputs.get("audio_sample_rate") model_fps = outputs.get("fps") outputs = outputs.get("video", outputs) @@ -225,7 +242,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: trajectory_timesteps=output.trajectory_timesteps, trajectory_log_probs=output.trajectory_log_probs, trajectory_decoded=output.trajectory_decoded, - custom_output=output.custom_output or {}, + custom_output=custom_output, multimodal_output=mm_output, stage_durations=output.stage_durations, peak_memory_mb=output.peak_memory_mb, @@ -295,7 +312,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: trajectory_timesteps=output.trajectory_timesteps, trajectory_log_probs=output.trajectory_log_probs, trajectory_decoded=output.trajectory_decoded, - custom_output=output.custom_output or {}, + custom_output=custom_output, multimodal_output=mm_output, stage_durations=output.stage_durations, peak_memory_mb=output.peak_memory_mb, @@ -361,15 +378,11 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus ) def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None: - """Start or stop torch profiling on all diffusion workers. + """Start or stop profiling on all diffusion workers. Args: is_start: True to start profiling, False to stop. - profile_prefix: Optional prefix for trace filename (vLLM compat). - - Note: - Matches vLLM's worker.profile() signature for consistency. - Traces are saved automatically via on_trace_ready callback. + profile_prefix: Optional prefix for trace filename. """ if is_start: if profile_prefix is None: diff --git a/vllm_omni/diffusion/distributed/cfg_parallel.py b/vllm_omni/diffusion/distributed/cfg_parallel.py index a8b0012f66..98757006bf 100644 --- a/vllm_omni/diffusion/distributed/cfg_parallel.py +++ b/vllm_omni/diffusion/distributed/cfg_parallel.py @@ -9,6 +9,7 @@ from typing import Any import torch +from vllm.logger import init_logger from vllm_omni.diffusion.distributed.parallel_state import ( get_cfg_group, @@ -16,6 +17,8 @@ get_classifier_free_guidance_world_size, ) +logger = init_logger(__name__) + def _wrap(pred: torch.Tensor | tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: """Normalize prediction to tuple form.""" @@ -32,6 +35,24 @@ def _slice_pred(pred: tuple[torch.Tensor, ...], output_slice: int) -> tuple[torc return tuple(p[:, :output_slice] for p in pred) +def _dispatch_branches(n_branches: int, n_ranks: int) -> list[list[int]]: + """ + Round-robin dispatch N branches to M ranks. + + Rule: branch i → rank (i % n_ranks). + + Examples: + _dispatch_branches(3, 2) -> [[0, 2], [1]] + _dispatch_branches(3, 3) -> [[0], [1], [2]] + _dispatch_branches(4, 2) -> [[0, 2], [1, 3]] + _dispatch_branches(4, 4) -> [[0], [1], [2], [3]] + """ + assignments: list[list[int]] = [[] for _ in range(n_ranks)] + for i in range(n_branches): + assignments[i % n_ranks].append(i) + return assignments + + class CFGParallelMixin(metaclass=ABCMeta): """ Base Mixin class for Diffusion pipelines providing shared CFG methods. @@ -189,6 +210,165 @@ def combine_cfg_noise(self, positive_noise_pred, negative_noise_pred, scale, nor results.append(comb) return _unwrap(tuple(results)) + # ── N-branch CFG interface (for 3+ branch models) ── + + def predict_noise_with_multi_branch_cfg( + self, + do_true_cfg: bool, + true_cfg_scale: float | dict[str, float], + branches_kwargs: list[dict[str, Any]], + cfg_normalize: bool = False, + output_slice: int | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, ...]: + """ + Predict noise with N-branch CFG dispatch across M GPUs. + + This is the multi-branch counterpart of predict_noise_maybe_with_cfg(). + Use this for models with 3 or more CFG branches (e.g., OmniGen2, Bagel, + DreamID). Existing 2-branch models should continue using + predict_noise_maybe_with_cfg(). + + Args: + do_true_cfg: Whether to apply CFG. + true_cfg_scale: CFG scale factor (passed to combine_multi_branch_cfg_noise). + branches_kwargs: List of N dicts, each containing kwargs for one + predict_noise() call. branches_kwargs[0] is always the + positive/conditional branch. + cfg_normalize: Whether to normalize (passed to combine_multi_branch_cfg_noise). + output_slice: If set, slice each output to [:, :output_slice]. + + Returns: + Combined noise prediction, identical on all ranks in CFG parallel. + """ + if do_true_cfg: + n_branches = len(branches_kwargs) + cfg_world_size = get_classifier_free_guidance_world_size() + cfg_parallel_ready = cfg_world_size > 1 + + if cfg_parallel_ready: + return self._predict_multi_branch_parallel( + branches_kwargs, + n_branches, + cfg_world_size, + true_cfg_scale, + cfg_normalize, + output_slice, + ) + else: + # Sequential: run all N branches on single device + preds: list[torch.Tensor | tuple[torch.Tensor, ...]] = [] + for kw in branches_kwargs: + pred = _wrap(self.predict_noise(**kw)) + if output_slice is not None: + pred = _slice_pred(pred, output_slice) + preds.append(_unwrap(pred)) + return self.combine_multi_branch_cfg_noise(preds, true_cfg_scale, cfg_normalize) + else: + # No CFG: only compute positive/conditional prediction + pred = self.predict_noise(**branches_kwargs[0]) + if output_slice is not None: + pred = _unwrap(_slice_pred(_wrap(pred), output_slice)) + return pred + + def _predict_multi_branch_parallel( + self, + branches_kwargs: list[dict[str, Any]], + n_branches: int, + cfg_world_size: int, + true_cfg_scale: float, + cfg_normalize: bool, + output_slice: int | None, + ) -> torch.Tensor | tuple[torch.Tensor, ...]: + """Dispatch N branches across M ranks, all_gather, then combine.""" + cfg_group = get_cfg_group() + cfg_rank = get_classifier_free_guidance_rank() + + if cfg_world_size > n_branches: + logger.warning_once( + "cfg_parallel_size=%d > n_branches=%d, %d GPU(s) will be idle for CFG", + cfg_world_size, + n_branches, + cfg_world_size - n_branches, + ) + + # Assign branches to ranks via round-robin + assignments = _dispatch_branches(n_branches, cfg_world_size) + my_branch_ids = assignments[cfg_rank] + max_per_rank = max(len(a) for a in assignments) + + # Run assigned branches + my_preds: list[tuple[torch.Tensor, ...]] = [] + for bid in my_branch_ids: + pred = _wrap(self.predict_noise(**branches_kwargs[bid])) + if output_slice is not None: + pred = _slice_pred(pred, output_slice) + my_preds.append(pred) + + # Idle ranks (cfg_world_size > n_branches) run a forward pass to get the output shape for all_gather. + # Output shape cannot be inferred from kwargs — may be tuple, sliced, etc. + if not my_preds: + pred = _wrap(self.predict_noise(**branches_kwargs[0])) + if output_slice is not None: + pred = _slice_pred(pred, output_slice) + my_preds.append(pred) + + # Pad to max_per_rank with zeros so all ranks have same size + ref_pred = my_preds[0] + while len(my_preds) < max_per_rank: + my_preds.append(tuple(torch.zeros_like(t) for t in ref_pred)) + + # All-gather each output element separately (like predict_noise_maybe_with_cfg) + # For each slot, gather across ranks; then pick valid results by owner_rank + # all_slots[slot][elem_idx] = [rank0_tensor, rank1_tensor, ...] + all_slots: list[list[list[torch.Tensor]]] = [] + for slot in range(max_per_rank): + slot_results: list[list[torch.Tensor]] = [] + for p in my_preds[slot]: + gathered = cfg_group.all_gather(p, separate_tensors=True) + slot_results.append(gathered) + all_slots.append(slot_results) + + # Reconstruct final_preds in branch order + final_preds: list[torch.Tensor | tuple[torch.Tensor, ...]] = [] + for bid in range(n_branches): + owner_rank = bid % cfg_world_size + slot_idx = bid // cfg_world_size + elements = tuple(all_slots[slot_idx][elem_idx][owner_rank] for elem_idx in range(len(ref_pred))) + final_preds.append(_unwrap(elements)) + + return self.combine_multi_branch_cfg_noise(final_preds, true_cfg_scale, cfg_normalize) + + def combine_multi_branch_cfg_noise( + self, + predictions: list[torch.Tensor | tuple[torch.Tensor, ...]], + true_cfg_scale: float | dict[str, float], + cfg_normalize: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, ...]: + """ + Combine N branch predictions. Default: standard 2-branch CFG formula. + + Override this method for custom multi-branch combine logic. + + Args: + predictions: List of N predictions, where predictions[0] is always + the positive/conditional branch. + true_cfg_scale: CFG scale factor (float for 2-branch, dict for multi-branch). + cfg_normalize: Whether to normalize the combined prediction. + + Returns: + Combined noise prediction. + """ + positive = _wrap(predictions[0]) + negative = _wrap(predictions[1]) + + results = [] + for p, n in zip(positive, negative): + comb = n + true_cfg_scale * (p - n) + if cfg_normalize: + comb = self.cfg_normalize_function(p, comb) + results.append(comb) + return _unwrap(tuple(results)) + def predict_noise(self, *args: Any, **kwargs: Any) -> torch.Tensor | tuple[torch.Tensor, ...]: """ Forward pass through transformer to predict noise. diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py index 8ab38f2a65..5294e6c9ed 100644 --- a/vllm_omni/diffusion/distributed/group_coordinator.py +++ b/vllm_omni/diffusion/distributed/group_coordinator.py @@ -104,6 +104,7 @@ def __init__( self.local_rank = local_rank self.device_group = None self.cpu_group = None + self.shm_broadcaster = None for ranks in group_ranks: device_group = torch.distributed.new_group(ranks, backend=torch_distributed_backend) @@ -316,7 +317,7 @@ def send_object(self, obj: Any, dst: int) -> None: assert dst < self.world_size, f"Invalid dst rank ({dst})" - assert dst != self.rank, "Invalid destination rank. Destination rank is the same as the current rank." + assert dst != self.rank_in_group, "Invalid destination rank. Destination rank is the same as the current rank." # Serialize object to tensor and get the size as well object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) @@ -338,7 +339,7 @@ def recv_object(self, src: int) -> Any: assert src < self.world_size, f"Invalid src rank ({src})" - assert src != self.rank, "Invalid source rank. Source rank is the same as the current rank." + assert src != self.rank_in_group, "Invalid source rank. Source rank is the same as the current rank." size_tensor = torch.empty(1, dtype=torch.long, device="cpu") diff --git a/vllm_omni/diffusion/hooks/base.py b/vllm_omni/diffusion/hooks/base.py index cda4201ccf..517c661587 100644 --- a/vllm_omni/diffusion/hooks/base.py +++ b/vllm_omni/diffusion/hooks/base.py @@ -8,6 +8,7 @@ from __future__ import annotations +import functools import inspect from collections.abc import Callable from dataclasses import dataclass @@ -94,10 +95,9 @@ def post_forward(self, module: nn.Module, output: Any) -> Any: return output def new_forward(self, module: nn.Module, *args: Any, **kwargs: Any) -> Any: - """Override the module's forward pass completely. - - The default implementation calls pre_forward, then the original forward, - then post_forward. Override this method for more complex behavior. + """Override the module's forward pass. This should be overridden for more complex + cases, e.g., TeaCache. If this method is overridden in a subclass, it will be called + instead of self.module._omni_original_forward when executing the hooks. Args: module: The module being called. @@ -105,11 +105,9 @@ def new_forward(self, module: nn.Module, *args: Any, **kwargs: Any) -> Any: **kwargs: Keyword arguments to forward. Returns: - The output of the forward pass. + The output of the replacement for the forward pass. """ - args, kwargs = self.pre_forward(module, *args, **kwargs) - output = module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined] - return self.post_forward(module, output) + raise NotImplementedError("By default, hooks do not implement new_forward") def reset_state(self, module: nn.Module) -> nn.Module: """Reset any state associated with this hook. @@ -136,6 +134,21 @@ def __call__(self, *args: Any, **kwargs: Any): return registry.dispatch(*args, **kwargs) +def sort_hooks_after_call(func): + """Calls the method on the hook registry, then sorts the hooks. + + This should be added to methods that mutate add or remove hooks. + """ + + @functools.wraps(func) + def wrapper(self: HookRegistry, *args, **kwargs): + res = func(self, *args, **kwargs) + self.update_sorted_hooks() + return res + + return wrapper + + class HookRegistry: """Registry of hooks attached to a module. @@ -146,6 +159,10 @@ class HookRegistry: def __init__(self, module: nn.Module): self.module = module self._hooks: dict[str, ModelHook] = {} + # Hooks sorted by execution order + self._sorted_hooks: list[ModelHook] = [] + # Hooks overriding new_forward (if any) + self._new_fwd_impl_hook: ModelHook | None = None @classmethod def get_or_create(cls, module: nn.Module) -> HookRegistry: @@ -173,6 +190,14 @@ def get_or_create(cls, module: nn.Module) -> HookRegistry: return registry + def update_sorted_hooks(self): + """Sort hooks by name, which dictates pre/post process order.""" + sorted_hooks = [self._hooks[k] for k in sorted(self._hooks) if self._hooks[k] != self._new_fwd_impl_hook] + if self._new_fwd_impl_hook is not None: + sorted_hooks.append(self._new_fwd_impl_hook) + self._sorted_hooks = sorted_hooks + + @sort_hooks_after_call def register_hook(self, name: str, hook: ModelHook) -> None: """Register a hook with the given name. @@ -182,7 +207,14 @@ def register_hook(self, name: str, hook: ModelHook) -> None: """ hook.initialize_hook(self.module) self._hooks[name] = hook - + # We can only have one hook that overrides new_forward, + # since we don't currently have a mechanism for combining them. + if type(hook).new_forward is not ModelHook.new_forward: + if self._new_fwd_impl_hook is not None: + raise RuntimeError("Cannot have multiple hooks that override forward active simultaneously") + self._new_fwd_impl_hook = hook + + @sort_hooks_after_call def remove_hook(self, name: str) -> None: """Remove a hook by name. @@ -190,6 +222,9 @@ def remove_hook(self, name: str) -> None: name: The name of the hook to remove. """ if name in self._hooks: + # clear the forward hook if it's the one to delete + if self._new_fwd_impl_hook is self._hooks[name]: + self._new_fwd_impl_hook = None del self._hooks[name] def get_hook(self, name: str) -> ModelHook | None: @@ -206,8 +241,18 @@ def get_hook(self, name: str) -> ModelHook | None: def dispatch(self, *args: Any, **kwargs: Any) -> Any: """Dispatch a forward call through registered hooks. - Currently supports a single active hook. Multiple hooks are called - in sorted order by name, with each hook's output passed to the next. + Multiple hooks may be used with the caveat that only one hook + may override new_forward. While it is assumed that pre/post process + on hooks are composable, the execution flow is as follows for determinism: + + - Run preprocess on all hooks in their sorted order; hooks are sorted alphabetically, + except for the hook overriding forward (`self._new_fwd_impl_hook`), which is last + if it exists. + + - If `self._new_fwd_impl_hook` isn't None, call its forward. Otherwise call the + original model forward. + + - Run post process on all hooks in the reverse sorted order. Args: *args: Positional arguments to forward. @@ -219,24 +264,19 @@ def dispatch(self, *args: Any, **kwargs: Any) -> Any: if not self._hooks: return self.module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined] - # For single hook case, call directly - if len(self._hooks) == 1: - hook = next(iter(self._hooks.values())) - return hook.new_forward(self.module, *args, **kwargs) - - # For multiple hooks, chain them in sorted order - # Each hook can modify args/kwargs via pre_forward - sorted_hooks = sorted(self._hooks.items(), key=lambda x: x[0]) - - # Apply all pre_forward hooks - for _, hook in sorted_hooks: + # Apply all pre_forward hooks; if _new_fwd_impl_hook is set, it's last + for hook in self._sorted_hooks: args, kwargs = hook.pre_forward(self.module, *args, **kwargs) - # Call original forward - output = self.module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined] + # If we have a hook that overrides new_forward, call it directly + if self._new_fwd_impl_hook is not None: + output = self._new_fwd_impl_hook.new_forward(self.module, *args, **kwargs) + # Otherwise just call the original forward. + else: + output = self.module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined] - # Apply all post_forward hooks in reverse order - for _, hook in reversed(sorted_hooks): + # Apply all post_forward hooks in reverse order; if _new_fwd_impl_hook is set, it's first + for hook in reversed(self._sorted_hooks): output = hook.post_forward(self.module, output) return output diff --git a/vllm_omni/diffusion/inline_stage_diffusion_client.py b/vllm_omni/diffusion/inline_stage_diffusion_client.py new file mode 100644 index 0000000000..a33a3e9561 --- /dev/null +++ b/vllm_omni/diffusion/inline_stage_diffusion_client.py @@ -0,0 +1,348 @@ +"""Inline Stage Diffusion Client for vLLM-Omni multi-stage runtime. + +Runs DiffusionEngine in a ThreadPoolExecutor inside the Orchestrator process +instead of spawning a separate StageDiffusionProc subprocess, eliminating ZMQ +IPC overhead. Used when there is only a single diffusion stage. +""" + +from __future__ import annotations + +import asyncio +import time +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING, Any + +import torch +from PIL import Image +from vllm.logger import init_logger + +from vllm_omni.diffusion.data import DiffusionRequestAbortedError +from vllm_omni.diffusion.diffusion_engine import DiffusionEngine +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.engine.stage_init_utils import StageMetadata +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput + +if TYPE_CHECKING: + from vllm_omni.diffusion.data import OmniDiffusionConfig + from vllm_omni.inputs.data import OmniPromptType + +logger = init_logger(__name__) + + +class InlineStageDiffusionClient: + """Runs DiffusionEngine in a thread executor inside the Orchestrator.""" + + stage_type: str = "diffusion" + + def __init__( + self, + model: str, + od_config: OmniDiffusionConfig, + metadata: StageMetadata, + batch_size: int = 1, + ) -> None: + self.model = model + self.od_config = od_config + self.stage_id = metadata.stage_id + self.final_output = metadata.final_output + self.final_output_type = metadata.final_output_type + self.default_sampling_params = metadata.default_sampling_params + self.custom_process_input_func = metadata.custom_process_input_func + self.engine_input_source = metadata.engine_input_source + self.batch_size = batch_size + + self._enrich_config() + self._engine = DiffusionEngine.make_engine(self.od_config) + self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="inline-diffusion") + + self._output_queue: asyncio.Queue[OmniRequestOutput] = asyncio.Queue() + self._tasks: dict[str, asyncio.Task] = {} + self._shutting_down = False + + logger.info( + "[InlineStageDiffusionClient] Stage-%s initialized inline (batch_size=%d)", + self.stage_id, + self.batch_size, + ) + + def _enrich_config(self) -> None: + """Load model metadata from HuggingFace and populate od_config fields.""" + self.od_config.enrich_config() + + # ------------------------------------------------------------------ + # Request processing + # ------------------------------------------------------------------ + + async def add_request_async( + self, + request_id: str, + prompt: OmniPromptType, + sampling_params: OmniDiffusionSamplingParams, + kv_sender_info: dict[int, dict[str, Any]] | None = None, + ) -> None: + task = asyncio.create_task( + self._dispatch_request( + request_id, + prompt, + sampling_params, + kv_sender_info, + ) + ) + self._tasks[request_id] = task + + async def _dispatch_request( + self, + request_id: str, + prompt: Any, + sampling_params: OmniDiffusionSamplingParams, + kv_sender_info: dict[str, Any] | None = None, + ) -> None: + try: + request = OmniDiffusionRequest( + prompts=[prompt], + sampling_params=sampling_params, + request_ids=[request_id], + request_id=request_id, + kv_sender_info=kv_sender_info, + ) + + loop = asyncio.get_running_loop() + results = await loop.run_in_executor(self._executor, self._engine.step, request) + result = results[0] + if not result.request_id: + result.request_id = request_id + + self._output_queue.put_nowait(result) + except DiffusionRequestAbortedError as e: + logger.info("request_id: %s aborted: %s", request_id, str(e)) + except Exception as e: + logger.exception("Diffusion request %s failed: %s", request_id, e) + error_output = OmniRequestOutput.from_diffusion( + request_id=request_id, + images=[], + ) + error_output.error = str(e) + self._output_queue.put_nowait(error_output) + finally: + self._tasks.pop(request_id, None) + + async def add_batch_request_async( + self, + request_id: str, + prompts: list[OmniPromptType], + sampling_params: OmniDiffusionSamplingParams, + kv_sender_info: dict[int, dict[str, Any]] | None = None, + ) -> None: + task = asyncio.create_task( + self._dispatch_batch( + request_id, + prompts, + sampling_params, + kv_sender_info, + ) + ) + self._tasks[request_id] = task + + async def _dispatch_batch( + self, + request_id: str, + prompts: list[Any], + sampling_params: OmniDiffusionSamplingParams, + kv_sender_info: dict[str, Any] | None = None, + ) -> None: + try: + request = OmniDiffusionRequest( + prompts=prompts, + sampling_params=sampling_params, + request_ids=[f"{request_id}-{i}" for i in range(len(prompts))], + request_id=request_id, + kv_sender_info=kv_sender_info, + ) + + loop = asyncio.get_running_loop() + results = await loop.run_in_executor(self._executor, self._engine.step, request) + + all_images: list = [] + merged_mm: dict[str, Any] = {} + merged_metrics: dict[str, Any] = {} + merged_durations: dict[str, float] = {} + merged_custom: dict[str, Any] = {} + peak_mem = 0.0 + latents = None + trajectory_latents: list[torch.Tensor] | None = None + trajectory_timesteps: list[torch.Tensor] | None = None + trajectory_log_probs: torch.Tensor | None = None + trajectory_decoded: list[Image.Image] | None = None + final_output_type = "image" + + for r in results: + all_images.extend(r.images) + merged_mm.update(r._multimodal_output) + merged_metrics.update(r.metrics) + merged_durations.update(r.stage_durations) + merged_custom.update(r._custom_output) + peak_mem = max(peak_mem, r.peak_memory_mb) + if latents is None and r.latents is not None: + latents = r.latents + if trajectory_latents is None: + trajectory_latents = r.trajectory_latents + if trajectory_timesteps is None: + trajectory_timesteps = r.trajectory_timesteps + if trajectory_log_probs is None: + trajectory_log_probs = r.trajectory_log_probs + if trajectory_decoded is None: + trajectory_decoded = r.trajectory_decoded + if r.final_output_type != "image": + final_output_type = r.final_output_type + + result = OmniRequestOutput.from_diffusion( + request_id=request_id, + images=all_images, + prompt=prompts[0] if len(prompts) == 1 else None, + metrics=merged_metrics, + latents=latents, + trajectory_latents=trajectory_latents, + trajectory_timesteps=trajectory_timesteps, + trajectory_log_probs=trajectory_log_probs, + trajectory_decoded=trajectory_decoded, + custom_output=merged_custom or None, + multimodal_output=merged_mm or None, + final_output_type=final_output_type, + stage_durations=merged_durations, + peak_memory_mb=peak_mem, + ) + + self._output_queue.put_nowait(result) + except DiffusionRequestAbortedError as e: + logger.info("request_id: %s aborted: %s", request_id, str(e)) + except Exception as e: + logger.exception("Batch diffusion request %s failed: %s", request_id, e) + error_output = OmniRequestOutput.from_diffusion( + request_id=request_id, + images=[], + ) + error_output.error = str(e) + self._output_queue.put_nowait(error_output) + finally: + self._tasks.pop(request_id, None) + + def get_diffusion_output_nowait(self) -> OmniRequestOutput | None: + try: + return self._output_queue.get_nowait() + except asyncio.QueueEmpty: + return None + + async def abort_requests_async(self, request_ids: list[str]) -> None: + for rid in request_ids: + task = self._tasks.pop(rid, None) + if task: + task.cancel() + self._engine.abort(rid) + + async def collective_rpc_async( + self, + method: str, + timeout: float | None = None, + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] | None = None, + ) -> Any: + loop = asyncio.get_running_loop() + + if method == "profile": + is_start = args[0] if args else True + profile_prefix = args[1] if len(args) > 1 else None + if is_start and profile_prefix is None: + profile_prefix = f"stage_{self.stage_id}_diffusion_{int(time.time())}" + return await loop.run_in_executor( + self._executor, + self._engine.profile, + is_start, + profile_prefix, + ) + + kwargs = kwargs or {} + + # LoRA methods + if method == "add_lora": + lora_request = args[0] if args else kwargs.get("lora_request") + results = await loop.run_in_executor( + self._executor, + self._engine.collective_rpc, + "add_lora", + timeout, + (), + {"lora_request": lora_request}, + None, + ) + return all(results) if isinstance(results, list) else results + + if method == "remove_lora": + results = await loop.run_in_executor( + self._executor, + self._engine.collective_rpc, + "remove_lora", + timeout, + args, + kwargs, + None, + ) + return all(results) if isinstance(results, list) else results + + if method == "list_loras": + results = await loop.run_in_executor( + self._executor, + self._engine.collective_rpc, + "list_loras", + timeout, + (), + {}, + None, + ) + if not isinstance(results, list): + return results or [] + merged: set[int] = set() + for part in results: + merged.update(part or []) + return sorted(merged) + + if method == "pin_lora": + lora_id = args[0] if args else kwargs.get("adapter_id") + results = await loop.run_in_executor( + self._executor, + self._engine.collective_rpc, + "pin_lora", + timeout, + (), + {"adapter_id": lora_id}, + None, + ) + return all(results) if isinstance(results, list) else results + + return await loop.run_in_executor( + self._executor, + self._engine.collective_rpc, + method, + timeout, + args, + kwargs, + None, + ) + + def shutdown(self) -> None: + self._shutting_down = True + + # Cancel all pending tasks + for task in self._tasks.values(): + task.cancel() + + try: + # Cancel queued futures and wait for the running one to complete deterministically + self._executor.shutdown(wait=True, cancel_futures=True) + except Exception: + pass + + try: + self._engine.close() + except Exception: + pass diff --git a/vllm_omni/diffusion/layers/adalayernorm.py b/vllm_omni/diffusion/layers/adalayernorm.py index 35f63e2fc9..d147bdcfeb 100644 --- a/vllm_omni/diffusion/layers/adalayernorm.py +++ b/vllm_omni/diffusion/layers/adalayernorm.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.linear import ReplicatedLinear from vllm_omni.diffusion.layers.custom_op import CustomOp +from vllm_omni.diffusion.layers.norm import LayerNorm if TYPE_CHECKING: from vllm.model_executor.layers.quantization.base_config import QuantizationConfig @@ -27,107 +28,63 @@ def __init__(self, hidden_size: int, elementwise_affine: bool = False, eps: floa self.eps = eps self.elementwise_affine = elementwise_affine self.hidden_size = hidden_size - self.layernorm = nn.LayerNorm(self.hidden_size, elementwise_affine=self.elementwise_affine, eps=self.eps) - - def preprocess( - self, - mod_params: torch.Tensor, - index: torch.Tensor = None, - ) -> torch.Tensor: - # shift: b d, scale: b d, gate: b d - shift, scale, gate = mod_params.chunk(3, dim=-1) - - if index is not None: - # Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts) - # So shift, scale, gate have shape [2*actual_batch, d] - actual_batch = shift.size(0) // 2 - shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d] - scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:] - gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:] - - # index: [b, l] where b is actual batch size - # Expand to [b, l, 1] to match feature dimension - index_expanded = index.unsqueeze(-1) # [b, l, 1] - - # Expand chunks to [b, 1, d] then broadcast to [b, l, d] - shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d] - shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d] - scale_0_exp = scale_0.unsqueeze(1) - scale_1_exp = scale_1.unsqueeze(1) - gate_0_exp = gate_0.unsqueeze(1) - gate_1_exp = gate_1.unsqueeze(1) - - # Use torch.where to select based on index - shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp) - scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp) - gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp) - else: - shift_result = shift.unsqueeze(1) - scale_result = scale.unsqueeze(1) - gate_result = gate.unsqueeze(1) - - return shift_result, scale_result, gate_result + self.layernorm = LayerNorm(self.hidden_size, elementwise_affine=self.elementwise_affine, eps=self.eps) def forward_cuda( self, x: torch.Tensor, - mod_params: torch.Tensor, - index: torch.Tensor = None, + scale: torch.Tensor, + shift: torch.Tensor, ) -> torch.Tensor: - return self.forward_native(x, mod_params, index) + return self.forward_native(x, scale, shift) def forward_hip( self, x: torch.Tensor, - mod_params: torch.Tensor, - index: torch.Tensor = None, + scale: torch.Tensor, + shift: torch.Tensor, ) -> torch.Tensor: - return self.forward_native(x, mod_params, index) + return self.forward_native(x, scale, shift) def forward_npu( self, x: torch.Tensor, - mod_params: torch.Tensor, - index: torch.Tensor = None, + scale: torch.Tensor, + shift: torch.Tensor, ) -> torch.Tensor: - shift_result, scale_result, gate_result = self.preprocess(mod_params, index) - if _HAS_MINDIESD: try: from mindiesd import layernorm_scale_shift - output = layernorm_scale_shift(self.layernorm, x, scale_result, shift_result, fused=True) + output = layernorm_scale_shift(self.layernorm, x, scale, shift, fused=True) - return output, gate_result + return output except ImportError as e: logger.warning_once(f"mindiesd import failed, falling back to torch_npu: {e}") import torch_npu output = ( - torch_npu.npu_layer_norm_eval(x, normalized_shape=[self.hidden_size], eps=self.eps) * (1 + scale_result) - + shift_result + torch_npu.npu_layer_norm_eval(x, normalized_shape=[self.hidden_size], eps=self.eps) * (1 + scale) + shift ) - return output, gate_result + return output def forward_xpu( self, x: torch.Tensor, - mod_params: torch.Tensor, - index: torch.Tensor = None, + scale: torch.Tensor, + shift: torch.Tensor, ) -> torch.Tensor: - return self.forward_native(x, mod_params, index) + return self.forward_native(x, scale, shift) def forward_native( self, x: torch.Tensor, - mod_params: torch.Tensor, - index: torch.Tensor = None, + scale: torch.Tensor, + shift: torch.Tensor, ) -> torch.Tensor: - shift_result, scale_result, gate_result = self.preprocess(mod_params, index) - - return self.layernorm(x) * (1 + scale_result) + shift_result, gate_result + return self.layernorm(x) * (1 + scale) + shift class AdaLayerNormZero(nn.Module): diff --git a/vllm_omni/diffusion/layers/norm.py b/vllm_omni/diffusion/layers/norm.py new file mode 100644 index 0000000000..6096ad7c37 --- /dev/null +++ b/vllm_omni/diffusion/layers/norm.py @@ -0,0 +1,110 @@ +from importlib.util import find_spec + +import torch +import torch.nn as nn +import torch.nn.functional as F +from vllm.logger import init_logger + +from vllm_omni.diffusion.layers.custom_op import CustomOp + +logger = init_logger(__name__) + +_HAS_MINDIESD = find_spec("mindiesd") is not None + + +class LayerNorm(nn.LayerNorm, CustomOp): + """ + LayerNorm implementation that inherits from both ``nn.LayerNorm`` and ``CustomOp``. + NPU: + Uses ``mindiesd.fast_layernorm(self, x)`` when MindIE-SD is installed. + CUDA / HIP / XPU / native: + Falls back to FP32 nn.LayerNorm implementation. + """ + + def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True): + super().__init__(normalized_shape=dim, eps=eps, elementwise_affine=elementwise_affine) + # CustomOp.__init__ cannot be called here because it would re-run + # nn.Module initialization and clear LayerNorm parameters. + self._forward_method = CustomOp.dispatch_forward(self) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self._forward_method(x) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + return self.forward_native(x) + + def forward_hip(self, x: torch.Tensor) -> torch.Tensor: + return self.forward_native(x) + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + return self.forward_native(x) + + def forward_npu(self, x: torch.Tensor) -> torch.Tensor: + if _HAS_MINDIESD: + try: + from mindiesd import fast_layernorm + + return fast_layernorm(self, x) + except ImportError as e: + logger.warning_once( + "mindiesd.fast_layernorm import failed, falling back to FP32 layer_norm: %s", + e, + ) + + return self.forward_native(x) + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + origin_dtype = x.dtype + return F.layer_norm( + x.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ).to(origin_dtype) + + +class RMSNorm(CustomOp): + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward_cuda( + self, + x: torch.Tensor, + ) -> torch.Tensor: + return self.forward_native(x) + + def forward_hip( + self, + x: torch.Tensor, + ) -> torch.Tensor: + return self.forward_native(x) + + def forward_npu( + self, + x: torch.Tensor, + ) -> torch.Tensor: + import torch_npu + + output = torch_npu.npu_rms_norm(x, gamma=self.weight, epsilon=self.variance_epsilon)[0] + + return output + + def forward_xpu( + self, + x: torch.Tensor, + ) -> torch.Tensor: + return self.forward_native(x) + + def forward_native( + self, + x: torch.Tensor, + ) -> torch.Tensor: + input_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + out = x * torch.rsqrt(variance + self.variance_epsilon) + out = self.weight.to(torch.float32) * out + return out.to(input_dtype) diff --git a/vllm_omni/diffusion/model_metadata.py b/vllm_omni/diffusion/model_metadata.py new file mode 100644 index 0000000000..ec133e7380 --- /dev/null +++ b/vllm_omni/diffusion/model_metadata.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class DiffusionModelMetadata: + # Keep serving-facing capability metadata in a lightweight shared module so + # config/model plumbing can read it without importing concrete pipelines. + supports_multimodal_inputs: bool = False + max_multimodal_image_inputs: int | None = None + + +QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES = 4 + + +_DIFFUSION_MODEL_METADATA: dict[str, DiffusionModelMetadata] = { + "QwenImageEditPlusPipeline": DiffusionModelMetadata( + supports_multimodal_inputs=True, + max_multimodal_image_inputs=QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES, + ), +} + + +def get_diffusion_model_metadata(model_class_name: str | None) -> DiffusionModelMetadata: + # Unknown models fall back to "no special multimodal capabilities" so new + # pipelines do not accidentally inherit limits meant for other models. + if model_class_name is None: + return DiffusionModelMetadata() + return _DIFFUSION_MODEL_METADATA.get(model_class_name, DiffusionModelMetadata()) diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index f848077568..d1254f8456 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -854,6 +854,7 @@ def __init__( config, parallel_config=parallel_config, quant_config=quant_config, prefix=f"{prefix}.model" ) self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() @@ -864,6 +865,12 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.embed_tokens = value + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + def set_decoder(self, decoder): self.model = decoder @@ -1207,7 +1214,7 @@ def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) curr += curr_kvlen - text_ids = tokenizer.encode(prompt) + text_ids = tokenizer.encode(prompt, add_special_tokens=False) text_ids = [new_token_ids["bos_token_id"]] + text_ids + [new_token_ids["eos_token_id"]] text_token_lens.append(len(text_ids)) packed_text_ids.extend(text_ids) @@ -1619,10 +1626,110 @@ def _merge_naive_caches(caches: list) -> NaiveCache: num_layers = len(caches[0].key_cache) merged = NaiveCache(num_layers) for layer_idx in range(num_layers): - merged.key_cache[layer_idx] = torch.cat([c.key_cache[layer_idx] for c in caches], dim=0) - merged.value_cache[layer_idx] = torch.cat([c.value_cache[layer_idx] for c in caches], dim=0) + key_parts = [c.key_cache[layer_idx] for c in caches if c.key_cache[layer_idx] is not None] + val_parts = [c.value_cache[layer_idx] for c in caches if c.value_cache[layer_idx] is not None] + merged.key_cache[layer_idx] = torch.cat(key_parts, dim=0) if key_parts else None + merged.value_cache[layer_idx] = torch.cat(val_parts, dim=0) if val_parts else None return merged + def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids): + """Prepare start tokens for autoregressive text generation. + + Ported from the original BAGEL ``Bagel.prepare_start_tokens``. + """ + packed_start_tokens, packed_key_value_indexes = list(), list() + packed_query_position_ids = list() + + curr = 0 + for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope): + packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) + packed_start_tokens.append(new_token_ids["bos_token_id"]) + packed_query_position_ids.append(curr_position_id) + curr += curr_kvlen + + generation_input = { + "packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long), + "packed_query_position_ids": torch.tensor(packed_query_position_ids, dtype=torch.long), + "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), + "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), + } + return generation_input + + @torch.no_grad() + def generate_text( + self, + past_key_values: NaiveCache, + packed_key_value_indexes: torch.LongTensor, + key_values_lens: torch.IntTensor, + packed_start_tokens: torch.LongTensor, + packed_query_position_ids: torch.LongTensor, + max_length: int, + do_sample: bool = False, + temperature: float = 1.0, + end_token_id: int | None = None, + ): + """Autoregressive text generation (ported from original BAGEL). + + Decodes tokens one at a time, appending to ``past_key_values`` + until ``max_length`` is reached or ``end_token_id`` is generated. + """ + step = 0 + generated_sequence = [] + curr_tokens = packed_start_tokens + while step < max_length: + generated_sequence.append(curr_tokens) + packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens) + query_lens = torch.ones_like(curr_tokens) + packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange( + 0, + len(key_values_lens), + device=key_values_lens.device, + dtype=key_values_lens.dtype, + ) + + uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0)) + for i in range(len(uppacked)): + uppacked[i] += i + packed_key_value_indexes = torch.cat(uppacked, dim=0) + + output = self.language_model( + packed_query_sequence=packed_text_embedding, + query_lens=query_lens, + packed_query_position_ids=packed_query_position_ids, + packed_query_indexes=packed_query_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=True, + is_causal=True, + mode="und", + ) + past_key_values = output.past_key_values + packed_query_sequence = output.packed_query_sequence + pred_logits = self.language_model.lm_head(packed_query_sequence) + + if do_sample: + probs = nn.functional.softmax(pred_logits / temperature, dim=-1) + curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + curr_tokens = torch.argmax(pred_logits, dim=-1) + + uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0)) + for i in range(len(uppacked)): + uppacked[i] = torch.cat( + [uppacked[i], torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device)], dim=0 + ) + packed_key_value_indexes = torch.cat(uppacked, dim=0) + key_values_lens = key_values_lens + 1 + packed_query_position_ids = packed_query_position_ids + 1 + step += 1 + + if end_token_id is not None and curr_tokens[0] == end_token_id: + break + + output_device = generated_sequence[0].device + return torch.stack([i.to(output_device) for i in generated_sequence], dim=0) + def generate_image( self, packed_text_ids: torch.LongTensor, diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index 13d0cc2093..90baf5f676 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -365,35 +365,65 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: if req.sampling_params.kv_metadata and "image_shape" in req.sampling_params.kv_metadata: image_shape = tuple(req.sampling_params.kv_metadata["image_shape"]) - cfg_text_kv = getattr(req.sampling_params, "cfg_text_past_key_values", None) + branch_kvs = getattr(req.sampling_params, "cfg_branch_past_key_values", None) or {} + branch_metadata = getattr(req.sampling_params, "cfg_branch_kv_metadata", None) or {} + active_branch = getattr(req.sampling_params, "cfg_active_branch", None) + branch_roles = getattr(req.sampling_params, "cfg_branch_roles", None) or list(branch_kvs.keys()) + + cfg_text_kv = getattr(req.sampling_params, "cfg_text_past_key_values", None) or branch_kvs.get("cfg_text") + cfg_text_metadata = getattr(req.sampling_params, "cfg_text_kv_metadata", None) or branch_metadata.get( + "cfg_text" + ) + cfg_img_kv = getattr(req.sampling_params, "cfg_img_past_key_values", None) or branch_kvs.get("cfg_img") + cfg_img_metadata = getattr(req.sampling_params, "cfg_img_kv_metadata", None) or branch_metadata.get( + "cfg_img" + ) + + cfg_parallel_contract = ( + active_branch is not None or bool(branch_roles) or cfg_text_kv is not None or cfg_img_kv is not None + ) + if cfg_parallel_contract: + logger.info( + "CFG enabled with injected branch KV context roles=%s active=%s", + branch_roles, + active_branch, + ) + if cfg_text_kv is not None: - logger.info("CFG enabled with multi-KV: using injected cfg_text KV Cache") cfg_text_seq_len = cfg_text_kv.key_cache[0].shape[0] cfg_text_context["past_key_values"] = cfg_text_kv cfg_text_context["kv_lens"] = [cfg_text_seq_len] - cfg_text_metadata = getattr(req.sampling_params, "cfg_text_kv_metadata", None) if cfg_text_metadata and "ropes" in cfg_text_metadata: cfg_text_context["ropes"] = cfg_text_metadata["ropes"] else: cfg_text_context["ropes"] = [cfg_text_seq_len] - - cfg_img_kv = getattr(req.sampling_params, "cfg_img_past_key_values", None) or injected_kv + else: + # No cfg_text companion received. For text2img this is the + # expected path: original BAGEL uses an empty KV cache (0 + # tokens) as the text-unconditional branch. Keep the default + # empty NaiveCache in cfg_text_context and preserve the + # original cfg_text_scale so CFG still applies. + pass + + if cfg_img_kv is None: + # text2img multi-stage: cfg_img reuses gen KV (positive prompt, + # no image), mirroring forward_cache_update_text on cfg_img_context + # in the single-stage path. + cfg_img_seq_len = injected_kv.key_cache[0].shape[0] + cfg_img_context["past_key_values"] = injected_kv + cfg_img_context["kv_lens"] = [cfg_img_seq_len] + if req.sampling_params.kv_metadata and "ropes" in req.sampling_params.kv_metadata: + cfg_img_context["ropes"] = req.sampling_params.kv_metadata["ropes"] + else: + cfg_img_context["ropes"] = [cfg_img_seq_len] + else: cfg_img_seq_len = cfg_img_kv.key_cache[0].shape[0] cfg_img_context["past_key_values"] = cfg_img_kv cfg_img_context["kv_lens"] = [cfg_img_seq_len] - cfg_img_metadata = getattr(req.sampling_params, "cfg_img_kv_metadata", None) if cfg_img_metadata and "ropes" in cfg_img_metadata: cfg_img_context["ropes"] = cfg_img_metadata["ropes"] else: cfg_img_context["ropes"] = [cfg_img_seq_len] - else: - logger.warning("CFG is disabled: only single KV cache available") - gen_params = BagelGenParams( - num_timesteps=gen_params.num_timesteps, - timestep_shift=gen_params.timestep_shift, - cfg_text_scale=1.0, - cfg_img_scale=1.0, - ) else: image_input = ( @@ -495,11 +525,15 @@ def vae_transforms(img): cfg_text_context = deepcopy(gen_context) + # Strip <|im_start|>/<|im_end|> wrappers that end2end.py may have + # already added, so prepare_prompts doesn't double-add bos/eos. + clean_prompt = prompt.removeprefix("<|im_start|>").removesuffix("<|im_end|>") + # Update gen_context with text prompt generation_input, newlens, new_rope = self.bagel.prepare_prompts( curr_kvlens=gen_context["kv_lens"], curr_rope=gen_context["ropes"], - prompts=[prompt], + prompts=[clean_prompt], tokenizer=self.tokenizer, new_token_ids=self.new_token_ids, ) @@ -527,34 +561,37 @@ def vae_transforms(img): gen_context["kv_lens"] = newlens gen_context["ropes"] = new_rope - # cfg_text_context: update with negative prompt (no text condition) + # cfg_text_context: update with negative prompt (no text condition). + # When empty, keep cfg_text_context as-is (kv_lens=0) to match + # original BAGEL; _merge_naive_caches handles None KV entries. neg_prompt = extra_args.get("negative_prompt", "") - neg_input, neg_newlens, neg_rope = self.bagel.prepare_prompts( - curr_kvlens=cfg_text_context["kv_lens"], - curr_rope=cfg_text_context["ropes"], - prompts=[neg_prompt], - tokenizer=self.tokenizer, - new_token_ids=self.new_token_ids, - ) - for k, v in neg_input.items(): - if torch.is_tensor(v): - neg_input[k] = v.to(self.device) - with torch.autocast( - device_type=self.device.type, - enabled=self.device.type != "cpu", - dtype=self.od_config.dtype, - ): - cfg_text_context["past_key_values"] = self.bagel.forward_cache_update_text( - cfg_text_context["past_key_values"], **neg_input + if neg_prompt: + neg_input, neg_newlens, neg_rope = self.bagel.prepare_prompts( + curr_kvlens=cfg_text_context["kv_lens"], + curr_rope=cfg_text_context["ropes"], + prompts=[neg_prompt], + tokenizer=self.tokenizer, + new_token_ids=self.new_token_ids, ) - cfg_text_context["kv_lens"] = neg_newlens - cfg_text_context["ropes"] = neg_rope + for k, v in neg_input.items(): + if torch.is_tensor(v): + neg_input[k] = v.to(self.device) + with torch.autocast( + device_type=self.device.type, + enabled=self.device.type != "cpu", + dtype=self.od_config.dtype, + ): + cfg_text_context["past_key_values"] = self.bagel.forward_cache_update_text( + cfg_text_context["past_key_values"], **neg_input + ) + cfg_text_context["kv_lens"] = neg_newlens + cfg_text_context["ropes"] = neg_rope # cfg_img_context: update with text prompt (no image condition) cfg_img_generation_input, cfg_img_newlens, cfg_img_new_rope = self.bagel.prepare_prompts( curr_kvlens=cfg_img_context["kv_lens"], curr_rope=cfg_img_context["ropes"], - prompts=[prompt], + prompts=[clean_prompt], tokenizer=self.tokenizer, new_token_ids=self.new_token_ids, ) @@ -572,6 +609,96 @@ def vae_transforms(img): cfg_img_context["kv_lens"] = cfg_img_newlens cfg_img_context["ropes"] = cfg_img_new_rope + # ---- Detect output modality and think mode ---- + modalities = first_prompt.get("modalities", []) if isinstance(first_prompt, dict) else [] + is_text_output = "text" in modalities + think_enabled = extra_args.get("think", False) + think_text = None + + if think_enabled and injected_kv is None: + max_think_tokens = int(extra_args.get("max_think_tokens", 1000)) + do_sample = bool(extra_args.get("do_sample", False)) + text_temperature = float(extra_args.get("text_temperature", 0.3)) + + with torch.autocast( + device_type=self.device.type, + enabled=self.device.type != "cpu", + dtype=self.od_config.dtype, + ): + start_input = self.bagel.prepare_start_tokens( + gen_context["kv_lens"], gen_context["ropes"], self.new_token_ids + ) + for k, v in start_input.items(): + if torch.is_tensor(v): + start_input[k] = v.to(self.device) + + gen_ctx_copy = deepcopy(gen_context) + token_ids = self.bagel.generate_text( + past_key_values=gen_ctx_copy["past_key_values"], + max_length=max_think_tokens, + do_sample=do_sample, + temperature=text_temperature, + end_token_id=self.new_token_ids["eos_token_id"], + **start_input, + ) + # token_ids shape: (seq_len, batch=1) + decoded = self.tokenizer.decode(token_ids[:, 0].tolist()) + # Strip chat markers to get clean text + think_text = decoded.split("<|im_end|>")[0] + if "<|im_start|>" in think_text: + think_text = think_text.split("<|im_start|>")[-1] + logger.info("Think mode generated %d tokens", token_ids.shape[0]) + + if not is_text_output: + # Use the autoregressive KV cache from think generation + # directly, instead of decode→re-encode which adds extra + # bos/eos and may alter tokenization. + num_think_tokens = token_ids.shape[0] + gen_context["past_key_values"] = gen_ctx_copy["past_key_values"] + gen_context["kv_lens"] = [kl + num_think_tokens for kl in gen_context["kv_lens"]] + gen_context["ropes"] = [r + num_think_tokens for r in gen_context["ropes"]] + + # ---- Text-only output (text2text / img2text) ---- + if is_text_output and injected_kv is None: + if think_text is not None: + # Think mode already generated the text (including reasoning) + text_output = think_text + else: + max_text_tokens = int(extra_args.get("max_think_tokens", 500)) + do_sample = bool(extra_args.get("do_sample", False)) + text_temperature = float(extra_args.get("text_temperature", 0.3)) + + with torch.autocast( + device_type=self.device.type, + enabled=self.device.type != "cpu", + dtype=self.od_config.dtype, + ): + start_input = self.bagel.prepare_start_tokens( + gen_context["kv_lens"], gen_context["ropes"], self.new_token_ids + ) + for k, v in start_input.items(): + if torch.is_tensor(v): + start_input[k] = v.to(self.device) + token_ids = self.bagel.generate_text( + past_key_values=gen_context["past_key_values"], + max_length=max_text_tokens, + do_sample=do_sample, + temperature=text_temperature, + end_token_id=self.new_token_ids["eos_token_id"], + **start_input, + ) + decoded = self.tokenizer.decode(token_ids[:, 0].tolist()) + text_output = decoded.split("<|im_end|>")[0] + if "<|im_start|>" in text_output: + text_output = text_output.split("<|im_start|>")[-1] + + return DiffusionOutput( + output=text_output, + custom_output={"text_output": text_output}, + stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None, + ) + + # ---- Image generation (text2img / img2img) ---- if req.sampling_params.seed is not None: torch.manual_seed(req.sampling_params.seed) if self.device.type == "cuda": @@ -676,12 +803,17 @@ def vae_transforms(img): if trajectory_log_probs: trajectory_log_probs_stacked = torch.stack(trajectory_log_probs) + custom = {} + if think_text is not None: + custom["think_text"] = think_text + return DiffusionOutput( output=img, trajectory_latents=trajectory_latents_stacked, trajectory_timesteps=trajectory_timesteps_stacked, trajectory_log_probs=trajectory_log_probs_stacked, trajectory_decoded=trajectory_decoded, + custom_output=custom, stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None, ) diff --git a/vllm_omni/diffusion/models/dreamid_omni/fusion.py b/vllm_omni/diffusion/models/dreamid_omni/fusion.py index a534f5a76f..abca4c9474 100644 --- a/vllm_omni/diffusion/models/dreamid_omni/fusion.py +++ b/vllm_omni/diffusion/models/dreamid_omni/fusion.py @@ -1,3 +1,5 @@ +import re + import torch import torch.nn as nn from vllm.logger import init_logger @@ -15,78 +17,26 @@ logger = init_logger(__name__) -class FusionModel(nn.Module): - def __init__(self, video_config=None, audio_config=None): - super().__init__() - has_video = True - has_audio = True - if video_config is not None: - self.video_model = WanModel(**video_config) - else: - has_video = False - self.video_model = None - logger.warning("No video model is provided!") - - if audio_config is not None: - self.audio_model = WanModel(**audio_config) - else: - has_audio = False - self.audio_model = None - logger.warning("No audio model is provided!") - - if has_video and has_audio: - assert len(self.video_model.blocks) == len(self.audio_model.blocks) - self.num_blocks = len(self.video_model.blocks) - - self.inject_cross_attention_kv_projections() - self.device = get_local_device() - - self.num_heads = self.video_model.num_heads - self.head_dim = self.video_model.dim // self.video_model.num_heads - self.attn = Attention( - num_heads=self.num_heads, - head_size=self.head_dim, - num_kv_heads=self.num_heads, - softmax_scale=1.0 / (self.head_dim**0.5), - causal=False, - ) - - def inject_cross_attention_kv_projections(self): - for vid_block in self.video_model.blocks: - vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim) - vid_block.cross_attn.v_fusion = nn.Linear(vid_block.dim, vid_block.dim) - vid_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(vid_block.dim, elementwise_affine=True) - vid_block.cross_attn.norm_k_fusion = ( - WanRMSNorm(vid_block.dim, eps=1e-6) if vid_block.qk_norm else nn.Identity() - ) +class FusedBlock(nn.Module): + """Wrapper pairing a video block and audio block for layerwise offloading. - for audio_block in self.audio_model.blocks: - audio_block.cross_attn.k_fusion = nn.Linear(audio_block.dim, audio_block.dim) - audio_block.cross_attn.v_fusion = nn.Linear(audio_block.dim, audio_block.dim) - audio_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(audio_block.dim, elementwise_affine=True) - audio_block.cross_attn.norm_k_fusion = ( - WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity() - ) + Registers both blocks as submodules so their parameters are visible to the offload hooks. + """ - def merge_kwargs(self, vid_kwargs, audio_kwargs): - """ - keys in each kwarg: - e - seq_lens - grid_sizes - freqs - context - context_lens - """ - merged_kwargs = {} - for key in vid_kwargs: - merged_kwargs[f"vid_{key}"] = vid_kwargs[key] - for key in audio_kwargs: - merged_kwargs[f"audio_{key}"] = audio_kwargs[key] - return merged_kwargs + def __init__( + self, + vid_block: nn.Module, + audio_block: nn.Module, + device: torch.device, + ): + super().__init__() + self.vid_block = vid_block + self.audio_block = audio_block + self.device = device - def single_fusion_cross_attention_forward( + def _cross_attention_forward( self, + attn: Attention, cross_attn_block, src_seq, src_grid_sizes, @@ -104,21 +54,17 @@ def single_fusion_cross_attention_forward( ): b, n, d = src_seq.size(0), cross_attn_block.num_heads, cross_attn_block.head_dim if hasattr(cross_attn_block, "k_img"): - ## means is i2v block q, k, v, k_img, v_img = cross_attn_block.qkv_fn(src_seq, context) else: - ## means is t2v block q, k, v = cross_attn_block.qkv_fn(src_seq, context) k_img = v_img = None - x = self.attn(q, k, v) + x = attn(q, k, v) if k_img is not None: - img_x = self.attn(q, k_img, v_img) + img_x = attn(q, k_img, v_img) x = x + img_x - # is_vid = src_grid_sizes.shape[1] > 1 - # compute target attention target_seq = cross_attn_block.pre_attn_norm_fusion(target_seq) k_target = cross_attn_block.norm_k_fusion(cross_attn_block.k_fusion(target_seq)).view(b, -1, n, d) v_target = cross_attn_block.v_fusion(target_seq).view(b, -1, n, d) @@ -132,17 +78,16 @@ def single_fusion_cross_attention_forward( freqs_scaling=target_freqs_scaling, ) - target_x = self.attn(q, k_target, v_target) + target_x = attn(q, k_target, v_target) x = x + target_x - - x = x.flatten(2) # [B, L/P, C] - + x = x.flatten(2) x = cross_attn_block.o(x) return x - def single_fusion_cross_attention_ffn_forward( + def _cross_attention_ffn_forward( self, + attn: Attention, attn_block, src_seq, src_grid_sizes, @@ -159,7 +104,8 @@ def single_fusion_cross_attention_ffn_forward( target_ref_lengths=None, target_freqs_scaling=None, ): - src_seq = src_seq + self.single_fusion_cross_attention_forward( + src_seq = src_seq + self._cross_attention_forward( + attn, attn_block.cross_attn, attn_block.norm3(src_seq), src_grid_sizes=src_grid_sizes, @@ -180,12 +126,11 @@ def single_fusion_cross_attention_ffn_forward( src_seq = src_seq + y * src_e[5].squeeze(2) return src_seq - def single_fusion_block_forward( + def forward( self, - vid_block, - audio_block, vid, audio, + attn: Attention, vid_e, vid_seq_lens, vid_grid_sizes, @@ -203,6 +148,9 @@ def single_fusion_block_forward( audio_ref_lengths, audio_freqs_scaling, ): + vid_block = self.vid_block + audio_block = self.audio_block + ## audio modulation assert audio_e.dtype == torch.bfloat16 assert len(audio_e.shape) == 4 and audio_e.size(2) == 6 and audio_e.shape[1] == audio.shape[1], ( @@ -246,7 +194,8 @@ def single_fusion_block_forward( og_audio = audio # audio cross-attention - audio = self.single_fusion_cross_attention_ffn_forward( + audio = self._cross_attention_ffn_forward( + attn, audio_block, audio, audio_grid_sizes, @@ -267,7 +216,8 @@ def single_fusion_block_forward( assert not torch.equal(og_audio, audio), "Audio should be changed after cross-attention!" # video cross-attention - vid = self.single_fusion_cross_attention_ffn_forward( + vid = self._cross_attention_ffn_forward( + attn, vid_block, vid, vid_grid_sizes, @@ -287,6 +237,128 @@ def single_fusion_block_forward( return vid, audio + +class FusionModel(nn.Module): + _layerwise_offload_blocks_attrs = ["fused_blocks"] + + def __init__(self, video_config=None, audio_config=None): + super().__init__() + has_video = True + has_audio = True + self.device = get_local_device() + if video_config is not None: + self.video_model = WanModel(**video_config) + else: + has_video = False + self.video_model = None + logger.warning("No video model is provided!") + + if audio_config is not None: + self.audio_model = WanModel(**audio_config) + else: + has_audio = False + self.audio_model = None + logger.warning("No audio model is provided!") + + if has_video and has_audio: + assert len(self.video_model.blocks) == len(self.audio_model.blocks) + self.num_blocks = len(self.video_model.blocks) + + self.inject_cross_attention_kv_projections() + + self.num_heads = self.video_model.num_heads + self.head_dim = self.video_model.dim // self.video_model.num_heads + # Make a single shared instance to pass in at forward time + self.attn = Attention( + num_heads=self.num_heads, + head_size=self.head_dim, + num_kv_heads=self.num_heads, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + ) + + if has_video and has_audio: + self.fused_blocks = nn.ModuleList( + [ + FusedBlock( + self.video_model.blocks[i], + self.audio_model.blocks[i], + self.device, + ) + for i in range(self.num_blocks) + ] + ) + + def load_state_dict(self, state_dict, strict=True, assign=False): + """Remap checkpoints where blocks are stored under + `video_model.blocks.N.*` / `audio_model.blocks.N.*` to the current + `fused_blocks.N.vid_block.*` / `fused_blocks.N.audio_block.*`. + """ + needs_remap = any(re.match(r"^(video_model|audio_model)\.blocks\.\d+\.", k) for k in state_dict) + if needs_remap: + remapped = {} + for k, v in state_dict.items(): + new_k = re.sub(r"^video_model\.blocks\.(\d+)\.", r"fused_blocks.\1.vid_block.", k) + new_k = re.sub(r"^audio_model\.blocks\.(\d+)\.", r"fused_blocks.\1.audio_block.", new_k) + remapped[new_k] = v + state_dict = remapped + + self._detach_blocks_from_backbones() + + return super().load_state_dict(state_dict, strict=strict, assign=assign) + + def inject_cross_attention_kv_projections(self): + for vid_block in self.video_model.blocks: + vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim) + vid_block.cross_attn.v_fusion = nn.Linear(vid_block.dim, vid_block.dim) + vid_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(vid_block.dim, elementwise_affine=True) + vid_block.cross_attn.norm_k_fusion = ( + WanRMSNorm(vid_block.dim, eps=1e-6) if vid_block.qk_norm else nn.Identity() + ) + + for audio_block in self.audio_model.blocks: + audio_block.cross_attn.k_fusion = nn.Linear(audio_block.dim, audio_block.dim) + audio_block.cross_attn.v_fusion = nn.Linear(audio_block.dim, audio_block.dim) + audio_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(audio_block.dim, elementwise_affine=True) + audio_block.cross_attn.norm_k_fusion = ( + WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity() + ) + + def _detach_blocks_from_backbones(self) -> None: + """Keep offloadable blocks owned only by a single place. + + NOTE: This is a special workaround to support layerwise offloading. + The model registers the same Wan blocks under both the video/audio + backbones and `fused_blocks` which is a wrapper for unified blocks + walking through. However, layerwise offloading will only consider + `fused_blocks` as offloadable components and will materialize all + other modules onto device, including the same blocks owned by both + `fused_blocks` and `video_model` and `audio_model`. + """ + video_blocks = list(self.video_model.blocks) + audio_blocks = list(self.audio_model.blocks) + self.video_model._modules.pop("blocks", None) + self.audio_model._modules.pop("blocks", None) + self.video_model.blocks = tuple(video_blocks) + self.audio_model.blocks = tuple(audio_blocks) + + def merge_kwargs(self, vid_kwargs, audio_kwargs): + """ + keys in each kwarg: + e + seq_lens + grid_sizes + freqs + context + context_lens + """ + merged_kwargs = {} + for key in vid_kwargs: + merged_kwargs[f"vid_{key}"] = vid_kwargs[key] + for key in audio_kwargs: + merged_kwargs[f"audio_{key}"] = audio_kwargs[key] + return merged_kwargs + def forward( self, vid, @@ -316,17 +388,8 @@ def forward( kwargs = self.merge_kwargs(vid_kwargs, audio_kwargs) - for i in range(self.num_blocks): - """ - 1 fusion block refers to 1 audio block with 1 video block. - """ - - vid_block = self.video_model.blocks[i] - audio_block = self.audio_model.blocks[i] - - vid, audio = self.single_fusion_block_forward( - vid_block=vid_block, audio_block=audio_block, vid=vid, audio=audio, **kwargs - ) + for fused_block in self.fused_blocks: + vid, audio = fused_block(vid, audio, self.attn, **kwargs) vid = self.video_model.post_transformer_block_out(vid, vid_kwargs["grid_sizes"], vid_e) audio = self.audio_model.post_transformer_block_out(audio, audio_kwargs["grid_sizes"], audio_e) diff --git a/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py b/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py index e22765f80e..cc932f8c1f 100644 --- a/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py +++ b/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py @@ -4,6 +4,7 @@ import logging import math import os +from collections.abc import Iterable import torch import torch.distributed @@ -15,12 +16,8 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin -from vllm_omni.diffusion.distributed.parallel_state import ( - get_cfg_group, - get_classifier_free_guidance_rank, - get_classifier_free_guidance_world_size, -) from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportAudioInput, SupportImageInput from vllm_omni.diffusion.request import OmniDiffusionRequest @@ -32,7 +29,6 @@ init_mmaudio_vae, init_text_model, init_wan_vae_2_2, - load_fusion_checkpoint, ) from dreamid_omni.utils.rearrange import Rearrange from dreamid_omni.utils.resize import NaResize @@ -43,6 +39,21 @@ logger = logging.getLogger(__name__) +def get_dreamid_omni_post_process_func(*args, **kwargs): + def post_process(output): + if isinstance(output, tuple) and len(output) == 2: + video, audio = output + return { + "video": video, + "audio": audio, + "audio_sample_rate": 16000, + "fps": 24, + } + return output + + return post_process + + AUDIO_CONFIG = { "patch_size": [1], "model_type": "t2a", @@ -112,16 +123,24 @@ def __init__( self.text_model = init_text_model(model, rank=self.device) self.text_encoder = self.text_model.model - # Fusion model - ## load audio/video model config - Fusion_model = FusionModel(VIDEO_CONFIG, AUDIO_CONFIG) - - checkpoint_path = self.od_config.tf_model_config.get("fusion", None) - assert checkpoint_path is not None, "fusion checkpoint path is None" - load_fusion_checkpoint(Fusion_model, checkpoint_path=os.path.join(model, checkpoint_path)) - self.model = Fusion_model + # Fusion model — weights are loaded later via load_weights() + self.model = FusionModel(VIDEO_CONFIG, AUDIO_CONFIG) self.transformer = self.model + fusion_path = self.od_config.tf_model_config.get("fusion", None) + assert fusion_path is not None, "fusion checkpoint path is None in transformer config" + fusion_subfolder = os.path.dirname(fusion_path) or None + fusion_filename = os.path.basename(fusion_path) + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=model, + subfolder=fusion_subfolder, + revision=None, + prefix="model.", + allow_patterns_overrides=[fusion_filename], + ) + ] + # Fixed attributes, non-configurable self.audio_latent_channel = AUDIO_CONFIG.get("in_dim") self.video_latent_channel = VIDEO_CONFIG.get("in_dim") @@ -216,8 +235,11 @@ def load_image_latent_ref_ip_video( return ref_vae_latents, ref_audio_lengths - def load_weights(self, weights): - pass + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + prefix = "model." + state_dict = {name[len(prefix) :]: tensor for name, tensor in weights if name.startswith(prefix)} + self.model.load_state_dict(state_dict, strict=True) + return {prefix + k for k in state_dict} def get_scheduler_time_steps(self, sampling_steps, solver_name="unipc", device=0, shift=5.0): torch.manual_seed(4) @@ -249,6 +271,28 @@ def get_scheduler_time_steps(self, sampling_steps, solver_name="unipc", device=0 return sample_scheduler, timesteps + def predict_noise(self, **kwargs): + pred_vid, pred_audio = self.model(**kwargs) + return (pred_vid[0], pred_audio[0]) + + def combine_multi_branch_cfg_noise(self, predictions, true_cfg_scale, cfg_normalize=False): + vid_pos, audio_pos = predictions[0] + vid_neg, audio_neg = predictions[1] + vid_ip_neg, _ = predictions[2] + _, refaudio_neg = predictions[3] + + pred_video = ( + vid_neg + + true_cfg_scale["video_cfg_scale"] * (vid_pos - vid_neg) + + true_cfg_scale["video_ref_cfg_scale"] * (vid_pos - vid_ip_neg) + ) + pred_audio = ( + audio_neg + + true_cfg_scale["audio_cfg_scale"] * (audio_pos - audio_neg) + + true_cfg_scale["audio_ref_cfg_scale"] * (audio_pos - refaudio_neg) + ) + return (pred_video, pred_audio) + def diffuse( self, video_noise: torch.Tensor, @@ -306,72 +350,22 @@ def diffuse( "vid_context": [text_embeddings_video_neg], } - if get_classifier_free_guidance_world_size() > 1: - # Enable CFG-parallel: rank0 computes positive, rank1 computes negative. - cfg_group = get_cfg_group() - cfg_rank = get_classifier_free_guidance_rank() - - if cfg_rank == 0: - pred_vid, pred_audio = self.model( - vid=[model_input_video], audio=[model_input_audio], t=timestep_input, **pos_args - ) - pre_vid_ip_neg, _ = self.model( - vid=[model_input_video_neg], audio=[model_input_audio], t=timestep_input, **pos_args - ) - pred_vid_0 = pred_vid[0] - pred_audio_0 = pred_audio[0] - pre_vid_ip_0 = pre_vid_ip_neg[0] - pred_refaudio_0 = torch.zeros_like(pred_audio_0) # dummy tensor - else: - pred_vid, pred_audio = self.model( - vid=[model_input_video], audio=[model_input_audio], t=timestep_input, **neg_args - ) - _, pred_refaudio_neg = self.model( - vid=[model_input_video], audio=[model_input_audio_neg], t=timestep_input, **pos_args - ) - pred_vid_0 = pred_vid[0] - pred_audio_0 = pred_audio[0] - pre_vid_ip_0 = torch.zeros_like(pred_vid_0) # dummy tensor - pred_refaudio_0 = pred_refaudio_neg[0] - - pred_vid_gathered = cfg_group.all_gather(pred_vid_0, separate_tensors=True) - pred_audio_gathered = cfg_group.all_gather(pred_audio_0, separate_tensors=True) - pre_vid_ip_gathered = cfg_group.all_gather(pre_vid_ip_0, separate_tensors=True) - pred_refaudio_gathered = cfg_group.all_gather(pred_refaudio_0, separate_tensors=True) - - pred_vid_pos = [pred_vid_gathered[0]] - pred_vid_neg = [pred_vid_gathered[1]] - pred_audio_pos = [pred_audio_gathered[0]] - pred_audio_neg = [pred_audio_gathered[1]] - pre_vid_ip_neg = [pre_vid_ip_gathered[0]] - pred_refaudio_neg = [pred_refaudio_gathered[1]] - else: - pred_vid_pos, pred_audio_pos = self.model( - vid=[model_input_video], audio=[model_input_audio], t=timestep_input, **pos_args - ) - - pred_vid_neg, pred_audio_neg = self.model( - vid=[model_input_video], audio=[model_input_audio], t=timestep_input, **neg_args - ) - - pre_vid_ip_neg, _ = self.model( - vid=[model_input_video_neg], audio=[model_input_audio], t=timestep_input, **pos_args - ) - - _, pred_refaudio_neg = self.model( - vid=[model_input_video], audio=[model_input_audio_neg], t=timestep_input, **pos_args - ) - - pred_video_guided = ( - pred_vid_neg[0] - + self.video_cfg_scale * (pred_vid_pos[0] - pred_vid_neg[0]) - + self.video_ref_cfg_scale * (pred_vid_pos[0] - pre_vid_ip_neg[0]) - ) - - pred_audio_guided = ( - pred_audio_neg[0] - + self.audio_cfg_scale * (pred_audio_pos[0] - pred_audio_neg[0]) - + self.audio_ref_cfg_scale * (pred_audio_pos[0] - pred_refaudio_neg[0]) + branches_kwargs = [ + {"vid": [model_input_video], "audio": [model_input_audio], "t": timestep_input, **pos_args}, + {"vid": [model_input_video], "audio": [model_input_audio], "t": timestep_input, **neg_args}, + {"vid": [model_input_video_neg], "audio": [model_input_audio], "t": timestep_input, **pos_args}, + {"vid": [model_input_video], "audio": [model_input_audio_neg], "t": timestep_input, **pos_args}, + ] + + pred_video_guided, pred_audio_guided = self.predict_noise_with_multi_branch_cfg( + do_true_cfg=True, + true_cfg_scale={ + "video_cfg_scale": self.video_cfg_scale, + "video_ref_cfg_scale": self.video_ref_cfg_scale, + "audio_cfg_scale": self.audio_cfg_scale, + "audio_ref_cfg_scale": self.audio_ref_cfg_scale, + }, + branches_kwargs=branches_kwargs, ) video_noise = scheduler_video.step( pred_video_guided.unsqueeze(0), t_v, video_noise.unsqueeze(0), return_dict=False diff --git a/vllm_omni/diffusion/models/flux/flux_transformer.py b/vllm_omni/diffusion/models/flux/flux_transformer.py index 680b8bfbbe..dd88ee76c1 100644 --- a/vllm_omni/diffusion/models/flux/flux_transformer.py +++ b/vllm_omni/diffusion/models/flux/flux_transformer.py @@ -381,7 +381,9 @@ def __init__( super().__init__() self.mlp_hidden_dim = int(dim * mlp_ratio) - self.norm = AdaLayerNormZeroSingle(dim, quant_config=quant_config, prefix=f"{prefix}.norm") + # Modulation linear kept full precision; shift/scale/gate outputs + # are multiplied into the residual stream every block (see #2728). + self.norm = AdaLayerNormZeroSingle(dim, quant_config=None, prefix=f"{prefix}.norm") self.proj_mlp = ReplicatedLinear( dim, self.mlp_hidden_dim, @@ -524,10 +526,10 @@ def _is_transformer_block(name: str, module) -> bool: def __init__( self, - od_config: OmniDiffusionConfig = None, + od_config: OmniDiffusionConfig | None = None, patch_size: int = 1, in_channels: int = 64, - out_channels: int = None, + out_channels: int | None = None, num_layers: int = 19, num_single_layers: int = 38, attention_head_dim: int = 128, @@ -563,13 +565,16 @@ def __init__( self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) self.x_embedder = nn.Linear(in_channels, self.inner_dim) + # Dual-stream blocks kept full precision — FP8 on their joint + # attention path causes noise on FLUX (#2728). Single-stream + # blocks (38 vs 19) still get FP8 for memory savings. self.transformer_blocks = nn.ModuleList( [ FluxTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, - quant_config=quant_config, + quant_config=None, prefix=f"transformer_blocks.{i}", ) for i in range(num_layers) @@ -589,12 +594,13 @@ def __init__( ] ) + # Final modulation feeds proj_out; keep full precision (see #2728). self.norm_out = AdaLayerNormContinuous( self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6, - quant_config=quant_config, + quant_config=None, prefix="norm_out", ) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) diff --git a/vllm_omni/diffusion/models/flux/pipeline_flux.py b/vllm_omni/diffusion/models/flux/pipeline_flux.py index 6f43e8dbb5..70d572d9a6 100644 --- a/vllm_omni/diffusion/models/flux/pipeline_flux.py +++ b/vllm_omni/diffusion/models/flux/pipeline_flux.py @@ -30,6 +30,7 @@ from vllm_omni.diffusion.models.t5_encoder import T5EncoderModel from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific logger = logging.getLogger(__name__) @@ -106,7 +107,11 @@ def __init__( self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( self.device ) - self.transformer = FluxTransformer2DModel(od_config=od_config, quant_config=od_config.quantization_config) + + transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, FluxTransformer2DModel) + self.transformer = FluxTransformer2DModel( + **transformer_kwargs, od_config=od_config, quant_config=od_config.quantization_config + ) self.tokenizer = CLIPTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) self.tokenizer_2 = T5TokenizerFast.from_pretrained( diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index 437dd58d0c..021d28c2ac 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -234,7 +234,15 @@ def __init__( self.transformer = Flux2Transformer2DModel(quant_config=od_config.quantization_config, **transformer_kwargs) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.latent_channels if hasattr(self.vae, "config") else 16 self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=self.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) self.tokenizer_max_length = 512 self.default_sample_size = 128 @@ -247,6 +255,14 @@ def __init__( enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler ) + def get_timesteps(self, num_inference_steps, strength, device): + init_timestep = min(num_inference_steps * strength, num_inference_steps) + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + return timesteps, num_inference_steps - t_start + @staticmethod def _get_qwen3_prompt_embeds( text_encoder: Qwen3ForCausalLM, @@ -582,6 +598,54 @@ def prepare_image_latents( return image_latents, image_latent_ids + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=dtype) + if masked_image.shape[1] != num_channels_latents: + masked_image_latents = self._encode_vae_image(image=masked_image, generator=generator) + else: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError("The passed mask and the required batch size don't match.") + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents is not None and masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError("The passed mask and the required batch size don't match.") + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + if masked_image_latents is not None: + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + masked_image_latents = self._pack_latents(masked_image_latents) + + mask = mask.repeat(1, self.latent_channels, 1, 1) + mask = self._patchify_latents(mask) + mask = self._pack_latents(mask) + + return mask, masked_image_latents + def check_inputs( self, prompt, @@ -590,7 +654,15 @@ def check_inputs( prompt_embeds=None, callback_on_step_end_tensor_inputs=None, guidance_scale=None, + strength=None, + num_inference_steps=None, ): + if strength is not None: + if strength < 0 or strength > 1: + raise ValueError(f"strength must be between 0 and 1, got {strength}") + if num_inference_steps is not None and num_inference_steps <= 0: + raise ValueError(f"num_inference_steps must be positive, got {num_inference_steps}") + if ( height is not None and height % (self.vae_scale_factor * 2) != 0 @@ -622,7 +694,7 @@ def check_inputs( elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if guidance_scale > 1.0 and self.is_distilled: + if guidance_scale is not None and guidance_scale > 1.0 and self.is_distilled: logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.") @property @@ -653,11 +725,14 @@ def forward( self, req: OmniDiffusionRequest, image: PIL.Image.Image | list[PIL.Image.Image] | None = None, + reference_image: PIL.Image.Image | list[PIL.Image.Image] | None = None, + mask_image: PIL.Image.Image | list[PIL.Image.Image] | None = None, prompt: str | list[str] | None = None, height: int | None = None, width: int | None = None, num_inference_steps: int = 50, sigmas: list[float] | None = None, + strength: float = 1.0, guidance_scale: float | None = 4.0, num_images_per_prompt: int = 1, generator: torch.Generator | list[torch.Generator] | None = None, @@ -671,6 +746,7 @@ def forward( callback_on_step_end_tensor_inputs: list[str] = ["latents"], max_sequence_length: int = 512, text_encoder_out_layers: tuple[int, ...] = (9, 18, 27), + padding_mask_crop: int | None = None, ) -> DiffusionOutput: r""" Function invoked when calling the pipeline for generation. @@ -804,6 +880,8 @@ def forward( prompt_embeds=prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, guidance_scale=guidance_scale, + strength=strength, + num_inference_steps=num_inference_steps, ) self._guidance_scale = guidance_scale @@ -848,6 +926,9 @@ def forward( if image is not None and not isinstance(image, list): image = [image] + multiple_of = self.vae_scale_factor * 2 + crops_coords = None + resize_mode = "crop" condition_images = None if image is not None: for img in image: @@ -860,10 +941,14 @@ def forward( img = self.image_processor._resize_to_target_area(img, 1024 * 1024) image_width, image_height = img.size - multiple_of = self.vae_scale_factor * 2 image_width = (image_width // multiple_of) * multiple_of image_height = (image_height // multiple_of) * multiple_of - img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + img = self.image_processor.preprocess( + img, height=image_height, width=image_width, crops_coords=crops_coords, resize_mode=resize_mode + ) condition_images.append(img) height = height or image_height width = width or image_width @@ -871,6 +956,16 @@ def forward( height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor + # Get mask_image and reference_image + multi_modal_data = req.prompts[0].get("multi_modal_data", {}) if req.prompts else {} + mask_image = multi_modal_data.get("mask_image") + reference_image = multi_modal_data.get("reference_image") + + if mask_image is not None and (image is None or (isinstance(image, list) and len(image) == 0)): + raise ValueError("image must be provided when using mask_image for inpainting") + + init_image = condition_images[0] if condition_images else None + # 5. prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, latent_ids = self.prepare_latents( @@ -884,6 +979,8 @@ def forward( latents=latents, ) + original_latent_ids = latent_ids + image_latents = None image_latent_ids = None if condition_images is not None: @@ -895,6 +992,71 @@ def forward( dtype=self.vae.dtype, ) + # Preprocess reference_image + if reference_image is not None and not ( + isinstance(reference_image, torch.Tensor) and reference_image.size(1) == self.latent_channels + ): + if ( + isinstance(reference_image, list) + and isinstance(reference_image[0], torch.Tensor) + and reference_image[0].ndim == 4 + ): + reference_image = torch.cat(reference_image, dim=0) + img_reference = reference_image[0] if isinstance(reference_image, list) else reference_image + reference_image_height, reference_image_width = self.image_processor.get_default_height_width(img_reference) + + reference_image_width = reference_image_width // multiple_of * multiple_of + reference_image_height = reference_image_height // multiple_of * multiple_of + reference_image = self.image_processor.resize( + reference_image, reference_image_height, reference_image_width + ) + reference_image = self.image_processor.preprocess( + reference_image, + reference_image_height, + reference_image_width, + crops_coords=crops_coords, + resize_mode=resize_mode, + ) + else: + pass + + reference_image_latents = None + reference_image_latent_ids = None + if reference_image is not None: + reference_image_latents, reference_image_latent_ids = self.prepare_image_latents( + images=[reference_image], + batch_size=batch_size * num_images_per_prompt, + generator=generator, + device=device, + dtype=self.vae.dtype, + ) + + if reference_image_latent_ids is not None: + latent_ids = torch.cat([latent_ids, reference_image_latent_ids], dim=1) + elif image_latent_ids is not None: + latent_ids = torch.cat([latent_ids, image_latent_ids], dim=1) + + mask = None + if mask_image is not None: + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + if init_image is not None: + masked_image = init_image * (mask_condition < 0.5).to(init_image.dtype) + else: + masked_image = None + mask, _ = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + self.latent_channels, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) # 6. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: @@ -908,6 +1070,13 @@ def forward( sigmas=sigmas, mu=mu, ) + if reference_image is not None or mask_image is not None: + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) self._num_timesteps = len(timesteps) # 7. Denoising loop @@ -924,7 +1093,10 @@ def forward( latent_model_input = latents.to(self.transformer.dtype) latent_image_ids = latent_ids - if image_latents is not None: + if reference_image_latents is not None: + latent_model_input = torch.cat([latents, reference_image_latents], dim=1) + latent_image_ids = latent_ids + elif image_latents is not None: latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) @@ -953,7 +1125,9 @@ def forward( negative_kwargs = None # For editing pipelines, we need to slice the output to remove condition latents - output_slice = latents.size(1) if image_latents is not None else None + output_slice = ( + latents.size(1) if (image_latents is not None or reference_image_latents is not None) else None + ) noise_pred = self.predict_noise_maybe_with_cfg( do_true_cfg=self.do_classifier_free_guidance, @@ -964,9 +1138,22 @@ def forward( output_slice=output_slice, ) + latents_dtype = latents.dtype # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, self.do_classifier_free_guidance) + if mask is not None and image_latents is not None: + init_latents_proper = image_latents + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep], device=device), latents + ) + latents = (1 - mask) * init_latents_proper + mask * latents + + if latents.dtype != latents_dtype and torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: @@ -978,7 +1165,7 @@ def forward( self._current_timestep = None - latents = self._unpack_latents_with_ids(latents, latent_ids) + latents = self._unpack_latents_with_ids(latents, original_latent_ids) latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( diff --git a/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py index 490e0198b9..7ff42a5f00 100644 --- a/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py +++ b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py @@ -19,10 +19,16 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata from vllm_omni.diffusion.attention.layer import Attention from vllm_omni.diffusion.cache.base import CachedTransformer -from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig from vllm_omni.diffusion.distributed.hsdp_utils import is_transformer_block_module +from vllm_omni.diffusion.distributed.sp_plan import ( + SequenceParallelInput, + SequenceParallelOutput, +) +from vllm_omni.diffusion.forward_context import get_forward_context logger = init_logger(__name__) @@ -108,8 +114,8 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, channel, height, width = hidden_states.shape - post_patch_height = height // self.patch_size - post_patch_width = width // self.patch_size + post_patch_height = torch.tensor(height // self.patch_size, device=hidden_states.device, dtype=torch.int64) + post_patch_width = torch.tensor(width // self.patch_size, device=hidden_states.device, dtype=torch.int64) # Reshape: [B, C, H, W] -> [B, H', W', C*p*p] -> [B, H'*W', C*p*p] hidden_states = hidden_states.reshape( @@ -159,6 +165,65 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens return (freqs.cos(), freqs.sin()) +class GlmImagePrepare(nn.Module): + """Prepare module for GLM-Image that handles patch embedding and RoPE computation. + + This module encapsulates the input processing pipeline to create a module boundary + where _sp_plan can shard outputs via split_output=True. + + Similar to Qwen-Image's ImageRopePrepare, this ensures hidden_states and RoPE + embeddings are sharded together to maintain dimension alignment. + """ + + def __init__( + self, + image_projector: nn.Module, + rope: GlmImageRotaryPosEmbed, + patch_size: int, + ): + super().__init__() + self.image_projector = image_projector + self.rope = rope + self.patch_size = patch_size + + def forward( + self, + hidden_states: torch.Tensor, + prior_hidden_states: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Process hidden_states and compute RoPE embeddings. + + Args: + hidden_states: Input latent tensor [B, C, H, W] + prior_hidden_states: Optional prior embedding to add + + Returns: + hidden_states: Patched hidden states [B, seq_len, D] + rope_cos: RoPE cos embeddings [seq_len, dim] + rope_sin: RoPE sin embeddings [seq_len, dim] + post_patch_height: Scalar tensor for height after patching + post_patch_width: Scalar tensor for width after patching + """ + batch_size, num_channels, height, width = hidden_states.shape + + post_patch_height = torch.tensor(height // self.patch_size, device=hidden_states.device, dtype=torch.int64) + post_patch_width = torch.tensor(width // self.patch_size, device=hidden_states.device, dtype=torch.int64) + + # Compute RoPE (uses original 4D hidden_states shape) + image_rotary_emb = self.rope(hidden_states) + rope_cos = image_rotary_emb[0].to(hidden_states.device) + rope_sin = image_rotary_emb[1].to(hidden_states.device) + + # Patch embedding: [B, C, H, W] -> [B, seq_len, D] + hidden_states = self.image_projector(hidden_states) + + # Add prior embedding if provided + if prior_hidden_states is not None: + hidden_states = hidden_states + prior_hidden_states + + return hidden_states, rope_cos, rope_sin, post_patch_height, post_patch_width + + class GlmImageAdaLayerNormZero(nn.Module): """Adaptive LayerNorm with zero initialization for both image and text streams.""" @@ -397,6 +462,7 @@ def __init__( dim: int, num_heads: int, head_dim: int, + parallel_config: DiffusionParallelConfig | None = None, out_bias: bool = True, eps: float = 1e-5, ): @@ -404,6 +470,7 @@ def __init__( self.dim = dim self.total_num_heads = num_heads self.head_dim = head_dim + self.parallel_config = parallel_config # QKV projection (fused for efficiency) self.to_qkv = QKVParallelLinear( @@ -450,16 +517,19 @@ def forward( attention_mask: torch.Tensor | None = None, kv_cache: GlmImageLayerKVCache | None = None, kv_cache_mode: KVCacheMode | None = None, + hidden_states_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Forward pass for joint attention. Args: - hidden_states: Image hidden states [B, img_seq_len, D] - encoder_hidden_states: Text hidden states [B, text_seq_len, D] - image_rotary_emb: Tuple of (cos, sin) for RoPE + hidden_states: Image hidden states [B, img_seq_len, D] (sharded in SP mode) + encoder_hidden_states: Text hidden states [B, text_seq_len, D] (full in SP mode) + image_rotary_emb: Tuple of (cos, sin) for RoPE (sharded in SP mode) + attention_mask: Optional attention mask kv_cache: Optional layer KV cache for image editing kv_cache_mode: Cache mode (WRITE, READ, SKIP) + hidden_states_mask: Mask for SP padding (True=valid, False=padding) Returns: Tuple of (image_hidden_states, text_hidden_states) @@ -467,6 +537,13 @@ def forward( dtype = encoder_hidden_states.dtype batch_size, text_seq_length, _ = encoder_hidden_states.shape + # Check if SP is enabled + sp_size = self.parallel_config.sequence_parallel_size if self.parallel_config else None + use_sp = sp_size is not None and sp_size > 1 + if use_sp: + forward_ctx = get_forward_context() + use_sp = not forward_ctx.split_text_embed_in_sp + # Concatenate text and image: [text, image] hidden_states_combined = torch.cat([encoder_hidden_states, hidden_states], dim=1) @@ -485,41 +562,88 @@ def forward( query = self.norm_q(query).to(dtype=dtype) key = self.norm_k(key).to(dtype=dtype) - # Apply RoPE only to image tokens (not text tokens) - if image_rotary_emb is not None: - # Only apply RoPE to image part (after text_seq_length) - query_img = query[:, text_seq_length:, :, :] - key_img = key[:, text_seq_length:, :, :] - from diffusers.models.embeddings import apply_rotary_emb - - query_img = apply_rotary_emb(query_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) - key_img = apply_rotary_emb(key_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) - query = torch.cat([query[:, :text_seq_length, :, :], query_img], dim=1) - key = torch.cat([key[:, :text_seq_length, :, :], key_img], dim=1) - - # Handle KV cache for image editing - if kv_cache is not None and kv_cache_mode is not None: - if kv_cache_mode == KVCacheMode.WRITE: - kv_cache.store(key, value) - elif kv_cache_mode == KVCacheMode.READ: - k_cached, v_cached = kv_cache.get() - if k_cached is not None: - key = torch.cat([k_cached, key], dim=1) - value = torch.cat([v_cached, value], dim=1) - # KVCacheMode.SKIP: do nothing - - # Attention computation - hidden_states_out = self.attn(query, key, value) - hidden_states_out = hidden_states_out.flatten(2, 3) - hidden_states_out = hidden_states_out.to(dtype) + if use_sp: + # SP mode: use joint attention mechanism + # Split Q/K/V into text and image parts + text_query = query[:, :text_seq_length, :, :] + text_key = key[:, :text_seq_length, :, :] + text_value = value[:, :text_seq_length, :, :] + img_query = query[:, text_seq_length:, :, :] + img_key = key[:, text_seq_length:, :, :] + img_value = value[:, text_seq_length:, :, :] + + # Apply RoPE only to image part + if image_rotary_emb is not None: + from diffusers.models.embeddings import apply_rotary_emb + + img_query = apply_rotary_emb(img_query, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) + img_key = apply_rotary_emb(img_key, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) + + # Create attention metadata for joint attention + attn_metadata = AttentionMetadata( + joint_query=text_query, + joint_key=text_key, + joint_value=text_value, + joint_strategy="front", + ) - # Output projection - for module in self.to_out: - hidden_states_out = module(hidden_states_out) + # Add padding mask for SP if available + if hidden_states_mask is not None: + attn_metadata.attn_mask = hidden_states_mask + + # Attention computation with joint text/image + # Note: Ulysses post_attention returns [text, image] concatenated + joint_hidden_states_out = self.attn(img_query, img_key, img_value, attn_metadata) + + # Project combined [text, image] outputs, then split. + # This keeps SP numerically aligned with the non-SP path. + joint_hidden_states_out = joint_hidden_states_out.flatten(2, 3).to(dtype) + for module in self.to_out: + joint_hidden_states_out = module(joint_hidden_states_out) - # Split back to text and image - encoder_hidden_states_out = hidden_states_out[:, :text_seq_length, :] - hidden_states_out = hidden_states_out[:, text_seq_length:, :] + encoder_hidden_states_out = joint_hidden_states_out[:, :text_seq_length, :] + hidden_states_out = joint_hidden_states_out[:, text_seq_length:, :] + else: + # Non-SP mode: original logic + # Apply RoPE only to image tokens (not text tokens) + if image_rotary_emb is not None: + query_img = query[:, text_seq_length:, :, :] + key_img = key[:, text_seq_length:, :, :] + from diffusers.models.embeddings import apply_rotary_emb + + query_img = apply_rotary_emb(query_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) + key_img = apply_rotary_emb(key_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2) + query = torch.cat([query[:, :text_seq_length, :, :], query_img], dim=1) + key = torch.cat([key[:, :text_seq_length, :, :], key_img], dim=1) + + # Handle KV cache for image editing + if kv_cache is not None and kv_cache_mode is not None: + if kv_cache_mode == KVCacheMode.WRITE: + kv_cache.store(key, value) + elif kv_cache_mode == KVCacheMode.READ: + k_cached, v_cached = kv_cache.get() + if k_cached is not None: + key = torch.cat([k_cached, key], dim=1) + value = torch.cat([v_cached, value], dim=1) + + # Attention computation + attn_metadata = None + if attention_mask is not None: + if attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + attn_metadata = AttentionMetadata(attn_mask=attention_mask) + + hidden_states_out = self.attn(query, key, value, attn_metadata) + hidden_states_out = hidden_states_out.flatten(2, 3) + hidden_states_out = hidden_states_out.to(dtype) + + # Output projection + for module in self.to_out: + hidden_states_out = module(hidden_states_out) + + # Split back to text and image + encoder_hidden_states_out = hidden_states_out[:, :text_seq_length, :] + hidden_states_out = hidden_states_out[:, text_seq_length:, :] return hidden_states_out, encoder_hidden_states_out @@ -628,6 +752,7 @@ def __init__( attention_head_dim: int = 40, time_embed_dim: int = 512, ffn_hidden_dim: int | None = None, + parallel_config: DiffusionParallelConfig | None = None, ) -> None: super().__init__() @@ -637,6 +762,7 @@ def __init__( dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, + parallel_config=parallel_config, ) # 2. Feedforward @@ -654,6 +780,7 @@ def forward( attention_kwargs: dict[str, Any] | None = None, kv_cache: GlmImageLayerKVCache | None = None, kv_cache_mode: KVCacheMode | None = None, + hidden_states_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Forward pass for transformer block. @@ -667,6 +794,7 @@ def forward( attention_kwargs: Additional attention arguments kv_cache: Layer-specific KV cache for image editing kv_cache_mode: Cache mode (WRITE, READ, SKIP) + hidden_states_mask: Mask for SP padding (True=valid, False=padding) Returns: Tuple of (image_hidden_states, text_hidden_states) @@ -693,6 +821,7 @@ def forward( attention_mask=attention_mask, kv_cache=kv_cache, kv_cache_mode=kv_cache_mode, + hidden_states_mask=hidden_states_mask, ) hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1) @@ -724,6 +853,26 @@ class GlmImageTransformer2DModel(CachedTransformer): """ _repeated_blocks = ["GlmImageTransformerBlock"] + # SP plan using GlmImagePrepare module for sharding hidden_states and RoPE together. + # Similar to Qwen-Image's ImageRopePrepare, this creates a module boundary where + # _sp_plan can shard outputs via split_output=True. + # + # Key insight: hidden_states and RoPE embeddings MUST be sharded together + # to maintain dimension alignment for RoPE computation in attention layers. + _sp_plan = { + # Shard GlmImagePrepare outputs (hidden_states and RoPE must be sharded together) + "prepare": { + # hidden_states: [B, seq_len, D] - shard along sequence dimension + 0: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True, auto_pad=True), + # RoPE cos: [seq_len, dim] - shard along sequence dimension + 1: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True), + # RoPE sin: [seq_len, dim] - shard along sequence dimension + 2: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True), + # post_patch_height and post_patch_width are scalars, not sharded + }, + # Gather output at proj_out + "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } _hsdp_shard_conditions = [is_transformer_block_module] @@ -790,6 +939,9 @@ def __init__( dim=inner_dim, dim_out=inner_dim, inner_dim=inner_dim, activation_fn="linear-silu" ) + # Prepare module for SP (encapsulates patch embedding and RoPE for _sp_plan) + self.prepare = GlmImagePrepare(self.image_projector, self.rope, patch_size) + self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings( embedding_dim=time_embed_dim, condition_dim=condition_dim, @@ -806,6 +958,7 @@ def __init__( attention_head_dim, time_embed_dim, ffn_hidden_dim=ffn_hidden_dim, + parallel_config=self.parallel_config, ) for _ in range(num_layers) ] @@ -859,33 +1012,51 @@ def forward( # Get KV cache mode kv_cache_mode = kv_cache.mode if kv_cache is not None else None - # 1. RoPE - if image_rotary_emb is None: - image_rotary_emb = self.rope(hidden_states) - # Move to correct device - image_rotary_emb = ( - image_rotary_emb[0].to(hidden_states.device), - image_rotary_emb[1].to(hidden_states.device), - ) - - # 2. Patch & Timestep embeddings - p = self.patch_size - post_patch_height = height // p - post_patch_width = width // p + # Set SP context if enabled + sp_size = self.parallel_config.sequence_parallel_size + if sp_size is not None and sp_size > 1: + get_forward_context().split_text_embed_in_sp = False - hidden_states = self.image_projector(hidden_states) + # Text embedding projection encoder_hidden_states = self.glyph_projector(encoder_hidden_states) # Prior embedding with dropout prior_embedding = self.prior_token_embedding(prior_token_id) prior_embedding[prior_token_drop] *= 0.0 prior_hidden_states = self.prior_projector(prior_embedding) - hidden_states = hidden_states + prior_hidden_states + + # 1. Prepare hidden_states and RoPE via GlmImagePrepare module + # _sp_plan will shard hidden_states and RoPE together via split_output=True + hidden_states, rope_cos, rope_sin, post_patch_height_t, post_patch_width_t = self.prepare( + hidden_states, prior_hidden_states + ) + image_rotary_emb = (rope_cos, rope_sin) + post_patch_height = int(post_patch_height_t.item()) + post_patch_width = int(post_patch_width_t.item()) # Timestep conditioning temb = self.time_condition_embed(timestep, target_size, crop_coords, hidden_states.dtype) - # 3. Transformer blocks + # Create padding mask for SP if needed (after _sp_plan hooks have run) + hidden_states_mask = None + if sp_size is not None and sp_size > 1: + from vllm_omni.diffusion.forward_context import is_forward_context_available + + if is_forward_context_available(): + ctx = get_forward_context() + if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0: + img_padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size + hidden_states_mask = torch.ones( + batch_size, + img_padded_seq_len, + dtype=torch.bool, + device=hidden_states.device, + ) + hidden_states_mask[:, ctx.sp_original_seq_len :] = False + if hidden_states_mask.all(): + hidden_states_mask = None + + # 2. Transformer blocks for layer_idx, block in enumerate(self.transformer_blocks): # Get layer-specific KV cache if available layer_kv_cache = kv_cache[layer_idx] if kv_cache is not None else None @@ -899,13 +1070,16 @@ def forward( attention_kwargs, kv_cache=layer_kv_cache, kv_cache_mode=kv_cache_mode, + hidden_states_mask=hidden_states_mask, ) - # 4. Output norm & projection + # 3. Output norm & projection + # _sp_plan will gather hidden_states via proj_out hook hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) - # 5. Unpatchify: [B, H'*W', C*p*p] -> [B, C, H, W] + # 4. Unpatchify: [B, H'*W', C*p*p] -> [B, C, H, W] + p = self.patch_size hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p) output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) diff --git a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py index 375f7e7b80..0386364998 100644 --- a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py +++ b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py @@ -712,6 +712,14 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: if img is not None: preprocessed_images = [img] + # Priority: prompt dict (from ar2diffusion) > sampling_params + # ar2diffusion returns adjusted height/width that matches prior_token_ids + if not isinstance(first_prompt, str): + ar_height = first_prompt.get("height") + ar_width = first_prompt.get("width") + else: + ar_height = ar_width = None + img_height = req.sampling_params.height img_width = req.sampling_params.width @@ -719,12 +727,19 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: # Treat that as t2i warmup to avoid requiring i2i-only KV-cache inputs. is_image_edit = (preprocessed_images is not None) and (not is_dummy_warmup) - # Use image dimensions as default if available - height = req.sampling_params.height or img_height or self.default_sample_size * self.vae_scale_factor - width = req.sampling_params.width or img_width or self.default_sample_size * self.vae_scale_factor + # Use prompt dict dimensions (from ar2diffusion) as priority, then sampling_params + height = ( + ar_height or req.sampling_params.height or img_height or self.default_sample_size * self.vae_scale_factor + ) + width = ar_width or req.sampling_params.width or img_width or self.default_sample_size * self.vae_scale_factor num_inference_steps = req.sampling_params.num_inference_steps or 50 guidance_scale = req.sampling_params.guidance_scale or 1.5 + # Ensure dimensions are multiples of vae_scale_factor * patch_size + multiple_of = self.vae_scale_factor * self._patch_size + height = height // multiple_of * multiple_of + width = width // multiple_of * multiple_of + self.check_inputs(prompt=prompt, height=height, width=width, prompt_embeds=prompt_embeds) batch_size = 1 @@ -753,6 +768,20 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: prior_token_id = prior_token_id.to(device=self.device, dtype=torch.long) if prior_token_id.dim() == 1: prior_token_id = prior_token_id.unsqueeze(0) + + # Validate that prior_token_id seq_len matches dimensions + prior_seq_len = prior_token_id.shape[1] + expected_seq_len = (height // self.vae_scale_factor // self._patch_size) * ( + width // self.vae_scale_factor // self._patch_size + ) + if prior_seq_len != expected_seq_len: + raise ValueError( + f"prior_token_ids seq_len ({prior_seq_len}) doesn't match dimensions " + f"({height}x{width}, expected seq_len={expected_seq_len}). " + f"This indicates a mismatch between AR output and Diffusion input. " + f"Please ensure ar2diffusion returns correct height/width." + ) + prior_token_image_ids = None if external_prior_image_ids is not None: if isinstance(external_prior_image_ids, torch.Tensor): diff --git a/vllm_omni/diffusion/models/helios/helios_transformer.py b/vllm_omni/diffusion/models/helios/helios_transformer.py index b3d2621ad8..5e7934c3ba 100644 --- a/vllm_omni/diffusion/models/helios/helios_transformer.py +++ b/vllm_omni/diffusion/models/helios/helios_transformer.py @@ -62,10 +62,16 @@ def apply_rotary_emb_helios( """ x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1) - out = torch.empty_like(hidden_states) - out[..., 0::2] = x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2] - out[..., 1::2] = x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2] - return out.type_as(hidden_states) + # Use stack+flatten instead of strided slice assignment for contiguous + # memory layout and better performance on GPU/NPU (#2436, cf. PR #2393). + rotated = torch.stack( + ( + x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2], + x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2], + ), + dim=-1, + ) + return rotated.flatten(-2, -1).type_as(hidden_states) class DistributedRMSNorm(nn.Module): diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/__init__.py b/vllm_omni/diffusion/models/hunyuan_image3/__init__.py similarity index 58% rename from vllm_omni/diffusion/models/hunyuan_image_3/__init__.py rename to vllm_omni/diffusion/models/hunyuan_image3/__init__.py index cbc6a8ad1f..6612bd855b 100644 --- a/vllm_omni/diffusion/models/hunyuan_image_3/__init__.py +++ b/vllm_omni/diffusion/models/hunyuan_image3/__init__.py @@ -2,12 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Hunyuan Image 3 diffusion model components.""" -from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import HunyuanFusedMoE -from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_image_3_transformer import ( +from vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe import HunyuanFusedMoE +from vllm_omni.diffusion.models.hunyuan_image3.hunyuan_image3_transformer import ( HunyuanImage3Model, HunyuanImage3Text2ImagePipeline, ) -from vllm_omni.diffusion.models.hunyuan_image_3.pipeline_hunyuan_image_3 import ( +from vllm_omni.diffusion.models.hunyuan_image3.pipeline_hunyuan_image3 import ( HunyuanImage3Pipeline, ) diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/autoencoder.py b/vllm_omni/diffusion/models/hunyuan_image3/autoencoder.py similarity index 100% rename from vllm_omni/diffusion/models/hunyuan_image_3/autoencoder.py rename to vllm_omni/diffusion/models/hunyuan_image3/autoencoder.py diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_fused_moe.py similarity index 100% rename from vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py rename to vllm_omni/diffusion/models/hunyuan_image3/hunyuan_fused_moe.py diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_tokenizer.py b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_tokenizer.py similarity index 99% rename from vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_tokenizer.py rename to vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_tokenizer.py index ce563f7115..4a29e9df93 100644 --- a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_tokenizer.py +++ b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_tokenizer.py @@ -13,7 +13,7 @@ from transformers import AutoTokenizer from vllm.logger import init_logger -from .hunyuan_image_3_transformer import ImageInfo, JointImageInfo, default +from .hunyuan_image3_transformer import ImageInfo, JointImageInfo, default logger = init_logger(__name__) diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py similarity index 99% rename from vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py rename to vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py index bc81ca9c3e..fbdacddaf3 100644 --- a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py +++ b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py @@ -74,7 +74,7 @@ ) from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.layers.rope import RotaryEmbedding -from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import HunyuanFusedMoE +from vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe import HunyuanFusedMoE logger = logging.getLogger(__name__) @@ -1684,7 +1684,8 @@ def forward( else: attn_output = self.attn(q, k, v) # For o_proj - attn_output = attn_output.view(q.shape[0], -1) + # image_attn may return a non-contiguous tensor; reshape is safe here. + attn_output = attn_output.reshape(q.shape[0], -1) output, _ = self.o_proj(attn_output) output = output.reshape(bsz, q_len, -1) return output, None, past_key_value diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py similarity index 99% rename from vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py rename to vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py index 7e9e2d2787..3de0ab3101 100644 --- a/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py +++ b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py @@ -6,7 +6,6 @@ from collections.abc import Iterable from typing import Any -import numpy as np import torch import torch.nn as nn from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler @@ -25,8 +24,8 @@ from vllm_omni.diffusion.request import OmniDiffusionRequest from .autoencoder import AutoencoderKLConv3D -from .hunyuan_image_3_tokenizer import TokenizerWrapper -from .hunyuan_image_3_transformer import ( +from .hunyuan_image3_tokenizer import TokenizerWrapper +from .hunyuan_image3_transformer import ( CausalMMOutputWithPast, HunyuanImage3ImageProcessor, HunyuanImage3Model, @@ -544,7 +543,7 @@ def prepare_model_inputs( generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds] # 3. apply chat template - cfg_factor = {"gen_text": 1, "gen_image": 2} + cfg_factor = {"gen_text": 1, "gen_image": 1 + int(guidance_scale > 1.0)} bot_task = kwargs.pop("bot_task", "auto") # If `drop_think` enabled, always drop parts in the context. drop_think = kwargs.get("drop_think", self.generation_config.drop_think) @@ -1009,8 +1008,7 @@ def forward( if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale if guidance_scale <= 1.0: - logger.warning("HunyuanImage3.0 does not support guidance_scale <= 1.0, will set it to 1.0 + epsilon.") - guidance_scale = 1.0 + np.finfo(float).eps + logger.info("HunyuanImage3.0 runs without classifier-free guidance when guidance_scale <= 1.0.") image_size = (height, width) model_inputs = self.prepare_model_inputs( prompt=prompt, diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/system_prompt.py b/vllm_omni/diffusion/models/hunyuan_image3/system_prompt.py similarity index 100% rename from vllm_omni/diffusion/models/hunyuan_image_3/system_prompt.py rename to vllm_omni/diffusion/models/hunyuan_image3/system_prompt.py diff --git a/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py b/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py index a1bf7f7809..95ef919c24 100644 --- a/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py +++ b/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py @@ -1264,6 +1264,7 @@ class LTX2VideoTransformer3DModel(nn.Module): _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["norm"] _repeated_blocks = ["LTX2VideoTransformerBlock"] + _layerwise_offload_blocks_attrs = ["transformer_blocks"] _sp_plan: dict[str, Any] | None = None @staticmethod diff --git a/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py b/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py index 9ff681a3c0..3f03563a1c 100644 --- a/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py +++ b/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py @@ -5,6 +5,8 @@ import torch import torch.nn as nn +import torch.nn.functional as F +import vllm._custom_ops as ops from diffusers.models.activations import get_activation from diffusers.models.embeddings import Timesteps, get_1d_rotary_pos_embed from diffusers.models.modeling_outputs import Transformer2DModelOutput @@ -16,6 +18,7 @@ QKVParallelLinear, RowParallelLinear, ) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm_omni.diffusion.attention.layer import Attention @@ -24,6 +27,105 @@ logger = logging.getLogger(__name__) +def _patch_cutlass_padded_fp8(): + """Monkey-patch vllm._custom_ops.cutlass_scaled_mm to pad tensors whose + dimensions are not multiples of 16, so the CUTLASS FP8 kernel is used. + + OmniGen2 has hidden_size=2520 (2520 % 16 == 8). Without this patch, + vLLM's cutlass_scaled_mm falls back to a Triton scaled_mm kernel for + every FP8 linear layer (QKV, attn output, gate_up_proj, down_proj), + which is dramatically slower than the native CUTLASS FP8 tensor-core + path on H100/H200 GPUs. + + Weight tensors (b) are constant across forward passes, so padded + versions are computed once and cached by data_ptr to avoid repeated + allocation and column-major conversion overhead. + """ + _orig_cutlass_scaled_mm = ops.cutlass_scaled_mm + # Cache: data_ptr → (padded_b, padded_bias, padded_scale_b, pad_k, pad_n, orig_n) + _weight_cache: dict[int, tuple] = {} + + def _padded_cutlass_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + if b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0: + return _orig_cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + # Reshape to 2D (mirrors the original function) + target_shape = (*a.shape[:-1], b.shape[1]) + a = a.view(-1, a.shape[-1]) + orig_n = b.shape[1] + + # Cache the padded weight — it's a model parameter that never changes. + key = b.data_ptr() + if key not in _weight_cache: + pad_k = (16 - b.shape[0] % 16) % 16 + pad_n = (16 - orig_n % 16) % 16 + b_pad = b + if pad_k > 0: + b_pad = F.pad(b_pad, (0, 0, 0, pad_k)) + if pad_n > 0: + b_pad = F.pad(b_pad, (0, pad_n)) + # CUTLASS requires b column-major (stride(0)==1). + b_pad = b_pad.t().contiguous().t() + + bias_pad = None + if bias is not None and pad_n > 0: + bias_pad = F.pad(bias, (0, pad_n)) + + scale_b_pad = scale_b + if scale_b.numel() > 1 and pad_n > 0: + scale_b_pad = F.pad( + scale_b.view(-1, scale_b.shape[-1]), + (0, pad_n), + value=1.0, + ) + + _weight_cache[key] = ( + b_pad, + bias_pad, + scale_b_pad, + pad_k, + pad_n, + orig_n, + ) + + b_pad, bias_pad, scale_b_pad, pad_k, pad_n, orig_n = _weight_cache[key] + + # Pad activations on K dimension (cheap — activations are small). + if pad_k > 0: + a = F.pad(a, (0, pad_k)).contiguous() + + out = torch.empty((a.shape[0], b_pad.shape[1]), dtype=out_dtype, device=a.device) + torch.ops._C.cutlass_scaled_mm( + out, + a, + b_pad, + scale_a, + scale_b_pad, + bias_pad if bias is not None else None, + ) + + if pad_n > 0: + out = out[:, :orig_n] + + return out.view(*target_shape) + + ops.cutlass_scaled_mm = _padded_cutlass_scaled_mm + logger.info( + "Patched vllm._custom_ops.cutlass_scaled_mm with CUTLASS-padded FP8 " + "variant (avoids slow Triton fallback for non-%%16 dimensions)" + ) + + +_patch_cutlass_padded_fp8() + + class OmniGen2Attention(nn.Module): def __init__( self, @@ -31,6 +133,8 @@ def __init__( num_heads: int, num_kv_heads: int, eps: float = 1e-5, + quant_config: QuantizationConfig | None = None, + prefix: str = "", ): super().__init__() self.dim = dim @@ -46,12 +150,26 @@ def __init__( total_num_kv_heads=num_kv_heads, disable_tp=True, bias=False, + quant_config=quant_config, + prefix=f"{prefix}.to_qkv", ) self.norm_q = RMSNorm(self.head_dim, eps=eps) self.norm_k = RMSNorm(self.head_dim, eps=eps) - self.to_out = nn.ModuleList([nn.Linear(dim, dim, bias=False)]) + self.to_out = nn.ModuleList( + [ + RowParallelLinear( + dim, + dim, + bias=False, + input_is_parallel=False, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.to_out.0", + ) + ] + ) self.attn = Attention( num_heads=num_heads, head_size=self.head_dim, @@ -78,6 +196,9 @@ def forward( """ batch_size = hidden_states.shape[0] + # Contiguous layout for FP8 quantized linear GEMMs (matches FLUX DiT). + hidden_states = hidden_states.contiguous() + # Get Query-Key-Value Pair qkv, _ = self.to_qkv(hidden_states) @@ -121,7 +242,7 @@ def forward( hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(dtype) - hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[0](hidden_states.contiguous()) return hidden_states @@ -233,6 +354,7 @@ def __init__( embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool, + **kwargs, ): super().__init__() self.silu = nn.SiLU() @@ -325,6 +447,8 @@ def __init__( inner_dim: int, multiple_of: int | None = 256, ffn_dim_multiplier: float | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", ): super().__init__() @@ -338,6 +462,8 @@ def __init__( [inner_dim, inner_dim], bias=False, return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", ) self.act_fn = get_act_and_mul_fn("silu") self.down_proj = RowParallelLinear( @@ -346,6 +472,8 @@ def __init__( bias=False, input_is_parallel=True, return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", ) def forward(self, x): @@ -591,6 +719,8 @@ def __init__( ffn_dim_multiplier: float, norm_eps: float, modulation: bool = True, + quant_config: QuantizationConfig | None = None, + prefix: str = "", ) -> None: """Initialize the transformer block.""" super().__init__() @@ -602,6 +732,8 @@ def __init__( num_heads=num_attention_heads, num_kv_heads=num_kv_heads, eps=1e-5, + quant_config=quant_config, + prefix=f"{prefix}.attn", ) # Initialize feed-forward network @@ -610,11 +742,19 @@ def __init__( inner_dim=4 * dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", ) # Initialize normalization layers if modulation: - self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True) + self.norm1 = LuminaRMSNormZero( + embedding_dim=dim, + norm_eps=norm_eps, + norm_elementwise_affine=True, + quant_config=quant_config, + prefix=f"{prefix}.norm1", + ) else: self.norm1 = RMSNorm(dim, eps=norm_eps) @@ -713,6 +853,7 @@ def __init__( axes_lens: tuple[int, int, int] = (1024, 1664, 1664), text_feat_dim: int = 2048, timestep_scale: float = 1000.0, + quant_config: QuantizationConfig | None = None, ) -> None: """Initialize the OmniGen2 transformer model.""" super().__init__() @@ -770,8 +911,10 @@ def __init__( ffn_dim_multiplier, norm_eps, modulation=True, + quant_config=quant_config, + prefix=f"noise_refiner.{i}", ) - for _ in range(num_refiner_layers) + for i in range(num_refiner_layers) ] ) @@ -785,8 +928,10 @@ def __init__( ffn_dim_multiplier, norm_eps, modulation=True, + quant_config=quant_config, + prefix=f"ref_image_refiner.{i}", ) - for _ in range(num_refiner_layers) + for i in range(num_refiner_layers) ] ) @@ -800,8 +945,10 @@ def __init__( ffn_dim_multiplier, norm_eps, modulation=False, + quant_config=quant_config, + prefix=f"context_refiner.{i}", ) - for _ in range(num_refiner_layers) + for i in range(num_refiner_layers) ] ) @@ -816,8 +963,10 @@ def __init__( ffn_dim_multiplier, norm_eps, modulation=True, + quant_config=quant_config, + prefix=f"layers.{i}", ) - for _ in range(num_layers) + for i in range(num_layers) ] ) @@ -847,11 +996,25 @@ def img_patch_embed_and_refine( temb, ): batch_size = len(hidden_states) + has_ref_tokens = any(ref_img_len > 0 for ref_lens in l_effective_ref_img_len for ref_img_len in ref_lens) max_combined_img_len = max( [img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)] ) hidden_states = self.x_embedder(hidden_states) + if not has_ref_tokens: + # FP8 kernels do not support zero-token GEMM on ref_image_patch_embedder; skip that path only. + # Still run noise_refiner and return the same combined layout as the no-ref case below + # (batch, max_combined_img_len, hidden) — not raw noise tokens alone. + for layer in self.noise_refiner: + hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb) + combined_img_hidden_states = hidden_states.new_zeros( + batch_size, max_combined_img_len, self.config.hidden_size + ) + for i, img_len in enumerate(l_effective_img_len): + combined_img_hidden_states[i, :img_len] = hidden_states[i, :img_len] + return combined_img_hidden_states + ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states) for i in range(batch_size): diff --git a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py index 2d370aea19..04720c932f 100644 --- a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py +++ b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py @@ -29,6 +29,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.omnigen2.omnigen2_transformer import ( @@ -619,7 +620,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class OmniGen2Pipeline(nn.Module): +class OmniGen2Pipeline(CFGParallelMixin, nn.Module): """ Pipeline for text-to-image generation using OmniGen2. @@ -675,7 +676,10 @@ def __init__( ) transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, OmniGen2Transformer2DModel) - self.transformer = OmniGen2Transformer2DModel(**transformer_kwargs) + self.transformer = OmniGen2Transformer2DModel( + **transformer_kwargs, + quant_config=od_config.quantization_config, + ) self.mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained( model, subfolder="mllm", local_files_only=local_files_only ).to(self.device) @@ -1171,7 +1175,14 @@ def processing( self._num_timesteps = len(timesteps) for i, t in enumerate(timesteps): - model_pred = self.predict( + text_guidance_scale = ( + self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + ) + image_guidance_scale = ( + self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + ) + + positive_kwargs = dict( t=t, latents=latents, prompt_embeds=prompt_embeds, @@ -1179,15 +1190,18 @@ def processing( prompt_attention_mask=prompt_attention_mask, ref_image_hidden_states=ref_latents, ) - text_guidance_scale = ( - self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 - ) - image_guidance_scale = ( - self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + uncond_kwargs = dict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=None, ) if text_guidance_scale > 1.0 and image_guidance_scale > 1.0: - model_pred_ref = self.predict( + # 3-branch CFG: pos + ref_neg + uncond + ref_neg_kwargs = dict( t=t, latents=latents, prompt_embeds=negative_prompt_embeds, @@ -1195,31 +1209,24 @@ def processing( prompt_attention_mask=negative_prompt_attention_mask, ref_image_hidden_states=ref_latents, ) - - model_pred_uncond = self.predict( - t=t, - latents=latents, - prompt_embeds=negative_prompt_embeds, - freqs_cis=freqs_cis, - prompt_attention_mask=negative_prompt_attention_mask, - ref_image_hidden_states=None, - ) - - model_pred = ( - model_pred_uncond - + image_guidance_scale * (model_pred_ref - model_pred_uncond) - + text_guidance_scale * (model_pred - model_pred_ref) + model_pred = self.predict_noise_with_multi_branch_cfg( + do_true_cfg=True, + true_cfg_scale={ + "text": text_guidance_scale, + "image": image_guidance_scale, + }, + branches_kwargs=[positive_kwargs, ref_neg_kwargs, uncond_kwargs], ) elif text_guidance_scale > 1.0: - model_pred_uncond = self.predict( - t=t, - latents=latents, - prompt_embeds=negative_prompt_embeds, - freqs_cis=freqs_cis, - prompt_attention_mask=negative_prompt_attention_mask, - ref_image_hidden_states=None, + # 2-branch CFG: pos + uncond + model_pred = self.predict_noise_with_multi_branch_cfg( + do_true_cfg=True, + true_cfg_scale=text_guidance_scale, + branches_kwargs=[positive_kwargs, uncond_kwargs], ) - model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond) + else: + # No CFG + model_pred = self.predict_noise(**positive_kwargs) latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0] @@ -1249,8 +1256,6 @@ def predict( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - batch_size, num_channels_latents, height, width = latents.shape - optional_kwargs = {} if "ref_image_hidden_states" in set(inspect.signature(self.transformer.forward).parameters.keys()): optional_kwargs["ref_image_hidden_states"] = ref_image_hidden_states @@ -1265,6 +1270,21 @@ def predict( ) return model_pred + def predict_noise(self, **kwargs): + """Override CFGParallelMixin.predict_noise to use self.predict.""" + return self.predict(**kwargs) + + def combine_multi_branch_cfg_noise(self, predictions, true_cfg_scale, cfg_normalize=False): + """Override: 3-branch dual scale or 2-branch standard CFG.""" + if len(predictions) == 3: + text_scale = true_cfg_scale["text"] + image_scale = true_cfg_scale["image"] + pos, ref, uncond = predictions[0], predictions[1], predictions[2] + return uncond + image_scale * (ref - uncond) + text_scale * (pos - ref) + # 2-branch: standard CFG + pos, neg = predictions[0], predictions[1] + return neg + true_cfg_scale * (pos - neg) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py b/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py index 568e2f5164..c330e91de8 100644 --- a/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py +++ b/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py @@ -16,6 +16,7 @@ from collections.abc import Iterable from typing import ClassVar +import numpy as np import torch from tokenizers import Tokenizer as HFTokenizer from torch import nn @@ -30,6 +31,13 @@ from vllm_omni.model_executor.models.omnivoice.omnivoice_decoder import OmniVoiceDecoder from vllm_omni.model_executor.models.omnivoice.omnivoice_generator import OmniVoiceGenerator +try: + from transformers import HiggsAudioV2TokenizerModel +except ImportError: + HiggsAudioV2TokenizerModel = None + +import torchaudio + logger = init_logger(__name__) @@ -79,6 +87,17 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): tokenizer_path = os.path.join(self.model_path, "tokenizer.json") self.tokenizer = HFTokenizer.from_file(tokenizer_path) + # Audio tokenizer for voice cloning (requires transformers>=5.3) + if HiggsAudioV2TokenizerModel is not None: + audio_tokenizer_path = os.path.join(self.model_path, "audio_tokenizer") + self.audio_tokenizer = HiggsAudioV2TokenizerModel.from_pretrained( + audio_tokenizer_path, device_map=self.device + ).eval() + logger.info("HiggsAudioV2 tokenizer loaded for voice cloning on %s", self.device) + else: + self.audio_tokenizer = None + logger.warning("Voice cloning disabled (requires transformers>=5.3.0).") + # Duration estimator self.duration_estimator = RuleDurationEstimator() @@ -91,20 +110,46 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): self.class_temperature = self.config.class_temperature self.sample_rate = self.config.sample_rate + def _encode_ref_audio(self, audio_signal: torch.Tensor, sr: int) -> torch.Tensor: + """Encode reference audio to 8-codebook tokens for voice cloning.""" + if self.audio_tokenizer is None: + raise RuntimeError("Audio tokenizer not available for voice cloning") + if audio_signal.dim() == 1: + audio_signal = audio_signal.unsqueeze(0) + # Resample to tokenizer's expected sample rate + target_sr = self.audio_tokenizer.config.sample_rate + if sr != target_sr: + audio_signal = torchaudio.functional.resample(audio_signal, sr, target_sr) + # Ensure mono [B, 1, samples] + if audio_signal.dim() == 2: + audio_signal = audio_signal.unsqueeze(1) + with torch.inference_mode(): + tokens = self.audio_tokenizer.encode( + audio_signal.to(self.audio_tokenizer.device), return_dict=False + ) # [B, 8, T_ref] + tokens = tokens.squeeze(0) # [8, T_ref] + return tokens + @torch.inference_mode() def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: - """Generate speech audio from text. - - Args: - req: Diffusion request containing text prompt(s). + """Generate speech audio from text, optionally with voice cloning. - Returns: - DiffusionOutput with audio tensor in .output + Accepts either a plain text prompt or a structured dict: + {"text": "...", "ref_audio": (samples, sr), "ref_text": "...", + "lang": "...", "instruct": "..."} """ - # Extract text from request prompt = req.prompts[0] if req.prompts else "" + ref_audio = None + ref_text = None + lang = "None" + instruct = "None" + if isinstance(prompt, dict): text = prompt.get("input", prompt.get("text", str(prompt))) + ref_audio = prompt.get("ref_audio") + ref_text = prompt.get("ref_text") + lang = prompt.get("lang") or "None" + instruct = prompt.get("instruct") or "None" else: text = str(prompt) @@ -119,17 +164,37 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: target_len = self.duration_estimator.estimate_duration(text, "Nice to meet you.", 25) target_len = max(1, int(target_len)) - # Tokenize with control tokens - style = "<|denoise|><|lang_start|>None<|lang_end|><|instruct_start|>None<|instruct_end|>" - full_prompt = f"{style}<|text_start|>{text}<|text_end|>" + # Build text prompt with control tokens + style = f"<|denoise|><|lang_start|>{lang}<|lang_end|><|instruct_start|>{instruct}<|instruct_end|>" + if ref_text: + full_text = f"{ref_text} {text}" + else: + full_text = text + full_prompt = f"{style}<|text_start|>{full_text}<|text_end|>" encoding = self.tokenizer.encode(full_prompt) text_tokens = torch.tensor(encoding.ids, dtype=torch.long, device=device) text_len = text_tokens.shape[0] + # Encode reference audio tokens if provided + ref_audio_tokens = None + if ref_audio is not None: + if self.audio_tokenizer is None: + raise RuntimeError( + "Voice cloning requires transformers>=5.3.0. Try: uv pip install 'transformers>=5.3.0'" + ) + audio_signal, sr = ref_audio + if isinstance(audio_signal, np.ndarray): + audio_signal = torch.from_numpy(audio_signal).float() + ref_audio_tokens = self._encode_ref_audio(audio_signal, int(sr)).to(device) + # Build conditional + unconditional batches [2, 8, max_len] text_ids = text_tokens.unsqueeze(0).repeat(num_cb, 1) target_ids = torch.full((num_cb, target_len), mask_id, dtype=torch.long, device=device) - cond_ids = torch.cat([text_ids, target_ids], dim=1) + + if ref_audio_tokens is not None: + cond_ids = torch.cat([text_ids, ref_audio_tokens, target_ids], dim=1) + else: + cond_ids = torch.cat([text_ids, target_ids], dim=1) cond_len = cond_ids.shape[1] uncond_ids = target_ids.clone() diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index 9f75c84538..9ef0cacd5a 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -34,6 +34,9 @@ ) from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.diffusion.utils.size_utils import ( normalize_min_aligned_size, ) @@ -363,8 +366,10 @@ def check_inputs( "that was used to generate `negative_prompt_embeds`." ) - if max_sequence_length is not None and max_sequence_length > 1024: - raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" + ) def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() @@ -378,6 +383,8 @@ def _get_qwen_prompt_embeds( self, prompt: str | list[str] = None, dtype: torch.dtype | None = None, + max_sequence_length: int | None = None, + prompt_name: str = "prompt", ): dtype = dtype or self.text_encoder.dtype @@ -388,12 +395,27 @@ def _get_qwen_prompt_embeds( txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( txt, - max_length=self.tokenizer_max_length + drop_idx, padding=True, - truncation=True, + truncation=False, + return_tensors="pt", + ).to(self.device) + # Validate only the user prompt contribution. The Qwen template also + # adds a fixed suffix after the user text, so subtracting only + # prompt_template_encode_start_idx would overcount near-limit prompts. + template_tokens = self.tokenizer( + [template.format("")], + padding=True, + truncation=False, return_tensors="pt", ).to(self.device) - # print(f"attention mask: {txt_tokens.attention_mask}") + validate_prompt_sequence_lengths( + txt_tokens.attention_mask, + max_sequence_length=max_sequence_length or self.tokenizer_max_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name=prompt_name, + baseline_attention_mask=template_tokens.attention_mask, + error_context="after applying the Qwen prompt template", + ) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, @@ -422,6 +444,7 @@ def encode_prompt( prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, max_sequence_length: int = 1024, + prompt_name: str = "prompt", ): r""" @@ -439,7 +462,11 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, + max_sequence_length=max_sequence_length, + prompt_name=prompt_name, + ) prompt_embeds = prompt_embeds[:, :max_sequence_length] prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] @@ -632,6 +659,7 @@ def _prepare_generation_context( prompt_embeds_mask=negative_prompt_embeds_mask, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + prompt_name="negative_prompt", ) else: negative_prompt_embeds = None @@ -703,7 +731,7 @@ def prepare_encode( num_images_per_prompt=sampling.num_outputs_per_prompt if sampling.num_outputs_per_prompt > 0 else 1, generator=sampling.generator, true_cfg_scale=sampling.true_cfg_scale or 4.0, - max_sequence_length=sampling.max_sequence_length or 512, + max_sequence_length=sampling.max_sequence_length or self.tokenizer_max_length, attention_kwargs=kwargs.get("attention_kwargs"), ) @@ -934,7 +962,7 @@ def forward( output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, + max_sequence_length: int = 1024, ) -> DiffusionOutput: extracted_prompt, negative_prompt = self._extract_prompts(req.prompts) prompt = extracted_prompt or prompt diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py index dd77d71b1e..cef7fe473a 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py @@ -37,6 +37,9 @@ ) from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.diffusion.utils.size_utils import ( normalize_min_aligned_size, ) @@ -323,8 +326,10 @@ def check_inputs( "that was used to generate `negative_prompt_embeds`." ) - if max_sequence_length is not None and max_sequence_length > 1024: - raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" + ) def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() @@ -384,6 +389,8 @@ def _get_qwen_prompt_embeds( prompt: str | list[str] = None, image: PIL.Image.Image | torch.Tensor | None = None, dtype: torch.dtype | None = None, + max_sequence_length: int | None = None, + prompt_name: str = "prompt", ): """Get prompt embeddings with image support for editing.""" dtype = dtype or self.text_encoder.dtype @@ -393,6 +400,33 @@ def _get_qwen_prompt_embeds( template = self.prompt_template_encode drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, + padding=True, + truncation=False, + return_tensors="pt", + ).to(self.device) + # The edit template contains fixed multimodal scaffolding around the + # instruction. Validate against the empty-template baseline so image + # placeholder text does not consume the user's text budget. + template_tokens = self.tokenizer( + [template.format("")], + padding=True, + truncation=False, + return_tensors="pt", + ).to(self.device) + # Qwen-Image-Edit expands image placeholders into many vision tokens + # inside the processor. `max_sequence_length` is meant to constrain the + # prompt text length, so validate on the text template before image + # token expansion. + validate_prompt_sequence_lengths( + txt_tokens.attention_mask, + max_sequence_length=max_sequence_length or self.tokenizer_max_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name=prompt_name, + baseline_attention_mask=template_tokens.attention_mask, + error_context="after applying the Qwen prompt template", + ) # Use processor to handle both text and image inputs model_inputs = self.processor( @@ -434,6 +468,7 @@ def encode_prompt( prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, max_sequence_length: int = 1024, + prompt_name: str = "prompt", ): r""" @@ -453,7 +488,12 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, + image, + max_sequence_length=max_sequence_length, + prompt_name=prompt_name, + ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -624,7 +664,7 @@ def forward( output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, + max_sequence_length: int = 1024, ) -> DiffusionOutput: """Forward pass for image editing.""" # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") @@ -739,6 +779,7 @@ def forward( prompt_embeds_mask=negative_prompt_embeds_mask, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + prompt_name="negative_prompt", ) num_channels_latents = self.transformer.in_channels // 4 diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py index 6f6c9d2ba3..2e25d0fe6b 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py @@ -25,6 +25,7 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.model_metadata import QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.qwen_image.cfg_parallel import ( QwenImageCFGParallelMixin, @@ -40,6 +41,9 @@ ) from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.diffusion.utils.size_utils import ( normalize_min_aligned_size, ) @@ -53,6 +57,12 @@ CONDITION_IMAGE_SIZE = 384 * 384 VAE_IMAGE_SIZE = 1024 * 1024 +# Keep this in sync with the practical conditioning-token budget for +# Qwen-Image-Edit-2511. Empirically, 4 images stays within the supported range +# while 5 images overflows the prompt/conditioning path and fails downstream. +# Re-export the shared metadata value locally so this pipeline keeps a nearby, +# descriptive constant for validation and tests without becoming the source of truth. +MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES = QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES def get_qwen_image_edit_plus_pre_process_func( @@ -90,6 +100,11 @@ def pre_process_func( if not isinstance(raw_image, list): raw_image = [raw_image] + if len(raw_image) > MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES: + raise ValueError( + f"Received {len(raw_image)} input images. " + f"At most {MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES} images are supported by this model." + ) image = [ PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image | np.ndarray | torch.Tensor, im) for im in raw_image @@ -283,8 +298,10 @@ def check_inputs( "that was used to generate `negative_prompt_embeds`." ) - if max_sequence_length is not None and max_sequence_length > 1024: - raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" + ) def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() @@ -299,6 +316,8 @@ def _get_qwen_prompt_embeds( prompt: str | list[str], image: list[torch.Tensor] | torch.Tensor | None = None, dtype: torch.dtype | None = None, + max_sequence_length: int | None = None, + prompt_name: str = "prompt", ): """Get prompt embeddings with support for multiple images.""" dtype = dtype or self.text_encoder.dtype @@ -319,6 +338,32 @@ def _get_qwen_prompt_embeds( template = self.prompt_template_encode drop_idx = self.prompt_template_encode_start_idx txt = [template.format(base_img_prompt + e) for e in prompt] + txt_tokens = self.tokenizer( + txt, + padding=True, + truncation=False, + return_tensors="pt", + ).to(self.device) + # Multi-image edit prepends "Picture N" placeholders before the user + # instruction. Subtract the placeholder-aware baseline so attached + # images do not reduce the remaining prompt budget. + template_tokens = self.tokenizer( + [template.format(base_img_prompt)], + padding=True, + truncation=False, + return_tensors="pt", + ).to(self.device) + # The processor expands image placeholders into many vision tokens. + # `max_sequence_length` should guard the prompt text length before that + # multimodal expansion happens. + validate_prompt_sequence_lengths( + txt_tokens.attention_mask, + max_sequence_length=max_sequence_length or self.tokenizer_max_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name=prompt_name, + baseline_attention_mask=template_tokens.attention_mask, + error_context="after applying the Qwen prompt template", + ) # Use processor to handle both text and image inputs model_inputs = self.processor( @@ -360,6 +405,7 @@ def encode_prompt( prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, max_sequence_length: int = 1024, + prompt_name: str = "prompt", ): r""" @@ -379,7 +425,12 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, + image, + max_sequence_length=max_sequence_length, + prompt_name=prompt_name, + ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -557,7 +608,7 @@ def forward( output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, + max_sequence_length: int = 1024, ) -> DiffusionOutput: """Forward pass for image editing with support for multiple images.""" # TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "") @@ -692,6 +743,7 @@ def forward( prompt_embeds_mask=negative_prompt_embeds_mask, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + prompt_name="negative_prompt", ) num_channels_latents = self.transformer.in_channels // 4 diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py index 38866d89c5..905ef5b424 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py @@ -36,6 +36,9 @@ ) from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.diffusion.utils.size_utils import ( normalize_min_aligned_size, ) @@ -340,8 +343,10 @@ def check_inputs( "generate `negative_prompt_embeds`." ) - if max_sequence_length is not None and max_sequence_length > 1024: - raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" + ) def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() @@ -356,6 +361,8 @@ def _get_qwen_prompt_embeds( prompt: str | list[str] | None = None, device: torch.device | None = None, dtype: torch.dtype | None = None, + max_sequence_length: int | None = None, + prompt_name: str = "prompt", ): device = device or self.device dtype = dtype or self.text_encoder.dtype @@ -368,8 +375,26 @@ def _get_qwen_prompt_embeds( txt_tokens = self.tokenizer( txt, padding=True, + truncation=False, + return_tensors="pt", + ).to(device) + # The layered template also appends fixed non-user tokens after the + # editable text, so use the empty-template tokenized baseline instead of + # counting everything after prompt_template_encode_start_idx. + template_tokens = self.tokenizer( + [template.format("")], + padding=True, + truncation=False, return_tensors="pt", ).to(device) + validate_prompt_sequence_lengths( + txt_tokens.attention_mask, + max_sequence_length=max_sequence_length or self.tokenizer_max_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name=prompt_name, + baseline_attention_mask=template_tokens.attention_mask, + error_context="after applying the Qwen prompt template", + ) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, @@ -399,6 +424,7 @@ def encode_prompt( prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, max_sequence_length: int = 1024, + prompt_name: str = "prompt", ): r""" @@ -419,7 +445,11 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, + max_sequence_length=max_sequence_length, + prompt_name=prompt_name, + ) prompt_embeds = prompt_embeds[:, :max_sequence_length] prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] @@ -603,7 +633,7 @@ def forward( negative_prompt_embeds_mask: torch.Tensor | None = None, output_type: str | None = "pil", attention_kwargs: dict[str, Any] | None = None, - max_sequence_length: int = 512, + max_sequence_length: int = 1024, resolution: int = 640, cfg_normalize: bool = False, use_en_prompt: bool = False, @@ -736,6 +766,7 @@ def forward( device=self.device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + prompt_name="negative_prompt", ) # 4. Prepare latent variables diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py index b34f19e954..88a66d7f6b 100644 --- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py +++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -169,12 +169,15 @@ def __init__( self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + # Time embedding MLP is kept full precision (quant_config=None) — + # small layers that feed per-block modulation; precision-sensitive + # (see #2728). self.timestep_embedder.linear_1 = ReplicatedLinear( 256, embedding_dim, bias=True, return_bias=False, - quant_config=quant_config, + quant_config=None, prefix="timestep_embedder.linear_1", ) self.timestep_embedder.linear_2 = ReplicatedLinear( @@ -182,7 +185,7 @@ def __init__( embedding_dim, bias=True, return_bias=False, - quant_config=quant_config, + quant_config=None, prefix="timestep_embedder.linear_2", ) self.use_additional_t_cond = use_additional_t_cond @@ -701,7 +704,10 @@ def __init__( self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim - # Image processing modules + # Image processing modules. + # Modulation linear is kept full precision (quant_config=None) — it + # produces shift/scale/gate values that are precision-sensitive + # (see #2728). self.img_mod = nn.Sequential( nn.SiLU(), ReplicatedLinear( @@ -709,7 +715,7 @@ def __init__( 6 * dim, bias=True, return_bias=False, - quant_config=quant_config, + quant_config=None, prefix="img_mod.1", ), ) @@ -725,7 +731,7 @@ def __init__( self.img_norm2 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) self.img_mlp = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config, prefix="img_mlp") - # Text processing modules + # Text processing modules. self.txt_mod = nn.Sequential( nn.SiLU(), ReplicatedLinear( @@ -733,7 +739,7 @@ def __init__( 6 * dim, bias=True, return_bias=False, - quant_config=quant_config, + quant_config=None, prefix="txt_mod.1", ), ) @@ -744,9 +750,9 @@ def __init__( self.zero_cond_t = zero_cond_t - def _modulate(self, x, mod_params, index=None): + def _modulate(self, mod_params, index=None): """Apply modulation to input tensor""" - # x: b l d, shift: b d, scale: b d, gate: b d + # shift: b d, scale: b d, gate: b d shift, scale, gate = mod_params.chunk(3, dim=-1) if index is not None: @@ -778,7 +784,7 @@ def _modulate(self, x, mod_params, index=None): scale_result = scale.unsqueeze(1) gate_result = gate.unsqueeze(1) - return x * (1 + scale_result) + shift_result, gate_result + return scale_result, shift_result, gate_result def forward( self, @@ -804,10 +810,12 @@ def forward( txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] # Process image stream - norm1 + modulation - img_modulated, img_gate1 = self.img_norm1(hidden_states, img_mod1, modulate_index) + img_scale1, img_shift1, img_gate1 = self._modulate(img_mod1, modulate_index) + img_modulated = self.img_norm1(hidden_states, img_scale1, img_shift1) # Process text stream - norm1 + modulation - txt_modulated, txt_gate1 = self.txt_norm1(encoder_hidden_states, txt_mod1) + txt_scale1, txt_shift1, txt_gate1 = self._modulate(txt_mod1) + txt_modulated = self.txt_norm1(encoder_hidden_states, txt_scale1, txt_shift1) # Use QwenAttnProcessor2_0 for joint attention computation # This directly implements the DoubleStreamLayerMegatron logic: @@ -832,13 +840,16 @@ def forward( encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output # Process image stream - norm2 + MLP - img_modulated2, img_gate2 = self.img_norm2(hidden_states, img_mod2, modulate_index) + img_scale2, img_shift2, img_gate2 = self._modulate(img_mod2, modulate_index) + img_modulated2 = self.img_norm2(hidden_states, img_scale2, img_shift2) img_mlp_output = self.img_mlp(img_modulated2) hidden_states = hidden_states + img_gate2 * img_mlp_output # Process text stream - norm2 + MLP - txt_modulated2, txt_gate2 = self.txt_norm2(encoder_hidden_states, txt_mod2) + txt_scale2, txt_shift2, txt_gate2 = self._modulate(txt_mod2) + txt_modulated2 = self.txt_norm2(encoder_hidden_states, txt_scale2, txt_shift2) + txt_mlp_output = self.txt_mlp(txt_modulated2) encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output @@ -958,12 +969,14 @@ def __init__( self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) + # Entry projections (image/text) are kept full precision — + # small sensitive layers at the network boundary (see #2728). self.img_in = ReplicatedLinear( in_channels, self.inner_dim, bias=True, return_bias=False, - quant_config=quant_config, + quant_config=None, prefix="img_in", ) self.txt_in = ReplicatedLinear( @@ -971,7 +984,7 @@ def __init__( self.inner_dim, bias=True, return_bias=False, - quant_config=quant_config, + quant_config=None, prefix="txt_in", ) @@ -988,13 +1001,16 @@ def __init__( ] ) + # Final modulation and output projection are kept full precision — + # they produce the output latent and are precision-sensitive + # (see #2728). self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.norm_out.linear = ReplicatedLinear( self.inner_dim, 2 * self.inner_dim, bias=True, return_bias=False, - quant_config=quant_config, + quant_config=None, prefix="norm_out.linear", ) self.proj_out = ReplicatedLinear( @@ -1002,7 +1018,7 @@ def __init__( patch_size * patch_size * self.out_channels, bias=True, return_bias=False, - quant_config=quant_config, + quant_config=None, prefix="proj_out", ) diff --git a/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py b/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py index 22d56ac1fd..4a4892673f 100644 --- a/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py +++ b/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py @@ -375,6 +375,8 @@ class StableAudioDiTModel(nn.Module): - Output: [B, out_channels, L] """ + _repeated_blocks = ["StableAudioDiTBlock"] + def __init__( self, od_config: OmniDiffusionConfig | None = None, diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index a550e576f0..652425d509 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -24,14 +24,60 @@ from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler +from vllm_omni.diffusion.models.wan2_2.scheduling_wan_euler import WanEulerScheduler from vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import WanTransformer3DModel +from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform logger = logging.getLogger(__name__) DEBUG_PERF = False +WAN_SAMPLE_SOLVER_CHOICES = {"unipc", "euler"} +WAN22_MAX_SEQUENCE_LENGTH = 512 + + +def build_wan_scheduler(sample_solver: str, flow_shift: float) -> Any: + if sample_solver == "unipc": + return FlowUniPCMultistepScheduler( + num_train_timesteps=1000, + shift=flow_shift, + prediction_type="flow_prediction", + ) + if sample_solver == "euler": + return WanEulerScheduler( + num_train_timesteps=1000, + shift=flow_shift, + ) + + raise ValueError( + f"Unsupported Wan sample_solver: {sample_solver}. Expected one of: {sorted(WAN_SAMPLE_SOLVER_CHOICES)}" + ) + + +def resolve_wan_sample_solver(req: OmniDiffusionRequest, default: str = "unipc") -> str: + extra_args = getattr(req.sampling_params, "extra_args", {}) or {} + raw = extra_args.get("sample_solver", default) + sample_solver = str(raw).strip().lower() + if sample_solver not in WAN_SAMPLE_SOLVER_CHOICES: + raise ValueError(f"Invalid sample_solver={raw!r}. Expected one of: {sorted(WAN_SAMPLE_SOLVER_CHOICES)}") + return sample_solver + + +def resolve_wan_flow_shift(req: OmniDiffusionRequest, od_config: OmniDiffusionConfig) -> float: + extra_args = getattr(req.sampling_params, "extra_args", {}) or {} + raw_flow_shift = extra_args.get("flow_shift") + if raw_flow_shift is None: + raw_flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 + + try: + return float(raw_flow_shift) + except (TypeError, ValueError) as exc: + raise ValueError(f"Invalid flow_shift={raw_flow_shift!r}. flow_shift must be a float.") from exc def retrieve_latents( @@ -121,10 +167,23 @@ def get_wan22_post_process_func( def post_process_func( video: torch.Tensor, output_type: str = "np", + sampling_params=None, ): if output_type == "latent": return video - return video_processor.postprocess_video(video, output_type=output_type) + custom_output = {} + if sampling_params is not None and getattr(sampling_params, "enable_frame_interpolation", False): + video, multiplier = interpolate_video_tensor( + video, + exp=sampling_params.frame_interpolation_exp, + scale=sampling_params.frame_interpolation_scale, + model_path=sampling_params.frame_interpolation_model_path, + ) + custom_output["video_fps_multiplier"] = multiplier + return { + "video": video_processor.postprocess_video(video, output_type=output_type), + "custom_output": custom_output, + } return post_process_func @@ -234,6 +293,7 @@ def __init__( pass self.boundary_ratio = od_config.boundary_ratio + self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH # Determine which transformers to load based on boundary_ratio # boundary_ratio=1.0: only load transformer_2 (low-noise stage only) @@ -296,13 +356,9 @@ def __init__( else: raise RuntimeError("No transformer loaded") - # Initialize UniPC scheduler - flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p - self.scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=1000, - shift=flow_shift, - prediction_type="flow_prediction", - ) + self._sample_solver = "unipc" + self._flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 + self.scheduler = build_wan_scheduler(self._sample_solver, self._flow_shift) self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 @@ -336,6 +392,102 @@ def num_timesteps(self): def current_timestep(self): return self._current_timestep + def diffuse( + self, + latents: torch.Tensor, + timesteps: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor | None, + guidance_low: float, + guidance_high: float, + boundary_timestep: float | None, + dtype: torch.dtype, + attention_kwargs: dict[str, Any], + latent_condition: torch.Tensor | None = None, + first_frame_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + with self.progress_bar(total=len(timesteps)) as pbar: + for t in timesteps: + self._current_timestep = t + + # Select model based on timestep and boundary_ratio + # High noise stage (t >= boundary_timestep): use transformer + # Low noise stage (t < boundary_timestep): use transformer_2 + if boundary_timestep is not None and t < boundary_timestep: + # Low noise stage - always use guidance_high for this stage + current_guidance_scale = guidance_high + if self.transformer_2 is not None: + current_model = self.transformer_2 + elif self.transformer is not None: + # Fallback to transformer if transformer_2 not loaded + current_model = self.transformer + else: + raise RuntimeError("No transformer available for low-noise stage") + else: + # High noise stage - always use guidance_low for this stage + current_guidance_scale = guidance_low + if self.transformer is not None: + current_model = self.transformer + elif self.transformer_2 is not None: + # Fallback to transformer_2 if transformer not loaded + current_model = self.transformer_2 + else: + raise RuntimeError("No transformer available for high-noise stage") + + if self.expand_timesteps and latent_condition is not None: + # I2V mode: blend condition with latents using mask + latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents + latent_model_input = latent_model_input.to(dtype) + + # Expand timesteps per patch - use floor division to match patch embedding + patch_size = self.transformer_config.patch_size + patch_height = latents.shape[3] // patch_size[1] + patch_width = latents.shape[4] // patch_size[2] + + # Create mask at patch resolution (same as hidden states sequence length) + patch_mask = first_frame_mask[:, :, :, :: patch_size[1], :: patch_size[2]] + patch_mask = patch_mask[:, :, :, :patch_height, :patch_width] # Ensure correct dimensions + temp_ts = (patch_mask[0][0] * t).flatten() + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + # T2V mode: standard forward + latent_model_input = latents.to(dtype) + timestep = t.expand(latents.shape[0]) + + do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": prompt_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": current_model, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": negative_prompt_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": current_model, + } + else: + negative_kwargs = None + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=current_guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + ) + + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + pbar.update() + + return latents + def forward( self, req: OmniDiffusionRequest, @@ -413,6 +565,7 @@ def forward( negative_prompt_embeds=negative_prompt_embeds, guidance_scale_2=guidance_high if boundary_ratio is not None else None, boundary_ratio=boundary_ratio, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) if num_frames % self.vae_scale_factor_temporal != 1: @@ -446,7 +599,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or 512, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, device=device, dtype=dtype, ) @@ -462,6 +615,13 @@ def forward( current_omni_platform.synchronize() _t_text_enc_ms = (time.perf_counter() - _t_text_enc_start) * 1000 + sample_solver = resolve_wan_sample_solver(req, default=self._sample_solver) + flow_shift = resolve_wan_flow_shift(req, self.od_config) + if sample_solver != self._sample_solver or abs(flow_shift - self._flow_shift) > 1e-6: + self.scheduler = build_wan_scheduler(sample_solver, flow_shift) + self._sample_solver = sample_solver + self._flow_shift = flow_shift + # Timesteps self.scheduler.set_timesteps(num_steps, device=device) timesteps = self.scheduler.timesteps @@ -571,90 +731,19 @@ def forward( if DEBUG_PERF: _t_denoise_start = time.perf_counter() - with self.progress_bar(total=len(timesteps)) as pbar: - for t in timesteps: - self._current_timestep = t - - # Select model based on timestep and boundary_ratio - # High noise stage (t >= boundary_timestep): use transformer - # Low noise stage (t < boundary_timestep): use transformer_2 - if boundary_timestep is not None and t < boundary_timestep: - # Low noise stage - always use guidance_high for this stage - current_guidance_scale = guidance_high - if self.transformer_2 is not None: - current_model = self.transformer_2 - elif self.transformer is not None: - # Fallback to transformer if transformer_2 not loaded - current_model = self.transformer - else: - raise RuntimeError("No transformer available for low-noise stage") - else: - # High noise stage - always use guidance_low for this stage - current_guidance_scale = guidance_low - if self.transformer is not None: - current_model = self.transformer - elif self.transformer_2 is not None: - # Fallback to transformer_2 if transformer not loaded - current_model = self.transformer_2 - else: - raise RuntimeError("No transformer available for high-noise stage") - - if self.expand_timesteps and latent_condition is not None: - # I2V mode: blend condition with latents using mask - latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents - latent_model_input = latent_model_input.to(dtype) - - # Expand timesteps per patch - use floor division to match patch embedding - patch_size = self.transformer_config.patch_size - num_latent_frames = latents.shape[2] - patch_height = latents.shape[3] // patch_size[1] - patch_width = latents.shape[4] // patch_size[2] - - # Create mask at patch resolution (same as hidden states sequence length) - patch_mask = first_frame_mask[:, :, :, :: patch_size[1], :: patch_size[2]] - patch_mask = patch_mask[:, :, :, :patch_height, :patch_width] # Ensure correct dimensions - temp_ts = (patch_mask[0][0] * t).flatten() - timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) - else: - # T2V mode: standard forward - latent_model_input = latents.to(dtype) - timestep = t.expand(latents.shape[0]) - - do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None - # Prepare kwargs for positive and negative predictions - positive_kwargs = { - "hidden_states": latent_model_input, - "timestep": timestep, - "encoder_hidden_states": prompt_embeds, - "attention_kwargs": attention_kwargs, - "return_dict": False, - "current_model": current_model, - } - if do_true_cfg: - negative_kwargs = { - "hidden_states": latent_model_input, - "timestep": timestep, - "encoder_hidden_states": negative_prompt_embeds, - "attention_kwargs": attention_kwargs, - "return_dict": False, - "current_model": current_model, - } - else: - negative_kwargs = None - - # Predict noise with automatic CFG parallel handling - noise_pred = self.predict_noise_maybe_with_cfg( - do_true_cfg=do_true_cfg, - true_cfg_scale=current_guidance_scale, - positive_kwargs=positive_kwargs, - negative_kwargs=negative_kwargs, - cfg_normalize=False, - ) - - # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync - latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) - - pbar.update() + latents = self.diffuse( + latents=latents, + timesteps=timesteps, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_low=guidance_low, + guidance_high=guidance_high, + boundary_timestep=boundary_timestep, + dtype=dtype, + attention_kwargs=attention_kwargs, + latent_condition=latent_condition, + first_frame_mask=first_frame_mask, + ) # Wan2.2 is prone to out of memory errors when predicting large videos # so we empty the cache here to avoid OOM before vae decoding. @@ -743,6 +832,20 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt prompt_clean = [self._prompt_clean(p) for p in prompt] batch_size = len(prompt_clean) + text_inputs_untruncated = self.tokenizer( + prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_prompt_sequence_lengths( + text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + error_context="for Wan2.2 text encoding", + ) text_inputs = self.tokenizer( prompt_clean, @@ -771,8 +874,24 @@ def encode_prompt( if do_classifier_free_guidance: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] + neg_text_inputs_untruncated = self.tokenizer( + negative_prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_prompt_sequence_lengths( + neg_text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name="negative_prompt", + error_context="for Wan2.2 text encoding", + ) neg_text_inputs = self.tokenizer( - [self._prompt_clean(p) for p in negative_prompt], + negative_prompt_clean, padding="max_length", max_length=max_sequence_length, truncation=True, @@ -844,6 +963,7 @@ def check_inputs( negative_prompt_embeds=None, guidance_scale_2=None, boundary_ratio=None, + max_sequence_length=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -870,5 +990,10 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" + ) + if boundary_ratio is None and guidance_scale_2 is not None: raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.") diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index c05ecc9c9a..95d1e08bbc 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -12,6 +12,7 @@ import numpy as np import PIL.Image import torch +import torchvision.transforms.functional as TF from diffusers.utils.torch_utils import randn_tensor from torch import nn from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel @@ -24,14 +25,21 @@ from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero -from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( + WAN22_MAX_SEQUENCE_LENGTH, + build_wan_scheduler, create_transformer_from_config, load_transformer_config, + resolve_wan_flow_shift, + resolve_wan_sample_solver, retrieve_latents, ) +from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform @@ -72,10 +80,23 @@ def get_wan22_i2v_post_process_func( def post_process_func( video: torch.Tensor, output_type: str = "np", + sampling_params=None, ): if output_type == "latent": return video - return video_processor.postprocess_video(video, output_type=output_type) + custom_output = {} + if sampling_params is not None and getattr(sampling_params, "enable_frame_interpolation", False): + video, multiplier = interpolate_video_tensor( + video, + exp=sampling_params.frame_interpolation_exp, + scale=sampling_params.frame_interpolation_scale, + model_path=sampling_params.frame_interpolation_model_path, + ) + custom_output["video_fps_multiplier"] = multiplier + return { + "video": video_processor.postprocess_video(video, output_type=output_type), + "custom_output": custom_output, + } return post_process_func @@ -197,6 +218,7 @@ def __init__( # Text encoder self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH self.text_encoder = UMT5EncoderModel.from_pretrained( model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only ).to(self.device) @@ -230,13 +252,9 @@ def __init__( else: self.transformer_2 = None - # Initialize UniPC scheduler - flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p - self.scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=1000, - shift=flow_shift, - prediction_type="flow_prediction", - ) + self._sample_solver = "unipc" + self._flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 + self.scheduler = build_wan_scheduler(self._sample_solver, self._flow_shift) # VAE scale factors self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if hasattr(self.vae, "config") else 4 @@ -272,6 +290,82 @@ def num_timesteps(self): def current_timestep(self): return self._current_timestep + def diffuse( + self, + latents: torch.Tensor, + timesteps: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor | None, + image_embeds: torch.Tensor | None, + guidance_low: float, + guidance_high: float, + boundary_timestep: float | None, + dtype: torch.dtype, + attention_kwargs: dict[str, Any], + condition: torch.Tensor, + first_frame_mask: torch.Tensor, + ) -> torch.Tensor: + with self.progress_bar(total=len(timesteps)) as pbar: + for t in timesteps: + self._current_timestep = t + + # Select model and guidance scale based on timestep + current_model = self.transformer + current_guidance_scale = guidance_low + if boundary_timestep is not None and t < boundary_timestep and self.transformer_2 is not None: + current_model = self.transformer_2 + current_guidance_scale = guidance_high + + # Prepare latent input + if self.expand_timesteps: + # TI2V-5B style: blend condition with latents using mask + latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents + latent_model_input = latent_model_input.to(dtype) + + # Expand timesteps for each patch + temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten() + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + # Wan2.1 style: concatenate condition with latents + latent_model_input = torch.cat([latents, condition], dim=1).to(dtype) + timestep = t.expand(latents.shape[0]) + + do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": prompt_embeds, + "encoder_hidden_states_image": image_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": current_model, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": negative_prompt_embeds, + "encoder_hidden_states_image": image_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": current_model, + } + else: + negative_kwargs = None + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=current_guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + ) + + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + pbar.update() + + return latents + def encode_image( self, image: PIL.Image.Image | list[PIL.Image.Image], @@ -380,6 +474,7 @@ def forward( image_embeds=image_embeds, guidance_scale_2=guidance_high if boundary_ratio is not None else None, boundary_ratio=boundary_ratio, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) # Adjust num_frames to be compatible with VAE temporal scaling @@ -408,7 +503,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or 512, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, device=device, dtype=dtype, ) @@ -440,6 +535,13 @@ def forward( current_omni_platform.synchronize() _t_img_enc_ms = (time.perf_counter() - _t_img_enc_start) * 1000 + sample_solver = resolve_wan_sample_solver(req, default=self._sample_solver) + flow_shift = resolve_wan_flow_shift(req, self.od_config) + if sample_solver != self._sample_solver or abs(flow_shift - self._flow_shift) > 1e-6: + self.scheduler = build_wan_scheduler(sample_solver, flow_shift) + self._sample_solver = sample_solver + self._flow_shift = flow_shift + # Timesteps self.scheduler.set_timesteps(num_steps, device=device) timesteps = self.scheduler.timesteps @@ -459,6 +561,7 @@ def forward( video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) if isinstance(image, PIL.Image.Image): + image = TF.to_tensor(image).to(device) image_tensor = video_processor.preprocess(image, height=height, width=width) else: image_tensor = image @@ -467,6 +570,7 @@ def forward( # Handle last_image if provided if last_image is not None: if isinstance(last_image, PIL.Image.Image): + image = TF.to_tensor(last_image).to(device) last_image_tensor = video_processor.preprocess(last_image, height=height, width=width) else: last_image_tensor = last_image @@ -497,68 +601,20 @@ def forward( if DEBUG_PERF: _t_denoise_start = time.perf_counter() - with self.progress_bar(total=len(timesteps)) as pbar: - for t in timesteps: - self._current_timestep = t - - # Select model and guidance scale based on timestep - current_model = self.transformer - current_guidance_scale = guidance_low - if boundary_timestep is not None and t < boundary_timestep and self.transformer_2 is not None: - current_model = self.transformer_2 - current_guidance_scale = guidance_high - - # Prepare latent input - if self.expand_timesteps: - # TI2V-5B style: blend condition with latents using mask - latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents - latent_model_input = latent_model_input.to(dtype) - - # Expand timesteps for each patch - temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten() - timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) - else: - # Wan2.1 style: concatenate condition with latents - latent_model_input = torch.cat([latents, condition], dim=1).to(dtype) - timestep = t.expand(latents.shape[0]) - - do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None - # Prepare kwargs for positive and negative predictions - positive_kwargs = { - "hidden_states": latent_model_input, - "timestep": timestep, - "encoder_hidden_states": prompt_embeds, - "encoder_hidden_states_image": image_embeds, - "attention_kwargs": attention_kwargs, - "return_dict": False, - "current_model": current_model, - } - if do_true_cfg: - negative_kwargs = { - "hidden_states": latent_model_input, - "timestep": timestep, - "encoder_hidden_states": negative_prompt_embeds, - "encoder_hidden_states_image": image_embeds, - "attention_kwargs": attention_kwargs, - "return_dict": False, - "current_model": current_model, - } - else: - negative_kwargs = None - - # Predict noise with automatic CFG parallel handling - noise_pred = self.predict_noise_maybe_with_cfg( - do_true_cfg=do_true_cfg, - true_cfg_scale=current_guidance_scale, - positive_kwargs=positive_kwargs, - negative_kwargs=negative_kwargs, - cfg_normalize=False, - ) - - # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync - latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) - - pbar.update() + latents = self.diffuse( + latents=latents, + timesteps=timesteps, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image_embeds=image_embeds, + guidance_low=guidance_low, + guidance_high=guidance_high, + boundary_timestep=boundary_timestep, + dtype=dtype, + attention_kwargs=attention_kwargs, + condition=condition, + first_frame_mask=first_frame_mask, + ) # Wan2.2 is prone to out of memory errors when predicting large videos # so we empty the cache here to avoid OOM before vae decoding. @@ -652,6 +708,20 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt prompt_clean = [self._prompt_clean(p) for p in prompt] batch_size = len(prompt_clean) + text_inputs_untruncated = self.tokenizer( + prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_prompt_sequence_lengths( + text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + error_context="for Wan2.2 text encoding", + ) text_inputs = self.tokenizer( prompt_clean, @@ -680,8 +750,24 @@ def encode_prompt( if do_classifier_free_guidance: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] + neg_text_inputs_untruncated = self.tokenizer( + negative_prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_prompt_sequence_lengths( + neg_text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name="negative_prompt", + error_context="for Wan2.2 text encoding", + ) neg_text_inputs = self.tokenizer( - [self._prompt_clean(p) for p in negative_prompt], + negative_prompt_clean, padding="max_length", max_length=max_sequence_length, truncation=True, @@ -789,12 +875,14 @@ def prepare_latents( return latents, latent_condition, first_frame_mask # Wan2.1 style: create mask and concatenate with condition - mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + mask_lat_size = torch.ones( + batch_size, 1, num_frames, latent_height, latent_width, device=latent_condition.device + ) if last_image is None: - mask_lat_size[:, :, list(range(1, num_frames))] = 0 + mask_lat_size[:, :, 1:] = 0 else: - mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 + mask_lat_size[:, :, 1 : num_frames - 1] = 0 first_frame_mask = mask_lat_size[:, :, 0:1] first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) @@ -823,6 +911,7 @@ def check_inputs( image_embeds=None, guidance_scale_2=None, boundary_ratio=None, + max_sequence_length=None, ): if image is None and image_embeds is None: raise ValueError("Provide either `image` or `image_embeds`. Cannot leave both undefined.") @@ -844,6 +933,11 @@ def check_inputs( if prompt is None and prompt_embeds is None: raise ValueError("Provide either `prompt` or `prompt_embeds`.") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" + ) + if boundary_ratio is None and guidance_scale_2 is not None: raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.") diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index 261f62fb79..dba76ba8af 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -36,13 +36,20 @@ from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin -from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( + WAN22_MAX_SEQUENCE_LENGTH, + build_wan_scheduler, create_transformer_from_config, load_transformer_config, + resolve_wan_flow_shift, + resolve_wan_sample_solver, retrieve_latents, ) +from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.prompt_utils import ( + validate_prompt_sequence_lengths, +) from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform @@ -59,10 +66,23 @@ def get_wan22_ti2v_post_process_func( def post_process_func( video: torch.Tensor, output_type: str = "np", + sampling_params=None, ): if output_type == "latent": return video - return video_processor.postprocess_video(video, output_type=output_type) + custom_output = {} + if sampling_params is not None and getattr(sampling_params, "enable_frame_interpolation", False): + video, multiplier = interpolate_video_tensor( + video, + exp=sampling_params.frame_interpolation_exp, + scale=sampling_params.frame_interpolation_scale, + model_path=sampling_params.frame_interpolation_model_path, + ) + custom_output["video_fps_multiplier"] = multiplier + return { + "video": video_processor.postprocess_video(video, output_type=output_type), + "custom_output": custom_output, + } return post_process_func @@ -169,6 +189,7 @@ def __init__( # Text encoder self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH self.text_encoder = UMT5EncoderModel.from_pretrained( model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only ).to(self.device) @@ -183,13 +204,9 @@ def __init__( transformer_config = load_transformer_config(model, "transformer", local_files_only) self.transformer = create_transformer_from_config(transformer_config) - # Initialize UniPC scheduler - flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p - self.scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=1000, - shift=flow_shift, - prediction_type="flow_prediction", - ) + self._sample_solver = "unipc" + self._flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 + self.scheduler = build_wan_scheduler(self._sample_solver, self._flow_shift) # VAE scale factors self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if hasattr(self.vae, "config") else 4 @@ -218,6 +235,77 @@ def num_timesteps(self): def current_timestep(self): return self._current_timestep + def diffuse( + self, + latents: torch.Tensor, + timesteps: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor | None, + guidance_scale: float, + dtype: torch.dtype, + attention_kwargs: dict[str, Any], + num_latent_frames: int, + latent_height: int, + latent_width: int, + latent_condition: torch.Tensor | None = None, + first_frame_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + with self.progress_bar(total=len(timesteps)) as pbar: + for t in timesteps: + self._current_timestep = t + + # Prepare latent input + if latent_condition is not None: + # I2V mode: blend condition with latents using mask + latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents + latent_model_input = latent_model_input.to(dtype) + + # Expand timesteps for each patch (TI2V style) + temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten() + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + # T2V mode: use latents directly + latent_model_input = latents.to(dtype) + + # Expand timesteps for TI2V model architecture + mask = torch.ones(1, 1, num_latent_frames, latent_height, latent_width, device=latents.device) + temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + + do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": prompt_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": self.transformer, + } + if do_true_cfg: + negative_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": negative_prompt_embeds, + "attention_kwargs": attention_kwargs, + "return_dict": False, + "current_model": self.transformer, + } + else: + negative_kwargs = None + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + ) + + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + pbar.update() + + return latents + def forward( self, req: OmniDiffusionRequest, @@ -289,6 +377,7 @@ def forward( width=width, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) # Adjust num_frames to be compatible with VAE temporal scaling @@ -312,7 +401,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_scale > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or 512, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, device=device, dtype=dtype, ) @@ -323,6 +412,13 @@ def forward( batch_size = prompt_embeds.shape[0] + sample_solver = resolve_wan_sample_solver(req, default=self._sample_solver) + flow_shift = resolve_wan_flow_shift(req, self.od_config) + if sample_solver != self._sample_solver or abs(flow_shift - self._flow_shift) > 1e-6: + self.scheduler = build_wan_scheduler(sample_solver, flow_shift) + self._sample_solver = sample_solver + self._flow_shift = flow_shift + # Timesteps self.scheduler.set_timesteps(num_steps, device=device) timesteps = self.scheduler.timesteps @@ -380,64 +476,20 @@ def forward( if attention_kwargs is None: attention_kwargs = {} - # Denoising loop - with self.progress_bar(total=len(timesteps)) as pbar: - for t in timesteps: - self._current_timestep = t - - # Prepare latent input - if latent_condition is not None: - # I2V mode: blend condition with latents using mask - latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents - latent_model_input = latent_model_input.to(dtype) - - # Expand timesteps for each patch (TI2V style) - temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten() - timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) - else: - # T2V mode: use latents directly - latent_model_input = latents.to(dtype) - - # Expand timesteps for TI2V model architecture - mask = torch.ones(1, 1, num_latent_frames, latent_height, latent_width, device=device) - temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() - timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) - - do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None - # Prepare kwargs for positive and negative predictions - positive_kwargs = { - "hidden_states": latent_model_input, - "timestep": timestep, - "encoder_hidden_states": prompt_embeds, - "attention_kwargs": attention_kwargs, - "return_dict": False, - "current_model": self.transformer, - } - if do_true_cfg: - negative_kwargs = { - "hidden_states": latent_model_input, - "timestep": timestep, - "encoder_hidden_states": negative_prompt_embeds, - "attention_kwargs": attention_kwargs, - "return_dict": False, - "current_model": self.transformer, - } - else: - negative_kwargs = None - - # Predict noise with automatic CFG parallel handling - noise_pred = self.predict_noise_maybe_with_cfg( - do_true_cfg=do_true_cfg, - true_cfg_scale=guidance_scale, - positive_kwargs=positive_kwargs, - negative_kwargs=negative_kwargs, - cfg_normalize=False, - ) - - # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync - latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) - - pbar.update() + latents = self.diffuse( + latents=latents, + timesteps=timesteps, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + dtype=dtype, + attention_kwargs=attention_kwargs, + num_latent_frames=num_latent_frames, + latent_height=latent_height, + latent_width=latent_width, + latent_condition=latent_condition, + first_frame_mask=first_frame_mask, + ) # Wan2.2 is prone to out of memory errors when predicting large videos # so we empty the cache here to avoid OOM before vae decoding. @@ -499,6 +551,20 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt prompt_clean = [self._prompt_clean(p) for p in prompt] batch_size = len(prompt_clean) + text_inputs_untruncated = self.tokenizer( + prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_prompt_sequence_lengths( + text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + error_context="for Wan2.2 text encoding", + ) text_inputs = self.tokenizer( prompt_clean, @@ -527,8 +593,24 @@ def encode_prompt( if do_classifier_free_guidance: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] + neg_text_inputs_untruncated = self.tokenizer( + negative_prompt_clean, + padding=True, + truncation=False, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + validate_prompt_sequence_lengths( + neg_text_inputs_untruncated.attention_mask, + max_sequence_length=max_sequence_length, + supported_max_sequence_length=self.tokenizer_max_length, + prompt_name="negative_prompt", + error_context="for Wan2.2 text encoding", + ) neg_text_inputs = self.tokenizer( - [self._prompt_clean(p) for p in negative_prompt], + negative_prompt_clean, padding="max_length", max_length=max_sequence_length, truncation=True, @@ -653,6 +735,7 @@ def check_inputs( width, prompt_embeds=None, negative_prompt_embeds=None, + max_sequence_length=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -668,6 +751,11 @@ def check_inputs( if prompt is None and prompt_embeds is None: raise ValueError("Provide either `prompt` or `prompt_embeds`.") + if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" + ) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights using AutoWeightsLoader for vLLM integration.""" loader = AutoWeightsLoader(self) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py index ea52336311..11408e2d24 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py @@ -176,6 +176,62 @@ def _create_transformer(self, config: dict) -> WanVACETransformer3DModel: """Build VACE transformer directly from config dict.""" return create_vace_transformer_from_config(config) + def diffuse( + self, + latents: torch.Tensor, + timesteps: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor | None, + guidance_scale: float, + dtype: torch.dtype, + attention_kwargs: dict[str, object], + vace_context: torch.Tensor | None, + vace_context_scale: float, + ) -> torch.Tensor: + with self.progress_bar(total=len(timesteps)) as pbar: + for t in timesteps: + self._current_timestep = t + latent_model_input = latents.to(dtype) + timestep = t.expand(latents.shape[0]) + + do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None + + positive_kwargs = { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": prompt_embeds, + "attention_kwargs": attention_kwargs, + "vace_context": vace_context, + "vace_context_scale": vace_context_scale, + "return_dict": False, + } + negative_kwargs = ( + { + "hidden_states": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": negative_prompt_embeds, + "attention_kwargs": attention_kwargs, + "vace_context": vace_context, + "vace_context_scale": vace_context_scale, + "return_dict": False, + } + if do_true_cfg + else None + ) + + noise_pred = self.predict_noise_maybe_with_cfg( + do_true_cfg=do_true_cfg, + true_cfg_scale=guidance_scale, + positive_kwargs=positive_kwargs, + negative_kwargs=negative_kwargs, + cfg_normalize=False, + ) + + latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + pbar.update() + + return latents + def check_inputs( self, prompt, @@ -187,6 +243,7 @@ def check_inputs( video=None, mask=None, reference_images=None, + max_sequence_length=None, ): super().check_inputs( prompt=prompt, @@ -195,6 +252,7 @@ def check_inputs( width=width, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, ) # VACE-specific: validate video/mask/reference_images consistency @@ -491,6 +549,7 @@ def forward( video=source_video, mask=source_mask, reference_images=reference_images, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) device = self.device @@ -509,7 +568,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_scale > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or 512, + max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, device=device, dtype=dtype, ) @@ -569,48 +628,17 @@ def forward( timesteps = self.scheduler.timesteps self._num_timesteps = len(timesteps) - # Denoising loop - with self.progress_bar(total=len(timesteps)) as pbar: - for t in timesteps: - self._current_timestep = t - latent_model_input = latents.to(dtype) - timestep = t.expand(latents.shape[0]) - - do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None - - positive_kwargs = { - "hidden_states": latent_model_input, - "timestep": timestep, - "encoder_hidden_states": prompt_embeds, - "attention_kwargs": attention_kwargs, - "vace_context": vace_context, - "vace_context_scale": vace_context_scale, - "return_dict": False, - } - negative_kwargs = ( - { - "hidden_states": latent_model_input, - "timestep": timestep, - "encoder_hidden_states": negative_prompt_embeds, - "attention_kwargs": attention_kwargs, - "vace_context": vace_context, - "vace_context_scale": vace_context_scale, - "return_dict": False, - } - if do_true_cfg - else None - ) - - noise_pred = self.predict_noise_maybe_with_cfg( - do_true_cfg=do_true_cfg, - true_cfg_scale=guidance_scale, - positive_kwargs=positive_kwargs, - negative_kwargs=negative_kwargs, - cfg_normalize=False, - ) - - latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) - pbar.update() + latents = self.diffuse( + latents=latents, + timesteps=timesteps, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + dtype=dtype, + attention_kwargs=attention_kwargs, + vace_context=vace_context, + vace_context_scale=vace_context_scale, + ) self._current_timestep = None diff --git a/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py b/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py new file mode 100644 index 0000000000..25444044c2 --- /dev/null +++ b/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace + +import numpy as np +import torch + + +@dataclass +class WanEulerSchedulerOutput: + prev_sample: torch.FloatTensor + + +def _unsqueeze_to_ndim(in_tensor: torch.Tensor, target_ndim: int) -> torch.Tensor: + if in_tensor.ndim >= target_ndim: + return in_tensor + return in_tensor[(...,) + (None,) * (target_ndim - in_tensor.ndim)] + + +def _get_timesteps(num_steps: int, max_steps: int = 1000) -> np.ndarray: + # Keep num_steps + 1 points so Euler update can always access sigma_next. + return np.linspace(max_steps, 0, num_steps + 1, dtype=np.float32) + + +def _timestep_shift(timesteps: torch.Tensor, shift: float = 1.0) -> torch.Tensor: + return shift * timesteps / (1 + (shift - 1) * timesteps) + + +class WanEulerScheduler: + order = 1 + + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + device: torch.device | str = "cpu", + ) -> None: + self.num_train_timesteps = int(num_train_timesteps) + self._shift = float(shift) + self.device = device + self.config = SimpleNamespace(num_train_timesteps=self.num_train_timesteps) + self.init_noise_sigma = 1.0 + + self._step_index: int | None = None + self._begin_index: int | None = None + + self.timesteps = torch.empty(0, dtype=torch.float32) + self.sigmas = torch.empty(0, dtype=torch.float32) + self.timesteps_ori = torch.empty(0, dtype=torch.float32) + + self.set_timesteps(num_inference_steps=self.num_train_timesteps, device=self.device) + + @property + def step_index(self) -> int | None: + return self._step_index + + @property + def begin_index(self) -> int | None: + return self._begin_index + + def set_begin_index(self, begin_index: int = 0) -> None: + self._begin_index = int(begin_index) + + def index_for_timestep(self, timestep: torch.Tensor) -> int: + indices = (self.timesteps == timestep).nonzero() + if len(indices) > 0: + pos = 1 if len(indices) > 1 else 0 + return int(indices[pos].item()) + # Fallback for tiny float drift + return int(torch.argmin(torch.abs(self.timesteps - timestep)).item()) + + def _init_step_index(self, timestep: float | torch.Tensor) -> None: + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep_t = timestep.to(self.timesteps.device, dtype=self.timesteps.dtype) + else: + timestep_t = torch.tensor(timestep, device=self.timesteps.device, dtype=self.timesteps.dtype) + self._step_index = self.index_for_timestep(timestep_t) + else: + self._step_index = self._begin_index + + def set_shift(self, shift: float = 1.0) -> None: + # Compute shifted sigma schedule on [0, 1]. + sigmas_full = self.timesteps_ori / float(self.num_train_timesteps) + sigmas_full = _timestep_shift(sigmas_full, shift=float(shift)) + self.sigmas = sigmas_full + # Public timesteps are the first N points; next point is consumed as sigma_next. + self.timesteps = self.sigmas[:-1] * self.num_train_timesteps + self._shift = float(shift) + + def set_timesteps( + self, + num_inference_steps: int, + device: torch.device | str | int | None = None, + **kwargs, # noqa: ARG002 - kept for scheduler API compatibility + ) -> None: + timesteps = _get_timesteps( + num_steps=int(num_inference_steps), + max_steps=self.num_train_timesteps, + ) + self.timesteps_ori = torch.from_numpy(timesteps).to( + dtype=torch.float32, + device=device or self.device, + ) + self.set_shift(self._shift) + self._step_index = None + self._begin_index = None + + def scale_model_input(self, sample: torch.Tensor, timestep: int | None = None) -> torch.Tensor: # noqa: ARG002 + return sample + + def step( + self, + model_output: torch.FloatTensor, + timestep: float | torch.FloatTensor, + sample: torch.FloatTensor, + return_dict: bool = True, + **kwargs, # noqa: ARG002 - kept for scheduler API compatibility + ) -> WanEulerSchedulerOutput | tuple[torch.FloatTensor]: + if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): + raise ValueError( + "Passing integer indices as timesteps is not supported. Use one value from scheduler.timesteps instead." + ) + + if self.step_index is None: + self._init_step_index(timestep) + assert self._step_index is not None + + sample_fp32 = sample.to(torch.float32) + sigma = _unsqueeze_to_ndim(self.sigmas[self._step_index], sample_fp32.ndim).to(sample_fp32.device) + sigma_next = _unsqueeze_to_ndim(self.sigmas[self._step_index + 1], sample_fp32.ndim).to(sample_fp32.device) + + prev_sample = sample_fp32 + (sigma_next - sigma) * model_output + prev_sample = prev_sample.to(model_output.dtype) + + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + return WanEulerSchedulerOutput(prev_sample=prev_sample) + + def __len__(self) -> int: + return self.num_train_timesteps diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py index 65a2d4390a..d4d81b78eb 100644 --- a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py @@ -11,7 +11,6 @@ from diffusers.models.attention import FeedForward from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps from diffusers.models.modeling_outputs import Transformer2DModelOutput -from diffusers.models.normalization import FP32LayerNorm from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -29,6 +28,8 @@ SequenceParallelOutput, ) from vllm_omni.diffusion.forward_context import get_forward_context +from vllm_omni.diffusion.layers.adalayernorm import AdaLayerNorm +from vllm_omni.diffusion.layers.norm import LayerNorm, RMSNorm from vllm_omni.platforms import current_omni_platform logger = init_logger(__name__) @@ -235,9 +236,9 @@ class WanImageEmbedding(nn.Module): def __init__(self, in_features: int, out_features: int, pos_embed_seq_len: int | None = None): super().__init__() - self.norm1 = FP32LayerNorm(in_features) + self.norm1 = LayerNorm(in_features) self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu") - self.norm2 = FP32LayerNorm(out_features) + self.norm2 = LayerNorm(out_features) if pos_embed_seq_len is not None: self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features)) else: @@ -377,8 +378,12 @@ def __init__( self.tp_inner_dim = self.num_heads * head_dim # QK normalization using vLLM's RMSNorm - self.norm_q = DistributedRMSNorm(self.tp_inner_dim, eps=eps) - self.norm_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps) + if get_tensor_model_parallel_world_size() > 1: + self.norm_q = DistributedRMSNorm(self.tp_inner_dim, eps=eps) + self.norm_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps) + else: + self.norm_q = RMSNorm(self.tp_inner_dim, eps=eps) + self.norm_k = RMSNorm(self.tp_inner_dim, eps=eps) self.to_out = RowParallelLinear( self.inner_dim, @@ -497,8 +502,12 @@ def __init__( self.tp_inner_dim = self.num_heads * head_dim # QK normalization - self.norm_q = DistributedRMSNorm(self.tp_inner_dim, eps=eps) - self.norm_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps) + if get_tensor_model_parallel_world_size() > 1: + self.norm_q = DistributedRMSNorm(self.tp_inner_dim, eps=eps) + self.norm_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps) + else: + self.norm_q = RMSNorm(self.tp_inner_dim, eps=eps) + self.norm_k = RMSNorm(self.tp_inner_dim, eps=eps) # Optional added KV projections for I2V (image embeddings) self.added_kv_proj_dim = added_kv_proj_dim @@ -517,7 +526,10 @@ def __init__( gather_output=False, return_bias=False, ) - self.norm_added_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps) + if get_tensor_model_parallel_world_size() > 1: + self.norm_added_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps) + else: + self.norm_added_k = RMSNorm(self.tp_inner_dim, eps=eps) else: self.add_k_proj = None self.add_v_proj = None @@ -620,7 +632,7 @@ def __init__( head_dim = dim // num_heads # 1. Self-attention - self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.norm1 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) self.attn1 = WanSelfAttention( dim=dim, num_heads=num_heads, @@ -636,11 +648,11 @@ def __init__( eps=eps, added_kv_proj_dim=added_kv_proj_dim, ) - self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.norm2 = LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() # 3. Feed-forward self.ffn = WanFeedForward(dim=dim, inner_dim=ffn_dim, dim_out=dim) - self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.norm3 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) # Scale-shift table for modulation self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) @@ -656,7 +668,7 @@ def forward( if temb.ndim == 4: # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( - self.scale_shift_table.unsqueeze(0) + temb.float() + self.scale_shift_table.unsqueeze(0) + temb ).chunk(6, dim=2) shift_msa = shift_msa.squeeze(2) scale_msa = scale_msa.squeeze(2) @@ -667,25 +679,23 @@ def forward( else: # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( - self.scale_shift_table + temb.float() + self.scale_shift_table + temb ).chunk(6, dim=1) # 1. Self-attention - norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + norm_hidden_states = self.norm1(hidden_states, scale_msa, shift_msa).type_as(hidden_states) attn_output = self.attn1(norm_hidden_states, rotary_emb, hidden_states_mask) - hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states) # 2. Cross-attention - norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + norm_hidden_states = self.norm2(hidden_states).type_as(hidden_states) attn_output = self.attn2(norm_hidden_states, encoder_hidden_states) hidden_states = hidden_states + attn_output # 3. Feed-forward - norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( - hidden_states - ) + norm_hidden_states = self.norm3(hidden_states, c_scale_msa, c_shift_msa).type_as(hidden_states) ff_output = self.ffn(norm_hidden_states) - hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + hidden_states = (hidden_states + ff_output * c_gate_msa).type_as(hidden_states) return hidden_states @@ -854,7 +864,7 @@ def __init__( ) # 4. Output norm & projection - self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.norm_out = AdaLayerNorm(inner_dim, elementwise_affine=False, eps=eps) self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) # SP helper modules @@ -942,7 +952,7 @@ def forward( shift = shift.unsqueeze(1) scale = scale.unsqueeze(1) - hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.norm_out(hidden_states, scale, shift).type_as(hidden_states) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( @@ -1015,6 +1025,14 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if ".to_out.0." in lookup_name: lookup_name = lookup_name.replace(".to_out.0.", ".to_out.") + # Compatibility: some Wan conversion pipelines still keep + # block modulation keys as `blocks.N.modulation` instead of + # `blocks.N.scale_shift_table`. + if lookup_name.endswith(".modulation"): + modulation_alias = lookup_name[: -len(".modulation")] + ".scale_shift_table" + if modulation_alias in params_dict: + lookup_name = modulation_alias + if lookup_name not in params_dict: logger.warning(f"Skipping weight {original_name} -> {lookup_name}") continue diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py index 4f4217dabf..c48938e1ba 100644 --- a/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py @@ -239,7 +239,7 @@ def forward( shift = shift.unsqueeze(1) scale = scale.unsqueeze(1) - hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.norm_out(hidden_states, scale, shift).type_as(hidden_states) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( diff --git a/vllm_omni/diffusion/models/z_image/z_image_transformer.py b/vllm_omni/diffusion/models/z_image/z_image_transformer.py index 3ffad221ba..c36ea74665 100644 --- a/vllm_omni/diffusion/models/z_image/z_image_transformer.py +++ b/vllm_omni/diffusion/models/z_image/z_image_transformer.py @@ -214,12 +214,14 @@ def __init__( super().__init__() if mid_size is None: mid_size = out_size + # Time embedding MLP is kept full precision (quant_config=None) — + # small layers that feed adaLN; precision-sensitive (see #2728). self.mlp = nn.Sequential( ReplicatedLinear( frequency_embedding_size, mid_size, bias=True, - quant_config=quant_config, + quant_config=None, return_bias=False, ), nn.SiLU(), @@ -227,7 +229,7 @@ def __init__( mid_size, out_size, bias=True, - quant_config=quant_config, + quant_config=None, return_bias=False, ), ) @@ -426,9 +428,16 @@ def __init__( self.modulation = modulation if modulation: + # Modulation linear is kept at full precision (quant_config=None) + # — it produces scale/gate values that are precision-sensitive + # (see #2728, mirrors OmniGen2 fix). self.adaLN_modulation = nn.Sequential( ReplicatedLinear( - min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True, return_bias=False, quant_config=quant_config + min(dim, ADALN_EMBED_DIM), + 4 * dim, + bias=True, + quant_config=None, + return_bias=False, ), ) @@ -485,14 +494,24 @@ class FinalLayer(nn.Module): def __init__(self, hidden_size, out_channels, quant_config: "QuantizationConfig | None" = None): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + # Final output projection and its modulation are precision-sensitive + # (produce the output latent); keep at full precision (see #2728). self.linear = ReplicatedLinear( - hidden_size, out_channels, bias=True, quant_config=quant_config, return_bias=False + hidden_size, + out_channels, + bias=True, + quant_config=None, + return_bias=False, ) self.adaLN_modulation = nn.Sequential( nn.SiLU(), ReplicatedLinear( - min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True, quant_config=quant_config, return_bias=False + min(hidden_size, ADALN_EMBED_DIM), + hidden_size, + bias=True, + quant_config=None, + return_bias=False, ), ) @@ -673,11 +692,13 @@ def __init__( all_x_embedder = {} all_final_layer = {} for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + # x_embedder (patch embed) is a small precision-sensitive entry + # layer; keep full precision (see #2728). x_embedder = ReplicatedLinear( f_patch_size * patch_size * patch_size * in_channels, dim, bias=True, - quant_config=quant_config, + quant_config=None, return_bias=False, ) all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder @@ -720,9 +741,17 @@ def __init__( ] ) self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024, quant_config=quant_config) + # Caption embedder maps text features -> hidden; keep full precision + # (see #2728). self.cap_embedder = nn.Sequential( RMSNorm(cap_feat_dim, eps=norm_eps), - ReplicatedLinear(cap_feat_dim, dim, bias=True, return_bias=False, quant_config=quant_config), + ReplicatedLinear( + cap_feat_dim, + dim, + bias=True, + quant_config=None, + return_bias=False, + ), ) self.x_pad_token = nn.Parameter(torch.empty((1, dim))) diff --git a/vllm_omni/diffusion/postprocess/__init__.py b/vllm_omni/diffusion/postprocess/__init__.py new file mode 100644 index 0000000000..e6fe5b2d22 --- /dev/null +++ b/vllm_omni/diffusion/postprocess/__init__.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Diffusion post-processing helpers.""" + +from vllm_omni.diffusion.postprocess.rife_interpolator import ( + FrameInterpolator, + interpolate_video_tensor, +) + +__all__ = ["FrameInterpolator", "interpolate_video_tensor"] diff --git a/vllm_omni/diffusion/postprocess/rife_interpolator.py b/vllm_omni/diffusion/postprocess/rife_interpolator.py new file mode 100644 index 0000000000..89297d0a44 --- /dev/null +++ b/vllm_omni/diffusion/postprocess/rife_interpolator.py @@ -0,0 +1,443 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +RIFE 4.22.lite frame interpolation for vLLM-Omni video generation. + +RIFE model code is vendored and adapted from: + - https://github.com/hzwer/ECCV2022-RIFE (MIT License) + - https://github.com/hzwer/Practical-RIFE (MIT License) + Copyright (c) 2021 Zhewei Huang + +The FrameInterpolator wrapper and vLLM-Omni integration code are original work. +""" + +from __future__ import annotations + +import os +import threading +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +from vllm.logger import init_logger + +logger = init_logger(__name__) + +_DEFAULT_RIFE_HF_REPO = "elfgum/RIFE-4.22.lite" +_MODEL_CACHE: dict[tuple[str, str], Model] = {} +_MODEL_CACHE_LOCK = threading.Lock() + + +def warp(ten_input: torch.Tensor, ten_flow: torch.Tensor) -> torch.Tensor: + """Warp input tensor by optical flow using grid_sample.""" + ten_horizontal = ( + torch.linspace(-1.0, 1.0, ten_flow.shape[3], device=ten_flow.device) + .view(1, 1, 1, ten_flow.shape[3]) + .expand(ten_flow.shape[0], -1, ten_flow.shape[2], -1) + ) + ten_vertical = ( + torch.linspace(-1.0, 1.0, ten_flow.shape[2], device=ten_flow.device) + .view(1, 1, ten_flow.shape[2], 1) + .expand(ten_flow.shape[0], -1, -1, ten_flow.shape[3]) + ) + ten_grid = torch.cat([ten_horizontal, ten_vertical], dim=1) + + ten_flow = torch.cat( + [ + ten_flow[:, 0:1, :, :] / ((ten_input.shape[3] - 1.0) / 2.0), + ten_flow[:, 1:2, :, :] / ((ten_input.shape[2] - 1.0) / 2.0), + ], + dim=1, + ) + grid = (ten_grid + ten_flow).permute(0, 2, 3, 1) + return F.grid_sample( + input=ten_input, + grid=grid, + mode="bilinear", + padding_mode="border", + align_corners=True, + ) + + +def _conv( + in_planes: int, + out_planes: int, + kernel_size: int = 3, + stride: int = 1, + padding: int = 1, + dilation: int = 1, +) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ), + nn.LeakyReLU(0.2, True), + ) + + +class ResConv(nn.Module): + """Residual convolution block with learnable beta scaling.""" + + def __init__(self, c: int, dilation: int = 1): + super().__init__() + self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.relu(self.conv(x) * self.beta + x) + + +class IFBlock(nn.Module): + """Single-scale optical flow, mask, and feature block.""" + + def __init__(self, in_planes: int, c: int = 64): + super().__init__() + self.conv0 = nn.Sequential( + _conv(in_planes, c // 2, 3, 2, 1), + _conv(c // 2, c, 3, 2, 1), + ) + self.convblock = nn.Sequential( + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ) + self.lastconv = nn.Sequential( + nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1), + nn.PixelShuffle(2), + ) + + def forward( + self, + x: torch.Tensor, + flow: torch.Tensor | None = None, + scale: float = 1.0, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) + if flow is not None: + flow = ( + F.interpolate( + flow, + scale_factor=1.0 / scale, + mode="bilinear", + align_corners=False, + ) + * 1.0 + / scale + ) + x = torch.cat((x, flow), 1) + feat = self.conv0(x) + feat = self.convblock(feat) + tmp = self.lastconv(feat) + tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False) + flow = tmp[:, :4] * scale + mask = tmp[:, 4:5] + feat = tmp[:, 5:] + return flow, mask, feat + + +class Head(nn.Module): + """Feature encoder producing four-channel features at full resolution.""" + + def __init__(self): + super().__init__() + self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1) + self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1) + self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1) + self.cnn3 = nn.ConvTranspose2d(16, 4, 4, 2, 1) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x0 = self.cnn0(x) + x = self.relu(x0) + x1 = self.cnn1(x) + x = self.relu(x1) + x2 = self.cnn2(x) + x = self.relu(x2) + x3 = self.cnn3(x) + return x3 + + +class IFNet(nn.Module): + """Four-scale IFNet optical flow network.""" + + def __init__(self): + super().__init__() + self.block0 = IFBlock(7 + 8, c=192) + self.block1 = IFBlock(8 + 4 + 8 + 8, c=128) + self.block2 = IFBlock(8 + 4 + 8 + 8, c=64) + self.block3 = IFBlock(8 + 4 + 8 + 8, c=32) + self.encode = Head() + + def forward( + self, + x: torch.Tensor, + timestep: float = 0.5, + scale_list: list[float] | None = None, + ) -> tuple[list[torch.Tensor], torch.Tensor, list[tuple[torch.Tensor, torch.Tensor] | torch.Tensor]]: + if scale_list is None: + scale_list = [8, 4, 2, 1] + + channel = x.shape[1] // 2 + img0 = x[:, :channel] + img1 = x[:, channel:] + + if not torch.is_tensor(timestep): + timestep = (x[:, :1].clone() * 0 + 1) * timestep + else: + timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) + + f0 = self.encode(img0[:, :3]) + f1 = self.encode(img1[:, :3]) + + flow_list: list[torch.Tensor] = [] + merged: list[tuple[torch.Tensor, torch.Tensor] | torch.Tensor] = [] + mask_list: list[torch.Tensor] = [] + warped_img0 = img0 + warped_img1 = img1 + flow = None + mask = None + + for i, block in enumerate([self.block0, self.block1, self.block2, self.block3]): + if flow is None: + flow, mask, feat = block( + torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1), + None, + scale=scale_list[i], + ) + else: + wf0 = warp(f0, flow[:, :2]) + wf1 = warp(f1, flow[:, 2:4]) + fd, m0, feat = block( + torch.cat( + ( + warped_img0[:, :3], + warped_img1[:, :3], + wf0, + wf1, + timestep, + mask, + feat, + ), + 1, + ), + flow, + scale=scale_list[i], + ) + mask = m0 + flow = flow + fd + + mask_list.append(mask) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + merged.append((warped_img0, warped_img1)) + + mask = torch.sigmoid(mask) + merged[3] = warped_img0 * mask + warped_img1 * (1 - mask) + return flow_list, mask_list[3], merged + + +class Model: + """Wraps IFNet and exposes RIFE-compatible load/inference helpers.""" + + def __init__(self): + self.flownet = IFNet() + + def eval(self) -> Model: + self.flownet.eval() + return self + + def device(self) -> torch.device: + return next(self.flownet.parameters()).device + + def load_model(self, path: str) -> None: + flownet_path = os.path.join(path, "flownet.pkl") + if not os.path.isfile(flownet_path): + raise FileNotFoundError( + f"RIFE weight file not found: {flownet_path}. Expected layout: /flownet.pkl" + ) + + state = torch.load(flownet_path, map_location="cpu", weights_only=False) + state = {k.removeprefix("module."): v for k, v in state.items()} + self.flownet.load_state_dict(state, strict=False) + logger.info("Loaded RIFE weights from %s", flownet_path) + + def inference( + self, + img0: torch.Tensor, + img1: torch.Tensor, + scale: float = 1.0, + timestep: float = 0.5, + ) -> torch.Tensor: + _n, _c, h, w = img0.shape + ph = ((h - 1) // 32 + 1) * 32 + pw = ((w - 1) // 32 + 1) * 32 + pad = (0, pw - w, 0, ph - h) + img0 = F.pad(img0, pad) + img1 = F.pad(img1, pad) + + imgs = torch.cat((img0, img1), 1) + scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] + with torch.no_grad(): + _flow_list, _mask, merged = self.flownet( + imgs, + timestep=timestep, + scale_list=scale_list, + ) + return merged[3][:, :, :h, :w] + + +def _resolve_rife_model_path(model_path: str | None) -> str: + model_path = model_path or _DEFAULT_RIFE_HF_REPO + if os.path.isdir(model_path): + return model_path + from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, + ) + + return download_weights_from_hf_specific( + model_path, + cache_dir=None, + allow_patterns=["flownet.pkl"], + require_all=True, + ) + + +def _select_torch_device() -> torch.device: + try: + from vllm_omni.platforms import current_omni_platform + + return current_omni_platform.get_torch_device() + except Exception as exc: + logger.warning("Failed to resolve current vLLM-Omni torch device: %s", exc) + + if torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +def _normalize_video_tensor_layout(video: torch.Tensor) -> tuple[torch.Tensor, Any]: + if video.ndim == 5: + if video.shape[1] in (3, 4): + return video, lambda out: out + if video.shape[2] in (3, 4): + return video.permute(0, 2, 1, 3, 4), lambda out: out.permute(0, 2, 1, 3, 4) + elif video.ndim == 4: + if video.shape[0] in (3, 4): + return video.unsqueeze(0), lambda out: out.squeeze(0) + if video.shape[1] in (3, 4): + return video.permute(1, 0, 2, 3).unsqueeze(0), lambda out: out.squeeze(0).permute(1, 0, 2, 3) + raise ValueError(f"Unsupported video tensor shape for interpolation: {tuple(video.shape)}") + + +def _normalize_video_tensor_range(video: torch.Tensor) -> tuple[torch.Tensor, Any]: + original_dtype = video.dtype + video = video.detach() + if video.is_floating_point(): + video = video.to(torch.float32) + if torch.amin(video) < 0.0 or torch.amax(video) > 1.0: + return video.clamp(-1.0, 1.0) * 0.5 + 0.5, lambda out: (out * 2.0 - 1.0).to(original_dtype) + return video.clamp(0.0, 1.0), lambda out: out.to(original_dtype) + return video.to(torch.float32) / 255.0, lambda out: (out * 255.0).round().clamp(0, 255).to(original_dtype) + + +class FrameInterpolator: + """Lazy-loaded RIFE 4.22.lite frame interpolator.""" + + def __init__(self, model_path: str | None = None): + self._model_path = model_path + self._resolved_path: str | None = None + + def _ensure_model_loaded(self, preferred_device: torch.device | None = None) -> Model: + resolved_path = _resolve_rife_model_path(self._model_path) + self._resolved_path = resolved_path + device = preferred_device or _select_torch_device() + cache_key = (resolved_path, str(device)) + + with _MODEL_CACHE_LOCK: + if cache_key in _MODEL_CACHE: + return _MODEL_CACHE[cache_key] + + model = Model() + model.load_model(resolved_path) + model.eval() + model.flownet = model.flownet.to(device) + _MODEL_CACHE[cache_key] = model + logger.info("RIFE model loaded on device: %s", device) + return model + + def _make_inference( + self, + model: Model, + img0: torch.Tensor, + img1: torch.Tensor, + n: int, + scale: float, + ) -> list[torch.Tensor]: + if n == 1: + return [model.inference(img0, img1, scale=scale)] + mid = model.inference(img0, img1, scale=scale) + return ( + self._make_inference(model, img0, mid, n // 2, scale) + + [mid] + + self._make_inference(model, mid, img1, n // 2, scale) + ) + + def interpolate_tensor( + self, + video: torch.Tensor, + exp: int = 1, + scale: float = 1.0, + ) -> tuple[torch.Tensor, int]: + if exp < 1: + raise ValueError(f"frame interpolation exp must be >= 1, got {exp}") + if scale <= 0: + raise ValueError(f"frame interpolation scale must be > 0, got {scale}") + + video, restore_layout = _normalize_video_tensor_layout(video) + if video.shape[2] < 2: + return restore_layout(video), 1 + + video, restore_range = _normalize_video_tensor_range(video) + # A CPU tensor may be transport/offload state rather than an execution + # choice, so only trust it when it is already on an accelerator. + preferred_device = video.device + if preferred_device.type == "cpu": + preferred_device = _select_torch_device() + model = self._ensure_model_loaded(preferred_device=preferred_device) + video = video.to(model.device()) + intermediates_per_pair = 2**exp // 2 + + result_frames: list[torch.Tensor] = [] + for idx in range(video.shape[2] - 1): + img0 = video[:, :, idx, :, :] + img1 = video[:, :, idx + 1, :, :] + result_frames.append(img0) + result_frames.extend(self._make_inference(model, img0, img1, intermediates_per_pair, scale)) + result_frames.append(video[:, :, -1, :, :]) + result = torch.stack(result_frames, dim=2) + return restore_layout(restore_range(result)), 2**exp + + +def interpolate_video_tensor( + video: torch.Tensor, + exp: int = 1, + scale: float = 1.0, + model_path: str | None = None, +) -> tuple[torch.Tensor, int]: + """Interpolate a video tensor and return the FPS multiplier.""" + interpolator = FrameInterpolator(model_path=model_path) + return interpolator.interpolate_tensor(video, exp=exp, scale=scale) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 97bc7fa292..0bf8c04517 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -119,8 +119,8 @@ "FluxKontextPipeline", ), "HunyuanImage3ForCausalMM": ( - "hunyuan_image_3", - "pipeline_hunyuan_image_3", + "hunyuan_image3", + "pipeline_hunyuan_image3", "HunyuanImage3Pipeline", ), "Flux2KleinPipeline": ( @@ -375,6 +375,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "HunyuanVideo15ImageToVideoPipeline": "get_hunyuan_video_15_i2v_post_process_func", "MagiHumanPipeline": "get_magi_human_post_process_func", "OmniVoicePipeline": "get_omnivoice_post_process_func", + "DreamIDOmniPipeline": "get_dreamid_omni_post_process_func", } _DIFFUSION_PRE_PROCESS_FUNCS = { diff --git a/vllm_omni/diffusion/stage_diffusion_client.py b/vllm_omni/diffusion/stage_diffusion_client.py index cd7159b683..480d113d19 100644 --- a/vllm_omni/diffusion/stage_diffusion_client.py +++ b/vllm_omni/diffusion/stage_diffusion_client.py @@ -34,6 +34,24 @@ logger = init_logger(__name__) +def create_diffusion_client( + model: str, + od_config: OmniDiffusionConfig, + metadata: StageMetadata, + stage_init_timeout: int, + batch_size: int = 1, + use_inline: bool = False, +) -> Any: + """Factory to create either an inline or out-of-process diffusion client.""" + if use_inline: + from vllm_omni.diffusion.inline_stage_diffusion_client import InlineStageDiffusionClient + + return InlineStageDiffusionClient(model, od_config, metadata, batch_size=batch_size) + return StageDiffusionClient( + model, od_config, metadata, stage_init_timeout=stage_init_timeout, batch_size=batch_size + ) + + class StageDiffusionClient: """Communicates with StageDiffusionProc via ZMQ for use inside the Orchestrator. @@ -50,11 +68,12 @@ def __init__( model: str, od_config: OmniDiffusionConfig, metadata: StageMetadata, + stage_init_timeout: int, batch_size: int = 1, ) -> None: # Spawn StageDiffusionProc subprocess and wait for READY. proc, handshake_address, request_address, response_address = spawn_diffusion_proc(model, od_config) - complete_diffusion_handshake(proc, handshake_address) + complete_diffusion_handshake(proc, handshake_address, stage_init_timeout) self._initialize_client(metadata, request_address, response_address, proc=proc, batch_size=batch_size) @classmethod @@ -153,6 +172,13 @@ def _drain_responses(self) -> None: "error": True, "reason": error_msg, } + elif req_id is not None: + error_output = OmniRequestOutput.from_diffusion( + request_id=req_id, + images=[], + ) + error_output.error = error_msg + self._output_queue.put_nowait(error_output) # Fields that are subprocess-local and cannot be serialized across # process boundaries. They are recreated in the subprocess with diff --git a/vllm_omni/diffusion/stage_diffusion_proc.py b/vllm_omni/diffusion/stage_diffusion_proc.py index 2bba419250..eced444fd3 100644 --- a/vllm_omni/diffusion/stage_diffusion_proc.py +++ b/vllm_omni/diffusion/stage_diffusion_proc.py @@ -19,12 +19,11 @@ import zmq.asyncio from PIL import Image from vllm.logger import init_logger -from vllm.transformers_utils.config import get_hf_file_to_dict from vllm.utils.network_utils import get_open_zmq_ipc_path, zmq_socket_ctx from vllm.utils.system_utils import get_mp_context from vllm.v1.utils import shutdown -from vllm_omni.diffusion.data import DiffusionRequestAbortedError, TransformerConfig +from vllm_omni.diffusion.data import DiffusionRequestAbortedError from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.distributed.omni_connectors.utils.serialization import ( @@ -39,8 +38,6 @@ logger = init_logger(__name__) -_HANDSHAKE_POLL_TIMEOUT_S = 600 - class StageDiffusionProc: """Subprocess entry point for diffusion inference. @@ -68,47 +65,8 @@ def initialize(self) -> None: logger.info("StageDiffusionProc initialized with model: %s", self._model) def _enrich_config(self) -> None: - """Load model metadata from HuggingFace and populate od_config fields. - - Diffusers-style models expose ``model_index.json`` with ``_class_name``. - Non-diffusers models (e.g. Bagel, NextStep) only have ``config.json``, - so we fall back to reading that and mapping model_type manually. - """ - od_config = self._od_config - - try: - config_dict = get_hf_file_to_dict("model_index.json", od_config.model) - if config_dict is not None: - if od_config.model_class_name is None: - od_config.model_class_name = config_dict.get("_class_name", None) - od_config.update_multimodal_support() - - tf_config_dict = get_hf_file_to_dict("transformer/config.json", od_config.model) - od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) - else: - raise FileNotFoundError("model_index.json not found") - except (AttributeError, OSError, ValueError, FileNotFoundError): - cfg = get_hf_file_to_dict("config.json", od_config.model) - if cfg is None: - raise ValueError(f"Could not find config.json or model_index.json for model {od_config.model}") - - od_config.tf_model_config = TransformerConfig.from_dict(cfg) - model_type = cfg.get("model_type") - architectures = cfg.get("architectures") or [] - - if model_type == "bagel" or "BagelForConditionalGeneration" in architectures: - od_config.model_class_name = "BagelPipeline" - od_config.tf_model_config = TransformerConfig() - od_config.update_multimodal_support() - elif model_type == "nextstep": - if od_config.model_class_name is None: - od_config.model_class_name = "NextStep11Pipeline" - od_config.tf_model_config = TransformerConfig() - od_config.update_multimodal_support() - elif architectures and len(architectures) == 1: - od_config.model_class_name = architectures[0] - else: - raise + """Load model metadata from HuggingFace and populate od_config fields.""" + self._od_config.enrich_config() # ------------------------------------------------------------------ # Request processing @@ -619,13 +577,14 @@ def spawn_diffusion_proc( def complete_diffusion_handshake( proc: BaseProcess, handshake_address: str, + handshake_timeout: int, ) -> None: """Wait for the diffusion subprocess to signal READY. On failure the process is terminated before re-raising. """ try: - _perform_diffusion_handshake(proc, handshake_address) + _perform_diffusion_handshake(proc, handshake_address, handshake_timeout) except Exception: shutdown([proc]) raise @@ -634,6 +593,7 @@ def complete_diffusion_handshake( def _perform_diffusion_handshake( proc: BaseProcess, handshake_address: str, + handshake_timeout: int, ) -> None: """Run the handshake with the diffusion subprocess.""" with zmq_socket_ctx(handshake_address, zmq.ROUTER, bind=True) as handshake_socket: @@ -641,11 +601,15 @@ def _perform_diffusion_handshake( poller.register(handshake_socket, zmq.POLLIN) poller.register(proc.sentinel, zmq.POLLIN) - timeout_ms = _HANDSHAKE_POLL_TIMEOUT_S * 1000 + timeout_ms = handshake_timeout * 1000 while True: events = dict(poller.poll(timeout=timeout_ms)) if not events: - raise TimeoutError("Timed out waiting for READY from StageDiffusionProc") + raise TimeoutError( + f"Timed out waiting for READY from StageDiffusionProc after {handshake_timeout}s. " + f"This typically indicates model loading or warmup is taking too long. " + f"Consider increasing `stage_init_timeout` for large models." + ) if handshake_socket in events: identity, raw = handshake_socket.recv_multipart() msg = msgspec.msgpack.decode(raw) diff --git a/vllm_omni/diffusion/utils/media_utils.py b/vllm_omni/diffusion/utils/media_utils.py index f96a28fbd7..a09cd45953 100644 --- a/vllm_omni/diffusion/utils/media_utils.py +++ b/vllm_omni/diffusion/utils/media_utils.py @@ -20,6 +20,7 @@ def mux_video_audio_bytes( video_codec: str = "h264", audio_codec: str = "aac", crf: str = "18", + video_codec_options: dict[str, str] | None = None, ) -> bytes: """Mux video frames and optional audio waveform into MP4 bytes. @@ -42,7 +43,11 @@ def mux_video_audio_bytes( v_stream.width = video_frames.shape[2] v_stream.height = video_frames.shape[1] v_stream.pix_fmt = "yuv420p" - v_stream.options = {"crf": crf} + + options = {"crf": str(crf)} + if video_codec_options: + options.update(video_codec_options) + v_stream.options = options a_stream = None if audio_waveform is not None: diff --git a/vllm_omni/diffusion/utils/prompt_utils.py b/vllm_omni/diffusion/utils/prompt_utils.py new file mode 100644 index 0000000000..fc1769f4d5 --- /dev/null +++ b/vllm_omni/diffusion/utils/prompt_utils.py @@ -0,0 +1,38 @@ +import torch + + +def validate_prompt_sequence_lengths( + attention_mask: torch.Tensor, + *, + max_sequence_length: int, + supported_max_sequence_length: int, + prompt_name: str = "prompt", + length_offset: int = 0, + baseline_attention_mask: torch.Tensor | None = None, + error_context: str, +) -> None: + sequence_lengths = attention_mask.sum(dim=1) + if baseline_attention_mask is not None: + # Some callers need to validate only the user-controlled portion of a + # templated prompt. In those cases we subtract the fully-tokenized + # template baseline instead of only removing a fixed prefix length, + # because the template may also contribute a suffix or image markers. + baseline_lengths = baseline_attention_mask.sum(dim=1) + if baseline_lengths.shape[0] == 1 and sequence_lengths.shape[0] > 1: + baseline_lengths = baseline_lengths.expand(sequence_lengths.shape[0]) + sequence_lengths = sequence_lengths - baseline_lengths + if length_offset: + sequence_lengths = sequence_lengths - length_offset + sequence_lengths = torch.clamp(sequence_lengths, min=0) + too_long = torch.nonzero(sequence_lengths > max_sequence_length, as_tuple=False) + if too_long.numel() == 0: + return + + batch_idx = int(too_long[0].item()) + actual_length = int(sequence_lengths[batch_idx].item()) + prompt_ref = f"`{prompt_name}` at batch index {batch_idx}" if attention_mask.shape[0] > 1 else f"`{prompt_name}`" + raise ValueError( + f"{prompt_ref} is too long {error_context}: got {actual_length} tokens, but " + f"`max_sequence_length` is {max_sequence_length}. Shorten the prompt or increase " + f"`max_sequence_length` up to {supported_max_sequence_length}." + ) diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 32ea5bf64d..535f053c38 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -35,11 +35,12 @@ from vllm_omni.diffusion.worker.utils import DiffusionRequestState, RunnerOutput from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager from vllm_omni.platforms import current_omni_platform +from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin logger = init_logger(__name__) -class DiffusionModelRunner: +class DiffusionModelRunner(OmniConnectorModelRunnerMixin): """ Model runner that handles model loading and execution for diffusion models. diff --git a/vllm_omni/diffusion/worker/diffusion_worker.py b/vllm_omni/diffusion/worker/diffusion_worker.py index ea4b9d96f7..160309e0d8 100644 --- a/vllm_omni/diffusion/worker/diffusion_worker.py +++ b/vllm_omni/diffusion/worker/diffusion_worker.py @@ -20,6 +20,7 @@ from vllm.config import CompilationConfig, DeviceConfig, VllmConfig, set_current_vllm_config from vllm.distributed.device_communicators.shm_broadcast import MessageQueue from vllm.logger import init_logger +from vllm.profiler.wrapper import CudaProfilerWrapper, WorkerProfiler from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.mem_utils import GiB_bytes from vllm.v1.worker.workspace import init_workspace_manager @@ -83,15 +84,7 @@ def __init__( od_config=self.od_config, device=self.device, ) - # Initialize profiler if configured - self.profiler: OmniTorchProfilerWrapper | None = None - profiler_config = self.od_config.profiler_config - if profiler_config and profiler_config.profiler == "torch": - self.profiler = create_omni_profiler( - profiler_config=profiler_config, - worker_name=f"diffusion_worker_{self.rank}", - local_rank=self.local_rank, - ) + self.profiler: WorkerProfiler | None = self._create_profiler() if not skip_load_model: self.load_model(load_format=self.od_config.diffusion_load_format) self.init_lora_manager() @@ -122,6 +115,7 @@ def init_device(self) -> None: vllm_config.parallel_config.tensor_parallel_size = self.od_config.parallel_config.tensor_parallel_size vllm_config.parallel_config.data_parallel_size = self.od_config.parallel_config.data_parallel_size vllm_config.parallel_config.enable_expert_parallel = self.od_config.parallel_config.enable_expert_parallel + vllm_config.profiler_config = self.od_config.profiler_config self.vllm_config = vllm_config # Initialize distributed environment @@ -147,6 +141,24 @@ def init_device(self) -> None: ) init_workspace_manager(self.device) + def _create_profiler(self) -> WorkerProfiler | None: + profiler_config = self.od_config.profiler_config + profiler_type = getattr(profiler_config, "profiler", None) + if profiler_type == "torch": + return create_omni_profiler( + profiler_config=profiler_config, + worker_name=f"diffusion_rank{self.rank}", + local_rank=self.local_rank, + ) + if profiler_type == "cuda": + return CudaProfilerWrapper(profiler_config) + if profiler_type is not None: + logger.warning("Unknown profiler backend %r on diffusion worker %s", profiler_type, self.rank) + return None + + def _get_profiler(self) -> WorkerProfiler | None: + return getattr(self, "profiler", None) + def load_model(self, load_format: str = "default", custom_pipeline_name: str | None = None) -> None: """Load the diffusion model using DiffusionModelRunner.""" with ( @@ -192,27 +204,21 @@ def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> N Args: is_start: True to start profiling, False to stop. - profile_prefix: Optional prefix for trace filename (vLLM compat). - - Note: - Matches vLLM's worker.profile() signature for consistency. - Traces are saved automatically via on_trace_ready callback. + profile_prefix: Optional prefix for trace filename. """ - if self.profiler is None: - logger.warning("Profiler not initialized, skipping profile(%s)", is_start) + profiler = self._get_profiler() + if profiler is None: return if is_start: - from vllm_omni.profiler import OmniTorchProfilerWrapper - - if isinstance(self.profiler, OmniTorchProfilerWrapper): + if isinstance(profiler, OmniTorchProfilerWrapper): import time - filename = profile_prefix or f"diffusion_{int(time.time())}" - self.profiler.set_trace_filename(filename) - self.profiler.start() + filename = profile_prefix or f"diffusion_rank{self.rank}_{int(time.time())}" + profiler.set_trace_filename(filename) + profiler.start() else: - self.profiler.stop() + profiler.stop() def execute_model(self, req: OmniDiffusionRequest, od_config: OmniDiffusionConfig) -> DiffusionOutput: """Execute a forward pass by delegating to the model runner.""" @@ -224,7 +230,13 @@ def execute_model(self, req: OmniDiffusionRequest, od_config: OmniDiffusionConfi if req.sampling_params.lora_request is not None: raise logger.warning("LoRA activation skipped: %s", exc) - return self.model_runner.execute_model(req) + profiler = self._get_profiler() + ctx = profiler.annotate_context_manager("diffusion_forward") if profiler else nullcontext() + with ctx: + output = self.model_runner.execute_model(req) + if profiler: + profiler.step() + return output def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: """Execute one diffusion step by delegating to the model runner.""" @@ -236,8 +248,13 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> Runner if any(new_req.req.sampling_params.lora_request is not None for new_req in scheduler_output.scheduled_new_reqs): raise ValueError("Step mode does not support LoRA yet.") - - return self.model_runner.execute_stepwise(scheduler_output) + profiler = self._get_profiler() + ctx = profiler.annotate_context_manager("diffusion_step") if profiler else nullcontext() + with ctx: + output = self.model_runner.execute_stepwise(scheduler_output) + if profiler: + profiler.step() + return output def load_weights(self, weights) -> set[str]: """Load weights by delegating to the model runner.""" diff --git a/vllm_omni/distributed/omni_connectors/connectors/base.py b/vllm_omni/distributed/omni_connectors/connectors/base.py index 83edb2ab0a..0df428f2ff 100644 --- a/vllm_omni/distributed/omni_connectors/connectors/base.py +++ b/vllm_omni/distributed/omni_connectors/connectors/base.py @@ -34,13 +34,21 @@ def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[ pass @abstractmethod - def get(self, from_stage: str, to_stage: str, get_key: str, metadata=None) -> tuple[Any, int] | None: + def get( + self, from_stage: str, to_stage: str, get_key: str, metadata: dict[str, Any] | None = None + ) -> tuple[Any, int] | None: """Retrieve Python object and payload size (bytes). Args: from_stage: Source stage identifier to_stage: Destination stage identifier get_key: Unique request identifier + metadata: Optional transport-specific metadata. When provided, + the connector uses it directly (e.g. source_host, source_port, + data_size) instead of querying the sender. For heterogeneous + TP the manager may supply partial metadata (host/port only); + the connector will query the sender at that address to fill + in data_size. Returns: Tuple of (Python object, serialized byte size) if found, None otherwise diff --git a/vllm_omni/distributed/omni_connectors/connectors/mooncake_store_connector.py b/vllm_omni/distributed/omni_connectors/connectors/mooncake_store_connector.py index c672e35f79..fa1fc3286d 100644 --- a/vllm_omni/distributed/omni_connectors/connectors/mooncake_store_connector.py +++ b/vllm_omni/distributed/omni_connectors/connectors/mooncake_store_connector.py @@ -78,7 +78,24 @@ def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[ try: serialized_data = self.serialize_obj(data) key = self._make_key(put_key, from_stage, to_stage) - self.store.put(key, serialized_data, self.pin) + put_rc = self.store.put(key, serialized_data, self.pin) + + if isinstance(put_rc, bool): + put_ok = put_rc + else: + put_ok = put_rc is None or put_rc == 0 + + if not put_ok: + self._metrics["errors"] += 1 + logger.error( + "MooncakeStoreConnector put failed for %s (%s -> %s), rc=%r, %d bytes", + key, + from_stage, + to_stage, + put_rc, + len(serialized_data), + ) + return False, 0, None self._metrics["puts"] += 1 self._metrics["bytes_transferred"] += len(serialized_data) diff --git a/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py b/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py index 96a528963f..bd4160f3e6 100644 --- a/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py +++ b/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py @@ -230,16 +230,19 @@ class MooncakeTransferEngineConnector(OmniConnectorBase): sender immediately cleans up the buffer (``cleanup()``), so only the first receiver to pull a given key will succeed. Broadcast / multicast (1 sender → N receivers sharing the same data) is not yet supported. - - **1 receiver → 1 sender**: ``update_sender_info()`` stores a single - ``(sender_host, sender_zmq_port)`` pair, so a receiver can only query - metadata from one sender at a time. + - **1 receiver → N senders**: Supported via partial metadata. The + manager constructs metadata with the target sender's + ``source_host`` / ``source_port`` (computed from ``from_rank``) + and passes it to ``get(metadata=...)``. The connector detects + that ``data_size`` is missing, queries the specified sender at + the given address to fill it in, then performs the RDMA pull. + This enables heterogeneous TP (sender TP > receiver TP) where a + single receiver must pull KV shards from multiple sender ranks. Future work: - Support 1 sender → N receivers (e.g. reference-counted buffers, or explicit ``retain()`` / ``release()`` semantics so the buffer survives multiple pulls). - - Support 1 receiver → N senders (e.g. a sender registry mapping - ``get_key`` prefixes to different sender endpoints). """ # RDMA connector copies raw bytes/tensor directly to the memory pool @@ -267,6 +270,7 @@ def __init__(self, config: dict[str, Any]): self._req_local = threading.local() self._worker_local = threading.local() self._last_ttl_check: float = _time_mod.monotonic() + self._sender_endpoints: dict[int, tuple[str, int]] = {} self._metrics = { "puts": 0, @@ -408,16 +412,38 @@ def get_connection_info(self) -> dict[str, Any]: "can_put": self.can_put, } - def update_sender_info(self, sender_host: str, sender_zmq_port: int) -> None: - """ - Inject the sender's ZMQ endpoint into the receiver connector. - Used for NO METADATA GET calls.(E.g: KV-cache transfer path) - Must be called before using get() without metadata! - Otherwise, get() will raise an error. + def update_sender_info( + self, + sender_host: str, + sender_zmq_port: int, + sender_rank: int | None = None, + ) -> None: + """Inject a sender's ZMQ endpoint into the receiver connector. + + When ``sender_rank`` is ``None`` (default), sets the single default + sender used by ``get()`` when no rank is specified — this preserves + backward-compatible 1:1 semantics. + + When ``sender_rank`` is an integer, the endpoint is stored in a + per-rank registry for internal use (e.g. by + ``_query_metadata_from_sender(sender_rank=R)``). """ - self.sender_host = sender_host - self.sender_zmq_port = sender_zmq_port - logger.info(f"Sender info updated: host={sender_host!r}, zmq_port={sender_zmq_port}") + if sender_rank is not None: + self._sender_endpoints[sender_rank] = (sender_host, sender_zmq_port) + logger.info( + "Sender info updated for rank %s: host=%r, zmq_port=%s", + sender_rank, + sender_host, + sender_zmq_port, + ) + else: + self.sender_host = sender_host + self.sender_zmq_port = sender_zmq_port + logger.info( + "Sender info updated (default): host=%r, zmq_port=%s", + sender_host, + sender_zmq_port, + ) def _get_local_ip(self) -> str: """ @@ -657,56 +683,75 @@ def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[ logger.error(f"RDMA Put failed for {put_key}: {e}", exc_info=True) return False, 0, None - def _query_metadata_from_sender(self, get_key: str) -> dict[str, Any] | None: - """Query metadata from sender via ZMQ (fallback when ``metadata=None``). - - ``get()`` supports two metadata resolution paths:: - - get(metadata=?) - ├── metadata provided (adapter path) - │ → use metadata directly (source_host/port/data_size) - │ → RDMA pull - └── metadata=None (KV-transfer polling path) - → _query_metadata_from_sender(get_key) ← this method - │ - ├── sender_host resolved (via update_sender_info) - │ → ZMQ query → get data_size/is_fast_path - │ → construct metadata → RDMA pull - └── sender_host unresolved ("auto" / None) - → return None → caller retries or times out + def _resolve_sender_endpoint(self, sender_rank: int | None = None) -> tuple[str, int] | None: + """Return ``(host, zmq_port)`` for *sender_rank*. - For the second path, the caller must call - :meth:`update_sender_info` before ``get()`` to resolve the sender's ZMQ endpoint. - Support the two paths in case that the orchestrator pushes the request info - to different stages at the same time knowing metadata or not. + Resolution order: + 1. Per-rank registry (``_sender_endpoints[sender_rank]``) + 2. Default sender (``sender_host`` / ``sender_zmq_port``) + 3. ``None`` if nothing is configured. + """ + if sender_rank is not None and sender_rank in self._sender_endpoints: + return self._sender_endpoints[sender_rank] + host = getattr(self, "sender_host", None) + port = getattr(self, "sender_zmq_port", None) + if host and port and str(host).lower() != "auto": + return (host, int(port)) + return None + + def _query_metadata_at(self, get_key: str, host: str, port: int) -> dict[str, Any] | None: + """Query metadata from a sender endpoint via ZMQ. + + Returns ``{source_host, source_port, data_size, is_fast_path}`` + or ``None`` when the key is not found / the query fails. """ - zmq_addr = f"tcp://{self.sender_host}:{self.sender_zmq_port}" + zmq_addr = f"tcp://{host}:{port}" req_socket = self._get_req_socket(zmq_addr, timeout_ms=5000) - try: - # Send query request - query = QueryRequest(request_id=get_key) - req_socket.send(QUERY_INFO + msgspec.msgpack.encode(query)) + req_socket.send(QUERY_INFO + msgspec.msgpack.encode(QueryRequest(request_id=get_key))) resp = req_socket.recv() - if resp == INFO_NOT_FOUND: return None - - # Parse response query_resp = msgspec.msgpack.decode(resp, type=QueryResponse) return { - # source_host/source_port are used for verification - "source_host": self.sender_host, - "source_port": self.sender_zmq_port, + "source_host": host, + "source_port": port, "data_size": query_resp.data_size, "is_fast_path": query_resp.is_fast_path, } except Exception as e: - # Socket may be stuck in bad state after timeout; discard it self._invalidate_req_socket(zmq_addr) - logger.debug(f"Failed to query metadata for {get_key}: {e}") + logger.debug("Failed to query metadata at %s for %s: %s", zmq_addr, get_key, e) return None + def _query_metadata_from_sender(self, get_key: str, sender_rank: int | None = None) -> dict[str, Any] | None: + """Query metadata from sender via ZMQ (fallback when ``metadata=None``). + + ``get()`` supports three metadata resolution paths:: + + get(metadata=?) + ├── Path 1: metadata has data_size (adapter path) + │ → use metadata directly → RDMA pull + ├── Path 2: metadata has source_host/port but no data_size + │ → _query_metadata_at(host, port) → get data_size → RDMA pull + └── Path 3: metadata=None (KV-transfer polling path) + → _query_metadata_from_sender(get_key) ← this method + │ + ├── sender endpoint resolved (via update_sender_info) + │ → ZMQ query → get data_size/is_fast_path + │ → construct metadata → RDMA pull + └── sender endpoint unresolved + → return None → caller retries or times out + + When *sender_rank* is provided, the query is routed to that + rank's endpoint (registered via ``update_sender_info(rank=...)``). + Otherwise the default sender is used. + """ + endpoint = self._resolve_sender_endpoint(sender_rank) + if endpoint is None: + return None + return self._query_metadata_at(get_key, *endpoint) + def get( self, from_stage: str, @@ -714,12 +759,18 @@ def get( get_key: str, metadata: dict[str, Any] | None = None, ) -> tuple[Any, int] | None: - """ - Consumer Side. - Allocates from local pool and pulls data via RDMA. + """Consumer Side. Allocates from local pool and pulls data via RDMA. + + Metadata resolution: - If metadata is not provided, will attempt to query it from sender - using configured sender_host/sender_zmq_port. + 1. ``metadata`` provided **with** ``data_size`` → use directly (RDMA pull). + 2. ``metadata`` provided with ``source_host``/``source_port`` but + **without** ``data_size`` → query that specific sender for + ``data_size`` / ``is_fast_path``, then RDMA pull. This is the + heterogeneous-TP path where the manager knows the target sender + endpoint but not the payload size. + 3. ``metadata=None`` → query the default sender (set via + ``update_sender_info()``) for the full metadata. Returns: ``(data, size)`` on success, ``None`` on failure. @@ -727,9 +778,6 @@ def get( - **is_fast_path=True** (tensor *or* bytes payload): Returns ``(ManagedBuffer, size)``. **CALLER MUST call ``ManagedBuffer.release()`` after consuming.** - Note: even if the producer ``put()`` raw ``bytes``, the consumer - receives a ``ManagedBuffer`` — use ``buf.to_bytes()`` to obtain - a ``bytes`` copy, or ``buf.tensor`` for zero-copy access. - **is_fast_path=False** (serialized Python object): Returns ``(DeserializedObject, size)``. Buffer is auto-released internally after deserialization. @@ -741,9 +789,8 @@ def get( _t0 = _time_mod.perf_counter() - # If no metadata provided, try to query from sender if not metadata: - # Must insert sender info before using get() without metadata. + # Path 3: no metadata at all — query default sender if not self.sender_host or not self.sender_zmq_port or str(self.sender_host).lower() == "auto": raise RuntimeError( f"get(metadata=None) requires sender info to be resolved, " @@ -753,6 +800,21 @@ def get( metadata = self._query_metadata_from_sender(get_key) if not metadata: return None + elif "data_size" not in metadata: + # Path 2: partial metadata (host/port only) — query that sender + partial_host = metadata.get("source_host") + partial_port = metadata.get("source_port") + if not partial_host or not partial_port: + logger.warning( + "get(%s): partial metadata missing source_host/source_port, cannot resolve data_size. metadata=%s", + get_key, + metadata, + ) + return None + queried = self._query_metadata_at(get_key, str(partial_host), int(partial_port)) + if not queried: + return None + metadata = queried _t1 = _time_mod.perf_counter() _query_ms = (_t1 - _t0) * 1000 diff --git a/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py b/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py index 5c7384c1f8..6cf5c2f15b 100644 --- a/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py +++ b/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py @@ -15,9 +15,13 @@ class SharedMemoryConnector(OmniConnectorBase): - """ - Connector that uses SharedMemory for large objects and inline data for small objects. - Acts as a unified replacement for the legacy IPC fallback logic. + """Key-addressed local shared-memory connector. + + SHM is a local-only transport: it reads/writes POSIX shared memory + segments identified purely by *key*. It does **not** understand + remote-transport metadata such as ``source_host`` / ``source_port`` + (that is the RDMA connector's job). When such metadata is passed in, + the connector silently falls back to key-based lookup. """ def __init__(self, config: dict[str, Any]): @@ -25,6 +29,7 @@ def __init__(self, config: dict[str, Any]): self.stage_id = config.get("stage_id", -1) self.device = config.get("device", "cuda:0") self.threshold = int(config.get("shm_threshold_bytes", 65536)) + self._pending_keys: set[str] = set() self._metrics = { "puts": 0, "gets": 0, @@ -59,6 +64,7 @@ def put( # meta contains {'name': ..., 'size': ...} metadata = {"shm": meta, "size": size} + self._pending_keys.add(put_key) self._metrics["shm_writes"] += 1 else: # Inline - pass bytes directly to avoid double serialization of the object @@ -93,6 +99,28 @@ def _get_data_with_lock(self, lock_file: str, shm_handle: dict): if obj and os.path.exists(lock_file): os.remove(lock_file) + def _get_by_key(self, get_key: str) -> tuple[Any, int] | None: + """Read a SHM segment addressed purely by *get_key*.""" + shm = None + try: + shm = shm_pkg.SharedMemory(name=get_key) + if shm is None or shm.size == 0: + return None + lock_file = f"/dev/shm/shm_{get_key}_lockfile.lock" + shm_handle = {"name": get_key, "size": shm.size} + result = self._get_data_with_lock(lock_file, shm_handle) + if result is not None: + self._pending_keys.discard(get_key) + return result + except FileNotFoundError: + return None + except Exception: + logger.debug("_get_by_key: unexpected error reading SHM segment %s", get_key, exc_info=True) + return None + finally: + if shm: + shm.close() + def get( self, from_stage: str, @@ -101,16 +129,16 @@ def get( metadata=None, ) -> tuple[Any, int] | None: if metadata is not None: - # Some callers may wrap metadata by request id. if isinstance(metadata, dict) and get_key in metadata: metadata = metadata.get(get_key) if not isinstance(metadata, dict): - return None + return self._get_by_key(get_key) if "inline_bytes" in metadata: try: obj = self.deserialize_obj(metadata["inline_bytes"]) + self._pending_keys.discard(get_key) return obj, int(metadata.get("size", 0)) except Exception as e: logger.error(f"SharedMemoryConnector inline get failed for req {get_key}: {e}") @@ -119,33 +147,64 @@ def get( if "shm" in metadata: shm_handle = metadata["shm"] lock_file = f"/dev/shm/shm_{shm_handle['name']}_lockfile.lock" - return self._get_data_with_lock(lock_file, shm_handle) + result = self._get_data_with_lock(lock_file, shm_handle) + if result is not None: + self._pending_keys.discard(get_key) + return result - return None - shm = None - try: - shm = shm_pkg.SharedMemory(name=get_key) - if shm is None or shm.size == 0: - return None - lock_file = f"/dev/shm/shm_{get_key}_lockfile.lock" - shm_handle = {"name": get_key, "size": shm.size} - return self._get_data_with_lock(lock_file, shm_handle) - except Exception: - return None - finally: - if shm: - shm.close() + # Metadata is a dict but has no SHM-specific handle (e.g. RDMA- + # style source_host/source_port). Fall back to key-based read. + return self._get_by_key(get_key) + + return self._get_by_key(get_key) def cleanup(self, request_id: str) -> None: - # SHM segments are automatically unlinked during 'get' (shm_read_bytes). - # If 'get' is never called (e.g. error flow), the SHM segment might leak. - # A robust implementation might track created segments and unlink them here - # if they haven't been consumed. - # For now, we rely on the consumer to read and unlink. - pass + """Best-effort cleanup of unconsumed SHM segments for *request_id*. + + Matches pending keys where *request_id* appears as the full key, + as a ``_``-delimited prefix, or as a ``_``-delimited suffix. + If ``get()`` was never called, we unlink it here so /dev/shm + doesn't leak. + """ + stale = [ + k + for k in self._pending_keys + if k == request_id or k.startswith(request_id + "_") or k.endswith("_" + request_id) + ] + for key in stale: + self._pending_keys.discard(key) + try: + seg = shm_pkg.SharedMemory(name=key) + seg.close() + seg.unlink() + logger.debug("cleanup: unlinked unconsumed SHM segment %s", key) + except FileNotFoundError: + pass + except Exception as e: + logger.debug("cleanup: failed to unlink SHM segment %s: %s", key, e) + lock_file = f"/dev/shm/shm_{key}_lockfile.lock" + if os.path.exists(lock_file): + try: + os.remove(lock_file) + except OSError: + pass def close(self) -> None: - pass + """Unlink all remaining tracked SHM segments.""" + for key in list(self._pending_keys): + try: + seg = shm_pkg.SharedMemory(name=key) + seg.close() + seg.unlink() + except Exception: + pass + lock_file = f"/dev/shm/shm_{key}_lockfile.lock" + if os.path.exists(lock_file): + try: + os.remove(lock_file) + except OSError: + pass + self._pending_keys.clear() def health(self) -> dict[str, Any]: return {"status": "healthy", "threshold": self.threshold, **self._metrics} diff --git a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py index 1958c9d40a..ad008c3971 100644 --- a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py +++ b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py @@ -14,8 +14,20 @@ from .factory import OmniConnectorFactory from .utils.config import ConnectorSpec -from .utils.initialization import KV_TRANSFER_PORT_OFFSET -from .utils.kv_utils import normalize_layer_kv +from .utils.initialization import KV_RANK_PORT_STRIDE +from .utils.kv_utils import ( + KVTPTopology, + build_rank_aware_recv_keys, + build_rank_aware_send_keys, + get_kv_target_ranks, + get_local_tp_rank, + get_tp_world_size, + kv_zmq_port, + merge_received_rank_shards, + normalize_layer_kv, + slice_layer_blocks, + slice_received_rank_shard, +) logger = init_logger(__name__) @@ -57,6 +69,8 @@ class OmniKVCacheConfig: need_recv_cache: bool = False need_send_cache: bool = False recv_timeout: float = 30.0 + from_tp: int = 1 + to_tp: int = 1 @dataclass @@ -72,82 +86,44 @@ def to_dict(self) -> dict[str, Any]: """Convert to dictionary for serialization.""" return asdict(self) - def to_bytes(self) -> bytes: - """Convert to compact binary format for fast transfer.""" - tensors_desc: list[dict[str, Any]] = [] - tensor_bufs: list[bytes] = [] - data_offset = 0 - - for cache_name in ("key_cache", "value_cache"): - cache_list = self.layer_blocks.get(cache_name, []) - for layer_idx, tensor in enumerate(cache_list): - if tensor is None: - tensors_desc.append({"n": f"{cache_name}_{layer_idx}", "x": True}) - continue - - t = tensor.detach().cpu().contiguous() - dtype_str = str(t.dtype).removeprefix("torch.") - raw = t.view(torch.uint8).numpy().tobytes() - tensors_desc.append( - { - "n": f"{cache_name}_{layer_idx}", - "i": layer_idx, - "d": dtype_str, - "s": list(t.shape), - "o": data_offset, - "b": len(raw), - } - ) - tensor_bufs.append(raw) - data_offset += len(raw) - - header = json.dumps( - { - "rid": self.request_id, - "bids": self.block_ids, - "meta": self.metadata, - "td": tensors_desc, - "nl": len(self.layer_blocks.get("key_cache", [])), - }, - separators=(",", ":"), - ).encode("utf-8") - return b"".join([struct.pack(">I", len(header)), header] + tensor_bufs) + def _build_tensors_desc(self, *, cpu: bool) -> tuple[list[dict[str, Any]], list, int, torch.device | None]: + """Iterate layer blocks and build tensor descriptors + data chunks. - def to_gpu_tensor(self) -> torch.Tensor: - """Convert to a packed GPU tensor for raw-data connectors.""" + Returns ``(tensors_desc, chunks, total_bytes, device)``. + *chunks* contains ``bytes`` when *cpu* is True, flat uint8 GPU tensors otherwise. + """ tensors_desc: list[dict[str, Any]] = [] - gpu_tensors: list[torch.Tensor] = [] + chunks: list = [] data_offset = 0 device = None for cache_name in ("key_cache", "value_cache"): - cache_list = self.layer_blocks.get(cache_name, []) - for layer_idx, tensor in enumerate(cache_list): + for layer_idx, tensor in enumerate(self.layer_blocks.get(cache_name, [])): if tensor is None: tensors_desc.append({"n": f"{cache_name}_{layer_idx}", "x": True}) continue - t = tensor.detach().contiguous() - if device is None and t.is_cuda: + if cpu: + t = t.cpu() + elif device is None and t.is_cuda: device = t.device - dtype_str = str(t.dtype).removeprefix("torch.") nbytes = t.numel() * t.element_size() tensors_desc.append( { "n": f"{cache_name}_{layer_idx}", "i": layer_idx, - "d": dtype_str, + "d": str(t.dtype).removeprefix("torch."), "s": list(t.shape), "o": data_offset, "b": nbytes, } ) - gpu_tensors.append(t.view(torch.uint8).flatten()) + chunks.append(t.view(torch.uint8).numpy().tobytes() if cpu else t.view(torch.uint8).flatten()) data_offset += nbytes - if device is None: - raise RuntimeError("No CUDA tensors found, use to_bytes() instead") + return tensors_desc, chunks, data_offset, device + def _build_header_bytes(self, tensors_desc: list[dict[str, Any]]) -> bytes: header = json.dumps( { "rid": self.request_id, @@ -158,19 +134,26 @@ def to_gpu_tensor(self) -> torch.Tensor: }, separators=(",", ":"), ).encode("utf-8") + return struct.pack(">I", len(header)) + header - header_prefix = struct.pack(">I", len(header)) + header - total_size = len(header_prefix) + data_offset - output = torch.empty(total_size, dtype=torch.uint8, device=device) - header_tensor = torch.frombuffer(bytearray(header_prefix), dtype=torch.uint8) - output[: len(header_prefix)].copy_(header_tensor) + def to_bytes(self) -> bytes: + """Convert to compact binary format for fast transfer.""" + tensors_desc, chunks, _, _ = self._build_tensors_desc(cpu=True) + return b"".join([self._build_header_bytes(tensors_desc)] + chunks) + def to_gpu_tensor(self) -> torch.Tensor: + """Convert to a packed GPU tensor for raw-data connectors.""" + tensors_desc, chunks, data_offset, device = self._build_tensors_desc(cpu=False) + if device is None: + raise RuntimeError("No CUDA tensors found, use to_bytes() instead") + header_prefix = self._build_header_bytes(tensors_desc) + output = torch.empty(len(header_prefix) + data_offset, dtype=torch.uint8, device=device) + output[: len(header_prefix)].copy_(torch.frombuffer(bytearray(header_prefix), dtype=torch.uint8)) pos = len(header_prefix) - for t_flat in gpu_tensors: + for t_flat in chunks: n = t_flat.numel() output[pos : pos + n].copy_(t_flat) pos += n - return output @staticmethod @@ -237,11 +220,8 @@ def _resolve_layer_idx(info: dict[str, Any], num_layers: int) -> int: return layer_idx @staticmethod - def from_bytes(raw: "bytes | bytearray | memoryview") -> dict[str, Any]: - """Reconstruct KV cache data from the packed bytes format.""" - raw_mv = memoryview(raw) if not isinstance(raw, memoryview) else raw - header, tensor_data_mv = KVCacheTransferData._load_header_from_memoryview(raw_mv) - + def _populate_caches(header: dict[str, Any], get_tensor: callable) -> dict[str, Any]: + """Shared deserialization loop for both CPU and GPU paths.""" num_layers = header["nl"] key_cache: list[torch.Tensor | None] = [None] * num_layers value_cache: list[torch.Tensor | None] = [None] * num_layers @@ -249,20 +229,9 @@ def from_bytes(raw: "bytes | bytearray | memoryview") -> dict[str, Any]: for info in header["td"]: if info.get("x"): continue - name: str = info["n"] torch_dtype = KVCacheTransferData._resolve_torch_dtype(info["d"]) - offset, nbytes = KVCacheTransferData._validate_tensor_span(name, info, len(tensor_data_mv)) - t = ( - torch.frombuffer( - tensor_data_mv, - dtype=torch.uint8, - offset=offset, - count=nbytes, - ) - .view(torch_dtype) - .reshape(info["s"]) - ) + t = get_tensor(info).view(torch_dtype).reshape(info["s"]) layer_idx = KVCacheTransferData._resolve_layer_idx(info, num_layers) if name.startswith("key_cache_"): key_cache[layer_idx] = t @@ -276,37 +245,30 @@ def from_bytes(raw: "bytes | bytearray | memoryview") -> dict[str, Any]: "metadata": header["meta"], } + @staticmethod + def from_bytes(raw: "bytes | bytearray | memoryview") -> dict[str, Any]: + """Reconstruct KV cache data from the packed bytes format.""" + raw_mv = memoryview(raw) if not isinstance(raw, memoryview) else raw + header, tensor_data_mv = KVCacheTransferData._load_header_from_memoryview(raw_mv) + data_len = len(tensor_data_mv) + + def _get(info: dict) -> torch.Tensor: + offset, nbytes = KVCacheTransferData._validate_tensor_span(info["n"], info, data_len) + return torch.frombuffer(tensor_data_mv, dtype=torch.uint8, offset=offset, count=nbytes) + + return KVCacheTransferData._populate_caches(header, _get) + @staticmethod def from_bytes_gpu(gpu_tensor: torch.Tensor) -> dict[str, Any]: """Reconstruct KV cache data from a packed GPU tensor.""" header, data_start = KVCacheTransferData._load_header_from_tensor(gpu_tensor) + data_len = int(gpu_tensor.numel()) - data_start - num_layers = header["nl"] - key_cache: list[torch.Tensor | None] = [None] * num_layers - value_cache: list[torch.Tensor | None] = [None] * num_layers - tensor_data_bytes = int(gpu_tensor.numel()) - data_start + def _get(info: dict) -> torch.Tensor: + offset, nbytes = KVCacheTransferData._validate_tensor_span(info["n"], info, data_len) + return gpu_tensor[data_start + offset : data_start + offset + nbytes].clone() - for info in header["td"]: - if info.get("x"): - continue - - name: str = info["n"] - torch_dtype = KVCacheTransferData._resolve_torch_dtype(info["d"]) - offset, nbytes = KVCacheTransferData._validate_tensor_span(name, info, tensor_data_bytes) - t = gpu_tensor[data_start + offset : data_start + offset + nbytes].clone() - t = t.view(torch_dtype).reshape(info["s"]) - layer_idx = KVCacheTransferData._resolve_layer_idx(info, num_layers) - if name.startswith("key_cache_"): - key_cache[layer_idx] = t - elif name.startswith("value_cache_"): - value_cache[layer_idx] = t - - return { - "request_id": header["rid"], - "layer_blocks": {"key_cache": key_cache, "value_cache": value_cache}, - "block_ids": header["bids"], - "metadata": header["meta"], - } + return KVCacheTransferData._populate_caches(header, _get) class OmniKVTransferManager: @@ -341,6 +303,30 @@ def __init__(self, config: OmniKVCacheConfig): else (None, None) ) + local_rank = get_local_tp_rank() + + if config.from_tp <= 1 and config.to_tp <= 1: + detected_tp = get_tp_world_size() + from_tp = detected_tp + to_tp = detected_tp + else: + from_tp = config.from_tp + to_tp = config.to_tp + + self._tp_topo = KVTPTopology(source_tp_size=from_tp, target_tp_size=to_tp, local_rank=local_rank) + + # Injectable hooks (compatible with PR #2677 OmniConnectorModelRunnerMixin). + self.kv_send_key_builder: Callable | None = None + self.kv_recv_key_builder: Callable | None = None + self.kv_payload_merger: Callable | None = None + self.kv_payload_slicer: Callable | None = None + + # Base sender endpoint (rank-0 host/port) stored during + # update_sender_info(). Used by the receive path to construct + # per-rank metadata for heterogeneous TP without querying a registry. + self._sender_base_host: str | None = None + self._sender_base_zmq_port: int | None = None + if config.need_send_cache and config.connector_config: try: _ = self.connector @@ -348,11 +334,20 @@ def __init__(self, config: OmniKVCacheConfig): except Exception as e: logger.warning("Failed to eagerly initialize sender connector: %s", e) + # ------------------------------------------------------------------ # + # Factory helpers + # ------------------------------------------------------------------ # + @classmethod def _create(cls, cfg: dict | None) -> "OmniKVTransferManager": """Create manager from raw config dict.""" if not cfg or not isinstance(cfg, dict): return cls(OmniKVCacheConfig()) + + rank_mapping = cfg.get("rank_mapping", {}) + if not isinstance(rank_mapping, dict): + rank_mapping = {} + return cls( OmniKVCacheConfig( connector_config=cfg.get("connector_config"), @@ -363,19 +358,18 @@ def _create(cls, cfg: dict | None) -> "OmniKVTransferManager": need_recv_cache=cfg.get("need_recv_cache", False), need_send_cache=cfg.get("need_send_cache", False), recv_timeout=cfg.get("recv_timeout", 30.0), + from_tp=int(rank_mapping.get("from_tp", 1)), + to_tp=int(rank_mapping.get("to_tp", 1)), ) ) - @classmethod - def from_model_config(cls, config: Any) -> "OmniKVTransferManager": - """Create from model config (for AR model runner).""" - return cls._create(getattr(config, "omni_kv_config", None)) - @classmethod def from_od_config(cls, config: Any) -> "OmniKVTransferManager": - """Create from OmniDiffusion config (for diffusion runner).""" + """Create from model or OmniDiffusion config.""" return cls._create(getattr(config, "omni_kv_config", None)) + from_model_config = from_od_config + @classmethod def from_vllm_config(cls, vllm_config: Any, model_config: Any) -> "OmniKVTransferManager": """Create from vllm config with fallback to kv_transfer_config.""" @@ -417,45 +411,33 @@ def connector(self): ) c_extra["to_stage"] = str(self.config.to_stage) if self.config.to_stage is not None else "1" + try: + stage_int = int(self.config.from_stage) if self.config.from_stage is not None else 0 + except (TypeError, ValueError): + stage_int = 0 + zmq_port = kv_zmq_port(base_port, stage_int, self._tp_topo.local_rank) + if self.config.need_send_cache: c_extra["role"] = "sender" - from_stage = self.config.from_stage - if from_stage is not None: - try: - c_extra["zmq_port"] = base_port + KV_TRANSFER_PORT_OFFSET + int(from_stage) - except (TypeError, ValueError): - c_extra["zmq_port"] = base_port + KV_TRANSFER_PORT_OFFSET + c_extra["zmq_port"] = zmq_port elif self.config.need_recv_cache: c_extra["role"] = "receiver" - from_stage = self.config.from_stage - sender_port = base_port + KV_TRANSFER_PORT_OFFSET - if from_stage is not None: - try: - sender_port = base_port + KV_TRANSFER_PORT_OFFSET + int(from_stage) - except (TypeError, ValueError): - pass c_extra.setdefault("sender_host", c_extra.get("host", "127.0.0.1")) - c_extra.setdefault("sender_zmq_port", sender_port) + c_extra.setdefault("sender_zmq_port", zmq_port) logger.info( - "Initializing OmniConnector (purpose=kv_transfer) with config: %s, role: %s", - cfg, + "Initializing OmniConnector type=%s role=%s", + c_type, c_extra.get("role", "N/A"), ) self._connector = OmniConnectorFactory.create_connector(ConnectorSpec(name=c_type, extra=c_extra)) - except Exception as e: - logger.error(f"Failed to initialize OmniConnector: {e}") - import traceback - - traceback.print_exc() - # Cache failure sentinel to avoid repeated initialization attempts in hot paths. + except Exception: + logger.exception("Failed to initialize OmniConnector") self._connector = False return self._connector if self._connector else None - def get_connector(self): - """Get connector (compatibility wrapper for existing code).""" - return self.connector + get_connector = property(lambda self: self.connector) def _resolve_sender_info( self, sender_info: dict[str, Any], sender_stage_id: str | int | None = None @@ -513,8 +495,187 @@ def _clone_received_payload_tensors(data: dict[str, Any]) -> dict[str, Any]: cache_list[idx] = tensor.clone() return data + def _slice_transfer_data_for_target(self, kv_data: KVCacheTransferData, target_rank: int) -> KVCacheTransferData: + """Pre-slice sender payload for one target rank when sender TP < receiver TP.""" + topo = self._tp_topo + ratio = topo.target_tp_size // topo.source_tp_size + offset_in_sender = target_rank % ratio + metadata = dict(kv_data.metadata) if isinstance(kv_data.metadata, dict) else {} + metadata["tp_head_slice"] = { + "applied": True, + "side": "sender", + "target_rank": target_rank, + "source_rank": topo.local_rank, + "from_tp": topo.source_tp_size, + "to_tp": topo.target_tp_size, + "offset_in_shard": offset_in_sender, + "num_slices": ratio, + } + return KVCacheTransferData( + request_id=kv_data.request_id, + layer_blocks=slice_layer_blocks(kv_data.layer_blocks, offset_in_sender, ratio), + block_ids=list(kv_data.block_ids), + metadata=metadata, + ) + + def _serialize_transfer_payload(self, kv_data: KVCacheTransferData) -> torch.Tensor | bytes | dict[str, Any]: + """Serialize KV transfer data using the connector's fastest supported path.""" + if getattr(self.connector, "supports_raw_data", False): + try: + return kv_data.to_gpu_tensor() + except Exception: + pass + try: + return kv_data.to_bytes() + except Exception: + return kv_data.to_dict() + + @staticmethod + def _collect_request_kv_payload(req: Any) -> dict[str, object]: + """Collect request-side KV objects for object broadcast.""" + kv_payload: dict[str, object] = {} + for attr in ("past_key_values", "kv_metadata"): + val = getattr(req, attr, None) + if val is not None: + kv_payload[attr] = val + + if hasattr(req, "sampling_params") and req.sampling_params is not None: + for key in list(vars(req.sampling_params).keys()): + if key in ("past_key_values", "kv_metadata") or ( + key.startswith("cfg_") + and ( + key.endswith("_past_key_values") + or key.endswith("_kv_metadata") + or key + in ( + "cfg_kv_request_ids", + "cfg_active_branch", + "cfg_branch_roles", + "cfg_branch_past_key_values", + "cfg_branch_kv_metadata", + ) + ) + ): + val = getattr(req.sampling_params, key, None) + if val is not None: + kv_payload[f"sp.{key}"] = val + + return kv_payload + + @staticmethod + def _apply_request_kv_payload( + req: Any, + kv_payload: dict[str, object], + target_device: torch.device | None = None, + ) -> None: + """Apply a broadcast KV payload back onto a request object.""" + for attr in ("past_key_values", "kv_metadata"): + val = kv_payload.get(attr) + if val is not None: + if target_device is not None: + val = _move_to_device(val, target_device) + setattr(req, attr, val) + + if hasattr(req, "sampling_params") and req.sampling_params is not None: + for key, val in kv_payload.items(): + if key.startswith("sp."): + if target_device is not None: + val = _move_to_device(val, target_device) + setattr(req.sampling_params, key[3:], val) + + @staticmethod + def _discover_cfg_branch_roles(req: Any) -> list[str]: + """Discover CFG branch roles in a stable order.""" + sampling_params = getattr(req, "sampling_params", None) + if sampling_params is None: + return [] + + roles: list[str] = [] + branch_map = getattr(sampling_params, "cfg_branch_past_key_values", None) or {} + for preferred_role in ("cfg_text", "cfg_img"): + if ( + preferred_role in branch_map + or getattr(sampling_params, f"{preferred_role}_past_key_values", None) is not None + ): + roles.append(preferred_role) + + for role in branch_map.keys(): + if role not in roles and branch_map.get(role) is not None: + roles.append(role) + + for key in vars(sampling_params).keys(): + if not (key.startswith("cfg_") and key.endswith("_past_key_values")): + continue + role = key.removesuffix("_past_key_values") + if role in ("cfg_branch",) or role in roles: + continue + if getattr(sampling_params, key, None) is not None: + roles.append(role) + + return roles + + @classmethod + def _build_cfg_rank_local_payloads(cls, req: Any, cfg_size: int) -> list[dict[str, object] | None]: + """Build per-cfg-rank payloads so each rank receives only its branch KV.""" + full_payload = cls._collect_request_kv_payload(req) + payloads: list[dict[str, object] | None] = [] + + main_payload = { + key: value + for key, value in full_payload.items() + if key in ("past_key_values", "kv_metadata", "sp.past_key_values", "sp.kv_metadata") + } + branch_roles = cls._discover_cfg_branch_roles(req) + if branch_roles: + main_payload["sp.cfg_branch_roles"] = list(branch_roles) + main_payload["sp.cfg_active_branch"] = None + payloads.append(main_payload or None) + + sampling_params = getattr(req, "sampling_params", None) + branch_map = getattr(sampling_params, "cfg_branch_past_key_values", None) or {} + branch_metadata_map = getattr(sampling_params, "cfg_branch_kv_metadata", None) or {} + + for role in branch_roles: + if sampling_params is None: + payloads.append(None) + continue + + branch_kv = branch_map.get(role) + if branch_kv is None: + branch_kv = getattr(sampling_params, f"{role}_past_key_values", None) + branch_metadata = branch_metadata_map.get(role) + if branch_metadata is None: + branch_metadata = getattr(sampling_params, f"{role}_kv_metadata", None) + if branch_kv is None: + payloads.append(None) + continue + + local_payload = dict(main_payload) + local_payload["sp.cfg_active_branch"] = role + local_payload["sp.cfg_branch_roles"] = list(branch_roles) + local_payload["sp.cfg_branch_past_key_values"] = {role: branch_kv} + local_payload[f"sp.{role}_past_key_values"] = branch_kv + if branch_metadata is not None: + local_payload["sp.cfg_branch_kv_metadata"] = {role: branch_metadata} + local_payload[f"sp.{role}_kv_metadata"] = branch_metadata + + payloads.append(local_payload) + + while len(payloads) < cfg_size: + payloads.append(None) + + return payloads[:cfg_size] + def update_sender_info(self, sender_info: dict[str, Any], sender_stage_id: str | int | None = None) -> None: - """Update receiver-side sender info before loading remote KV cache.""" + """Update receiver-side sender info before loading remote KV cache. + + The orchestrator always reports rank-0's ZMQ port. When TP > 1 the + receiver must offset the port so that each TP rank connects to the + corresponding sender rank's port. + + The base host/port are also stored so that the receive path can + construct per-rank metadata for heterogeneous TP scenarios. + """ if not self.config.need_recv_cache: return @@ -523,18 +684,39 @@ def update_sender_info(self, sender_info: dict[str, Any], sender_stage_id: str | logger.warning("Invalid sender_info format: %s", sender_info) return + sender_host = actual_info.get("host") + base_zmq_port = actual_info.get("zmq_port") + + # Store base sender info for per-rank metadata construction. + self._sender_base_host = sender_host + if base_zmq_port is not None: + self._sender_base_zmq_port = int(base_zmq_port) + + # --- Default sender: offset to match this receiver's corresponding sender rank --- + zmq_port = base_zmq_port + if zmq_port is not None and self._tp_topo.local_rank > 0: + zmq_port = int(zmq_port) + self._tp_topo.local_rank * KV_RANK_PORT_STRIDE + if self.config.connector_config: - self.config.connector_config["sender_host"] = actual_info.get("host") - self.config.connector_config["sender_zmq_port"] = actual_info.get("zmq_port") + self.config.connector_config["sender_host"] = sender_host + self.config.connector_config["sender_zmq_port"] = zmq_port if self._connector and hasattr(self._connector, "update_sender_info"): try: - self._connector.update_sender_info(actual_info.get("host"), actual_info.get("zmq_port")) + self._connector.update_sender_info(sender_host, zmq_port) except Exception: if hasattr(self._connector, "sender_host"): - self._connector.sender_host = actual_info.get("host") + self._connector.sender_host = sender_host if hasattr(self._connector, "sender_zmq_port"): - self._connector.sender_zmq_port = actual_info.get("zmq_port") + self._connector.sender_zmq_port = zmq_port + + logger.info( + "Sender info updated: host=%s, base_port=%s, adjusted_port=%s (local_rank=%s)", + sender_host, + base_zmq_port, + zmq_port, + self._tp_topo.local_rank, + ) def handle_finished_requests_kv_transfer( self, @@ -692,35 +874,54 @@ def _transfer_kv_cache(self, kv_data: KVCacheTransferData, transfer_req_id: str) kv_data.request_id = transfer_req_id serialization_start = time.perf_counter() - transfer_data: torch.Tensor | bytes | dict[str, Any] - supports_raw = getattr(self.connector, "supports_raw_data", False) + topo = self._tp_topo + send_keys = build_rank_aware_send_keys( + transfer_req_id, from_stage, to_stage, topo, hook=self.kv_send_key_builder + ) + sender_slice_active = ( + topo.source_tp_size < topo.target_tp_size and len(send_keys) > 1 and not callable(self.kv_send_key_builder) + ) + per_key_payloads: list[tuple[str, torch.Tensor | bytes | dict[str, Any]]] = [] - try: - if supports_raw: - transfer_data = kv_data.to_gpu_tensor() + if sender_slice_active: + target_ranks = get_kv_target_ranks(topo) + if len(target_ranks) != len(send_keys): + logger.warning( + "Skip sender-side KV slicing because target rank count does not match send key count: " + "target_ranks=%s send_keys=%s", + len(target_ranks), + len(send_keys), + ) + sender_slice_active = False else: - raise RuntimeError("Connector does not support raw tensor") - except Exception: - try: - transfer_data = kv_data.to_bytes() - except Exception: - data_dict = kv_data.to_dict() - data_dict["request_id"] = transfer_req_id - transfer_data = data_dict + for put_key, target_rank in zip(send_keys, target_ranks, strict=False): + sliced_kv_data = self._slice_transfer_data_for_target(kv_data, target_rank) + per_key_payloads.append((put_key, self._serialize_transfer_payload(sliced_kv_data))) + + if not per_key_payloads: + transfer_data = self._serialize_transfer_payload(kv_data) + per_key_payloads = [(put_key, transfer_data) for put_key in send_keys] serialization_ms = (time.perf_counter() - serialization_start) * 1000 logger.info("KV cache serialized for %s in %.1f ms", transfer_req_id, serialization_ms) transfer_start = time.perf_counter() - success, size, _ = self._transfer_with_retry(from_stage, to_stage, f"kv_cache_{transfer_req_id}", transfer_data) + total_size = 0 + all_succeeded = True + for put_key, transfer_data in per_key_payloads: + success, size, _ = self._transfer_with_retry(from_stage, to_stage, put_key, transfer_data) + total_size += size + all_succeeded = all_succeeded and success + elapsed = time.perf_counter() - transfer_start - if success: - mbps = (size / 1024 / 1024) / elapsed if elapsed > 0 else 0 + if all_succeeded: + mbps = (total_size / 1024 / 1024) / elapsed if elapsed > 0 else 0 logger.info( - "KV transfer OK: %s, %s bytes, %.3fs, %.1f MB/s", + "KV transfer OK: %s, %s bytes across %s key(s), %.3fs, %.1f MB/s", transfer_req_id, - size, + total_size, + len(send_keys), elapsed, mbps, ) @@ -731,7 +932,7 @@ def _transfer_with_retry( self, from_stage: str, to_stage: str, - request_id: str, + put_key: str, data: "dict[str, Any] | bytes | torch.Tensor", max_retries: int = 3, ) -> tuple[bool, int, dict[str, Any] | None]: @@ -740,7 +941,7 @@ def _transfer_with_retry( Args: from_stage: Source stage identifier to_stage: Target stage identifier - request_id: Request identifier for the key + put_key: Pre-built connector key (rank-aware when TP > 1) data: Data to transfer max_retries: Maximum number of retry attempts @@ -749,14 +950,12 @@ def _transfer_with_retry( """ for attempt in range(max_retries): try: - # Build the full key for connector - full_request_id = f"omni_{from_stage}_to_{to_stage}_{request_id}" success, size, metadata = self.connector.put( - from_stage=from_stage, to_stage=to_stage, put_key=full_request_id, data=data + from_stage=from_stage, to_stage=to_stage, put_key=put_key, data=data ) if success: return success, size, metadata - logger.warning(f"Transfer attempt {attempt + 1} failed for {request_id}") + logger.warning(f"Transfer attempt {attempt + 1} failed for {put_key}") except Exception as e: logger.warning(f"Transfer attempt {attempt + 1} exception: {e}") @@ -801,22 +1000,46 @@ def receive_kv_cache_for_request( poll_interval = 0.01 max_poll_interval = 0.5 - logger.info(f"Wait for KV cache for request {request_id} from stage {from_stage} to {to_stage}...") + topo = self._tp_topo + recv_key_pairs = build_rank_aware_recv_keys( + request_id, from_stage, to_stage, topo, hook=self.kv_recv_key_builder + ) + pending_pairs = list(recv_key_pairs) + received_payloads: dict[str, tuple[dict[str, Any], int]] = {} + + logger.info( + "Wait for KV cache for request %s from stage %s to %s via %s key(s)...", + request_id, + from_stage, + to_stage, + len(recv_key_pairs), + ) try: while True: - # Build the full key for connector - full_request_id = f"omni_{from_stage}_to_{to_stage}_kv_cache_{request_id}" link_start = time.perf_counter() - result = self.connector.get( - from_stage=from_stage, - to_stage=to_stage, - get_key=full_request_id, - ) - if result: + for get_key, from_rank in list(pending_pairs): + # Construct per-rank metadata so the connector queries + # the correct sender endpoint (heterogeneous TP path). + # When from_rank is None (TP<=1), metadata stays None + # and the connector falls back to its default sender. + rank_metadata: dict[str, Any] | None = None + if from_rank is not None and self._sender_base_host and self._sender_base_zmq_port is not None: + rank_metadata = { + "source_host": self._sender_base_host, + "source_port": self._sender_base_zmq_port + from_rank * KV_RANK_PORT_STRIDE, + } + + result = self.connector.get( + from_stage=from_stage, + to_stage=to_stage, + get_key=get_key, + metadata=rank_metadata, + ) + if not result: + continue + raw_data, size = result - elapsed = time.time() - start_time - link_ms = (time.perf_counter() - link_start) * 1000 managed_buffer = None if hasattr(raw_data, "tensor") and hasattr(raw_data, "release"): @@ -844,6 +1067,21 @@ def receive_kv_cache_for_request( else: data = raw_data + received_payloads[get_key] = (data, size) + pending_pairs.remove((get_key, from_rank)) + + if not pending_pairs and received_payloads: + elapsed = time.time() - start_time + link_ms = (time.perf_counter() - link_start) * 1000 + ordered_payloads = [received_payloads[key][0] for key, _ in recv_key_pairs] + total_size = sum(received_payloads[key][1] for key, _ in recv_key_pairs) + + if len(ordered_payloads) == 1: + data = ordered_payloads[0] + else: + data = merge_received_rank_shards(ordered_payloads, merger=self.kv_payload_merger) + data = slice_received_rank_shard(data, topo, slicer=self.kv_payload_slicer) + try: if isinstance(data, dict) and "layer_blocks" in data: layer_blocks = data["layer_blocks"] @@ -856,18 +1094,18 @@ def receive_kv_cache_for_request( continue if target_device is not None and tensor.device != target_device: cache_list[i] = tensor.to(target_device).contiguous() - finally: - if managed_buffer is not None: - managed_buffer.release() + except Exception: + logger.exception("Failed to move KV cache tensors to target device") logger.info( - "Successfully received KV cache for %s, %s bytes, wait=%.3fs, link=%.1fms", + "Successfully received KV cache for %s, %s bytes across %s key(s), wait=%.3fs, link=%.1fms", request_id, - size, + total_size, + len(recv_key_pairs), elapsed, link_ms, ) - return data, size + return data, total_size if time.time() - start_time > timeout: logger.error(f"Timeout waiting for KV cache for request {request_id} after {timeout}s") @@ -876,11 +1114,8 @@ def receive_kv_cache_for_request( time.sleep(poll_interval) poll_interval = min(poll_interval * 2, max_poll_interval) - except Exception as e: - logger.error(f"Error receiving KV cache for {request_id}: {e}") - import traceback - - traceback.print_exc() + except Exception: + logger.exception("Error receiving KV cache for %s", request_id) return None, 0 def apply_kv_cache_to_request(self, req: Any, data: dict[str, Any]) -> None: @@ -994,73 +1229,79 @@ def receive_multi_kv_cache_distributed( cfg_kv_collect_func: Callable | None = None, target_device: torch.device | None = None, ) -> bool: - """Broadcast-aware wrapper around :meth:`receive_multi_kv_cache`. - - SharedMemory connector is single-reader: once rank 0 consumes the - segment it is deleted. For multi-GPU stages (e.g. sequence-parallel) - only rank 0 receives; the result is then broadcast to every other - rank via the world process-group. - - For single-worker stages this is equivalent to calling - :meth:`receive_multi_kv_cache` directly. + """Distributed wrapper around :meth:`receive_multi_kv_cache`. + + TP-aware path selection: + - world size 1: direct receive + - TP active, cfg size 1: each rank independently receives + - TP active, cfg size > 1: cfg-rank 0 receives, then broadcasts to + peers that share the same TP rank + - TP inactive: legacy rank-0 receive then world broadcast """ - from vllm_omni.diffusion.distributed.parallel_state import get_world_group + from vllm_omni.diffusion.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, + get_world_group, + ) world = get_world_group() if world.world_size <= 1: return self.receive_multi_kv_cache(req, cfg_kv_collect_func, target_device) - # --- rank 0: receive to CPU (needed for pickle-based broadcast) --- - if world.rank_in_group == 0: - self.receive_multi_kv_cache(req, cfg_kv_collect_func, torch.device("cpu")) + topo = self._tp_topo + tp_active = topo.source_tp_size > 1 or topo.target_tp_size > 1 + cfg_size = 1 + cfg_rank = 0 + cfg_group = None + try: + cfg_size = get_classifier_free_guidance_world_size() + cfg_rank = get_classifier_free_guidance_rank() + cfg_group = get_cfg_group() + except Exception: + cfg_size = 1 + cfg_rank = 0 + cfg_group = None - kv_payload: dict[str, object] = {} - for attr in ("past_key_values", "kv_metadata"): - val = getattr(req, attr, None) - if val is not None: - kv_payload[attr] = val + if tp_active and cfg_size <= 1: + logger.info( + "Rank-aware KV receive: rank %s independently receiving (from_tp=%s, to_tp=%s)", + topo.local_rank, + topo.source_tp_size, + topo.target_tp_size, + ) + return self.receive_multi_kv_cache(req, cfg_kv_collect_func, target_device) - if hasattr(req, "sampling_params") and req.sampling_params is not None: - for key in list(vars(req.sampling_params).keys()): - if (key.startswith("cfg_") and key.endswith("_past_key_values")) or key in ( - "past_key_values", - "kv_metadata", - ): - val = getattr(req.sampling_params, key, None) - if val is not None: - kv_payload[f"sp.{key}"] = val - - payload_list = [kv_payload] - # Use broadcast_object_list (pickle-based) instead of broadcast_tensor_dict - # because the KV cache is a heterogeneous nested structure (NaiveCache objects - # with metadata + tensors), not a flat tensor dict. This runs once before - # the denoising loop so the serialization cost is negligible. - torch.distributed.broadcast_object_list(payload_list, src=world.ranks[0], group=world.cpu_group) - kv_payload = payload_list[0] - else: - payload_list: list[dict[str, object] | None] = [None] - torch.distributed.broadcast_object_list(payload_list, src=world.ranks[0], group=world.cpu_group) - kv_payload = payload_list[0] + if tp_active and cfg_size > 1 and cfg_group is not None: + kv_payload: dict[str, object] | None = None + if cfg_rank == 0: + received = self.receive_multi_kv_cache(req, cfg_kv_collect_func, torch.device("cpu")) + rank_payloads = self._build_cfg_rank_local_payloads(req, cfg_size) if received else [None] * cfg_size + kv_payload = rank_payloads[0] + for dst_rank in range(1, cfg_size): + cfg_group.send_object(rank_payloads[dst_rank], dst_rank) + else: + kv_payload = cfg_group.recv_object(0) - # --- apply on ALL ranks (rank 0 also needs CPU→GPU move) --- - if not kv_payload: - return False + if not kv_payload: + return False - for attr in ("past_key_values", "kv_metadata"): - val = kv_payload.get(attr) - if val is not None: - if target_device is not None: - val = _move_to_device(val, target_device) - setattr(req, attr, val) + self._apply_request_kv_payload(req, kv_payload, target_device) + return True - if hasattr(req, "sampling_params") and req.sampling_params is not None: - for key, val in kv_payload.items(): - if key.startswith("sp."): - if target_device is not None: - val = _move_to_device(val, target_device) - setattr(req.sampling_params, key[3:], val) + kv_payload: dict[str, object] | None = None + if world.rank_in_group == 0: + received = self.receive_multi_kv_cache(req, cfg_kv_collect_func, torch.device("cpu")) + if received: + kv_payload = self._collect_request_kv_payload(req) + + kv_payload = world.broadcast_object(kv_payload, src=0) + + if not kv_payload: + return False + self._apply_request_kv_payload(req, kv_payload, target_device) return True diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py index 393d0e8013..1d5fc398a4 100644 --- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py +++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py @@ -58,6 +58,7 @@ def __init__(self, vllm_config: Any): self.waiting_for_chunk_waiting_requests: deque[Any] = deque() self.waiting_for_chunk_running_requests: deque[Any] = deque() self.requests_with_ready_chunks = set() + self.requests_origin_status = {} @classmethod def create_connector(cls, model_config: Any): @@ -279,6 +280,7 @@ def cleanup_receiver(self, request_id: str) -> None: self.get_req_chunk.pop(request_id, None) self.requests_with_ready_chunks.discard(request_id) self.request_ids_mapping.pop(request_id, None) + self.requests_origin_status.pop(request_id, None) self._cancelled_load_reqs.add(request_id) self._finished_load_reqs.discard(request_id) @@ -408,6 +410,7 @@ def _process_chunk_queue( self.requests_with_ready_chunks.add(request.request_id) continue queue.remove(request) + self.requests_origin_status[request.request_id] = target_status waiting_for_chunk_list.append(request) def _clear_chunk_ready(self, scheduler_output: Any) -> None: @@ -420,3 +423,23 @@ def _clear_chunk_ready(self, scheduler_output: Any) -> None: for req_id in scheduler_output.scheduled_cached_reqs.req_ids: if req_id in self.requests_with_ready_chunks: self.requests_with_ready_chunks.remove(req_id) + + def finish_requests( + self, request_ids: Any, finished_status: RequestStatus, requests: dict[str, Request] | None = None + ) -> list[tuple[str, int]]: + assert RequestStatus.is_finished(finished_status) + if isinstance(request_ids, str): + request_ids = (request_ids,) + elif request_ids is not None: + request_ids = set(request_ids) + else: + request_ids = requests.keys() + + # First pass: collect requests to remove from queues + for req_id in request_ids: + request = requests.get(req_id) if requests else None + if request is None or request.is_finished(): + # Invalid request ID. + continue + if req_id in self.requests_origin_status: + request.status = self.requests_origin_status.pop(req_id) diff --git a/vllm_omni/distributed/omni_connectors/utils/initialization.py b/vllm_omni/distributed/omni_connectors/utils/initialization.py index 37b7d0d7f8..f012af3c9c 100644 --- a/vllm_omni/distributed/omni_connectors/utils/initialization.py +++ b/vllm_omni/distributed/omni_connectors/utils/initialization.py @@ -23,6 +23,11 @@ # collide with request-forwarding endpoints that share the same base port. KV_TRANSFER_PORT_OFFSET = 100 +# Port stride between TP ranks so each worker binds a unique ZMQ port +# when TP > 1. Must be larger than the maximum number of pipeline stages. +# Formula: zmq_port = base + KV_TRANSFER_PORT_OFFSET + rank * STRIDE + stage +KV_RANK_PORT_STRIDE = 16 + def initialize_connectors_from_config( config_path: str | Path | None = None, @@ -201,6 +206,19 @@ def load_omni_transfer_config( if config_dict is None: return None + # Normalize new-schema (top-level ``connectors`` + ``stages``) into the + # legacy ``runtime.connectors`` + ``stage_args`` shape the parser reads. + if "stages" in config_dict and "stage_args" not in config_dict: + normalized: dict[str, Any] = dict(config_dict) + runtime = dict(normalized.get("runtime") or {}) + if "connectors" in normalized and "connectors" not in runtime: + runtime["connectors"] = normalized["connectors"] + if "edges" in normalized and "edges" not in runtime: + runtime["edges"] = normalized["edges"] + normalized["runtime"] = runtime + normalized["stage_args"] = normalized["stages"] + config_dict = normalized + # Parse connectors connectors = {} runtime_config = config_dict.get("runtime", {}) diff --git a/vllm_omni/distributed/omni_connectors/utils/kv_utils.py b/vllm_omni/distributed/omni_connectors/utils/kv_utils.py index 2cb48a8b34..12b9b3d4f7 100644 --- a/vllm_omni/distributed/omni_connectors/utils/kv_utils.py +++ b/vllm_omni/distributed/omni_connectors/utils/kv_utils.py @@ -1,15 +1,380 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Utility helpers for KV cache manipulation.""" +"""Utility helpers for KV cache manipulation, TP routing, and merge/slice.""" + +from __future__ import annotations + +import os +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any import torch +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.logger import init_logger +from .initialization import KV_RANK_PORT_STRIDE, KV_TRANSFER_PORT_OFFSET + logger = init_logger(__name__) LayerKV = torch.Tensor | tuple[torch.Tensor, torch.Tensor] +# ------------------------------------------------------------------ # +# TP Topology +# ------------------------------------------------------------------ # + + +@dataclass(frozen=True) +class KVTPTopology: + """Immutable descriptor for a KV-transfer parallel mapping. + + Captures sender/receiver parallel sizes and the local rank within + that parallel dimension. Works for any divisible parallel dimension + (TP, SP, Ring Attention). + """ + + source_tp_size: int + target_tp_size: int + local_rank: int + + def __post_init__(self) -> None: + if self.source_tp_size <= 0 or self.target_tp_size <= 0: + raise ValueError( + f"Parallel sizes must be positive: " + f"source_tp_size={self.source_tp_size}, target_tp_size={self.target_tp_size}" + ) + if self.local_rank < 0: + raise ValueError(f"local_rank must be non-negative, got {self.local_rank}") + + @property + def is_heterogeneous(self) -> bool: + return self.source_tp_size != self.target_tp_size + + @property + def ratio(self) -> int: + """Larger parallel size divided by smaller. Always >= 1.""" + return max(self.source_tp_size, self.target_tp_size) // min(self.source_tp_size, self.target_tp_size) + + +# ------------------------------------------------------------------ # +# Runtime TP detection +# ------------------------------------------------------------------ # + + +def get_local_tp_rank() -> int: + """Return the TP-local rank of this worker process. + + Uses ``get_tensor_model_parallel_rank()`` which returns the rank + within the TP group only, not the stage-global rank. + """ + try: + return get_tensor_model_parallel_rank() + except Exception: + logger.debug("TP parallel state not initialized, falling back to LOCAL_RANK env", exc_info=True) + try: + return int(os.environ.get("LOCAL_RANK", "0")) + except (ValueError, TypeError): + return 0 + + +def get_tp_world_size() -> int: + """Return the TP world size (tensor-parallel dimension only). + + Uses ``get_tensor_model_parallel_world_size()`` so that + cfg_parallel, SP, PP etc. are not included in the count. + """ + try: + return get_tensor_model_parallel_world_size() + except Exception: + logger.debug("TP parallel state not initialized, defaulting world_size=1", exc_info=True) + return 1 + + +# ------------------------------------------------------------------ # +# ZMQ port computation +# ------------------------------------------------------------------ # + + +def kv_zmq_port(base_port: int, from_stage: int, local_rank: int = 0) -> int: + """Compute the ZMQ port for a KV-transfer connector. + + Each TP rank gets its own port so that TP > 1 deployments do not + cause ``EADDRINUSE`` when multiple sender workers bind on the same + host. The formula is backward-compatible: rank 0 produces the same + port as the previous ``base + OFFSET + stage`` formula. + """ + return base_port + KV_TRANSFER_PORT_OFFSET + local_rank * KV_RANK_PORT_STRIDE + from_stage + + +# ------------------------------------------------------------------ # +# TP topology validation and rank routing +# ------------------------------------------------------------------ # + + +def validate_kv_tp_topology(topo: KVTPTopology) -> None: + """Reject heterogeneous TP mappings that cannot be routed losslessly.""" + larger = max(topo.source_tp_size, topo.target_tp_size) + smaller = min(topo.source_tp_size, topo.target_tp_size) + if larger % smaller != 0: + raise ValueError( + f"KV TP mapping must be divisible: " + f"source_tp_size={topo.source_tp_size}, " + f"target_tp_size={topo.target_tp_size}" + ) + + +def get_kv_target_ranks(topo: KVTPTopology) -> list[int]: + """Which remote ranks this local rank sends KV shards to (send side).""" + validate_kv_tp_topology(topo) + if topo.source_tp_size == topo.target_tp_size: + return [topo.local_rank] + if topo.source_tp_size > topo.target_tp_size: + return [topo.local_rank // (topo.source_tp_size // topo.target_tp_size)] + ratio = topo.target_tp_size // topo.source_tp_size + return [topo.local_rank * ratio + i for i in range(ratio)] + + +def get_kv_source_ranks(topo: KVTPTopology) -> list[int]: + """Which remote ranks this local rank receives KV shards from (recv side).""" + validate_kv_tp_topology(topo) + if topo.source_tp_size == topo.target_tp_size: + return [topo.local_rank] + if topo.source_tp_size > topo.target_tp_size: + ratio = topo.source_tp_size // topo.target_tp_size + return [topo.local_rank * ratio + i for i in range(ratio)] + return [topo.local_rank // (topo.target_tp_size // topo.source_tp_size)] + + +# ------------------------------------------------------------------ # +# Rank-aware connector key building +# ------------------------------------------------------------------ # + + +def get_kv_connector_key( + req_id: str, + from_stage: int | str, + chunk_id: int, + from_rank: int, + to_rank: int, +) -> str: + """Build connector key that includes rank info for KV transfers. + + Format matches PR #2677: ``{req_id}_{from_stage}_{chunk_id}_{from_rank}_{to_rank}`` + """ + return f"{req_id}_{from_stage}_{chunk_id}_{from_rank}_{to_rank}" + + +def build_rank_aware_send_keys( + request_id: str, + from_stage: str, + to_stage: str, + topo: KVTPTopology, + hook: Callable | None = None, +) -> list[str]: + """Build send-side connector keys, checking injectable hook first.""" + if callable(hook): + keys = list(hook(request_id, from_stage, to_stage)) + if keys: + return keys + if topo.source_tp_size <= 1 and topo.target_tp_size <= 1: + return [f"omni_{from_stage}_to_{to_stage}_kv_cache_{request_id}"] + target_ranks = get_kv_target_ranks(topo) + return [get_kv_connector_key(request_id, from_stage, 0, topo.local_rank, r) for r in target_ranks] + + +def build_rank_aware_recv_keys( + request_id: str, + from_stage: str, + to_stage: str, + topo: KVTPTopology, + hook: Callable | None = None, +) -> list[tuple[str, int | None]]: + """Build recv-side connector keys with sender rank info. + + Returns a list of ``(key, from_rank)`` tuples. ``from_rank`` is + ``None`` when TP <= 1 (single sender, no per-rank routing needed). + For TP > 1, ``from_rank`` identifies which sender rank owns the + key so that the connector can route metadata queries to the + correct endpoint. + """ + if callable(hook): + raw = list(hook(request_id, from_stage, to_stage)) + if raw: + if isinstance(raw[0], tuple): + return raw + # Hook returned plain strings (e.g. OmniConnectorModelRunnerMixin. + # get_rank_aware_kv_keys). Reconstruct from_rank from topology so + # Mooncake connector can route metadata queries to the correct + # sender endpoint in heterogeneous TP. + # TODO: have the mixin return (key, from_rank) tuples directly + # to avoid this indirect reconstruction. + source_ranks = get_kv_source_ranks(topo) + if len(raw) == len(source_ranks): + return list(zip(raw, source_ranks)) + return [(k, None) for k in raw] + if topo.source_tp_size <= 1 and topo.target_tp_size <= 1: + return [(f"omni_{from_stage}_to_{to_stage}_kv_cache_{request_id}", None)] + source_ranks = get_kv_source_ranks(topo) + return [(get_kv_connector_key(request_id, from_stage, 0, r, topo.local_rank), r) for r in source_ranks] + + +# ------------------------------------------------------------------ # +# KV tensor head slicing (heterogeneous TP) +# ------------------------------------------------------------------ # + + +def slice_kv_tensor_heads( + tensor: torch.Tensor | None, + offset_in_shard: int, + num_slices: int, +) -> torch.Tensor | None: + """Slice one KV tensor along its head dimension (dim 1).""" + if tensor is None: + return None + if not isinstance(tensor, torch.Tensor): + return tensor + if tensor.dim() < 2: + raise ValueError(f"Expected KV tensor with a head dimension, got shape={tuple(tensor.shape)}") + if num_slices <= 0: + raise ValueError(f"num_slices must be > 0, got {num_slices}") + if not (0 <= offset_in_shard < num_slices): + raise ValueError(f"offset_in_shard must be in [0, {num_slices}), got {offset_in_shard}") + + heads_in_shard = tensor.shape[1] + if heads_in_shard % num_slices != 0: + raise ValueError( + "KV head count must be divisible for heterogeneous TP slicing: " + f"heads_in_shard={heads_in_shard}, num_slices={num_slices}" + ) + + heads_per_slice = heads_in_shard // num_slices + start = offset_in_shard * heads_per_slice + end = start + heads_per_slice + return tensor[:, start:end, ...].contiguous() + + +def slice_layer_blocks( + layer_blocks: dict[str, Any], + offset_in_shard: int, + num_slices: int, +) -> dict[str, list[torch.Tensor | None]]: + """Slice all KV layers for one logical receiver rank.""" + sliced_blocks: dict[str, list[torch.Tensor | None]] = {} + for cache_name in ("key_cache", "value_cache"): + cache_list = layer_blocks.get(cache_name, []) + sliced_blocks[cache_name] = [ + slice_kv_tensor_heads(tensor, offset_in_shard, num_slices) for tensor in cache_list + ] + return sliced_blocks + + +# ------------------------------------------------------------------ # +# Multi-rank merge and receiver-side slice +# ------------------------------------------------------------------ # + + +def merge_received_rank_shards( + payloads: list[dict[str, Any]], + merger: Callable | None = None, +) -> dict[str, Any] | None: + """Merge multiple source-rank KV shards for one target rank. + + When *merger* is provided (injectable hook), it is called directly. + Otherwise the default merges along the head dimension (dim 1). + """ + if callable(merger): + return merger(payloads) + if not payloads: + return None + if len(payloads) == 1: + return payloads[0] + + base_payload = payloads[0] + if not isinstance(base_payload, dict) or "layer_blocks" not in base_payload: + return base_payload + + merged: dict[str, Any] = { + "request_id": base_payload.get("request_id"), + "block_ids": list(base_payload.get("block_ids", [])), + "metadata": dict(base_payload.get("metadata", {})), + } + merged_layer_blocks: dict[str, list[torch.Tensor | None]] = {} + + for cache_name in ("key_cache", "value_cache"): + cache_lists = [payload.get("layer_blocks", {}).get(cache_name, []) for payload in payloads] + num_layers = max((len(cache_list) for cache_list in cache_lists), default=0) + merged_cache: list[torch.Tensor | None] = [] + + for layer_idx in range(num_layers): + layer_tensors = [ + cache_list[layer_idx] + for cache_list in cache_lists + if layer_idx < len(cache_list) and cache_list[layer_idx] is not None + ] + if not layer_tensors: + merged_cache.append(None) + elif len(layer_tensors) == 1 or not isinstance(layer_tensors[0], torch.Tensor): + merged_cache.append(layer_tensors[0]) + else: + merged_cache.append(torch.cat(layer_tensors, dim=1).contiguous()) + + merged_layer_blocks[cache_name] = merged_cache + + merged["layer_blocks"] = merged_layer_blocks + return merged + + +def slice_received_rank_shard( + payload: dict[str, Any] | None, + topo: KVTPTopology, + slicer: Callable | None = None, +) -> dict[str, Any] | None: + """Optionally slice a received payload to extract this rank's portion. + + Used when ``to_tp > from_tp``: the sender sent full heads and each + receiver rank slices out its own subset. + """ + if callable(slicer): + return slicer(payload) + if not payload or topo.target_tp_size <= topo.source_tp_size or "layer_blocks" not in payload: + return payload + + metadata = payload.get("metadata", {}) + slice_metadata = metadata.get("tp_head_slice") if isinstance(metadata, dict) else None + if isinstance(slice_metadata, dict) and slice_metadata.get("applied"): + tagged_rank = slice_metadata.get("target_rank") + if tagged_rank is not None and tagged_rank != topo.local_rank: + logger.warning( + "Received pre-sliced KV payload for unexpected target rank: expected=%s got=%s", + topo.local_rank, + tagged_rank, + ) + return payload + + ratio = topo.target_tp_size // topo.source_tp_size + offset_in_sender = topo.local_rank % ratio + updated_metadata = dict(metadata) if isinstance(metadata, dict) else {} + updated_metadata["tp_head_slice"] = { + "applied": True, + "side": "receiver", + "target_rank": topo.local_rank, + "from_tp": topo.source_tp_size, + "to_tp": topo.target_tp_size, + "offset_in_shard": offset_in_sender, + "num_slices": ratio, + } + return { + "request_id": payload.get("request_id"), + "layer_blocks": slice_layer_blocks(payload["layer_blocks"], offset_in_sender, ratio), + "block_ids": list(payload.get("block_ids", [])), + "metadata": updated_metadata, + } + + def normalize_layer_kv( layer_kv: LayerKV, *, diff --git a/vllm_omni/engine/__init__.py b/vllm_omni/engine/__init__.py index c8a96e6d25..6c92d7952d 100644 --- a/vllm_omni/engine/__init__.py +++ b/vllm_omni/engine/__init__.py @@ -79,6 +79,10 @@ class OmniEngineCoreRequest(EngineCoreRequest): class OmniEngineCoreOutput(EngineCoreOutput): pooling_output: dict[str, torch.Tensor] | None = None + # Finished flag for streaming input segment + is_segment_finished: bool | None = False + # Streaming update prompt length + new_prompt_len_snapshot: int | None = None class OmniEngineCoreOutputs(EngineCoreOutputs): diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py index d43f1b8fdc..d98ce7d419 100644 --- a/vllm_omni/engine/arg_utils.py +++ b/vllm_omni/engine/arg_utils.py @@ -3,7 +3,7 @@ import json import os import tempfile -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields from typing import Any from vllm.engine.arg_utils import EngineArgs @@ -20,6 +20,8 @@ _ARCH_TO_MODEL_TYPE: dict[str, str] = { "CosyVoice3Model": "cosyvoice3", "OmniVoiceModel": "omnivoice", + "VoxCPM2TalkerForConditionalGeneration": "voxcpm2", + "VoxCPMForConditionalGeneration": "voxcpm", } # Maps model architecture names to tokenizer subfolder paths within HF repos. @@ -40,6 +42,8 @@ def _register_omni_hf_configs() -> None: from vllm_omni.model_executor.models.voxtral_tts.configuration_voxtral_tts import ( VoxtralTTSConfig, ) + from vllm_omni.transformers_utils.configs.voxcpm import VoxCPMConfig + from vllm_omni.transformers_utils.configs.voxcpm2 import VoxCPM2Config except Exception as exc: # pragma: no cover - best-effort optional registration logger.warning("Skipping omni HF config registration due to import error: %s", exc) return @@ -57,6 +61,8 @@ def _register_omni_hf_configs() -> None: ("cosyvoice3", CosyVoice3Config), ("omnivoice", OmniVoiceConfig), ("voxtral_tts", VoxtralTTSConfig), + ("voxcpm", VoxCPMConfig), + ("voxcpm2", VoxCPM2Config), ]: try: AutoConfig.register(model_type, config_cls) @@ -121,6 +127,9 @@ class OmniEngineArgs(EngineArgs): (e.g. ["text", "audio"]). If None, all modalities supported by the model are used. log_stats: Whether to log engine statistics. Defaults to False. + custom_pipeline_args: Dictionary of arguments for custom pipeline + initialization (e.g., ``{"pipeline_class": "my.Module"}``). + Passed through to the diffusion stage engine. """ stage_id: int = 0 @@ -140,6 +149,7 @@ class OmniEngineArgs(EngineArgs): stage_configs_path: str | None = None output_modalities: list[str] | None = None log_stats: bool = False + custom_pipeline_args: dict[str, Any] | None = None def __post_init__(self) -> None: load_omni_general_plugins() @@ -290,3 +300,254 @@ def create_model_config(self) -> OmniModelConfig: def output_modality(self) -> OutputModality: """Parse engine_output_type into a type-safe OutputModality flag.""" return OutputModality.from_string(self.engine_output_type) + + +# ============================================================================ +# CLI argument routing +# ============================================================================ +# +# vLLM-Omni's CLI flags live in three buckets: +# +# ┌──────────────────┐ ┌──────────────────┐ ┌──────────────────┐ +# │ OrchestratorArgs │ │ OmniEngineArgs │ │ (upstream vllm) │ +# │ │ │ │ │ server/api │ +# │ stage_timeout │ │ max_num_seqs │ │ host, port │ +# │ worker_backend │ │ gpu_mem_util │ │ ssl_keyfile │ +# │ deploy_config │ │ dtype, quant │ │ api_key │ +# │ ... │ │ ... │ │ ... │ +# └──────────────────┘ └──────────────────┘ └──────────────────┘ +# │ │ │ +# ▼ ▼ ▼ +# orchestrator each stage uvicorn / +# consumes engine FastAPI +# +# Fields in ``SHARED_FIELDS`` (e.g. ``model``, ``log_stats``) flow to BOTH +# orchestrator and engine by design. +# +# Invariants enforced by ``tests/test_arg_utils.py``: +# +# 1. ``OrchestratorArgs`` ∩ ``OmniEngineArgs`` ⊆ ``SHARED_FIELDS`` +# 2. Every CLI flag is classifiable into one of the three buckets +# 3. User-typed flags that match none of the above are logged as dropped +# +# Adding a new orchestrator-only flag → add a field to ``OrchestratorArgs``. +# Everything else is automatic. + + +@dataclass(frozen=True) +class OrchestratorArgs: + """CLI flags consumed by the orchestrator. + + Contract: every field here is either + (a) orchestrator-only (never needed by a stage engine), OR + (b) orchestrator-read-then-redistributed (e.g. ``async_chunk`` is read + from CLI, written to ``DeployConfig``, then propagated to every + stage via ``merge_pipeline_deploy`` — not via direct kwargs + forwarding). + + Fields that BOTH orchestrator and engine genuinely need (e.g. ``model``, + ``log_stats``) should be listed in ``SHARED_FIELDS`` below; ``split_kwargs`` + will copy them to both buckets. + """ + + # === Lifecycle === + stage_init_timeout: int = 300 + init_timeout: int = 600 + + # === Cross-stage Communication === + shm_threshold_bytes: int = 65536 + batch_timeout: int = 10 + + # === Cluster / Backend === + worker_backend: str = "multi_process" + ray_address: str | None = None + + # === Config Files === + stage_configs_path: str | None = None + deploy_config: str | None = None + stage_overrides: str | None = None # raw JSON string; parsed downstream + + # === Mode Switches (orchestrator reads, DeployConfig redistributes) === + async_chunk: bool | None = None + + # === Observability === + log_stats: bool = False + + # === Headless Mode (also forwarded to engine — see SHARED_FIELDS) === + stage_id: int | None = None + + # === Pre-built Objects === + parallel_config: Any = None + + # === Multi-stage guards === + # --tokenizer is captured here so it does not propagate to every stage + # uniformly (different stages often need different tokenizers, e.g. + # qwen3_omni thinker vs talker). Users wanting a per-stage tokenizer + # should set it in the deploy YAML. + tokenizer: str | None = None + + +# Fields that live in BOTH OrchestratorArgs and OmniEngineArgs by design. +# Changes to this set are a review red flag — revisit the contract. +SHARED_FIELDS: frozenset[str] = frozenset( + { + "model", # orch: detect model_type; engine: load weights + "stage_id", # orch: route (headless); engine: identity + "log_stats", # both want the flag + "stage_configs_path", # orch: load legacy YAML; engine: may reference for validation + } +) + + +def orchestrator_field_names() -> frozenset[str]: + """Return the names of every field on OrchestratorArgs.""" + return frozenset(f.name for f in fields(OrchestratorArgs)) + + +def internal_blacklist_keys() -> frozenset[str]: + """Return the set of CLI keys that must never be forwarded as per-stage + engine overrides. + + Derived from ``OrchestratorArgs`` fields minus ``SHARED_FIELDS``, so + adding a new orchestrator-owned flag is a one-line change to the + dataclass — this function updates automatically. + """ + return orchestrator_field_names() - SHARED_FIELDS + + +def split_kwargs( + kwargs: dict[str, Any], + *, + engine_cls: type | None = None, + user_typed: set[str] | None = None, + strict: bool = False, +) -> tuple[OrchestratorArgs, dict[str, Any]]: + """Partition CLI kwargs into (orchestrator, engine) buckets. + + Args: + kwargs: Raw dict, typically ``vars(args)``. + engine_cls: Engine dataclass used to whitelist-filter the engine + bucket. Defaults to ``OmniEngineArgs``. Pass a custom class + for testing. + user_typed: Keys the user actually typed on the command line. Used + to warn when a user-typed flag is unclassifiable. + strict: If True, raise ``ValueError`` on ambiguous (double-classified + but not in ``SHARED_FIELDS``) fields. Default False to keep the + rollout non-breaking; flip to True in tests and CI. + + Returns: + ``(orchestrator_args, engine_kwargs)``. ``engine_kwargs`` has already + been whitelist-filtered against ``engine_cls`` — safe to pass directly + to ``engine_cls(**engine_kwargs)``. + """ + if engine_cls is None: + engine_cls = OmniEngineArgs + + orch_fields = orchestrator_field_names() + engine_fields = {f.name for f in fields(engine_cls)} + + orch_kwargs: dict[str, Any] = {} + engine_candidate: dict[str, Any] = {} + shared_values: dict[str, Any] = {} + unclassified: dict[str, Any] = {} + + for key, value in kwargs.items(): + in_orch = key in orch_fields + in_engine = key in engine_fields + is_shared = key in SHARED_FIELDS + + if is_shared: + shared_values[key] = value + elif in_orch and in_engine: + # Declared in both but not marked shared → ambiguous. + msg = ( + f"Field {key!r} is defined on both OrchestratorArgs and " + f"{engine_cls.__name__} but is not in SHARED_FIELDS. " + f"This causes double-routing. Either remove the duplicate or " + f"add {key!r} to SHARED_FIELDS if the sharing is intentional." + ) + if strict: + raise ValueError(msg) + logger.error(msg) + # Default: treat as orchestrator-only to preserve existing behavior. + orch_kwargs[key] = value + elif in_orch: + orch_kwargs[key] = value + elif in_engine: + engine_candidate[key] = value + else: + unclassified[key] = value + + # Warn on user-typed but unclassifiable flags so we don't silently drop + # something the user cared about (fixes the class of bug that spawned #873). + if unclassified and user_typed: + user_typed_unknown = sorted(k for k in unclassified if k in user_typed) + if user_typed_unknown: + logger.warning( + "CLI flags not consumed by vllm-omni and dropped before " + "per-stage engine construction: %s. If these are vllm " + "frontend/uvicorn flags (host, port, ssl_*, api_key, …) this " + "is expected; otherwise check your spelling.", + user_typed_unknown, + ) + + # Engine bucket: shared + engine-only. We do NOT pass through unclassified + # fields — that's exactly the server/uvicorn noise we want to shed. + engine_kwargs = {**shared_values, **engine_candidate} + + # Construct the orchestrator dataclass. Shared fields that OrchestratorArgs + # also declares get copied into its constructor. + orch_init: dict[str, Any] = dict(orch_kwargs) + for key, value in shared_values.items(): + if key in orch_fields: + orch_init[key] = value + orch_args = OrchestratorArgs(**orch_init) + + return orch_args, engine_kwargs + + +def derive_server_dests_from_vllm_parser() -> frozenset[str]: + """Derive the set of argparse dests that belong to vllm's frontend/server. + + Returns every dest registered by ``make_arg_parser`` that is NOT a field + of ``OmniEngineArgs`` and NOT a field of ``OrchestratorArgs``. Useful for + CI tests to assert all CLI flags are classifiable without maintaining + a hardcoded server list. + + Returns empty frozenset if vllm's parser cannot be built (e.g. in a + minimal test environment). + """ + try: + from vllm.entrypoints.openai.cli_args import make_arg_parser + from vllm.utils.argparse_utils import FlexibleArgumentParser + except ImportError: + logger.debug("Cannot import vllm parser — server-dest derivation skipped") + return frozenset() + + try: + parser = make_arg_parser(FlexibleArgumentParser()) + all_dests = {a.dest for a in parser._actions if a.dest and a.dest != "help"} + except Exception as exc: + logger.debug("Failed to build vllm parser: %s", exc) + return frozenset() + + engine_fields = {f.name for f in fields(OmniEngineArgs)} + orch_fields = orchestrator_field_names() + + return frozenset(all_dests - engine_fields - orch_fields - SHARED_FIELDS) + + +def orchestrator_args_from_argparse(args: Any) -> OrchestratorArgs: + """Build an ``OrchestratorArgs`` from an ``argparse.Namespace``. + + Only copies attributes that exist on the namespace — missing fields fall + back to the dataclass default. Useful when the full parser is already + built and ``vars(args)`` would include noise. + """ + kwargs: dict[str, Any] = {} + for f in fields(OrchestratorArgs): + if hasattr(args, f.name): + value = getattr(args, f.name) + if value is not None or f.default is None: + kwargs[f.name] = value + return OrchestratorArgs(**kwargs) diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 77b386124c..35b6be70f3 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -25,12 +25,15 @@ import janus import torch from omegaconf import OmegaConf +from vllm import envs as vllm_envs +from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.tokenizers import cached_tokenizer_from_config from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.input_processor import InputProcessor +from vllm_omni.config.stage_config import strip_parent_engine_args from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient from vllm_omni.diffusion.stage_diffusion_proc import ( @@ -63,6 +66,7 @@ ) from vllm_omni.engine.stage_init_utils import ( StartedLlmStage, + _inject_inferred_kv_tp_topology, acquire_device_locks, build_diffusion_config, build_engine_args_dict, @@ -80,7 +84,11 @@ setup_stage_devices, terminate_alive_proc, ) -from vllm_omni.entrypoints.utils import load_and_resolve_stage_configs +from vllm_omni.entrypoints.pd_utils import PDDisaggregationMixin +from vllm_omni.entrypoints.utils import ( + inject_omni_kv_config, + load_and_resolve_stage_configs, +) from vllm_omni.inputs.preprocess import OmniInputPreprocessor from vllm_omni.platforms import current_omni_platform @@ -90,6 +98,27 @@ logger = init_logger(__name__) +# ============================================================================ +# Parent-EngineArgs field-routing contracts (consumed by +# AsyncOmniEngine._strip_parent_engine_args when ``stage_configs_path`` is set). +# ============================================================================ + +# Fields that must survive the "equal to default → strip" filter because +# diffusion stages need them even when equal to vllm's default value +# (e.g. colocate worker setup relies on worker_extension_cls being forwarded). +_PARENT_ARGS_KEEP: frozenset[str] = frozenset({"worker_extension_cls"}) + +# Omni orchestrator-level fields consumed by ``_resolve_stage_configs`` that +# must never leak into per-stage EngineArgs (``stage_configs_path`` would +# trigger the ``create_model_config`` guard). +_PARENT_ARGS_STRIP: frozenset[str] = frozenset({"stage_configs_path"}) + +# Fields always populated by callers (via ``from_cli_args`` / ``asdict``) so +# their presence as an override is never a surprise — suppress the +# "override ignored" warning for these. +_PARENT_ARGS_NO_WARN: frozenset[str] = frozenset({"model"}) + + def _patch_generation_config_if_needed(model_config: Any) -> None: """Ensure try_get_generation_config won't crash for models whose HF config.json lacks model_type (e.g. CosyVoice3). We probe it once; @@ -380,6 +409,12 @@ def _launch_llm_stage( omni_kv["omni_to_stage"] = omni_to omni_kv.setdefault("stage_id", metadata.stage_id) engine_args_dict["omni_kv_config"] = omni_kv + if self.stage_configs: + _inject_inferred_kv_tp_topology( + engine_args_dict.get("omni_kv_config"), + metadata.stage_id, + self.stage_configs, + ) vllm_config, executor_class = build_vllm_config( stage_cfg, self.model, @@ -426,21 +461,24 @@ def _launch_llm_stage( proc=proc, ) logger.info("[AsyncOmniEngine] Stage %s engine launch started", metadata.stage_id) - # Keep the stage-specific device visibility until vLLM - # finishes starting all child processes. - if self.single_stage_mode and self._omni_master_server is not None: - launch_stack.close() - else: - assert proc is not None - assert handshake_address is not None - complete_stage_handshake(proc, handshake_address, addresses, vllm_config) - logger.info("[AsyncOmniEngine] Stage %s engine startup completed", metadata.stage_id) finally: if previous_visible_devices is None: current_omni_platform.unset_device_control_env_var() else: current_omni_platform.set_device_control_env_var(previous_visible_devices) + # After StageEngineCoreProc has been spawned it carries its + # stage-specific device visibility into descendants, so the + # slow HELLO/READY handshake can run without holding the + # process-wide launch lock. + if self.single_stage_mode and self._omni_master_server is not None: + launch_stack.close() + else: + assert proc is not None + assert handshake_address is not None + complete_stage_handshake(proc, handshake_address, addresses, vllm_config, stage_init_timeout) + logger.info("[AsyncOmniEngine] Stage %s engine startup completed", metadata.stage_id) + assert started_stage is not None return started_stage except Exception: @@ -752,10 +790,8 @@ def _initialize_stages(self, stage_init_timeout: int) -> None: setup_stage_devices(configured_stage_id, metadata.runtime_cfg) omni_conn_cfg, omni_from, omni_to = omni_kv_connector if omni_conn_cfg: - from vllm_omni.entrypoints.utils import inject_omni_kv_config - inject_omni_kv_config(stage_cfg, omni_conn_cfg, omni_from, omni_to) - inject_kv_stage_info(stage_cfg, configured_stage_id) + inject_kv_stage_info(stage_cfg, configured_stage_id, self.stage_configs) if self.single_stage_mode: assert self._omni_master_server is not None stage_clients[stage_idx] = self._launch_diffusion_stage( @@ -764,11 +800,14 @@ def _initialize_stages(self, stage_init_timeout: int) -> None: self._omni_master_server, ) else: + use_inline = True if self.num_stages == 1 else False stage_clients[stage_idx] = initialize_diffusion_stage( self.model, stage_cfg, metadata, + stage_init_timeout=stage_init_timeout, batch_size=self.diffusion_batch_size, + use_inline=use_inline, ) logger.info( "[AsyncOmniEngine] Stage %s initialized (diffusion, batch_size=%d)", @@ -927,6 +966,7 @@ async def _run_orchestrator() -> None: if isinstance(_sc, StageEngineCoreClientBase): _sc.receiver_connectors = self.stage_receiver_connectors[_sid] + pd_config = self._detect_pd_config() orchestrator = Orchestrator( request_async_queue=self.request_queue.async_q, output_async_queue=self.output_queue.async_q, @@ -936,6 +976,7 @@ async def _run_orchestrator() -> None: output_processors=self.output_processors, stage_vllm_configs=self.stage_vllm_configs, connectors=self.omni_connectors, + pd_config=pd_config, ) if not startup_future.done(): startup_future.set_result(asyncio.get_running_loop()) @@ -1098,7 +1139,6 @@ def _enqueue_cfg_companions( params=companion_params, supported_tasks=self.supported_tasks, ) - request = _upgrade_to_omni_request(request, companion_prompt) request.external_req_id = cid self.output_processors[0].add_request( @@ -1158,12 +1198,64 @@ def _normalize_cache_config(cache_backend: str | None, cache_config: Any | None) cache_config = AsyncOmniEngine._get_default_cache_config(cache_backend) return cache_config + def _detect_pd_config(self) -> dict[str, Any] | None: + """Detect PD (Prefill-Decode) disaggregation config from stage_configs. + Returns a dict with 'pd_pair' and 'bootstrap_addr', or None. + """ + pd_pair = PDDisaggregationMixin.detect_pd_separation_from_stage_configs(self.stage_configs) + if pd_pair is None: + return None + prefill_idx, decode_idx = pd_pair + + # Extract bootstrap address from prefill stage engine_args + bootstrap_addr: str | None = None + try: + prefill_cfg = self.stage_configs[prefill_idx] + ea = getattr(prefill_cfg, "engine_args", None) + kv_cfg = getattr(ea, "kv_transfer_config", None) if ea is not None else None + if kv_cfg is not None: + port = vllm_envs.VLLM_MOONCAKE_BOOTSTRAP_PORT + kv_ip = getattr(kv_cfg, "kv_ip", None) or "127.0.0.1" + bootstrap_addr = f"http://{kv_ip}:{port}" + except Exception as exc: + logger.warning("[AsyncOmniEngine] Could not extract PD bootstrap address: %s", exc) + + logger.info( + "[AsyncOmniEngine] PD disaggregation detected: prefill=stage-%d, decode=stage-%d, bootstrap=%s", + prefill_idx, + decode_idx, + bootstrap_addr, + ) + prefill_engine_id: str | None = None + try: + prefill_client = self.stage_clients[prefill_idx] + kv_cfg = getattr(getattr(prefill_client, "vllm_config", None), "kv_transfer_config", None) + prefill_engine_id = getattr(kv_cfg, "engine_id", None) + except Exception as exc: + logger.warning("[AsyncOmniEngine] Could not extract prefill engine_id: %s", exc) + + return { + "pd_pair": (prefill_idx, decode_idx), + "bootstrap_addr": bootstrap_addr, + "prefill_engine_id": prefill_engine_id, + } + @staticmethod def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: """Create a default single-stage diffusion config from kwargs.""" # We temporally create a default config for diffusion stage. # In the future, we should merge the default config with the user-provided config. normalized_kwargs = dict(kwargs) + default_sampling_params = normalized_kwargs.get("default_sampling_params") + if isinstance(default_sampling_params, str): + try: + default_sampling_params = json.loads(default_sampling_params) + except json.JSONDecodeError: + logger.warning("Invalid default_sampling_params JSON, ignoring stage defaults.") + default_sampling_params = None + if not isinstance(default_sampling_params, dict): + default_sampling_params = None + stage_default_sampling_params = default_sampling_params.get("0", {}) if default_sampling_params else {} # TODO: hack, convert dtype to string to avoid non-premitive omegaconf create error. if "dtype" in normalized_kwargs and not isinstance(normalized_kwargs["dtype"], str): @@ -1227,6 +1319,8 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: "enable_cpu_offload": kwargs.get("enable_cpu_offload", False), "enable_layerwise_offload": kwargs.get("enable_layerwise_offload", False), "enforce_eager": kwargs.get("enforce_eager", False), + "boundary_ratio": kwargs.get("boundary_ratio", None), + "flow_shift": kwargs.get("flow_shift", None), "diffusion_load_format": kwargs.get("diffusion_load_format", "default"), "custom_pipeline_args": kwargs.get("custom_pipeline_args", None), "worker_extension_cls": kwargs.get("worker_extension_cls", None), @@ -1258,6 +1352,7 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: "devices": devices, }, "engine_args": stage_engine_args, + "default_sampling_params": stage_default_sampling_params, "final_output": True, "final_output_type": "image", } @@ -1265,10 +1360,50 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: default_stage_cfg[0]["engine_args"]["model_stage"] = "diffusion" return default_stage_cfg + @staticmethod + def _strip_single_engine_args(kwargs: dict[str, Any]) -> dict[str, Any]: + """Remove parent ``EngineArgs`` fields from *kwargs*. + + When ``stage_configs_path`` is set, per-stage engine args are defined + in the YAML. Top-level single-engine fields (``compilation_config``, + ``tensor_parallel_size``, …) must not leak into per-stage configs via + the ``base_engine_args`` merge in ``load_stage_configs_from_yaml`` — + they can cause type errors (e.g. ``compilation_config`` as a JSON + string rejected by ``VllmConfig``) or silently override YAML values. + + Logs a warning for any parent field whose value differs from the + dataclass default, so users know their explicit overrides are ignored. + See the module-level ``_PARENT_ARGS_*`` constants for the routing + contracts this method enforces. + """ + parent_fields: dict[str, dataclasses.Field] = {f.name: f for f in dataclasses.fields(EngineArgs)} + result, overridden = strip_parent_engine_args( + kwargs, + parent_fields=parent_fields, + keep_keys=_PARENT_ARGS_KEEP, + strip_keys=_PARENT_ARGS_STRIP, + no_warn_keys=_PARENT_ARGS_NO_WARN, + ) + + if overridden: + logger.warning( + "stage_configs_path is set — the following top-level engine " + "args are ignored (per-stage YAML takes precedence): %s", + ", ".join(sorted(overridden)), + ) + + return result + def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[str, list[Any]]: """Resolve stage configs and inject defaults shared by orchestrator/headless.""" stage_configs_path = kwargs.get("stage_configs_path", None) + deploy_config_path = kwargs.pop("deploy_config", None) + stage_overrides_json = kwargs.pop("stage_overrides", None) + # Set of CLI keys the user actually typed; ``None`` means we have no + # parser-level info (e.g. programmatic Omni() call) and the lower + # layers should treat all kwargs as explicit. + cli_explicit_keys = kwargs.pop("_cli_explicit_keys", None) explicit_stage_configs = kwargs.pop("stage_configs", None) if explicit_stage_configs is not None: logger.warning( @@ -1276,13 +1411,32 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st "Ignoring it and resolving stages from stage_configs_path/model factory." ) - # Use the legacy config loading path (load_and_resolve_stage_configs). - # StageConfigFactory wiring will be done in config refactor [2/N]. + if stage_configs_path is not None: + base_kwargs = self._strip_single_engine_args(kwargs) + else: + base_kwargs = kwargs + + # Parse --stage-overrides JSON string if provided + stage_overrides = None + if stage_overrides_json: + if isinstance(stage_overrides_json, str): + try: + stage_overrides = json.loads(stage_overrides_json) + except json.JSONDecodeError as exc: + raise ValueError( + f"--stage-overrides is not valid JSON: {exc}. Got: {stage_overrides_json!r}" + ) from exc + else: + stage_overrides = stage_overrides_json + config_path, stage_configs = load_and_resolve_stage_configs( model, stage_configs_path, - kwargs, + base_kwargs, default_stage_cfg_factory=lambda: self._create_default_diffusion_stage_cfg(kwargs), + deploy_config_path=deploy_config_path, + stage_overrides=stage_overrides, + cli_explicit_keys=cli_explicit_keys, ) # Inject diffusion LoRA-related knobs from kwargs if not present in the stage config. diff --git a/vllm_omni/engine/cfg_companion_tracker.py b/vllm_omni/engine/cfg_companion_tracker.py new file mode 100644 index 0000000000..b9dfae833e --- /dev/null +++ b/vllm_omni/engine/cfg_companion_tracker.py @@ -0,0 +1,125 @@ +"""CFG companion request tracker for the Omni orchestrator. + +Encapsulates all bookkeeping for Classifier-Free Guidance companion +requests (parent/companion ID mapping, completion tracking, +deferred forwarding, and cleanup). +""" + +from __future__ import annotations + +import logging +from typing import Any + +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +logger = logging.getLogger(__name__) + + +class CfgCompanionTracker: + """Manages CFG companion request lifecycle in the orchestrator scheduling loop.""" + + def __init__(self) -> None: + self._companion_map: dict[str, dict[str, str]] = {} # parent -> {role: companion_id} + self._companion_ids: set[str] = set() + self._companion_to_parent: dict[str, str] = {} # companion -> parent + self._done: dict[str, set[str]] = {} # parent -> completed companion ids + self._pending_parents: dict[str, dict[str, Any]] = {} # parent -> deferred result + + def is_companion(self, req_id: str) -> bool: + return req_id in self._companion_ids + + def has_companions(self, parent_id: str) -> bool: + return parent_id in self._companion_map + + def all_companions_done(self, parent_id: str) -> bool: + role_map = self._companion_map.get(parent_id, {}) + done_set = self._done.get(parent_id, set()) + return all(cid in done_set for cid in role_map.values()) + + def get_companion_request_ids(self, parent_id: str) -> dict[str, str]: + """Return ``{role: companion_request_id}`` for a parent.""" + return self._companion_map.get(parent_id, {}) + + def register_parent(self, parent_id: str) -> None: + self._companion_map.setdefault(parent_id, {}) + self._done.setdefault(parent_id, set()) + + def register_companion(self, parent_id: str, role: str, companion_id: str) -> None: + self.register_parent(parent_id) + self._companion_map[parent_id][role] = companion_id + self._companion_ids.add(companion_id) + self._companion_to_parent[companion_id] = parent_id + + def attach_cfg_request_ids(self, parent_id: str, sampling_params: Any) -> Any: + cfg_ids = self.get_companion_request_ids(parent_id) + if not cfg_ids: + return sampling_params + + if isinstance(sampling_params, OmniDiffusionSamplingParams): + sampling_params = sampling_params.clone() + sampling_params.cfg_kv_request_ids = cfg_ids + logger.info( + "Attaching cfg_kv_request_ids=%s to request %s", + cfg_ids, + parent_id, + ) + return sampling_params + + def on_companion_completed(self, companion_id: str) -> str | None: + """Mark done. Returns parent_id only if parent is pending and all companions finished.""" + parent_id = self._companion_to_parent.get(companion_id) + if parent_id is None: + return None + done_set = self._done.get(parent_id) + assert done_set is not None, f"Companion {companion_id} completed before parent {parent_id} was registered" + if companion_id in done_set: + return None + done_set.add(companion_id) + logger.debug("CFG companion %s completed (parent=%s)", companion_id, parent_id) + if parent_id in self._pending_parents and self.all_companions_done(parent_id): + return parent_id + return None + + def defer_parent(self, parent_id: str, engine_outputs: Any, stage_id: int) -> None: + """Hold parent result while waiting for companions to finish.""" + # TODO: Add timeout/error recovery when the orchestrator grows a + # companion-failure path. Today deferred parents are released only when + # companions finish or the external layer aborts the request. + self._pending_parents[parent_id] = { + "engine_outputs": engine_outputs, + "stage_id": stage_id, + } + logger.debug("Parent %s deferred, waiting for CFG companions", parent_id) + + def pop_pending_parent(self, parent_id: str) -> dict[str, Any] | None: + return self._pending_parents.pop(parent_id, None) + + def cleanup_parent(self, parent_id: str) -> list[str]: + companion_ids = list(self._companion_map.pop(parent_id, {}).values()) + for companion_id in companion_ids: + self._companion_ids.discard(companion_id) + self._companion_to_parent.pop(companion_id, None) + self._done.pop(parent_id, None) + self._pending_parents.pop(parent_id, None) + return companion_ids + + def abort_parents(self, request_ids: list[str]) -> list[str]: + all_request_ids = list(request_ids) + seen = set(all_request_ids) + parents_to_cleanup: set[str] = set() + + for req_id in request_ids: + # The orchestrator calls this with parent request IDs. If a raw + # companion ID is passed here, keep it as a direct abort target and + # avoid tearing down parent tracking state implicitly. + if req_id not in self._companion_ids: + parents_to_cleanup.add(req_id) + + for parent_id in parents_to_cleanup: + companion_ids = self.cleanup_parent(parent_id) + for companion_id in companion_ids: + if companion_id not in seen: + seen.add(companion_id) + all_request_ids.append(companion_id) + + return all_request_ids diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py index d955b30251..aea7b9a8b7 100644 --- a/vllm_omni/engine/orchestrator.py +++ b/vllm_omni/engine/orchestrator.py @@ -27,6 +27,7 @@ from vllm_omni.engine import ( OmniEngineCoreRequest, ) +from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker from vllm_omni.engine.serialization import serialize_additional_information from vllm_omni.metrics.stats import StageRequestStats as StageRequestMetrics from vllm_omni.metrics.stats import StageStats @@ -41,6 +42,8 @@ def build_engine_core_request_from_tokens( params: SamplingParams | PoolingParams, arrival_time: float | None = None, model_config: ModelConfig | None = None, + resumable: bool = False, + mm_features: list | None = None, ) -> OmniEngineCoreRequest: """Build an OmniEngineCoreRequest directly from an OmniTokensPrompt. @@ -75,7 +78,7 @@ def build_engine_core_request_from_tokens( return OmniEngineCoreRequest( request_id=request_id, prompt_token_ids=prompt_token_ids, - mm_features=None, + mm_features=mm_features, sampling_params=sampling_params, pooling_params=pooling_params, arrival_time=arrival_time, @@ -83,6 +86,7 @@ def build_engine_core_request_from_tokens( cache_salt=None, data_parallel_rank=None, prompt_embeds=prompt_embeds, + resumable=resumable, additional_information=additional_info_payload, ) @@ -103,6 +107,22 @@ class OrchestratorRequestState: # Metrics: timestamp when request was submitted to each stage stage_submit_ts: dict[int, float] = field(default_factory=dict) + mm_processor_kwargs: dict | None = None + mm_features: list | None = None + + streaming: StreamingInputState = field(default_factory=lambda: StreamingInputState()) + + +@dataclass +class StreamingInputState: + # Flag of streaming input request + enabled: bool = False + # Flag of segment of streaming input finished + segment_finished: bool = False + # Streaming update prompt length + new_prompt_len_snapshot: int | None = None + # Model/bridge-specific runtime states (e.g., thinker->talker) + bridge_states: dict[str, Any] = field(default_factory=dict) class Orchestrator: @@ -123,6 +143,7 @@ def __init__( *, async_chunk: bool = False, connectors: dict[tuple[str, str], Any] | None = None, + pd_config: dict[str, Any] | None = None, ) -> None: self.request_async_queue = request_async_queue self.output_async_queue = output_async_queue @@ -138,15 +159,21 @@ def __init__( # Sender-side omni connectors keyed by (from_stage, to_stage) self.connectors: dict[tuple[str, str], Any] = connectors or {} + # PD disaggregation state + self._pd_pair: tuple[int, int] | None = None + self._pd_bootstrap_addr: str | None = None + self._pd_prefill_engine_id: str | None = None + self._pd_kv_params: dict[str, Any] = {} + if pd_config is not None: + self._pd_pair = pd_config.get("pd_pair") + self._pd_bootstrap_addr = pd_config.get("bootstrap_addr") + self._pd_prefill_engine_id = pd_config.get("prefill_engine_id") + # Per-request state self.request_states: dict[str, OrchestratorRequestState] = {} # CFG companion tracking - self._companion_map: dict[str, dict[str, str]] = {} - self._companion_to_parent: dict[str, str] = {} - self._companion_ids: set[str] = set() - self._companion_done: dict[str, set[str]] = {} - self._deferred_parents: dict[str, dict[str, Any]] = {} + self._cfg_tracker = CfgCompanionTracker() # Per-stage metrics accumulators. self._batch_seq: list[int] = [0] * self.num_stages @@ -250,6 +277,23 @@ async def _orchestration_loop(self) -> None: idle = False req_state = self.request_states.get(output.request_id) if req_state is not None: + if getattr(output, "error", None) is not None: + parent_id = self._companion_to_parent.get(output.request_id, output.request_id) + await self.output_async_queue.put( + { + "type": "error", + "request_id": parent_id, + "stage_id": stage_id, + "error": output.error, + } + ) + role_map = self._companion_map.get(parent_id, {}) + for cid in role_map.values(): + self.request_states.pop(cid, None) + self._cleanup_companion_state(parent_id) + self.request_states.pop(parent_id, None) + continue + stage_metrics = self._build_stage_metrics(stage_id, output.request_id, [output], req_state) await self._route_output(stage_id, output, req_state, stage_metrics) continue @@ -321,7 +365,7 @@ async def _route_output( # CFG companion handling: companions don't produce user-visible output # and don't forward to the next stage directly. - if finished and req_id in self._companion_ids: + if finished and self._cfg_tracker.is_companion(req_id): await self._handle_cfg_companion_ready(req_id) self.request_states.pop(req_id, None) return @@ -349,63 +393,80 @@ async def _route_output( } ) + # PD disaggregation: extract KV transfer params from prefill stage output + if self._pd_pair is not None and finished and stage_id == self._pd_pair[0]: + kv_params = getattr(output, "kv_transfer_params", None) + if kv_params is not None: + self._pd_kv_params[req_id] = kv_params if isinstance(kv_params, dict) else dict(kv_params) + logger.debug( + "[Orchestrator][PD] stored kv_transfer_params for req=%s (keys=%s)", + req_id, + list(self._pd_kv_params[req_id].keys()), + ) + else: + logger.warning( + "[Orchestrator][PD] prefill stage output for req=%s has no kv_transfer_params; " + "KV transfer may fail. Ensure apply_mooncake_connector_patch() was called.", + req_id, + ) + if ( - finished + (finished or (req_state.streaming.enabled and req_state.streaming.segment_finished)) and stage_id < req_state.final_stage_id and not self.async_chunk - and not self._next_stage_already_submitted(stage_id, req_state) + and (not self._next_stage_already_submitted(stage_id, req_state) or req_state.streaming.enabled) ): - if req_id in self._companion_map and not self._all_companions_done(req_id): - self._deferred_parents[req_id] = { - "stage_id": stage_id, - "output": output, - } + if ( + finished + and self._cfg_tracker.has_companions(req_id) + and not self._cfg_tracker.all_companions_done(req_id) + ): + self._cfg_tracker.defer_parent(req_id, output, stage_id) else: - await self._forward_to_next_stage(req_id, stage_id, output, req_state) + await self._forward_to_next_stage( + req_id, + stage_id, + output, + req_state, + is_streaming_session=req_state.streaming.enabled, + is_final_update=False, + ) + if req_state.streaming.enabled and finished: + # For streaming sessions, send the terminal (resumable=False) update only on a finish + await self._forward_to_next_stage( + req_id, + stage_id, + output, + req_state, + is_streaming_session=True, + is_final_update=True, + ) if finished and stage_id == req_state.final_stage_id: - self._cleanup_companion_state(req_id) + # PD: clean up any lingering KV params for this request + self._pd_kv_params.pop(req_id, None) + self._cfg_tracker.cleanup_parent(req_id) self.request_states.pop(req_id, None) - def _cleanup_companion_state(self, parent_id: str) -> None: - """Remove all companion tracking state for a completed parent.""" - role_map = self._companion_map.pop(parent_id, {}) - for cid in role_map.values(): - self._companion_ids.discard(cid) - self._companion_to_parent.pop(cid, None) - self._companion_done.pop(parent_id, None) - self._deferred_parents.pop(parent_id, None) - - def _all_companions_done(self, parent_id: str) -> bool: - """Check whether all CFG companions for a parent request have finished.""" - role_map = self._companion_map.get(parent_id, {}) - if not role_map: - return True - done_set = self._companion_done.get(parent_id, set()) - return all(cid in done_set for cid in role_map.values()) - def _next_stage_already_submitted(self, stage_id: int, req_state: OrchestratorRequestState) -> bool: return (stage_id + 1) in req_state.stage_submit_ts async def _handle_cfg_companion_ready(self, req_id: str) -> None: """Mark a CFG companion as done; if all companions are done, flush deferred parent.""" - parent_id = self._companion_to_parent.get(req_id) + parent_id = self._cfg_tracker.on_companion_completed(req_id) if parent_id is None: return - done_set = self._companion_done.setdefault(parent_id, set()) - if req_id in done_set: + deferred = self._cfg_tracker.pop_pending_parent(parent_id) + if deferred is None: return - done_set.add(req_id) - if parent_id in self._deferred_parents and self._all_companions_done(parent_id): - deferred = self._deferred_parents.pop(parent_id) - parent_state = self.request_states.get(parent_id) - if parent_state is not None and not self._next_stage_already_submitted(deferred["stage_id"], parent_state): - await self._forward_to_next_stage( - parent_id, - deferred["stage_id"], - deferred["output"], - parent_state, - ) + parent_state = self.request_states.get(parent_id) + if parent_state is not None and not self._next_stage_already_submitted(deferred["stage_id"], parent_state): + await self._forward_to_next_stage( + parent_id, + deferred["stage_id"], + deferred["engine_outputs"], + parent_state, + ) async def _handle_kv_ready_raw_outputs(self, stage_id: int, raw_outputs: EngineCoreOutputs) -> None: """Forward split requests once stage-0 KV is ready, not only when decode fully finishes.""" @@ -419,21 +480,63 @@ async def _handle_kv_ready_raw_outputs(self, stage_id: int, raw_outputs: EngineC req_state = self.request_states.get(req_id) if req_state is None: continue - if req_id in self._companion_ids: + if self._cfg_tracker.is_companion(req_id): await self._handle_cfg_companion_ready(req_id) continue if stage_id >= req_state.final_stage_id: continue if self._next_stage_already_submitted(stage_id, req_state): continue - if req_id in self._companion_map and not self._all_companions_done(req_id): - self._deferred_parents[req_id] = { - "stage_id": stage_id, - "output": raw_output, - } + if self._cfg_tracker.has_companions(req_id) and not self._cfg_tracker.all_companions_done(req_id): + self._cfg_tracker.defer_parent(req_id, raw_output, stage_id) else: await self._forward_to_next_stage(req_id, stage_id, raw_output, req_state) + def _build_pd_decode_params(self, req_id: str, sp: Any) -> Any: + """Build decode-side sampling params with KV transfer params for PD routing. + + Clones the sampling params and injects kv_transfer_params that tell the + decode engine where to pull the KV cache from (prefill engine's bootstrap addr). + """ + sp = sp.clone() + if sp.extra_args is None: + sp.extra_args = {} + + # Get KV params captured from the prefill output (must include remote_request_id). + kv_prefill_params = self._pd_kv_params.pop(req_id, None) + if not kv_prefill_params or "remote_request_id" not in kv_prefill_params: + raise RuntimeError( + f"[Orchestrator][PD] Missing prefill kv_transfer_params.remote_request_id for req={req_id}" + ) + + decode_kv_params: dict[str, Any] = { + "transfer_id": f"xfer-{req_id}", + } + + if self._pd_bootstrap_addr: + decode_kv_params["remote_bootstrap_addr"] = self._pd_bootstrap_addr + + if self._pd_prefill_engine_id: + decode_kv_params["remote_engine_id"] = self._pd_prefill_engine_id + + # Overlay params from prefill side (includes remote_request_id set by monkey patch). + decode_kv_params.update(kv_prefill_params) + + # Ensure these flags are set correctly after any overlay. + decode_kv_params["do_remote_prefill"] = True + decode_kv_params["do_remote_decode"] = False + if not decode_kv_params.get("transfer_id"): + decode_kv_params["transfer_id"] = f"xfer-{req_id}" + + sp.extra_args["kv_transfer_params"] = decode_kv_params + + logger.debug( + "[Orchestrator][PD] decode kv_transfer_params for req=%s: %s", + req_id, + decode_kv_params, + ) + return sp + def _build_stage_metrics( self, stage_id: int, @@ -511,6 +614,9 @@ async def _forward_to_next_stage( stage_id: int, output: Any, req_state: OrchestratorRequestState, + *, + is_streaming_session: bool = False, + is_final_update: bool = False, ) -> None: """Forward output from current stage to the next stage. @@ -520,6 +626,7 @@ async def _forward_to_next_stage( next_stage_id = stage_id + 1 next_client = self.stage_clients[next_stage_id] params = req_state.sampling_params_list[next_stage_id] + next_stage_resumable = is_streaming_session and not is_final_update if next_client.stage_type == "diffusion": self.stage_clients[stage_id].set_engine_outputs([output]) @@ -535,20 +642,7 @@ async def _forward_to_next_stage( else: diffusion_prompt = req_state.prompt - # Attach CFG companion KV request IDs so the diffusion model - # runner can fetch companion KV caches alongside the primary one. - cfg_ids = self._companion_map.get(req_id) - if cfg_ids: - from vllm_omni.inputs.data import OmniDiffusionSamplingParams - - if isinstance(params, OmniDiffusionSamplingParams): - params = copy.deepcopy(params) - params.cfg_kv_request_ids = cfg_ids - logger.info( - "[Orchestrator] Attaching cfg_kv_request_ids=%s to req %s", - cfg_ids, - req_id, - ) + params = self._cfg_tracker.attach_cfg_request_ids(req_id, params) source_stage_ids = list(getattr(next_client, "engine_input_source", None) or [stage_id]) kv_sender_info = self._build_kv_sender_info(sender_stage_ids=source_stage_ids) @@ -569,6 +663,52 @@ async def _forward_to_next_stage( req_state.stage_submit_ts[next_stage_id] = _time.time() return + # PD disaggregation: prefill → decode routing uses original prompt + KV transfer params + if self._pd_pair is not None and (stage_id, next_stage_id) == self._pd_pair: + # Save prefill stage outputs so thinker2talker can merge embeddings later + self.stage_clients[stage_id].set_engine_outputs([output]) + + params = self._build_pd_decode_params(req_id, params) + + # Use the original user prompt for the decode stage (not processed embeddings) + original_prompt = req_state.prompt + raw_decode_inputs = [original_prompt] if not isinstance(original_prompt, list) else original_prompt + + decode_inputs: list[dict[str, Any]] = [] + for decode_input in raw_decode_inputs: + if isinstance(decode_input, dict): + decode_inputs.append(decode_input) + continue + prompt_token_ids = getattr(decode_input, "prompt_token_ids", None) + if prompt_token_ids is None: + raise TypeError( + "[Orchestrator][PD] decode input must be dict or have prompt_token_ids, " + f"got {type(decode_input).__name__} for req={req_id}" + ) + decode_inputs.append({"prompt_token_ids": list(prompt_token_ids)}) + + for decode_input in decode_inputs: + request = build_engine_core_request_from_tokens( + request_id=req_id, + prompt=decode_input, + params=params, + model_config=self.stage_vllm_configs[next_stage_id].model_config, + mm_features=req_state.mm_features, # Pass mm_features for M-RoPE + ) + request.external_req_id = request.request_id + + self.output_processors[next_stage_id].add_request( + request=request, + prompt=None, + parent_req=None, + request_index=0, + queue=None, + ) + await next_client.add_request_async(request) + + req_state.stage_submit_ts[next_stage_id] = _time.time() + return + self.stage_clients[stage_id].set_engine_outputs([output]) # Process inputs for next stage @@ -576,6 +716,7 @@ async def _forward_to_next_stage( next_inputs = next_client.process_engine_inputs( stage_list=self.stage_clients, prompt=req_state.prompt, + streaming_context=(req_state.streaming if req_state.streaming.enabled else None), ) except Exception: logger.exception( @@ -630,11 +771,17 @@ async def _forward_to_next_stage( next_stage_id, ) + # Only AR thinker stages consume encoder mm_features; downstream + # (talker/code2wav/…) must not see them (avoids encoder-cache misses). + _ms = getattr(next_client, "model_stage", None) + _mm_features = req_state.mm_features if _ms == "thinker" else None request = build_engine_core_request_from_tokens( request_id=req_id, prompt=next_input, params=params, model_config=self.stage_vllm_configs[next_stage_id].model_config, + mm_features=_mm_features, + resumable=next_stage_resumable, ) # TODO: Here we directly use the req id to assign. @@ -676,6 +823,13 @@ async def _process_stage_outputs(self, stage_id: int, raw_outputs: EngineCoreOut raw_outputs.timestamp, None, ) + for eco in raw_outputs.outputs: + if not hasattr(eco, "request_id"): + continue + req_state = self.request_states.get(eco.request_id) + if req_state: + req_state.streaming.segment_finished = eco.is_segment_finished + req_state.streaming.new_prompt_len_snapshot = eco.new_prompt_len_snapshot if processed.reqs_to_abort: await self.stage_clients[stage_id].abort_requests_async(processed.reqs_to_abort) @@ -711,19 +865,22 @@ async def _handle_add_request(self, msg: dict[str, Any]) -> None: # Track request state - use original_prompt so downstream stages # (e.g. thinker2talker) can access the raw dict with multi_modal_data. + request = prompt + is_streaming = bool(getattr(request, "resumable", False)) req_state = OrchestratorRequestState( request_id=request_id, prompt=original_prompt, sampling_params_list=sampling_params_list, final_stage_id=final_stage_id, + mm_features=getattr(prompt, "mm_features", None), # Save mm_features for PD ) + req_state.streaming.enabled = is_streaming req_state.stage_submit_ts[stage_id] = _time.time() self.request_states[request_id] = req_state # Stage-0 prompt is already a fully-formed OmniEngineCoreRequest # (pre-processed by AsyncOmniEngine.add_request, output processor # already registered there) - submit directly. - request = prompt stage_client = self.stage_clients[stage_id] if stage_client.stage_type == "diffusion": if isinstance(prompt, list): @@ -759,6 +916,7 @@ async def _handle_streaming_update(self, msg: dict[str, Any]) -> None: if "sampling_params_list" in msg and msg["sampling_params_list"]: req_state.sampling_params_list = msg["sampling_params_list"] + req_state.streaming.enabled = True req_state.stage_submit_ts[stage_id] = _time.time() stage_client = self.stage_clients[stage_id] @@ -847,13 +1005,7 @@ async def _handle_add_companion(self, msg: dict[str, Any]) -> None: companion_prompt = msg["prompt"] sampling_params_list = msg["sampling_params_list"] - # Register companion mapping - if parent_id not in self._companion_map: - self._companion_map[parent_id] = {} - self._companion_map[parent_id][role] = companion_id - self._companion_ids.add(companion_id) - self._companion_to_parent[companion_id] = parent_id - self._companion_done.setdefault(parent_id, set()) + self._cfg_tracker.register_companion(parent_id, role, companion_id) companion_state = OrchestratorRequestState( request_id=companion_id, @@ -878,22 +1030,10 @@ async def _handle_add_companion(self, msg: dict[str, Any]) -> None: async def _handle_abort(self, msg: dict[str, Any]) -> None: """Handle an abort message from the main thread.""" request_ids = msg["request_ids"] - # Also abort any CFG companions for aborted parents - companion_ids_to_abort: list[str] = [] - for req_id in request_ids: - role_map = self._companion_map.pop(req_id, {}) - for cid in role_map.values(): - companion_ids_to_abort.append(cid) - self._companion_ids.discard(cid) - self._companion_to_parent.pop(cid, None) - self.request_states.pop(cid, None) - self._companion_done.pop(req_id, None) - self._deferred_parents.pop(req_id, None) - - all_ids_to_abort = list(request_ids) + companion_ids_to_abort + all_ids_to_abort = self._cfg_tracker.abort_parents(request_ids) for stage_id in range(self.num_stages): await self.stage_clients[stage_id].abort_requests_async(all_ids_to_abort) - for req_id in request_ids: + for req_id in all_ids_to_abort: self.request_states.pop(req_id, None) logger.info("[Orchestrator] Aborted request(s) %s", request_ids) diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py index 43d02e85b8..67b4dd1650 100644 --- a/vllm_omni/engine/output_processor.py +++ b/vllm_omni/engine/output_processor.py @@ -118,9 +118,10 @@ def _consolidate_multimodal_tensors(self) -> None: if isinstance(v, list) and v and isinstance(v[0], torch.Tensor): try: if k == "audio": - # When the audio tensor shape is inconsistent, torch.cat will fail. - # We need to use torch.cat in -1 dimension. - continue + # Concatenate delta audio chunks (1-D) into the full waveform. + # Each entry is a per-step slice; flatten to -1 so chunks with + # inconsistent leading dims can still be joined on the sample axis. + self.mm_accumulated[k] = torch.cat([t.reshape(-1) for t in v], dim=0) elif k == "sr": # Sample rate is a constant scalar, keep last value. self.mm_accumulated[k] = v[-1] @@ -232,10 +233,9 @@ def _new_completion_output( # Reuse base text/logprobs logic, then annotate with pooling_result. base_output = super()._new_completion_output(token_ids, finish_reason, stop_reason, routed_experts) try: + if not hasattr(base_output, "multimodal_output"): + setattr(base_output, "multimodal_output", {}) if self.mm_accumulated is not None: - # Attach accumulated multimodal dict on the completion output - if not hasattr(base_output, "multimodal_output"): - setattr(base_output, "multimodal_output", {}) mm_out = getattr(base_output, "multimodal_output") if isinstance(mm_out, dict): for k, v in self.mm_accumulated.items(): diff --git a/vllm_omni/engine/stage_engine_core_client.py b/vllm_omni/engine/stage_engine_core_client.py index 193ab6a656..073c2340a4 100644 --- a/vllm_omni/engine/stage_engine_core_client.py +++ b/vllm_omni/engine/stage_engine_core_client.py @@ -14,7 +14,9 @@ from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import AsyncMPClient, DPLBAsyncMPClient -from vllm_omni.distributed.omni_connectors.utils.initialization import KV_TRANSFER_PORT_OFFSET +from vllm_omni.distributed.omni_connectors.utils.initialization import ( + KV_TRANSFER_PORT_OFFSET, +) from vllm_omni.engine.stage_init_utils import StageMetadata if TYPE_CHECKING: @@ -294,6 +296,8 @@ def _initialize_kv_sender_endpoint(self) -> None: from_stage = omni_kv_config.get("omni_from_stage", from_stage) try: + # Orchestrator always reports rank-0's port; receiver + # workers add their own local_rank * KV_RANK_PORT_STRIDE. sender_port = int(base_port) + KV_TRANSFER_PORT_OFFSET + int(from_stage) except (TypeError, ValueError): logger.warning( @@ -332,6 +336,7 @@ def get_kv_sender_info( self._kv_sender_host = self._resolve_contact_host() if self._kv_sender_host is None: return None + # rank-0 base port; receiver workers adjust per KV_RANK_PORT_STRIDE. return { "host": self._kv_sender_host, "zmq_port": base_port + kv_transfer_port_offset + int(self.stage_id), @@ -345,11 +350,21 @@ def process_engine_inputs( self, stage_list: list[Any], prompt: OmniTokensPrompt | list[OmniTokensPrompt] | None = None, + streaming_context: Any | None = None, ) -> list[OmniTokensPrompt]: """Process inputs from upstream stages.""" from vllm_omni.inputs.data import OmniTokensPrompt if self.custom_process_input_func is not None: + # Keep legacy arg call for non-streaming processors. + if bool(getattr(streaming_context, "enabled", False)): + return self.custom_process_input_func( + stage_list, + self.engine_input_source, + prompt, + self.requires_multimodal_data, + streaming_context, + ) return self.custom_process_input_func( stage_list, self.engine_input_source, diff --git a/vllm_omni/engine/stage_engine_core_proc.py b/vllm_omni/engine/stage_engine_core_proc.py index 05d8f107c2..689378a798 100644 --- a/vllm_omni/engine/stage_engine_core_proc.py +++ b/vllm_omni/engine/stage_engine_core_proc.py @@ -37,8 +37,6 @@ logger = init_logger(__name__) -_HANDSHAKE_POLL_TIMEOUT_S = 600 - class StageEngineCoreProc(EngineCoreProc): """Stage-specific engine core process for vLLM-Omni. @@ -145,13 +143,14 @@ def complete_stage_handshake( handshake_address: str, addresses: EngineZmqAddresses, vllm_config: VllmConfig, + handshake_timeout: int, ) -> None: """Perform the HELLO/INIT/READY handshake with an already-spawned proc. On failure the process is terminated before re-raising. """ try: - _perform_handshake(proc, handshake_address, addresses, vllm_config) + _perform_handshake(proc, handshake_address, addresses, vllm_config, handshake_timeout) except Exception: shutdown([proc]) raise @@ -162,6 +161,7 @@ def _perform_handshake( handshake_address: str, addresses: EngineZmqAddresses, vllm_config: VllmConfig, + handshake_timeout: int, ) -> None: """Run the HELLO / INIT / READY handshake with the subprocess.""" with zmq_socket_ctx(handshake_address, zmq.ROUTER, bind=True) as handshake_socket: @@ -169,7 +169,7 @@ def _perform_handshake( poller.register(handshake_socket, zmq.POLLIN) poller.register(proc.sentinel, zmq.POLLIN) - identity, msg = _recv(poller, handshake_socket, proc, "HELLO") + identity, msg = _recv(poller, handshake_socket, proc, "HELLO", handshake_timeout) if msg.get("status") != "HELLO": raise RuntimeError(f"Expected HELLO, got: {msg}") @@ -179,7 +179,7 @@ def _perform_handshake( ) handshake_socket.send_multipart([identity, msgspec.msgpack.encode(init_payload)]) - identity, msg = _recv(poller, handshake_socket, proc, "READY") + identity, msg = _recv(poller, handshake_socket, proc, "READY", handshake_timeout) if msg.get("status") != "READY": raise RuntimeError(f"Expected READY, got: {msg}") num_gpu_blocks = msg.get("num_gpu_blocks") @@ -192,13 +192,18 @@ def _recv( handshake_socket: zmq.Socket, proc: BaseProcess, expected: str, + timeout_s: int = 600, ) -> tuple[bytes, dict]: """Wait for one handshake message; raise if the process dies first.""" - timeout_ms = _HANDSHAKE_POLL_TIMEOUT_S * 1000 + timeout_ms = timeout_s * 1000 while True: events = dict(poller.poll(timeout=timeout_ms)) if not events: - raise TimeoutError(f"Timed out waiting for {expected} from StageEngineCoreProc") + raise TimeoutError( + f"Timed out waiting for {expected} from StageEngineCoreProc after {timeout_s}s. " + f"This typically indicates model loading or initialization is taking too long. " + f"Consider increasing `stage_init_timeout` for large models." + ) if handshake_socket in events: identity, raw = handshake_socket.recv_multipart() return identity, msgspec.msgpack.decode(raw) diff --git a/vllm_omni/engine/stage_init_utils.py b/vllm_omni/engine/stage_init_utils.py index 09195faeca..89dfdc163c 100644 --- a/vllm_omni/engine/stage_init_utils.py +++ b/vllm_omni/engine/stage_init_utils.py @@ -13,7 +13,7 @@ import multiprocessing as mp import os import time -from collections.abc import Callable +from collections.abc import Callable, Sequence from dataclasses import dataclass from typing import Any, Literal @@ -101,8 +101,110 @@ def resolve_worker_cls(engine_args: dict[str, Any]) -> None: raise ValueError(f"Unknown worker_type: {worker_type}") -def inject_kv_stage_info(stage_cfg: Any, stage_id: int) -> None: - """Inject stage metadata into omni_kv_config when present.""" +def _get_attr_or_item(obj: Any, key: str, default: Any = None) -> Any: + """Read *key* from *obj* regardless of whether it's a dict or object.""" + if hasattr(obj, "get"): + return obj.get(key, default) + return getattr(obj, key, default) + + +def _tp_size_for_stage(stage_configs: Sequence[Any], stage_id: Any) -> int | None: + """Resolve tensor_parallel_size for *stage_id* from the loaded stage configs.""" + id_strs = {str(stage_id)} + try: + id_strs.add(str(int(stage_id))) + except (TypeError, ValueError): + pass + + for stage_cfg in stage_configs: + if str(getattr(stage_cfg, "stage_id", None)) not in id_strs: + continue + engine_args = getattr(stage_cfg, "engine_args", None) + if engine_args is None: + return 1 + parallel_config = _get_attr_or_item(engine_args, "parallel_config") + if parallel_config is not None: + tp = _get_attr_or_item(parallel_config, "tensor_parallel_size", 1) + else: + tp = _get_attr_or_item(engine_args, "tensor_parallel_size", 1) + try: + return max(1, int(tp)) + except (TypeError, ValueError): + return 1 + return None + + +def _inject_inferred_kv_tp_topology( + omni_kv: Any, + stage_id: int, + stage_configs: Sequence[Any], + engine_input_source: Sequence[int] | None = None, +) -> None: + """Infer adjacent-stage TP topology and inject it into omni_kv_config. + + This keeps heterogeneous TP working without requiring user-authored + rank_mapping blocks in config files. + """ + if omni_kv is None: + return + + if hasattr(omni_kv, "get"): + need_send = bool(omni_kv.get("need_send_cache", False)) + need_recv = bool(omni_kv.get("need_recv_cache", False)) + omni_from_stage = omni_kv.get("omni_from_stage") + omni_to_stage = omni_kv.get("omni_to_stage") + rank_mapping = omni_kv.get("rank_mapping") + else: + need_send = bool(getattr(omni_kv, "need_send_cache", False)) + need_recv = bool(getattr(omni_kv, "need_recv_cache", False)) + omni_from_stage = getattr(omni_kv, "omni_from_stage", None) + omni_to_stage = getattr(omni_kv, "omni_to_stage", None) + rank_mapping = getattr(omni_kv, "rank_mapping", None) + + if not need_send and not need_recv: + return + + current_tp = _tp_size_for_stage(stage_configs, stage_id) + if current_tp is None: + return + + peer_stage_id = None + from_tp = None + to_tp = None + if str(omni_from_stage) == str(stage_id): + peer_stage_id = omni_to_stage + from_tp = current_tp + to_tp = _tp_size_for_stage(stage_configs, peer_stage_id) + elif str(omni_to_stage) == str(stage_id): + peer_stage_id = omni_from_stage + from_tp = _tp_size_for_stage(stage_configs, peer_stage_id) + to_tp = current_tp + elif need_recv and engine_input_source: + peer_stage_id = engine_input_source[0] + from_tp = _tp_size_for_stage(stage_configs, peer_stage_id) + to_tp = current_tp + + if from_tp is None or to_tp is None: + return + + if not isinstance(rank_mapping, dict): + rank_mapping = {} + rank_mapping.setdefault("from_tp", int(from_tp)) + rank_mapping.setdefault("to_tp", int(to_tp)) + + if hasattr(omni_kv, "__setitem__"): + omni_kv["rank_mapping"] = rank_mapping + else: + setattr(omni_kv, "rank_mapping", rank_mapping) + + +def inject_kv_stage_info(stage_cfg: Any, stage_id: int, stage_configs: Sequence[Any] | None = None) -> None: + """Inject stage_id, engine_input_source, and inferred TP topology into omni_kv_config. + + When *stage_configs* is provided, also infers from_tp/to_tp for + heterogeneous TP topologies so the KV transfer manager can compute + rank mappings automatically. + """ try: engine_args = stage_cfg.engine_args if hasattr(engine_args, "get"): @@ -125,6 +227,14 @@ def inject_kv_stage_info(stage_cfg: Any, stage_id: int) -> None: omni_kv.setdefault("engine_input_source", list(engine_input_source)) elif hasattr(omni_kv, "__setitem__") and "engine_input_source" not in omni_kv: omni_kv["engine_input_source"] = list(engine_input_source) + + if stage_configs: + _inject_inferred_kv_tp_topology( + omni_kv, + stage_id=stage_id, + stage_configs=stage_configs, + engine_input_source=engine_input_source, + ) except Exception as e: logger.debug("Failed to inject stage info into omni_kv_config: %s", e) @@ -168,6 +278,20 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata: stage_id: int = stage_config.stage_id stage_type: Literal["llm", "diffusion"] = getattr(stage_config, "stage_type", "llm") engine_args = stage_config.engine_args + + if current_omni_platform.is_rocm(): + if engine_args.get("attention_backend") is None: + from vllm._aiter_ops import rocm_aiter_ops + + if rocm_aiter_ops.is_enabled(): + engine_args["attention_backend"] = "ROCM_AITER_FA" + # Before vLLM v0.19.0, the default attention backend is TRITON_ATTN for ROCm. + # Since vLLM v0.19.0, the default attention backend is ROCM_ATTN for ROCm. + # However, the compatibility of ROCM_ATTN with Omni is not guaranteed. + # Therefore, we still use TRITON_ATTN as the default attention backend, + # when the selected_backend is not specified. + engine_args["attention_backend"] = "TRITON_ATTN" + runtime_cfg = getattr(stage_config, "runtime", {}) engine_input_source: list[int] = getattr(stage_config, "engine_input_source", []) final_output: bool = getattr(stage_config, "final_output", False) @@ -178,8 +302,9 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata: default_sampling_params: OmniSamplingParams = SPClass(**default_sp) custom_process_input_func: Callable | None = None - if hasattr(stage_config, "custom_process_input_func"): - mod_path, fn_name = stage_config.custom_process_input_func.rsplit(".", 1) + _cpif_path = getattr(stage_config, "custom_process_input_func", None) + if _cpif_path: + mod_path, fn_name = _cpif_path.rsplit(".", 1) custom_process_input_func = getattr(importlib.import_module(mod_path), fn_name) prompt_expand_func: Callable | None = None @@ -309,6 +434,20 @@ def build_vllm_config( filtered_engine_args_dict = filter_dataclass_kwargs(OmniEngineArgs, engine_args_dict) omni_engine_args = OmniEngineArgs(**filtered_engine_args_dict) + + # Multi-stage pipelines (qwen3_tts code2wav, etc.) set max_model_len + # larger than HF max_position_embeddings by design. vLLM's validator + # rejects that without the env flag. + if filtered_engine_args_dict.get("max_model_len") is not None and not os.environ.get( + "VLLM_ALLOW_LONG_MAX_MODEL_LEN" + ): + os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1" + logger.debug( + "Auto-set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 for stage %s (max_model_len=%s).", + stage_config.stage_id, + filtered_engine_args_dict["max_model_len"], + ) + vllm_config = omni_engine_args.create_engine_config( usage_context=UsageContext.LLM_CLASS, headless=headless, @@ -321,7 +460,7 @@ def build_vllm_config( def acquire_device_locks( stage_id: int, engine_args_dict: dict[str, Any], - stage_init_timeout: int = 300, + stage_init_timeout: int, ) -> list[int]: """Acquire exclusive file locks on devices needed by this stage. @@ -514,7 +653,9 @@ def initialize_diffusion_stage( model: str, stage_cfg: Any, metadata: StageMetadata, + stage_init_timeout: int, batch_size: int = 1, + use_inline: bool = False, ) -> Any: """Build a diffusion stage client. @@ -522,14 +663,16 @@ def initialize_diffusion_stage( model: Model name or path. stage_cfg: Stage configuration. metadata: Extracted stage metadata. + stage_init_timeout: Timeout in seconds for stage initialization handshake batch_size: Maximum number of requests to batch together in the diffusion engine. Passed through to ``StageDiffusionClient`` and ultimately to ``AsyncOmni``. + use_inline: If True, uses the inline diffusion client instead of subprocess. """ - from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient + from vllm_omni.diffusion.stage_diffusion_client import create_diffusion_client od_config = build_diffusion_config(model, stage_cfg, metadata) - return StageDiffusionClient(model, od_config, metadata, batch_size=batch_size) + return create_diffusion_client(model, od_config, metadata, stage_init_timeout, batch_size, use_inline) def _shutdown_or_close_resource(resource: Any, resource_name: str, stage_id: int) -> None: diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 129ef3c99d..9606cc80d0 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -78,7 +78,6 @@ def __init__(self, *args: Any, model: str = "", **kwargs: Any) -> None: self.final_output_task: asyncio.Task | None = None self.config_path = self.engine.config_path - self.stage_configs = self.engine.stage_configs self.tts_max_instructions_length = kwargs.get("tts_max_instructions_length", None) self.input_processor = self.engine.input_processor @@ -209,6 +208,13 @@ async def generate( # Start final output dispatcher on the first call to generate() self._final_output_handler() + # Expand sampling params for PD disaggregation (user may provide N-1 params) + if ( + sampling_params_list is not None + and isinstance(sampling_params_list, Sequence) + and not isinstance(sampling_params_list, (str, bytes)) + ): + sampling_params_list = self._maybe_expand_sampling_params(list(sampling_params_list)) sampling_params_list = self.resolve_sampling_params_list(sampling_params_list) # Track per-request metrics @@ -228,20 +234,27 @@ async def generate( req_state.metrics = metrics self.request_states[request_id] = req_state + # PD disaggregation: modify prefill-stage sampling params per request + req_sp_list = list(sampling_params_list) + pd_pair = self._get_pd_separation_pair() + if pd_pair is not None: + p_id = pd_pair[0] + req_sp_list[p_id] = self._prepare_prefill_sampling_params(request_id, req_sp_list[p_id]) + # Add request(s) to stage 0. For streaming inputs, submit # chunks incrementally through streaming_update. if isinstance(prompt, AsyncGenerator): input_stream_task = await self._add_streaming_input_request( request_id=request_id, input_stream=prompt, - sampling_params_list=sampling_params_list, + sampling_params_list=req_sp_list, final_stage_id=final_stage_id_for_e2e, ) else: await self.engine.add_request_async( request_id=request_id, prompt=prompt, - sampling_params_list=sampling_params_list, + sampling_params_list=req_sp_list, final_stage_id=final_stage_id_for_e2e, ) submit_ts = time.time() @@ -296,7 +309,6 @@ async def _add_streaming_input_request( if not stage0_params.skip_clone: stage0_params = stage0_params.clone() stage0_params.skip_clone = True - stage0_params.output_kind = RequestOutputKind.DELTA has_submitted_first_chunk = False diff --git a/vllm_omni/entrypoints/cfg_companion_tracker.py b/vllm_omni/entrypoints/cfg_companion_tracker.py deleted file mode 100644 index 9c2e835f07..0000000000 --- a/vllm_omni/entrypoints/cfg_companion_tracker.py +++ /dev/null @@ -1,233 +0,0 @@ -"""CFG companion request tracker for the Omni orchestrator. - -Encapsulates all bookkeeping for Classifier-Free Guidance companion -requests (prompt expansion, parent/companion ID mapping, completion -tracking, deferred forwarding, failure propagation, and timeouts) -so that ``Omni._run_generation`` stays clean. -""" - -from __future__ import annotations - -import copy -import logging -import os -import time -from collections.abc import Callable, Sequence -from typing import Any - -from vllm_omni.distributed.omni_connectors.adapter import try_send_via_connector -from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams - -logger = logging.getLogger(__name__) - - -class CfgCompanionTracker: - """Manages CFG companion request lifecycle in the orchestrator scheduling loop.""" - - def __init__( - self, - prompt_expand_func: Callable[..., Any] | None, - stage0_sampling_params: Any, - timeout_s: float | None = None, - ) -> None: - self._expand_func = prompt_expand_func - self._sp0 = stage0_sampling_params - self._timeout_s = ( - timeout_s if timeout_s is not None else float(os.environ.get("VLLM_CFG_PENDING_TIMEOUT_S", "120")) - ) - - self._companion_map: dict[str, dict[str, str]] = {} # parent -> {role: companion_id} - self._companion_ids: set[str] = set() - self._companion_to_parent: dict[str, str] = {} # companion -> parent - self._done: dict[str, set[str]] = {} # parent -> completed companion ids - self._pending_parents: dict[str, dict[str, Any]] = {} # parent -> deferred result - self._failed_parents: set[str] = set() - - @property - def is_active(self) -> bool: - return bool(self._companion_ids) - - @property - def num_companions(self) -> int: - return len(self._companion_ids) - - @property - def stage0_sampling_params(self) -> Any: - return self._sp0 - - def expand_prompts( - self, - request_id_to_prompt: dict[str, Any], - ) -> list[tuple[str, Any]]: - """Expand user prompts into ``(companion_id, prompt)`` pairs via model-specific func.""" - if not self._expand_func: - return [] - - pairs: list[tuple[str, Any]] = [] - for rid, prompt in request_id_to_prompt.items(): - expanded = self._expand_func(prompt, self._sp0) - if not expanded: - continue - role_map: dict[str, str] = {} - for ep in expanded: - cid = f"{rid}{ep.request_id_suffix}" - role_map[ep.role] = cid - self._companion_ids.add(cid) - self._companion_to_parent[cid] = rid - pairs.append((cid, ep.prompt)) - self._companion_map[rid] = role_map - self._done[rid] = set() - - logger.debug( - "CFG expansion: %d parent(s) -> %d companion(s)", - len(self._companion_map), - len(self._companion_ids), - ) - return pairs - - def is_companion(self, req_id: str) -> bool: - return req_id in self._companion_ids - - def has_companions(self, parent_id: str) -> bool: - return parent_id in self._companion_map - - def all_companions_done(self, parent_id: str) -> bool: - role_map = self._companion_map.get(parent_id, {}) - done_set = self._done.get(parent_id, set()) - return all(cid in done_set for cid in role_map.values()) - - def get_companion_request_ids(self, parent_id: str) -> dict[str, str]: - """Return ``{role: companion_request_id}`` for a parent.""" - return self._companion_map.get(parent_id, {}) - - def is_parent_failed(self, parent_id: str) -> bool: - return parent_id in self._failed_parents - - # -- Lifecycle events -- - - def on_companion_error(self, companion_id: str) -> tuple[str | None, bool]: - """Record failure. Returns ``(parent_id, parent_was_aborted)``.""" - parent_id = self._companion_to_parent.get(companion_id) - if parent_id is None: - return None, False - self._failed_parents.add(parent_id) - logger.error("CFG companion %s failed; marking parent %s as failed", companion_id, parent_id) - aborted = parent_id in self._pending_parents - if aborted: - self._pending_parents.pop(parent_id, None) - return parent_id, aborted - - def on_companion_completed(self, companion_id: str) -> str | None: - """Mark done. Returns parent_id only if parent is pending and all companions finished.""" - parent_id = self._companion_to_parent.get(companion_id) - if parent_id is None: - return None - self._done[parent_id].add(companion_id) - logger.debug("CFG companion %s completed (parent=%s)", companion_id, parent_id) - if parent_id in self._pending_parents and self.all_companions_done(parent_id): - return parent_id - return None - - def consume_parent_failure(self, parent_id: str) -> None: - self._failed_parents.discard(parent_id) - - # -- Deferred parent management -- - - def defer_parent(self, parent_id: str, engine_outputs: Any, stage_id: int) -> None: - """Hold parent result while waiting for companions to finish.""" - self._pending_parents[parent_id] = { - "engine_outputs": engine_outputs, - "stage_id": stage_id, - "pending_since": time.time(), - } - logger.debug("Parent %s deferred, waiting for CFG companions", parent_id) - - def pop_pending_parent(self, parent_id: str) -> dict[str, Any] | None: - return self._pending_parents.pop(parent_id, None) - - def check_timeouts(self) -> list[str]: - """Return and remove parent IDs that exceeded the pending timeout.""" - if not self._pending_parents: - return [] - now = time.time() - timed_out: list[str] = [] - for pid in list(self._pending_parents): - pending_since = self._pending_parents[pid].get("pending_since", now) - if now - pending_since > self._timeout_s: - self._pending_parents.pop(pid) - self._failed_parents.discard(pid) - timed_out.append(pid) - logger.error("Parent %s timed out waiting for CFG companions (>%.0fs)", pid, self._timeout_s) - return timed_out - - # -- Forward parent with CFG KV -- - - def forward_parent_with_cfg( - self, - req_id: str, - parent_result: dict[str, Any], - stage_list: Sequence[Any], - connectors: dict[tuple[str, str], Any], - sampling_params_list: Sequence[OmniSamplingParams], - request_id_to_prompt: dict[str, Any], - final_stage_id_to_prompt: dict[str, int], - metrics: Any, - remaining_by_stage: list[int], - ) -> bool: - """Forward a parent request to the next stage with CFG KV request IDs attached.""" - stage_id = parent_result["stage_id"] - next_stage_id = stage_id + 1 - if next_stage_id > final_stage_id_to_prompt.get(req_id, 0): - return True - - next_stage = stage_list[next_stage_id] - try: - with metrics.stage_postprocess_timer(stage_id, req_id): - next_inputs = next_stage.process_engine_inputs( - stage_list, - [request_id_to_prompt[req_id]], - source_outputs_override=parent_result["engine_outputs"], - ) - except Exception as e: - logger.exception( - "Process engine inputs error for req %s at stage %d: %s", - req_id, - next_stage_id, - e, - ) - return False - - sp_next = copy.deepcopy(sampling_params_list[next_stage_id]) - if isinstance(sp_next, OmniDiffusionSamplingParams): - sp_next.cfg_kv_request_ids = self.get_companion_request_ids(req_id) - logger.info( - "Attaching cfg_kv_request_ids=%s to request %s", - sp_next.cfg_kv_request_ids, - req_id, - ) - - connector_key = (str(stage_id), str(next_stage_id)) - connector = connectors.get(connector_key) - sent_via_connector = False - if connector: - sent_via_connector = try_send_via_connector( - connector=connector, - stage_id=stage_id, - next_stage_id=next_stage_id, - req_id=req_id, - next_inputs=next_inputs, - sampling_params=sp_next, - original_prompt=request_id_to_prompt[req_id], - next_stage_queue_submit_fn=stage_list[next_stage_id].submit, - metrics=metrics, - ) - - if not sent_via_connector: - raise RuntimeError( - f"Failed to send CFG request {req_id} to stage-{next_stage_id} via connector. " - "Configure a connector for this edge or inspect connector logs for details." - ) - - logger.debug("Forwarded CFG-enabled request %s to stage-%d", req_id, next_stage_id) - remaining_by_stage[next_stage_id] += 1 - return True diff --git a/vllm_omni/entrypoints/chat_utils.py b/vllm_omni/entrypoints/chat_utils.py index 8970e58984..4c3d311ec5 100644 --- a/vllm_omni/entrypoints/chat_utils.py +++ b/vllm_omni/entrypoints/chat_utils.py @@ -2,7 +2,7 @@ async def extract_audio_from_video_async(video_url: str) -> tuple[np.ndarray, int | float]: - """Extract audio from a video URL using librosa. + """Extract audio from a video URL using vllm's load_audio. Returns a (audio_array, sample_rate) tuple compatible with audio format. All blocking I/O operations are run in a thread pool. @@ -26,9 +26,9 @@ def _write_temp_file_sync(data: bytes, suffix: str) -> str: return temp_file.name def _load_audio_sync(file_path: str) -> tuple[np.ndarray, int | float]: - import librosa + from vllm.multimodal.media.audio import load_audio - return librosa.load(file_path, sr=16000) + return load_audio(file_path, sr=16000) def _cleanup_file_sync(file_path: str) -> None: try: diff --git a/vllm_omni/entrypoints/cli/benchmark/serve.py b/vllm_omni/entrypoints/cli/benchmark/serve.py index 906e8851a4..d281432e59 100644 --- a/vllm_omni/entrypoints/cli/benchmark/serve.py +++ b/vllm_omni/entrypoints/cli/benchmark/serve.py @@ -1,4 +1,5 @@ import argparse +import os from vllm.benchmarks.serve import add_cli_args @@ -6,15 +7,149 @@ from vllm_omni.entrypoints.cli.benchmark.base import OmniBenchmarkSubcommandBase +def add_daily_omni_cli_args(parser: argparse.ArgumentParser) -> None: + """Add CLI arguments specific to Daily-Omni dataset. + + This function should be called by the CLI entrypoint to add additional + arguments for daily-omni benchmark support. + + Args: + parser: The ArgumentParser instance to extend + """ + # Daily-Omni specific arguments + daily_omni_group = parser.add_argument_group("Daily-Omni Dataset Options") + + daily_omni_group.add_argument( + "--daily-omni-qa-json", + type=str, + default=None, + help="Path to local upstream qa.json. When set, QA rows are read from this file and " + "the HuggingFace dataset is not loaded (no network). Use with --daily-omni-video-dir " + "for fully offline runs. --dataset-path / Hub split flags are then ignored for QA loading.", + ) + daily_omni_group.add_argument( + "--daily-omni-video-dir", + type=str, + default=None, + help="Root directory of extracted Daily-Omni videos (contents of Videos.tar: " + "each video_id in its own subdir with {video_id}_video.mp4). " + "When using file URLs, you MUST start the vLLM server with " + "--allowed-local-media-path set to this same directory (or a parent), " + "otherwise requests fail with 'Cannot load local files without " + "--allowed-local-media-path'.", + ) + daily_omni_group.add_argument( + "--daily-omni-inline-local-video", + action="store_true", + default=False, + help="For local videos only: embed MP4 as base64 data URLs in benchmark " + "requests so the server does not need --allowed-local-media-path. " + "Increases request size and client memory; use for small --num-prompts. " + "When using --daily-omni-input-mode audio or all, local WAV files are " + "embedded the same way.", + ) + daily_omni_group.add_argument( + "--daily-omni-input-mode", + type=str, + choices=["all", "visual", "audio"], + default="all", + help="Daily-Omni input protocol (mirrors upstream Lliar-liar/Daily-Omni " + "--input_mode). 'visual': video only (default). 'audio': WAV only, " + "requires {video_id}/{video_id}_audio.wav under --daily-omni-video-dir. " + "'all': video + WAV together. Sets mm_processor_kwargs.use_audio_in_video=false " + "and matches official separate video/audio streams.", + ) + daily_omni_group.add_argument( + "--daily-omni-save-eval-items", + action="store_true", + default=False, + help="Include per-request Daily-Omni accuracy rows (gold/predicted/correct) " + "in the saved JSON under key daily_omni_eval_items. " + "Alternatively set env DAILY_OMNI_SAVE_EVAL_ITEMS=1.", + ) + + # Note: --dataset-name daily-omni via get_samples patch; use either Hub (--dataset-path + # liarliar/Daily-Omni) or local --daily-omni-qa-json (offline). + + +def add_seed_tts_cli_args(parser: argparse.ArgumentParser) -> None: + """CLI for Seed-TTS zero-shot TTS benchmark (``--dataset-name seed-tts``).""" + g = parser.add_argument_group("Seed-TTS Dataset Options") + g.add_argument( + "--seed-tts-locale", + type=str, + choices=["en", "zh"], + default="en", + help="Which Seed-TTS split to load: en/meta.lst or zh/meta.lst under the dataset root.", + ) + g.add_argument( + "--seed-tts-root", + type=str, + default=None, + help="Override root directory that contains en/ and zh/ (meta.lst + prompt-wavs). " + "If set, --dataset-path can still name the HF repo for logging; this path is used for files.", + ) + g.add_argument( + "--seed-tts-file-ref-audio", + action="store_true", + default=False, + help="Send ref_audio as file:// URIs (smaller HTTP bodies). Requires the API server " + "to be started with --allowed-local-media-path covering the Seed-TTS dataset root. " + "Default is inline data:audio/wav;base64 so Qwen3-TTS works without that flag.", + ) + g.add_argument( + "--seed-tts-inline-ref-audio", + action="store_true", + default=False, + help=argparse.SUPPRESS, + ) + g.add_argument( + "--seed-tts-system-prompt", + type=str, + default=None, + help="Override chat system message for --backend openai-chat-omni (Qwen3-Omni TTS). " + "Default follows official Qwen3-Omni identity + zero-shot voice-clone instructions.", + ) + g.add_argument( + "--seed-tts-wer-eval", + action="store_true", + default=False, + help="Keep synthesized audio as 24 kHz mono PCM for WER (works with " + "--backend openai-audio-speech or openai-chat-omni). Scoring follows " + "BytedanceSpeech/seed-tts-eval (Whisper-large-v3 / Paraformer-zh + jiwer). " + "Sets SEED_TTS_WER_EVAL=1. Install: pip install 'vllm-omni[seed-tts-eval]'. " + "Optional: SEED_TTS_EVAL_DEVICE, SEED_TTS_HF_WHISPER_MODEL.", + ) + g.add_argument( + "--seed-tts-wer-save-items", + action="store_true", + default=False, + help="Include per-utterance ASR rows in the saved JSON under key seed_tts_wer_eval_items. " + "Or set SEED_TTS_WER_SAVE_ITEMS=1.", + ) + + class OmniBenchmarkServingSubcommand(OmniBenchmarkSubcommandBase): """The `serve` subcommand for vllm bench.""" name = "serve" - help = "Benchmark the online serving throughput." + help = "Benchmark the online serving throughput. Supports Daily-Omni and Seed-TTS datasets." @classmethod def add_cli_args(cls, parser: argparse.ArgumentParser) -> None: add_cli_args(parser) + + # Add Daily-Omni specific arguments + add_daily_omni_cli_args(parser) + add_seed_tts_cli_args(parser) + + for action in parser._actions: + if action.dest == "dataset_name" and action.choices is not None: + extra = [c for c in ("daily-omni", "seed-tts") if c not in action.choices] + if extra: + action.choices = list(action.choices) + extra + + # Update help messages for omni-specific features for action in parser._actions: if action.dest == "percentile_metrics": action.help = ( @@ -48,4 +183,10 @@ def add_cli_args(cls, parser: argparse.ArgumentParser) -> None: @staticmethod def cmd(args: argparse.Namespace) -> None: + if getattr(args, "daily_omni_save_eval_items", False): + os.environ["DAILY_OMNI_SAVE_EVAL_ITEMS"] = "1" + if getattr(args, "seed_tts_wer_eval", False): + os.environ["SEED_TTS_WER_EVAL"] = "1" + if getattr(args, "seed_tts_wer_save_items", False): + os.environ["SEED_TTS_WER_SAVE_ITEMS"] = "1" main(args) diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index 6e9adc2461..8bccfbb591 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -9,6 +9,7 @@ import json import os import signal +import sys from types import FrameType from typing import Any @@ -21,6 +22,7 @@ from vllm_omni.entrypoints.cli.logo import log_logo from vllm_omni.entrypoints.openai.api_server import omni_run_server +from vllm_omni.entrypoints.utils import detect_explicit_cli_keys logger = init_logger(__name__) @@ -79,6 +81,9 @@ class OmniServeCommand(CLISubcommand): """The `serve` subcommand for the vLLM CLI.""" name = "serve" + # Parser stashed at subparser_init so ``cmd`` can resolve each user-typed + # flag to its real ``dest`` via the parser's action table. + _parser: FlexibleArgumentParser | None = None @staticmethod def cmd(args: argparse.Namespace) -> None: @@ -90,6 +95,10 @@ def cmd(args: argparse.Namespace) -> None: if hasattr(args, "model_tag") and args.model_tag is not None: args.model = args.model_tag + # Stash the set of long-option keys the user actually typed so the + # stage-config factory can give YAML precedence over argparse defaults. + args._cli_explicit_keys = detect_explicit_cli_keys(sys.argv[1:], OmniServeCommand._parser) + if args.headless: run_headless(args) else: @@ -138,11 +147,33 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu help="Default task type for TTS models (CustomVoice, VoiceDesign, or Base). " "If not specified, will be inferred from model path.", ) + # TODO(@lishunyang12): deprecate once all models migrate to --deploy-config omni_config_group.add_argument( "--stage-configs-path", type=str, default=None, - help="Path to the stage configs file. If not specified, the stage configs will be loaded from the model.", + help="[Deprecated — will be removed in a future release] Path to a legacy " + "stage configs YAML (stage_args format). Prefer --deploy-config for new-format deploy YAMLs.", + ) + omni_config_group.add_argument( + "--deploy-config", + type=str, + default=None, + help="Path to a deploy config YAML (new format with stages/engine_args). " + "Mutually exclusive with --stage-configs-path.", + ) + omni_config_group.add_argument( + "--stage-overrides", + type=str, + default=None, + help="Per-stage JSON overrides. Example: " + '\'{"0": {"gpu_memory_utilization": 0.8}, "2": {"enforce_eager": true}}\'', + ) + omni_config_group.add_argument( + "--async-chunk", + action=argparse.BooleanOptionalAction, + default=None, + help="Override the deploy YAML's ``async_chunk:`` bool. Unset leaves the YAML value in force.", ) omni_config_group.add_argument( "--stage-id", @@ -406,6 +437,9 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu action="store_true", help="Enable diffusion pipeline profiler to display stage durations.", ) + # Stash via type(self) so the docs hook (which execs this function in a + # sandboxed globals dict via ``DummySelf``) doesn't fail on a NameError. + type(self)._parser = serve_parser return serve_parser @@ -461,10 +495,15 @@ def run_headless(args: argparse.Namespace) -> None: raise ValueError("headless mode requires worker_backend=multi_process") args_dict = vars(args).copy() + # Preserve the explicit-keys set captured at parse time so per-stage yaml + # values (e.g. stage 1's ``gpu_memory_utilization: 0.5``) are not + # overwritten by argparse defaults for flags the user didn't type. + cli_explicit_keys = args_dict.pop("_cli_explicit_keys", None) config_path, stage_configs = load_and_resolve_stage_configs( model, args_dict.get("stage_configs_path"), args_dict, + cli_explicit_keys=cli_explicit_keys, ) # Locate the stage config that matches stage_id. diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index a3bfe98ce2..8ef7e2ee5b 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -66,6 +66,13 @@ def generate( py_generator: bool = False, use_tqdm: bool | Callable[..., tqdm] = True, ) -> Generator[OmniRequestOutput, None, None] | list[OmniRequestOutput]: + # Expand sampling params for PD disaggregation (user may provide N-1 params) + if ( + sampling_params_list is not None + and isinstance(sampling_params_list, Sequence) + and not isinstance(sampling_params_list, (str, bytes)) + ): + sampling_params_list = self._maybe_expand_sampling_params(list(sampling_params_list)) sampling_params_list = self.resolve_sampling_params_list(sampling_params_list) try: if py_generator: @@ -125,10 +132,17 @@ def _run_generation( req_state.metrics = metrics self.request_states[req_id] = req_state + # PD disaggregation: modify stage-0 (prefill) sampling params per request + req_sp_list = list(sampling_params_list) + pd_pair = self._get_pd_separation_pair() + if pd_pair is not None: + p_id = pd_pair[0] + req_sp_list[p_id] = self._prepare_prefill_sampling_params(req_id, req_sp_list[p_id]) + self.engine.add_request( request_id=req_id, prompt=prompt, - sampling_params_list=sampling_params_list, + sampling_params_list=req_sp_list, final_stage_id=final_stage_id, ) submit_ts = time.time() diff --git a/vllm_omni/entrypoints/omni_base.py b/vllm_omni/entrypoints/omni_base.py index 1a7ffc4a50..dca494efe7 100644 --- a/vllm_omni/entrypoints/omni_base.py +++ b/vllm_omni/entrypoints/omni_base.py @@ -1,6 +1,8 @@ from __future__ import annotations +import argparse import os +import sys import time import types import weakref @@ -14,7 +16,8 @@ from vllm_omni.engine.async_omni_engine import AsyncOmniEngine from vllm_omni.entrypoints.client_request_state import ClientRequestState -from vllm_omni.entrypoints.utils import get_final_stage_id_for_e2e +from vllm_omni.entrypoints.pd_utils import PDDisaggregationMixin +from vllm_omni.entrypoints.utils import detect_explicit_cli_keys, get_final_stage_id_for_e2e from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific from vllm_omni.outputs import OmniRequestOutput @@ -65,9 +68,51 @@ def omni_snapshot_download(model_id: str) -> str: OutputMessageHandleResult = tuple[Literal[True], None, None, None] | tuple[Literal[False], str, int, ClientRequestState] -class OmniBase: +class OmniBase(PDDisaggregationMixin): """Shared runtime foundation for AsyncOmni and Omni.""" + @classmethod + def from_cli_args( + cls, + args: argparse.Namespace, + *, + parser: argparse.ArgumentParser | None = None, + **overrides: Any, + ) -> OmniBase: + """Construct an ``Omni`` / ``AsyncOmni`` from an ``argparse.Namespace``. + + Mirrors the ``EngineArgs.from_cli_args`` pattern used upstream and in + ``OmniEngineArgs.from_cli_args``. This is the recommended entry point + for any argparse-based caller (offline scripts, tests, CI): it + expands ``vars(args)`` into kwargs and automatically captures which + flags the user typed on the command line so that argparse defaults + do not silently override deploy YAML values. + + Passing ``parser`` is strongly recommended: without it, flag-to-dest + resolution falls back to a name-based heuristic that misidentifies + flags with ``dest=`` overrides, alias flags, and ``--disable-X`` / + ``store_false`` pairs. See :func:`detect_explicit_cli_keys`. + + Args: + args: Parsed argparse namespace from ``parser.parse_args()``. + parser: The argparse parser used to produce ``args``. When + provided, each user-typed flag is resolved to its real + ``dest`` via the parser's action table. + **overrides: Extra keyword arguments that take precedence over + attributes on ``args``. + + Example:: + + parser = FlexibleArgumentParser() + parser.add_argument("--model", required=True) + args = parser.parse_args() + omni = Omni.from_cli_args(args, parser=parser) # preferred + omni = Omni.from_cli_args(args, parser=parser, model="other") + """ + kwargs: dict[str, Any] = {**vars(args), **overrides} + kwargs["_cli_explicit_keys"] = detect_explicit_cli_keys(sys.argv[1:], parser) + return cls(**kwargs) + def __init__( self, model: str, @@ -77,16 +122,24 @@ def __init__( stage_init_timeout = kwargs.pop("stage_init_timeout", 300) init_timeout = kwargs.pop("init_timeout", 600) log_stats = kwargs.pop("log_stats", False) - async_chunk = kwargs.pop("async_chunk", False) + # NOTE: read-only lookup — must NOT pop. Popping here drops the key + # before it reaches ``StageConfigFactory._create_from_registry``, so + # ``--no-async-chunk`` (``async_chunk=False``) silently fails to + # override the deploy YAML's ``async_chunk: true`` default. + async_chunk = kwargs.get("async_chunk") output_modalities = kwargs.pop("output_modalities", None) diffusion_batch_size: int = kwargs.pop("diffusion_batch_size", 1) if "log_requests" in kwargs: raise TypeError("`log_requests` has been removed in Omni/AsyncOmni. Use `log_stats`.") model = omni_snapshot_download(model) + self._name = self.__class__.__name__ self.model = model self.log_stats = log_stats - self.async_chunk = async_chunk + # Provisional value (mirrors the CLI/caller kwarg); the engine resolves + # pipeline + deploy YAML + CLI precedence below and the final value is + # re-assigned from ``self.engine.async_chunk`` after init. + self.async_chunk = bool(async_chunk) if async_chunk is not None else False self.output_modalities = output_modalities or [] self.tts_batch_max_items: int = kwargs.pop("tts_batch_max_items", 32) @@ -104,7 +157,11 @@ def __init__( self._weak_finalizer = weakref.finalize(self, _weak_shutdown_engine, self.engine) et = time.time() logger.info("[%s] AsyncOmniEngine initialized in %.2f seconds", self.__class__.__name__, et - st) - self.async_chunk = bool(self.async_chunk or getattr(self.engine, "async_chunk", False)) + # Authoritative: ``AsyncOmniEngine`` resolves (pipeline + deploy YAML + + # CLI overrides) through ``StageConfigFactory`` and stores the final + # value on ``engine.async_chunk``; mirror it here so ``--no-async-chunk`` + # (explicit ``False``) is not fallen-back-through by ``or``. + self.async_chunk = bool(getattr(self.engine, "async_chunk", False)) self.request_states: dict[str, ClientRequestState] = {} @@ -125,10 +182,18 @@ def __init__( model, ) + # PD disaggregation state (detects if a prefill/decode stage pair is configured) + self._init_pd_state() + @property def num_stages(self) -> int: return self.engine.num_stages + @property + def stage_configs(self) -> list: + """Expose engine stage configs for PD disaggregation detection and validation.""" + return self.engine.stage_configs + @property def is_running(self) -> bool: return self.engine.is_alive() diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 4519ae8c0c..745b719d5b 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -18,6 +18,7 @@ from typing import Annotated, Any, Literal, cast import httpx +import numpy as np import vllm.envs as envs from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, Request, UploadFile, WebSocket from fastapi.responses import FileResponse, JSONResponse, Response, StreamingResponse @@ -52,7 +53,6 @@ from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.orca_metrics import metrics_header -from vllm.entrypoints.openai.realtime.connection import RealtimeConnection from vllm.entrypoints.openai.realtime.serving import OpenAIServingRealtime from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses from vllm.entrypoints.openai.server_utils import get_uvicorn_log_config @@ -107,6 +107,7 @@ VideoListResponse, VideoResponse, ) +from vllm_omni.entrypoints.openai.realtime_connection import RealtimeConnection from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech from vllm_omni.entrypoints.openai.serving_speech_stream import OmniStreamingSpeechHandler @@ -120,6 +121,7 @@ logger = init_logger(__name__) router = APIRouter() +MAX_UINT32_SEED = 2**32 - 1 profiler_router = APIRouter() @@ -1202,6 +1204,22 @@ async def streaming_speech(websocket: WebSocket): @router.websocket("/v1/realtime") async def realtime_websocket(websocket: WebSocket): """WebSocket endpoint for OpenAI-style realtime interactions.""" + engine_client = getattr(websocket.app.state, "engine_client", None) + if engine_client is not None and getattr(engine_client, "async_chunk", False): + await websocket.accept() + await websocket.send_json( + { + "type": "error", + "error": ( + "The /v1/realtime API is not supported when async_chunk is enabled on the server. " + "Use a stage configuration with async_chunk disabled and restart the server before using " + "this endpoint." + ), + "code": "unsupported", + } + ) + await websocket.close() + return serving = getattr(websocket.app.state, "openai_serving_realtime", None) if serving is None: await websocket.accept() @@ -1303,14 +1321,64 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) # Get engine client (AsyncOmni) from app state engine_client, model_name, stage_configs = _get_engine_and_model(raw_request) - # Validate model field (warn if mismatch, don't error) if request.model is not None and request.model != model_name: - logger.warning( - f"Model mismatch: request specifies '{request.model}' but " - f"server is running '{model_name}'. Using server model." + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=(f"Model mismatch: request specifies '{request.model}' but server is running '{model_name}'."), ) try: + # Unify request construction for any multi-stage pipeline to avoid + # divergence between /v1/images and /v1/chat/completions. + if len(stage_configs) > 1: + chat_handler = getattr(raw_request.app.state, "openai_serving_chat", None) + if chat_handler is None: + logger.warning("openai_serving_chat is not initialized for multi-stage /v1/images/generations") + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, + detail="openai_serving_chat is not initialized for multi-stage image generation.", + ) + + effective_seed = request.seed if request.seed is not None else random.randint(0, MAX_UINT32_SEED) + extra_body: dict[str, Any] = { + "seed": effective_seed, + "num_outputs_per_prompt": request.n, + } + if request.size is not None: + parse_size(request.size) + width, height = parse_size(request.size) + app_state_args = getattr(raw_request.app.state, "args", None) + _check_max_generated_image_size(app_state_args, width, height) + extra_body["size"] = request.size + if request.negative_prompt is not None: + extra_body["negative_prompt"] = request.negative_prompt + if request.num_inference_steps is not None: + extra_body["num_inference_steps"] = request.num_inference_steps + if request.guidance_scale is not None: + extra_body["guidance_scale"] = request.guidance_scale + if request.true_cfg_scale is not None: + extra_body["true_cfg_scale"] = request.true_cfg_scale + if request.generator_device is not None: + extra_body["generator_device"] = request.generator_device + if request.lora is not None: + # Keep /images validation semantics: invalid LoRA should fail with 400. + _parse_lora_request(request.lora) + extra_body["lora"] = request.lora + + generation_result = await chat_handler.generate_diffusion_images( + prompt=request.prompt, + extra_body=extra_body, + request_id=f"img_gen-{random_uuid()}", + ) + if isinstance(generation_result, ErrorResponse): + return JSONResponse( + status_code=generation_result.error.code if generation_result.error else 400, + content=generation_result.model_dump(), + ) + flat_images, _, _ = generation_result + image_data = [ImageData(b64_json=encode_image_base64(img), revised_prompt=None) for img in flat_images] + return ImageGenerationResponse(created=int(time.time()), data=image_data) + # Build params - pass through user values directly prompt: OmniTextPrompt = {"prompt": request.prompt} if request.negative_prompt is not None: @@ -1351,7 +1419,7 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) # This fixes issues where using the default global generator # might produce blurry images in some environments. _update_if_not_none( - gen_params, "seed", request.seed if request.seed is not None else random.randint(0, 2**32 - 1) + gen_params, "seed", request.seed if request.seed is not None else random.randint(0, MAX_UINT32_SEED) ) _update_if_not_none(gen_params, "generator_device", request.generator_device) _update_if_not_none(gen_params, "layers", request.layers) @@ -1425,6 +1493,9 @@ async def edit_images( background: str | None = Form("auto"), output_compression: Annotated[int, Form(ge=0, le=100)] = 100, user: str | None = Form(None), # unused now + # vllm-omni extensions for image editing + mask_image: str | UploadFile | None = None, + reference_image: str | UploadFile | None = None, # vllm-omni extensions for diffusion control negative_prompt: str | None = Form(None), num_inference_steps: int | None = Form(None), @@ -1445,8 +1516,9 @@ async def edit_images( # 1. get engine and model engine_client, model_name, stage_configs = _get_engine_and_model(raw_request) if model is not None and model != model_name: - logger.warning( - f"Model mismatch: request specifies '{model}' but server is running '{model_name}'. Using server model." + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=(f"Model mismatch: request specifies '{model}' but server is running '{model_name}'."), ) # 2. get output format & compression output_format = _choose_output_format(output_format, background) @@ -1469,15 +1541,35 @@ async def edit_images( input_images_list.extend(urls) if not input_images_list: raise HTTPException(status_code=422, detail="Field 'image' or 'url' is required") - pil_images = await _load_input_images(input_images_list) - if len(pil_images) > 1 and not _supports_multimodal_image_inputs(raw_request, engine_client): + # Reject oversized multi-image edit requests before fetching or decoding + # any inputs. This keeps over-limit URL requests from burning network, + # CPU, and memory on work that will be rejected anyway. + max_input_images = _get_max_edit_input_images(raw_request, engine_client) + if max_input_images is not None and len(input_images_list) > max_input_images: + detail = ( + "Received multiple input images. Only a single image is supported by this model." + if max_input_images == 1 + else ( + f"Received {len(input_images_list)} input images. " + f"At most {max_input_images} images are supported by this model." + ) + ) raise HTTPException( status_code=HTTPStatus.BAD_REQUEST.value, - detail="Received multiple input images. Only a single image is supported by this model.", + detail=detail, ) + pil_images = await _load_input_images(input_images_list) prompt["multi_modal_data"] = {} prompt["multi_modal_data"]["image"] = pil_images + if mask_image is not None: + loaded = await _load_input_images([mask_image]) + prompt["multi_modal_data"]["mask_image"] = loaded[0] + + if reference_image is not None: + loaded = await _load_input_images([reference_image]) + prompt["multi_modal_data"]["reference_image"] = loaded[0] + # 3 Build sample params gen_params = OmniDiffusionSamplingParams() # 3.0 Init with system default values @@ -1549,7 +1641,7 @@ async def edit_images( # a proper generator is initialized in the backend. # This fixes issues where using the default global generator # might produce blurry images in some environments. - _update_if_not_none(gen_params, "seed", seed if seed is not None else random.randint(0, 2**32 - 1)) + _update_if_not_none(gen_params, "seed", seed if seed is not None else random.randint(0, MAX_UINT32_SEED)) _update_if_not_none(gen_params, "generator_device", generator_device) _update_if_not_none(gen_params, "layers", layers) _update_if_not_none(gen_params, "resolution", resolution) @@ -1639,18 +1731,25 @@ def _get_engine_and_model(raw_request: Request): return engine_client, model_name, normalized_stage_configs -def _supports_multimodal_image_inputs(raw_request: Request, engine_client: Any) -> bool: +def _get_diffusion_od_config(raw_request: Request, engine_client: Any) -> Any: diffusion_engine = getattr(raw_request.app.state, "diffusion_engine", None) or engine_client get_diffusion_od_config = getattr(diffusion_engine, "get_diffusion_od_config", None) - od_config = ( + return ( get_diffusion_od_config() if callable(get_diffusion_od_config) else getattr(diffusion_engine, "od_config", None) ) + +def _get_max_edit_input_images(raw_request: Request, engine_client: Any) -> int | None: + od_config = _get_diffusion_od_config(raw_request, engine_client) if od_config is None: # Preserve the existing compatibility behavior when the diffusion # config is not exposed on the serving surface. - return True - return bool(getattr(od_config, "supports_multimodal_inputs", False)) + return None + + if not bool(getattr(od_config, "supports_multimodal_inputs", False)): + return 1 + + return getattr(od_config, "max_multimodal_image_inputs", None) def _get_lora_from_json_str(lora_body): @@ -1767,6 +1866,34 @@ def _update_if_not_none(object: Any, key: str, val: Any) -> None: setattr(object, key, val) +def _normalize_image(image: Any) -> Any: + """Normalize a single image output to a PIL-compatible format.""" + if isinstance(image, Image.Image): + return image + if not isinstance(image, np.ndarray): + raise ValueError(f"Unsupported image type: {type(image)}") + if not np.issubdtype(image.dtype, np.integer) and not np.issubdtype(image.dtype, np.floating): + raise ValueError(f"Unsupported dtype: {image.dtype}") + if isinstance(image, np.ndarray): + while image.ndim > 3: + image = image[0] + if image.min() < 0: + if image.min() < -1.01 or image.max() > 1.01: + logger.warning( + f"Image float range [{image.min():.2f}, {image.max():.2f}] outside expected [-1, 1]. " + f"Clipping to [-1, 1] before normalization." + ) + image = np.clip(image, -1.0, 1.0) * 0.5 + 0.5 + elif image.max() > 1.01: + logger.warning( + f"Image float range [{image.min():.2f}, {image.max():.2f}] outside expected [0, 1]. " + f"Clipping to [0, 1] before normalization." + ) + image = (np.clip(image, 0.0, 1.0) * 255).astype(np.uint8) + image = Image.fromarray(image) + return image + + def _extract_images_from_result(result: Any) -> list[Any]: images = [] if hasattr(result, "images") and result.images: @@ -1777,6 +1904,10 @@ def _extract_images_from_result(result: Any) -> list[Any]: images = request_output["images"] elif hasattr(request_output, "images") and request_output.images: images = request_output.images + # Handle when generate more than one image + if images and isinstance(images[0], np.ndarray) and images[0].shape[0] > 1 and images[0].ndim == 5: + # Unwrap batch: (N, T, H, W, C) -> [img1, img2, ...] + images = list(images[0]) # Flatten nested lists (e.g., from layered models like Qwen-Image-Layered). # Note: This only flattens one level deep. Deeper nesting is not supported. flattened = [] @@ -1785,7 +1916,7 @@ def _extract_images_from_result(result: Any) -> list[Any]: flattened.extend(img) else: flattened.append(img) - return flattened + return [_normalize_image(img) for img in flattened] async def _load_input_images( @@ -1955,18 +2086,6 @@ def video_response_from_request(model_name: str, req: VideoGenerationRequest) -> return resp -async def decode_and_save_video_output(output: Any, file_name: str) -> str: - if not output.b64_json: - raise RuntimeError(f"Video output for {file_name} did not include b64_json content.") - - try: - video_bytes = base64.b64decode(output.b64_json) - except Exception as decode_exc: - raise RuntimeError(f"Failed to decode generated video payload for {file_name}") from decode_exc - - return await STORAGE_MANAGER.save(video_bytes, file_name) - - def _cleanup_video(video_id: str, output_path: str | None): try: if output_path is not None: @@ -1990,15 +2109,12 @@ async def _run_video_generation_job( started_at = time.perf_counter() output_path = None try: - response = await handler.generate_videos(request, video_id, reference_image=reference_image) - if not response.data: - raise RuntimeError("Video generation completed but returned no outputs.") - - if (video_count := len(response.data)) > 1: - logger.warning("Video request %s generated %s outputs but we only expected one.", video_id, video_count) + video_bytes, stage_durations, peak_memory_mb = await handler.generate_video_bytes( + request, video_id, reference_image=reference_image + ) file_name = f"{video_id}.{job.file_extension}" - output_path = await decode_and_save_video_output(response.data[0], file_name) + output_path = await STORAGE_MANAGER.save(video_bytes, file_name) logger.info("Video request %s persisted %s output file.", video_id, output_path) await VIDEO_STORE.update_fields( @@ -2009,6 +2125,8 @@ async def _run_video_generation_job( "file_name": file_name, "completed_at": int(time.time()), "inference_time_s": time.perf_counter() - started_at, + "stage_durations": stage_durations, + "peak_memory_mb": peak_memory_mb, }, ) except Exception as exc: @@ -2055,6 +2173,10 @@ async def _parse_video_form( true_cfg_scale: float | None = Form(default=None), seed: int | None = Form(default=None), negative_prompt: str | None = Form(default=None), + enable_frame_interpolation: bool = Form(default=False), + frame_interpolation_exp: int = Form(default=1, ge=1), + frame_interpolation_scale: float = Form(default=1.0, gt=0.0), + frame_interpolation_model_path: str | None = Form(default=None), lora: str | None = Form(default=None), extra_params: str | None = Form(default=None), ) -> tuple[VideoGenerationRequest, "OmniOpenAIServingVideo", str, ReferenceImage | None]: @@ -2091,6 +2213,10 @@ async def _parse_video_form( "true_cfg_scale": true_cfg_scale, "seed": seed, "negative_prompt": negative_prompt, + "enable_frame_interpolation": enable_frame_interpolation, + "frame_interpolation_exp": frame_interpolation_exp, + "frame_interpolation_scale": frame_interpolation_scale, + "frame_interpolation_model_path": frame_interpolation_model_path, "lora": _parse_form_json(lora, expected_type=dict), "extra_params": _parse_form_json(extra_params, expected_type=dict), } @@ -2108,10 +2234,12 @@ async def _parse_video_form( app_model_name, app_stage_configs = _resolve_video_runtime_context(raw_request) effective_model_name = handler.model_name or app_model_name or request.model or "unknown" if request.model is not None and effective_model_name is not None and request.model != effective_model_name: - logger.warning( - "Model mismatch: request specifies '%s' but server is running '%s'. Using server model.", - request.model, - effective_model_name, + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=( + f"Model mismatch: request specifies '{request.model}' but server is running " + f"'{effective_model_name}'." + ), ) handler.set_stage_configs_if_missing(app_stage_configs) except HTTPException: @@ -2182,7 +2310,7 @@ async def create_video_sync( request_id = f"video_sync-{random_uuid()}" started_at = time.perf_counter() try: - video_bytes = await asyncio.wait_for( + video_bytes, stage_durations, peak_memory_mb = await asyncio.wait_for( handler.generate_video_bytes(request, request_id, reference_image=reference_image), timeout=VIDEO_SYNC_TIMEOUT_S, ) @@ -2208,6 +2336,8 @@ async def create_video_sync( "X-Request-Id": request_id, "X-Model": effective_model_name, "X-Inference-Time-S": f"{inference_time_s:.3f}", + "X-Stage-Durations": json.dumps(stage_durations, separators=(",", ":")), + "X-Peak-Memory-MB": f"{peak_memory_mb:.3f}", }, ) diff --git a/vllm_omni/entrypoints/openai/audio_utils_mixin.py b/vllm_omni/entrypoints/openai/audio_utils_mixin.py index 13df32ebe0..b626f7eeb2 100644 --- a/vllm_omni/entrypoints/openai/audio_utils_mixin.py +++ b/vllm_omni/entrypoints/openai/audio_utils_mixin.py @@ -1,6 +1,8 @@ from io import BytesIO import numpy as np +import torch +import torchaudio from vllm.logger import init_logger from vllm_omni.entrypoints.openai.protocol.audio import AudioResponse, CreateAudio @@ -10,11 +12,6 @@ except ImportError: soundfile = None -try: - import librosa -except ImportError: - librosa = None - logger = init_logger(__name__) @@ -74,20 +71,53 @@ def create_audio(self, audio_obj: CreateAudio) -> AudioResponse: return AudioResponse(audio_data=audio_data, media_type=media_type) def _apply_speed_adjustment(self, audio_tensor: np.ndarray, speed: float, sample_rate: int): - """Apply speed adjustment to the audio tensor while preserving pitch.""" + """Apply speed adjustment to the audio tensor while preserving pitch. + + Uses torchaudio's phase vocoder (Spectrogram → TimeStretch → + InverseSpectrogram) to stretch/compress audio in time without + changing pitch. + """ if speed == 1.0: return audio_tensor, sample_rate - if librosa is None: - raise ImportError("librosa is required for speed adjustment. Please install it with: pip install librosa") - try: - # librosa.effects.time_stretch requires a float audio tensor. if not np.issubdtype(audio_tensor.dtype, np.floating): audio_tensor = audio_tensor.astype(np.float32) - stretched_audio = librosa.effects.time_stretch(y=audio_tensor, rate=speed) - return stretched_audio, sample_rate + # Stereo numpy arrays use channels-last (T, C); + # torch expects channels-first (C, T). + channels_last = audio_tensor.ndim == 2 + if channels_last: + waveform = torch.from_numpy(audio_tensor.T) + else: + waveform = torch.from_numpy(audio_tensor).unsqueeze(0) + + # Match librosa.stft defaults: n_fft=2048, hop_length=n_fft//4 + n_fft = 2048 + hop_length = n_fft // 4 + to_spec = torchaudio.transforms.Spectrogram( + n_fft=n_fft, + hop_length=hop_length, + power=None, + ) + stretch = torchaudio.transforms.TimeStretch( + n_freq=n_fft // 2 + 1, + hop_length=hop_length, + ) + to_wave = torchaudio.transforms.InverseSpectrogram( + n_fft=n_fft, + hop_length=hop_length, + ) + + spec = to_spec(waveform) + stretched = stretch(spec, speed) + expected_length = int(audio_tensor.shape[0] / speed) + result = to_wave(stretched, length=expected_length) + + result = result.squeeze(0).numpy() + if channels_last: + result = result.T + return result, sample_rate except Exception as e: logger.error(f"An error occurred during speed adjustment: {e}") raise ValueError("Failed to apply speed adjustment.") from e diff --git a/vllm_omni/entrypoints/openai/protocol/videos.py b/vllm_omni/entrypoints/openai/protocol/videos.py index e180bef229..7c2c3164d9 100644 --- a/vllm_omni/entrypoints/openai/protocol/videos.py +++ b/vllm_omni/entrypoints/openai/protocol/videos.py @@ -150,6 +150,29 @@ class VideoGenerationRequest(BaseModel): ) seed: int | None = Field(default=None, description="Random seed for reproducibility") + # vllm-omni extensions for post-generation frame interpolation. + enable_frame_interpolation: bool = Field( + default=False, + description="Enable post-generation RIFE frame interpolation before MP4 encoding.", + ) + frame_interpolation_exp: int = Field( + default=1, + ge=1, + description="Interpolation exponent: 1=2x temporal resolution, 2=4x, etc.", + ) + frame_interpolation_scale: float = Field( + default=1.0, + gt=0.0, + description="RIFE inference scale. Use 0.5 for high-resolution inputs to save memory.", + ) + frame_interpolation_model_path: str | None = Field( + default=None, + description=( + "Local directory or Hugging Face repo ID containing RIFE flownet.pkl weights. " + "Defaults to elfgum/RIFE-4.22.lite." + ), + ) + # vllm-omni extension for per-request LoRA. lora: dict[str, Any] | None = Field( default=None, @@ -201,6 +224,14 @@ class VideoGenerationResponse(BaseModel): created: int = Field(..., description="Unix timestamp of when the generation completed") data: list[VideoData] = Field(..., description="Array of generated videos") + stage_durations: dict[str, float] = Field( + default_factory=dict, + description="Profiler stage durations reported by the diffusion pipeline.", + ) + peak_memory_mb: float = Field( + default=0.0, + description="Peak device memory usage in MB reported by the diffusion pipeline.", + ) class VideoError(BaseModel): @@ -250,6 +281,14 @@ class VideoResponse(BaseModel): description="Filename of the saved output video files for this job.", ) inference_time_s: float | None = Field(default=None, description="End-to-end inference time in seconds.") + stage_durations: dict[str, float] = Field( + default_factory=dict, + description="Profiler stage durations reported by the diffusion pipeline.", + ) + peak_memory_mb: float = Field( + default=0.0, + description="Peak device memory usage in MB reported by the diffusion pipeline.", + ) @property def file_extension(self) -> str: diff --git a/vllm_omni/entrypoints/openai/realtime_connection.py b/vllm_omni/entrypoints/openai/realtime_connection.py new file mode 100644 index 0000000000..1d5470f569 --- /dev/null +++ b/vllm_omni/entrypoints/openai/realtime_connection.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +import asyncio +import base64 +import json +from collections.abc import AsyncGenerator +from uuid import uuid4 + +import numpy as np +from vllm.entrypoints.openai.engine.protocol import UsageInfo +from vllm.entrypoints.openai.realtime.connection import RealtimeConnection as VllmRealtimeConnection +from vllm.entrypoints.openai.realtime.protocol import TranscriptionDelta, TranscriptionDone +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class RealtimeConnection(VllmRealtimeConnection): + """Omni realtime connection with audio-only server events. + + Reuses upstream vLLM websocket/session lifecycle and only customizes + generation output handling to emit audio deltas. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Last audio buffer seen for this realtime generation (cumulative or concatenation + # of increments); used to turn server cumulative PCM into true deltas. + self._realtime_audio_ref: np.ndarray | None = None + + async def start_generation(self): + await super().start_generation() + + @staticmethod + def _tensor_to_numpy(value) -> np.ndarray | None: + if value is None: + return None + if isinstance(value, np.ndarray): + arr = value + elif hasattr(value, "detach"): + arr = value.detach().float().cpu().numpy() + else: + try: + arr = np.asarray(value) + except Exception: + return None + if arr.ndim > 1: + arr = arr.reshape(-1) + return arr.astype(np.float32, copy=False) + + @staticmethod + def _numpy_audio_prefix_match(prev: np.ndarray, curr: np.ndarray) -> bool: + n = prev.shape[0] + if n == 0: + return True + if curr.shape[0] < n: + return False + return bool(np.allclose(curr[:n], prev, rtol=1e-3, atol=2e-4)) + + def _raw_waveform_to_deltas(self, arr: np.ndarray) -> list[np.ndarray]: + """Convert one streaming PCM f32 chunk into incremental piece(s) for the client. + + Some engine paths emit a growing cumulative waveform each step; others emit + true per-step deltas. We support both without duplicating audio on the client. + """ + if arr.size == 0: + return [] + ref = self._realtime_audio_ref + if ref is None: + self._realtime_audio_ref = arr.copy() + return [arr] + if self._numpy_audio_prefix_match(ref, arr): + delta = arr[ref.shape[0] :] + self._realtime_audio_ref = arr.copy() + return [delta] if delta.size > 0 else [] + # True per-step delta (not a prefix extension of what we have seen). + self._realtime_audio_ref = np.concatenate([ref, arr]) + return [arr] + + def _extract_audio_chunks(self, output) -> tuple[list[np.ndarray], int]: + mm = getattr(output, "multimodal_output", None) + if not isinstance(mm, dict): + return [], 24000 + + sr = mm.get("sr") or mm.get("sample_rate") or mm.get("audio_sample_rate") or 24000 + key = "audio" if "audio" in mm else ("model_outputs" if "model_outputs" in mm else None) + if key is None: + return [], int(sr) + + raw_audio = mm.get(key) + chunks: list[np.ndarray] = [] + if isinstance(raw_audio, (list, tuple)): + if len(raw_audio) > 0: + arr = self._tensor_to_numpy(raw_audio[-1]) + if arr is not None and arr.size > 0: + chunks.extend(self._raw_waveform_to_deltas(arr)) + else: + arr = self._tensor_to_numpy(raw_audio) + if arr is not None and arr.size > 0: + chunks.extend(self._raw_waveform_to_deltas(arr)) + return chunks, int(sr) + + @staticmethod + def _pcm16_b64(audio_f32: np.ndarray) -> str: + clipped = np.clip(audio_f32, -1.0, 1.0) + pcm16 = (clipped * 32767.0).astype(np.int16) + return base64.b64encode(pcm16.tobytes()).decode("utf-8") + + async def _run_generation( + self, + streaming_input_gen: AsyncGenerator, + input_stream: asyncio.Queue[list[int]], + ): + request_id = f"rt-{self.connection_id}-{uuid4()}" + sent_audio = False + audio_done_sent = False + full_text = "" + sent_text_len = 0 + prompt_token_ids_len = 0 + completion_tokens_len = 0 + self._realtime_audio_ref = None + + try: + result_gen = self.serving.engine_client.generate( + prompt=streaming_input_gen, + request_id=request_id, + ) + + async for output in result_gen: + if output.outputs and len(output.outputs) > 0: + output0 = output.outputs[0] + token_ids = list(output0.token_ids) + if token_ids: + input_stream.put_nowait(token_ids) + # token_ids are cumulative per request + completion_tokens_len = len(token_ids) + if not prompt_token_ids_len and output.prompt_token_ids: + prompt_token_ids_len = len(output.prompt_token_ids) + cumulative_text = output0.text or "" + if cumulative_text: + if len(cumulative_text) >= sent_text_len: + delta_text = cumulative_text[sent_text_len:] + else: + delta_text = cumulative_text + sent_text_len = len(cumulative_text) + full_text = cumulative_text + else: + delta_text = "" + + if delta_text: + await self.send(TranscriptionDelta(delta=delta_text)) + + audio_chunks, sample_rate = self._extract_audio_chunks(output) + + for chunk in audio_chunks: + sent_audio = True + await self.send_json( + { + "type": "response.audio.delta", + "audio": self._pcm16_b64(chunk), + "format": "pcm16", + "sample_rate_hz": sample_rate, + } + ) + + if not self._is_connected: + break + + usage = UsageInfo( + prompt_tokens=prompt_token_ids_len, + completion_tokens=completion_tokens_len, + total_tokens=prompt_token_ids_len + completion_tokens_len, + ) + await self.send(TranscriptionDone(text=full_text, usage=usage)) + + if sent_audio: + await self.send_json({"type": "response.audio.done", "has_audio": True}) + audio_done_sent = True + except Exception as e: + logger.exception("Error in generation: %s", e) + await self.send_error(str(e), "processing_error") + finally: + # Always send terminal event so clients don't hang forever. + if self._is_connected and not audio_done_sent: + try: + await self.send_json({"type": "response.audio.done", "has_audio": sent_audio}) + except Exception: + logger.exception("Failed to send response.audio.done") + while not self.audio_queue.empty(): + self.audio_queue.get_nowait() + + async def send_json(self, payload: dict): + await self.websocket.send_text(json.dumps(payload)) diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index e84a49aac2..8cddac6a6c 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -85,7 +85,12 @@ from vllm_omni.entrypoints.openai.image_api_utils import validate_layered_layers from vllm_omni.entrypoints.openai.protocol import OmniChatCompletionStreamResponse from vllm_omni.entrypoints.openai.protocol.audio import AudioResponse, CreateAudio -from vllm_omni.entrypoints.openai.utils import parse_lora_request +from vllm_omni.entrypoints.openai.utils import ( + get_stage_type, + get_supported_speakers_from_hf_config, + parse_lora_request, + validate_requested_speaker, +) from vllm_omni.lora.request import LoRARequest from vllm_omni.outputs import OmniRequestOutput @@ -106,6 +111,7 @@ class OmniOpenAIServingChat(OpenAIServingChat, AudioMixin): _diffusion_mode: bool = False _diffusion_engine: AsyncOmni | None = None _diffusion_model_name: str = "" + _supported_speakers: set[str] | None = None @classmethod def for_diffusion( @@ -132,6 +138,18 @@ def for_diffusion( instance._diffusion_model_name = model_name return instance + def _get_supported_speakers(self) -> set[str]: + """Load supported speakers from model config (cached).""" + if self._supported_speakers is not None: + return self._supported_speakers + try: + self._supported_speakers = get_supported_speakers_from_hf_config(self.model_config.hf_config) + return self._supported_speakers + except Exception as e: + logger.warning("Could not load speakers from model config: %s", e) + self._supported_speakers = set() + return self._supported_speakers + async def create_chat_completion( self, request: ChatCompletionRequest, @@ -260,7 +278,10 @@ async def create_chat_completion( except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(f"{e} {e.__cause__}") + message = str(e) + if e.__cause__ is not None: + message = f"{message} {e.__cause__}" + return self.create_error_response(message) request_id = f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}" @@ -274,6 +295,8 @@ async def create_chat_completion( ) num_inference_steps = None + cfg_text_scale = None + cfg_img_scale = None # Omni multistage image generation: Stage-0 (AR) should receive a clean # text prompt (and optional conditioning image/size) so the model's own # processor can construct the correct inputs. @@ -322,6 +345,8 @@ async def create_chat_completion( except Exception: pass negative_prompt = extra_body.get("negative_prompt") + cfg_text_scale = extra_body.get("cfg_text_scale") + cfg_img_scale = extra_body.get("cfg_img_scale") engine_prompt_image: dict[str, Any] | None = None is_img2img = False @@ -377,14 +402,18 @@ async def create_chat_completion( sampling_params_list = self._build_sampling_params_list_from_request(request) # Apply user-specified overrides to diffusion stage(s) for image generation - if _image_gen_height is not None or _image_gen_width is not None or num_inference_steps is not None: - for idx, sp in enumerate(sampling_params_list): - if hasattr(sp, "height") and _image_gen_height is not None: - sp.height = _image_gen_height - if hasattr(sp, "width") and _image_gen_width is not None: - sp.width = _image_gen_width - if hasattr(sp, "num_inference_steps") and num_inference_steps is not None: - sp.num_inference_steps = num_inference_steps + for idx, sp in enumerate(sampling_params_list): + if hasattr(sp, "height") and _image_gen_height is not None: + sp.height = _image_gen_height + if hasattr(sp, "width") and _image_gen_width is not None: + sp.width = _image_gen_width + if hasattr(sp, "num_inference_steps") and num_inference_steps is not None: + sp.num_inference_steps = num_inference_steps + if hasattr(sp, "extra_args") and sp.extra_args is not None: + if cfg_text_scale is not None: + sp.extra_args["cfg_text_scale"] = cfg_text_scale + if cfg_img_scale is not None: + sp.extra_args["cfg_img_scale"] = cfg_img_scale self._log_inputs( request_id, @@ -540,10 +569,11 @@ async def _preprocess_chat( engine_prompt["cache_salt"] = request.cache_salt speaker = getattr(request, "speaker", None) - if speaker is not None and isinstance(speaker, str) and speaker.strip(): + normalized = validate_requested_speaker(speaker, self._get_supported_speakers()) + if normalized is not None: if "additional_information" not in engine_prompt or engine_prompt["additional_information"] is None: engine_prompt["additional_information"] = {} - engine_prompt["additional_information"]["speaker"] = [speaker.lower().strip()] + engine_prompt["additional_information"]["speaker"] = [normalized] language = getattr(request, "language", None) if language is not None and isinstance(language, str) and language.strip(): @@ -698,11 +728,17 @@ def _apply_request_overrides( for field_name in self._OPENAI_SAMPLING_FIELDS: value = getattr(request, field_name, None) - if value is not None: + if (value is not None and not isinstance(value, list)) or (isinstance(value, list) and len(value) > 0): setattr(params, field_name, value) return params + @staticmethod + def _set_if_supported(obj: Any, **kwargs: Any) -> None: + for key, value in kwargs.items(): + if value is not None and hasattr(obj, key): + setattr(obj, key, value) + def _build_sampling_params_list_from_request( self, request: ChatCompletionRequest, @@ -1549,6 +1585,7 @@ async def chat_completion_full_generator( role, reasoning_parser, ) + final_res = omni_outputs.request_output elif omni_outputs.final_output_type == "audio": choices_data = self._create_audio_choice(omni_outputs, role, request, stream=False) elif omni_outputs.final_output_type == "image": @@ -2027,6 +2064,254 @@ def _create_image_choice( return choices # ==================== Diffusion Mode Methods ==================== + def _build_multistage_generation_inputs( + self, + *, + engine: AsyncOmni, + prompt: str, + extra_body: dict[str, Any], + reference_images: list[Image.Image], + gen_params: OmniDiffusionSamplingParams, + ) -> tuple[OmniTextPrompt, list[Any]]: + """Build the shared multistage generation prompt and stage params.""" + stage_configs = getattr(engine, "stage_configs", None) or [] + default_params_list = list(getattr(engine, "default_sampling_params_list", []) or []) + + height = gen_params.height + width = gen_params.width + seed = gen_params.seed + generator_device = gen_params.generator_device + num_outputs_per_prompt = gen_params.num_outputs_per_prompt + num_inference_steps = extra_body.get("num_inference_steps") + guidance_scale = extra_body.get("guidance_scale") + true_cfg_scale = extra_body.get("true_cfg_scale") or extra_body.get("cfg_scale") + negative_prompt = extra_body.get("negative_prompt") + num_frames = extra_body.get("num_frames") + guidance_scale_2 = extra_body.get("guidance_scale_2") + lora_body = extra_body.get("lora") + layers = extra_body.get("layers") + resolution = extra_body.get("resolution") + + engine_prompt_data: dict[str, Any] | None = None + modalities = ["image"] + if reference_images: + if len(reference_images) == 1: + engine_prompt_data = {"img2img": reference_images[0]} + modalities = ["img2img"] + else: + engine_prompt_data = {"image": reference_images} + + engine_prompt: OmniTextPrompt = {"prompt": prompt} + engine_prompt["modalities"] = modalities + if negative_prompt is not None: + engine_prompt["negative_prompt"] = negative_prompt + + mm_processor_kwargs: dict[str, Any] = {} + if height is not None: + mm_processor_kwargs["target_h"] = height + if width is not None: + mm_processor_kwargs["target_w"] = width + if mm_processor_kwargs: + engine_prompt["mm_processor_kwargs"] = mm_processor_kwargs + if engine_prompt_data is not None: + engine_prompt["multi_modal_data"] = engine_prompt_data + + comprehension_idx = None + for idx, stage in enumerate(stage_configs): + if getattr(stage, "is_comprehension", False): + comprehension_idx = idx + break + + sampling_params_list: list[Any] = [] + for idx, stage_cfg in enumerate(stage_configs): + stage_type = get_stage_type(stage_cfg) + if idx < len(default_params_list): + default_stage_params = default_params_list[idx] + if hasattr(default_stage_params, "clone"): + try: + default_stage_params = default_stage_params.clone() + except Exception: + pass + elif stage_type == "diffusion": + default_stage_params = gen_params.clone() + else: + default_stage_params = SamplingParams() + + if ( + comprehension_idx is not None + and idx == comprehension_idx + and seed is not None + and hasattr(default_stage_params, "seed") + ): + default_stage_params.seed = seed + + if stage_type == "diffusion": + self._set_if_supported( + default_stage_params, + height=height, + width=width, + seed=seed, + generator_device=generator_device, + num_outputs_per_prompt=num_outputs_per_prompt, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + true_cfg_scale=true_cfg_scale, + num_frames=num_frames, + guidance_scale_2=guidance_scale_2, + layers=layers, + resolution=resolution, + ) + if lora_body and isinstance(lora_body, dict): + try: + lora_req, lora_scale = parse_lora_request(lora_body) + if lora_req is not None: + default_stage_params.lora_request = lora_req + if lora_scale is not None: + default_stage_params.lora_scale = lora_scale + except Exception as e: # pragma: no cover - safeguard + logger.warning("Failed to parse LoRA request: %s", e) + + sampling_params_list.append(default_stage_params) + + return engine_prompt, sampling_params_list + + async def generate_diffusion_images( + self, + *, + prompt: str, + extra_body: dict[str, Any] | None = None, + reference_images: list[str] | None = None, + request_id: str | None = None, + ) -> tuple[list[Image.Image], dict[str, Any], float] | ErrorResponse: + """Generate diffusion images and return raw images plus generation stats.""" + if request_id is None: + request_id = f"chatcmpl-{uuid.uuid4().hex[:16]}" + if extra_body is None: + extra_body = {} + if reference_images is None: + reference_images = [] + + engine = self._diffusion_engine if self._diffusion_engine is not None else self.engine_client + + height = extra_body.get("height") + width = extra_body.get("width") + if "size" in extra_body: + try: + size_str = extra_body["size"] + if isinstance(size_str, str) and "x" in size_str.lower(): + w, h = size_str.lower().split("x") + width, height = int(w), int(h) + except ValueError: + logger.warning("Invalid size format: %s", extra_body.get("size")) + + seed = extra_body.get("seed") + generator_device = extra_body.get("generator_device") + negative_prompt = extra_body.get("negative_prompt") + num_outputs_per_prompt = extra_body.get("num_outputs_per_prompt", 1) + lora_body = extra_body.get("lora") + + pil_images: list[Image.Image] = [] + for img_b64 in reference_images: + try: + img_bytes = base64.b64decode(img_b64) + pil_images.append(Image.open(BytesIO(img_bytes))) + except Exception as e: + logger.warning("Failed to decode reference image: %s", e) + + gen_params = OmniDiffusionSamplingParams( + height=height, + width=width, + num_outputs_per_prompt=num_outputs_per_prompt, + seed=seed, + ) + self._set_if_supported( + gen_params, + generator_device=generator_device, + num_inference_steps=extra_body.get("num_inference_steps"), + guidance_scale=extra_body.get("guidance_scale"), + true_cfg_scale=extra_body.get("true_cfg_scale") or extra_body.get("cfg_scale"), + num_frames=extra_body.get("num_frames"), + guidance_scale_2=extra_body.get("guidance_scale_2"), + layers=extra_body.get("layers"), + resolution=extra_body.get("resolution"), + ) + + if lora_body and isinstance(lora_body, dict): + try: + lora_req, lora_scale = parse_lora_request(lora_body) + if lora_req is not None: + gen_params.lora_request = lora_req + if lora_scale is not None: + gen_params.lora_scale = lora_scale + except Exception as e: # pragma: no cover - safeguard + logger.warning("Failed to parse LoRA request: %s", e) + + gen_prompt: OmniTextPrompt = { + "prompt": prompt, + "negative_prompt": negative_prompt, + } + if pil_images: + if len(pil_images) == 1: + gen_prompt["multi_modal_data"] = {"image": pil_images[0]} + else: + od_config = getattr(engine, "od_config", None) + supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", False) + if od_config is None: + supports_multimodal_inputs = True + if supports_multimodal_inputs: + gen_prompt["multi_modal_data"] = {"image": pil_images} + else: + return self._create_error_response( + "Multiple input images are not supported by the current diffusion model. " + "For multi-image editing, start the server with Qwen-Image-Edit-2509 " + "and send multiple images in the user message content.", + status_code=400, + ) + + if isinstance(engine, AsyncOmni): + diffusion_engine = cast(AsyncOmni, engine) + stage_configs = getattr(diffusion_engine, "stage_configs", None) or [] + if len(stage_configs) > 1: + engine_prompt, sampling_params_list = self._build_multistage_generation_inputs( + engine=diffusion_engine, + prompt=prompt, + extra_body=extra_body, + reference_images=pil_images, + gen_params=gen_params, + ) + else: + engine_prompt = gen_prompt + sampling_params_list = [gen_params] + + result = None + async for output in diffusion_engine.generate( + prompt=engine_prompt, + sampling_params_list=sampling_params_list, + request_id=request_id, + ): + result = output + if result is None: + return self._create_error_response("No output generated from AsyncOmni", status_code=500) + else: + result = await engine.generate( + prompt=gen_prompt, + sampling_params=gen_params, + request_id=request_id, + ) + + images = getattr(result.request_output, "images", []) + stage_durations = result.stage_durations + peak_memory_mb = result.peak_memory_mb + + flat_images: list[Image.Image] = [] + for item in images: + if isinstance(item, list): + flat_images.extend(item) + else: + flat_images.append(item) + + return flat_images, stage_durations, peak_memory_mb + async def _create_diffusion_chat_completion( self, request: ChatCompletionRequest, @@ -2087,6 +2372,8 @@ async def _create_diffusion_chat_completion( num_inference_steps = extra_body.get("num_inference_steps") guidance_scale = extra_body.get("guidance_scale") true_cfg_scale = extra_body.get("true_cfg_scale") or extra_body.get("cfg_scale") + cfg_text_scale = extra_body.get("cfg_text_scale") + cfg_img_scale = extra_body.get("cfg_img_scale") seed = extra_body.get("seed") negative_prompt = extra_body.get("negative_prompt") num_outputs_per_prompt = extra_body.get("num_outputs_per_prompt", 1) @@ -2141,6 +2428,10 @@ async def _create_diffusion_chat_completion( gen_params.guidance_scale = guidance_scale if true_cfg_scale is not None: gen_params.true_cfg_scale = true_cfg_scale + if cfg_text_scale is not None: + gen_params.extra_args["cfg_text_scale"] = cfg_text_scale + if cfg_img_scale is not None: + gen_params.extra_args["cfg_img_scale"] = cfg_img_scale if num_frames is not None: gen_params.num_frames = num_frames if guidance_scale_2 is not None: @@ -2161,7 +2452,7 @@ async def _create_diffusion_chat_completion( except Exception as e: # pragma: no cover - safeguard logger.warning("Failed to parse LoRA request: %s", e) - # Add reference image if provided + # Add reference image if provided (from messages content) if pil_images: if len(pil_images) == 1: gen_prompt["multi_modal_data"] = {} @@ -2185,10 +2476,30 @@ async def _create_diffusion_chat_completion( # Generate image diffusion_engine = cast(AsyncOmni, self._diffusion_engine) + stage_configs = list(getattr(diffusion_engine, "stage_configs", []) or []) + default_params_list = list(getattr(diffusion_engine, "default_sampling_params_list", []) or []) + + sampling_params_list: list[Any] = [] + for idx, stage_cfg in enumerate(stage_configs): + if get_stage_type(stage_cfg) == "diffusion": + sampling_params_list.append(gen_params) + continue + + default_stage_params = default_params_list[idx] if idx < len(default_params_list) else SamplingParams() + if hasattr(default_stage_params, "clone"): + try: + default_stage_params = default_stage_params.clone() + except Exception as e: + logger.warning("Failed to clone default params for stage %d: %s", idx, e) + sampling_params_list.append(default_stage_params) + + if not sampling_params_list: + sampling_params_list = [gen_params] + result = None async for output in diffusion_engine.generate( prompt=gen_prompt, - sampling_params_list=[gen_params], # Pass as single-stage params + sampling_params_list=sampling_params_list, request_id=request_id, ): result = output diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 87ef6a4e9b..ba8292f0c2 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -49,12 +49,16 @@ _FISH_TTS_MODEL_STAGES = {"fish_speech_slow_ar"} _COSYVOICE3_TTS_MODEL_STAGES = {"cosyvoice3_talker"} _OMNIVOICE_TTS_MODEL_STAGES = {"omnivoice_generator"} +_VOXCPM_TTS_MODEL_STAGES = {"latent_generator", "vae"} +_VOXCPM2_TTS_MODEL_STAGES = {"latent_generator"} _TTS_MODEL_STAGES: set[str] = ( _VOXTRAL_TTS_MODEL_STAGES | _QWEN3_TTS_MODEL_STAGES | _FISH_TTS_MODEL_STAGES | _COSYVOICE3_TTS_MODEL_STAGES | _OMNIVOICE_TTS_MODEL_STAGES + | _VOXCPM_TTS_MODEL_STAGES + | _VOXCPM2_TTS_MODEL_STAGES ) _TTS_LANGUAGES: set[str] = { "Auto", @@ -212,6 +216,8 @@ def __init__(self, *args, **kwargs): "Re-upload voices after each restart if needed." ) self._tts_tokenizer = None + self._voxcpm2_tokenizer = None + self._voxcpm2_split_map: dict[int, list[int]] = {} logger.info(f"Loaded {len(self.supported_speakers)} supported speakers: {sorted(self.supported_speakers)}") @@ -280,6 +286,11 @@ def _detect_tts_model_type(self) -> str | None: if self._tts_stage is None: return None model_stage = getattr(self._tts_stage.engine_args, "model_stage", None) + model_arch = getattr(self._tts_stage.engine_args, "model_arch", None) + if model_arch == "VoxCPM2TalkerForConditionalGeneration": + return "voxcpm2" + if model_arch == "VoxCPMForConditionalGeneration": + return "voxcpm" if model_stage in _QWEN3_TTS_MODEL_STAGES: return "qwen3_tts" if model_stage in _VOXTRAL_TTS_MODEL_STAGES: @@ -290,6 +301,12 @@ def _detect_tts_model_type(self) -> str | None: return "cosyvoice3" if model_stage in _OMNIVOICE_TTS_MODEL_STAGES: return "omnivoice" + if model_stage in (_VOXCPM_TTS_MODEL_STAGES | _VOXCPM2_TTS_MODEL_STAGES): + has_vae_stage = any( + getattr(getattr(stage, "engine_args", None), "model_stage", None) == "vae" + for stage in self.engine_client.stage_configs + ) + return "voxcpm" if has_vae_stage or model_stage == "vae" else "voxcpm2" return None def _compute_max_instructions_length(self) -> int: @@ -314,6 +331,8 @@ def _compute_max_instructions_length(self) -> int: def _load_supported_speakers(self) -> set[str]: """Load supported speakers (case-insensitive) from the model configuration.""" try: + if self._tts_model_type == "voxcpm": + return set() if self._tts_model_type == "voxtral_tts": config = self.engine_client.model_config.hf_config.audio_config else: @@ -373,6 +392,8 @@ def _estimate_ref_code_len(self, ref_audio: object) -> int | None: def _estimate_prompt_len(self, tts_params: dict[str, Any]) -> int: """Estimate prompt length so the placeholder matches model-side embeddings.""" try: + if self._tts_model_type == "voxcpm": + return 1 from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import ( Qwen3TTSTalkerForConditionalGeneration, ) @@ -436,6 +457,25 @@ def _estimate_fish_prompt_len(self, text: str, ref_text: str, ref_audio: object) logger.warning("Failed to estimate Fish Speech prompt length, using fallback 2048: %s", e) return 2048 + async def _build_voxcpm2_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]: + """Build prefill prompt for VoxCPM2 TTS (`prompt_token_ids` padded to full prefill length).""" + from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import build_voxcpm2_prompt + + self._voxcpm2_encode("") # lazy-init tokenizer + split_map + ref_audio = None + ref_sr = None + if request.ref_audio is not None: + ref_audio, ref_sr = await self._resolve_ref_audio(request.ref_audio) + return build_voxcpm2_prompt( + hf_config=self.engine_client.model_config.hf_config, + tokenizer=self._voxcpm2_tokenizer, + split_map=self._voxcpm2_split_map, + text=request.input, + ref_audio=ref_audio, + ref_sr=ref_sr, + ref_text=request.ref_text, + ) + def _get_uploaded_audio_data(self, voice_name: str) -> str | None: """Get base64 encoded audio data for uploaded voice.""" voice_name_lower = voice_name.lower() @@ -787,8 +827,31 @@ def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | Non return self._validate_fish_tts_request(request) if self._tts_model_type == "cosyvoice3": return self._validate_cosyvoice3_request(request) + if self._tts_model_type == "voxcpm": + return self._validate_voxcpm_request(request) + if self._tts_model_type == "voxcpm2": + return None # VoxCPM2 accepts any text input return self._validate_qwen_tts_request(request) + def _voxcpm2_encode(self, text: str) -> list[int]: + """Tokenize text for VoxCPM2, splitting multichar Chinese tokens.""" + from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import ( + build_cjk_split_map, + split_multichar_chinese, + ) + + if self._voxcpm2_tokenizer is None: + from transformers import AutoTokenizer + + model_name = self.engine_client.model_config.model + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + self._voxcpm2_split_map = build_cjk_split_map(tokenizer) + self._voxcpm2_tokenizer = tokenizer + logger.info("VoxCPM2 serving: built multichar split map (%d entries)", len(self._voxcpm2_split_map)) + + ids = self._voxcpm2_tokenizer.encode(text, add_special_tokens=True) + return split_multichar_chinese(ids, self._voxcpm2_split_map) + def _validate_ref_audio_format(self, ref_audio: str) -> str | None: """Validate ref_audio is a supported URI format. Returns error or None.""" if not ( @@ -826,6 +889,43 @@ def _validate_voxtral_tts_request(self, request: OpenAICreateSpeechRequest) -> s return None + def _validate_voxcpm_request(self, request: OpenAICreateSpeechRequest) -> str | None: + """Validate VoxCPM request parameters. Returns error message or None.""" + if not request.input or not request.input.strip(): + return "Input text cannot be empty" + + if request.voice is not None: + return "'voice' is not supported for VoxCPM" + if request.instructions is not None: + return "'instructions' is not supported for VoxCPM" + if request.language is not None: + return "'language' is not supported for VoxCPM" + if request.task_type not in (None, "Base"): + return "VoxCPM only supports plain TTS or voice cloning with ref_audio/ref_text" + if request.x_vector_only_mode is not None: + return "'x_vector_only_mode' is not supported for VoxCPM" + if request.speaker_embedding is not None: + return "'speaker_embedding' is not supported for VoxCPM" + if request.initial_codec_chunk_frames is not None: + return "'initial_codec_chunk_frames' is not supported for VoxCPM" + + if request.ref_audio is not None: + fmt_err = self._validate_ref_audio_format(request.ref_audio) + if fmt_err: + return fmt_err + if not request.ref_text or not request.ref_text.strip(): + return "Voice cloning requires 'ref_text' (transcript of the reference audio)" + elif request.ref_text is not None: + return "'ref_text' requires 'ref_audio' for VoxCPM voice cloning" + + if request.max_new_tokens is not None: + if request.max_new_tokens < _TTS_MAX_NEW_TOKENS_MIN: + return f"max_new_tokens must be at least {_TTS_MAX_NEW_TOKENS_MIN}" + if request.max_new_tokens > _TTS_MAX_NEW_TOKENS_MAX: + return f"max_new_tokens cannot exceed {_TTS_MAX_NEW_TOKENS_MAX}" + + return None + def _validate_qwen_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None: """Validate Qwen TTS request parameters. Returns error message or None.""" # Infer Base task when ref_audio or ref_text is provided without explicit task_type. @@ -919,6 +1019,13 @@ def _validate_qwen_tts_request(self, request: OpenAICreateSpeechRequest) -> str fmt_err = self._validate_ref_audio_format(request.ref_audio) if fmt_err: return fmt_err + if not getattr(request, "x_vector_only_mode", False) and ( + not request.ref_text or not request.ref_text.strip() + ): + return ( + "Base task requires non-empty 'ref_text' (transcript of " + "the reference audio) unless 'x_vector_only_mode' is enabled" + ) # Validate cross-parameter dependencies if task_type != "Base": @@ -1017,11 +1124,15 @@ async def _resolve_ref_audio(self, ref_audio_str: str) -> tuple[list[float], int URLs, ``data:`` base64 URIs, and ``file:`` local paths (the latter gated by ``--allowed-local-media-path``). """ - model_config = self.model_config - connector = MediaConnector( - allowed_local_media_path=model_config.allowed_local_media_path, - allowed_media_domains=model_config.allowed_media_domains, - ) + # In diffusion mode, model_config may not be available + if self._diffusion_mode: + connector = MediaConnector() + else: + model_config = self.model_config + connector = MediaConnector( + allowed_local_media_path=model_config.allowed_local_media_path, + allowed_media_domains=model_config.allowed_media_domains, + ) wav_np, sr = await connector.fetch_audio_async(ref_audio_str) wav_np = np.asarray(wav_np, dtype=np.float32) if wav_np.ndim > 1: @@ -1152,6 +1263,18 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any Processes each parameter if present, skips if not. Values are wrapped in lists as required by the model. """ + if self._tts_model_type == "voxcpm": + params: dict[str, Any] = { + "text": [request.input], + "cfg_value": [2.0], + "inference_timesteps": [10], + "min_len": [2], + "max_new_tokens": [request.max_new_tokens or 4096], + } + if request.ref_text is not None: + params["ref_text"] = [request.ref_text] + return params + params: dict[str, Any] = {} # Text content (always required) @@ -1392,8 +1515,36 @@ async def _prepare_speech_generation( prompt = await self._build_fish_speech_prompt_async(request, ref_audio_data=ref_audio_data) tts_params = {} elif self._tts_model_type == "omnivoice": + if not request.input or not request.input.strip(): + raise ValueError("Input text cannot be empty") + tts_params = {} + prompt: dict[str, Any] = {"input": request.input} + # Resolve ref_audio: explicit request param or uploaded voice + ref_src = request.ref_audio + if not ref_src and request.voice: + vl = request.voice.lower() + if vl in self.uploaded_speakers: + sp = self.uploaded_speakers[vl] + if sp.get("embedding_source") == "audio": + ref_src = self._get_uploaded_audio_data(request.voice) + if not ref_src: + raise ValueError(f"Audio for voice '{request.voice}' missing") + prompt["ref_text"] = sp.get("ref_text") + if ref_src: + fmt_err = self._validate_ref_audio_format(ref_src) + if fmt_err: + raise ValueError(fmt_err) + wav, sr = await self._resolve_ref_audio(ref_src) + prompt["ref_audio"] = (np.asarray(wav, dtype=np.float32), sr) + if request.ref_text: + prompt["ref_text"] = request.ref_text + if request.language: + prompt["lang"] = request.language + if request.instructions: + prompt["instruct"] = request.instructions + elif self._tts_model_type == "voxcpm2": + prompt = await self._build_voxcpm2_prompt(request) tts_params = {} - prompt = request.input # Diffusion engine takes raw text elif self._is_tts: validation_error = self._validate_tts_request(request) if validation_error: @@ -1420,6 +1571,24 @@ async def _prepare_speech_generation( ph_len = await self._estimate_prompt_len_async(tts_params) prompt = {"prompt_token_ids": [1] * ph_len, "additional_information": tts_params} else: + # Qwen omni models (Qwen3-Omni, Qwen2.5-Omni) use a "talker" + # stage whose preprocess requires chat-templated tokens. The + # async-chunk orchestrator prewarms the talker via + # compute_talker_prompt_ids_length(), which scans for Qwen + # chat-template markers (im_start_token_id 151644). A raw-text + # prompt produces a 1-token placeholder that crashes the talker's + # prefill/decode handoff. Reject early with an actionable message. + stage_names = { + getattr(getattr(s, "engine_args", None), "model_stage", None) for s in self.engine_client.stage_configs + } + if "talker" in stage_names: + raise ValueError( + "The /v1/audio/speech endpoint is only supported for " + "dedicated TTS models (e.g., Qwen3-TTS, Voxtral, Fish " + "Speech, CosyVoice3, OmniVoice, VoxCPM2). For omni " + "models like Qwen3-Omni, use /v1/chat/completions with " + '\'"modalities": ["audio"]\' instead.' + ) tts_params = {} prompt = {"prompt": request.input} @@ -1430,6 +1599,10 @@ async def _prepare_speech_generation( model_type = "voxtral_tts" elif self._tts_model_type == "cosyvoice3": model_type = "cosyvoice3" + elif self._tts_model_type == "voxcpm": + model_type = "voxcpm" + elif self._tts_model_type == "voxcpm2": + model_type = "voxcpm2" elif self._is_tts: model_type = tts_params.get("task_type", ["unknown"])[0] else: @@ -1560,13 +1733,26 @@ async def _create_diffusion_speech( from vllm_omni.outputs import OmniRequestOutput try: + if not request.input or not request.input.strip(): + raise ValueError("Input text cannot be empty") + request_id = f"speech-{random_uuid()}" - prompt = request.input + prompt: dict[str, Any] = {"input": request.input} + if request.ref_audio: + wav, sr = await self._resolve_ref_audio(request.ref_audio) + prompt["ref_audio"] = (np.asarray(wav, dtype=np.float32), sr) + if request.ref_text: + prompt["ref_text"] = request.ref_text + if request.language: + prompt["lang"] = request.language + if request.instructions: + prompt["instruct"] = request.instructions logger.info( - "Diffusion TTS speech request %s: text=%r", + "Diffusion TTS speech request %s: text=%r, voice_clone=%s", request_id, - prompt[:50] + "..." if len(prompt) > 50 else prompt, + request.input[:50] + "..." if len(request.input) > 50 else request.input, + "ref_audio" in prompt, ) generator = self._diffusion_engine.generate( diff --git a/vllm_omni/entrypoints/openai/serving_video.py b/vllm_omni/entrypoints/openai/serving_video.py index bddfd48003..741295c7c2 100644 --- a/vllm_omni/entrypoints/openai/serving_video.py +++ b/vllm_omni/entrypoints/openai/serving_video.py @@ -33,6 +33,18 @@ class ReferenceImage: data: Image.Image +@dataclass +class VideoGenerationArtifacts: + """Normalized outputs and profiler metadata extracted from one request.""" + + videos: list[Any] + audios: list[Any | None] + audio_sample_rate: int + output_fps: int + stage_durations: dict[str, float] + peak_memory_mb: float + + class OmniOpenAIServingVideo: """OpenAI-style video generation handler for omni diffusion models.""" @@ -77,12 +89,8 @@ async def _run_and_extract( reference_id: str, *, reference_image: ReferenceImage | None = None, - ) -> tuple[list[Any], list[Any | None], int, int]: - """Run the generation pipeline and extract video/audio outputs. - - Returns: - Tuple of (videos, audios, audio_sample_rate, output_fps). - """ + ) -> VideoGenerationArtifacts: + """Run the generation pipeline and extract video/audio/profiler outputs.""" prompt: OmniTextPrompt = OmniTextPrompt(prompt=request.prompt) if request.negative_prompt is not None: prompt["negative_prompt"] = request.negative_prompt @@ -105,6 +113,10 @@ async def _run_and_extract( if vp.fps is not None: gen_params.fps = vp.fps gen_params.frame_rate = float(vp.fps) + gen_params.enable_frame_interpolation = request.enable_frame_interpolation + gen_params.frame_interpolation_exp = request.frame_interpolation_exp + gen_params.frame_interpolation_scale = request.frame_interpolation_scale + gen_params.frame_interpolation_model_path = request.frame_interpolation_model_path if request.num_inference_steps is not None: gen_params.num_inference_steps = request.num_inference_steps @@ -152,8 +164,15 @@ async def _run_and_extract( videos = self._extract_video_outputs(result) audios = self._extract_audio_outputs(result, expected_count=len(videos)) audio_sample_rate = self._resolve_audio_sample_rate(result) - output_fps = vp.fps or self._resolve_fps(result) or 24 - return videos, audios, audio_sample_rate, output_fps + output_fps = (vp.fps or self._resolve_fps(result) or 24) * self._resolve_video_fps_multiplier(result) + return VideoGenerationArtifacts( + videos=videos, + audios=audios, + audio_sample_rate=audio_sample_rate, + output_fps=output_fps, + stage_durations=self._extract_stage_durations(result), + peak_memory_mb=self._extract_peak_memory_mb(result), + ) async def generate_videos( self, @@ -162,28 +181,38 @@ async def generate_videos( *, reference_image: ReferenceImage | None = None, ) -> VideoGenerationResponse: - videos, audios, audio_sample_rate, output_fps = await self._run_and_extract( - request, reference_id, reference_image=reference_image - ) + artifacts = await self._run_and_extract(request, reference_id, reference_image=reference_image) + + video_codec_options = {"preset": "ultrafast", "threads": "0"} + if request.extra_params is not None and isinstance(request.extra_params, dict): + if "video_codec_options" in request.extra_params: + video_codec_options = request.extra_params["video_codec_options"] + _t_encode_start = time.perf_counter() video_data = [ VideoData( b64_json=( - encode_video_base64(video, fps=output_fps) - if audios[idx] is None + encode_video_base64(video, fps=artifacts.output_fps, video_codec_options=video_codec_options) + if artifacts.audios[idx] is None else encode_video_base64( video, - fps=output_fps, - audio=audios[idx], - audio_sample_rate=audio_sample_rate, + fps=artifacts.output_fps, + audio=artifacts.audios[idx], + audio_sample_rate=artifacts.audio_sample_rate, + video_codec_options=video_codec_options, ) ) ) - for idx, video in enumerate(videos) + for idx, video in enumerate(artifacts.videos) ] _t_encode_ms = (time.perf_counter() - _t_encode_start) * 1000 logger.info("Video response encoding (MP4+base64): %.2f ms", _t_encode_ms) - return VideoGenerationResponse(created=int(time.time()), data=video_data) + return VideoGenerationResponse( + created=int(time.time()), + data=video_data, + stage_durations=artifacts.stage_durations, + peak_memory_mb=artifacts.peak_memory_mb, + ) async def generate_video_bytes( self, @@ -191,25 +220,48 @@ async def generate_video_bytes( reference_id: str, *, reference_image: ReferenceImage | None = None, - ) -> bytes: + ) -> tuple[bytes, dict[str, float], float]: """Generate a video and return raw MP4 bytes, bypassing base64 encoding.""" - videos, audios, audio_sample_rate, output_fps = await self._run_and_extract( - request, reference_id, reference_image=reference_image - ) - if len(videos) > 1: + artifacts = await self._run_and_extract(request, reference_id, reference_image=reference_image) + if len(artifacts.videos) > 1: logger.warning( - "Video request %s generated %d outputs; returning only the first.", reference_id, len(videos) + "Video request %s generated %d outputs; returning only the first.", + reference_id, + len(artifacts.videos), ) - audio = audios[0] + audio = artifacts.audios[0] + + video_codec_options = {"preset": "ultrafast", "threads": "0"} + if request.extra_params is not None and isinstance(request.extra_params, dict): + if "video_codec_options" in request.extra_params: + video_codec_options = request.extra_params["video_codec_options"] + _t_encode_start = time.perf_counter() video_bytes = _encode_video_bytes( - videos[0], - fps=output_fps, - **({"audio": audio, "audio_sample_rate": audio_sample_rate} if audio is not None else {}), + artifacts.videos[0], + fps=artifacts.output_fps, + **({"audio": audio, "audio_sample_rate": artifacts.audio_sample_rate} if audio is not None else {}), + video_codec_options=video_codec_options, ) _t_encode_ms = (time.perf_counter() - _t_encode_start) * 1000 logger.info("Video response encoding (MP4 bytes): %.2f ms", _t_encode_ms) - return video_bytes + return video_bytes, artifacts.stage_durations, artifacts.peak_memory_mb + + @staticmethod + def _resolve_video_fps_multiplier(result: Any) -> int: + custom_output = getattr(result, "custom_output", None) + if isinstance(custom_output, dict): + multiplier = custom_output.get("video_fps_multiplier") + if multiplier is not None: + return int(multiplier) + request_output = getattr(result, "request_output", None) + if request_output is not None: + custom_output = getattr(request_output, "custom_output", None) + if isinstance(custom_output, dict): + multiplier = custom_output.get("video_fps_multiplier") + if multiplier is not None: + return int(multiplier) + return 1 @staticmethod def _apply_lora(lora_body: Any, gen_params: OmniDiffusionSamplingParams) -> None: @@ -483,3 +535,16 @@ def _coerce_audio_sample_rate(value: Any) -> int | None: return None return sample_rate if sample_rate > 0 else None + + @staticmethod + def _extract_stage_durations(result: Any) -> dict[str, float]: + stage_durations = getattr(result, "stage_durations", None) + return stage_durations if isinstance(stage_durations, dict) else {} + + @staticmethod + def _extract_peak_memory_mb(result: Any) -> float: + peak_memory_mb = getattr(result, "peak_memory_mb", 0.0) + try: + return float(peak_memory_mb or 0.0) + except (TypeError, ValueError): + return 0.0 diff --git a/vllm_omni/entrypoints/openai/utils.py b/vllm_omni/entrypoints/openai/utils.py index 84b28ef5b1..f411526fdb 100644 --- a/vllm_omni/entrypoints/openai/utils.py +++ b/vllm_omni/entrypoints/openai/utils.py @@ -53,3 +53,33 @@ def parse_lora_request(lora_body: Any) -> tuple[LoRARequest | None, float | None scale = float(lora_scale) if lora_scale is not None else None return LoRARequest(str(lora_name), int(lora_int_id), str(lora_path)), scale + + +def get_supported_speakers_from_hf_config(hf_config: Any) -> set[str]: + """Extract supported speaker names from a model hf_config.""" + config = ( + hf_config.get("talker_config") if isinstance(hf_config, dict) else getattr(hf_config, "talker_config", None) + ) + if config is None: + return set() + + for spk_attr in ("speaker_id", "spk_id"): + speakers_dict = config.get(spk_attr) if isinstance(config, dict) else getattr(config, spk_attr, None) + if speakers_dict and isinstance(speakers_dict, dict): + return {speaker.lower() for speaker in speakers_dict} + return set() + + +def validate_requested_speaker(speaker: str | None, supported_speakers: set[str]) -> str | None: + """Normalize and validate an optional speaker value. + + Returns the normalized speaker string when provided, otherwise ``None``. + Raises ``ValueError`` when the speaker is not in the supported list. + """ + if not isinstance(speaker, str) or not speaker.strip(): + return None + + normalized = speaker.lower().strip() + if supported_speakers and normalized not in supported_speakers: + raise ValueError(f"Invalid speaker '{speaker}'. Supported: {', '.join(sorted(supported_speakers))}") + return normalized diff --git a/vllm_omni/entrypoints/openai/video_api_utils.py b/vllm_omni/entrypoints/openai/video_api_utils.py index 69178fb3d3..3fb991225c 100644 --- a/vllm_omni/entrypoints/openai/video_api_utils.py +++ b/vllm_omni/entrypoints/openai/video_api_utils.py @@ -202,7 +202,13 @@ def _coerce_audio_to_numpy(audio: Any) -> np.ndarray: return arr.astype(np.float32) -def _encode_video_bytes(video: Any, fps: int, audio: Any | None = None, audio_sample_rate: int | None = None) -> bytes: +def _encode_video_bytes( + video: Any, + fps: int, + audio: Any | None = None, + audio_sample_rate: int | None = None, + video_codec_options: dict[str, str] | None = None, +) -> bytes: """Encode a video payload into MP4 bytes, optionally muxing audio.""" from vllm_omni.diffusion.utils.media_utils import mux_video_audio_bytes @@ -213,7 +219,16 @@ def _encode_video_bytes(video: Any, fps: int, audio: Any | None = None, audio_sa frames_np = np.stack(frames, axis=0) if frames_np.ndim == 4 and frames_np.shape[-1] == 4: frames_np = frames_np[..., :3] - frames_u8 = (np.clip(frames_np, 0.0, 1.0) * 255).round().clip(0, 255).astype(np.uint8) + + if frames_np.dtype == np.uint8: + frames_u8 = frames_np + else: + frames_np = np.clip(frames_np, 0.0, 1.0) + frames_np *= 255.0 + frames_u8 = np.round(frames_np).astype(np.uint8) + + # Ensure contiguous memory layout for faster PyAV muxing + frames_u8 = np.ascontiguousarray(frames_u8) audio_np = _coerce_audio_to_numpy(audio) if audio is not None else None @@ -222,10 +237,19 @@ def _encode_video_bytes(video: Any, fps: int, audio: Any | None = None, audio_sa audio_np, fps=float(fps), audio_sample_rate=audio_sample_rate or 24000, + video_codec_options=video_codec_options, ) -def encode_video_base64(video: Any, fps: int, audio: Any | None = None, audio_sample_rate: int | None = None) -> str: +def encode_video_base64( + video: Any, + fps: int, + audio: Any | None = None, + audio_sample_rate: int | None = None, + video_codec_options: dict[str, str] | None = None, +) -> str: """Encode a video (frames/array/tensor) to base64 MP4.""" - video_bytes = _encode_video_bytes(video, fps=fps, audio=audio, audio_sample_rate=audio_sample_rate) + video_bytes = _encode_video_bytes( + video, fps=fps, audio=audio, audio_sample_rate=audio_sample_rate, video_codec_options=video_codec_options + ) return base64.b64encode(video_bytes).decode("utf-8") diff --git a/vllm_omni/entrypoints/pd_utils.py b/vllm_omni/entrypoints/pd_utils.py index 0e3d65f553..413d5d6b44 100644 --- a/vllm_omni/entrypoints/pd_utils.py +++ b/vllm_omni/entrypoints/pd_utils.py @@ -23,9 +23,19 @@ class PDDisaggregationMixin: """Mixin supplying PD disaggregation helpers to OmniBase.""" + def _get_pd_separation_pair(self) -> tuple[int, int] | None: + """PD prefill/decode indices when ``_init_pd_state`` ran; else ``None``. + + Partial test doubles may skip ``OmniBase.__init__``; treat missing state as + no PD disaggregation instead of raising ``AttributeError``. + """ + return getattr(self, "_pd_separation_pair", None) + def _init_pd_state(self) -> None: """Initialise PD disaggregation state.""" - self._pd_separation_pair: tuple[int, int] | None = self._detect_pd_separation() + self._pd_separation_pair: tuple[int, int] | None = self.detect_pd_separation_from_stage_configs( + self.stage_configs + ) self._pd_connector_info: dict[str, Any] | None = None self._pd_kv_params_by_req: dict[str, dict[str, Any]] = {} self._pd_kv_params_lock = threading.Lock() @@ -40,11 +50,19 @@ def _init_pd_state(self) -> None: d_id, ) - def _detect_pd_separation(self) -> tuple[int, int] | None: - """Scan stage_list for a prefill/decode pair. Returns (p_id, d_id) or None.""" + @staticmethod + def detect_pd_separation_from_stage_configs(stage_configs: list[Any]) -> tuple[int, int] | None: + """Scan stage configs for a prefill/decode pair. + + Returns: + (prefill_idx, decode_idx) if one pair exists, None if not found. + + Raises: + ValueError: if multiple candidate PD pairs are found. + """ prefill_by_id: dict[int, int] = {} decode_indices: list[int] = [] - for i, stage in enumerate(self.stage_list): + for i, stage in enumerate(stage_configs): if getattr(stage, "is_prefill_only", False): prefill_by_id[i] = i sid = getattr(stage, "stage_id", i) @@ -55,7 +73,7 @@ def _detect_pd_separation(self) -> tuple[int, int] | None: pd_pairs: list[tuple[int, int]] = [] for j in decode_indices: - source_ids = getattr(self.stage_list[j], "engine_input_source", []) + source_ids = getattr(stage_configs[j], "engine_input_source", []) for src in source_ids: if src in prefill_by_id: pd_pairs.append((prefill_by_id[src], j)) @@ -107,10 +125,11 @@ def _normalize_kv_transfer_params(self, kv_params: Any) -> dict[str, Any] | None def _validate_pd_separation_config(self) -> None: """Validate PD stage configurations are consistent.""" - assert self._pd_separation_pair is not None - p_id, d_id = self._pd_separation_pair - p_stage = self.stage_list[p_id] - d_stage = self.stage_list[d_id] + pair = self._get_pd_separation_pair() + assert pair is not None + p_id, d_id = pair + p_stage = self.stage_configs[p_id] + d_stage = self.stage_configs[d_id] def _get_kv_cfg(stage: "OmniStage") -> dict[str, Any]: ea = stage.engine_args @@ -158,11 +177,12 @@ def _get_kv_cfg(stage: "OmniStage") -> dict[str, Any]: def _get_pd_connector_info(self) -> dict[str, Any] | None: """Extract prefill engine KV connector info.""" - if self._pd_separation_pair is None: + pair = self._get_pd_separation_pair() + if pair is None: return None - p_id, _ = self._pd_separation_pair - p_stage = self.stage_list[p_id] + p_id, _ = pair + p_stage = self.stage_configs[p_id] ea = p_stage.engine_args kv_cfg = getattr(ea, "kv_transfer_config", None) @@ -241,18 +261,17 @@ def _extract_kv_transfer_params(self, engine_outputs: Any) -> dict[str, Any] | N def _is_pd_routing(self, stage_id: int, next_stage_id: int) -> bool: """True when edge stage_id → next_stage_id is the prefill→decode boundary.""" - return self._pd_separation_pair is not None and self._pd_separation_pair == ( - stage_id, - next_stage_id, - ) + pair = self._get_pd_separation_pair() + return pair is not None and pair == (stage_id, next_stage_id) def _maybe_expand_sampling_params(self, sampling_params_list: list) -> list: """Auto-duplicate thinker SP for decode stage when user provides N-1 params.""" - if self._pd_separation_pair is None: + pair = self._get_pd_separation_pair() + if pair is None: return sampling_params_list - if len(sampling_params_list) != len(self.stage_list) - 1: + if len(sampling_params_list) != len(self.stage_configs) - 1: return sampling_params_list - p_id, d_id = self._pd_separation_pair + p_id, d_id = pair sp_list = list(sampling_params_list) sp_list.insert(d_id, sp_list[p_id]) return sp_list diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index 84391c2ea8..5757d38990 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -1,3 +1,4 @@ +import argparse import os import types from collections import Counter @@ -5,10 +6,12 @@ from pathlib import Path from typing import Any, get_args, get_origin +import yaml from vllm.logger import init_logger from vllm.transformers_utils.config import get_config, get_hf_file_to_dict from vllm.transformers_utils.repo_utils import file_or_path_exists +from vllm_omni.config.stage_config import StageConfigFactory from vllm_omni.config.yaml_util import create_config, load_yaml_config, merge_configs from vllm_omni.entrypoints.stage_utils import _to_dict from vllm_omni.platforms import current_omni_platform @@ -23,6 +26,65 @@ } +def detect_explicit_cli_keys( + argv: list[str], + parser: argparse.ArgumentParser | None = None, +) -> set[str]: + """Walk ``argv`` and return the set of ``dest`` attribute names the user + explicitly provided (e.g. ``--max-num-seqs 64`` → ``max_num_seqs``). + + Used to distinguish user-typed CLI args from argparse default values so + deploy YAMLs are not silently overridden by parser defaults. Shared + across online (``vllm serve``) and offline (scripts, examples, tests, + CI) entry points — offline callers that parse CLI args via argparse + should invoke this on ``sys.argv[1:]`` and pass the result through to + ``AsyncOmni`` / ``Omni`` via the ``_cli_explicit_keys`` kwarg. + + When ``parser`` is provided, each token is looked up in the parser's + action table to find its real ``dest``. This correctly handles flags + with ``dest=`` overrides, alias flags (e.g. ``--usp`` / + ``--ulysses-degree`` both mapping to ``ulysses_degree``), and + ``--disable-foo`` / ``store_false`` patterns that map to a differently + named dest. Callers with access to an ``argparse.ArgumentParser`` should + always pass it. + + When ``parser`` is ``None``, a name-based heuristic is used as a + fallback (hyphens → underscores, plus a ``no_`` prefix strip for + ``argparse.BooleanOptionalAction``). This is correct for simple flags + but silently misidentifies ``--disable-X``-style flags and explicit + ``dest=`` overrides, so prefer the parser-aware form. + """ + if parser is not None: + dest_map: dict[str, str] = {} + for action in parser._actions: + for opt in action.option_strings: + dest_map[opt] = action.dest + explicit: set[str] = set() + for tok in argv: + if not tok.startswith("--"): + continue + flag = tok.split("=", 1)[0] + dest = dest_map.get(flag) + if dest is not None: + explicit.add(dest) + return explicit + + # Fallback: name-based heuristic (legacy path for callers without a parser). + explicit = set() + for tok in argv: + if not tok.startswith("--"): + continue + name = tok[2:].split("=", 1)[0] + if not name: + continue + attr = name.replace("-", "_") + explicit.add(attr) + # BooleanOptionalAction: --no-foo records as dest `foo`, not `no_foo`. + if attr.startswith("no_"): + explicit.add(attr[3:]) + return explicit + + def inject_omni_kv_config(stage: Any, omni_conn_cfg: dict[str, Any], omni_from: str, omni_to: str) -> None: """Inject connector configuration into stage engine arguments.""" # Prepare omni_kv_config dict @@ -273,29 +335,59 @@ def resolve_model_config_path(model: str) -> str: return str(stage_config_path) -def load_stage_configs_from_model(model: str, base_engine_args: dict | None = None) -> list: +def load_stage_configs_from_model( + model: str, + base_engine_args: dict | None = None, + deploy_config_path: str | None = None, + stage_overrides: dict[str, dict[str, Any]] | None = None, + cli_explicit_keys: set[str] | None = None, +) -> list: """Load stage configurations from model's default config file. - .. deprecated:: - This is the legacy OmegaConf-based loading path. New code should use - ``StageConfigFactory.create_from_model()`` instead. This function will - be removed once all callers are migrated (see PR series [2/N]). + For models registered in the pipeline registry (new path), uses + ``StageConfigFactory.create_from_model()`` which merges + PipelineConfig + DeployConfig + CLI overrides. - Loads stage configurations based on the model type and device type. - First tries to load a device-specific YAML file from stage_configs/{device_type}/ - directory. If not found, falls back to the default config file. + For other models (legacy path), loads stage configs from YAML. Args: model: Model name or path (used to determine model_type) + base_engine_args: Base engine args to merge as CLI overrides. + deploy_config_path: Optional explicit deploy config path. + stage_overrides: Per-stage overrides from --stage-overrides. + cli_explicit_keys: Set of CLI keys the user actually typed. When + provided, only these keys override deploy YAML; argparse defaults + stay subordinate to YAML. ``None`` means treat every kwarg as + explicit (programmatic ``Omni()`` calls). Returns: List of stage configuration dictionaries - - Raises: - FileNotFoundError: If no stage config file exists for the model type """ if base_engine_args is None: base_engine_args = {} + + cli_overrides = _convert_dataclasses_to_dict(dict(base_engine_args)) + # Per-stage JSON overrides are always explicit (the user typed --stage-overrides). + explicit = set(cli_explicit_keys) if cli_explicit_keys is not None else None + if stage_overrides: + for stage_id_str, overrides in stage_overrides.items(): + for key, val in overrides.items(): + stage_key = f"stage_{stage_id_str}_{key}" + cli_overrides[stage_key] = val + if explicit is not None: + explicit.add(stage_key) + + stages = StageConfigFactory.create_from_model( + model, + cli_overrides=cli_overrides, + deploy_config_path=deploy_config_path, + cli_explicit_keys=explicit, + ) + if stages is not None: + # Convert StageConfig objects to OmegaConf for backward compat + return [stage.to_omegaconf() for stage in stages] + + # Legacy fallback: load from YAML stage_config_path = resolve_model_config_path(model) if stage_config_path is None: return [] @@ -312,10 +404,9 @@ def load_stage_configs_from_yaml( base_engine_args: dict | None = None, prefer_stage_engine_args: bool = True, ) -> list: - """Load stage configurations from a YAML file. + """Load stage configurations from a YAML file (legacy OmegaConf path). - .. deprecated:: - Legacy OmegaConf-based loader. Will be removed in PR series [2/N]. + TODO(@lishunyang12): remove once all models use PipelineConfig + DeployConfig. Args: config_path: Path to the YAML configuration file @@ -449,22 +540,75 @@ def load_and_resolve_stage_configs( stage_configs_path: str | None, kwargs: dict | None, default_stage_cfg_factory: Any = None, + deploy_config_path: str | None = None, + stage_overrides: dict[str, dict[str, Any]] | None = None, + cli_explicit_keys: set[str] | None = None, ) -> tuple[str, list]: """Load stage configurations from model or YAML file with fallback to defaults. Args: model: Model name or path - stage_configs_path: Optional path to YAML file containing stage configurations + stage_configs_path: Optional path to legacy YAML (stage_args format) kwargs: Engine arguments to merge with stage configs default_stage_cfg_factory: Optional callable that takes no args and returns default stage config list when no configs are found + deploy_config_path: Optional path to deploy YAML (new format). + Mutually exclusive with ``stage_configs_path``. + stage_overrides: Per-stage overrides from ``--stage-overrides`` JSON. + Keys are stage_id strings, values are dicts of overrides. Returns: Tuple of (config_path, stage_configs) """ - if stage_configs_path is None: + if stage_configs_path is not None and deploy_config_path is not None: + raise ValueError( + "--stage-configs-path and --deploy-config are mutually exclusive: " + "they use different path resolution rules and loading paths. " + "Use --deploy-config for new-format YAMLs (preferred); " + "--stage-configs-path is kept only for the legacy `stage_args` format " + "and will be removed in a future release." + ) + if stage_configs_path is not None and deploy_config_path is None: + if not os.path.exists(stage_configs_path): + raise FileNotFoundError( + f"--stage-configs-path {stage_configs_path!r} does not exist. " + "Legacy `stage_configs/` yamls were replaced by `vllm_omni/deploy/.yaml`; " + "use --deploy-config. See docs/configuration/stage_configs.md." + ) + with open(stage_configs_path, encoding="utf-8") as f: + _peek = yaml.safe_load(f) or {} + if "stages" in _peek and "stage_args" not in _peek: + deploy_config_path = stage_configs_path + stage_configs_path = None + else: + logger.warning( + "--stage-configs-path is deprecated; migrate %r and use --deploy-config.", + stage_configs_path, + ) + + if deploy_config_path is not None: + config_path = deploy_config_path + stage_configs = load_stage_configs_from_model( + model, + base_engine_args=kwargs, + deploy_config_path=deploy_config_path, + stage_overrides=stage_overrides, + cli_explicit_keys=cli_explicit_keys, + ) + if not stage_configs: + if default_stage_cfg_factory is not None: + default_stage_cfg = default_stage_cfg_factory() + stage_configs = create_config(default_stage_cfg) + else: + stage_configs = [] + elif stage_configs_path is None: config_path = resolve_model_config_path(model) - stage_configs = load_stage_configs_from_model(model, base_engine_args=kwargs) + stage_configs = load_stage_configs_from_model( + model, + base_engine_args=kwargs, + stage_overrides=stage_overrides, + cli_explicit_keys=cli_explicit_keys, + ) if not stage_configs: if default_stage_cfg_factory is not None: default_stage_cfg = default_stage_cfg_factory() diff --git a/vllm_omni/inputs/data.py b/vllm_omni/inputs/data.py index 9cb6c44335..e4c33a58c2 100644 --- a/vllm_omni/inputs/data.py +++ b/vllm_omni/inputs/data.py @@ -227,6 +227,10 @@ class OmniDiffusionSamplingParams: frame_rate: float | None = None # Floating-point rate used by the diffusion model when it differs from `fps`. height_not_provided: bool = False width_not_provided: bool = False + enable_frame_interpolation: bool = False + frame_interpolation_exp: int = 1 + frame_interpolation_scale: float = 1.0 + frame_interpolation_model_path: str | None = None # Timesteps timesteps: torch.Tensor | None = None @@ -263,6 +267,10 @@ class OmniDiffusionSamplingParams: cfg_text_kv_metadata: dict[str, Any] | None = None cfg_img_kv_metadata: dict[str, Any] | None = None cfg_kv_request_ids: dict[str, str] | None = None + cfg_active_branch: str | None = None + cfg_branch_roles: list[str] | None = None + cfg_branch_past_key_values: dict[str, Any] | None = None + cfg_branch_kv_metadata: dict[str, dict[str, Any]] | None = None # Component modules modules: dict[str, Any] = field(default_factory=dict) diff --git a/vllm_omni/model_executor/models/bagel/bagel.py b/vllm_omni/model_executor/models/bagel/bagel.py index acbbc28b4c..cbb775680c 100644 --- a/vllm_omni/model_executor/models/bagel/bagel.py +++ b/vllm_omni/model_executor/models/bagel/bagel.py @@ -1,4 +1,3 @@ -from collections import deque from collections.abc import Iterable, Mapping, Sequence from math import isqrt from typing import Any @@ -442,14 +441,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._pending_img2img_info: list[tuple[int, int, int, int]] = [] self._ropes_pending: list[dict[str, Any]] = [] self._ropes_metadata: dict[str, dict[str, Any]] = {} - self._cfg_companion_queue: deque[tuple[tuple[int, int, int, int], int]] = deque() - - # Per-request position offset for decode after img2img prefill. - # Prefill rewrites positions (VAE→0, ViT→1, text→2..N) but the model - # runner assigns decode positions starting from prefill_len, not N+1. - # offset = rope - prefill_len (a negative number). - self._pending_decode_offsets: list[int] = [] - self._decode_position_offsets: dict[str, int] = {} + self._last_img2img_info: tuple[int, int, int, int] | None = None from transformers import AutoTokenizer @@ -461,7 +453,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._start_of_image_id = int(_tok.convert_tokens_to_ids("<|vision_start|>")) self._end_of_image_id = int(_tok.convert_tokens_to_ids("<|vision_end|>")) self._img2img_token_id = int(_tok.convert_tokens_to_ids("<|fim_middle|>")) - self._vae_token_mask: torch.Tensor | None = None self.device = get_local_device() self._install_mot_modules(config) @@ -540,9 +531,7 @@ def _clear_warmup_state(self): self._ropes_pending.clear() self._ropes_metadata.clear() self._pending_img2img_info.clear() - self._cfg_companion_queue.clear() - self._pending_decode_offsets.clear() - self._decode_position_offsets.clear() + self._last_img2img_info = None self._vae_token_mask = None def get_kv_transfer_metadata( @@ -554,12 +543,10 @@ def get_kv_transfer_metadata( meta = self._ropes_metadata.pop(req_id, None) if meta is None: return None - # In think-mode img2img the prefill rope doesn't account for decoded - # thinking tokens; correct it to num_computed_tokens + offset. - # Skip correction when num_computed_tokens is unavailable (None). - offset = self._decode_position_offsets.pop(req_id, 0) - if offset != 0 and "ropes" in meta and num_computed_tokens is not None: - meta["ropes"] = [num_computed_tokens + offset] + if num_computed_tokens is not None and "image_shape" in meta: + prefill_rope = meta["ropes"][0] if meta.get("ropes") else 0 + if num_computed_tokens > prefill_rope: + meta["ropes"] = [num_computed_tokens] return meta def prepare_runner_inputs( @@ -572,48 +559,29 @@ def prepare_runner_inputs( num_scheduled_tokens: list[int], input_ids_buffer: torch.Tensor | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - """Model-runner hook: adjust inputs before ``forward()``. - - Returns ``(input_ids, positions)`` — possibly modified. - - Two adjustments for BAGEL img2img: - - 1. **Restore input_ids** when ``inputs_embeds`` is present so that - ``_adjust_positions_for_img2img`` can locate the - ``<|fim_middle|>`` placeholder. - 2. **Decode position offset**: prefill rewrites positions to a - compact scheme (rope ≪ prefill_len). The runner assigns decode - positions from ``num_computed_tokens``, which is far too large; - apply the stored per-request offset. - """ + """Restore input_ids so _adjust_positions_for_img2img can locate + the <|fim_middle|> placeholder for thinking-mode pre_text_len + detection.""" if inputs_embeds is not None and input_ids is None and input_ids_buffer is not None: input_ids = input_ids_buffer - - if self._decode_position_offsets and positions is not None: - token_start = 0 - for i, rid in enumerate(req_ids): - sched = num_scheduled_tokens[i] - offset = self._decode_position_offsets.get(rid, 0) - if offset != 0 and num_computed_tokens[i] > 0: - positions[token_start : token_start + sched] += offset - token_start += sched - return input_ids, positions def flush_pending_metadata(self, req_ids: list[str]) -> None: - """Map pending metadata (batch order) to req_ids after forward().""" + """Map pending metadata (batch order) to req_ids after forward(). + + Guard: if a request already has metadata with ``image_shape`` + (written during img2img prefill), don't overwrite it with + decode-step metadata that lacks ``image_shape``. + """ pending = self._ropes_pending self._ropes_pending = [] for i, meta in enumerate(pending): if i < len(req_ids): - if req_ids[i] not in self._ropes_metadata: - self._ropes_metadata[req_ids[i]] = meta - - pending_offsets = self._pending_decode_offsets - self._pending_decode_offsets = [] - for i, offset in enumerate(pending_offsets): - if i < len(req_ids) and offset != 0: - self._decode_position_offsets[req_ids[i]] = offset + rid = req_ids[i] + existing = self._ropes_metadata.get(rid) + if existing and "image_shape" in existing and "image_shape" not in meta: + continue + self._ropes_metadata[rid] = meta def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} @@ -727,16 +695,7 @@ def _process_img2img_input(self, multimodal_input): num_vit = vit_emb.shape[0] + 2 info = (num_vae, num_vit, int(H), int(W)) self._pending_img2img_info.append(info) - # Only the gen (main) request should add a companion queue entry. - # Companion requests (cfg_text, cfg_img) also call this method with - # the same image, so guard by checking whether this exact info - # tuple is already enqueued. For batched img2img with multiple - # concurrent gen requests this correctly adds one entry per unique - # image; images with identical (num_vae, num_vit, H, W) that arrive - # in the same batch are indistinguishable here and will share one - # entry, but that is an uncommon edge case. - if not any(entry[0] == info for entry in self._cfg_companion_queue): - self._cfg_companion_queue.append((info, 2)) # cfg_text + cfg_img + self._last_img2img_info = info return tuple(results) @@ -755,31 +714,18 @@ def forward( positions = self._adjust_positions_for_img2img(positions, input_ids) use_mot = True - elif self._cfg_companion_queue: - # Guard: if this looks like a pure decode step (small token count, - # no multimodal embeddings), the queue has stale entries from a - # previous prefill cycle — clear them instead of consuming. - if inputs_embeds is None and seq_len <= 2: - self._cfg_companion_queue.clear() - else: - cached, remaining = self._cfg_companion_queue[0] - remaining -= 1 - num_vae, num_vit, img_H, img_W = cached - num_img2img = num_vae + 1 + num_vit # +1 separator - seq_len = inputs_embeds.shape[0] if inputs_embeds is not None else positions.shape[0] - - if inputs_embeds is not None and seq_len >= num_img2img: - self._pending_img2img_info = [cached] - positions = self._adjust_positions_for_img2img(positions, input_ids) - use_mot = True - else: - rope = int(positions[seq_len - 1].item()) + 1 - self._ropes_pending.append({"ropes": [rope]}) + elif self._last_img2img_info is not None: + info = self._last_img2img_info + num_vae, num_vit, _, _ = info + num_img2img = num_vae + 1 + num_vit - if remaining == 0: - self._cfg_companion_queue.popleft() - else: - self._cfg_companion_queue[0] = (cached, remaining) + if seq_len >= num_img2img: + self._pending_img2img_info = [info] + positions = self._adjust_positions_for_img2img(positions, input_ids) + use_mot = True + else: + rope = int(positions[seq_len - 1].item()) + 1 + self._ropes_pending.append({"ropes": [rope]}) if use_mot: return self._mot_forward(input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs) @@ -790,27 +736,18 @@ def _adjust_positions_for_img2img( positions: torch.Tensor, input_ids: torch.Tensor | None = None, ) -> torch.Tensor: - """Rewrite position IDs to match the original BAGEL position scheme: - - If there are ``pre_text_len`` text tokens before the img2img block:: - - pre_text → 0, 1, ..., M-1 - VAE → M (all share) - separator→ M - ViT → M+1 (all share) - post_text→ M+2, M+3, ... + """Rewrite position IDs for img2img. - When no text precedes the img2img block (M=0), this reduces to the - simpler scheme: VAE→0, ViT→1, text→2, 3, ... + Supports an optional ``pre_text_len`` prefix (thinking-mode) detected + via the ``<|fim_middle|>`` token in *input_ids*: - Also computes ``self._vae_token_mask`` (bool tensor, True for actual - VAE latent patches that should use gen-mode weights) and pushes - per-request ropes + image_shape to the FIFO consumed by - ``get_kv_transfer_metadata``. + pre_text -> 0 .. M-1 + VAE -> M (all share) + separator-> M + ViT -> M+1 (all share) + post_text-> M+2, M+3, ... - For img2img requests, also stores a decode position offset so that - subsequent autoregressive decode steps use positions that continue - from the rewritten scheme rather than from the original prefill length. + When M=0 (standard img2img) this reduces to VAE->0, ViT->1, text->2.. """ info_list = self._pending_img2img_info self._pending_img2img_info = [] @@ -836,70 +773,64 @@ def _adjust_positions_for_img2img( req_len = end - start if img2img_idx < len(info_list): - num_vae, num_vit, img_H, img_W = info_list[img2img_idx] + cur_info = info_list[img2img_idx] + elif self._last_img2img_info is not None: + cur_info = self._last_img2img_info + else: + cur_info = None + + if cur_info is not None: + num_vae, num_vit, img_H, img_W = cur_info num_img2img = num_vae + 1 + num_vit # +1 separator if req_len >= num_img2img: - # Detect offset of img2img tokens within this request - # by searching for the img2img placeholder token ID. pre_text_len = 0 if input_ids is not None: - req_ids = input_ids[start:end] - mask = req_ids == self._img2img_token_id - indices = mask.nonzero(as_tuple=True)[0] + req_ids_slice = input_ids[start:end] + indices = (req_ids_slice == self._img2img_token_id).nonzero(as_tuple=True)[0] if indices.numel() > 0: pre_text_len = int(indices[0].item()) - img_start = start + pre_text_len + M = pre_text_len + img_start = start + M post_text_start = img_start + num_img2img - # pre_text_pos: position base for image tokens - pre_text_pos = pre_text_len - # Pre-image text: sequential positions 0..pre_text_pos-1 - if pre_text_len > 0: + if M > 0: new_positions[start:img_start] = torch.arange( - 0, pre_text_pos, device=positions.device, dtype=positions.dtype + 0, M, device=positions.device, dtype=positions.dtype ) - # VAE tokens: all share position pre_text_pos - new_positions[img_start : img_start + num_vae] = pre_text_pos - # Separator: position pre_text_pos - new_positions[img_start + num_vae] = pre_text_pos - # ViT tokens: all share position pre_text_pos+1 + new_positions[img_start : img_start + num_vae] = M + new_positions[img_start + num_vae] = M # separator vit_start = img_start + num_vae + 1 - new_positions[vit_start : vit_start + num_vit] = pre_text_pos + 1 + new_positions[vit_start : vit_start + num_vit] = M + 1 - # Post-image text: sequential positions pre_text_pos+2, pre_text_pos+3, ... num_post_text = end - post_text_start if num_post_text > 0: new_positions[post_text_start:end] = torch.arange( - pre_text_pos + 2, - pre_text_pos + 2 + num_post_text, + M + 2, + M + 2 + num_post_text, device=positions.device, dtype=positions.dtype, ) - # VAE gen-mode mask: only actual VAE latent patches (not markers) - vae_patches_start = img_start + 1 # skip start_marker - vae_patches_end = img_start + num_vae - 1 # before end_marker + vae_patches_start = img_start + 1 + vae_patches_end = img_start + num_vae - 1 if vae_patches_end > vae_patches_start: vae_mask[vae_patches_start:vae_patches_end] = True - rope = pre_text_pos + 2 + num_post_text + rope = M + 2 + num_post_text self._ropes_pending.append( { "ropes": [rope], "image_shape": [img_H, img_W], } ) - decode_offset = rope - req_len - self._pending_decode_offsets.append(decode_offset) img2img_idx += 1 continue rope = int(new_positions[end - 1].item()) + 1 self._ropes_pending.append({"ropes": [rope]}) - self._pending_decode_offsets.append(0) self._vae_token_mask = vae_mask if vae_mask.any() else None return new_positions diff --git a/vllm_omni/model_executor/models/common/__init__.py b/vllm_omni/model_executor/models/common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_omni/model_executor/models/common/qwen3_code_predictor.py b/vllm_omni/model_executor/models/common/qwen3_code_predictor.py new file mode 100644 index 0000000000..3a904442fa --- /dev/null +++ b/vllm_omni/model_executor/models/common/qwen3_code_predictor.py @@ -0,0 +1,654 @@ +"""Qwen3 Code Predictor -- optimized re-prefill, no KV cache. + +Shared by Qwen3-Omni and Qwen3-TTS talker models. + +* SDPA attention (F.scaled_dot_product_attention) with native GQA support +* HF-compatible numerics (float32 RMSNorm, float32 RoPE, separate linear layers) +* Per-call embedding buffer to avoid cross-request aliasing +* Pre-allocated position_ids (read-only, safe to persist) +* torch.compile (epilogue_fusion=False) on inner transformer by default +* Optional manual CUDA graph capture per batch-size bucket +* Inline sampling (top-k + top-p) -- no custom op overhead +""" + +from __future__ import annotations + +import dataclasses +from collections.abc import Iterable + +import torch +import torch.nn as nn +import torch.nn.functional as F +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.platforms import current_omni_platform + +logger = init_logger(__name__) + + +# =================================================================== +# HF-numerics-compatible layers for code predictor +# =================================================================== +# +# These use plain PyTorch ops (nn.Linear, manual RMSNorm in float32, +# rotate_half RoPE) to produce outputs numerically identical to the +# HuggingFace reference. vLLM's fused kernels (RMSNorm, QKVParallel, +# get_rope) introduce small precision differences that compound across +# the autoregressive steps of the code predictor, causing severe +# audio quality degradation. +# +# See: https://github.com/vllm-project/vllm-omni/issues/2274 + + +class _RMSNorm(nn.Module): + """RMSNorm matching HuggingFace's implementation exactly. + + Computes variance in float32 to avoid bfloat16 precision loss. + """ + + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +class _RotaryEmbedding(nn.Module): + """RoPE matching HuggingFace's implementation exactly. + + Forces float32 computation for cos/sin, matching HF's torch.autocast(enabled=False). + """ + + def __init__(self, config) -> None: + super().__init__() + head_dim = getattr( + config, + "head_dim", + config.hidden_size // config.num_attention_heads, + ) + rope_theta = getattr(config, "rope_theta", 10000.0) + inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # position_ids: [batch, seq_len] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 (matching HF) + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# =================================================================== +# Attention +# =================================================================== + + +class CodePredictorAttention(nn.Module): + """Multi-head self-attention for code predictor. + + Uses ``F.scaled_dot_product_attention`` with HF-compatible RoPE and RMSNorm. + No KV cache -- the code predictor always re-prefills the full (short) + sequence each AR step. + + Input : [B, seq_len, hidden_size] + Output: [B, seq_len, hidden_size] + """ + + def __init__(self, config, *, prefix: str = "") -> None: + super().__init__() + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = getattr( + config, + "head_dim", + config.hidden_size // config.num_attention_heads, + ) + self.hidden_size = config.hidden_size + self.scaling = self.head_dim**-0.5 + self._use_gqa = self.num_kv_heads != self.num_heads + + # Separate q/k/v projections matching HF (no fused packing) + bias = getattr(config, "attention_bias", False) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.q_norm = _RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = _RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + bsz, seq_len, _ = hidden_states.shape + hidden_shape_q = (bsz, seq_len, self.num_heads, self.head_dim) + hidden_shape_kv = (bsz, seq_len, self.num_kv_heads, self.head_dim) + + q = self.q_norm(self.q_proj(hidden_states).view(hidden_shape_q)).transpose(1, 2) + k = self.k_norm(self.k_proj(hidden_states).view(hidden_shape_kv)).transpose(1, 2) + v = self.v_proj(hidden_states).view(hidden_shape_kv).transpose(1, 2) + + cos, sin = position_embeddings + # cos/sin are [batch, seq_len, head_dim], need unsqueeze at dim=1 for heads + cos = cos.unsqueeze(1) # [batch, 1, seq_len, head_dim] + sin = sin.unsqueeze(1) + q = (q * cos) + (_rotate_half(q) * sin) + k = (k * cos) + (_rotate_half(k) * sin) + + attn_out = F.scaled_dot_product_attention( + q, + k, + v, + scale=self.scaling, + is_causal=True, + enable_gqa=self._use_gqa, + ) + + attn_out = attn_out.transpose(1, 2).reshape(bsz, seq_len, -1) + return self.o_proj(attn_out) + + +# =================================================================== +# MLP +# =================================================================== + + +class CodePredictorMLP(nn.Module): + """SiLU-gated MLP for code predictor, matching HF's implementation.""" + + def __init__(self, config, *, prefix: str = "") -> None: + super().__init__() + self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.down_proj(F.silu(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) + + +# =================================================================== +# Decoder Layer +# =================================================================== + + +class CodePredictorDecoderLayer(nn.Module): + """Transformer decoder layer (SDPA, no KV cache).""" + + def __init__(self, config, *, prefix: str = "") -> None: + super().__init__() + self.self_attn = CodePredictorAttention(config, prefix=f"{prefix}.self_attn") + self.mlp = CodePredictorMLP(config, prefix=f"{prefix}.mlp") + self.input_layernorm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states, position_embeddings) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +# =================================================================== +# Base Transformer Model (re-prefill, no KV cache) +# =================================================================== + + +class CodePredictorBaseModel(nn.Module): + """Inner transformer for code predictor. + + Signature: ``forward(inputs_embeds, position_ids) -> hidden_states`` + """ + + def __init__( + self, + config, + *, + embedding_dim: int | None = None, + use_parallel_embedding: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + emb_dim = int(embedding_dim) if embedding_dim is not None else int(config.hidden_size) + if use_parallel_embedding: + self.codec_embedding = nn.ModuleList( + [VocabParallelEmbedding(config.vocab_size, emb_dim) for _ in range(config.num_code_groups - 1)] + ) + else: + self.codec_embedding = nn.ModuleList( + [nn.Embedding(config.vocab_size, emb_dim) for _ in range(config.num_code_groups - 1)] + ) + + self.layers = nn.ModuleList( + [ + CodePredictorDecoderLayer(config, prefix=f"{prefix}.layers.{idx}") + for idx in range(config.num_hidden_layers) + ] + ) + self.norm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = _RotaryEmbedding(config) + + def get_input_embeddings(self) -> nn.ModuleList: + return self.codec_embedding + + def forward( + self, + inputs_embeds: torch.Tensor, + position_ids: torch.Tensor, + ) -> torch.Tensor: + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for layer in self.layers: + hidden_states = layer(hidden_states, position_embeddings) + hidden_states = self.norm(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + param = params_dict.get(name) + if param is None: + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +# =================================================================== +# Wrapper Configuration +# =================================================================== + + +@dataclasses.dataclass +class CodePredictorWrapperConfig: + """Controls behavioral differences between model-specific code predictors.""" + + use_cuda_graphs: bool = False + use_parallel_embedding: bool = False + use_projection: bool = False + return_proj_buf: bool = False + sampling_mode: str = "stored" + + +# =================================================================== +# Code Predictor Wrapper (optimized re-prefill, persistent buffers) +# =================================================================== + + +class CodePredictorWrapper(nn.Module): + """Optimized code predictor -- re-prefill approach, no KV cache. + + Each AR step forwards the full growing sequence (len 2 -> num_code_groups+1) + through the transformer. The extra O(T^2) FLOPs are negligible for + short sequences, and this avoids all KV-cache management overhead. + + Optimizations: + 1. Per-call embedding buffer -- avoids cross-request aliasing. + 2. Pre-allocated position_ids -- no torch.arange per step. + 3. Cached module references -- bypass ModuleList indexing. + 4. torch.compile on inner transformer. + 5. Inline sampling (top-k + top-p) -- no custom op overhead. + 6. Optional manual CUDA graph capture per batch-size bucket. + """ + + def __init__( + self, + *, + vllm_config: VllmConfig, + cp_config, + wrapper_config: CodePredictorWrapperConfig, + talker_hidden_size: int | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self._vllm_config = vllm_config + self.config = cp_config + self._wrapper_config = wrapper_config + self.prefix = prefix + + self._num_groups = int(cp_config.num_code_groups) + self._cp_hidden = int(cp_config.hidden_size) + + # For Omni backward compat (accessed by the talker) + self.num_code_groups = self._num_groups + + # Determine embedding dimension + _talker_hidden = int(talker_hidden_size) if talker_hidden_size is not None else self._cp_hidden + + self.model = CodePredictorBaseModel( + cp_config, + embedding_dim=_talker_hidden, + use_parallel_embedding=wrapper_config.use_parallel_embedding, + prefix=f"{prefix}.model" if prefix else "model", + ) + + self.lm_head = nn.ModuleList( + [nn.Linear(cp_config.hidden_size, cp_config.vocab_size, bias=False) for _ in range(self._num_groups - 1)] + ) + + # Projection: Identity when hidden sizes match or not needed + if wrapper_config.use_projection and _talker_hidden != self._cp_hidden: + self.small_to_mtp_projection = nn.Linear(_talker_hidden, self._cp_hidden, bias=True) + else: + self.small_to_mtp_projection = nn.Identity() + + # Sampling defaults for "stored" mode + self._top_k: int = 50 + self._top_p: float = 0.8 + + # Lazily initialised state + self._proj_buf: torch.Tensor | None = None + self._model_dtype: torch.dtype | None = None + self._compiled_model_fwd = None + self._bucket_sizes: list[int] = [] + self._bucket_pos_ids: dict[int, torch.Tensor] = {} + self._lm_heads_list: list[nn.Module] | None = None + self._codec_embeds_list: list[nn.Module] | None = None + self._cuda_graphs: dict[int, tuple[torch.cuda.CUDAGraph, torch.Tensor]] = {} + + def get_input_embeddings(self) -> nn.ModuleList: + return self.model.get_input_embeddings() + + def set_sampling_params(self, top_k: int = 50, top_p: float = 0.8) -> None: + """Configure sampling parameters to maintain consistency with previous implementation.""" + self._top_k = top_k + self._top_p = top_p + logger.debug("Sampling parameters updated: top_k=%d, top_p=%.2f", top_k, top_p) + + # ------------------------------------------------------------------ + # Lazy-init helpers + # ------------------------------------------------------------------ + + def _ensure_buffers(self, device: torch.device, dtype: torch.dtype, bsz: int) -> None: + """Ensure the projection buffer can hold at least *bsz* rows.""" + max_seq = self._num_groups + 1 + if ( + self._proj_buf is not None + and self._proj_buf.device == device + and self._proj_buf.dtype == dtype + and self._proj_buf.shape[0] >= bsz + ): + return + self._proj_buf = torch.zeros(bsz, max_seq, self._cp_hidden, dtype=dtype, device=device) + + def _setup_compile(self) -> None: + """Lazily set up torch.compile with optional CUDA graph capture.""" + if self._compiled_model_fwd is not None: + return + + # Cache model parameter dtype so forward() doesn't need to query it + # on every call. Also ensures warmup buffers match model precision + # even when upstream modules produce a different dtype (#2385). + self._model_dtype = next(self.model.parameters()).dtype + self._lm_heads_list = list(self.lm_head) + self._codec_embeds_list = list(self.model.codec_embedding) + + if not current_omni_platform.supports_torch_inductor(): + logger.warning_once("code_predictor: torch.compile disabled") + self._compiled_model_fwd = self.model.forward + return + + # torch.compile fuses RMSNorm/RoPE in ways that lose float32 + # precision, compounding across AR steps. Use epilogue_fusion=False + # to disable the problematic fusions while still getting kernel + # fusion benefits for the linear layers and SDPA. + self._compiled_model_fwd = torch.compile( + self.model.forward, + dynamic=False, + options={"epilogue_fusion": False}, + ) + self._warmup_buckets() + + if self._wrapper_config.use_cuda_graphs: + self._capture_cuda_graphs() + logger.info("code_predictor: torch.compile (no epilogue fusion) + CUDA graphs") + else: + logger.info("code_predictor: torch.compile (dynamic=False, no epilogue fusion)") + + def _padded_bsz(self, bsz: int) -> int: + """Round batch size up to nearest power-of-2 bucket.""" + for bucket in self._bucket_sizes: + if bsz <= bucket: + return bucket + return bsz + + def _warmup_buckets(self) -> None: + """Warmup power-of-2 batch-size buckets to front-load Inductor compilation.""" + max_bsz = self._vllm_config.scheduler_config.max_num_seqs + bucket_sizes = [1 << i for i in range(max_bsz.bit_length()) if (1 << i) <= max_bsz] + if max_bsz not in bucket_sizes: + bucket_sizes.append(max_bsz) + self._bucket_sizes = sorted(bucket_sizes) + + max_seq = self._num_groups + 1 + device = next(self.model.parameters()).device + + # Ensure proj_buf matches model parameter dtype to avoid dtype + # mismatch during warmup compilation (see #2385). + self._ensure_buffers(device, self._model_dtype, max(self._bucket_sizes)) + proj_buf = self._proj_buf + + for bsz in self._bucket_sizes: + pos_ids = torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(bsz, -1).contiguous() + self._bucket_pos_ids[bsz] = pos_ids + for _ in range(3): + self._compiled_model_fwd(proj_buf[:bsz, :max_seq, :], pos_ids) + logger.info("code_predictor: warmup done for buckets %s", self._bucket_sizes) + + def _capture_cuda_graphs(self) -> None: + """Capture a CUDA graph per bucket using vLLM's global graph pool.""" + from vllm.platforms import current_platform + + pool = current_platform.get_global_graph_pool() + max_seq = self._num_groups + 1 + proj_buf = self._proj_buf + + for bsz in self._bucket_sizes: + static_input = proj_buf[:bsz, :max_seq, :] + pos_ids = self._bucket_pos_ids[bsz] + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, pool=pool): + static_output = self._compiled_model_fwd(static_input, pos_ids) + + self._cuda_graphs[bsz] = (g, static_output) + + logger.info("code_predictor: captured CUDA graphs for buckets %s", self._bucket_sizes) + + # ------------------------------------------------------------------ + # Forward -- re-prefill + inline sampling + # ------------------------------------------------------------------ + + @torch.inference_mode() + def forward( + self, + layer0_code: torch.Tensor, + layer0_embed: torch.Tensor, + last_talker_hidden: torch.Tensor, + do_sample: bool = True, + temperature: float = 0.9, + top_k: int = 50, + top_p: float = 1.0, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """Predict residual codebooks 1..G-1 autoregressively via re-prefill.""" + bsz = int(layer0_code.shape[0]) + num_groups = self._num_groups + device = layer0_code.device + + # _setup_compile caches _model_dtype on first call; use it for buffers + # so they always match model weight precision (#2385). + self._setup_compile() + dtype = self._model_dtype + + padded_bsz = self._padded_bsz(bsz) + self._ensure_buffers(device, dtype, padded_bsz) + + proj_buf = self._proj_buf + max_seq = num_groups + 1 + projection = self.small_to_mtp_projection + model_fwd = self._compiled_model_fwd + lm_heads = self._lm_heads_list + codec_embeds = self._codec_embeds_list + + # Zero the padded region of the buffer + proj_buf[:padded_bsz].zero_() + + # Fill buffer positions 0 (talker hidden) & 1 (layer0 embed) + proj_buf[:bsz, 0, :] = projection(last_talker_hidden.reshape(bsz, 1, -1).to(dtype)).reshape(bsz, -1) + proj_buf[:bsz, 1, :] = projection(layer0_embed.reshape(bsz, 1, -1).to(dtype)).reshape(bsz, -1) + + # Get pre-computed pos_ids for this bucket + full_pos_ids = self._bucket_pos_ids.get(padded_bsz) + if full_pos_ids is None: + full_pos_ids = ( + torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(padded_bsz, -1).contiguous() + ) + + # Use captured CUDA graph if available, otherwise call compiled fn. + cuda_graph_entry = self._cuda_graphs.get(padded_bsz) + + # Prepare sampling parameters + stored_mode = self._wrapper_config.sampling_mode == "stored" + if stored_mode: + s_top_k = self._top_k + s_top_p = self._top_p + else: + use_sampling = do_sample and temperature > 0 + inv_temperature = 1.0 / max(temperature, 1e-6) if use_sampling else 0.0 + if use_sampling and top_p != 1.0: + raise NotImplementedError( + "top_p sampling is not implemented for the vLLM-native code predictor; please set top_p=1.0." + ) + + # Output codes -- shape depends on return mode + if self._wrapper_config.return_proj_buf: + all_codes = torch.empty(bsz, num_groups, 1, dtype=torch.int64, device=device) + all_codes[:, 0] = layer0_code.reshape(bsz, -1)[:, :1] + else: + all_codes = torch.empty(bsz, num_groups, dtype=torch.long, device=device) + all_codes[:, 0] = layer0_code.reshape(bsz) + + # Autoregressive loop: predict layers 1..G-1 + for step in range(1, num_groups): + # Run transformer (CUDA graph replay or compiled forward) + if cuda_graph_entry is not None: + cuda_graph_entry[0].replay() + hidden_out = cuda_graph_entry[1] + else: + hidden_out = model_fwd(proj_buf[:padded_bsz, :max_seq, :], full_pos_ids) + + logits = lm_heads[step - 1](hidden_out[:bsz, step, :]) + + # Sample next code + if stored_mode: + # "stored" mode: top-k -> top-p -> softmax -> multinomial + if s_top_k > 0: + topk_vals, _ = logits.topk(s_top_k, dim=-1) + logits = logits.masked_fill(logits < topk_vals[:, -1:], float("-inf")) + if s_top_p < 1.0: + sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True) + sorted_probs = F.softmax(sorted_logits, dim=-1) + cumulative_probs = sorted_probs.cumsum(dim=-1) + remove_mask = (cumulative_probs - sorted_probs) >= s_top_p + sorted_logits[remove_mask] = float("-inf") + logits = sorted_logits.scatter(1, sorted_idx, sorted_logits) + probs = F.softmax(logits, dim=-1) + code = torch.multinomial(probs, num_samples=1) + else: + # "per_call" mode: temperature-scaled + top-k + if use_sampling: + scaled = logits * inv_temperature + if top_k > 0: + topk_vals, _ = scaled.topk(top_k, dim=-1) + scaled = scaled.masked_fill(scaled < topk_vals[:, -1:], float("-inf")) + probs = F.softmax(scaled, dim=-1) + code = torch.multinomial(probs, num_samples=1) + else: + code = logits.argmax(dim=-1, keepdim=True) + + # Store code + if self._wrapper_config.return_proj_buf: + all_codes[:, step] = code + else: + all_codes[:, step] = code.reshape(bsz) + + # Embed predicted code -> project -> next buffer position + if step < num_groups - 1 or self._wrapper_config.return_proj_buf: + new_embed = codec_embeds[step - 1](code) + proj_buf[:bsz, step + 1, :] = projection(new_embed.reshape(bsz, 1, -1)).reshape(bsz, -1) + + if self._wrapper_config.return_proj_buf: + return all_codes, proj_buf[:bsz].clone() + return all_codes + + # ------------------------------------------------------------------ + # Weight loading + # ------------------------------------------------------------------ + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights directly (no fused projection remapping needed).""" + loaded: set[str] = set() + model_weights: list[tuple[str, torch.Tensor]] = [] + other_weights: list[tuple[str, torch.Tensor]] = [] + + for name, w in weights: + if "rotary_emb.inv_freq" in name: + continue + if name.startswith("model."): + model_weights.append((name[len("model.") :], w)) + else: + other_weights.append((name, w)) + + loaded_model = self.model.load_weights(model_weights) + loaded |= {f"model.{n}" for n in loaded_model} + + params = dict(self.named_parameters(remove_duplicate=False)) + for name, w in other_weights: + param = params.get(name) + if param is None: + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, w) + loaded.add(name) + + return loaded diff --git a/vllm_omni/model_executor/models/cosyvoice3/assets/mel_filters.npz b/vllm_omni/model_executor/models/cosyvoice3/assets/mel_filters.npz deleted file mode 100644 index 28ea26909d..0000000000 Binary files a/vllm_omni/model_executor/models/cosyvoice3/assets/mel_filters.npz and /dev/null differ diff --git a/vllm_omni/model_executor/models/cosyvoice3/utils.py b/vllm_omni/model_executor/models/cosyvoice3/utils.py index 52c52655e8..0bf0cccb16 100644 --- a/vllm_omni/model_executor/models/cosyvoice3/utils.py +++ b/vllm_omni/model_executor/models/cosyvoice3/utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging -import os from functools import cache, lru_cache import numpy as np @@ -9,7 +8,8 @@ import torch.nn.functional as F import torchaudio import torchaudio.compliance.kaldi as kaldi -from librosa.filters import mel as librosa_mel_fn + +from vllm_omni.utils.audio import mel_filter_bank logger = logging.getLogger(__name__) @@ -34,8 +34,13 @@ def _get_mel_basis( fmax: float | None, device_str: str, ) -> torch.Tensor: - mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - return torch.from_numpy(mel).float().to(torch.device(device_str)) + return mel_filter_bank( + sr=sampling_rate, + n_fft=n_fft, + n_mels=num_mels, + fmin=fmin, + fmax=fmax, + ).to(torch.device(device_str)) @lru_cache @@ -122,42 +127,8 @@ def exact_div(x, y): @cache def mel_filters(device, n_mels: int) -> torch.Tensor: - """ - load the mel filterbank matrix for projecting STFT into a Mel spectrogram. - Allows decoupling librosa dependency; saved using: - - np.savez_compressed( - "mel_filters.npz", - mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), - mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), - ) - """ - assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" - - filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") - if not os.path.exists(filters_path): - source_url = "https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/mel_filters.npz" - os.makedirs(os.path.dirname(filters_path), exist_ok=True) - try: - import urllib.request - - with urllib.request.urlopen(source_url, timeout=30) as resp: - with open(filters_path, "wb") as f_out: - f_out.write(resp.read()) - logger.info("Downloaded mel_filters.npz from %s", source_url) - except Exception as e: - raise FileNotFoundError( - "Missing CosyVoice3 mel filter asset:\n" - f" {filters_path}\n" - "Auto-download failed. Download it manually from:\n" - f" {source_url}\n" - "Example:\n" - f" mkdir -p {os.path.dirname(filters_path)} && " - f"curl -L {source_url} -o {filters_path}" - ) from e - - with np.load(filters_path, allow_pickle=False) as f: - return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) + """Compute mel filterbank matrix for projecting STFT into a Mel spectrogram.""" + return mel_filter_bank(sr=16000, n_fft=400, n_mels=n_mels).to(device) def log_mel_spectrogram( diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py index 8bbb643ebe..22a2744ff5 100644 --- a/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py +++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py @@ -310,6 +310,7 @@ def __init__( self._compiled_model_fwd: object | None = None self._compile_attempted = False self._compile_failed = False + self._disable_compile_for_graph = False def _ensure_buffers(self, bsz: int, device: torch.device, dtype: torch.dtype) -> None: max_seq = self._num_codebooks + 1 # hidden_state + num_codebooks codes @@ -327,11 +328,20 @@ def _setup_compile(self) -> None: if self._compile_attempted: return self._compile_attempted = True + if self._disable_compile_for_graph: + try: + self._compiled_model_fwd = torch.compile( + self.model.forward, + dynamic=True, + options={"epilogue_fusion": False}, + ) + except Exception as exc: + logger.warning("Fast AR torch.compile (graph mode) failed: %s", exc) + self._compiled_model_fwd = self.model.forward + return try: self._compiled_model_fwd = torch.compile( self.model.forward, - # Keep the helper compiler separate from vLLM's outer - # cudagraph-managed Stage-0 execution. mode="default", dynamic=True, fullgraph=False, @@ -366,10 +376,10 @@ def warmup_compile( @torch.inference_mode() def _run_model(self, step_input: torch.Tensor, step_pos_ids: torch.Tensor, bsz: int) -> torch.Tensor: - # Default-on compile only pays off for single-request decode. For - # batched decode, eager preserves loaded throughput and avoids the - # regression seen with batch>1 compiled execution. - model_fwd = self._compiled_model_fwd if bsz == 1 else self.model.forward + if self._disable_compile_for_graph: + model_fwd = self._compiled_model_fwd or self.model.forward + else: + model_fwd = self._compiled_model_fwd if bsz == 1 else self.model.forward try: return model_fwd(step_input, step_pos_ids) except Exception as exc: diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py index 3813597caa..62776cbb31 100644 --- a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py +++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py @@ -194,6 +194,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.has_postprocess = True self.mtp_hidden_size = int(self.text_config.hidden_size) self.talker_mtp_output_key = "audio_codes" + self.talker_mtp_graph_safe = True self.gpu_resident_buffer_keys: set[str] = {"last_slow_ar_hidden"} # Qwen3 transformer backbone. @@ -236,6 +237,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): slow_ar_config=self.text_config, prefix="fast_ar", ) + if self.talker_mtp_graph_safe: + self.fast_ar._disable_compile_for_graph = True # Constant logit mask: allow only semantic tokens + im_end. vocab = int(self.text_config.vocab_size) @@ -680,18 +683,13 @@ def talker_mtp( inputs_embeds_out = input_embeds.reshape(bsz, -1).clone() semantic_mask = (input_ids[:, 0] >= self._semantic_begin_id) & (input_ids[:, 0] <= self._semantic_end_id) - if semantic_mask.any(): - semantic_codes = audio_codes[semantic_mask].clamp(min=0) - offsets = ( - torch.arange(self._num_codebooks, device=dev, dtype=semantic_codes.dtype) * self._codebook_size - ).unsqueeze(0) - codebook_sum = self.codebook_embeddings(semantic_codes + offsets).sum(dim=1).to(dtype=torch.bfloat16) - - # Normalize by sqrt(num_codebooks + 1) as in the reference model - # (scale_codebook_embeddings=True for fish_qwen3_omni). - inputs_embeds_out[semantic_mask] = (inputs_embeds_out[semantic_mask] + codebook_sum) / math.sqrt( - self._num_codebooks + 1 - ) + semantic_codes = audio_codes.clamp(min=0, max=self._codebook_size - 1) + offsets = ( + torch.arange(self._num_codebooks, device=dev, dtype=semantic_codes.dtype) * self._codebook_size + ).unsqueeze(0) + codebook_sum = self.codebook_embeddings(semantic_codes + offsets).sum(dim=1).to(dtype=torch.bfloat16) + norm_embeds = (inputs_embeds_out + codebook_sum) / math.sqrt(self._num_codebooks + 1) + inputs_embeds_out = torch.where(semantic_mask.unsqueeze(-1), norm_embeds, inputs_embeds_out) return inputs_embeds_out, audio_codes.to(dtype=torch.long) @@ -802,14 +800,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if truncated: logger.info("Truncated %d RoPE cos_sin_cache buffers to bf16 precision", truncated) - try: - self.fast_ar.warmup_compile( - device=self.codebook_embeddings.weight.device, - dtype=torch.bfloat16, - batch_sizes=(1,), - ) - except Exception as exc: - logger.warning("Fish Speech Fast AR compile warmup failed: %s", exc) + if not getattr(self, "talker_mtp_graph_safe", False): + try: + self.fast_ar.warmup_compile( + device=self.codebook_embeddings.weight.device, + dtype=torch.bfloat16, + batch_sizes=(1,), + ) + except Exception as exc: + logger.warning("Fish Speech Fast AR compile warmup failed: %s", exc) codec_device = self.codebook_embeddings.weight.device _load_dac_codec( diff --git a/vllm_omni/model_executor/models/fish_speech/prompt_utils.py b/vllm_omni/model_executor/models/fish_speech/prompt_utils.py index 923e97b63a..8b8d8559ea 100644 --- a/vllm_omni/model_executor/models/fish_speech/prompt_utils.py +++ b/vllm_omni/model_executor/models/fish_speech/prompt_utils.py @@ -38,10 +38,7 @@ def _encode_plain_text(tokenizer: Any, text: str) -> list[int]: def _encode_control_token(tokenizer: Any, token: str) -> list[int]: - vocab = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else {} - token_id = vocab.get(token) - if token_id is None: - token_id = tokenizer.convert_tokens_to_ids(token) + token_id = tokenizer.convert_tokens_to_ids(token) if token_id is None or token_id == getattr(tokenizer, "unk_token_id", None): raise ValueError(f"Fish Speech tokenizer is missing required control token: {token}") return [int(token_id)] diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py index 6d25274f90..6304eeab29 100644 --- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py +++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py @@ -77,7 +77,9 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils.tensor_schema import TensorSchema +from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.sampler import Sampler from vllm_omni.model_executor.models.hunyuan_image3.autoencoder_kl_3d import AutoencoderKLConv3D from vllm_omni.model_executor.models.hunyuan_image3.siglip2 import LightProjector, Siglip2VisionTransformer @@ -175,8 +177,11 @@ def contains_unexpected_keyword(name, keywords): return True return False + skipped_unexpected: set[str] = set() + for name, loaded_weight in weights: if contains_unexpected_keyword(name, unexpected_keywords): + skipped_unexpected.add(name) continue if "rotary_emb.inv_freq" in name: @@ -362,6 +367,17 @@ def contains_unexpected_keyword(name, keywords): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) + + if skipped_unexpected: + logger.warning_once( + "Skipped %d weights matching unexpected_keywords " + "(e.g. vae, vision_model, patch_embed, timestep_emb). " + "If upstream renamed components, these may be silently " + "lost. Skipped names: %s", + len(skipped_unexpected), + sorted(skipped_unexpected)[:10], + ) + return loaded_params @@ -1149,6 +1165,8 @@ class HunyuanImage3ForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo HunyuanImage3Inputs: TypeAlias = HunyuanImage3PixelInputs + prefer_model_sampler = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1199,6 +1217,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.lm_head = PPMissingLayer() + # --- AR-stage components --- + # These are needed for image encoding in the AR stage. + # If a future text-only stage is added, gate on vllm_config.model_config.model_stage. + # vae self.vae = AutoencoderKLConv3D.from_config(config.vae) self.patch_embed = UNetDown( @@ -1226,6 +1248,63 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._mrope_joint_img_sep_token_id = tokenizer.convert_tokens_to_ids("") self._mrope_max_num_patches = config.vit_processor.get("max_num_patches", 729) + # Special token IDs for logits processors (stage transitions). + # These mirror the official tokenization_hunyuan_image_3.py setup. + self._end_of_think_id = tokenizer.convert_tokens_to_ids("") + self._recaption_id = tokenizer.convert_tokens_to_ids("") + self._end_of_recaption_id = tokenizer.convert_tokens_to_ids("") + self._answer_id = tokenizer.convert_tokens_to_ids("") + self._end_of_answer_id = tokenizer.convert_tokens_to_ids("") + image_base_size = getattr(config, "image_base_size", 1024) + self._size_token_id = tokenizer.convert_tokens_to_ids(f"") + self._start_ratio_id = tokenizer.convert_tokens_to_ids("") + self._end_ratio_id = tokenizer.convert_tokens_to_ids("") + ratio_33 = tokenizer.convert_tokens_to_ids("") + ratio_36 = tokenizer.convert_tokens_to_ids("") + self._ratio_other_slices = [(ratio_33, ratio_36 + 1)] + # Build the full set of ratio token IDs for use as stop tokens. + self._all_ratio_ids = set(range(self._start_ratio_id, self._end_ratio_id + 1)) + for s, e in self._ratio_other_slices: + self._all_ratio_ids.update(range(s, e)) + + # Determine mode: comprehension (I2T/T2T) vs generation (IT2I/T2I). + engine_output_type = getattr(vllm_config.model_config, "engine_output_type", None) + self._is_comprehension = engine_output_type in (None, "text") + + # For comprehension mode, block image generation tokens but allow + # text structure tokens (, , etc.) so the model can + # follow its natural generation pattern. Stop tokens in YAML will + # terminate at or EOS. + self._blocked_token_ids: set[int] = set() + if self._is_comprehension: + self._blocked_token_ids.update( + [ + self._mrope_boi_token_id, # + self._mrope_eoi_token_id, # + self._size_token_id, # + ] + ) + self._blocked_token_ids.update(self._all_ratio_ids) + + # For generation mode, build stage transition map. + # Official logic: → [], + # → [, , ] + # After , restrict vocab to ratio tokens only. + # Stage-transition forced sequences, keyed by trigger token. + self._stage_transitions: dict[int, list[int]] = {} + if not self._is_comprehension: + self._stage_transitions[self._end_of_think_id] = [ + self._recaption_id, + ] + self._stage_transitions[self._end_of_recaption_id] = [ + self._answer_id, + self._mrope_boi_token_id, + self._size_token_id, + ] + + self._sampler: Sampler | None = None + self._eos_token_id: int = tokenizer.eos_token_id + self._replace_rotary_embeddings() def _replace_rotary_embeddings(self): @@ -1257,6 +1336,12 @@ def _replace_rotary_embeddings(self): head_dim, rope_theta, ) + if replaced == 0: + raise RuntimeError( + "HunyuanImage3: _replace_rotary_embeddings replaced 0 layers. " + "The custom interleaved 2D mRoPE is not active — model outputs " + "will be incorrect. Check that model.layers[*].self_attn.rotary_emb exists." + ) def _parse_and_validate_image_input( self, @@ -1274,6 +1359,10 @@ def _parse_and_validate_image_input( if vit_pixel_values is None or vae_pixel_values is None: return None + # Handle empty batch (e.g., during profiling with 0 images / T2T mode) + if vit_pixel_values.numel() == 0 or vae_pixel_values.numel() == 0: + return None + return HunyuanImage3PixelInputs( type="pixel_values", pixel_values={ @@ -1472,6 +1561,112 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits + # ------------------------------------------------------------------ + # Custom sampler — applies HunyuanImage3-specific logits processors + # before the standard sampling step. + # + # Comprehension (I2T / T2T): + # Block generation-specific special tokens so sampling can't + # accidentally produce , , ratio tokens, etc. + # + # Generation (IT2I / T2I think): + # 1. _StageTransitionLogitsProcessor — force token sequences at + # transition boundaries ( → , etc.) + # 2. _ConditionalSliceVocabLogitsProcessor — after , + # restrict vocab to ratio tokens only (greedy). + # ------------------------------------------------------------------ + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput | None: + if logits is None or logits.numel() == 0: + return None + + if self._sampler is None: + self._sampler = Sampler() + + min_score = torch.finfo(logits.dtype).min + + assert logits.shape[0] == 1, f"HunyuanImage3 sampler requires max_num_seqs=1, got batch size {logits.shape[0]}" + + for req_idx in range(logits.shape[0]): + decoded_tokens: list[int] = ( + sampling_metadata.output_token_ids[req_idx] if req_idx < len(sampling_metadata.output_token_ids) else [] + ) + last_token = decoded_tokens[-1] if decoded_tokens else -1 + + if self._is_comprehension: + for tid in self._blocked_token_ids: + logits[req_idx, tid] = min_score + else: + forced = self._get_forced_token(decoded_tokens) + if forced is not None: + logits[req_idx].fill_(min_score) + logits[req_idx, forced] = 0 + elif last_token == self._size_token_id: + self._apply_ratio_restriction(logits, req_idx, min_score) + elif last_token in self._all_ratio_ids: + logits[req_idx].fill_(min_score) + logits[req_idx, self._eos_token_id] = 0 + + return self._sampler(logits=logits, sampling_metadata=sampling_metadata) + + def _get_forced_token(self, decoded_tokens: list[int]) -> int | None: + """Derive the next forced token from output history (stateless). + + Scans decoded_tokens backwards for the most recent trigger token, + then prefix-matches the forced sequence against what followed. + Returns the next token to force, or None if the sequence is complete + or history has diverged from the expected forced sequence. + """ + for i in range(len(decoded_tokens) - 1, -1, -1): + trigger = decoded_tokens[i] + if trigger not in self._stage_transitions: + continue + + forced_seq = self._stage_transitions[trigger] + emitted = decoded_tokens[i + 1 :] + + matched = 0 + for expected, actual in zip(forced_seq, emitted): + if actual != expected: + # History diverged from the expected forced sequence. + # Stop applying transition forcing for safety. + return None + matched += 1 + + if matched < len(forced_seq): + return forced_seq[matched] + return None + + return None + + def _apply_ratio_restriction( + self, + logits: torch.Tensor, + req_idx: int, + min_score: float, + ) -> None: + """Port of official _ConditionalSliceVocabLogitsProcessor.__call__. + + After the size token, only allow ratio tokens and pick greedily. + """ + original = logits[req_idx].clone() + logits[req_idx].fill_(min_score) + # Allow primary ratio range. + logits[req_idx, self._start_ratio_id : self._end_ratio_id + 1] = original[ + self._start_ratio_id : self._end_ratio_id + 1 + ] + # Allow extra ratio slices. + for s, e in self._ratio_other_slices: + logits[req_idx, s:e] = original[s:e] + # Force greedy: keep only the argmax. + max_id = logits[req_idx].argmax().item() + logits[req_idx].fill_(min_score) + logits[req_idx, max_id] = 0 + def make_empty_intermediate_tensors( self, batch_size: int, dtype: torch.dtype, device: torch.device ) -> IntermediateTensors: @@ -1507,9 +1702,9 @@ def get_mrope_input_positions( input_tokens: list[int], mm_features: list[MultiModalFeatureSpec] | None = None, *, - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, + hf_config: PretrainedConfig | None = None, + image_grid_thw: list[list[int]] | torch.Tensor | None = None, + video_grid_thw: list[list[int]] | torch.Tensor | None = None, second_per_grid_ts: list[float] | None = None, context_len: int = 0, seq_len: int | None = None, diff --git a/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py b/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py index 56cb8788ee..85fe4b0051 100644 --- a/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py +++ b/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py @@ -50,6 +50,7 @@ PromptUpdate, PromptUpdateDetails, ) +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema @@ -150,7 +151,6 @@ def __init__(self, model: "MiMoAudioLLMForConditionalGeneration", max_batch_size dtype = next(model.hidden_states_downcast.parameters()).dtype hidden_size = model.local_config.hidden_size - self.pool = torch.cuda.graph_pool_handle() self.input_tensor = torch.zeros((max_batch_size, 1, hidden_size), dtype=dtype, device=device) self.sampler = MiMoLocalSamplerTensor( temperature=torch.ones(max_batch_size, dtype=torch.float32, device=device), @@ -231,7 +231,7 @@ def capture( cuda_graph = torch.cuda.CUDAGraph() if eager_run_first: model.base_local_forward(input_tensor, local_sampler=sampler) - with torch.cuda.graph(cuda_graph, buffer.pool): + with torch.cuda.graph(cuda_graph, pool=current_platform.get_global_graph_pool()): output_tensor = model.base_local_forward(input_tensor, local_sampler=sampler) return cls( @@ -263,7 +263,6 @@ def __init__(self, model: "MiMoAudioLLMForConditionalGeneration", max_batch_size hidden_size = model.input_local_config.hidden_size group_size = model.group_size - self.pool = torch.cuda.graph_pool_handle() self.input_tensor = torch.zeros((max_batch_size, group_size, hidden_size), dtype=dtype, device=device) self.lock = threading.Lock() @@ -311,7 +310,7 @@ def capture( out = model.input_local_transformer(inputs_embeds=input_tensor, return_dict=True, is_causal=False) _ = out.last_hidden_state - with torch.cuda.graph(cuda_graph, buffer.pool): + with torch.cuda.graph(cuda_graph, pool=current_platform.get_global_graph_pool()): out = model.input_local_transformer(inputs_embeds=input_tensor, return_dict=True, is_causal=False) output_tensor = out.last_hidden_state diff --git a/vllm_omni/model_executor/models/ming_flash_omni/__init__.py b/vllm_omni/model_executor/models/ming_flash_omni/__init__.py new file mode 100644 index 0000000000..d7fa44fd7e --- /dev/null +++ b/vllm_omni/model_executor/models/ming_flash_omni/__init__.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The vLLM-Omni team. + +from .ming_flash_omni import MingFlashOmniForConditionalGeneration +from .ming_flash_omni_thinker import ( + MingFlashOmniThinkerDummyInputsBuilder, + MingFlashOmniThinkerForConditionalGeneration, + MingFlashOmniThinkerMultiModalProcessor, + MingFlashOmniThinkerProcessingInfo, +) + +__all__ = [ + "MingFlashOmniForConditionalGeneration", + "MingFlashOmniThinkerForConditionalGeneration", + "MingFlashOmniThinkerProcessingInfo", + "MingFlashOmniThinkerMultiModalProcessor", + "MingFlashOmniThinkerDummyInputsBuilder", +] diff --git a/vllm_omni/model_executor/models/ming_flash_omni/audio_encoder.py b/vllm_omni/model_executor/models/ming_flash_omni/audio_encoder.py new file mode 100644 index 0000000000..6ca1990114 --- /dev/null +++ b/vllm_omni/model_executor/models/ming_flash_omni/audio_encoder.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The vLLM-Omni team. +# Copyright 2024 ANT Group and the HuggingFace Inc. team. +# Copyright (c) 2022 OpenAI +# Adapted from Ming repository modeling_whisper_encoder.py +# https://github.com/inclusionAI/Ming + +import operator +from collections.abc import Iterable +from itertools import accumulate + +import torch +import torch.nn as nn +import torch.nn.functional as F +from vllm.logger import init_logger +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.backends.utils.fa import HAS_FLASH_ATTN, flash_attn_varlen_func +from vllm_omni.model_executor.models.whisper_utils import Conv1d, Linear, sinusoids + +logger = init_logger(__name__) + + +class MultiHeadAttention(nn.Module): + """Multi-head attention with packed sequence support. + Adapted from Qwen3-TTS WhisperEncoder. + """ + + def __init__(self, n_state: int, n_head: int, use_flash_attn: bool = True): + super().__init__() + self.n_head = n_head + self.query = Linear(n_state, n_state) + self.key = Linear(n_state, n_state, bias=False) + self.value = Linear(n_state, n_state) + self.out = Linear(n_state, n_state) + + if use_flash_attn and not HAS_FLASH_ATTN: + logger.warning("flash-attn is not available. Fallback to manual PyTorch version") + self.use_flash_attn = use_flash_attn and HAS_FLASH_ATTN + + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + """Forward pass with packed sequence support. + + Args: + x: [total_tokens, n_state] packed sequence + cu_seqlens: [num_seqs + 1] cumulative sequence lengths, e.g. [0, len1, len1+len2, ...] + + Returns: + [total_tokens, n_state] attention output + """ + q = self.query(x) + k = self.key(x) + v = self.value(x) + + n_ctx, n_state = q.shape + head_dim = n_state // self.n_head + + q = q.view(n_ctx, self.n_head, head_dim) + k = k.view(n_ctx, self.n_head, head_dim) + v = v.view(n_ctx, self.n_head, head_dim) + + # Try flash attention varlen + if self.use_flash_attn and cu_seqlens is not None and q.dtype in [torch.float16, torch.bfloat16]: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen) + else: + attn_output = self._manual_attention(q, k, v, cu_seqlens) + + # Reshape back: [T, H, D] -> [T, H*D] + attn_output = attn_output.contiguous().view(n_ctx, n_state) + return self.out(attn_output) + + def _manual_attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens: torch.Tensor + ) -> torch.Tensor: + """Manual attention for variable-length sequences (fallback).""" + _, n_head, head_dim = q.shape + scale = head_dim**-0.5 + + # Unpack sequences and pad to max length + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + batch_size = len(seqlens) + max_seqlen = max(seqlens) + + # Create padded tensors + q_padded = torch.zeros(batch_size, max_seqlen, n_head, head_dim, dtype=q.dtype, device=q.device) + k_padded = torch.zeros_like(q_padded) + v_padded = torch.zeros_like(q_padded) + + # Fill with actual sequences + for i in range(batch_size): + start_idx = cu_seqlens[i] + end_idx = cu_seqlens[i + 1] + seq_len = seqlens[i] + q_padded[i, :seq_len] = q[start_idx:end_idx] + k_padded[i, :seq_len] = k[start_idx:end_idx] + v_padded[i, :seq_len] = v[start_idx:end_idx] + + # Transpose for attention: [B, H, T, D] + q_padded = q_padded.transpose(1, 2) + k_padded = k_padded.transpose(1, 2) + v_padded = v_padded.transpose(1, 2) + + # Create attention mask for variable lengths: 0 for valid positions, -inf for padding + padding_mask = ( + torch.arange(max_seqlen, device=q.device)[None, :] >= torch.tensor(seqlens, device=q.device)[:, None] + ) + attn_mask = torch.zeros(batch_size, 1, 1, max_seqlen, dtype=q.dtype, device=q.device) + attn_mask = attn_mask.masked_fill(padding_mask.unsqueeze(1).unsqueeze(2), -torch.finfo(q.dtype).max) + + # Compute attention + attn_scores = torch.matmul(q_padded, k_padded.transpose(-2, -1)) * scale + attn_scores = attn_scores + attn_mask + attn_weights = F.softmax(attn_scores, dim=-1) + context = torch.matmul(attn_weights, v_padded) + + # Transpose back: [B, H, T, D] -> [B, T, H, D] + context = context.transpose(1, 2).contiguous() + output_packed = torch.cat([context[i, : seqlens[i]] for i in range(batch_size)], dim=0) + + return output_packed + + +class ResidualAttentionBlock(nn.Module): + """Whisper-style residual attention block with packed sequence support. + + Adapted from + https://github.com/openai/whisper/blob/v20250625/whisper/model.py + vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py + """ + + def __init__(self, n_state: int, n_head: int, use_flash_attn: bool = True): + super().__init__() + self.attn = MultiHeadAttention(n_state, n_head, use_flash_attn=use_flash_attn) + self.attn_ln = nn.LayerNorm(n_state) + + n_mlp = n_state * 4 + self.mlp = nn.Sequential( + Linear(n_state, n_mlp), + nn.GELU(), + Linear(n_mlp, n_state), + ) + self.mlp_ln = nn.LayerNorm(n_state) + + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.attn_ln(x), cu_seqlens=cu_seqlens) + x = x + self.mlp(self.mlp_ln(x)) + return x + + +class WhisperAudioEncoder(nn.Module): + """Whisper audio encoder for Ming with packed sequence support. + + Adapted from + https://github.com/openai/whisper/blob/v20250625/whisper/model.py + vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py + """ + + def __init__( + self, + n_mels: int = 128, + n_ctx: int = 15000, + n_state: int = 1280, + n_head: int = 20, + n_layer: int = 32, + use_flash_attn: bool = True, + ): + super().__init__() + self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) + self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) + # self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) + self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) + self.blocks = nn.ModuleList( + [ResidualAttentionBlock(n_state, n_head, use_flash_attn=use_flash_attn) for _ in range(n_layer)] + ) + self.ln_post = nn.LayerNorm(n_state) + self.audio_emb_dim = n_state + + self.n_layer = n_layer + self.n_mels = n_mels + self.use_flash_attn = use_flash_attn + + def forward( + self, + x_list: list[torch.Tensor], + audio_lens: list[int], + ) -> torch.Tensor: + """Forward pass with packed sequence format for variable-length inputs. + + Args: + x_list: List of [n_mels, T_i] mel spectrogram features for each audio + audio_lens: List of original audio lengths in frames + + Returns: + [total_T', n_state] packed encoded audio features, where + total_T' is the sum of all encoded sequence lengths + """ + # Cast inputs to model dtype + target_dtype = self.conv1.weight.dtype + x_list = [x.to(target_dtype) for x in x_list] + + encoded_list = [] + encoded_lens = [] + for mel_spec in x_list: + # mel_spec: [n_mels, T] - process through conv layers + x = mel_spec.unsqueeze(0) # [1, n_mels, T] + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.squeeze(0).transpose(0, 1) # [T', n_state] + + # Add positional embedding + seq_len = x.shape[0] + positional_embedding = self.positional_embedding[:seq_len, :] + x = (x + positional_embedding).to(x.dtype) + + encoded_list.append(x) + encoded_lens.append(seq_len) + + x_packed = torch.cat(encoded_list, dim=0) # [total_T', n_state] + + cu_seqlens = list(accumulate(encoded_lens, func=operator.add, initial=0)) + cu_seqlens = torch.tensor(cu_seqlens, device=x_packed.device, dtype=torch.int32) + + for block in self.blocks: + x_packed = block(x_packed, cu_seqlens=cu_seqlens) + + x_packed = self.ln_post(x_packed) + return x_packed + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + params_dict: dict[str, torch.Tensor] = { + **dict(self.named_parameters(remove_duplicate=False)), + **dict(self.named_buffers()), + } + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if name not in params_dict: + logger.warning("Skipping unknown audio encoder weight: %s", name) + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params diff --git a/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py new file mode 100644 index 0000000000..87728890b6 --- /dev/null +++ b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The vLLM-Omni team. +# Copyright 2024 ANT Group and the HuggingFace Inc. team. All rights reserved. +# Adapted from Ming repository modeling_bailingmm2.py +# https://github.com/inclusionAI/Ming +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Ming-flash-omni-2.0 unified model (thinker + imagegen + talker).""" + +from collections.abc import Iterable + +import torch +import torch.nn as nn +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import ( + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, +) +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.utils import ( + init_vllm_registered_model, + maybe_prefix, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors + +from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin +from vllm_omni.model_executor.models.output_templates import OmniOutput +from vllm_omni.model_executor.models.utils import add_prefix_to_loaded_weights +from vllm_omni.transformers_utils.configs.ming_flash_omni import BailingMM2Config, MingFlashOmniConfig + +from .ming_flash_omni_thinker import ( + MingFlashOmniThinkerDummyInputsBuilder, + MingFlashOmniThinkerMultiModalProcessor, + MingFlashOmniThinkerProcessingInfo, +) + +logger = init_logger(__name__) + + +@MULTIMODAL_REGISTRY.register_processor( + MingFlashOmniThinkerMultiModalProcessor, + info=MingFlashOmniThinkerProcessingInfo, + dummy_inputs=MingFlashOmniThinkerDummyInputsBuilder, +) +class MingFlashOmniForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsPP, + SupportsMRoPE, + CustomProcessMixin, +): + """Unified Ming-flash-omni-2.0 model combining thinker, imagegen, and talker.""" + + supports_multimodal = True + requires_raw_input_tokens: bool = True + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.have_multimodal_outputs = True + self.has_preprocess = False + self.has_postprocess = False + + config = vllm_config.model_config.hf_config + + self.vllm_config = vllm_config + self.config = config + + if isinstance(config, MingFlashOmniConfig): + thinker_config = config.thinker_config + else: + thinker_config = config + + self.thinker_config: BailingMM2Config = thinker_config + self.model_stage = vllm_config.model_config.model_stage + + if self.model_stage == "thinker": + thinker_vllm_config = vllm_config.with_hf_config( + thinker_config, architectures=["MingFlashOmniThinkerForConditionalGeneration"] + ) + self.thinker = init_vllm_registered_model( + vllm_config=thinker_vllm_config, + prefix=maybe_prefix(prefix, "thinker"), + architectures=["MingFlashOmniThinkerForConditionalGeneration"], + ) + self.model = self.thinker + self.imagegen = None + self.talker = None + + elif self.model_stage == "imagegen": + # TODO: Implement image generator stage + raise NotImplementedError( + "Image generation stage is not yet implemented. Please use model_stage='thinker' for now." + ) + + elif self.model_stage == "talker": + # TODO: Implement talker (TTS) stage + raise NotImplementedError( + "Talker (TTS) stage is not yet implemented. Please use model_stage='thinker' for now." + ) + + else: + raise ValueError( + f"Invalid model_stage: {self.model_stage}. Must be one of: 'thinker', 'imagegen', 'talker'" + ) + + # Set up intermediate tensors + self.make_empty_intermediate_tensors = ( + self.thinker.make_empty_intermediate_tensors if self.model_stage == "thinker" else lambda: None + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> OmniOutput: + return self.model.forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata=None, + ) -> torch.Tensor | None: + if hasattr(self.model, "compute_logits"): + return self.model.compute_logits(hidden_states, sampling_metadata) + return None + + def sample( + self, + logits: torch.Tensor, + sampling_metadata, + ): + if hasattr(self.model, "sample"): + return self.model.sample(logits, sampling_metadata) + raise NotImplementedError("sample method not available on current stage") + + def get_mrope_input_positions(self, *args, **kwargs): + if hasattr(self.model, "get_mrope_input_positions"): + return self.model.get_mrope_input_positions(*args, **kwargs) + raise NotImplementedError("get_mrope_input_positions not available on current stage") + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loaded_weights = set() + thinker_weights = [] + imagegen_weights = [] + talker_weights = [] + + for name, value in weights: + if name.startswith("thinker."): + thinker_weights.append((name, value)) + elif name.startswith("imagegen."): + imagegen_weights.append((name, value)) + elif name.startswith("talker."): + talker_weights.append((name, value)) + else: + # Weights without prefix go to thinker by default + thinker_weights.append((name, value)) + + if self.model_stage == "thinker" and thinker_weights: + # Remove "thinker." prefix before loading + thinker_weights_stripped = [ + (name.replace("thinker.", "", 1) if name.startswith("thinker.") else name, value) + for name, value in thinker_weights + ] + thinker_loaded = self.thinker.load_weights(thinker_weights_stripped) + thinker_loaded = add_prefix_to_loaded_weights(thinker_loaded, "thinker") + loaded_weights.update(thinker_loaded) + + # TODO: Load imagegen weights when implemented + # TODO: Load talker weights when implemented + + return loaded_weights + + def get_mm_mapping(self) -> MultiModelKeys: + return MultiModelKeys.from_string_field( + language_model="thinker.language_model", + connector=["thinker.linear_proj.", "thinker.linear_proj_audio."], + tower_model=["thinker.vision.", "thinker.audio."], + ) + + @property + def sampler(self): + if hasattr(self.model, "sampler"): + return self.model.sampler + return None + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings=None, + *, + is_multimodal=None, + ) -> torch.Tensor: + return self.model.embed_input_ids( + input_ids, + multimodal_embeddings, + is_multimodal=is_multimodal, + ) + + def embed_multimodal(self, **kwargs): + return self.model.embed_multimodal(**kwargs) diff --git a/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_thinker.py b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_thinker.py new file mode 100644 index 0000000000..bde7477b94 --- /dev/null +++ b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_thinker.py @@ -0,0 +1,893 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The vLLM-Omni team. +# Copyright 2024 ANT Group and the HuggingFace Inc. team. +# Adapted from Ming repository modeling_bailingmm2.py and processing_bailingmm2.py +# https://github.com/inclusionAI/Ming + +"""Ming-flash-omni-2.0 Thinker stage implementation (multimodal understanding).""" + +from collections.abc import Iterable, Iterator, Mapping, Sequence +from typing import Annotated, Any + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.feature_extraction_utils import BatchFeature +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict +from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import ( + MultiModalEmbeddings, + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, +) +from vllm.model_executor.models.qwen2_5_vl import ( + Qwen2_5_VLImageInputs, + Qwen2_5_VLImagePixelInputs, + Qwen2_5_VLVideoInputs, + Qwen2_5_VLVideoPixelInputs, +) +from vllm.model_executor.models.qwen2_vl import ( + Qwen2VLProcessingInfo, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + _merge_multimodal_embeddings, + maybe_prefix, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + AudioProcessorItems, + ImageProcessorItems, + MultiModalDataItems, + MultiModalDataParser, + VideoProcessorItems, +) +from vllm.multimodal.processing import ( + BaseDummyInputsBuilder, + BaseMultiModalProcessor, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin +from vllm_omni.model_executor.models.output_templates import OmniOutput +from vllm_omni.transformers_utils.configs.ming_flash_omni import BailingMM2Config +from vllm_omni.transformers_utils.processors.ming import ( + PLACEHOLDER_AUDIO_TOKEN_IN_TEXT, + PLACEHOLDER_IMAGE_TOKEN_IN_TEXT, + PLACEHOLDER_VIDEO_TOKEN_IN_TEXT, + MingFlashOmniProcessor, + MingWhisperFeatureExtractor, +) + +from .audio_encoder import WhisperAudioEncoder +from .modeling_bailing_moe_v2 import BailingMoeV2ForCausalLM +from .projectors import AudioProjector, VisionProjector +from .vision_encoder import MingVisionEncoder + +logger = init_logger(__name__) + + +class MingAudioInput(TensorSchema): + """ + Dimensions: + - b: Batch size + - l: Total audio frames (clips concatenated along the time axis) + - nm: Number of mel bins + - N: Max number of audio clips per batch item + """ + + audio_feats: Annotated[ + torch.Tensor, + TensorShape("b", "l", "nm"), + ] + + audio_feats_lengths: Annotated[ + torch.Tensor, + TensorShape("b", "N"), + ] + + +class MingFlashOmniThinkerProcessingInfo(Qwen2VLProcessingInfo): + def get_hf_config(self) -> BailingMM2Config: + return self.ctx.get_hf_config(BailingMM2Config) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(MingFlashOmniProcessor, **kwargs) + + def get_target_channels(self) -> int: + # See `_normalize_audio_tensor` in vllm_omni/transformers_utils/processors/ming.py + return 1 + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": None, "video": None, "audio": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + mm_counts = mm_counts or {} + requested_modalities = {m for m, c in mm_counts.items() if c > 0} + mm_max_tokens: dict[str, int] = {} + + if requested_modalities & {"image", "video"}: + vl_tokens = super().get_mm_max_tokens_per_item( + seq_len=seq_len, + mm_counts=mm_counts, + ) + mm_max_tokens.update({m: vl_tokens[m] for m in ["image", "video"] if m in requested_modalities}) + + if "audio" in requested_modalities: + # TODO: consider computing from audio config + mm_max_tokens["audio"] = 3000 + + return mm_max_tokens + + def get_feature_extractor(self, **kwargs: object) -> MingWhisperFeatureExtractor: + hf_processor = self.get_hf_processor(**kwargs) + feature_extractor = hf_processor.audio_processor + assert isinstance(feature_extractor, MingWhisperFeatureExtractor) + return feature_extractor + + def get_data_parser(self): + feature_extractor = self.get_feature_extractor() + return MultiModalDataParser( + target_sr=feature_extractor.sampling_rate, + target_channels=self.get_target_channels(), + expected_hidden_size=self._get_expected_hidden_size(), + ) + + +class MingFlashOmniThinkerDummyInputsBuilder(BaseDummyInputsBuilder[MingFlashOmniThinkerProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + num_audios = mm_counts.get("audio", 0) + + hf_processor = self.info.get_hf_processor() + + audio_token: str = hf_processor.audio_token + image_token: str = hf_processor.image_token + video_token: str = hf_processor.video_token + + return image_token * num_images + video_token * num_videos + audio_token * num_audios + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + num_audios = mm_counts.get("audio", 0) + + # Default dimensions for dummy data + image_width, image_height = 448, 448 + video_width, video_height = 448, 448 + num_frames = 8 + audio_duration = 3.0 # seconds + sample_rate = 16000 + + audio_length = int(audio_duration * sample_rate) + + mm_data: MultiModalDataDict = { + "image": self._get_dummy_images( + width=image_width, + height=image_height, + num_images=num_images, + ), + "video": self._get_dummy_videos( + width=video_width, + height=video_height, + num_frames=num_frames, + num_videos=num_videos, + ), + "audio": [(np.random.randn(audio_length).astype(np.float32), sample_rate) for _ in range(num_audios)], + } + + return mm_data + + +class MingFlashOmniThinkerMultiModalProcessor(BaseMultiModalProcessor[MingFlashOmniThinkerProcessingInfo]): + """Multimodal processor for Ming-flash-omni Thinker stage. + + Handles preprocessing of 1) image, 2) video, and 3) audio inputs, + and expands placeholder tokens to the correct number of patch tokens. + """ + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + tokenizer = self.info.get_tokenizer() + # might want to add a fallback to resolve token ids + # vocab = tokenizer.get_vocab() + thinker_config = self.info.get_hf_config() + + # patch/delimiter token IDs (used in replacement sequences) + image_start_token_id = thinker_config.llm_config.image_start_token + image_patch_token_id = thinker_config.llm_config.image_patch_token + image_end_token_id = thinker_config.llm_config.image_end_token + + video_start_token_id = thinker_config.llm_config.video_start_token + frame_patch_token_id = thinker_config.llm_config.video_patch_token + video_end_token_id = thinker_config.llm_config.video_end_token + + audio_start_token_id = thinker_config.llm_config.audio_start_token + audio_patch_token_id = thinker_config.llm_config.audio_patch_token + audio_end_token_id = thinker_config.llm_config.audio_end_token + + vision_config = thinker_config.vision_config + spatial_merge_size = vision_config.spatial_merge_size if vision_config else 2 + + newline_token_ids: list[int] = tokenizer.encode("\n", add_special_tokens=False) + + out_mm_data = out_mm_kwargs.get_data() + + def get_replacement_image(item_idx: int) -> PromptUpdateDetails: + """Generate token sequence for an image.""" + grid_thw = out_mm_data.get("image_grid_thw") + if grid_thw is None: + raise ValueError( + "image_grid_thw missing from processor output; " + "cannot determine image patch count for prompt replacement." + ) + if isinstance(grid_thw, torch.Tensor): + thw = grid_thw[item_idx] + num_patches = int(thw.prod().item()) // (spatial_merge_size**2) + else: + thw = grid_thw[item_idx] + num_patches = (thw[0] * thw[1] * thw[2]) // (spatial_merge_size**2) + + # Build token sequence: *N \n + # the newline token is added in purpose from original model processing + tokens: list[int] = [] + tokens.append(image_start_token_id) + tokens.extend([image_patch_token_id] * num_patches) + tokens.append(image_end_token_id) + # Refer to Ming's BailingMM2Processor._expand_image_tokens + # https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/processing_bailingmm2.py + tokens.extend(newline_token_ids) + + # Only tokens receive multimodal embeddings + return PromptUpdateDetails.select_token_id(tokens, image_patch_token_id) + + def get_replacement_video(item_idx: int) -> PromptUpdateDetails: + """Generate token sequence for a video.""" + grid_thw = out_mm_data.get("video_grid_thw", None) + if grid_thw is None: + raise ValueError( + "video_grid_thw missing from processor output; " + "cannot determine video patch count for prompt replacement." + ) + if isinstance(grid_thw, torch.Tensor): + thw = grid_thw[item_idx] + num_patches = int(thw.prod().item()) // (spatial_merge_size**2) + else: + thw = grid_thw[item_idx] + num_patches = (thw[0] * thw[1] * thw[2]) // (spatial_merge_size**2) + + # Build token sequence: \n + # the newline token is added in purpose from original model processing + tokens: list[int] = [] + tokens.append(video_start_token_id) + tokens.extend([frame_patch_token_id] * num_patches) + tokens.append(video_end_token_id) + tokens.extend(newline_token_ids) + + # Only tokens receive multimodal embeddings + return PromptUpdateDetails.select_token_id(tokens, frame_patch_token_id) + + def get_replacement_audio(item_idx: int) -> PromptUpdateDetails: + """Generate token sequence for an audio.""" + encoder_feats_lengths = out_mm_data.get("encoder_feats_lengths", None) + if encoder_feats_lengths is None: + raise ValueError( + "encoder_feats_lengths missing from processor output; " + "cannot determine audio patch count for prompt replacement." + ) + if isinstance(encoder_feats_lengths, torch.Tensor): + num_patches = int(encoder_feats_lengths[item_idx].item()) + else: + num_patches = encoder_feats_lengths[item_idx] + + # Build token sequence: + tokens: list[int] = [] + tokens.append(audio_start_token_id) + tokens.extend([audio_patch_token_id] * num_patches) + tokens.append(audio_end_token_id) + + # Only tokens receive multimodal embeddings + return PromptUpdateDetails.select_token_id(tokens, audio_patch_token_id) + + # Build prompt updates and process replacement + updates: list[PromptUpdate] = [] + + if "image" in mm_items and mm_items.get_items("image", ImageProcessorItems): + updates.append( + PromptReplacement( + modality="image", + target=PLACEHOLDER_IMAGE_TOKEN_IN_TEXT, + replacement=get_replacement_image, + ) + ) + if "video" in mm_items and mm_items.get_items("video", VideoProcessorItems): + updates.append( + PromptReplacement( + modality="video", + target=PLACEHOLDER_VIDEO_TOKEN_IN_TEXT, + replacement=get_replacement_video, + ) + ) + if "audio" in mm_items and mm_items.get_items("audio", AudioProcessorItems): + updates.append( + PromptReplacement( + modality="audio", + target=PLACEHOLDER_AUDIO_TOKEN_IN_TEXT, + replacement=get_replacement_audio, + ) + ) + return updates + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + config: dict[str, MultiModalFieldConfig] = {} + + # Image fields, pixel_values is flat (concatenated patches from all images) + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + if "pixel_values" in hf_inputs: + image_sizes = image_grid_thw.prod(-1) + config["pixel_values"] = MultiModalFieldConfig.flat_from_sizes( + "image", + image_sizes, + ) + if "image_grid_thw" in hf_inputs: + config["image_grid_thw"] = MultiModalFieldConfig.batched("image") + + # Video fields, same flat layout as images + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + if "pixel_values_videos" in hf_inputs: + video_sizes = video_grid_thw.prod(-1) + config["pixel_values_videos"] = MultiModalFieldConfig.flat_from_sizes( + "video", + video_sizes, + ) + if "video_grid_thw" in hf_inputs: + config["video_grid_thw"] = MultiModalFieldConfig.batched("video") + + # Audio fields + if "audio_feats" in hf_inputs: + config["audio_feats"] = MultiModalFieldConfig.batched("audio") + if "audio_feats_lengths" in hf_inputs: + config["audio_feats_lengths"] = MultiModalFieldConfig.batched("audio") + if "encoder_feats_lengths" in hf_inputs: + config["encoder_feats_lengths"] = MultiModalFieldConfig.batched("audio") + if "placeholder_audio_loc_lens" in hf_inputs: + config["placeholder_audio_loc_lens"] = MultiModalFieldConfig.batched("audio") + + return config + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + """Call sub-processors for multimodal inputs and tokenize. + + We call the image/audio sub-processors directly (instead of going + through `MingFlashOmniProcessor.__call__`) so that the high-level + placeholder tokens remain **unexpanded** in the tokenized output. + """ + hf_processor = self.info.get_hf_processor() + tokenizer = self.info.get_tokenizer() + + data: dict[str, object] = {} + + images = mm_data.get("images", None) + if images is not None: + image_outputs = hf_processor.image_processor( + images=images, + videos=None, + return_tensors="pt", + ) + data.update(image_outputs) + + videos = mm_data.get("videos", None) + if videos is not None: + video_outputs = hf_processor.image_processor( + images=None, + videos=videos, + return_tensors="pt", + ) + # Rename keys to distinguish from images + if "pixel_values" in video_outputs: + video_outputs["pixel_values_videos"] = video_outputs.pop("pixel_values") + if "image_grid_thw" in video_outputs: + video_outputs["video_grid_thw"] = video_outputs.pop("image_grid_thw") + data.update(video_outputs) + + audios = mm_data.get("audios", None) + if audios is not None: + # vLLM's AudioProcessorItems provides raw numpy arrays (already resampled). + # MingWhisperAudioProcessor expects (waveform, sr) tuples, + # so wrap them with the target sample rate. + target_sr = hf_processor.audio_processor.sampling_rate + audio_tuples = [(a, target_sr) if not isinstance(a, tuple) else a for a in audios] + + audio_outputs = hf_processor.audio_processor( + audio_tuples, + return_tensors="pt", + ) + data.update(audio_outputs) + + # Tokenize text with placeholders still intact + text_outputs = tokenizer(prompt, return_tensors="pt", **tok_kwargs) + data.update(text_outputs) + + return BatchFeature(data=data) + + +@MULTIMODAL_REGISTRY.register_processor( + MingFlashOmniThinkerMultiModalProcessor, + info=MingFlashOmniThinkerProcessingInfo, + dummy_inputs=MingFlashOmniThinkerDummyInputsBuilder, +) +class MingFlashOmniThinkerForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsPP, + SupportsMRoPE, + CustomProcessMixin, +): + """Ming Thinker stage: multimodal understanding + (text + image + video + audio) -> text generation. + """ + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={"model.": "language_model."}, + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + # vllm_omni/transformers_utils/processors/ming.py + if modality.startswith("image"): + return "" + elif modality.startswith("video"): + return " + detokenize: True + +runtime: + enabled: true diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml new file mode 100644 index 0000000000..413e0f09cb --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml @@ -0,0 +1,74 @@ +# Stage config for HunyuanImage-3.0 Image+Text-to-Image (image editing). +# Stage 0: AR (HunyuanImage3ForConditionalGeneration) — reads (image, text), emits latent tokens +# Stage 1: Diffusion (HunyuanImage3Pipeline / DiT + VAE) — denoise + decode latents → image + +stage_args: + # Stage 0: AR Model + - stage_id: 0 + stage_type: llm + runtime: + process: true + devices: "0,1,2,3" + max_batch_size: 1 + requires_multimodal_data: true # AR needs the original image + engine_args: + model_stage: AR + model_arch: HunyuanImage3ForCausalMM + worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.95 + enforce_eager: true + trust_remote_code: true + engine_output_type: latent # AR outputs latent for DiT + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + hf_overrides: + rope_parameters: + mrope_section: [0, 32, 32] + rope_type: default + is_comprehension: false # Generation task, not comprehension + final_output: false # AR is not the final output + default_sampling_params: + temperature: 0.6 + top_p: 0.95 + top_k: 1024 + max_tokens: 4096 + stop_token_ids: [127957] # <|endoftext|> + detokenize: false + + # Stage 1: Diffusion (DiT + VAE) + # Receives latents from AR stage, performs denoising + VAE decode + - stage_id: 1 + stage_type: diffusion + runtime: + process: true + devices: "4,5,6,7" + max_batch_size: 1 + requires_multimodal_data: true # May need condition images + engine_args: + model_stage: dit + model_arch: HunyuanImage3ForCausalMM + enforce_eager: true + trust_remote_code: true + distributed_executor_backend: "mp" + parallel_config: + tensor_parallel_size: 4 + enable_expert_parallel: true + omni_kv_config: + need_recv_cache: true + engine_input_source: [0] # Input from AR stage + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.hunyuan_image3.ar2diffusion + final_output: true + final_output_type: image + default_sampling_params: + num_inference_steps: 50 + guidance_scale: 2.5 + +# Top-level runtime config +runtime: + enabled: true + edges: + - from: 0 # AR → Diffusion + to: 1 diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit_2gpu_fp8.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit_2gpu_fp8.yaml index 51110c2858..586b601bc5 100644 --- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit_2gpu_fp8.yaml +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit_2gpu_fp8.yaml @@ -11,13 +11,9 @@ stage_args: max_batch_size: 1 engine_args: model_stage: dit - gpu_memory_utilization: 0.9 enforce_eager: true trust_remote_code: true - engine_output_type: image distributed_executor_backend: "mp" - enable_prefix_caching: false - max_num_batched_tokens: 32768 quantization: "fp8" parallel_config: tensor_parallel_size: 2 @@ -34,6 +30,3 @@ stage_args: # Runtime edges runtime: enabled: true - defaults: - window_size: -1 - max_inflight: 1 diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i.yaml similarity index 80% rename from vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit.yaml rename to vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i.yaml index 0b812ff376..1d8c7f4812 100644 --- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit.yaml +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i.yaml @@ -11,13 +11,9 @@ stage_args: engine_args: max_num_seqs: 1 model_stage: dit - gpu_memory_utilization: 0.65 enforce_eager: true trust_remote_code: true - engine_output_type: image distributed_executor_backend: "mp" - enable_prefix_caching: false - max_num_batched_tokens: 32768 parallel_config: tensor_parallel_size: 4 enable_expert_parallel: true @@ -33,6 +29,3 @@ stage_args: # Runtime edges runtime: enabled: true - defaults: - window_size: -1 - max_inflight: 1 diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe_2gpu.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i_2gpu.yaml similarity index 95% rename from vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe_2gpu.yaml rename to vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i_2gpu.yaml index e029c38362..41ed74ba62 100644 --- a/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe_2gpu.yaml +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i_2gpu.yaml @@ -39,6 +39,3 @@ stage_args: runtime: enabled: true - defaults: - window_size: -1 - max_inflight: 1 diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml new file mode 100644 index 0000000000..a0a1a0dc1c --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml @@ -0,0 +1,42 @@ +# Stage config for HunyuanImage-3.0 Text-to-Text (T2T / pure text generation). +# Single LLM stage: AR model reads text prompt only, generates text output. +# Sampling params aligned with official generation_config.json. + +stage_args: + - stage_id: 0 + stage_type: llm + runtime: + process: true + devices: "0,1,2,3" + max_batch_size: 1 + requires_multimodal_data: false + engine_args: + model_stage: AR + max_num_seqs: 1 + model_arch: HunyuanImage3ForCausalMM + worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.95 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + hf_overrides: + rope_parameters: + mrope_section: [0, 32, 32] + rope_type: default + is_comprehension: true + final_output: true + final_output_type: text + default_sampling_params: + temperature: 0.0 + top_p: 0.95 + top_k: 1024 + max_tokens: 2048 + stop_token_ids: [127957, 128026] # <|endoftext|>, + detokenize: True + +runtime: + enabled: true diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe.yaml deleted file mode 100644 index 6f4ba306a5..0000000000 --- a/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe.yaml +++ /dev/null @@ -1,85 +0,0 @@ -# Stage config for running Hunyuan-Image3.0 for multi-stage omni runtime. -# Stage 0: AR Model (vLLM implementation) - -# The following config has been verified on 8x L40S-48G GPU. -modes: - - mode: text-to-image - stages: [1] - - mode: image-to-text - stages: [0] -stage_args: - - stage_id: 0 - stage_type: llm # Use llm stage type for AR stages - runtime: - process: true # Run this stage in a separate process - devices: "0,1,2,3,4,5,6,7" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) - engine_args: - model_stage: AR - max_num_seqs: 1 - model_arch: HunyuanImage3ForCausalMM - worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.3 - enforce_eager: true # Now we only support eager mode - trust_remote_code: true - engine_output_type: latent - enable_prefix_caching: false - max_num_batched_tokens: 32768 - tensor_parallel_size: 8 - pipeline_parallel_size: 1 - hf_overrides: - rope_parameters: - mrope_section: [0, 32, 32] - rope_type: default - is_comprehension: true - final_output: true - final_output_type: text - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.1 - - stage_id: 1 - stage_type: diffusion - runtime: - process: true - devices: "0,1,2,3,4,5,6,7" - max_batch_size: 1 - engine_args: - model_stage: diffusion - gpu_memory_utilization: 0.9 - enforce_eager: true - engine_output_type: image - distributed_executor_backend: "mp" - enable_prefix_caching: false - max_num_batched_tokens: 32768 - vae_use_slicing: false - vae_use_tiling: false - cache_backend: null - cache_config: null - enable_cache_dit_summary: false - parallel_config: - pipeline_parallel_size: 1 - data_parallel_size: 1 - tensor_parallel_size: 8 - enable_expert_parallel: false - sequence_parallel_size: 1 - ulysses_degree: 1 - ring_degree: 1 - cfg_parallel_size: 1 - vae_patch_parallel_size: 1 - use_hsdp: false - hsdp_shard_size: -1 - hsdp_replicate_size: 1 - final_output: true - final_output_type: image - -# Top-level runtime config (concise): default windows and stage edges -runtime: - enabled: true - defaults: - window_size: -1 # Simplified: trigger downstream only after full upstream completion - max_inflight: 1 # Simplified: process serially within each stage diff --git a/vllm_omni/model_executor/stage_configs/mimo_audio_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/mimo_audio_async_chunk.yaml index b3c6bbbaf0..2fa1b982af 100644 --- a/vllm_omni/model_executor/stage_configs/mimo_audio_async_chunk.yaml +++ b/vllm_omni/model_executor/stage_configs/mimo_audio_async_chunk.yaml @@ -74,10 +74,6 @@ stage_args: runtime: enabled: true - defaults: - window_size: -1 - max_inflight: 1 - connectors: connector_of_shared_memory: name: SharedMemoryConnector @@ -93,4 +89,3 @@ runtime: edges: - from: 0 to: 1 - window_size: -1 diff --git a/vllm_omni/model_executor/stage_configs/omnivoice.yaml b/vllm_omni/model_executor/stage_configs/omnivoice.yaml index 49f11e9674..546e3b3dc2 100644 --- a/vllm_omni/model_executor/stage_configs/omnivoice.yaml +++ b/vllm_omni/model_executor/stage_configs/omnivoice.yaml @@ -10,10 +10,8 @@ stage_args: engine_args: model_stage: dit model_class_name: "OmniVoicePipeline" - gpu_memory_utilization: 0.5 enforce_eager: true trust_remote_code: true - engine_output_type: audio distributed_executor_backend: "mp" dtype: "float32" final_output: true diff --git a/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml b/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml deleted file mode 100644 index 0a307b4477..0000000000 --- a/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml +++ /dev/null @@ -1,107 +0,0 @@ -# stage config for running qwen2.5-omni for multi-stage omni runtime. - -# The following config has been verified on 2x H100-80G GPU. -stage_args: - - stage_id: 0 - stage_type: llm # Use llm stage type for AR stages - runtime: - process: true # Run this stage in a separate process - devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) - engine_args: - model_stage: thinker - max_num_seqs: 1 - model_arch: Qwen2_5OmniForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.8 - enforce_eager: true # Now we only support eager mode - trust_remote_code: true - engine_output_type: latent - enable_prefix_caching: false - max_num_batched_tokens: 32768 - mm_processor_cache_gb: 0 - is_comprehension: true - final_output: true - final_output_type: text - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.1 - - - stage_id: 1 - stage_type: llm # Use llm stage type for AR stages - runtime: - process: true - devices: "1" - engine_args: - model_stage: talker - max_num_seqs: 1 - model_arch: Qwen2_5OmniForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.8 - enforce_eager: true - trust_remote_code: true - enable_prefix_caching: false - max_num_batched_tokens: 32768 - engine_output_type: latent - engine_input_source: [0] - custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker - default_sampling_params: - temperature: 0.9 - top_p: 0.8 - top_k: 40 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.05 - stop_token_ids: [8294] - - - stage_id: 2 - stage_type: llm # Use llm stage type for AR stages - runtime: - process: true - devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU - engine_args: - model_stage: code2wav - max_num_seqs: 1 - model_arch: Qwen2_5OmniForConditionalGeneration - worker_type: generation - scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler - gpu_memory_utilization: 0.15 - enforce_eager: true - trust_remote_code: true - enable_prefix_caching: false - max_num_batched_tokens: 32768 - async_scheduling: false - engine_output_type: audio - engine_input_source: [1] - final_output: true - final_output_type: audio - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.1 - -# Top-level runtime config (concise): default windows and stage edges -runtime: - enabled: true - defaults: - window_size: -1 # Simplified: trigger downstream only after full upstream completion - max_inflight: 1 # Simplified: process serially within each stage - - edges: - - from: 0 # thinker → talker: trigger only after receiving full input (-1) - to: 1 - window_size: -1 - - from: 1 # talker → code2wav: trigger only after receiving full input (-1) - to: 2 - window_size: -1 diff --git a/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml deleted file mode 100644 index b318aebe36..0000000000 --- a/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml +++ /dev/null @@ -1,151 +0,0 @@ -# stage config for running qwen2.5-omni for multi-stage omni runtime. - -# The following config has been verified on 1x H100-80G GPU. -stage_args: - - stage_id: 0 - stage_type: llm # Use llm stage type for AR stages - runtime: - process: true # Run this stage in a separate process - devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) - engine_args: - model_stage: thinker - max_num_seqs: 1 - model_arch: Qwen2_5OmniForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.8 - enforce_eager: true # Now we only support eager mode - trust_remote_code: true - engine_output_type: latent - enable_prefix_caching: false - mm_processor_cache_gb: 0 - is_comprehension: true - final_output: true - final_output_type: text - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.1 - # Distributed connector configuration (optional) - output_connectors: - to_stage_1: mooncake_connector - - stage_id: 1 - stage_type: llm # Use llm stage type for AR stages - runtime: - process: true - devices: "1" - engine_args: - model_stage: talker - max_num_seqs: 1 - model_arch: Qwen2_5OmniForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.8 - enforce_eager: true - trust_remote_code: true - enable_prefix_caching: false - engine_output_type: latent - engine_input_source: [0] - custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker - default_sampling_params: - temperature: 0.9 - top_p: 0.8 - top_k: 40 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.05 - stop_token_ids: [8294] - # Distributed connector configuration (optional) - input_connectors: - from_stage_0: mooncake_connector - output_connectors: - to_stage_2: mooncake_connector - - stage_id: 2 - stage_type: llm # Use llm stage type for AR stages - runtime: - process: true - devices: "2" # Example: use a different GPU than the previous stage; use "0" if single GPU - engine_args: - model_stage: code2wav - max_num_seqs: 1 - model_arch: Qwen2_5OmniForConditionalGeneration - worker_type: generation - scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler - gpu_memory_utilization: 0.3 - enforce_eager: true - trust_remote_code: true - enable_prefix_caching: false - max_num_batched_tokens: 32768 - engine_output_type: audio - engine_input_source: [1] - final_output: true - final_output_type: audio - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.1 - # Distributed connector configuration (optional) - input_connectors: - from_stage_1: mooncake_connector - -# Top-level runtime config (concise): default windows and stage edges -runtime: - enabled: true - defaults: - window_size: -1 # Simplified: trigger downstream only after full upstream completion - max_inflight: 1 # Simplified: process serially within each stage - - # Distributed connectors configuration (optional) - # More connectors will be supported in the future. - connectors: - # Mooncake connector for cross-node/intra-node communication - mooncake_connector: - name: MooncakeStoreConnector - extra: - host: "127.0.0.1" - metadata_server: "http://10.90.67.86:8080/metadata" - master: "10.90.67.86:50051" - segment: 512000000 # 512MB - localbuf: 64000000 # 64MB - proto: "tcp" - - # Mori RDMA connector for cross-node/intra-node communication - mori_connector: - name: MoriTransferEngineConnector - extra: - host: "auto" - zmq_port: 50051 - device_name: "" - memory_pool_size: 536870912 # 512 MB - memory_pool_device: "cpu" - - # Yuanrong connector for cross-node/intra-node communication - yuanrong_connector: - name: YuanrongConnector - extra: - host: "127.0.0.1" - port: "35000" - - # SharedMemory connector for intra-node communication - # Alternative SHM connector with different threshold - shared_memory_connector: - name: SharedMemoryConnector - extra: - shm_threshold_bytes: 65536 # 64KB threshold - - edges: - - from: 0 # thinker → talker: trigger only after receiving full input (-1) - to: 1 - window_size: -1 - - from: 1 # talker → code2wav: trigger only after receiving full input (-1) - to: 2 - window_size: -1 diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml deleted file mode 100644 index 0ce4f0c94f..0000000000 --- a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml +++ /dev/null @@ -1,101 +0,0 @@ -# Stage config for running Qwen3-Omni-MoE with 3-stage architecture -# Stage 0: Thinker (multimodal understanding + text generation) -# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes) -# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform) - -# The following config has been verified on 2x H100-80G GPUs. -async_chunk: false -stage_args: - - stage_id: 0 - stage_type: llm # Use llm stage type for AR stages - runtime: - devices: "0" - engine_args: - model_stage: thinker - max_num_seqs: 64 - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.9 - enforce_eager: false - trust_remote_code: true - engine_output_type: latent # Output hidden states for talker - distributed_executor_backend: "mp" - enable_prefix_caching: false - max_num_batched_tokens: 32768 - hf_config_name: thinker_config - tensor_parallel_size: 1 - final_output: true - final_output_type: text - is_comprehension: true - default_sampling_params: - temperature: 0.4 - top_p: 0.9 - top_k: 1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.05 - - - stage_id: 1 - stage_type: llm # Use llm stage type for AR stages - runtime: - devices: "1" - engine_args: - model_stage: talker - max_num_seqs: 64 - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.6 - enforce_eager: false - trust_remote_code: true - engine_output_type: latent # Output codec codes for code2wav - enable_prefix_caching: false - max_num_batched_tokens: 32768 - distributed_executor_backend: "mp" - hf_config_name: talker_config - engine_input_source: [0] - custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker - # final_output: true - # final_output_type: text - default_sampling_params: - temperature: 0.9 - top_k: 50 - max_tokens: 4096 - seed: 42 - detokenize: False - repetition_penalty: 1.05 - stop_token_ids: [2150] - - - stage_id: 2 - stage_type: llm # Use llm stage type for AR stages - runtime: - devices: "1" - engine_args: - model_stage: code2wav - max_num_seqs: 32 - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: generation - scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler - enforce_eager: true - trust_remote_code: true - async_scheduling: false - enable_prefix_caching: false - engine_output_type: audio # Final output: audio waveform - gpu_memory_utilization: 0.1 - distributed_executor_backend: "mp" - max_num_batched_tokens: 1000000 - hf_config_name: thinker_config - engine_input_source: [1] - custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav - final_output: true - final_output_type: audio - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 65536 - seed: 42 - detokenize: True - repetition_penalty: 1.1 diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml deleted file mode 100644 index 38626fc081..0000000000 --- a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml +++ /dev/null @@ -1,117 +0,0 @@ -# Stage config for running Qwen3-Omni-MoE with 3-stage architecture -# Stage 0: Thinker (multimodal understanding + text generation) -# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes) -# Stage 2: Code2Wav (16-layer RVQ codes → audio waveform) - -# The following config has been verified on 2x H100-80G GPUs. -async_chunk: true -stage_args: - - stage_id: 0 - stage_type: llm # Use llm stage type for AR stages - runtime: - devices: "0" - engine_args: - model_stage: thinker - max_num_seqs: 64 - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.9 - enforce_eager: false - trust_remote_code: true - engine_output_type: latent # Output hidden states for talker - distributed_executor_backend: "mp" - enable_prefix_caching: false - max_num_batched_tokens: 32768 - hf_config_name: thinker_config - tensor_parallel_size: 1 - custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk - final_output: true - final_output_type: text - is_comprehension: true - # Use named connector to apply runtime.connectors.extra. - output_connectors: - to_stage_1: connector_of_shared_memory - default_sampling_params: - temperature: 0.4 - top_p: 0.9 - top_k: 1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.05 - - - stage_id: 1 - stage_type: llm # Use llm stage type for AR stages - runtime: - devices: "1" - engine_args: - model_stage: talker - max_num_seqs: 64 - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.6 - enforce_eager: false - trust_remote_code: true - engine_output_type: latent # Output codec codes for code2wav - enable_prefix_caching: false - max_num_batched_tokens: 32768 - distributed_executor_backend: "mp" - hf_config_name: talker_config - custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk - engine_input_source: [0] - # final_output: true - # final_output_type: text - # Distributed connector configuration - input_connectors: - from_stage_0: connector_of_shared_memory - default_sampling_params: - temperature: 0.9 - top_k: 50 - max_tokens: 4096 - seed: 42 - detokenize: False - repetition_penalty: 1.05 - stop_token_ids: [2150] - - - stage_id: 2 - stage_type: llm # Use llm stage type for AR stages - runtime: - devices: "1" - engine_args: - model_stage: code2wav - max_num_seqs: 64 - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: generation - scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler - enforce_eager: true - trust_remote_code: true - async_scheduling: false - enable_prefix_caching: false - engine_output_type: audio # Final output: audio waveform - gpu_memory_utilization: 0.1 - distributed_executor_backend: "mp" - max_num_batched_tokens: 51200 # [TODO] if max_num_batch_tokens < max_num_seqs * 800, there will be precision problem. - hf_config_name: thinker_config - engine_input_source: [1] - final_output: true - final_output_type: audio - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 65536 - seed: 42 - detokenize: True - repetition_penalty: 1.1 - -runtime: - - connectors: - connector_of_shared_memory: - name: SharedMemoryConnector - extra: - # Align with Omni: small chunks with sufficient context overlap. - codec_chunk_frames: 25 # code2wav decode chunk size - codec_left_context_frames: 25 # code2wav left context size diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml deleted file mode 100644 index 6c2d2a7669..0000000000 --- a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml +++ /dev/null @@ -1,143 +0,0 @@ -# Stage config for running Qwen3-Omni-MoE with 3-stage architecture -# Stage 0: Thinker (multimodal understanding + text generation) -# Stage 1: Talker (text embeddings -> 8-layer RVQ codec codes) -# Stage 2: Code2Wav (8-layer RVQ codes -> audio waveform) - -# The following config has been verified on 2x H100-80G GPUs. -stage_args: - - stage_id: 0 - stage_type: llm # Use llm stage type for AR stages - runtime: - devices: "0" - engine_args: - model_stage: thinker - max_num_seqs: 1 - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.9 - enforce_eager: true - trust_remote_code: true - engine_output_type: latent # Output hidden states for talker - distributed_executor_backend: "mp" - enable_prefix_caching: false - hf_config_name: thinker_config - tensor_parallel_size: 1 - final_output: true - final_output_type: text - is_comprehension: true - default_sampling_params: - temperature: 0.4 - top_p: 0.9 - top_k: 1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.05 - # Distributed connector configuration - output_connectors: - to_stage_1: mooncake_connector - - - stage_id: 1 - stage_type: llm # Use llm stage type for AR stages - runtime: - devices: "1" - engine_args: - model_stage: talker - max_num_seqs: 1 - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.6 - enforce_eager: true - trust_remote_code: true - engine_output_type: latent # Output codec codes for code2wav - # tensor_parallel_size: 2 - enable_prefix_caching: false - distributed_executor_backend: "mp" - hf_config_name: talker_config - engine_input_source: [0] - custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker - # final_output: true - # final_output_type: text - default_sampling_params: - temperature: 0.9 - top_k: 50 - max_tokens: 4096 - seed: 42 - detokenize: False - repetition_penalty: 1.05 - stop_token_ids: [2150] - # Distributed connector configuration - input_connectors: - from_stage_0: mooncake_connector - output_connectors: - to_stage_2: mooncake_connector - - - stage_id: 2 - stage_type: llm # Use llm stage type for AR stages - runtime: - devices: "1" - engine_args: - model_stage: code2wav - max_num_seqs: 64 - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: generation - scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler - enforce_eager: true - trust_remote_code: true - enable_prefix_caching: false - engine_output_type: audio # Final output: audio waveform - gpu_memory_utilization: 0.1 - distributed_executor_backend: "mp" - max_num_batched_tokens: 1000000 - hf_config_name: thinker_config - engine_input_source: [1] - custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav - final_output: true - final_output_type: audio - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 65536 - seed: 42 - detokenize: True - repetition_penalty: 1.1 - # Distributed connector configuration - input_connectors: - from_stage_1: mooncake_connector - -# Top-level runtime config: default windows and stage edges -runtime: - enabled: true - defaults: - window_size: -1 - max_inflight: 1 - - # Distributed connectors configuration - connectors: - # Mooncake connector for cross-node/intra-node communication - mooncake_connector: - name: MooncakeStoreConnector - extra: - host: "127.0.0.1" - metadata_server: "http://10.90.67.86:8080/metadata" - master: "10.90.67.86:50051" - segment: 512000000 # 512MB - localbuf: 64000000 # 64MB - proto: "tcp" - - # SharedMemory connector for intra-node communication - shared_memory_connector: - name: SharedMemoryConnector - extra: - shm_threshold_bytes: 65536 # 64KB threshold - - edges: - - from: 0 - to: 1 - window_size: -1 - - from: 1 - to: 2 - window_size: -1 diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml deleted file mode 100644 index 75b2bab3a2..0000000000 --- a/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml +++ /dev/null @@ -1,100 +0,0 @@ -# Same as qwen3_tts.yaml with batched talker and code2wav. -# Stage 0: max_num_seqs 4, stage 1: max_num_seqs 4. -# max_num_seqs must be a power of two to align with CUDA graph capture sizes -# (stage 0) and must match --batch-size in end2end.py / benchmark scripts. -async_chunk: true -stage_args: - - stage_id: 0 - stage_type: llm - is_comprehension: true - runtime: - devices: "0" - engine_args: - model_stage: qwen3_tts - max_num_seqs: 4 - model_arch: Qwen3TTSTalkerForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - enforce_eager: false - trust_remote_code: true - async_scheduling: true - enable_prefix_caching: false - engine_output_type: latent - gpu_memory_utilization: 0.3 - distributed_executor_backend: "mp" - max_num_batched_tokens: 512 - max_model_len: 4096 - custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk - # Use named connector to apply runtime.connectors.extra. - output_connectors: - to_stage_1: connector_of_shared_memory - default_sampling_params: - temperature: 0.9 - top_k: 50 - max_tokens: 4096 - seed: 42 - detokenize: false - repetition_penalty: 1.05 - stop_token_ids: [2150] - - - stage_id: 1 - stage_type: llm - runtime: - devices: "0" - engine_args: - model_stage: code2wav - max_num_seqs: 4 - model_arch: Qwen3TTSCode2Wav - worker_type: generation - scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler - enforce_eager: true - trust_remote_code: true - async_scheduling: true - enable_prefix_caching: false - engine_output_type: audio - gpu_memory_utilization: 0.2 - distributed_executor_backend: "mp" - # Must be divisible by num_code_groups and cover (left_context + chunk). - max_num_batched_tokens: 65536 - # Flat codec prompt can exceed 32k tokens (Q * frames); align with max_tokens below. - max_model_len: 65536 - engine_input_source: [0] - final_output: true - final_output_type: audio - # Distributed connector configuration - input_connectors: - from_stage_0: connector_of_shared_memory - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 65536 - seed: 42 - detokenize: true - repetition_penalty: 1.0 - -runtime: - enabled: true - defaults: - window_size: -1 - max_inflight: 1 - - connectors: - connector_of_shared_memory: - name: SharedMemoryConnector - extra: - shm_threshold_bytes: 65536 - # Frame-aligned codec streaming transport. - codec_streaming: true - # Connector polling / timeout (unit: loop count, sleep interval in seconds). - connector_get_sleep_s: 0.01 - connector_get_max_wait_first_chunk: 3000 - connector_get_max_wait: 300 - # Match the decoder sliding attention window to avoid chunk-boundary noise. - codec_chunk_frames: 25 - codec_left_context_frames: 72 - - edges: - - from: 0 - to: 1 - window_size: -1 diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml similarity index 94% rename from vllm_omni/model_executor/stage_configs/qwen3_tts.yaml rename to vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml index a0d38eb4b9..4ca8d11ad7 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml +++ b/vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml @@ -17,7 +17,6 @@ stage_args: enable_prefix_caching: false engine_output_type: latent gpu_memory_utilization: 0.3 - distributed_executor_backend: "mp" max_num_batched_tokens: 512 max_model_len: 4096 custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk @@ -49,7 +48,6 @@ stage_args: enable_prefix_caching: false engine_output_type: audio gpu_memory_utilization: 0.3 - distributed_executor_backend: "mp" # Must be divisible by num_code_groups and cover (left_context + chunk). # Prefill length is Q * num_frames (e.g. 16 * 2148 = 34368); keep headroom past 32k. max_num_batched_tokens: 65536 @@ -74,9 +72,6 @@ stage_args: runtime: enabled: true - defaults: - window_size: -1 - max_inflight: 1 connectors: connector_of_shared_memory: @@ -96,4 +91,3 @@ runtime: edges: - from: 0 to: 1 - window_size: -1 diff --git a/vllm_omni/model_executor/stage_configs/voxcpm.yaml b/vllm_omni/model_executor/stage_configs/voxcpm.yaml new file mode 100644 index 0000000000..a5f324f660 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/voxcpm.yaml @@ -0,0 +1,69 @@ +# VoxCPM two-stage (latent → VAE) without async_chunk: one-shot latent then decode. +stage_args: + - stage_id: 0 + stage_type: llm + is_comprehension: true + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + dtype: bfloat16 + model_stage: latent_generator + model_arch: VoxCPMForConditionalGeneration + # Optional persistent HF-compatible config dir for native VoxCPM models. + hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,} + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: latent + gpu_memory_utilization: 0.7 + distributed_executor_backend: "mp" + max_num_batched_tokens: 4096 + max_model_len: 4096 + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 4096 + stop_token_ids: [2] + seed: 42 + detokenize: false + repetition_penalty: 1.0 + final_output: false + + - stage_id: 1 + stage_type: llm + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + dtype: float32 + model_stage: vae + model_arch: VoxCPMForConditionalGeneration + hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,} + worker_type: generation + scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: false + enable_prefix_caching: false + engine_output_type: audio + gpu_memory_utilization: 0.15 + distributed_executor_backend: "mp" + max_num_batched_tokens: 8192 + max_model_len: 4096 + engine_input_source: [0] + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.voxcpm.latent2vae + final_output: true + final_output_type: audio + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 1 + seed: 42 + detokenize: true + repetition_penalty: 1.0 diff --git a/vllm_omni/model_executor/stage_configs/voxcpm2.yaml b/vllm_omni/model_executor/stage_configs/voxcpm2.yaml new file mode 100644 index 0000000000..7cc93d6b26 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/voxcpm2.yaml @@ -0,0 +1,36 @@ +# VoxCPM2 AR pipeline with per-request state batching. +# Uses native MiniCPM4 base_lm + per-request StaticKVCache. +# max_batch_size > 1 supported via KV cache save/restore. +stage_args: + - stage_id: 0 + stage_type: llm + is_comprehension: true + runtime: + devices: "0" + max_batch_size: 4 + engine_args: + dtype: bfloat16 + model_stage: latent_generator + model_arch: VoxCPM2TalkerForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + enforce_eager: true + trust_remote_code: true + async_scheduling: true + enable_prefix_caching: false + engine_output_type: audio + gpu_memory_utilization: 0.9 + distributed_executor_backend: "mp" + max_num_batched_tokens: 4096 + max_model_len: 4096 + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 4096 + seed: 42 + detokenize: false + repetition_penalty: 1.0 + stop_token_ids: [1] + final_output: true + final_output_type: audio diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml b/vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml similarity index 60% rename from vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml rename to vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml index cd82d91b71..c6fd177a35 100644 --- a/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml +++ b/vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml @@ -1,3 +1,5 @@ +# VoxCPM two-stage streaming (align with qwen3_tts.yaml async_chunk pattern). +# Stage0 (latent_generator) emits latent in time chunks; Stage1 (VAE) decodes as chunks arrive. async_chunk: true stage_args: - stage_id: 0 @@ -5,42 +7,48 @@ stage_args: is_comprehension: true runtime: devices: "0" + max_batch_size: 1 engine_args: - model_stage: qwen3_tts - max_num_seqs: 1 - model_arch: Qwen3TTSTalkerForConditionalGeneration + dtype: bfloat16 + model_stage: latent_generator + model_arch: VoxCPMForConditionalGeneration + # Optional persistent HF-compatible config dir for native VoxCPM models. + hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,} worker_type: ar scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler enforce_eager: true trust_remote_code: true - async_scheduling: false + async_scheduling: true enable_prefix_caching: false engine_output_type: latent - gpu_memory_utilization: 0.3 + gpu_memory_utilization: 0.7 distributed_executor_backend: "mp" - max_num_batched_tokens: 512 + max_num_batched_tokens: 4096 max_model_len: 4096 - custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk - # Use named connector to apply runtime.connectors.extra. - output_connectors: - to_stage_1: connector_of_shared_memory + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.voxcpm.latent2vae_async_chunk default_sampling_params: - temperature: 0.9 - top_k: 50 + temperature: 0.0 + top_p: 1.0 + top_k: -1 max_tokens: 4096 + stop_token_ids: [2] seed: 42 detokenize: false - repetition_penalty: 1.05 - stop_token_ids: [2150] + repetition_penalty: 1.0 + final_output: false + output_connectors: + to_stage_1: voxcpm_shm - stage_id: 1 stage_type: llm runtime: devices: "0" + max_batch_size: 1 engine_args: - model_stage: code2wav - max_num_seqs: 1 - model_arch: Qwen3TTSCode2Wav + dtype: float32 + model_stage: vae + model_arch: VoxCPMForConditionalGeneration + hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,} worker_type: generation scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler enforce_eager: true @@ -48,35 +56,30 @@ stage_args: async_scheduling: false enable_prefix_caching: false engine_output_type: audio - gpu_memory_utilization: 0.2 + gpu_memory_utilization: 0.15 distributed_executor_backend: "mp" - max_num_batched_tokens: 65536 - max_model_len: 65536 + max_num_batched_tokens: 8192 + max_model_len: 4096 engine_input_source: [0] final_output: true final_output_type: audio - # Distributed connector configuration input_connectors: - from_stage_0: connector_of_shared_memory - tts_args: - max_instructions_length: 500 + from_stage_0: voxcpm_shm default_sampling_params: temperature: 0.0 top_p: 1.0 top_k: -1 - max_tokens: 65536 + max_tokens: 128 seed: 42 detokenize: true repetition_penalty: 1.0 + runtime: enabled: true - defaults: - window_size: -1 - max_inflight: 1 connectors: - connector_of_shared_memory: + voxcpm_shm: name: SharedMemoryConnector extra: shm_threshold_bytes: 65536 @@ -87,10 +90,9 @@ runtime: connector_get_max_wait_first_chunk: 3000 connector_get_max_wait: 300 # Align with Omni: small chunks with sufficient context overlap. - codec_chunk_frames: 25 - codec_left_context_frames: 72 + codec_chunk_frames: 1 + codec_left_context_frames: 1 edges: - from: 0 to: 1 - window_size: -1 diff --git a/vllm_omni/model_executor/stage_configs/voxtral_tts.yaml b/vllm_omni/model_executor/stage_configs/voxtral_tts.yaml index 31cccb9ccf..b0d9a81e78 100644 --- a/vllm_omni/model_executor/stage_configs/voxtral_tts.yaml +++ b/vllm_omni/model_executor/stage_configs/voxtral_tts.yaml @@ -82,10 +82,6 @@ stage_args: # Top-level runtime config (concise): default windows and stage edges runtime: enabled: true - defaults: - window_size: -1 # Simplified: trigger downstream only after full upstream completion - max_inflight: 1 # Simplified: process serially within each stage - connectors: connector_of_shared_memory: name: SharedMemoryConnector @@ -102,4 +98,3 @@ runtime: edges: - from: 0 # language_model → acoustic_transformer: trigger only after receiving full input (-1) to: 1 - window_size: -1 diff --git a/vllm_omni/model_executor/stage_input_processors/bagel.py b/vllm_omni/model_executor/stage_input_processors/bagel.py index bfcff0ea0f..52cc14d3aa 100644 --- a/vllm_omni/model_executor/stage_input_processors/bagel.py +++ b/vllm_omni/model_executor/stage_input_processors/bagel.py @@ -82,6 +82,8 @@ def expand_cfg_prompts( neg_prompt = _get_negative_prompt(prompt, sampling_params) if "image" in modalities: + if not neg_prompt: + return [] neg_prompt_dict = { "prompt": neg_prompt, "modalities": prompt.get("modalities", []), @@ -166,6 +168,8 @@ def expand_cfg_prompts_think( companion_params = {"max_tokens": 1} if "image" in modalities: + if not neg_prompt: + return [] neg_prompt_dict = { "prompt": neg_prompt, "modalities": prompt.get("modalities", []), @@ -287,9 +291,10 @@ def _get_negative_prompt( ) -> str: """Resolve the negative prompt for CFG from prompt or sampling params. - An empty string is treated the same as absent (falls through to - the Bagel default token pair), because an empty negative prompt is - not meaningful for CFG guidance. + Returns the negative prompt string when one is supplied, otherwise an + empty string. Callers decide how to treat the empty case: text2img + skips the cfg_text companion entirely, while img2img substitutes it + into the cfg_text prompt template. """ neg = prompt.get("negative_prompt") if neg: @@ -300,4 +305,4 @@ def _get_negative_prompt( if neg: return neg - return "<|im_start|><|im_end|>" + return "" diff --git a/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py b/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py new file mode 100644 index 0000000000..89a7a28f6c --- /dev/null +++ b/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Stage input processor for HunyuanImage3: AR → Diffusion transition. + +In IT2I (image editing) mode: + - Stage 0 (AR) receives (image + edit instruction), generates CoT/latent tokens + - Stage 1 (DiT) receives the AR output + original image, denoises → edited image + +The ar2diffusion function bridges these two stages, following the same +signature pattern as glm_image.ar2diffusion. +""" + +from typing import Any + +import torch +from vllm.inputs import TextPrompt +from vllm.logger import init_logger + +from vllm_omni.inputs.data import OmniTokensPrompt + +logger = init_logger(__name__) + + +def ar2diffusion( + stage_list: list[Any], + engine_input_source: list[int], + prompt: OmniTokensPrompt | TextPrompt | list | None = None, + requires_multimodal_data: bool = False, +) -> list[dict[str, Any]]: + """Process AR stage outputs to create Diffusion stage inputs. + + Args: + stage_list: List of stage clients (set by orchestrator). + engine_input_source: List of source stage IDs (from YAML). + prompt: Original user prompt (may contain multimodal data). + requires_multimodal_data: Whether to forward multimodal data. + + Returns: + List of dicts, each consumable by the HunyuanImage3 diffusion pipeline. + """ + if not engine_input_source: + raise ValueError("engine_input_source cannot be empty") + + source_stage_id = engine_input_source[0] + if source_stage_id >= len(stage_list): + raise IndexError(f"Invalid source stage_id: {source_stage_id}") + + if stage_list[source_stage_id].engine_outputs is None: + raise RuntimeError(f"Stage {source_stage_id} has no outputs yet") + + ar_outputs = stage_list[source_stage_id].engine_outputs + diffusion_inputs = [] + + # Normalize prompt to list + if not isinstance(prompt, list): + prompt = [prompt] if prompt is not None else [{}] + + for i, ar_output in enumerate(ar_outputs): + output = ar_output.outputs[0] + generated_token_ids = output.token_ids + generated_text = getattr(output, "text", "") or "" + + # Get original prompt info + original_prompt = prompt[i] if i < len(prompt) else {} + if isinstance(original_prompt, dict): + pass + elif hasattr(original_prompt, "_asdict"): + original_prompt = original_prompt._asdict() + elif hasattr(original_prompt, "__dict__"): + original_prompt = vars(original_prompt) + else: + original_prompt = {} + + height = original_prompt.get("height", 1024) + width = original_prompt.get("width", 1024) + text_prompt = original_prompt.get("prompt", "") + + logger.info( + "[ar2diffusion] Request %d: AR generated %d tokens, text length=%d, target size=%dx%d", + i, + len(generated_token_ids), + len(generated_text), + height, + width, + ) + + token_tensor = torch.tensor(generated_token_ids, dtype=torch.long) + + diffusion_input: dict[str, Any] = { + "prompt": text_prompt, + "height": height, + "width": width, + "extra": { + "ar_token_ids": token_tensor, + "ar_generated_text": generated_text, + }, + } + + # Forward multimodal data (original image for IT2I conditioning) + mm_data = original_prompt.get("multi_modal_data") + if mm_data: + pil_image = mm_data.get("image") + if pil_image is None: + images = mm_data.get("images") + if images: + pil_image = images[0] if isinstance(images, list) else images + if pil_image is not None: + diffusion_input["pil_image"] = pil_image + + # Forward multimodal output from AR (if any) + if hasattr(ar_output, "multimodal_output") and ar_output.multimodal_output: + mm_output = ar_output.multimodal_output + if isinstance(mm_output, dict): + diffusion_input["extra"]["ar_multimodal_output"] = mm_output + + # Forward sampling params + for key in ["seed", "num_inference_steps", "guidance_scale", "negative_prompt"]: + if key in original_prompt: + diffusion_input[key] = original_prompt[key] + + diffusion_inputs.append(diffusion_input) + + return diffusion_inputs diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index f4828fddaa..699e4b194a 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -3,6 +3,8 @@ # Copyright 2025 The Qwen team. """Stage input processor for Qwen3 Omni MoE: Thinker → Talker transition.""" +import logging +from dataclasses import dataclass, field from typing import Any import torch @@ -18,6 +20,12 @@ extract_speaker_from_request, ) +logger = logging.getLogger(__name__) + +# Pooling output layer keys: "0" = word embedding, "24" = accept_hidden_layer +_EMBED_LAYER_KEY = "0" +_HIDDEN_LAYER_KEY = "24" + def _compute_talker_prompt_ids_length(info, device: torch.device | str = "cuda") -> int: im_start_token_id = 151644 @@ -84,6 +92,191 @@ def _validate_stage_inputs(stage_list, engine_input_source): return stage.engine_outputs +# ========================= +# PD disaggregation helpers +# ========================= + + +def _get_prefill_stage(stage_list: list[Any], source_stage_id: int) -> Any | None: + if source_stage_id <= 0: + return None + source_stage = stage_list[source_stage_id] + if not getattr(source_stage, "is_decode_only", False): + return None + prev_stage = stage_list[source_stage_id - 1] + if getattr(prev_stage, "is_prefill_only", False) and prev_stage.engine_outputs is not None: + return prev_stage + return None + + +def _merge_pd_embeddings( + decode_emb: torch.Tensor, + decode_hid: torch.Tensor, + prefill_mm: dict[str, Any], + device: torch.device, + expected_total: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Merge prefill prompt embeddings with decode generated embeddings. + + In PD mode the prefill engine processes the prompt and the decode engine + generates tokens starting from position 1. This function concatenates + them, removing the overlapping token(s): + + merged = prefill[:P] + decode[overlap:] + + where overlap = P + D - expected_total. + """ + try: + p_emb = prefill_mm[_EMBED_LAYER_KEY].detach().to(device=device, dtype=torch.float) + p_hid = prefill_mm[_HIDDEN_LAYER_KEY].detach().to(device=device, dtype=torch.float) + except (KeyError, AttributeError, TypeError) as exc: + available_keys = list(prefill_mm.keys()) if isinstance(prefill_mm, dict) else type(prefill_mm).__name__ + logger.error( + "_merge_pd_embeddings: failed to extract prefill embeddings (%s). " + "Expected keys %r and %r, got: %s. " + "Falling back to decode-only embeddings – talker user-segment will be degraded.", + exc, + _EMBED_LAYER_KEY, + _HIDDEN_LAYER_KEY, + available_keys, + ) + return decode_emb, decode_hid + + if p_emb.shape[0] == 0 or decode_emb.shape[0] == 0: + return decode_emb, decode_hid + + raw_total = p_emb.shape[0] + decode_emb.shape[0] + overlap = max(0, raw_total - expected_total) if expected_total is not None else 0 + + merged_emb = torch.cat([p_emb, decode_emb[overlap:]], dim=0) + merged_hid = torch.cat([p_hid, decode_hid[overlap:]], dim=0) + return merged_emb, merged_hid + + +def _get_prefill_multimodal_output(prefill_stage: Any, output_index: int) -> dict[str, Any] | None: + """Return multimodal_output dict from the PD prefill stage for a given batch index.""" + try: + prefill_eos = prefill_stage.engine_outputs + prefill_eo = prefill_eos[min(output_index, len(prefill_eos) - 1)] + return prefill_eo.outputs[0].multimodal_output + except Exception: + return None + + +def _resolve_tts_token_embedding( + key: str, + *, + thinker_mm: dict[str, Any], + prefill_mm: dict[str, Any] | None, + device: torch.device, +) -> torch.Tensor | None: + """Return TTS BOS/EOS/PAD embedding tensors for the talker projection path. + + Values are taken from the current thinker (decode) ``multimodal_output``; in + PD mode, missing keys may be filled from the paired prefill stage output. + """ + val = thinker_mm.get(key) + if val is None and prefill_mm is not None: + val = prefill_mm.get(key) + return val.detach().to(device=device, dtype=torch.float) if val is not None else None + + +# ========================= +# Streaming input helpers +# ========================= + + +@dataclass +class _Thinker2TalkerStreamingState: + last_prompt_len: int = 0 + last_output_len: int = 0 + merged_sequences: list[int] = field(default_factory=list) + + +@dataclass +class _Qwen3OmniStreamingState: + thinker2talker: _Thinker2TalkerStreamingState = field(default_factory=_Thinker2TalkerStreamingState) + talker2code2wav_last_seq_len: int = 0 + + +def _get_qwen3_streaming_state( + request_id: str, + streaming_context: Any | None, +) -> _Qwen3OmniStreamingState: + bridge_states = getattr(streaming_context, "bridge_states", None) + per_model_state = bridge_states.setdefault("qwen3_omni", {}) + state = per_model_state.get(request_id) + if state is None: + state = _Qwen3OmniStreamingState() + per_model_state[request_id] = state + return state + + +def _get_streaming_talker_tokens( + request_id: str, + prompt_token_ids: list[int], + output_token_ids: list[int], + new_prompt_len_snapshot: int | None = None, + streaming_context: Any | None = None, + *, + clear_state: bool = False, +) -> tuple[list[int], list[int], list[int], list[int]]: + """Return streaming token slices and merged token views for thinker->talker. + e.g. For the second streaming input request: + merged_sequences: [input_prompt 1, output_tokens 1[:-1], input_prompt 2, output_tokens 2] + thinker_input_ids: [input_prompt 1, output_tokens 1[:-1], input_prompt 2] + Returns: + inc_prompt: prompt token delta for this segment. + inc_output: output token delta for this segment. + merged_sequences: full thinker_sequences to send downstream. + thinker_input_ids: full thinker_input_ids paired with merged_sequences. + """ + state = _get_qwen3_streaming_state(request_id, streaming_context).thinker2talker + if new_prompt_len_snapshot: + prompt_token_ids = prompt_token_ids[:-new_prompt_len_snapshot] + cur_prompt_len = len(prompt_token_ids) + cur_output_len = len(output_token_ids) + + inc_prompt = prompt_token_ids[state.last_prompt_len :] + inc_output = output_token_ids[state.last_output_len :] + delta_sequences = inc_prompt + inc_output + cached_sequences = state.merged_sequences + + merged_sequences = cached_sequences + delta_sequences + thinker_input_ids = cached_sequences + inc_prompt + + # Persist history for next segment. Drop the latest sampled token to keep + # thinker_input_ids / thinker_sequences alignment with next-step append. + cached_sequences.extend(delta_sequences[:-1]) + + state.last_prompt_len = cur_prompt_len + state.last_output_len = cur_output_len + + if clear_state: + state.last_prompt_len = 0 + state.last_output_len = 0 + state.merged_sequences.clear() + + return inc_prompt, inc_output, merged_sequences, thinker_input_ids + + +def _get_streaming_codec_delta_len( + cur_seq_len: int, + request_id: str, + talker_output: Any, + streaming_context: Any | None = None, +) -> int: + """Return newly added seq_len for talker->code2wav in streaming mode.""" + state = _get_qwen3_streaming_state(request_id, streaming_context) + prev_seq_len = state.talker2code2wav_last_seq_len + seq_len = cur_seq_len - prev_seq_len + state.talker2code2wav_last_seq_len = cur_seq_len + 1 + if bool(getattr(talker_output, "finished", False)): + # Final segment: clear history to avoid cross-session carry-over. + state.talker2code2wav_last_seq_len = 0 + return seq_len + + # ========================= # Thinker -> Talker # ========================= @@ -111,8 +304,8 @@ def thinker2talker_async_chunk( all_token_ids = _ensure_list(all_token_ids) prompt_token_ids = _ensure_list(prompt_token_ids) talker_additional_info = { - "thinker_prefill_embeddings": pooling_output.get("0").detach().cpu(), - "thinker_hidden_states": pooling_output.get("24").detach().cpu(), + "thinker_prefill_embeddings": pooling_output.get(_EMBED_LAYER_KEY).detach().cpu(), + "thinker_hidden_states": pooling_output.get(_HIDDEN_LAYER_KEY).detach().cpu(), "thinker_sequences": all_token_ids, "thinker_input_ids": prompt_token_ids, # Provide thinker-side TTS token embeddings for talker projection @@ -161,7 +354,7 @@ def thinker2talker_async_chunk( if output_token_ids: talker_additional_info["override_keys"] = ["thinker_decode_embeddings", "thinker_output_token_ids"] - talker_additional_info["thinker_decode_embeddings"] = pooling_output.get("0").detach().cpu() + talker_additional_info["thinker_decode_embeddings"] = pooling_output.get(_EMBED_LAYER_KEY).detach().cpu() talker_additional_info["thinker_output_token_ids"] = output_token_ids else: # When prefilling a chunked thinker, thinker_hidden_states needs to be updated. @@ -176,6 +369,7 @@ def thinker2talker( engine_input_source: list[int], prompt: OmniTokensPrompt | TextPrompt | None = None, requires_multimodal_data: bool = False, + streaming_context: Any | None = None, ) -> list[OmniTokensPrompt]: """ Process thinker outputs to create talker inputs. @@ -185,6 +379,9 @@ def thinker2talker( 2. Split hidden states into: prompt embeddings + generated embeddings 3. Package for talker with additional information + In PD disaggregation mode, merges prefill-stage prompt embeddings with + decode-stage generated embeddings before handing off to the talker. + Args: stage_list: List of stage objects engine_input_source: Source stage IDs (typically [0] for thinker) @@ -199,21 +396,66 @@ def thinker2talker( device = torch.device(current_platform.device_type) + # PD disaggregation: look up the preceding prefill stage (if any) + source_stage_id = engine_input_source[0] + prefill_stage = _get_prefill_stage(stage_list, source_stage_id) + # Process each thinker output for i, thinker_output in enumerate(thinker_outputs): output = thinker_output.outputs[0] + req_id = str(getattr(thinker_output, "request_id", f"idx-{i}")) + prompt_token_ids = _ensure_list(thinker_output.prompt_token_ids) + output_ids = _ensure_list(output.token_ids) + is_streaming_session = bool(getattr(streaming_context, "enabled", False)) + if is_streaming_session: + prompt_token_ids, output_ids, thinker_sequences, thinker_input_ids = _get_streaming_talker_tokens( + req_id, + prompt_token_ids, + output_ids, + getattr(streaming_context, "new_prompt_len_snapshot", None), + streaming_context, + clear_state=bool(getattr(thinker_output, "finished", False)), + ) + else: + thinker_sequences = prompt_token_ids + output_ids + thinker_input_ids = prompt_token_ids + # For streaming input, just send incremental prefill and hidden states tensor to talker + # Equally applicable to non-streaming cases. + new_seq_length = len(prompt_token_ids + output_ids) - 1 + thinker_mm = output.multimodal_output + # Full thinker embedding sequence for the talker: single thinker engine in the + # non-PD path; after optional merge with prefill-side tensors in PD mode. + thinker_emb = thinker_mm[_EMBED_LAYER_KEY].detach().to(device=device, dtype=torch.float)[-new_seq_length:] + thinker_hid = thinker_mm[_HIDDEN_LAYER_KEY].detach().to(device=device, dtype=torch.float)[-new_seq_length:] + + prefill_mm: dict[str, Any] | None = None + if prefill_stage is not None: + prefill_mm = _get_prefill_multimodal_output(prefill_stage, i) + + if prefill_mm is not None: + expected_total = len(prompt_token_ids) + len(output_ids) + try: + thinker_emb, thinker_hid = _merge_pd_embeddings( + thinker_emb, thinker_hid, prefill_mm, device, expected_total=expected_total + ) + except Exception as exc: + logger.warning("[PD] Could not merge prefill embeddings: %s", exc) info = { - "thinker_prefill_embeddings": output.multimodal_output["0"].detach().to(device=device, dtype=torch.float), - "thinker_hidden_states": output.multimodal_output["24"].detach().to(device=device, dtype=torch.float), - "thinker_sequences": ( - thinker_output.prompt_token_ids + output.token_ids - ), # the thinker_sequences is the whole ids - "thinker_input_ids": thinker_output.prompt_token_ids, + "thinker_prefill_embeddings": thinker_emb, + "thinker_hidden_states": thinker_hid, + "thinker_sequences": thinker_sequences, # the thinker_sequences is the whole ids + "thinker_input_ids": thinker_input_ids, # Provide thinker-side TTS token embeddings for talker projection - "tts_bos_embed": output.multimodal_output["tts_bos_embed"].detach().to(device=device, dtype=torch.float), - "tts_eos_embed": output.multimodal_output["tts_eos_embed"].detach().to(device=device, dtype=torch.float), - "tts_pad_embed": output.multimodal_output["tts_pad_embed"].detach().to(device=device, dtype=torch.float), + "tts_bos_embed": _resolve_tts_token_embedding( + "tts_bos_embed", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device + ), + "tts_eos_embed": _resolve_tts_token_embedding( + "tts_eos_embed", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device + ), + "tts_pad_embed": _resolve_tts_token_embedding( + "tts_pad_embed", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device + ), } speaker = extract_speaker_from_prompt(prompt, index=i) if speaker is not None: @@ -314,6 +556,7 @@ def talker2code2wav( engine_input_source: list[int], prompt: OmniTokensPrompt | TextPrompt | None = None, requires_multimodal_data: bool = False, + streaming_context: Any | None = None, ) -> list[OmniTokensPrompt]: """ Process talker outputs to create code2wav inputs. @@ -335,9 +578,14 @@ def talker2code2wav( talker_outputs = _validate_stage_inputs(stage_list, engine_input_source) code2wav_inputs: list[OmniTokensPrompt] = [] # Process each talker output - for talker_output in talker_outputs: + for i, talker_output in enumerate(talker_outputs): output = talker_output.outputs[0] - seq_len = len(output.token_ids) - 1 + req_id = str(getattr(talker_output, "request_id", f"idx-{i}")) + cur_seq_len = len(output.token_ids) - 1 + seq_len = cur_seq_len + is_streaming_session = bool(getattr(streaming_context, "enabled", False)) + if is_streaming_session: + seq_len = _get_streaming_codec_delta_len(cur_seq_len, req_id, talker_output, streaming_context) # Extract codec codes from talker output # Expected shape: [8, seq_len] (8-layer RVQ codes) codec_codes = ( diff --git a/vllm_omni/model_executor/stage_input_processors/voxcpm.py b/vllm_omni/model_executor/stage_input_processors/voxcpm.py new file mode 100644 index 0000000000..c2fcf521bf --- /dev/null +++ b/vllm_omni/model_executor/stage_input_processors/voxcpm.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from typing import Any + +import torch +from vllm.inputs import TextPrompt + +from vllm_omni.inputs.data import OmniTokensPrompt + +_VOXCPM_LATENT_MAGIC = 131071 + + +def _serialize_latent_to_codes(latent: Any) -> list[int]: + latent_tensor = latent if isinstance(latent, torch.Tensor) else torch.as_tensor(latent) + latent_tensor = latent_tensor.detach().cpu().contiguous() + if latent_tensor.ndim == 3: + if latent_tensor.shape[0] != 1: + raise ValueError(f"Expected batch=1 latent tensor, got shape={tuple(latent_tensor.shape)}") + latent_tensor = latent_tensor.squeeze(0) + if latent_tensor.ndim != 2: + raise ValueError(f"Unsupported latent_audio_feat shape for async chunk: {tuple(latent_tensor.shape)}") + latent_dim, time_dim = int(latent_tensor.shape[0]), int(latent_tensor.shape[1]) + packed = latent_tensor.to(torch.bfloat16).contiguous().view(torch.uint16).reshape(-1).to(torch.int32) + return [_VOXCPM_LATENT_MAGIC, latent_dim, time_dim, *packed.tolist()] + + +def _coerce_finished_flag(value: Any) -> bool: + """Normalize VoxCPM async-chunk finished markers to a Python bool.""" + if value is None: + return False + if isinstance(value, torch.Tensor): + if value.numel() != 1: + raise ValueError(f"finished tensor must be scalar, got shape={tuple(value.shape)}") + return bool(value.detach().cpu().item()) + if isinstance(value, (list, tuple)): + if not value: + return False + if len(value) != 1: + raise ValueError(f"finished container must have one element, got len={len(value)}") + return _coerce_finished_flag(value[0]) + return bool(value) + + +def latent2vae( + stage_list: list[Any], + engine_input_source: list[int], + prompt: OmniTokensPrompt | TextPrompt | None = None, + requires_multimodal_data: bool = False, +) -> list[OmniTokensPrompt]: + del prompt, requires_multimodal_data + + if not engine_input_source: + raise ValueError("engine_input_source cannot be empty") + + source_stage_id = engine_input_source[0] + if source_stage_id >= len(stage_list): + raise IndexError(f"Invalid stage_id: {source_stage_id}") + + source_outputs = stage_list[source_stage_id].engine_outputs + if source_outputs is None: + raise RuntimeError(f"Stage {source_stage_id} has no outputs yet") + + vae_inputs: list[OmniTokensPrompt] = [] + for source_output in source_outputs: + output = source_output.outputs[0] + multimodal_output = getattr(output, "multimodal_output", None) + if not isinstance(multimodal_output, dict) or "latent_audio_feat" not in multimodal_output: + raise ValueError( + "VoxCPM latent stage output missing 'latent_audio_feat'. " + f"request_id={getattr(source_output, 'request_id', None)}" + ) + + additional_information = { + "latent_audio_feat": multimodal_output["latent_audio_feat"], + } + if "sr" in multimodal_output: + additional_information["sample_rate"] = [int(multimodal_output["sr"])] + + vae_inputs.append( + OmniTokensPrompt( + prompt_token_ids=[0], + additional_information=additional_information, + multi_modal_data=None, + mm_processor_kwargs=None, + ) + ) + + return vae_inputs + + +def latent2vae_async_chunk( + transfer_manager: Any, + pooling_output: dict[str, Any] | None, + request: Any, + is_finished: bool = False, +) -> dict[str, Any] | None: + """Stage-0 latent → stage-1 VAE under ``async_chunk`` (connector payload).""" + # Kept for callback signature compatibility with OmniChunkTransferAdapter. + _ = transfer_manager + finished_request = _coerce_finished_flag(is_finished) + if callable(getattr(request, "is_finished", None)): + finished_request = finished_request or _coerce_finished_flag(request.is_finished()) + if not isinstance(pooling_output, dict): + if finished_request: + return { + "code_predictor_codes": [], + "finished": torch.tensor(True, dtype=torch.bool), + } + return None + + latent = pooling_output.get("latent_audio_feat") + if isinstance(latent, torch.Tensor) and latent.numel() == 0: + latent = None + + if latent is None: + if finished_request: + return { + "code_predictor_codes": [], + "finished": torch.tensor(True, dtype=torch.bool), + } + return None + + serialized_codes = _serialize_latent_to_codes(latent) + out: dict[str, Any] = { + "code_predictor_codes": serialized_codes, + "finished": torch.tensor(finished_request, dtype=torch.bool), + } + return out diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py index 9a7bb67065..c02c0c1427 100644 --- a/vllm_omni/outputs.py +++ b/vllm_omni/outputs.py @@ -9,6 +9,33 @@ from vllm_omni.inputs.data import OmniPromptType +@dataclass +class OmniConnectorOutput: + """Communication results from Model Runner to Scheduler. + + Carries transfer readiness signals so the Scheduler can make scheduling + decisions without ever calling connector.put()/get() directly. + + Attributes: + chunk_ready_req_ids: Request IDs with newly arrived chunks this cycle. + chunk_finished_req_ids: Request IDs whose final chunk has arrived. + request_metadata: Lightweight scheduling metadata keyed by request ID + (e.g. next_stage_prompt_len, code_predictor_codes, left_context_size). + Full payloads are owned by the Model Runner's local cache. + kv_sent_req_ids: Request IDs whose KV cache was successfully sent. + stage_recv_req_ids: Request IDs that received batch stage inputs. + has_pending_kv_work: True if the mixin has pending, active, or + completed KV transfers that the scheduler should account for. + """ + + chunk_ready_req_ids: set[str] = field(default_factory=set) + chunk_finished_req_ids: set[str] = field(default_factory=set) + request_metadata: dict[str, dict[str, Any]] = field(default_factory=dict) + kv_sent_req_ids: list[str] = field(default_factory=list) + stage_recv_req_ids: set[str] = field(default_factory=set) + has_pending_kv_work: bool = False + + class OmniModelRunnerOutput(ModelRunnerOutput): """Model runner output for omni models. @@ -24,6 +51,7 @@ class OmniModelRunnerOutput(ModelRunnerOutput): # IDs of requests whose KV cache has been extracted from GPU/NPU to CPU. # The Scheduler can safely free the block tables for these requests. kv_extracted_req_ids: list[str] | None = None + omni_connector_output: OmniConnectorOutput | None = None @dataclass @@ -72,6 +100,9 @@ class OmniRequestOutput: # memory usage info peak_memory_mb: float = 0.0 + # error handling + error: str | None = None + @classmethod def from_pipeline( cls, diff --git a/vllm_omni/patch.py b/vllm_omni/patch.py index eafff821a2..f6c483a92f 100644 --- a/vllm_omni/patch.py +++ b/vllm_omni/patch.py @@ -1,6 +1,8 @@ import sys +from functools import cached_property from aenum import extend_enum +from vllm.config import ModelConfig as _OriginalModelConfig from vllm.inputs import TokensPrompt as _OriginalTokensPrompt from vllm.model_executor.layers.rotary_embedding import ( MRotaryEmbedding as _OriginalMRotaryEmbedding, @@ -10,12 +12,63 @@ from vllm.v1.engine import EngineCoreRequest as _OriginalEngineCoreRequest from vllm.v1.request import Request as _OriginalRequest from vllm.v1.request import RequestStatus +from vllm.v1.request import StreamingUpdate as _OriginalStreamingUpdate import vllm_omni.logger # noqa: F401 from vllm_omni.engine import OmniEngineCoreOutput, OmniEngineCoreOutputs, OmniEngineCoreRequest from vllm_omni.inputs.data import OmniTokensPrompt from vllm_omni.model_executor.layers.rotary_embedding import OmniMRotaryEmbedding -from vllm_omni.request import OmniRequest +from vllm_omni.request import OmniRequest, OmniStreamingUpdate + +# ============================================================================= +# Patch ModelConfig.is_mm_prefix_lm to support omni-specific models +# ============================================================================= +# WHY: HunyuanImage-3.0 requires bidirectional attention for image tokens +# (cond_token_attn_type: "joint_full" in config.json). vLLM gates this on +# is_mm_prefix_lm, which checks an internal MM_PREFIX_LM_MODELS list that +# does not include "hunyuan_image_3_moe" (the upstream HF model_type). +# +# WHY NOT model-level: is_mm_prefix_lm is checked in vLLM core (scheduler, +# attention backend selection) before model code runs — no model-level hook. +# +# SCOPE: Only affects model_type in _OMNI_MM_PREFIX_LM_MODELS (currently +# just "hunyuan_image_3_moe"). All other models fall through to the +# original vLLM implementation unchanged. +# +# FRAGILITY: Relies on is_mm_prefix_lm being a cached_property on +# ModelConfig. The __dict__ access + __set_name__ dance works around a +# pydantic dataclass issue in vllm 0.19.0+. If vLLM changes +# is_mm_prefix_lm to a regular method or removes it, this will break. +# +# TODO: Upstream a configurable MM_PREFIX_LM_MODELS or a model_config flag +# so this patch can be removed. +_OMNI_MM_PREFIX_LM_MODELS = ("hunyuan_image_3_moe",) +# Access via __dict__ to avoid triggering cached_property.__get__ which fails +# with "Cannot use cached_property instance without calling __set_name__" in +# pydantic dataclasses (vllm 0.19.0+). +_cp = _OriginalModelConfig.__dict__["is_mm_prefix_lm"] +_original_is_mm_prefix_lm = _cp.func if hasattr(_cp, "func") else _cp.fget + + +def _patched_is_mm_prefix_lm(self): + if _original_is_mm_prefix_lm(self): + return True + model_type = getattr(self.hf_config, "model_type", "") + return model_type in _OMNI_MM_PREFIX_LM_MODELS + + +_patched_cp = cached_property(_patched_is_mm_prefix_lm) +_patched_cp.__set_name__(_OriginalModelConfig, "is_mm_prefix_lm") +_OriginalModelConfig.is_mm_prefix_lm = _patched_cp + +# Sanity check: verify the patch is active. If vLLM changes the descriptor +# type or __set_name__ semantics, this will fail loudly at import time +# rather than silently falling back to unpatched behavior. +_installed = _OriginalModelConfig.__dict__.get("is_mm_prefix_lm") +assert _installed is _patched_cp, ( + "is_mm_prefix_lm patch failed to install — bidirectional attention " + "for HunyuanImage3 will not work. Check vLLM ModelConfig changes." +) # ============================================================================= # Patch GlmImageTextConfig to expose mrope_section in rope_parameters @@ -63,5 +116,7 @@ def _patched_glm_image_text_config_init(self, *args, **kwargs): module.MRotaryEmbedding = OmniMRotaryEmbedding if hasattr(module, "Request") and module.Request == _OriginalRequest: module.Request = OmniRequest + if hasattr(module, "StreamingUpdate") and module.StreamingUpdate == _OriginalStreamingUpdate: + module.StreamingUpdate = OmniStreamingUpdate if hasattr(module, "EngineCoreRequest") and module.EngineCoreRequest == _OriginalEngineCoreRequest: module.EngineCoreRequest = OmniEngineCoreRequest diff --git a/vllm_omni/platforms/interface.py b/vllm_omni/platforms/interface.py index 8f1e66747d..b69731a67d 100644 --- a/vllm_omni/platforms/interface.py +++ b/vllm_omni/platforms/interface.py @@ -64,7 +64,7 @@ def get_default_stage_config_path(cls) -> str: @classmethod def get_diffusion_model_impl_qualname(cls, op_name: str) -> str: if op_name == "hunyuan_fused_moe": - return "vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe.HunyuanFusedMoEDefault" + return "vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe.HunyuanFusedMoEDefault" raise NotImplementedError(f"Unsupported diffusion model op: {op_name}") @classmethod diff --git a/vllm_omni/platforms/musa/platform.py b/vllm_omni/platforms/musa/platform.py index fe1ccc6d0b..64a70a9beb 100644 --- a/vllm_omni/platforms/musa/platform.py +++ b/vllm_omni/platforms/musa/platform.py @@ -39,7 +39,7 @@ def get_default_stage_config_path(cls) -> str: def get_diffusion_model_impl_qualname(cls, op_name: str) -> str: # MUSA uses default implementations for diffusion ops if op_name == "hunyuan_fused_moe": - return "vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe.HunyuanFusedMoEDefault" + return "vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe.HunyuanFusedMoEDefault" return super().get_diffusion_model_impl_qualname(op_name) @classmethod diff --git a/vllm_omni/platforms/npu/platform.py b/vllm_omni/platforms/npu/platform.py index c40dd6fea1..53ffe6775a 100644 --- a/vllm_omni/platforms/npu/platform.py +++ b/vllm_omni/platforms/npu/platform.py @@ -69,6 +69,9 @@ def get_diffusion_attn_backend_cls( # Try FLASH_ATTN if mindiesd is available, otherwise fall back to SDPA if find_spec("mindiesd"): + # Configure ASCEND_CUSTOM_OPP_PATH for mindiesd custom ops upon import + import mindiesd # noqa: F401 + logger.info("Defaulting to diffusion attention backend FLASH_ATTN") return DiffusionAttentionBackendEnum.FLASH_ATTN.get_path() diff --git a/vllm_omni/platforms/npu/stage_configs/hunyuan_image3_moe_dit.yaml b/vllm_omni/platforms/npu/stage_configs/hunyuan_image3_t2i.yaml similarity index 94% rename from vllm_omni/platforms/npu/stage_configs/hunyuan_image3_moe_dit.yaml rename to vllm_omni/platforms/npu/stage_configs/hunyuan_image3_t2i.yaml index 053e8a8cca..0fd03949d1 100644 --- a/vllm_omni/platforms/npu/stage_configs/hunyuan_image3_moe_dit.yaml +++ b/vllm_omni/platforms/npu/stage_configs/hunyuan_image3_t2i.yaml @@ -33,6 +33,3 @@ stage_args: # Runtime defaults runtime: enabled: true - defaults: - window_size: -1 - max_inflight: 1 diff --git a/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml b/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml deleted file mode 100644 index 8f7af161d6..0000000000 --- a/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml +++ /dev/null @@ -1,97 +0,0 @@ -# stage config for running qwen2.5-omni for multi-stage omni runtime. -stage_args: - - stage_id: 0 - stage_type: llm # Use llm stage type for AR stages - runtime: - process: true # Run this stage in a separate process - devices: "0" # Visible devices for this stage - engine_args: - model_stage: thinker - max_num_seqs: 1 - model_arch: Qwen2_5OmniForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.8 - enforce_eager: false - trust_remote_code: true - engine_output_type: latent - enable_prefix_caching: false - is_comprehension: true - final_output: true - final_output_type: text - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.1 - - stage_id: 1 - stage_type: llm # Use llm stage type for AR stages - runtime: - process: true - devices: "1" - engine_args: - model_stage: talker - max_num_seqs: 1 - model_arch: Qwen2_5OmniForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.8 - enforce_eager: true # haven't supported talker ACL graph on NPU - trust_remote_code: true - enable_prefix_caching: false - engine_output_type: latent - engine_input_source: [0] - custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker - default_sampling_params: - temperature: 0.9 - top_p: 0.8 - top_k: 40 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.05 - stop_token_ids: [8294] - - stage_id: 2 - stage_type: llm # Use llm stage type for AR stages - runtime: - process: true - devices: "2" # Example: use a different NPU than the previous stage; use "0" if single NPU - engine_args: - model_stage: code2wav - max_num_seqs: 1 - model_arch: Qwen2_5OmniForConditionalGeneration - worker_type: generation - scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler - gpu_memory_utilization: 0.15 - enforce_eager: true - trust_remote_code: true - enable_prefix_caching: false - engine_output_type: audio - engine_input_source: [1] - final_output: true - final_output_type: audio - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.1 - -# Top-level runtime config (concise): default windows and stage edges -runtime: - enabled: true - defaults: - window_size: -1 # Simplified: trigger downstream only after full upstream completion - max_inflight: 1 # Simplified: process serially within each stage - edges: - - from: 0 # thinker → talker: trigger only after receiving full input (-1) - to: 1 - window_size: -1 - - from: 1 # talker → code2wav: trigger only after receiving full input (-1) - to: 2 - window_size: -1 diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe.yaml deleted file mode 100644 index 2638c99cd4..0000000000 --- a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe.yaml +++ /dev/null @@ -1,99 +0,0 @@ -# Stage config for running Qwen3-Omni-MoE with 3-stage architecture -# Stage 0: Thinker (multimodal understanding + text generation) -# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes) -# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform) - -# The following config has been verified on 5x A2/A3-64G NPUs. -stage_args: - - stage_id: 0 - runtime: - devices: "0,1" - engine_args: - model_stage: thinker - max_num_seqs: 1 - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.6 - enforce_eager: false - trust_remote_code: true - engine_output_type: latent # Output hidden states for talker - distributed_executor_backend: "mp" - enable_prefix_caching: false - hf_config_name: thinker_config - tensor_parallel_size: 2 - # profiler_config: - # profiler: torch - # torch_profiler_dir: ./perf - final_output: true - final_output_type: text - is_comprehension: true - default_sampling_params: - temperature: 0.4 - top_p: 0.9 - top_k: 1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.05 - - - stage_id: 1 - runtime: - devices: "2" - engine_args: - model_stage: talker - max_num_seqs: 1 - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.6 - enforce_eager: true # haven't supported talker ACL graph on NPU - trust_remote_code: true - engine_output_type: latent # Output codec codes for code2wav - # tensor_parallel_size: 2 - enable_prefix_caching: false - distributed_executor_backend: "mp" - hf_config_name: talker_config - engine_input_source: [0] - custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker - # final_output: true - # final_output_type: text - default_sampling_params: - temperature: 0.9 - top_k: 50 - max_tokens: 4096 - seed: 42 - detokenize: False - repetition_penalty: 1.05 - stop_token_ids: [2150] - - - stage_id: 2 - runtime: - devices: "2" - engine_args: - model_stage: code2wav - max_num_seqs: 1 - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: generation - scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler - enforce_eager: true - trust_remote_code: true - async_scheduling: false - enable_prefix_caching: false - engine_output_type: audio # Final output: audio waveform - gpu_memory_utilization: 0.3 - distributed_executor_backend: "mp" - max_num_batched_tokens: 1000000 - hf_config_name: thinker_config - engine_input_source: [1] - custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav - final_output: true - final_output_type: audio - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 65536 - seed: 42 - detokenize: True - repetition_penalty: 1.1 diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml deleted file mode 100644 index 9aa20baecf..0000000000 --- a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml +++ /dev/null @@ -1,101 +0,0 @@ -# Stage config for running Qwen3-Omni-MoE with 3-stage architecture -# Stage 0: Thinker (multimodal understanding + text generation) -# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes) -# Stage 2: Code2Wav (16-layer RVQ codes → audio waveform) - -# The following config has been verified on 2x H100-80G GPUs. -async_chunk: true -stage_args: - - stage_id: 0 - stage_type: llm # Use llm stage type for AR stages - runtime: - devices: "0,1" - engine_args: - max_num_seqs: 10 - model_stage: thinker - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.9 - enforce_eager: false - trust_remote_code: true - engine_output_type: latent # Output hidden states for talker - distributed_executor_backend: "mp" - enable_prefix_caching: false - max_num_batched_tokens: 32768 - hf_config_name: thinker_config - tensor_parallel_size: 2 - custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk - final_output: true - final_output_type: text - is_comprehension: true - default_sampling_params: - temperature: 0.4 - top_p: 0.9 - top_k: 1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.05 - - - stage_id: 1 - stage_type: llm # Use llm stage type for AR stages - runtime: - devices: "2" - engine_args: - max_num_seqs: 10 - model_stage: talker - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.6 - enforce_eager: true - trust_remote_code: true - engine_output_type: latent # Output codec codes for code2wav - enable_prefix_caching: false - max_num_batched_tokens: 32768 - distributed_executor_backend: "mp" - hf_config_name: talker_config - custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk - engine_input_source: [0] - # final_output: true - # final_output_type: text - default_sampling_params: - temperature: 0.9 - top_k: 50 - max_tokens: 4096 - seed: 42 - detokenize: False - repetition_penalty: 1.0 - stop_token_ids: [2150] - - - stage_id: 2 - stage_type: llm # Use llm stage type for AR stages - runtime: - devices: "2" - engine_args: - max_num_seqs: 10 - model_stage: code2wav - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: generation - scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler - enforce_eager: true - trust_remote_code: true - async_scheduling: false - enable_prefix_caching: false - engine_output_type: audio # Final output: audio waveform - gpu_memory_utilization: 0.3 - distributed_executor_backend: "mp" - max_num_batched_tokens: 51200 # [TODO] if max_num_batched_tokens < max_num_seqs * 800, there will be precision problem. - hf_config_name: thinker_config - engine_input_source: [1] - final_output: true - final_output_type: audio - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 65536 - seed: 42 - detokenize: True - repetition_penalty: 1.1 diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml b/vllm_omni/platforms/npu/stage_configs/voxcpm.yaml similarity index 59% rename from vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml rename to vllm_omni/platforms/npu/stage_configs/voxcpm.yaml index 3f412fc4dc..dcd1f40517 100644 --- a/vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml +++ b/vllm_omni/platforms/npu/stage_configs/voxcpm.yaml @@ -1,42 +1,47 @@ -async_chunk: false stage_args: - stage_id: 0 stage_type: llm is_comprehension: true runtime: devices: "0" + max_batch_size: 1 engine_args: - model_stage: qwen3_tts - max_num_seqs: 1 - model_arch: Qwen3TTSTalkerForConditionalGeneration + dtype: bfloat16 + model_stage: latent_generator + model_arch: VoxCPMForConditionalGeneration + # Optional persistent HF-compatible config dir for native VoxCPM models. + hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,} worker_type: ar scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - enforce_eager: false + enforce_eager: true trust_remote_code: true async_scheduling: false enable_prefix_caching: false engine_output_type: latent - gpu_memory_utilization: 0.3 + gpu_memory_utilization: 0.75 distributed_executor_backend: "mp" - max_num_batched_tokens: 512 + max_num_batched_tokens: 4096 max_model_len: 4096 default_sampling_params: - temperature: 0.9 - top_k: 50 + temperature: 0.0 + top_p: 1.0 + top_k: -1 max_tokens: 4096 seed: 42 detokenize: false - repetition_penalty: 1.05 - stop_token_ids: [2150] + repetition_penalty: 1.0 + final_output: false - stage_id: 1 stage_type: llm runtime: devices: "0" + max_batch_size: 1 engine_args: - model_stage: code2wav - max_num_seqs: 1 - model_arch: Qwen3TTSCode2Wav + dtype: float32 + model_stage: vae + model_arch: VoxCPMForConditionalGeneration + hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,} worker_type: generation scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler enforce_eager: true @@ -44,21 +49,19 @@ stage_args: async_scheduling: false enable_prefix_caching: false engine_output_type: audio - gpu_memory_utilization: 0.2 + gpu_memory_utilization: 0.1 distributed_executor_backend: "mp" - max_num_batched_tokens: 65536 - max_model_len: 65536 + max_num_batched_tokens: 8192 + max_model_len: 4096 engine_input_source: [0] - custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav + custom_process_input_func: vllm_omni.model_executor.stage_input_processors.voxcpm.latent2vae final_output: true final_output_type: audio - tts_args: - max_instructions_length: 500 default_sampling_params: temperature: 0.0 top_p: 1.0 top_k: -1 - max_tokens: 65536 + max_tokens: 1 seed: 42 detokenize: true repetition_penalty: 1.0 diff --git a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs1.yaml b/vllm_omni/platforms/npu/stage_configs/voxcpm_async_chunk.yaml similarity index 64% rename from benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs1.yaml rename to vllm_omni/platforms/npu/stage_configs/voxcpm_async_chunk.yaml index ca441d286d..87843634cb 100644 --- a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs1.yaml +++ b/vllm_omni/platforms/npu/stage_configs/voxcpm_async_chunk.yaml @@ -1,5 +1,3 @@ -# Qwen3-TTS batch_size=1 config (streaming with async_chunk) -# 2-stage pipeline: Talker -> Code2Wav async_chunk: true stage_args: - stage_id: 0 @@ -7,87 +5,85 @@ stage_args: is_comprehension: true runtime: devices: "0" + max_batch_size: 1 engine_args: - max_num_seqs: 1 - model_stage: qwen3_tts - model_arch: Qwen3TTSTalkerForConditionalGeneration + dtype: bfloat16 + model_stage: latent_generator + model_arch: VoxCPMForConditionalGeneration + # Optional persistent HF-compatible config dir for native VoxCPM models. + hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,} worker_type: ar scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - enforce_eager: false + enforce_eager: true trust_remote_code: true - async_scheduling: true + async_scheduling: false enable_prefix_caching: false engine_output_type: latent - gpu_memory_utilization: 0.3 + gpu_memory_utilization: 0.75 distributed_executor_backend: "mp" - max_num_batched_tokens: 512 + max_num_batched_tokens: 4096 max_model_len: 4096 - custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk + custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.voxcpm.latent2vae_async_chunk output_connectors: to_stage_1: connector_of_shared_memory default_sampling_params: - temperature: 0.9 - top_k: 50 + temperature: 0.0 + top_p: 1.0 + top_k: -1 max_tokens: 4096 seed: 42 detokenize: false - repetition_penalty: 1.05 - stop_token_ids: [2150] + repetition_penalty: 1.0 + final_output: false - stage_id: 1 stage_type: llm runtime: devices: "0" + max_batch_size: 1 engine_args: - max_num_seqs: 1 - model_stage: code2wav - model_arch: Qwen3TTSCode2Wav + dtype: float32 + model_stage: vae + model_arch: VoxCPMForConditionalGeneration + hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,} worker_type: generation scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler enforce_eager: true trust_remote_code: true - async_scheduling: true + async_scheduling: false enable_prefix_caching: false engine_output_type: audio - gpu_memory_utilization: 0.3 + gpu_memory_utilization: 0.1 distributed_executor_backend: "mp" max_num_batched_tokens: 8192 - max_model_len: 32768 + max_model_len: 4096 engine_input_source: [0] - final_output: true - final_output_type: audio input_connectors: from_stage_0: connector_of_shared_memory - tts_args: - max_instructions_length: 500 + final_output: true + final_output_type: audio default_sampling_params: temperature: 0.0 top_p: 1.0 top_k: -1 - max_tokens: 65536 + max_tokens: 1 seed: 42 detokenize: true repetition_penalty: 1.0 runtime: enabled: true - defaults: - window_size: -1 - max_inflight: 1 connectors: connector_of_shared_memory: name: SharedMemoryConnector extra: shm_threshold_bytes: 65536 - codec_streaming: true + codec_streaming: false connector_get_sleep_s: 0.01 connector_get_max_wait_first_chunk: 3000 connector_get_max_wait: 300 - codec_chunk_frames: 25 - codec_left_context_frames: 25 edges: - from: 0 to: 1 - window_size: -1 diff --git a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py index 138948064b..ffb997048b 100644 --- a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py @@ -149,7 +149,15 @@ def execute_model( encoder_cache=self.encoder_cache, ) as ec_connector_output: self._execute_mm_encoder(scheduler_output) - return make_empty_encoder_model_runner_output(scheduler_output) + + kv_ids = self.kv_extracted_req_ids + self.kv_extracted_req_ids = None + + output = make_empty_encoder_model_runner_output(scheduler_output) + if kv_ids: + output = copy(output) + output.kv_extracted_req_ids = kv_ids + return output if not num_scheduled_tokens: if ( @@ -163,10 +171,20 @@ def execute_model( # dummy run to ensure coordinate_batch_across_dp # is called into to avoid out of sync issues. self._dummy_run(1) + + kv_ids = self.kv_extracted_req_ids + self.kv_extracted_req_ids = None + if not has_kv_transfer_group(): - # Return empty ModelRunnerOutput if no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT - return self.kv_connector_no_forward(scheduler_output, self.vllm_config) + output = EMPTY_MODEL_RUNNER_OUTPUT + else: + output = self.kv_connector_no_forward(scheduler_output, self.vllm_config) + + if kv_ids: + output = copy(output) + output.kv_extracted_req_ids = kv_ids + + return output if self.cache_config.kv_sharing_fast_prefill: assert not self.num_prompt_logprobs, ( "--kv-sharing-fast-prefill produces incorrect " diff --git a/vllm_omni/platforms/rocm/platform.py b/vllm_omni/platforms/rocm/platform.py index 4479e54f2a..7b0e09c128 100644 --- a/vllm_omni/platforms/rocm/platform.py +++ b/vllm_omni/platforms/rocm/platform.py @@ -16,6 +16,34 @@ class RocmOmniPlatform(OmniPlatform, RocmPlatform): Inherits all ROCm-specific implementations from vLLM's RocmPlatform, and adds Omni-specific interfaces from OmniPlatform. + + + NOTE: AR Attention Backend Overriding Logic: + ------------------------------------------ + Since vLLM v0.19.0, the default attention backend is ROCM_ATTN for ROCm. + However, the compatibility of ROCM_ATTN with Omni is not guaranteed. + Therefore, we still use TRITON_ATTN as the default attention backend, + when the selected_backend is not specified. + + So the behaviour of the attention backend overriding logic currently lives in + extract_stage_metadata in `vllm_omni/engine/stage_init_utils.py` + + ``` + if current_omni_platform.is_rocm(): + print(f"engine_args: {str(engine_args)}") + if engine_args.get("attention_backend") is None: + from vllm._aiter_ops import rocm_aiter_ops + + if rocm_aiter_ops.is_enabled(): + engine_args["attention_backend"] = "ROCM_AITER_FA" + # Before vLLM v0.19.0, the default attention backend is TRITON_ATTN for ROCm. + # Since vLLM v0.19.0, the default attention backend is ROCM_ATTN for ROCm. + # However, the compatibility of ROCM_ATTN with Omni is not guaranteed. + # Therefore, we still use TRITON_ATTN as the default attention backend, + # when the selected_backend is not specified. + engine_args["attention_backend"] = "TRITON_ATTN" + ``` + """ _omni_enum = OmniPlatformEnum.ROCM diff --git a/vllm_omni/platforms/rocm/stage_configs/qwen2_5_omni.yaml b/vllm_omni/platforms/rocm/stage_configs/qwen2_5_omni.yaml deleted file mode 100644 index 35e8193545..0000000000 --- a/vllm_omni/platforms/rocm/stage_configs/qwen2_5_omni.yaml +++ /dev/null @@ -1,102 +0,0 @@ -# stage config for running qwen2.5-omni for multi-stage omni runtime. - -# The following config has been verified on 2x H100-80G GPU. -stage_args: - - stage_id: 0 - runtime: - process: true # Run this stage in a separate process - devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device) - engine_args: - model_stage: thinker - max_num_seqs: 1 - model_arch: Qwen2_5OmniForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.8 - enforce_eager: true # Now we only support eager mode - trust_remote_code: true - engine_output_type: latent - enable_prefix_caching: false - max_num_batched_tokens: 32768 - is_comprehension: true - final_output: true - final_output_type: text - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.1 - - - stage_id: 1 - runtime: - process: true - devices: "1" - engine_args: - model_stage: talker - max_num_seqs: 1 - model_arch: Qwen2_5OmniForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.8 - enforce_eager: true - trust_remote_code: true - enable_prefix_caching: false - max_num_batched_tokens: 32768 - engine_output_type: latent - engine_input_source: [0] - custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker - default_sampling_params: - temperature: 0.9 - top_p: 0.8 - top_k: 40 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.05 - stop_token_ids: [8294] - - - stage_id: 2 - runtime: - process: true - devices: "2" # Example: use a different GPU than the previous stage; use "0" if single GPU - engine_args: - model_stage: code2wav - max_num_seqs: 1 - model_arch: Qwen2_5OmniForConditionalGeneration - worker_type: generation - scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler - gpu_memory_utilization: 0.15 - enforce_eager: true - trust_remote_code: true - enable_prefix_caching: false - max_num_batched_tokens: 32768 - engine_output_type: audio - engine_input_source: [1] - final_output: true - final_output_type: audio - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.1 - -# Top-level runtime config (concise): default windows and stage edges -runtime: - enabled: true - defaults: - window_size: -1 # Simplified: trigger downstream only after full upstream completion - max_inflight: 1 # Simplified: process serially within each stage - - edges: - - from: 0 # thinker → talker: trigger only after receiving full input (-1) - to: 1 - window_size: -1 - - from: 1 # talker → code2wav: trigger only after receiving full input (-1) - to: 2 - window_size: -1 diff --git a/vllm_omni/platforms/rocm/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/platforms/rocm/stage_configs/qwen3_omni_moe.yaml deleted file mode 100644 index 0ca150bee6..0000000000 --- a/vllm_omni/platforms/rocm/stage_configs/qwen3_omni_moe.yaml +++ /dev/null @@ -1,97 +0,0 @@ -# Stage config for running Qwen3-Omni-MoE with 3-stage architecture -# Stage 0: Thinker (multimodal understanding + text generation) -# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes) -# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform) - -# The following config has been verified on 2x H100-80G GPUs. -stage_args: - - stage_id: 0 - runtime: - devices: "0" - engine_args: - model_stage: thinker - max_num_seqs: 1 - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.9 - enforce_eager: true - trust_remote_code: true - engine_output_type: latent # Output hidden states for talker - distributed_executor_backend: "mp" - enable_prefix_caching: false - max_num_batched_tokens: 32768 - hf_config_name: thinker_config - tensor_parallel_size: 1 - final_output: true - final_output_type: text - is_comprehension: true - default_sampling_params: - temperature: 0.4 - top_p: 0.9 - top_k: 1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.05 - - - stage_id: 1 - runtime: - devices: "1" - engine_args: - model_stage: talker - max_num_seqs: 1 - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.6 - enforce_eager: true - trust_remote_code: true - engine_output_type: latent # Output codec codes for code2wav - # tensor_parallel_size: 2 - enable_prefix_caching: false - max_num_batched_tokens: 32768 - distributed_executor_backend: "mp" - hf_config_name: talker_config - engine_input_source: [0] - custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker - # final_output: true - # final_output_type: text - default_sampling_params: - temperature: 0.9 - top_k: 50 - max_tokens: 4096 - seed: 42 - detokenize: False - repetition_penalty: 1.05 - stop_token_ids: [2150] - - - stage_id: 2 - runtime: - devices: "1" - engine_args: - model_stage: code2wav - max_num_seqs: 1 - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: generation - scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler - enforce_eager: true - trust_remote_code: true - enable_prefix_caching: false - engine_output_type: audio # Final output: audio waveform - gpu_memory_utilization: 0.1 - distributed_executor_backend: "mp" - max_num_batched_tokens: 1000000 - hf_config_name: thinker_config - engine_input_source: [1] - custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav - final_output: true - final_output_type: audio - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 65536 - seed: 42 - detokenize: True - repetition_penalty: 1.1 diff --git a/vllm_omni/platforms/xpu/stage_configs/bagel.yaml b/vllm_omni/platforms/xpu/stage_configs/bagel.yaml index 0fc8a25ea5..7b27f6a443 100644 --- a/vllm_omni/platforms/xpu/stage_configs/bagel.yaml +++ b/vllm_omni/platforms/xpu/stage_configs/bagel.yaml @@ -67,10 +67,6 @@ stage_args: # Runtime edges runtime: enabled: true - defaults: - window_size: -1 - max_inflight: 1 - # Distributed connectors configuration (optional) # More connectors will be supported in the future. connectors: @@ -83,4 +79,3 @@ runtime: edges: - from: 0 to: 1 - window_size: -1 diff --git a/vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml b/vllm_omni/platforms/xpu/stage_configs/hunyuan_image3_t2i.yaml similarity index 93% rename from vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml rename to vllm_omni/platforms/xpu/stage_configs/hunyuan_image3_t2i.yaml index 8f969ced5f..4e0005f82a 100644 --- a/vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml +++ b/vllm_omni/platforms/xpu/stage_configs/hunyuan_image3_t2i.yaml @@ -78,6 +78,3 @@ stage_args: # Top-level runtime config (concise): default windows and stage edges runtime: enabled: true - defaults: - window_size: -1 # Simplified: trigger downstream only after full upstream completion - max_inflight: 1 # Simplified: process serially within each stage diff --git a/vllm_omni/platforms/xpu/stage_configs/qwen2_5_omni.yaml b/vllm_omni/platforms/xpu/stage_configs/qwen2_5_omni.yaml deleted file mode 100644 index 7dbedb29a5..0000000000 --- a/vllm_omni/platforms/xpu/stage_configs/qwen2_5_omni.yaml +++ /dev/null @@ -1,101 +0,0 @@ -# stage config for running qwen2.5-omni for multi-stage omni runtime. - -# The following config is verified with 2 * Intel Arc Pro B60 XPU. -stage_args: - - stage_id: 0 - stage_type: llm # Use llm stage type for AR stages - runtime: - process: true # Run this stage in a separate process - devices: "0" # Visible devices for this stage - engine_args: - model_stage: thinker - max_num_seqs: 1 - model_arch: Qwen2_5OmniForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.9 # thinker weight is around 16.74GB for Qwen2.5-Omni-7B - enforce_eager: false - trust_remote_code: true - engine_output_type: latent - enable_prefix_caching: false - is_comprehension: true - final_output: true - final_output_type: text - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.1 - - stage_id: 1 - stage_type: llm # Use llm stage type for AR stages - runtime: - process: true - devices: "1" - engine_args: - model_stage: talker - max_num_seqs: 1 - model_arch: Qwen2_5OmniForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.5 # talker weight is 6.03GB for Qwen2.5-Omni-7B - enforce_eager: false - trust_remote_code: true - enable_prefix_caching: false - engine_output_type: latent - engine_input_source: [0] - custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker - default_sampling_params: - temperature: 0.9 - top_p: 0.8 - top_k: 40 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.05 - stop_token_ids: [8294] - - - stage_id: 2 - stage_type: llm # Use llm stage type for AR stages - runtime: - process: true - devices: "1" - engine_args: - model_stage: code2wav - max_num_seqs: 1 - model_arch: Qwen2_5OmniForConditionalGeneration - worker_type: generation - scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler - gpu_memory_utilization: 0.3 # code2wav weight is around 1.46GB for Qwen2.5-Omni-7B - enforce_eager: true - trust_remote_code: true - enable_prefix_caching: false - engine_output_type: audio - engine_input_source: [1] - final_output: true - final_output_type: audio - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.1 - -# Top-level runtime config (concise): default windows and stage edges -runtime: - enabled: true - defaults: - window_size: -1 # Simplified: trigger downstream only after full upstream completion - max_inflight: 1 # Simplified: process serially within each stage - - edges: - - from: 0 # thinker → talker: trigger only after receiving full input (-1) - to: 1 - window_size: -1 - - from: 1 # talker → code2wav: trigger only after receiving full input (-1) - to: 2 - window_size: -1 diff --git a/vllm_omni/platforms/xpu/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/platforms/xpu/stage_configs/qwen3_omni_moe.yaml deleted file mode 100644 index 49914bebc4..0000000000 --- a/vllm_omni/platforms/xpu/stage_configs/qwen3_omni_moe.yaml +++ /dev/null @@ -1,102 +0,0 @@ -# Stage config for running Qwen3-Omni-MoE with 3-stage architecture -# Stage 0: Thinker (multimodal understanding + text generation) -# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes) -# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform) - -# The following config is verified with 8 * Intel Arc Pro B60 XPU. -stage_args: - - stage_id: 0 - stage_type: llm # Use llm stage type for AR stages - runtime: - devices: "0,1,2,3" - engine_args: - model_stage: thinker - max_num_seqs: 1 - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.9 # thinker weight is around 61.08GB for Qwen3-Omni-30B-A3B-Instruct - enforce_eager: true - trust_remote_code: true - engine_output_type: latent # Output hidden states for talker - distributed_executor_backend: "mp" - enable_prefix_caching: false - max_num_batched_tokens: 32768 - hf_config_name: thinker_config - tensor_parallel_size: 4 - max_cudagraph_capture_size: 0 - final_output: true - final_output_type: text - is_comprehension: true - default_sampling_params: - temperature: 0.4 - top_p: 0.9 - top_k: 1 - max_tokens: 2048 - seed: 42 - detokenize: True - repetition_penalty: 1.05 - - - stage_id: 1 - stage_type: llm # Use llm stage type for AR stages - runtime: - devices: "4" - engine_args: - model_stage: talker - max_num_seqs: 1 - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: ar - scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler - gpu_memory_utilization: 0.6 # talker weight is around 8.5GB for Qwen3-Omni-30B-A3B-Instruct - enforce_eager: true - trust_remote_code: true - engine_output_type: latent # Output codec codes for code2wav - enable_prefix_caching: false - max_num_batched_tokens: 32768 - distributed_executor_backend: "mp" - hf_config_name: talker_config - max_cudagraph_capture_size: 0 - engine_input_source: [0] - custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker - # final_output: true - # final_output_type: text - default_sampling_params: - temperature: 0.9 - top_k: 50 - max_tokens: 4096 - seed: 42 - detokenize: False - repetition_penalty: 1.05 - stop_token_ids: [2150] - - - stage_id: 2 - stage_type: llm # Use llm stage type for AR stages - runtime: - devices: "4" - engine_args: - model_stage: code2wav - max_num_seqs: 1 - model_arch: Qwen3OmniMoeForConditionalGeneration - worker_type: generation - scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler - enforce_eager: true - trust_remote_code: true - enable_prefix_caching: false - engine_output_type: audio # Final output: audio waveform - gpu_memory_utilization: 0.3 # code2wav weight is around 0.4GB for Qwen3-Omni-30B-A3B-Instruct - distributed_executor_backend: "mp" - max_num_batched_tokens: 1000000 - hf_config_name: thinker_config - max_cudagraph_capture_size: 0 - engine_input_source: [1] - custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav - final_output: true - final_output_type: audio - default_sampling_params: - temperature: 0.0 - top_p: 1.0 - top_k: -1 - max_tokens: 65536 - seed: 42 - detokenize: True - repetition_penalty: 1.1 diff --git a/vllm_omni/platforms/xpu/stage_configs/voxtral_tts.yaml b/vllm_omni/platforms/xpu/stage_configs/voxtral_tts.yaml index 10051c1eda..0820ab6320 100644 --- a/vllm_omni/platforms/xpu/stage_configs/voxtral_tts.yaml +++ b/vllm_omni/platforms/xpu/stage_configs/voxtral_tts.yaml @@ -88,9 +88,6 @@ stage_args: runtime: enabled: true - defaults: - window_size: -1 - max_inflight: 1 connectors: connector_of_shared_memory: @@ -108,4 +105,3 @@ runtime: edges: - from: 0 to: 1 - window_size: -1 diff --git a/vllm_omni/quantization/component_config.py b/vllm_omni/quantization/component_config.py index 7986da8850..f9286079be 100644 --- a/vllm_omni/quantization/component_config.py +++ b/vllm_omni/quantization/component_config.py @@ -23,6 +23,31 @@ ) +# Pre-quantized checkpoints (modelopt FP8/FP4/MXFP8) only quantize the +# Thinker LM. Vision and audio encoder weights remain in BF16 with no +# corresponding scale tensors in the checkpoint. +PRE_QUANTIZED_METHODS: frozenset[str] = frozenset({"modelopt", "modelopt_fp4", "modelopt_mxfp8"}) + + +def resolve_encoder_quant_config( + quant_config: QuantizationConfig | None, +) -> QuantizationConfig | None: + """Resolve quantization config for vision / audio encoders. + + Returns *None* for pre-quantized methods so that FP8 kernels are never + applied to BF16 encoder weights (which lack scale tensors). All other + configs — including ``ComponentQuantizationConfig`` and ``None`` — are + returned as-is so the caller can handle them. + """ + if ( + quant_config is not None + and not isinstance(quant_config, ComponentQuantizationConfig) + and quant_config.get_name() in PRE_QUANTIZED_METHODS + ): + return None + return quant_config + + class ComponentQuantizationConfig(QuantizationConfig): """Routes quantization to different configs by layer prefix.""" diff --git a/vllm_omni/request.py b/vllm_omni/request.py index 3ec325316f..48cbf9b31d 100644 --- a/vllm_omni/request.py +++ b/vllm_omni/request.py @@ -1,8 +1,11 @@ from collections.abc import Callable +from dataclasses import dataclass from typing import TYPE_CHECKING import numpy as np import torch +from vllm.multimodal.inputs import MultiModalFeatureSpec +from vllm.sampling_params import SamplingParams from vllm.v1.request import Request if TYPE_CHECKING: @@ -92,3 +95,34 @@ def from_engine_core_request( resumable=request.resumable, reasoning_ended=request.reasoning_ended, ) + + +@dataclass +class OmniStreamingUpdate: + """ + Override: add additional information + Lightweight data for streaming session continuation. + + Contains only the fields needed to update an existing streaming session + with new input data. + """ + + mm_features: list[MultiModalFeatureSpec] | None + prompt_token_ids: list[int] | None + max_tokens: int + arrival_time: float + sampling_params: SamplingParams | None + additional_information: AdditionalInformationPayload | None = None + + @classmethod + def from_request(cls, request: "Request") -> "OmniStreamingUpdate | None": + if not request.resumable: + return None + return cls( + mm_features=request.mm_features, + prompt_token_ids=request.prompt_token_ids, + max_tokens=request.max_tokens, + arrival_time=request.arrival_time, + sampling_params=request.sampling_params, + additional_information=request.additional_information, + ) diff --git a/vllm_omni/transformers_utils/configs/__init__.py b/vllm_omni/transformers_utils/configs/__init__.py index 59b23f9149..598ac3a965 100644 --- a/vllm_omni/transformers_utils/configs/__init__.py +++ b/vllm_omni/transformers_utils/configs/__init__.py @@ -17,6 +17,13 @@ "FishSpeechConfig": "vllm_omni.transformers_utils.configs.fish_speech", "FishSpeechSlowARConfig": "vllm_omni.transformers_utils.configs.fish_speech", "FishSpeechFastARConfig": "vllm_omni.transformers_utils.configs.fish_speech", + "VoxCPMConfig": "vllm_omni.transformers_utils.configs.voxcpm", + "VoxCPM2Config": "vllm_omni.transformers_utils.configs.voxcpm2", + "BailingMoeV2Config": "vllm_omni.transformers_utils.configs.ming_flash_omni", + "BailingMM2Config": "vllm_omni.transformers_utils.configs.ming_flash_omni", + "MingFlashOmniConfig": "vllm_omni.transformers_utils.configs.ming_flash_omni", + "Qwen3VLMoeVisionConfig": "vllm_omni.transformers_utils.configs.ming_flash_omni", + "WhisperEncoderConfig": "vllm_omni.transformers_utils.configs.ming_flash_omni", } __all__ = [ @@ -27,6 +34,13 @@ "FishSpeechConfig", "FishSpeechSlowARConfig", "FishSpeechFastARConfig", + "VoxCPMConfig", + "VoxCPM2Config", + "BailingMoeV2Config", + "BailingMM2Config", + "MingFlashOmniConfig", + "Qwen3VLMoeVisionConfig", + "WhisperEncoderConfig", ] @@ -47,3 +61,6 @@ def __dir__(): # run as soon as `vllm_omni.transformers_utils.configs` is imported. from vllm_omni.transformers_utils.configs import fish_speech as _fish_speech # noqa: F401, E402 from vllm_omni.transformers_utils.configs import mammoth_moda2 as _mammoth_moda2 # noqa: F401, E402 +from vllm_omni.transformers_utils.configs import ming_flash_omni as _ming_flash_omni # noqa: F401, E402 +from vllm_omni.transformers_utils.configs import voxcpm as _voxcpm # noqa: F401, E402 +from vllm_omni.transformers_utils.configs import voxcpm2 as _voxcpm2 # noqa: F401, E402 diff --git a/vllm_omni/transformers_utils/configs/ming_flash_omni.py b/vllm_omni/transformers_utils/configs/ming_flash_omni.py new file mode 100644 index 0000000000..dd13b682de --- /dev/null +++ b/vllm_omni/transformers_utils/configs/ming_flash_omni.py @@ -0,0 +1,302 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The vLLM-Omni team. +# Copyright 2024 ANT Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration for Ming-flash-omni-2.0 model""" + +import os +from typing import Any, ClassVar + +from transformers import AutoConfig, AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class BailingMoeV2Config(PretrainedConfig): + model_type = "bailing_moe_v2" + + def __init__( + self, + vocab_size=30592, + hidden_size=1024, + intermediate_size=None, + num_hidden_layers=24, + num_attention_heads=16, + num_key_value_heads=0, + hidden_act="silu", + use_qkv_bias=False, + use_qk_norm=False, + use_bias=True, + rms_norm_eps=1e-05, + norm_head=False, + tie_word_embeddings=False, + embedding_dropout=0.0, + attention_dropout=0.0, + output_dropout=0.0, + initializer_range=0.02, + max_position_embeddings=16384, + rope_theta=10000.0, + use_cache=True, + use_sliding_window=False, + sliding_window=81920, + max_window_layers=28, + rope_scaling=None, + mrope_section=None, + pad_token_id=126081, + num_experts=16, + num_shared_experts=1, + num_experts_per_tok=2, + n_group=8, + topk_group=4, + routed_scaling_factor=2.5, + moe_intermediate_size=None, + first_k_dense_replace=0, + head_dim=None, + output_router_logits=False, + partial_rotary_factor=0.5, + router_type="topN", + _attn_implementation="flash_attention_2", + use_interleaved_frame_timestamp=True, + # Multimodal token IDs + image_patch_token=157157, + video_patch_token=157175, + audio_patch_token=157168, + image_start_token=157158, + video_start_token=157160, + audio_start_token=157169, + image_end_token=157159, + video_end_token=157161, + audio_end_token=157170, + # Position encoding parameters + spatial_merge_size=2, + tokens_per_second=2, + **kwargs, + ): + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.use_qkv_bias = use_qkv_bias + self.use_bias = use_bias + self.norm_head = norm_head + self.rms_norm_eps = rms_norm_eps + self.embedding_dropout = embedding_dropout + self.attention_dropout = attention_dropout + self.output_dropout = output_dropout + self.initializer_range = initializer_range + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.use_cache = use_cache + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + self.head_dim = head_dim or self.hidden_size // self.num_attention_heads + self.use_qk_norm = use_qk_norm # arg unused; QK norm is always applied + + # By default, match the value of `mrope_section` + # to `apply_3d_rotary_pos_emb` in Ming's repo: + # https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/modeling_bailing_moe_v2.py + if mrope_section is None: + mrope_section = (rope_scaling or {}).get("mrope_section", [8, 12, 12]) + # Ensure mrope_section is stored inside rope_scaling + if rope_scaling is not None and isinstance(rope_scaling, dict): + rope_scaling = dict(rope_scaling) + rope_scaling.setdefault("mrope_section", mrope_section) + self.rope_scaling = rope_scaling + + # NOTE: Expose rope_parameters["mrope_section"] + # This refers to the pattern used for GLM-Image in vllm_omni/patch.py + rope_type = (rope_scaling or {}).get("type", (rope_scaling or {}).get("rope_type", "")) + if rope_type in ("video_rope", "3D", "mrope"): + self.rope_parameters = {"mrope_section": mrope_section} + else: + self.rope_parameters = None + + # MoE configs + self.num_experts = num_experts + self.num_shared_experts = num_shared_experts + self.num_experts_per_tok = num_experts_per_tok + self.n_group = n_group + self.topk_group = topk_group + self.moe_intermediate_size = moe_intermediate_size + self.first_k_dense_replace = first_k_dense_replace + self.output_router_logits = output_router_logits + self.routed_scaling_factor = routed_scaling_factor + self.partial_rotary_factor = partial_rotary_factor + self.router_type = router_type + self.use_interleaved_frame_timestamp = use_interleaved_frame_timestamp + self._attn_implementation = _attn_implementation + + # Multimodal token IDs and position encoding + self.image_patch_token = image_patch_token + self.video_patch_token = video_patch_token + self.audio_patch_token = audio_patch_token + self.image_start_token = image_start_token + self.video_start_token = video_start_token + self.audio_start_token = audio_start_token + self.image_end_token = image_end_token + self.video_end_token = video_end_token + self.audio_end_token = audio_end_token + self.spatial_merge_size = spatial_merge_size + self.tokens_per_second = tokens_per_second + + super().__init__(pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs) + + +class Qwen3VLMoeVisionConfig(PretrainedConfig): + """Configuration class for Qwen3 MoE Vision Transformer""" + + model_type = "qwen3_moe_vit" + + def __init__( + self, + depth=27, + hidden_size=1152, + hidden_act="gelu_pytorch_tanh", + intermediate_size=4304, + num_heads=16, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=3584, + num_position_embeddings=2304, + deepstack_visual_indexes=[8, 16, 24], + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.num_position_embeddings = num_position_embeddings + self.initializer_range = initializer_range + self.deepstack_visual_indexes = deepstack_visual_indexes + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + if "vision_config" in config_dict: + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class WhisperEncoderConfig(PretrainedConfig): + """Configuration class for Whisper audio encoder""" + + model_type = "whisper_encoder" + + def __init__( + self, + whisper_encoder_config: dict[str, Any] | None = None, + ds_kernel_size=3, + ds_stride=2, + norm_query_embeds=True, + **kwargs, + ): + super().__init__(**kwargs) + self.whisper_encoder_config = whisper_encoder_config or {} + self.ds_kernel_size = ds_kernel_size + self.ds_stride = ds_stride + self.norm_query_embeds = norm_query_embeds + + +class BailingMM2Config(PretrainedConfig): + model_type = "bailingmm_moe_v2_lite" + is_composition = True + sub_configs: ClassVar = {"llm_config": AutoConfig} + + def __init__( + self, + mlp_depth=1, + llm_config: BailingMoeV2Config | None = None, + vision_config: Qwen3VLMoeVisionConfig | None = None, + audio_config: WhisperEncoderConfig | None = None, + **kwargs, + ): + self.audio_config = WhisperEncoderConfig(**audio_config) if isinstance(audio_config, dict) else audio_config + self.vision_config = ( + Qwen3VLMoeVisionConfig(**vision_config) if isinstance(vision_config, dict) else vision_config + ) + self.llm_config = BailingMoeV2Config(**llm_config) if isinstance(llm_config, dict) else llm_config + self.mlp_depth = mlp_depth + super().__init__(**kwargs) + + def get_text_config(self, decoder: bool = False) -> PretrainedConfig: # noqa: ARG002 + return self.llm_config + + +class MingFlashOmniConfig(PretrainedConfig): + """Configuration class for unified Ming-flash-omni-2.0 model""" + + model_type = "ming_flash_omni" + is_composition = True + sub_configs: ClassVar = {"thinker_config": BailingMM2Config} + + def __init__( + self, + thinker_config: BailingMM2Config | None = None, + image_gen_config: dict[str, Any] | None = None, + talker_config: dict[str, Any] | None = None, + **kwargs, + ): + super().__init__(**kwargs) + + if isinstance(thinker_config, dict): + self.thinker_config = BailingMM2Config(**thinker_config) + else: + self.thinker_config = thinker_config or BailingMM2Config() + + # Image generation config (for future implementation) + self.image_gen_config = image_gen_config + + # Talker config (for future implementation) + self.talker_config = talker_config + + def get_text_config(self, decoder: bool = False) -> PretrainedConfig: # noqa: ARG002 + return self.thinker_config.get_text_config() + + +# Register model_type -> config class for AutoConfig +AutoConfig.register(BailingMoeV2Config.model_type, BailingMoeV2Config) +AutoConfig.register(BailingMM2Config.model_type, BailingMM2Config) +AutoConfig.register(MingFlashOmniConfig.model_type, MingFlashOmniConfig) + +# Register tokenizer mapping for composition configs so that +# AutoTokenizer.from_pretrained can resolve the tokenizer class +AutoTokenizer.register(BailingMM2Config, fast_tokenizer_class=PreTrainedTokenizerFast) +AutoTokenizer.register(MingFlashOmniConfig, fast_tokenizer_class=PreTrainedTokenizerFast) diff --git a/vllm_omni/transformers_utils/configs/voxcpm.py b/vllm_omni/transformers_utils/configs/voxcpm.py new file mode 100644 index 0000000000..0267838915 --- /dev/null +++ b/vllm_omni/transformers_utils/configs/voxcpm.py @@ -0,0 +1,68 @@ +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig + + +class VoxCPMConfig(PretrainedConfig): + model_type = "voxcpm" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + bos_token_id: int = 1, + eos_token_id: int = 2, + vocab_size: int = 32000, + hidden_size: int = 1024, + intermediate_size: int = 4096, + max_position_embeddings: int = 4096, + num_attention_heads: int = 16, + num_hidden_layers: int = 24, + num_key_value_heads: int = 16, + rms_norm_eps: float = 1e-6, + rope_theta: float = 10000.0, + rope_scaling: dict | None = None, + lm_config: dict | None = None, + encoder_config: dict | None = None, + dit_config: dict | None = None, + audio_vae_config: dict | None = None, + patch_size: int = 2, + feat_dim: int = 64, + residual_lm_num_layers: int = 6, + scalar_quantization_latent_dim: int = 256, + scalar_quantization_scale: int = 9, + max_length: int = 4096, + device: str = "cuda", + dtype: str = "bfloat16", + dit_mean_mode: bool = False, + **kwargs, + ): + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.num_key_value_heads = num_key_value_heads + self.rms_norm_eps = rms_norm_eps + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + + self.lm_config = lm_config or {} + self.encoder_config = encoder_config or {} + self.dit_config = dit_config or {} + self.audio_vae_config = audio_vae_config + + self.patch_size = patch_size + self.feat_dim = feat_dim + self.residual_lm_num_layers = residual_lm_num_layers + self.scalar_quantization_latent_dim = scalar_quantization_latent_dim + self.scalar_quantization_scale = scalar_quantization_scale + self.max_length = max_length + self.device = device + self.dtype = dtype + self.dit_mean_mode = dit_mean_mode + + +AutoConfig.register("voxcpm", VoxCPMConfig) + +__all__ = ["VoxCPMConfig"] diff --git a/vllm_omni/transformers_utils/configs/voxcpm2.py b/vllm_omni/transformers_utils/configs/voxcpm2.py new file mode 100644 index 0000000000..c625284bd6 --- /dev/null +++ b/vllm_omni/transformers_utils/configs/voxcpm2.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math + +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation + + +class VoxCPM2Config(PretrainedConfig): + """Configuration for VoxCPM2 native AR integration. + + The HuggingFace checkpoint stores LM parameters inside a nested + ``lm_config`` dict. This class hoists them to top-level attributes + so that vllm's ``MiniCPMModel`` can consume them directly. + + vllm's MiniCPM **always** applies muP scaling (scale_emb, scale_depth, + dim_model_base). VoxCPM2 was trained with ``use_mup=false``, so we + neutralise the scalings: + * ``scale_emb = 1.0`` + * ``scale_depth = sqrt(num_hidden_layers)`` (cancels the division) + * ``dim_model_base = hidden_size`` (makes scale_width = 1.0) + """ + + model_type = "voxcpm2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + # -- top-level VoxCPM2 params -- + architecture: str = "voxcpm2", + lm_config: dict | None = None, + encoder_config: dict | None = None, + dit_config: dict | None = None, + audio_vae_config: dict | None = None, + patch_size: int = 4, + feat_dim: int = 64, + residual_lm_num_layers: int = 8, + residual_lm_no_rope: bool = True, + scalar_quantization_latent_dim: int = 512, + scalar_quantization_scale: int = 9, + max_length: int = 8192, + device: str = "cuda", + dtype: str = "bfloat16", + # -- LM defaults (overridden by lm_config if present) -- + bos_token_id: int = 1, + eos_token_id: int = 2, + vocab_size: int = 73448, + hidden_size: int = 2048, + intermediate_size: int = 6144, + max_position_embeddings: int = 32768, + num_attention_heads: int = 16, + num_hidden_layers: int = 28, + num_key_value_heads: int = 2, + rms_norm_eps: float = 1e-5, + rope_theta: float = 10000.0, + rope_scaling: dict | None = None, + **kwargs, + ): + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + self.architecture = architecture + + # -- VoxCPM2-specific fields -- + self.lm_config = lm_config or {} + self.encoder_config = encoder_config or {} + self.dit_config = dit_config or {} + self.audio_vae_config = audio_vae_config or {} + self.patch_size = patch_size + self.feat_dim = feat_dim + self.residual_lm_num_layers = residual_lm_num_layers + self.residual_lm_no_rope = residual_lm_no_rope + self.scalar_quantization_latent_dim = scalar_quantization_latent_dim + self.scalar_quantization_scale = scalar_quantization_scale + self.max_length = max_length + self.device = device + self.dtype = dtype + + # -- Hoist LM parameters to top-level for MiniCPMModel -- + lm = self.lm_config + self.vocab_size = lm.get("vocab_size", vocab_size) + self.hidden_size = lm.get("hidden_size", hidden_size) + self.intermediate_size = lm.get("intermediate_size", intermediate_size) + self.max_position_embeddings = lm.get("max_position_embeddings", max_position_embeddings) + self.num_attention_heads = lm.get("num_attention_heads", num_attention_heads) + self.num_hidden_layers = lm.get("num_hidden_layers", num_hidden_layers) + self.num_key_value_heads = lm.get("num_key_value_heads", num_key_value_heads) + self.rms_norm_eps = lm.get("rms_norm_eps", rms_norm_eps) + self.rope_theta = lm.get("rope_theta", rope_theta) + + # MiniCPM-specific: kv_channels overrides head_dim when set. + kv_channels = lm.get("kv_channels") + if kv_channels is not None: + self.head_dim = kv_channels + else: + self.head_dim = self.hidden_size // self.num_attention_heads + + # MiniCPM requires hidden_act; VoxCPM2 uses SiLU. + self.hidden_act = "silu" + self.hidden_act_param = 0.0 + self.tie_word_embeddings = False + self.num_experts = 0 + + # -- muP scaling -- + # Native VoxCPM2 MiniCPM gates scale_depth behind use_mup: + # use_mup=True → residual += h * (scale_depth / sqrt(N)) + # use_mup=False → residual += h (plain add, no scaling) + # But vllm's MiniCPMModel ALWAYS applies scale_depth / sqrt(N). + # Native applies scale_emb externally; vllm applies it in embed_input_ids. + use_mup = lm.get("use_mup", False) + self.scale_emb = lm.get("scale_emb", 1.0) + if use_mup: + self.scale_depth = lm.get("scale_depth", 1.0) + self.dim_model_base = lm.get("dim_model_base", self.hidden_size) + else: + # Neutralize: scale_depth/sqrt(N) = 1.0, scale_width = 1.0 + self.scale_depth = math.sqrt(self.num_hidden_layers) + self.dim_model_base = self.hidden_size + + # -- RoPE scaling (longrope) -- + raw_rope = lm.get("rope_scaling", rope_scaling) + if raw_rope is not None: + self.rope_scaling = dict(raw_rope) + # HF expects "rope_type" not "type" + if "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling.pop("type") + # longrope requires "factor" (used by HF validation) + if "factor" not in self.rope_scaling: + self.rope_scaling["factor"] = 1.0 + rope_config_validation(self) + + # vllm's MiniCPMAttention reads config.rope_parameters (a dict + # with rope_type, theta, scaling factors, etc.). HF transformers + # only auto-computes this for known model_types; for custom + # types we must build it manually. + if not getattr(self, "rope_parameters", None): + rp = dict(self.rope_scaling) + rp["rope_theta"] = self.rope_theta + self.rope_parameters = rp + else: + self.rope_scaling = None + + def get_text_config(self, **kwargs): + """Return self as the text config — LM attributes are top-level.""" + return self + + +AutoConfig.register("voxcpm2", VoxCPM2Config) + +__all__ = ["VoxCPM2Config"] diff --git a/vllm_omni/transformers_utils/processors/__init__.py b/vllm_omni/transformers_utils/processors/__init__.py new file mode 100644 index 0000000000..52ca657539 --- /dev/null +++ b/vllm_omni/transformers_utils/processors/__init__.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The vLLM-Omni team. + +from vllm_omni.transformers_utils.processors.ming import ( + MingFlashOmniProcessor, + MingWhisperFeatureExtractor, +) + +__all__ = [ + "MingFlashOmniProcessor", + "MingWhisperFeatureExtractor", +] diff --git a/vllm_omni/transformers_utils/processors/ming.py b/vllm_omni/transformers_utils/processors/ming.py new file mode 100644 index 0000000000..7f414b7268 --- /dev/null +++ b/vllm_omni/transformers_utils/processors/ming.py @@ -0,0 +1,430 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The vLLM-Omni team. +# Copyright 2024 ANT Group and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import numpy as np +import torch +from transformers import AutoFeatureExtractor, AutoProcessor +from transformers.feature_extraction_utils import BatchFeature, FeatureExtractionMixin +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +DEFAULT_IMAGE_PATCH_TOKEN = "" +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" +DEFAULT_VID_START_TOKEN = "" +DEFAULT_FRAME_PATCH_TOKEN = "" + +DEFAULT_AUDIO_PATCH_TOKEN = "" +DEFAULT_AU_START_TOKEN = "" + +# High-level placeholders used in user prompts +PLACEHOLDER_IMAGE_TOKEN_IN_TEXT = "" +PLACEHOLDER_VIDEO_TOKEN_IN_TEXT = "